diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb
index ca3059fc..58a61e0e 100644
--- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb
+++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb
@@ -221,6 +221,10 @@ module DiscourseAi
permitted[:tools] = permit_tools(tools)
end
+ if response_format = params.dig(:ai_persona, :response_format)
+ permitted[:response_format] = permit_response_format(response_format)
+ end
+
permitted
end
@@ -235,6 +239,18 @@ module DiscourseAi
[tool, options, !!force_tool]
end
end
+
+ def permit_response_format(response_format)
+ return [] if !response_format.is_a?(Array)
+
+ response_format.map do |element|
+ if element && element.is_a?(ActionController::Parameters)
+ element.permit!
+ else
+ false
+ end
+ end
+ end
end
end
end
diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb
index e654cffc..8347a414 100644
--- a/app/models/ai_persona.rb
+++ b/app/models/ai_persona.rb
@@ -325,9 +325,11 @@ class AiPersona < ActiveRecord::Base
end
def system_persona_unchangeable
+ error_msg = I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona")
+
if top_p_changed? || temperature_changed? || system_prompt_changed? || name_changed? ||
description_changed?
- errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"))
+ errors.add(:base, error_msg)
elsif tools_changed?
old_tools = tools_change[0]
new_tools = tools_change[1]
@@ -335,9 +337,12 @@ class AiPersona < ActiveRecord::Base
old_tool_names = old_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set
new_tool_names = new_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set
- if old_tool_names != new_tool_names
- errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"))
- end
+ errors.add(:base, error_msg) if old_tool_names != new_tool_names
+ elsif response_format_changed?
+ old_format = response_format_change[0].map { |f| f["key"] }.to_set
+ new_format = response_format_change[1].map { |f| f["key"] }.to_set
+
+ errors.add(:base, error_msg) if old_format != new_format
end
end
@@ -395,6 +400,7 @@ end
# rag_llm_model_id :bigint
# default_llm_id :bigint
# question_consolidator_llm_id :bigint
+# response_format :jsonb
#
# Indexes
#
diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb
index dde41dfb..57945e29 100644
--- a/app/serializers/localized_ai_persona_serializer.rb
+++ b/app/serializers/localized_ai_persona_serializer.rb
@@ -30,7 +30,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:allow_chat_direct_messages,
:allow_topic_mentions,
:allow_personal_messages,
- :force_default_llm
+ :force_default_llm,
+ :response_format
has_one :user, serializer: BasicUserSerializer, embed: :object
has_many :rag_uploads, serializer: UploadSerializer, embed: :object
diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js
index 4c7ff28e..f313dee5 100644
--- a/assets/javascripts/discourse/admin/models/ai-persona.js
+++ b/assets/javascripts/discourse/admin/models/ai-persona.js
@@ -33,6 +33,7 @@ const CREATE_ATTRIBUTES = [
"allow_topic_mentions",
"allow_chat_channel_mentions",
"allow_chat_direct_messages",
+ "response_format",
];
const SYSTEM_ATTRIBUTES = [
@@ -60,6 +61,7 @@ const SYSTEM_ATTRIBUTES = [
"allow_topic_mentions",
"allow_chat_channel_mentions",
"allow_chat_direct_messages",
+ "response_format",
];
export default class AiPersona extends RestModel {
@@ -151,6 +153,7 @@ export default class AiPersona extends RestModel {
const attrs = this.getProperties(CREATE_ATTRIBUTES);
this.populateTools(attrs);
attrs.forced_tool_count = this.forced_tool_count || -1;
+ attrs.response_format = attrs.response_format || [];
return attrs;
}
diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs
index edaf13c8..f0b56d8a 100644
--- a/assets/javascripts/discourse/components/ai-persona-editor.gjs
+++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs
@@ -15,6 +15,7 @@ import Group from "discourse/models/group";
import { i18n } from "discourse-i18n";
import AdminUser from "admin/models/admin-user";
import GroupChooser from "select-kit/components/group-chooser";
+import AiPersonaResponseFormatEditor from "../components/modal/ai-persona-response-format-editor";
import AiLlmSelector from "./ai-llm-selector";
import AiPersonaToolOptions from "./ai-persona-tool-options";
import AiToolSelector from "./ai-tool-selector";
@@ -325,6 +326,8 @@ export default class PersonaEditor extends Component {
+
+
{
+ toDisplay[keyDesc.key] = keyDesc.type;
+ });
+
+ return prettyJSON(toDisplay);
+ }
+
+ @action
+ openModal() {
+ this.showJsonEditorModal = true;
+ }
+
+ @action
+ closeModal() {
+ this.showJsonEditorModal = false;
+ }
+
+ @action
+ updateResponseFormat(form, value) {
+ form.set("response_format", JSON.parse(value));
+ }
+
+
+ <@form.Container @title={{this.editorTitle}} @format="large">
+
+ @form.Container>
+
+ {{#if this.showJsonEditorModal}}
+
+ {{/if}}
+
+}
diff --git a/assets/stylesheets/modules/ai-bot/common/ai-persona.scss b/assets/stylesheets/modules/ai-bot/common/ai-persona.scss
index 39eeeb5f..a5fdfe5b 100644
--- a/assets/stylesheets/modules/ai-bot/common/ai-persona.scss
+++ b/assets/stylesheets/modules/ai-bot/common/ai-persona.scss
@@ -52,6 +52,21 @@
margin-bottom: 10px;
font-size: var(--font-down-1);
}
+
+ &__response-format {
+ width: 100%;
+ display: block;
+ }
+
+ &__response-format-pre {
+ margin-bottom: 0;
+ white-space: pre-line;
+ }
+
+ &__response-format-none {
+ margin-bottom: 1em;
+ margin-top: 0.5em;
+ }
}
.rag-options {
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index 3a4fb037..0c2c98d2 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -323,6 +323,13 @@ en:
rag_conversation_chunks: "Search conversation chunks"
rag_conversation_chunks_help: "The number of chunks to use for the RAG model searches. Increase to increase the amount of context the AI can use."
persona_description: "Personas are a powerful feature that allows you to customize the behavior of the AI engine in your Discourse forum. They act as a 'system message' that guides the AI's responses and interactions, helping to create a more personalized and engaging user experience."
+ response_format:
+ title: "JSON response format"
+ no_format: "No JSON format specified"
+ open_modal: "Edit"
+ modal:
+ root_title: "Response structure"
+ key_title: "Key"
list:
enabled: "AI Bot?"
diff --git a/db/fixtures/personas/603_ai_personas.rb b/db/fixtures/personas/603_ai_personas.rb
index b0d4854d..7d52e8a9 100644
--- a/db/fixtures/personas/603_ai_personas.rb
+++ b/db/fixtures/personas/603_ai_personas.rb
@@ -72,6 +72,8 @@ DiscourseAi::Personas::Persona.system_personas.each do |persona_class, id|
persona.tools = tools.map { |name, value| [name, value] }
+ persona.response_format = instance.response_format
+
persona.system_prompt = instance.system_prompt
persona.top_p = instance.top_p
persona.temperature = instance.temperature
diff --git a/db/migrate/20250411121705_add_response_format_json_to_personass.rb b/db/migrate/20250411121705_add_response_format_json_to_personass.rb
new file mode 100644
index 00000000..2f5a3e10
--- /dev/null
+++ b/db/migrate/20250411121705_add_response_format_json_to_personass.rb
@@ -0,0 +1,6 @@
+# frozen_string_literal: true
+class AddResponseFormatJsonToPersonass < ActiveRecord::Migration[7.2]
+ def change
+ add_column :ai_personas, :response_format, :jsonb
+ end
+end
diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb
index e67be109..a8ba9b4a 100644
--- a/lib/completions/anthropic_message_processor.rb
+++ b/lib/completions/anthropic_message_processor.rb
@@ -10,7 +10,7 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
@raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser =
- DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
+ DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls
end
def append(json)
diff --git a/lib/completions/dialects/nova.rb b/lib/completions/dialects/nova.rb
index b078e79d..9dc88097 100644
--- a/lib/completions/dialects/nova.rb
+++ b/lib/completions/dialects/nova.rb
@@ -42,6 +42,7 @@ module DiscourseAi
result = { system: system, messages: messages }
result[:inferenceConfig] = inference_config if inference_config.present?
result[:toolConfig] = tool_config if tool_config.present?
+ result[:response_format] = { type: "json_object" } if options[:response_format].present?
result
end
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index dd44a82b..4155db2a 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -88,10 +88,16 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support?
- payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
+ payload =
+ default_options(dialect).merge(model_params.except(:response_format)).merge(
+ messages: prompt.messages,
+ )
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
+
+ prefilled_message = +""
+
if prompt.has_tools?
payload[:tools] = prompt.tools
if dialect.tool_choice.present?
@@ -100,16 +106,24 @@ module DiscourseAi
# prefill prompt to nudge LLM to generate a response that is useful.
# without this LLM (even 3.7) can get confused and start text preambles for a tool calls.
- payload[:messages] << {
- role: "assistant",
- content: dialect.no_more_tool_calls_text,
- }
+ prefilled_message << dialect.no_more_tool_calls_text
else
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
end
end
end
+ # Prefill prompt to force JSON output.
+ if model_params[:response_format].present?
+ prefilled_message << " " if !prefilled_message.empty?
+ prefilled_message << "{"
+ @forced_json_through_prefill = true
+ end
+
+ if !prefilled_message.empty?
+ payload[:messages] << { role: "assistant", content: prefilled_message }
+ end
+
payload
end
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index 915e9d3b..f4336ac1 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -116,9 +116,14 @@ module DiscourseAi
payload = nil
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
- payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
+ payload =
+ default_options(dialect).merge(model_params.except(:response_format)).merge(
+ messages: prompt.messages,
+ )
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
+ prefilled_message = +""
+
if prompt.has_tools?
payload[:tools] = prompt.tools
if dialect.tool_choice.present?
@@ -128,15 +133,23 @@ module DiscourseAi
# payload[:tool_choice] = { type: "none" }
# prefill prompt to nudge LLM to generate a response that is useful, instead of trying to call a tool
- payload[:messages] << {
- role: "assistant",
- content: dialect.no_more_tool_calls_text,
- }
+ prefilled_message << dialect.no_more_tool_calls_text
else
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
end
end
end
+
+ # Prefill prompt to force JSON output.
+ if model_params[:response_format].present?
+ prefilled_message << " " if !prefilled_message.empty?
+ prefilled_message << "{"
+ @forced_json_through_prefill = true
+ end
+
+ if !prefilled_message.empty?
+ payload[:messages] << { role: "assistant", content: prefilled_message }
+ end
elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova)
payload = prompt.to_payload(default_options(dialect).merge(model_params))
else
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index 800381e4..7f79a819 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -73,6 +73,7 @@ module DiscourseAi
LlmQuota.check_quotas!(@llm_model, user)
start_time = Time.now
+ @forced_json_through_prefill = false
@partial_tool_calls = partial_tool_calls
@output_thinking = output_thinking
@@ -106,6 +107,18 @@ module DiscourseAi
prompt = dialect.translate
+ structured_output = nil
+
+ if model_params[:response_format].present?
+ schema_properties =
+ model_params[:response_format].dig(:json_schema, :schema, :properties)
+
+ if schema_properties.present?
+ structured_output =
+ DiscourseAi::Completions::StructuredOutput.new(schema_properties)
+ end
+ end
+
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
@@ -123,6 +136,10 @@ module DiscourseAi
request = prepare_request(request_body)
+ # Some providers rely on prefill to return structured outputs, so the start
+ # of the JSON won't be included in the response. Supply it to keep JSON valid.
+ structured_output << +"{" if structured_output && @forced_json_through_prefill
+
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
@@ -140,10 +157,17 @@ module DiscourseAi
xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
- if @streaming_mode && xml_stripper
+ if @streaming_mode
blk =
lambda do |partial, cancel|
- partial = xml_stripper << partial if partial.is_a?(String)
+ if partial.is_a?(String)
+ partial = xml_stripper << partial if xml_stripper
+
+ if structured_output.present?
+ structured_output << partial
+ partial = structured_output
+ end
+ end
orig_blk.call(partial, cancel) if partial
end
end
@@ -167,6 +191,7 @@ module DiscourseAi
xml_stripper: xml_stripper,
partials_raw: partials_raw,
response_raw: response_raw,
+ structured_output: structured_output,
)
return response_data
end
@@ -373,7 +398,8 @@ module DiscourseAi
xml_tool_processor:,
xml_stripper:,
partials_raw:,
- response_raw:
+ response_raw:,
+ structured_output:
)
response_raw << response.read_body
response_data = decode(response_raw)
@@ -403,6 +429,26 @@ module DiscourseAi
response_data.reject!(&:blank?)
+ if structured_output.present?
+ has_string_response = false
+
+ response_data =
+ response_data.reduce([]) do |memo, data|
+ if data.is_a?(String)
+ structured_output << data
+ has_string_response = true
+ next(memo)
+ else
+ memo << data
+ end
+
+ memo
+ end
+
+ # We only include the structured output if there was actually a structured response
+ response_data << structured_output if has_string_response
+ end
+
# this is to keep stuff backwards compatible
response_data = response_data.first if response_data.length == 1
diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb
index a7ccc1da..b4dafc74 100644
--- a/lib/completions/endpoints/canned_response.rb
+++ b/lib/completions/endpoints/canned_response.rb
@@ -40,6 +40,8 @@ module DiscourseAi
"The number of completions you requested exceed the number of canned responses"
end
+ response = as_structured_output(response) if model_params[:response_format].present?
+
raise response if response.is_a?(StandardError)
@completions += 1
@@ -54,6 +56,8 @@ module DiscourseAi
yield(response, cancel_fn)
elsif is_thinking?(response)
yield(response, cancel_fn)
+ elsif is_structured_output?(response)
+ yield(response, cancel_fn)
else
response.each_char do |char|
break if cancelled
@@ -80,6 +84,20 @@ module DiscourseAi
def is_tool?(response)
response.is_a?(DiscourseAi::Completions::ToolCall)
end
+
+ def is_structured_output?(response)
+ response.is_a?(DiscourseAi::Completions::StructuredOutput)
+ end
+
+ def as_structured_output(response)
+ schema_properties = model_params[:response_format].dig(:json_schema, :schema, :properties)
+ return response if schema_properties.blank?
+
+ output = DiscourseAi::Completions::StructuredOutput.new(schema_properties)
+ output << { schema_properties.keys.first => response }.to_json
+
+ output
+ end
end
end
end
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
index a99bb80b..e52fd46f 100644
--- a/lib/completions/endpoints/gemini.rb
+++ b/lib/completions/endpoints/gemini.rb
@@ -84,7 +84,16 @@ module DiscourseAi
payload[:tool_config] = { function_calling_config: function_calling_config }
end
- payload[:generationConfig].merge!(model_params) if model_params.present?
+ if model_params.present?
+ payload[:generationConfig].merge!(model_params.except(:response_format))
+
+ if model_params[:response_format].present?
+ # https://ai.google.dev/api/generate-content#generationconfig
+ payload[:generationConfig][:responseSchema] = model_params[:response_format]
+ payload[:generationConfig][:responseMimeType] = "application/json"
+ end
+ end
+
payload
end
diff --git a/lib/completions/endpoints/samba_nova.rb b/lib/completions/endpoints/samba_nova.rb
index cc81e786..9e6b3817 100644
--- a/lib/completions/endpoints/samba_nova.rb
+++ b/lib/completions/endpoints/samba_nova.rb
@@ -34,7 +34,12 @@ module DiscourseAi
end
def prepare_payload(prompt, model_params, dialect)
- payload = default_options.merge(model_params).merge(messages: prompt)
+ payload =
+ default_options.merge(model_params.except(:response_format)).merge(messages: prompt)
+
+ if model_params[:response_format].present?
+ payload[:response_format] = { type: "json_object" }
+ end
payload[:stream] = true if @streaming_mode
diff --git a/lib/completions/tool_call_progress_tracker.rb b/lib/completions/json_streaming_tracker.rb
similarity index 74%
rename from lib/completions/tool_call_progress_tracker.rb
rename to lib/completions/json_streaming_tracker.rb
index 0f6d9158..aa687ef1 100644
--- a/lib/completions/tool_call_progress_tracker.rb
+++ b/lib/completions/json_streaming_tracker.rb
@@ -2,11 +2,11 @@
module DiscourseAi
module Completions
- class ToolCallProgressTracker
- attr_reader :current_key, :current_value, :tool_call
+ class JsonStreamingTracker
+ attr_reader :current_key, :current_value, :stream_consumer
- def initialize(tool_call)
- @tool_call = tool_call
+ def initialize(stream_consumer)
+ @stream_consumer = stream_consumer
@current_key = nil
@current_value = nil
@parser = DiscourseAi::Completions::JsonStreamingParser.new
@@ -18,7 +18,7 @@ module DiscourseAi
@parser.value do |v|
if @current_key
- tool_call.notify_progress(@current_key, v)
+ stream_consumer.notify_progress(@current_key, v)
@current_key = nil
end
end
@@ -39,7 +39,7 @@ module DiscourseAi
if @parser.state == :start_string && @current_key
# this is is worth notifying
- tool_call.notify_progress(@current_key, @parser.buf)
+ stream_consumer.notify_progress(@current_key, @parser.buf)
end
@current_key = nil if @parser.state == :end_value
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index 5f82667e..f90de091 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -304,6 +304,8 @@ module DiscourseAi
# @param feature_context { Hash - Optional } - The feature context to use for the completion.
# @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls.
# @param output_thinking { Boolean - Optional } - If true, the completion will return the thinking output for thinking models.
+ # @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output.
+ # @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema.
#
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
#
@@ -321,6 +323,7 @@ module DiscourseAi
feature_context: nil,
partial_tool_calls: false,
output_thinking: false,
+ response_format: nil,
extra_model_params: nil,
&partial_read_blk
)
@@ -336,6 +339,7 @@ module DiscourseAi
feature_context: feature_context,
partial_tool_calls: partial_tool_calls,
output_thinking: output_thinking,
+ response_format: response_format,
extra_model_params: extra_model_params,
},
)
@@ -344,6 +348,7 @@ module DiscourseAi
model_params[:temperature] = temperature if temperature
model_params[:top_p] = top_p if top_p
+ model_params[:response_format] = response_format if response_format
model_params.merge!(extra_model_params) if extra_model_params
if prompt.is_a?(String)
diff --git a/lib/completions/nova_message_processor.rb b/lib/completions/nova_message_processor.rb
index efe54330..80710c9c 100644
--- a/lib/completions/nova_message_processor.rb
+++ b/lib/completions/nova_message_processor.rb
@@ -10,7 +10,7 @@ class DiscourseAi::Completions::NovaMessageProcessor
@raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser =
- DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
+ DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls
end
def append(json)
diff --git a/lib/completions/open_ai_message_processor.rb b/lib/completions/open_ai_message_processor.rb
index 33182995..979c2043 100644
--- a/lib/completions/open_ai_message_processor.rb
+++ b/lib/completions/open_ai_message_processor.rb
@@ -59,7 +59,7 @@ module DiscourseAi::Completions
if id.present? && name.present?
@tool_arguments = +""
@tool = ToolCall.new(id: id, name: name)
- @streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls
+ @streaming_parser = JsonStreamingTracker.new(self) if @partial_tool_calls
end
@tool_arguments << arguments.to_s
diff --git a/lib/completions/structured_output.rb b/lib/completions/structured_output.rb
new file mode 100644
index 00000000..fadf5722
--- /dev/null
+++ b/lib/completions/structured_output.rb
@@ -0,0 +1,52 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ class StructuredOutput
+ def initialize(json_schema_properties)
+ @property_names = json_schema_properties.keys.map(&:to_sym)
+ @property_cursors =
+ json_schema_properties.reduce({}) do |m, (k, prop)|
+ m[k.to_sym] = 0 if prop[:type] == "string"
+ m
+ end
+
+ @tracked = {}
+
+ @partial_json_tracker = JsonStreamingTracker.new(self)
+ end
+
+ attr_reader :last_chunk_buffer
+
+ def <<(raw)
+ @partial_json_tracker << raw
+ end
+
+ def read_latest_buffered_chunk
+ @property_names.reduce({}) do |memo, pn|
+ if @tracked[pn].present?
+ # This means this property is a string and we want to return unread chunks.
+ if @property_cursors[pn].present?
+ unread = @tracked[pn][@property_cursors[pn]..]
+
+ memo[pn] = unread if unread.present?
+ @property_cursors[pn] = @tracked[pn].length
+ else
+ # Ints and bools are always returned as is.
+ memo[pn] = @tracked[pn]
+ end
+ end
+
+ memo
+ end
+ end
+
+ def notify_progress(key, value)
+ key_sym = key.to_sym
+ return if !@property_names.include?(key_sym)
+
+ @tracked[key_sym] = value
+ end
+ end
+ end
+end
diff --git a/lib/personas/bot.rb b/lib/personas/bot.rb
index aa701abc..2c9b5a3b 100644
--- a/lib/personas/bot.rb
+++ b/lib/personas/bot.rb
@@ -64,10 +64,13 @@ module DiscourseAi
user = context.user
- llm_kwargs = { user: user }
+ llm_kwargs = llm_args.dup
+ llm_kwargs[:user] = user
llm_kwargs[:temperature] = persona.temperature if persona.temperature
llm_kwargs[:top_p] = persona.top_p if persona.top_p
- llm_kwargs[:max_tokens] = llm_args[:max_tokens] if llm_args[:max_tokens].present?
+ llm_kwargs[:response_format] = build_json_schema(
+ persona.response_format,
+ ) if persona.response_format.present?
needs_newlines = false
tools_ran = 0
@@ -148,6 +151,8 @@ module DiscourseAi
raw_context << partial
current_thinking << partial
end
+ elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput)
+ update_blk.call(partial, cancel, nil, :structured_output)
else
update_blk.call(partial, cancel)
end
@@ -176,6 +181,10 @@ module DiscourseAi
embed_thinking(raw_context)
end
+ def returns_json?
+ persona.response_format.present?
+ end
+
private
def embed_thinking(raw_context)
@@ -301,6 +310,30 @@ module DiscourseAi
placeholder
end
+
+ def build_json_schema(response_format)
+ properties =
+ response_format
+ .to_a
+ .reduce({}) do |memo, format|
+ memo[format[:key].to_sym] = { type: format[:type] }
+ memo
+ end
+
+ {
+ type: "json_schema",
+ json_schema: {
+ name: "reply",
+ schema: {
+ type: "object",
+ properties: properties,
+ required: properties.keys.map(&:to_s),
+ additionalProperties: false,
+ },
+ strict: true,
+ },
+ }
+ end
end
end
end
diff --git a/lib/personas/bot_context.rb b/lib/personas/bot_context.rb
index 94a67010..5f7dd99e 100644
--- a/lib/personas/bot_context.rb
+++ b/lib/personas/bot_context.rb
@@ -49,6 +49,8 @@ module DiscourseAi
@site_title = site_title
@site_description = site_description
@time = time
+ @resource_url = resource_url
+
@feature_name = feature_name
@resource_url = resource_url
diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb
index 3332e5a9..a8b08785 100644
--- a/lib/personas/persona.rb
+++ b/lib/personas/persona.rb
@@ -160,6 +160,10 @@ module DiscourseAi
{}
end
+ def response_format
+ nil
+ end
+
def available_tools
self
.class
diff --git a/lib/personas/short_summarizer.rb b/lib/personas/short_summarizer.rb
index d460d6c0..e7cef54a 100644
--- a/lib/personas/short_summarizer.rb
+++ b/lib/personas/short_summarizer.rb
@@ -18,9 +18,19 @@ module DiscourseAi
- Limit the summary to a maximum of 40 words.
- Do *NOT* repeat the discussion title in the summary.
- Return the summary inside tags.
+ Format your response as a JSON object with a single key named "summary", which has the summary as the value.
+ Your output should be in the following format:
+
+
+ Where "xx" is replaced by the summary.
PROMPT
end
+
+ def response_format
+ [{ key: "summary", type: "string" }]
+ end
end
end
end
diff --git a/lib/personas/summarizer.rb b/lib/personas/summarizer.rb
index a2f81463..64540ff0 100644
--- a/lib/personas/summarizer.rb
+++ b/lib/personas/summarizer.rb
@@ -18,8 +18,20 @@ module DiscourseAi
- Example: link to the 6th post by jane: [agreed with]({resource_url}/6)
- Example: link to the 13th post by joe: [joe]({resource_url}/13)
- When formatting usernames either use @USERNAME OR [USERNAME]({resource_url}/POST_NUMBER)
+
+ Format your response as a JSON object with a single key named "summary", which has the summary as the value.
+ Your output should be in the following format:
+
+
+ Where "xx" is replaced by the summary.
PROMPT
end
+
+ def response_format
+ [{ key: "summary", type: "string" }]
+ end
end
end
end
diff --git a/lib/summarization/fold_content.rb b/lib/summarization/fold_content.rb
index 6fcd1876..15800087 100644
--- a/lib/summarization/fold_content.rb
+++ b/lib/summarization/fold_content.rb
@@ -29,18 +29,10 @@ module DiscourseAi
summary = fold(truncated_content, user, &on_partial_blk)
- clean_summary = Nokogiri::HTML5.fragment(summary).css("ai")&.first&.text || summary
-
if persist_summaries
- AiSummary.store!(
- strategy,
- llm_model,
- clean_summary,
- truncated_content,
- human: user&.human?,
- )
+ AiSummary.store!(strategy, llm_model, summary, truncated_content, human: user&.human?)
else
- AiSummary.new(summarized_text: clean_summary)
+ AiSummary.new(summarized_text: summary)
end
end
@@ -118,9 +110,20 @@ module DiscourseAi
)
summary = +""
+
buffer_blk =
- Proc.new do |partial, cancel, placeholder, type|
- if type.blank?
+ Proc.new do |partial, cancel, _, type|
+ if type == :structured_output
+ json_summary_schema_key = bot.persona.response_format&.first.to_h
+ partial_summary =
+ partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym]
+
+ if partial_summary.present?
+ summary << partial_summary
+ on_partial_blk.call(partial_summary, cancel) if on_partial_blk
+ end
+ elsif type.blank?
+ # Assume response is a regular completion.
summary << partial
on_partial_blk.call(partial, cancel) if on_partial_blk
end
@@ -154,12 +157,6 @@ module DiscourseAi
item
end
-
- def text_only_update(&on_partial_blk)
- Proc.new do |partial, cancel, placeholder, type|
- on_partial_blk.call(partial, cancel) if type.blank?
- end
- end
end
end
end
diff --git a/spec/jobs/regular/fast_track_topic_gist_spec.rb b/spec/jobs/regular/fast_track_topic_gist_spec.rb
index 3eccc006..ef7dbc47 100644
--- a/spec/jobs/regular/fast_track_topic_gist_spec.rb
+++ b/spec/jobs/regular/fast_track_topic_gist_spec.rb
@@ -21,6 +21,7 @@ RSpec.describe Jobs::FastTrackTopicGist do
created_at: 10.minutes.ago,
)
end
+
let(:updated_gist) { "They updated me :(" }
context "when it's up to date" do
diff --git a/spec/jobs/regular/stream_topic_ai_summary_spec.rb b/spec/jobs/regular/stream_topic_ai_summary_spec.rb
index 87848560..1591e701 100644
--- a/spec/jobs/regular/stream_topic_ai_summary_spec.rb
+++ b/spec/jobs/regular/stream_topic_ai_summary_spec.rb
@@ -51,7 +51,9 @@ RSpec.describe Jobs::StreamTopicAiSummary do
end
it "publishes updates with a partial summary" do
- with_responses(["dummy"]) do
+ summary = "dummy"
+
+ with_responses([summary]) do
messages =
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
job.execute(topic_id: topic.id, user_id: user.id)
@@ -59,12 +61,16 @@ RSpec.describe Jobs::StreamTopicAiSummary do
partial_summary_update = messages.first.data
expect(partial_summary_update[:done]).to eq(false)
- expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text)).to eq("dummy")
+ expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text).chomp("\"}")).to eq(
+ summary,
+ )
end
end
it "publishes a final update to signal we're done and provide metadata" do
- with_responses(["dummy"]) do
+ summary = "dummy"
+
+ with_responses([summary]) do
messages =
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
job.execute(topic_id: topic.id, user_id: user.id)
@@ -73,6 +79,7 @@ RSpec.describe Jobs::StreamTopicAiSummary do
final_update = messages.last.data
expect(final_update[:done]).to eq(true)
+ expect(final_update.dig(:ai_topic_summary, :summarized_text)).to eq(summary)
expect(final_update.dig(:ai_topic_summary, :algorithm)).to eq("fake")
expect(final_update.dig(:ai_topic_summary, :outdated)).to eq(false)
expect(final_update.dig(:ai_topic_summary, :can_regenerate)).to eq(true)
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index ca3bcb46..b07586e0 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -590,7 +590,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: {"type": "content_block_stop", "index": 0}
event: content_block_start
-data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} }
+ data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} }
event: ping
data: {"type": "ping"}
@@ -769,4 +769,92 @@ data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_
expect(result).to eq("I won't use any tools. Here's a direct response instead.")
end
end
+
+ describe "structured output via prefilling" do
+ it "forces the response to be a JSON and using the given JSON schema" do
+ schema = {
+ type: "json_schema",
+ json_schema: {
+ name: "reply",
+ schema: {
+ type: "object",
+ properties: {
+ key: {
+ type: "string",
+ },
+ },
+ required: ["key"],
+ additionalProperties: false,
+ },
+ strict: true,
+ },
+ }
+
+ body = (<<~STRING).strip
+ event: message_start
+ data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
+
+ event: content_block_start
+ data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
+
+ event: content_block_delta
+ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\""}}
+
+ event: content_block_delta
+ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "key"}}
+
+ event: content_block_delta
+ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\":\\""}}
+
+ event: content_block_delta
+ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello!"}}
+
+ event: content_block_delta
+ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\"}"}}
+
+ event: content_block_stop
+ data: {"type": "content_block_stop", "index": 0}
+
+ event: message_delta
+ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
+
+ event: message_stop
+ data: {"type": "message_stop"}
+ STRING
+
+ parsed_body = nil
+
+ stub_request(:post, url).with(
+ body:
+ proc do |req_body|
+ parsed_body = JSON.parse(req_body, symbolize_names: true)
+ true
+ end,
+ headers: {
+ "Content-Type" => "application/json",
+ "X-Api-Key" => "123",
+ "Anthropic-Version" => "2023-06-01",
+ },
+ ).to_return(status: 200, body: body)
+
+ structured_output = nil
+ llm.generate(
+ prompt,
+ user: Discourse.system_user,
+ feature_name: "testing",
+ response_format: schema,
+ ) { |partial, cancel| structured_output = partial }
+
+ expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
+
+ expected_body = {
+ model: "claude-3-opus-20240229",
+ max_tokens: 4096,
+ messages: [{ role: "user", content: "user1: hello" }, { role: "assistant", content: "{" }],
+ system: "You are hello bot",
+ stream: true,
+ }
+ expect(parsed_body).to eq(expected_body)
+ end
+ end
end
diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
index 3a424451..e5e5d8b7 100644
--- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb
+++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
@@ -546,4 +546,69 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
)
end
end
+
+ describe "structured output via prefilling" do
+ it "forces the response to be a JSON and using the given JSON schema" do
+ schema = {
+ type: "json_schema",
+ json_schema: {
+ name: "reply",
+ schema: {
+ type: "object",
+ properties: {
+ key: {
+ type: "string",
+ },
+ },
+ required: ["key"],
+ additionalProperties: false,
+ },
+ strict: true,
+ },
+ }
+
+ messages =
+ [
+ { type: "message_start", message: { usage: { input_tokens: 9 } } },
+ { type: "content_block_delta", delta: { text: "\"" } },
+ { type: "content_block_delta", delta: { text: "key" } },
+ { type: "content_block_delta", delta: { text: "\":\"" } },
+ { type: "content_block_delta", delta: { text: "Hello!" } },
+ { type: "content_block_delta", delta: { text: "\"}" } },
+ { type: "message_delta", delta: { usage: { output_tokens: 25 } } },
+ ].map { |message| encode_message(message) }
+
+ proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
+ request = nil
+ bedrock_mock.with_chunk_array_support do
+ stub_request(
+ :post,
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
+ )
+ .with do |inner_request|
+ request = inner_request
+ true
+ end
+ .to_return(status: 200, body: messages)
+
+ structured_output = nil
+ proxy.generate("hello world", response_format: schema, user: user) do |partial|
+ structured_output = partial
+ end
+
+ expected = {
+ "max_tokens" => 4096,
+ "anthropic_version" => "bedrock-2023-05-31",
+ "messages" => [
+ { "role" => "user", "content" => "hello world" },
+ { "role" => "assistant", "content" => "{" },
+ ],
+ "system" => "You are a helpful bot",
+ }
+ expect(JSON.parse(request.body)).to eq(expected)
+
+ expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
+ end
+ end
+ end
end
diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb
index bdff8fc3..d8ae70ef 100644
--- a/spec/lib/completions/endpoints/cohere_spec.rb
+++ b/spec/lib/completions/endpoints/cohere_spec.rb
@@ -308,4 +308,64 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
expect(audit.request_tokens).to eq(14)
expect(audit.response_tokens).to eq(11)
end
+
+ it "is able to return structured outputs" do
+ schema = {
+ type: "json_schema",
+ json_schema: {
+ name: "reply",
+ schema: {
+ type: "object",
+ properties: {
+ key: {
+ type: "string",
+ },
+ },
+ required: ["key"],
+ additionalProperties: false,
+ },
+ strict: true,
+ },
+ }
+
+ body = <<~TEXT
+ {"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"}
+ {"is_finished":false,"event_type":"text-generation","text":"{\\""}
+ {"is_finished":false,"event_type":"text-generation","text":"key"}
+ {"is_finished":false,"event_type":"text-generation","text":"\\":\\""}
+ {"is_finished":false,"event_type":"text-generation","text":"Hello!"}
+ {"is_finished":false,"event_type":"text-generation","text":"\\"}"}|
+ {"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"{\\"key\\":\\"Hello!\\"}","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"}
+ TEXT
+
+ parsed_body = nil
+ structured_output = nil
+
+ EndpointMock.with_chunk_array_support do
+ stub_request(:post, "https://api.cohere.ai/v1/chat").with(
+ body:
+ proc do |req_body|
+ parsed_body = JSON.parse(req_body, symbolize_names: true)
+ true
+ end,
+ headers: {
+ "Content-Type" => "application/json",
+ "Authorization" => "Bearer ABC",
+ },
+ ).to_return(status: 200, body: body.split("|"))
+
+ result =
+ llm.generate(prompt, response_format: schema, user: user) do |partial, cancel|
+ structured_output = partial
+ end
+ end
+
+ expect(parsed_body[:preamble]).to eq("You are hello bot")
+ expect(parsed_body[:chat_history]).to eq(
+ [{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
+ )
+ expect(parsed_body[:message]).to eq("user1: thanks")
+
+ expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
+ end
end
diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb
index 05a87b41..a4355d92 100644
--- a/spec/lib/completions/endpoints/gemini_spec.rb
+++ b/spec/lib/completions/endpoints/gemini_spec.rb
@@ -420,7 +420,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
body:
proc do |_req_body|
req_body = _req_body
- true
+ _req_body
end,
).to_return(status: 200, body: response)
@@ -433,4 +433,67 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
{ function_calling_config: { mode: "ANY", allowed_function_names: ["echo"] } },
)
end
+
+ describe "structured output via JSON Schema" do
+ it "forces the response to be a JSON" do
+ schema = {
+ type: "json_schema",
+ json_schema: {
+ name: "reply",
+ schema: {
+ type: "object",
+ properties: {
+ key: {
+ type: "string",
+ },
+ },
+ required: ["key"],
+ additionalProperties: false,
+ },
+ strict: true,
+ },
+ }
+
+ response = <<~TEXT.strip
+ data: {"candidates": [{"content": {"parts": [{"text": "{\\""}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
+
+ data: {"candidates": [{"content": {"parts": [{"text": "key"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
+
+ data: {"candidates": [{"content": {"parts": [{"text": "\\":\\""}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
+
+ data: {"candidates": [{"content": {"parts": [{"text": "Hello!"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
+
+ data: {"candidates": [{"content": {"parts": [{"text": "\\"}"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
+
+ data: {"candidates": [{"finishReason": "MALFORMED_FUNCTION_CALL"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
+
+ TEXT
+
+ req_body = nil
+
+ llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
+ url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
+
+ stub_request(:post, url).with(
+ body:
+ proc do |_req_body|
+ req_body = _req_body
+ true
+ end,
+ ).to_return(status: 200, body: response)
+
+ structured_response = nil
+ llm.generate("Hello", response_format: schema, user: user) do |partial|
+ structured_response = partial
+ end
+
+ expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" })
+
+ parsed = JSON.parse(req_body, symbolize_names: true)
+
+ # Verify that schema is passed following Gemini API specs.
+ expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema)
+ expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
+ end
+ end
end
diff --git a/spec/lib/completions/structured_output_spec.rb b/spec/lib/completions/structured_output_spec.rb
new file mode 100644
index 00000000..178f5ab3
--- /dev/null
+++ b/spec/lib/completions/structured_output_spec.rb
@@ -0,0 +1,69 @@
+# frozen_string_literal: true
+
+RSpec.describe DiscourseAi::Completions::StructuredOutput do
+ subject(:structured_output) do
+ described_class.new(
+ {
+ message: {
+ type: "string",
+ },
+ bool: {
+ type: "boolean",
+ },
+ number: {
+ type: "integer",
+ },
+ status: {
+ type: "string",
+ },
+ },
+ )
+ end
+
+ describe "Parsing structured output on the fly" do
+ it "acts as a buffer for an streamed JSON" do
+ chunks = [
+ +"{\"message\": \"Line 1\\n",
+ +"Line 2\\n",
+ +"Line 3\", ",
+ +"\"bool\": true,",
+ +"\"number\": 4",
+ +"2,",
+ +"\"status\": \"o",
+ +"\\\"k\\\"\"}",
+ ]
+
+ structured_output << chunks[0]
+ expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" })
+
+ structured_output << chunks[1]
+ expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" })
+
+ structured_output << chunks[2]
+ expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" })
+
+ structured_output << chunks[3]
+ expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
+
+ # Waiting for number to be fully buffered.
+ structured_output << chunks[4]
+ expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
+
+ structured_output << chunks[5]
+ expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
+
+ structured_output << chunks[6]
+ expect(structured_output.read_latest_buffered_chunk).to eq(
+ { bool: true, number: 42, status: "o" },
+ )
+
+ structured_output << chunks[7]
+ expect(structured_output.read_latest_buffered_chunk).to eq(
+ { bool: true, number: 42, status: "\"k\"" },
+ )
+
+ # No partial string left to read.
+ expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
+ end
+ end
+end
diff --git a/spec/lib/modules/summarization/fold_content_spec.rb b/spec/lib/modules/summarization/fold_content_spec.rb
index 7f6fafaf..d2497c70 100644
--- a/spec/lib/modules/summarization/fold_content_spec.rb
+++ b/spec/lib/modules/summarization/fold_content_spec.rb
@@ -23,17 +23,17 @@ RSpec.describe DiscourseAi::Summarization::FoldContent do
llm_model.update!(max_prompt_tokens: model_tokens)
end
- let(:single_summary) { "this is a summary" }
+ let(:summary) { "this is a summary" }
fab!(:user)
it "summarizes the content" do
result =
- DiscourseAi::Completions::Llm.with_prepared_responses([single_summary]) do |spy|
+ DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do |spy|
summarizer.summarize(user).tap { expect(spy.completions).to eq(1) }
end
- expect(result.summarized_text).to eq(single_summary)
+ expect(result.summarized_text).to eq(summary)
end
end
diff --git a/spec/models/ai_persona_spec.rb b/spec/models/ai_persona_spec.rb
index 2e3a55de..c776e699 100644
--- a/spec/models/ai_persona_spec.rb
+++ b/spec/models/ai_persona_spec.rb
@@ -267,6 +267,7 @@ RSpec.describe AiPersona do
description: "system persona",
system_prompt: "system persona",
tools: %w[Search Time],
+ response_format: [{ key: "summary", type: "string" }],
system: true,
)
end
@@ -285,6 +286,22 @@ RSpec.describe AiPersona do
I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"),
)
end
+
+ it "doesn't accept response format changes" do
+ new_format = [{ key: "summary2", type: "string" }]
+
+ expect { system_persona.update!(response_format: new_format) }.to raise_error(
+ ActiveRecord::RecordInvalid,
+ )
+ end
+
+ it "doesn't accept additional format changes" do
+ new_format = [{ key: "summary", type: "string" }, { key: "summary2", type: "string" }]
+
+ expect { system_persona.update!(response_format: new_format) }.to raise_error(
+ ActiveRecord::RecordInvalid,
+ )
+ end
end
end
end
diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb
index 7fbb016f..0e5662f5 100644
--- a/spec/requests/admin/ai_personas_controller_spec.rb
+++ b/spec/requests/admin/ai_personas_controller_spec.rb
@@ -185,6 +185,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
default_llm_id: llm_model.id,
question_consolidator_llm_id: llm_model.id,
forced_tool_count: 2,
+ response_format: [{ key: "summary", type: "string" }],
}
end
@@ -209,6 +210,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(persona_json["allow_chat_channel_mentions"]).to eq(true)
expect(persona_json["allow_chat_direct_messages"]).to eq(true)
expect(persona_json["question_consolidator_llm_id"]).to eq(llm_model.id)
+ expect(persona_json["response_format"].map { |rf| rf["key"] }).to contain_exactly(
+ "summary",
+ )
persona = AiPersona.find(persona_json["id"])
diff --git a/spec/requests/summarization/summary_controller_spec.rb b/spec/requests/summarization/summary_controller_spec.rb
index 01ab28bd..5c051f3d 100644
--- a/spec/requests/summarization/summary_controller_spec.rb
+++ b/spec/requests/summarization/summary_controller_spec.rb
@@ -74,6 +74,7 @@ RSpec.describe DiscourseAi::Summarization::SummaryController do
it "returns a summary" do
summary_text = "This is a summary"
+
DiscourseAi::Completions::Llm.with_prepared_responses([summary_text]) do
get "/discourse-ai/summarization/t/#{topic.id}.json"
diff --git a/spec/services/discourse_ai/topic_summarization_spec.rb b/spec/services/discourse_ai/topic_summarization_spec.rb
index a75dd527..444cfc67 100644
--- a/spec/services/discourse_ai/topic_summarization_spec.rb
+++ b/spec/services/discourse_ai/topic_summarization_spec.rb
@@ -13,6 +13,8 @@ describe DiscourseAi::TopicSummarization do
let(:strategy) { DiscourseAi::Summarization.topic_summary(topic) }
+ let(:summary) { "This is the final summary" }
+
describe "#summarize" do
subject(:summarization) { described_class.new(strategy, user) }
@@ -27,8 +29,6 @@ describe DiscourseAi::TopicSummarization do
end
context "when the content was summarized in a single chunk" do
- let(:summary) { "This is the final summary" }
-
it "caches the summary" do
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
section = summarization.summarize
@@ -54,7 +54,6 @@ describe DiscourseAi::TopicSummarization do
describe "invalidating cached summaries" do
let(:cached_text) { "This is a cached summary" }
- let(:updated_summary) { "This is the final summary" }
def cached_summary
AiSummary.find_by(target: topic, summary_type: AiSummary.summary_types[:complete])
@@ -86,10 +85,10 @@ describe DiscourseAi::TopicSummarization do
before { cached_summary.update!(original_content_sha: "outdated_sha") }
it "returns a new summary" do
- DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do
+ DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
section = summarization.summarize
- expect(section.summarized_text).to eq(updated_summary)
+ expect(section.summarized_text).to eq(summary)
end
end
@@ -106,10 +105,10 @@ describe DiscourseAi::TopicSummarization do
end
it "returns a new summary if the skip_age_check flag is passed" do
- DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do
+ DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
section = summarization.summarize(skip_age_check: true)
- expect(section.summarized_text).to eq(updated_summary)
+ expect(section.summarized_text).to eq(summary)
end
end
end
@@ -118,8 +117,6 @@ describe DiscourseAi::TopicSummarization do
end
describe "stream partial updates" do
- let(:summary) { "This is the final summary" }
-
it "receives a blk that is passed to the underlying strategy and called with partial summaries" do
partial_result = +""
@@ -127,7 +124,8 @@ describe DiscourseAi::TopicSummarization do
summarization.summarize { |partial_summary| partial_result << partial_summary }
end
- expect(partial_result).to eq(summary)
+ # In a real world example, this is removed in the returned AiSummary obj.
+ expect(partial_result.chomp("\"}")).to eq(summary)
end
end
end
diff --git a/spec/system/summarization/topic_summarization_spec.rb b/spec/system/summarization/topic_summarization_spec.rb
index 8a23772b..30b147a0 100644
--- a/spec/system/summarization/topic_summarization_spec.rb
+++ b/spec/system/summarization/topic_summarization_spec.rb
@@ -12,7 +12,9 @@ RSpec.describe "Summarize a topic ", type: :system do
"I like to eat pie. It is a very good dessert. Some people are wasteful by throwing pie at others but I do not do that. I always eat the pie.",
)
end
+
let(:summarization_result) { "This is a summary" }
+
let(:topic_page) { PageObjects::Pages::Topic.new }
let(:summary_box) { PageObjects::Components::AiSummaryTrigger.new }