DEV: Use structured responses for summaries (#1252)

* DEV: Use structured responses for summaries

* Fix system specs

* Make response_format a first class citizen and update endpoints to support it

* Response format can be specified in the persona

* lint

* switch to jsonb and make column nullable

* Reify structured output chunks. Move JSON parsing to the depths of Completion

* Switch to JsonStreamingTracker for partial JSON parsing
This commit is contained in:
Roman Rizzi 2025-05-06 10:09:39 -03:00 committed by GitHub
parent c6a307b473
commit c0a2d4c935
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 822 additions and 68 deletions

View File

@ -221,6 +221,10 @@ module DiscourseAi
permitted[:tools] = permit_tools(tools)
end
if response_format = params.dig(:ai_persona, :response_format)
permitted[:response_format] = permit_response_format(response_format)
end
permitted
end
@ -235,6 +239,18 @@ module DiscourseAi
[tool, options, !!force_tool]
end
end
def permit_response_format(response_format)
return [] if !response_format.is_a?(Array)
response_format.map do |element|
if element && element.is_a?(ActionController::Parameters)
element.permit!
else
false
end
end
end
end
end
end

View File

@ -325,9 +325,11 @@ class AiPersona < ActiveRecord::Base
end
def system_persona_unchangeable
error_msg = I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona")
if top_p_changed? || temperature_changed? || system_prompt_changed? || name_changed? ||
description_changed?
errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"))
errors.add(:base, error_msg)
elsif tools_changed?
old_tools = tools_change[0]
new_tools = tools_change[1]
@ -335,9 +337,12 @@ class AiPersona < ActiveRecord::Base
old_tool_names = old_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set
new_tool_names = new_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set
if old_tool_names != new_tool_names
errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"))
end
errors.add(:base, error_msg) if old_tool_names != new_tool_names
elsif response_format_changed?
old_format = response_format_change[0].map { |f| f["key"] }.to_set
new_format = response_format_change[1].map { |f| f["key"] }.to_set
errors.add(:base, error_msg) if old_format != new_format
end
end
@ -395,6 +400,7 @@ end
# rag_llm_model_id :bigint
# default_llm_id :bigint
# question_consolidator_llm_id :bigint
# response_format :jsonb
#
# Indexes
#

View File

@ -30,7 +30,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:allow_chat_direct_messages,
:allow_topic_mentions,
:allow_personal_messages,
:force_default_llm
:force_default_llm,
:response_format
has_one :user, serializer: BasicUserSerializer, embed: :object
has_many :rag_uploads, serializer: UploadSerializer, embed: :object

View File

@ -33,6 +33,7 @@ const CREATE_ATTRIBUTES = [
"allow_topic_mentions",
"allow_chat_channel_mentions",
"allow_chat_direct_messages",
"response_format",
];
const SYSTEM_ATTRIBUTES = [
@ -60,6 +61,7 @@ const SYSTEM_ATTRIBUTES = [
"allow_topic_mentions",
"allow_chat_channel_mentions",
"allow_chat_direct_messages",
"response_format",
];
export default class AiPersona extends RestModel {
@ -151,6 +153,7 @@ export default class AiPersona extends RestModel {
const attrs = this.getProperties(CREATE_ATTRIBUTES);
this.populateTools(attrs);
attrs.forced_tool_count = this.forced_tool_count || -1;
attrs.response_format = attrs.response_format || [];
return attrs;
}

View File

@ -15,6 +15,7 @@ import Group from "discourse/models/group";
import { i18n } from "discourse-i18n";
import AdminUser from "admin/models/admin-user";
import GroupChooser from "select-kit/components/group-chooser";
import AiPersonaResponseFormatEditor from "../components/modal/ai-persona-response-format-editor";
import AiLlmSelector from "./ai-llm-selector";
import AiPersonaToolOptions from "./ai-persona-tool-options";
import AiToolSelector from "./ai-tool-selector";
@ -325,6 +326,8 @@ export default class PersonaEditor extends Component {
<field.Textarea />
</form.Field>
<AiPersonaResponseFormatEditor @form={{form}} @data={{data}} />
<form.Field
@name="default_llm_id"
@title={{i18n "discourse_ai.ai_persona.default_llm"}}

View File

@ -0,0 +1,99 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { fn, hash } from "@ember/helper";
import { action } from "@ember/object";
import { gt } from "truth-helpers";
import ModalJsonSchemaEditor from "discourse/components/modal/json-schema-editor";
import { prettyJSON } from "discourse/lib/formatter";
import { i18n } from "discourse-i18n";
export default class AiPersonaResponseFormatEditor extends Component {
@tracked showJsonEditorModal = false;
jsonSchema = {
type: "array",
uniqueItems: true,
title: i18n("discourse_ai.ai_persona.response_format.modal.root_title"),
items: {
type: "object",
title: i18n("discourse_ai.ai_persona.response_format.modal.key_title"),
properties: {
key: {
type: "string",
},
type: {
type: "string",
enum: ["string", "integer", "boolean"],
},
},
},
};
get editorTitle() {
return i18n("discourse_ai.ai_persona.response_format.title");
}
get responseFormatAsJSON() {
return JSON.stringify(this.args.data.response_format);
}
get displayJSON() {
const toDisplay = {};
this.args.data.response_format.forEach((keyDesc) => {
toDisplay[keyDesc.key] = keyDesc.type;
});
return prettyJSON(toDisplay);
}
@action
openModal() {
this.showJsonEditorModal = true;
}
@action
closeModal() {
this.showJsonEditorModal = false;
}
@action
updateResponseFormat(form, value) {
form.set("response_format", JSON.parse(value));
}
<template>
<@form.Container @title={{this.editorTitle}} @format="large">
<div class="ai-persona-editor__response-format">
{{#if (gt @data.response_format.length 0)}}
<pre class="ai-persona-editor__response-format-pre">
<code
>{{this.displayJSON}}</code>
</pre>
{{else}}
<div class="ai-persona-editor__response-format-none">
{{i18n "discourse_ai.ai_persona.response_format.no_format"}}
</div>
{{/if}}
<@form.Button
@action={{this.openModal}}
@label="discourse_ai.ai_persona.response_format.open_modal"
@disabled={{@data.system}}
/>
</div>
</@form.Container>
{{#if this.showJsonEditorModal}}
<ModalJsonSchemaEditor
@model={{hash
value=this.responseFormatAsJSON
updateValue=(fn this.updateResponseFormat @form)
settingName=this.editorTitle
jsonSchema=this.jsonSchema
}}
@closeModal={{this.closeModal}}
/>
{{/if}}
</template>
}

View File

@ -52,6 +52,21 @@
margin-bottom: 10px;
font-size: var(--font-down-1);
}
&__response-format {
width: 100%;
display: block;
}
&__response-format-pre {
margin-bottom: 0;
white-space: pre-line;
}
&__response-format-none {
margin-bottom: 1em;
margin-top: 0.5em;
}
}
.rag-options {

View File

@ -323,6 +323,13 @@ en:
rag_conversation_chunks: "Search conversation chunks"
rag_conversation_chunks_help: "The number of chunks to use for the RAG model searches. Increase to increase the amount of context the AI can use."
persona_description: "Personas are a powerful feature that allows you to customize the behavior of the AI engine in your Discourse forum. They act as a 'system message' that guides the AI's responses and interactions, helping to create a more personalized and engaging user experience."
response_format:
title: "JSON response format"
no_format: "No JSON format specified"
open_modal: "Edit"
modal:
root_title: "Response structure"
key_title: "Key"
list:
enabled: "AI Bot?"

View File

@ -72,6 +72,8 @@ DiscourseAi::Personas::Persona.system_personas.each do |persona_class, id|
persona.tools = tools.map { |name, value| [name, value] }
persona.response_format = instance.response_format
persona.system_prompt = instance.system_prompt
persona.top_p = instance.top_p
persona.temperature = instance.temperature

View File

@ -0,0 +1,6 @@
# frozen_string_literal: true
class AddResponseFormatJsonToPersonass < ActiveRecord::Migration[7.2]
def change
add_column :ai_personas, :response_format, :jsonb
end
end

View File

@ -10,7 +10,7 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
@raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser =
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls
end
def append(json)

View File

@ -42,6 +42,7 @@ module DiscourseAi
result = { system: system, messages: messages }
result[:inferenceConfig] = inference_config if inference_config.present?
result[:toolConfig] = tool_config if tool_config.present?
result[:response_format] = { type: "json_object" } if options[:response_format].present?
result
end

View File

@ -88,10 +88,16 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support?
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload =
default_options(dialect).merge(model_params.except(:response_format)).merge(
messages: prompt.messages,
)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
prefilled_message = +""
if prompt.has_tools?
payload[:tools] = prompt.tools
if dialect.tool_choice.present?
@ -100,16 +106,24 @@ module DiscourseAi
# prefill prompt to nudge LLM to generate a response that is useful.
# without this LLM (even 3.7) can get confused and start text preambles for a tool calls.
payload[:messages] << {
role: "assistant",
content: dialect.no_more_tool_calls_text,
}
prefilled_message << dialect.no_more_tool_calls_text
else
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
end
end
end
# Prefill prompt to force JSON output.
if model_params[:response_format].present?
prefilled_message << " " if !prefilled_message.empty?
prefilled_message << "{"
@forced_json_through_prefill = true
end
if !prefilled_message.empty?
payload[:messages] << { role: "assistant", content: prefilled_message }
end
payload
end

View File

@ -116,9 +116,14 @@ module DiscourseAi
payload = nil
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload =
default_options(dialect).merge(model_params.except(:response_format)).merge(
messages: prompt.messages,
)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
prefilled_message = +""
if prompt.has_tools?
payload[:tools] = prompt.tools
if dialect.tool_choice.present?
@ -128,15 +133,23 @@ module DiscourseAi
# payload[:tool_choice] = { type: "none" }
# prefill prompt to nudge LLM to generate a response that is useful, instead of trying to call a tool
payload[:messages] << {
role: "assistant",
content: dialect.no_more_tool_calls_text,
}
prefilled_message << dialect.no_more_tool_calls_text
else
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
end
end
end
# Prefill prompt to force JSON output.
if model_params[:response_format].present?
prefilled_message << " " if !prefilled_message.empty?
prefilled_message << "{"
@forced_json_through_prefill = true
end
if !prefilled_message.empty?
payload[:messages] << { role: "assistant", content: prefilled_message }
end
elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova)
payload = prompt.to_payload(default_options(dialect).merge(model_params))
else

View File

@ -73,6 +73,7 @@ module DiscourseAi
LlmQuota.check_quotas!(@llm_model, user)
start_time = Time.now
@forced_json_through_prefill = false
@partial_tool_calls = partial_tool_calls
@output_thinking = output_thinking
@ -106,6 +107,18 @@ module DiscourseAi
prompt = dialect.translate
structured_output = nil
if model_params[:response_format].present?
schema_properties =
model_params[:response_format].dig(:json_schema, :schema, :properties)
if schema_properties.present?
structured_output =
DiscourseAi::Completions::StructuredOutput.new(schema_properties)
end
end
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
@ -123,6 +136,10 @@ module DiscourseAi
request = prepare_request(request_body)
# Some providers rely on prefill to return structured outputs, so the start
# of the JSON won't be included in the response. Supply it to keep JSON valid.
structured_output << +"{" if structured_output && @forced_json_through_prefill
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
@ -140,10 +157,17 @@ module DiscourseAi
xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && xml_stripper
if @streaming_mode
blk =
lambda do |partial, cancel|
partial = xml_stripper << partial if partial.is_a?(String)
if partial.is_a?(String)
partial = xml_stripper << partial if xml_stripper
if structured_output.present?
structured_output << partial
partial = structured_output
end
end
orig_blk.call(partial, cancel) if partial
end
end
@ -167,6 +191,7 @@ module DiscourseAi
xml_stripper: xml_stripper,
partials_raw: partials_raw,
response_raw: response_raw,
structured_output: structured_output,
)
return response_data
end
@ -373,7 +398,8 @@ module DiscourseAi
xml_tool_processor:,
xml_stripper:,
partials_raw:,
response_raw:
response_raw:,
structured_output:
)
response_raw << response.read_body
response_data = decode(response_raw)
@ -403,6 +429,26 @@ module DiscourseAi
response_data.reject!(&:blank?)
if structured_output.present?
has_string_response = false
response_data =
response_data.reduce([]) do |memo, data|
if data.is_a?(String)
structured_output << data
has_string_response = true
next(memo)
else
memo << data
end
memo
end
# We only include the structured output if there was actually a structured response
response_data << structured_output if has_string_response
end
# this is to keep stuff backwards compatible
response_data = response_data.first if response_data.length == 1

View File

@ -40,6 +40,8 @@ module DiscourseAi
"The number of completions you requested exceed the number of canned responses"
end
response = as_structured_output(response) if model_params[:response_format].present?
raise response if response.is_a?(StandardError)
@completions += 1
@ -54,6 +56,8 @@ module DiscourseAi
yield(response, cancel_fn)
elsif is_thinking?(response)
yield(response, cancel_fn)
elsif is_structured_output?(response)
yield(response, cancel_fn)
else
response.each_char do |char|
break if cancelled
@ -80,6 +84,20 @@ module DiscourseAi
def is_tool?(response)
response.is_a?(DiscourseAi::Completions::ToolCall)
end
def is_structured_output?(response)
response.is_a?(DiscourseAi::Completions::StructuredOutput)
end
def as_structured_output(response)
schema_properties = model_params[:response_format].dig(:json_schema, :schema, :properties)
return response if schema_properties.blank?
output = DiscourseAi::Completions::StructuredOutput.new(schema_properties)
output << { schema_properties.keys.first => response }.to_json
output
end
end
end
end

View File

@ -84,7 +84,16 @@ module DiscourseAi
payload[:tool_config] = { function_calling_config: function_calling_config }
end
payload[:generationConfig].merge!(model_params) if model_params.present?
if model_params.present?
payload[:generationConfig].merge!(model_params.except(:response_format))
if model_params[:response_format].present?
# https://ai.google.dev/api/generate-content#generationconfig
payload[:generationConfig][:responseSchema] = model_params[:response_format]
payload[:generationConfig][:responseMimeType] = "application/json"
end
end
payload
end

View File

@ -34,7 +34,12 @@ module DiscourseAi
end
def prepare_payload(prompt, model_params, dialect)
payload = default_options.merge(model_params).merge(messages: prompt)
payload =
default_options.merge(model_params.except(:response_format)).merge(messages: prompt)
if model_params[:response_format].present?
payload[:response_format] = { type: "json_object" }
end
payload[:stream] = true if @streaming_mode

View File

@ -2,11 +2,11 @@
module DiscourseAi
module Completions
class ToolCallProgressTracker
attr_reader :current_key, :current_value, :tool_call
class JsonStreamingTracker
attr_reader :current_key, :current_value, :stream_consumer
def initialize(tool_call)
@tool_call = tool_call
def initialize(stream_consumer)
@stream_consumer = stream_consumer
@current_key = nil
@current_value = nil
@parser = DiscourseAi::Completions::JsonStreamingParser.new
@ -18,7 +18,7 @@ module DiscourseAi
@parser.value do |v|
if @current_key
tool_call.notify_progress(@current_key, v)
stream_consumer.notify_progress(@current_key, v)
@current_key = nil
end
end
@ -39,7 +39,7 @@ module DiscourseAi
if @parser.state == :start_string && @current_key
# this is is worth notifying
tool_call.notify_progress(@current_key, @parser.buf)
stream_consumer.notify_progress(@current_key, @parser.buf)
end
@current_key = nil if @parser.state == :end_value

View File

@ -304,6 +304,8 @@ module DiscourseAi
# @param feature_context { Hash - Optional } - The feature context to use for the completion.
# @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls.
# @param output_thinking { Boolean - Optional } - If true, the completion will return the thinking output for thinking models.
# @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output.
# @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema.
#
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
#
@ -321,6 +323,7 @@ module DiscourseAi
feature_context: nil,
partial_tool_calls: false,
output_thinking: false,
response_format: nil,
extra_model_params: nil,
&partial_read_blk
)
@ -336,6 +339,7 @@ module DiscourseAi
feature_context: feature_context,
partial_tool_calls: partial_tool_calls,
output_thinking: output_thinking,
response_format: response_format,
extra_model_params: extra_model_params,
},
)
@ -344,6 +348,7 @@ module DiscourseAi
model_params[:temperature] = temperature if temperature
model_params[:top_p] = top_p if top_p
model_params[:response_format] = response_format if response_format
model_params.merge!(extra_model_params) if extra_model_params
if prompt.is_a?(String)

View File

@ -10,7 +10,7 @@ class DiscourseAi::Completions::NovaMessageProcessor
@raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser =
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls
end
def append(json)

View File

@ -59,7 +59,7 @@ module DiscourseAi::Completions
if id.present? && name.present?
@tool_arguments = +""
@tool = ToolCall.new(id: id, name: name)
@streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls
@streaming_parser = JsonStreamingTracker.new(self) if @partial_tool_calls
end
@tool_arguments << arguments.to_s

View File

@ -0,0 +1,52 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class StructuredOutput
def initialize(json_schema_properties)
@property_names = json_schema_properties.keys.map(&:to_sym)
@property_cursors =
json_schema_properties.reduce({}) do |m, (k, prop)|
m[k.to_sym] = 0 if prop[:type] == "string"
m
end
@tracked = {}
@partial_json_tracker = JsonStreamingTracker.new(self)
end
attr_reader :last_chunk_buffer
def <<(raw)
@partial_json_tracker << raw
end
def read_latest_buffered_chunk
@property_names.reduce({}) do |memo, pn|
if @tracked[pn].present?
# This means this property is a string and we want to return unread chunks.
if @property_cursors[pn].present?
unread = @tracked[pn][@property_cursors[pn]..]
memo[pn] = unread if unread.present?
@property_cursors[pn] = @tracked[pn].length
else
# Ints and bools are always returned as is.
memo[pn] = @tracked[pn]
end
end
memo
end
end
def notify_progress(key, value)
key_sym = key.to_sym
return if !@property_names.include?(key_sym)
@tracked[key_sym] = value
end
end
end
end

View File

@ -64,10 +64,13 @@ module DiscourseAi
user = context.user
llm_kwargs = { user: user }
llm_kwargs = llm_args.dup
llm_kwargs[:user] = user
llm_kwargs[:temperature] = persona.temperature if persona.temperature
llm_kwargs[:top_p] = persona.top_p if persona.top_p
llm_kwargs[:max_tokens] = llm_args[:max_tokens] if llm_args[:max_tokens].present?
llm_kwargs[:response_format] = build_json_schema(
persona.response_format,
) if persona.response_format.present?
needs_newlines = false
tools_ran = 0
@ -148,6 +151,8 @@ module DiscourseAi
raw_context << partial
current_thinking << partial
end
elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput)
update_blk.call(partial, cancel, nil, :structured_output)
else
update_blk.call(partial, cancel)
end
@ -176,6 +181,10 @@ module DiscourseAi
embed_thinking(raw_context)
end
def returns_json?
persona.response_format.present?
end
private
def embed_thinking(raw_context)
@ -301,6 +310,30 @@ module DiscourseAi
placeholder
end
def build_json_schema(response_format)
properties =
response_format
.to_a
.reduce({}) do |memo, format|
memo[format[:key].to_sym] = { type: format[:type] }
memo
end
{
type: "json_schema",
json_schema: {
name: "reply",
schema: {
type: "object",
properties: properties,
required: properties.keys.map(&:to_s),
additionalProperties: false,
},
strict: true,
},
}
end
end
end
end

View File

@ -49,6 +49,8 @@ module DiscourseAi
@site_title = site_title
@site_description = site_description
@time = time
@resource_url = resource_url
@feature_name = feature_name
@resource_url = resource_url

View File

@ -160,6 +160,10 @@ module DiscourseAi
{}
end
def response_format
nil
end
def available_tools
self
.class

View File

@ -18,9 +18,19 @@ module DiscourseAi
- Limit the summary to a maximum of 40 words.
- Do *NOT* repeat the discussion title in the summary.
Return the summary inside <ai></ai> tags.
Format your response as a JSON object with a single key named "summary", which has the summary as the value.
Your output should be in the following format:
<output>
{"summary": "xx"}
</output>
Where "xx" is replaced by the summary.
PROMPT
end
def response_format
[{ key: "summary", type: "string" }]
end
end
end
end

View File

@ -18,8 +18,20 @@ module DiscourseAi
- Example: link to the 6th post by jane: [agreed with]({resource_url}/6)
- Example: link to the 13th post by joe: [joe]({resource_url}/13)
- When formatting usernames either use @USERNAME OR [USERNAME]({resource_url}/POST_NUMBER)
Format your response as a JSON object with a single key named "summary", which has the summary as the value.
Your output should be in the following format:
<output>
{"summary": "xx"}
</output>
Where "xx" is replaced by the summary.
PROMPT
end
def response_format
[{ key: "summary", type: "string" }]
end
end
end
end

View File

@ -29,18 +29,10 @@ module DiscourseAi
summary = fold(truncated_content, user, &on_partial_blk)
clean_summary = Nokogiri::HTML5.fragment(summary).css("ai")&.first&.text || summary
if persist_summaries
AiSummary.store!(
strategy,
llm_model,
clean_summary,
truncated_content,
human: user&.human?,
)
AiSummary.store!(strategy, llm_model, summary, truncated_content, human: user&.human?)
else
AiSummary.new(summarized_text: clean_summary)
AiSummary.new(summarized_text: summary)
end
end
@ -118,9 +110,20 @@ module DiscourseAi
)
summary = +""
buffer_blk =
Proc.new do |partial, cancel, placeholder, type|
if type.blank?
Proc.new do |partial, cancel, _, type|
if type == :structured_output
json_summary_schema_key = bot.persona.response_format&.first.to_h
partial_summary =
partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym]
if partial_summary.present?
summary << partial_summary
on_partial_blk.call(partial_summary, cancel) if on_partial_blk
end
elsif type.blank?
# Assume response is a regular completion.
summary << partial
on_partial_blk.call(partial, cancel) if on_partial_blk
end
@ -154,12 +157,6 @@ module DiscourseAi
item
end
def text_only_update(&on_partial_blk)
Proc.new do |partial, cancel, placeholder, type|
on_partial_blk.call(partial, cancel) if type.blank?
end
end
end
end
end

View File

@ -21,6 +21,7 @@ RSpec.describe Jobs::FastTrackTopicGist do
created_at: 10.minutes.ago,
)
end
let(:updated_gist) { "They updated me :(" }
context "when it's up to date" do

View File

@ -51,7 +51,9 @@ RSpec.describe Jobs::StreamTopicAiSummary do
end
it "publishes updates with a partial summary" do
with_responses(["dummy"]) do
summary = "dummy"
with_responses([summary]) do
messages =
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
job.execute(topic_id: topic.id, user_id: user.id)
@ -59,12 +61,16 @@ RSpec.describe Jobs::StreamTopicAiSummary do
partial_summary_update = messages.first.data
expect(partial_summary_update[:done]).to eq(false)
expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text)).to eq("dummy")
expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text).chomp("\"}")).to eq(
summary,
)
end
end
it "publishes a final update to signal we're done and provide metadata" do
with_responses(["dummy"]) do
summary = "dummy"
with_responses([summary]) do
messages =
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
job.execute(topic_id: topic.id, user_id: user.id)
@ -73,6 +79,7 @@ RSpec.describe Jobs::StreamTopicAiSummary do
final_update = messages.last.data
expect(final_update[:done]).to eq(true)
expect(final_update.dig(:ai_topic_summary, :summarized_text)).to eq(summary)
expect(final_update.dig(:ai_topic_summary, :algorithm)).to eq("fake")
expect(final_update.dig(:ai_topic_summary, :outdated)).to eq(false)
expect(final_update.dig(:ai_topic_summary, :can_regenerate)).to eq(true)

View File

@ -590,7 +590,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: {"type": "content_block_stop", "index": 0}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} }
data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} }
event: ping
data: {"type": "ping"}
@ -769,4 +769,92 @@ data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_
expect(result).to eq("I won't use any tools. Here's a direct response instead.")
end
end
describe "structured output via prefilling" do
it "forces the response to be a JSON and using the given JSON schema" do
schema = {
type: "json_schema",
json_schema: {
name: "reply",
schema: {
type: "object",
properties: {
key: {
type: "string",
},
},
required: ["key"],
additionalProperties: false,
},
strict: true,
},
}
body = (<<~STRING).strip
event: message_start
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
event: content_block_start
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\""}}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "key"}}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\":\\""}}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello!"}}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\"}"}}
event: content_block_stop
data: {"type": "content_block_stop", "index": 0}
event: message_delta
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
event: message_stop
data: {"type": "message_stop"}
STRING
parsed_body = nil
stub_request(:post, url).with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"X-Api-Key" => "123",
"Anthropic-Version" => "2023-06-01",
},
).to_return(status: 200, body: body)
structured_output = nil
llm.generate(
prompt,
user: Discourse.system_user,
feature_name: "testing",
response_format: schema,
) { |partial, cancel| structured_output = partial }
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
expected_body = {
model: "claude-3-opus-20240229",
max_tokens: 4096,
messages: [{ role: "user", content: "user1: hello" }, { role: "assistant", content: "{" }],
system: "You are hello bot",
stream: true,
}
expect(parsed_body).to eq(expected_body)
end
end
end

View File

@ -546,4 +546,69 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
)
end
end
describe "structured output via prefilling" do
it "forces the response to be a JSON and using the given JSON schema" do
schema = {
type: "json_schema",
json_schema: {
name: "reply",
schema: {
type: "object",
properties: {
key: {
type: "string",
},
},
required: ["key"],
additionalProperties: false,
},
strict: true,
},
}
messages =
[
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
{ type: "content_block_delta", delta: { text: "\"" } },
{ type: "content_block_delta", delta: { text: "key" } },
{ type: "content_block_delta", delta: { text: "\":\"" } },
{ type: "content_block_delta", delta: { text: "Hello!" } },
{ type: "content_block_delta", delta: { text: "\"}" } },
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
].map { |message| encode_message(message) }
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
structured_output = nil
proxy.generate("hello world", response_format: schema, user: user) do |partial|
structured_output = partial
end
expected = {
"max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [
{ "role" => "user", "content" => "hello world" },
{ "role" => "assistant", "content" => "{" },
],
"system" => "You are a helpful bot",
}
expect(JSON.parse(request.body)).to eq(expected)
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
end
end
end
end

View File

@ -308,4 +308,64 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
expect(audit.request_tokens).to eq(14)
expect(audit.response_tokens).to eq(11)
end
it "is able to return structured outputs" do
schema = {
type: "json_schema",
json_schema: {
name: "reply",
schema: {
type: "object",
properties: {
key: {
type: "string",
},
},
required: ["key"],
additionalProperties: false,
},
strict: true,
},
}
body = <<~TEXT
{"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"}
{"is_finished":false,"event_type":"text-generation","text":"{\\""}
{"is_finished":false,"event_type":"text-generation","text":"key"}
{"is_finished":false,"event_type":"text-generation","text":"\\":\\""}
{"is_finished":false,"event_type":"text-generation","text":"Hello!"}
{"is_finished":false,"event_type":"text-generation","text":"\\"}"}|
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"{\\"key\\":\\"Hello!\\"}","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"}
TEXT
parsed_body = nil
structured_output = nil
EndpointMock.with_chunk_array_support do
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"Authorization" => "Bearer ABC",
},
).to_return(status: 200, body: body.split("|"))
result =
llm.generate(prompt, response_format: schema, user: user) do |partial, cancel|
structured_output = partial
end
end
expect(parsed_body[:preamble]).to eq("You are hello bot")
expect(parsed_body[:chat_history]).to eq(
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
)
expect(parsed_body[:message]).to eq("user1: thanks")
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
end
end

View File

@ -420,7 +420,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
body:
proc do |_req_body|
req_body = _req_body
true
_req_body
end,
).to_return(status: 200, body: response)
@ -433,4 +433,67 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
{ function_calling_config: { mode: "ANY", allowed_function_names: ["echo"] } },
)
end
describe "structured output via JSON Schema" do
it "forces the response to be a JSON" do
schema = {
type: "json_schema",
json_schema: {
name: "reply",
schema: {
type: "object",
properties: {
key: {
type: "string",
},
},
required: ["key"],
additionalProperties: false,
},
strict: true,
},
}
response = <<~TEXT.strip
data: {"candidates": [{"content": {"parts": [{"text": "{\\""}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
data: {"candidates": [{"content": {"parts": [{"text": "key"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
data: {"candidates": [{"content": {"parts": [{"text": "\\":\\""}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
data: {"candidates": [{"content": {"parts": [{"text": "Hello!"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
data: {"candidates": [{"content": {"parts": [{"text": "\\"}"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
data: {"candidates": [{"finishReason": "MALFORMED_FUNCTION_CALL"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
TEXT
req_body = nil
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
stub_request(:post, url).with(
body:
proc do |_req_body|
req_body = _req_body
true
end,
).to_return(status: 200, body: response)
structured_response = nil
llm.generate("Hello", response_format: schema, user: user) do |partial|
structured_response = partial
end
expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" })
parsed = JSON.parse(req_body, symbolize_names: true)
# Verify that schema is passed following Gemini API specs.
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema)
expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
end
end
end

View File

@ -0,0 +1,69 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::StructuredOutput do
subject(:structured_output) do
described_class.new(
{
message: {
type: "string",
},
bool: {
type: "boolean",
},
number: {
type: "integer",
},
status: {
type: "string",
},
},
)
end
describe "Parsing structured output on the fly" do
it "acts as a buffer for an streamed JSON" do
chunks = [
+"{\"message\": \"Line 1\\n",
+"Line 2\\n",
+"Line 3\", ",
+"\"bool\": true,",
+"\"number\": 4",
+"2,",
+"\"status\": \"o",
+"\\\"k\\\"\"}",
]
structured_output << chunks[0]
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" })
structured_output << chunks[1]
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" })
structured_output << chunks[2]
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" })
structured_output << chunks[3]
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
# Waiting for number to be fully buffered.
structured_output << chunks[4]
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
structured_output << chunks[5]
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
structured_output << chunks[6]
expect(structured_output.read_latest_buffered_chunk).to eq(
{ bool: true, number: 42, status: "o" },
)
structured_output << chunks[7]
expect(structured_output.read_latest_buffered_chunk).to eq(
{ bool: true, number: 42, status: "\"k\"" },
)
# No partial string left to read.
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
end
end
end

View File

@ -23,17 +23,17 @@ RSpec.describe DiscourseAi::Summarization::FoldContent do
llm_model.update!(max_prompt_tokens: model_tokens)
end
let(:single_summary) { "this is a summary" }
let(:summary) { "this is a summary" }
fab!(:user)
it "summarizes the content" do
result =
DiscourseAi::Completions::Llm.with_prepared_responses([single_summary]) do |spy|
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do |spy|
summarizer.summarize(user).tap { expect(spy.completions).to eq(1) }
end
expect(result.summarized_text).to eq(single_summary)
expect(result.summarized_text).to eq(summary)
end
end

View File

@ -267,6 +267,7 @@ RSpec.describe AiPersona do
description: "system persona",
system_prompt: "system persona",
tools: %w[Search Time],
response_format: [{ key: "summary", type: "string" }],
system: true,
)
end
@ -285,6 +286,22 @@ RSpec.describe AiPersona do
I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"),
)
end
it "doesn't accept response format changes" do
new_format = [{ key: "summary2", type: "string" }]
expect { system_persona.update!(response_format: new_format) }.to raise_error(
ActiveRecord::RecordInvalid,
)
end
it "doesn't accept additional format changes" do
new_format = [{ key: "summary", type: "string" }, { key: "summary2", type: "string" }]
expect { system_persona.update!(response_format: new_format) }.to raise_error(
ActiveRecord::RecordInvalid,
)
end
end
end
end

View File

@ -185,6 +185,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
default_llm_id: llm_model.id,
question_consolidator_llm_id: llm_model.id,
forced_tool_count: 2,
response_format: [{ key: "summary", type: "string" }],
}
end
@ -209,6 +210,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(persona_json["allow_chat_channel_mentions"]).to eq(true)
expect(persona_json["allow_chat_direct_messages"]).to eq(true)
expect(persona_json["question_consolidator_llm_id"]).to eq(llm_model.id)
expect(persona_json["response_format"].map { |rf| rf["key"] }).to contain_exactly(
"summary",
)
persona = AiPersona.find(persona_json["id"])

View File

@ -74,6 +74,7 @@ RSpec.describe DiscourseAi::Summarization::SummaryController do
it "returns a summary" do
summary_text = "This is a summary"
DiscourseAi::Completions::Llm.with_prepared_responses([summary_text]) do
get "/discourse-ai/summarization/t/#{topic.id}.json"

View File

@ -13,6 +13,8 @@ describe DiscourseAi::TopicSummarization do
let(:strategy) { DiscourseAi::Summarization.topic_summary(topic) }
let(:summary) { "This is the final summary" }
describe "#summarize" do
subject(:summarization) { described_class.new(strategy, user) }
@ -27,8 +29,6 @@ describe DiscourseAi::TopicSummarization do
end
context "when the content was summarized in a single chunk" do
let(:summary) { "This is the final summary" }
it "caches the summary" do
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
section = summarization.summarize
@ -54,7 +54,6 @@ describe DiscourseAi::TopicSummarization do
describe "invalidating cached summaries" do
let(:cached_text) { "This is a cached summary" }
let(:updated_summary) { "This is the final summary" }
def cached_summary
AiSummary.find_by(target: topic, summary_type: AiSummary.summary_types[:complete])
@ -86,10 +85,10 @@ describe DiscourseAi::TopicSummarization do
before { cached_summary.update!(original_content_sha: "outdated_sha") }
it "returns a new summary" do
DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
section = summarization.summarize
expect(section.summarized_text).to eq(updated_summary)
expect(section.summarized_text).to eq(summary)
end
end
@ -106,10 +105,10 @@ describe DiscourseAi::TopicSummarization do
end
it "returns a new summary if the skip_age_check flag is passed" do
DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
section = summarization.summarize(skip_age_check: true)
expect(section.summarized_text).to eq(updated_summary)
expect(section.summarized_text).to eq(summary)
end
end
end
@ -118,8 +117,6 @@ describe DiscourseAi::TopicSummarization do
end
describe "stream partial updates" do
let(:summary) { "This is the final summary" }
it "receives a blk that is passed to the underlying strategy and called with partial summaries" do
partial_result = +""
@ -127,7 +124,8 @@ describe DiscourseAi::TopicSummarization do
summarization.summarize { |partial_summary| partial_result << partial_summary }
end
expect(partial_result).to eq(summary)
# In a real world example, this is removed in the returned AiSummary obj.
expect(partial_result.chomp("\"}")).to eq(summary)
end
end
end

View File

@ -12,7 +12,9 @@ RSpec.describe "Summarize a topic ", type: :system do
"I like to eat pie. It is a very good dessert. Some people are wasteful by throwing pie at others but I do not do that. I always eat the pie.",
)
end
let(:summarization_result) { "This is a summary" }
let(:topic_page) { PageObjects::Pages::Topic.new }
let(:summary_box) { PageObjects::Components::AiSummaryTrigger.new }