diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index ca3059fc..58a61e0e 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -221,6 +221,10 @@ module DiscourseAi permitted[:tools] = permit_tools(tools) end + if response_format = params.dig(:ai_persona, :response_format) + permitted[:response_format] = permit_response_format(response_format) + end + permitted end @@ -235,6 +239,18 @@ module DiscourseAi [tool, options, !!force_tool] end end + + def permit_response_format(response_format) + return [] if !response_format.is_a?(Array) + + response_format.map do |element| + if element && element.is_a?(ActionController::Parameters) + element.permit! + else + false + end + end + end end end end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index e654cffc..8347a414 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -325,9 +325,11 @@ class AiPersona < ActiveRecord::Base end def system_persona_unchangeable + error_msg = I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona") + if top_p_changed? || temperature_changed? || system_prompt_changed? || name_changed? || description_changed? - errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona")) + errors.add(:base, error_msg) elsif tools_changed? old_tools = tools_change[0] new_tools = tools_change[1] @@ -335,9 +337,12 @@ class AiPersona < ActiveRecord::Base old_tool_names = old_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set new_tool_names = new_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set - if old_tool_names != new_tool_names - errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona")) - end + errors.add(:base, error_msg) if old_tool_names != new_tool_names + elsif response_format_changed? + old_format = response_format_change[0].map { |f| f["key"] }.to_set + new_format = response_format_change[1].map { |f| f["key"] }.to_set + + errors.add(:base, error_msg) if old_format != new_format end end @@ -395,6 +400,7 @@ end # rag_llm_model_id :bigint # default_llm_id :bigint # question_consolidator_llm_id :bigint +# response_format :jsonb # # Indexes # diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb index dde41dfb..57945e29 100644 --- a/app/serializers/localized_ai_persona_serializer.rb +++ b/app/serializers/localized_ai_persona_serializer.rb @@ -30,7 +30,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer :allow_chat_direct_messages, :allow_topic_mentions, :allow_personal_messages, - :force_default_llm + :force_default_llm, + :response_format has_one :user, serializer: BasicUserSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index 4c7ff28e..f313dee5 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -33,6 +33,7 @@ const CREATE_ATTRIBUTES = [ "allow_topic_mentions", "allow_chat_channel_mentions", "allow_chat_direct_messages", + "response_format", ]; const SYSTEM_ATTRIBUTES = [ @@ -60,6 +61,7 @@ const SYSTEM_ATTRIBUTES = [ "allow_topic_mentions", "allow_chat_channel_mentions", "allow_chat_direct_messages", + "response_format", ]; export default class AiPersona extends RestModel { @@ -151,6 +153,7 @@ export default class AiPersona extends RestModel { const attrs = this.getProperties(CREATE_ATTRIBUTES); this.populateTools(attrs); attrs.forced_tool_count = this.forced_tool_count || -1; + attrs.response_format = attrs.response_format || []; return attrs; } diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index edaf13c8..f0b56d8a 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -15,6 +15,7 @@ import Group from "discourse/models/group"; import { i18n } from "discourse-i18n"; import AdminUser from "admin/models/admin-user"; import GroupChooser from "select-kit/components/group-chooser"; +import AiPersonaResponseFormatEditor from "../components/modal/ai-persona-response-format-editor"; import AiLlmSelector from "./ai-llm-selector"; import AiPersonaToolOptions from "./ai-persona-tool-options"; import AiToolSelector from "./ai-tool-selector"; @@ -325,6 +326,8 @@ export default class PersonaEditor extends Component { + + { + toDisplay[keyDesc.key] = keyDesc.type; + }); + + return prettyJSON(toDisplay); + } + + @action + openModal() { + this.showJsonEditorModal = true; + } + + @action + closeModal() { + this.showJsonEditorModal = false; + } + + @action + updateResponseFormat(form, value) { + form.set("response_format", JSON.parse(value)); + } + + +} diff --git a/assets/stylesheets/modules/ai-bot/common/ai-persona.scss b/assets/stylesheets/modules/ai-bot/common/ai-persona.scss index 39eeeb5f..a5fdfe5b 100644 --- a/assets/stylesheets/modules/ai-bot/common/ai-persona.scss +++ b/assets/stylesheets/modules/ai-bot/common/ai-persona.scss @@ -52,6 +52,21 @@ margin-bottom: 10px; font-size: var(--font-down-1); } + + &__response-format { + width: 100%; + display: block; + } + + &__response-format-pre { + margin-bottom: 0; + white-space: pre-line; + } + + &__response-format-none { + margin-bottom: 1em; + margin-top: 0.5em; + } } .rag-options { diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 3a4fb037..0c2c98d2 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -323,6 +323,13 @@ en: rag_conversation_chunks: "Search conversation chunks" rag_conversation_chunks_help: "The number of chunks to use for the RAG model searches. Increase to increase the amount of context the AI can use." persona_description: "Personas are a powerful feature that allows you to customize the behavior of the AI engine in your Discourse forum. They act as a 'system message' that guides the AI's responses and interactions, helping to create a more personalized and engaging user experience." + response_format: + title: "JSON response format" + no_format: "No JSON format specified" + open_modal: "Edit" + modal: + root_title: "Response structure" + key_title: "Key" list: enabled: "AI Bot?" diff --git a/db/fixtures/personas/603_ai_personas.rb b/db/fixtures/personas/603_ai_personas.rb index b0d4854d..7d52e8a9 100644 --- a/db/fixtures/personas/603_ai_personas.rb +++ b/db/fixtures/personas/603_ai_personas.rb @@ -72,6 +72,8 @@ DiscourseAi::Personas::Persona.system_personas.each do |persona_class, id| persona.tools = tools.map { |name, value| [name, value] } + persona.response_format = instance.response_format + persona.system_prompt = instance.system_prompt persona.top_p = instance.top_p persona.temperature = instance.temperature diff --git a/db/migrate/20250411121705_add_response_format_json_to_personass.rb b/db/migrate/20250411121705_add_response_format_json_to_personass.rb new file mode 100644 index 00000000..2f5a3e10 --- /dev/null +++ b/db/migrate/20250411121705_add_response_format_json_to_personass.rb @@ -0,0 +1,6 @@ +# frozen_string_literal: true +class AddResponseFormatJsonToPersonass < ActiveRecord::Migration[7.2] + def change + add_column :ai_personas, :response_format, :jsonb + end +end diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb index e67be109..a8ba9b4a 100644 --- a/lib/completions/anthropic_message_processor.rb +++ b/lib/completions/anthropic_message_processor.rb @@ -10,7 +10,7 @@ class DiscourseAi::Completions::AnthropicMessageProcessor @raw_json = +"" @tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {}) @streaming_parser = - DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls + DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls end def append(json) diff --git a/lib/completions/dialects/nova.rb b/lib/completions/dialects/nova.rb index b078e79d..9dc88097 100644 --- a/lib/completions/dialects/nova.rb +++ b/lib/completions/dialects/nova.rb @@ -42,6 +42,7 @@ module DiscourseAi result = { system: system, messages: messages } result[:inferenceConfig] = inference_config if inference_config.present? result[:toolConfig] = tool_config if tool_config.present? + result[:response_format] = { type: "json_object" } if options[:response_format].present? result end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index dd44a82b..4155db2a 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -88,10 +88,16 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) @native_tool_support = dialect.native_tool_support? - payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) + payload = + default_options(dialect).merge(model_params.except(:response_format)).merge( + messages: prompt.messages, + ) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:stream] = true if @streaming_mode + + prefilled_message = +"" + if prompt.has_tools? payload[:tools] = prompt.tools if dialect.tool_choice.present? @@ -100,16 +106,24 @@ module DiscourseAi # prefill prompt to nudge LLM to generate a response that is useful. # without this LLM (even 3.7) can get confused and start text preambles for a tool calls. - payload[:messages] << { - role: "assistant", - content: dialect.no_more_tool_calls_text, - } + prefilled_message << dialect.no_more_tool_calls_text else payload[:tool_choice] = { type: "tool", name: prompt.tool_choice } end end end + # Prefill prompt to force JSON output. + if model_params[:response_format].present? + prefilled_message << " " if !prefilled_message.empty? + prefilled_message << "{" + @forced_json_through_prefill = true + end + + if !prefilled_message.empty? + payload[:messages] << { role: "assistant", content: prefilled_message } + end + payload end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 915e9d3b..f4336ac1 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -116,9 +116,14 @@ module DiscourseAi payload = nil if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude) - payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) + payload = + default_options(dialect).merge(model_params.except(:response_format)).merge( + messages: prompt.messages, + ) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? + prefilled_message = +"" + if prompt.has_tools? payload[:tools] = prompt.tools if dialect.tool_choice.present? @@ -128,15 +133,23 @@ module DiscourseAi # payload[:tool_choice] = { type: "none" } # prefill prompt to nudge LLM to generate a response that is useful, instead of trying to call a tool - payload[:messages] << { - role: "assistant", - content: dialect.no_more_tool_calls_text, - } + prefilled_message << dialect.no_more_tool_calls_text else payload[:tool_choice] = { type: "tool", name: prompt.tool_choice } end end end + + # Prefill prompt to force JSON output. + if model_params[:response_format].present? + prefilled_message << " " if !prefilled_message.empty? + prefilled_message << "{" + @forced_json_through_prefill = true + end + + if !prefilled_message.empty? + payload[:messages] << { role: "assistant", content: prefilled_message } + end elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova) payload = prompt.to_payload(default_options(dialect).merge(model_params)) else diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 800381e4..7f79a819 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -73,6 +73,7 @@ module DiscourseAi LlmQuota.check_quotas!(@llm_model, user) start_time = Time.now + @forced_json_through_prefill = false @partial_tool_calls = partial_tool_calls @output_thinking = output_thinking @@ -106,6 +107,18 @@ module DiscourseAi prompt = dialect.translate + structured_output = nil + + if model_params[:response_format].present? + schema_properties = + model_params[:response_format].dig(:json_schema, :schema, :properties) + + if schema_properties.present? + structured_output = + DiscourseAi::Completions::StructuredOutput.new(schema_properties) + end + end + FinalDestination::HTTP.start( model_uri.host, model_uri.port, @@ -123,6 +136,10 @@ module DiscourseAi request = prepare_request(request_body) + # Some providers rely on prefill to return structured outputs, so the start + # of the JSON won't be included in the response. Supply it to keep JSON valid. + structured_output << +"{" if structured_output && @forced_json_through_prefill + http.request(request) do |response| if response.code.to_i != 200 Rails.logger.error( @@ -140,10 +157,17 @@ module DiscourseAi xml_stripper = DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? - if @streaming_mode && xml_stripper + if @streaming_mode blk = lambda do |partial, cancel| - partial = xml_stripper << partial if partial.is_a?(String) + if partial.is_a?(String) + partial = xml_stripper << partial if xml_stripper + + if structured_output.present? + structured_output << partial + partial = structured_output + end + end orig_blk.call(partial, cancel) if partial end end @@ -167,6 +191,7 @@ module DiscourseAi xml_stripper: xml_stripper, partials_raw: partials_raw, response_raw: response_raw, + structured_output: structured_output, ) return response_data end @@ -373,7 +398,8 @@ module DiscourseAi xml_tool_processor:, xml_stripper:, partials_raw:, - response_raw: + response_raw:, + structured_output: ) response_raw << response.read_body response_data = decode(response_raw) @@ -403,6 +429,26 @@ module DiscourseAi response_data.reject!(&:blank?) + if structured_output.present? + has_string_response = false + + response_data = + response_data.reduce([]) do |memo, data| + if data.is_a?(String) + structured_output << data + has_string_response = true + next(memo) + else + memo << data + end + + memo + end + + # We only include the structured output if there was actually a structured response + response_data << structured_output if has_string_response + end + # this is to keep stuff backwards compatible response_data = response_data.first if response_data.length == 1 diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index a7ccc1da..b4dafc74 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -40,6 +40,8 @@ module DiscourseAi "The number of completions you requested exceed the number of canned responses" end + response = as_structured_output(response) if model_params[:response_format].present? + raise response if response.is_a?(StandardError) @completions += 1 @@ -54,6 +56,8 @@ module DiscourseAi yield(response, cancel_fn) elsif is_thinking?(response) yield(response, cancel_fn) + elsif is_structured_output?(response) + yield(response, cancel_fn) else response.each_char do |char| break if cancelled @@ -80,6 +84,20 @@ module DiscourseAi def is_tool?(response) response.is_a?(DiscourseAi::Completions::ToolCall) end + + def is_structured_output?(response) + response.is_a?(DiscourseAi::Completions::StructuredOutput) + end + + def as_structured_output(response) + schema_properties = model_params[:response_format].dig(:json_schema, :schema, :properties) + return response if schema_properties.blank? + + output = DiscourseAi::Completions::StructuredOutput.new(schema_properties) + output << { schema_properties.keys.first => response }.to_json + + output + end end end end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index a99bb80b..e52fd46f 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -84,7 +84,16 @@ module DiscourseAi payload[:tool_config] = { function_calling_config: function_calling_config } end - payload[:generationConfig].merge!(model_params) if model_params.present? + if model_params.present? + payload[:generationConfig].merge!(model_params.except(:response_format)) + + if model_params[:response_format].present? + # https://ai.google.dev/api/generate-content#generationconfig + payload[:generationConfig][:responseSchema] = model_params[:response_format] + payload[:generationConfig][:responseMimeType] = "application/json" + end + end + payload end diff --git a/lib/completions/endpoints/samba_nova.rb b/lib/completions/endpoints/samba_nova.rb index cc81e786..9e6b3817 100644 --- a/lib/completions/endpoints/samba_nova.rb +++ b/lib/completions/endpoints/samba_nova.rb @@ -34,7 +34,12 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) - payload = default_options.merge(model_params).merge(messages: prompt) + payload = + default_options.merge(model_params.except(:response_format)).merge(messages: prompt) + + if model_params[:response_format].present? + payload[:response_format] = { type: "json_object" } + end payload[:stream] = true if @streaming_mode diff --git a/lib/completions/tool_call_progress_tracker.rb b/lib/completions/json_streaming_tracker.rb similarity index 74% rename from lib/completions/tool_call_progress_tracker.rb rename to lib/completions/json_streaming_tracker.rb index 0f6d9158..aa687ef1 100644 --- a/lib/completions/tool_call_progress_tracker.rb +++ b/lib/completions/json_streaming_tracker.rb @@ -2,11 +2,11 @@ module DiscourseAi module Completions - class ToolCallProgressTracker - attr_reader :current_key, :current_value, :tool_call + class JsonStreamingTracker + attr_reader :current_key, :current_value, :stream_consumer - def initialize(tool_call) - @tool_call = tool_call + def initialize(stream_consumer) + @stream_consumer = stream_consumer @current_key = nil @current_value = nil @parser = DiscourseAi::Completions::JsonStreamingParser.new @@ -18,7 +18,7 @@ module DiscourseAi @parser.value do |v| if @current_key - tool_call.notify_progress(@current_key, v) + stream_consumer.notify_progress(@current_key, v) @current_key = nil end end @@ -39,7 +39,7 @@ module DiscourseAi if @parser.state == :start_string && @current_key # this is is worth notifying - tool_call.notify_progress(@current_key, @parser.buf) + stream_consumer.notify_progress(@current_key, @parser.buf) end @current_key = nil if @parser.state == :end_value diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 5f82667e..f90de091 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -304,6 +304,8 @@ module DiscourseAi # @param feature_context { Hash - Optional } - The feature context to use for the completion. # @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls. # @param output_thinking { Boolean - Optional } - If true, the completion will return the thinking output for thinking models. + # @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output. + # @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema. # # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. # @@ -321,6 +323,7 @@ module DiscourseAi feature_context: nil, partial_tool_calls: false, output_thinking: false, + response_format: nil, extra_model_params: nil, &partial_read_blk ) @@ -336,6 +339,7 @@ module DiscourseAi feature_context: feature_context, partial_tool_calls: partial_tool_calls, output_thinking: output_thinking, + response_format: response_format, extra_model_params: extra_model_params, }, ) @@ -344,6 +348,7 @@ module DiscourseAi model_params[:temperature] = temperature if temperature model_params[:top_p] = top_p if top_p + model_params[:response_format] = response_format if response_format model_params.merge!(extra_model_params) if extra_model_params if prompt.is_a?(String) diff --git a/lib/completions/nova_message_processor.rb b/lib/completions/nova_message_processor.rb index efe54330..80710c9c 100644 --- a/lib/completions/nova_message_processor.rb +++ b/lib/completions/nova_message_processor.rb @@ -10,7 +10,7 @@ class DiscourseAi::Completions::NovaMessageProcessor @raw_json = +"" @tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {}) @streaming_parser = - DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls + DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls end def append(json) diff --git a/lib/completions/open_ai_message_processor.rb b/lib/completions/open_ai_message_processor.rb index 33182995..979c2043 100644 --- a/lib/completions/open_ai_message_processor.rb +++ b/lib/completions/open_ai_message_processor.rb @@ -59,7 +59,7 @@ module DiscourseAi::Completions if id.present? && name.present? @tool_arguments = +"" @tool = ToolCall.new(id: id, name: name) - @streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls + @streaming_parser = JsonStreamingTracker.new(self) if @partial_tool_calls end @tool_arguments << arguments.to_s diff --git a/lib/completions/structured_output.rb b/lib/completions/structured_output.rb new file mode 100644 index 00000000..fadf5722 --- /dev/null +++ b/lib/completions/structured_output.rb @@ -0,0 +1,52 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class StructuredOutput + def initialize(json_schema_properties) + @property_names = json_schema_properties.keys.map(&:to_sym) + @property_cursors = + json_schema_properties.reduce({}) do |m, (k, prop)| + m[k.to_sym] = 0 if prop[:type] == "string" + m + end + + @tracked = {} + + @partial_json_tracker = JsonStreamingTracker.new(self) + end + + attr_reader :last_chunk_buffer + + def <<(raw) + @partial_json_tracker << raw + end + + def read_latest_buffered_chunk + @property_names.reduce({}) do |memo, pn| + if @tracked[pn].present? + # This means this property is a string and we want to return unread chunks. + if @property_cursors[pn].present? + unread = @tracked[pn][@property_cursors[pn]..] + + memo[pn] = unread if unread.present? + @property_cursors[pn] = @tracked[pn].length + else + # Ints and bools are always returned as is. + memo[pn] = @tracked[pn] + end + end + + memo + end + end + + def notify_progress(key, value) + key_sym = key.to_sym + return if !@property_names.include?(key_sym) + + @tracked[key_sym] = value + end + end + end +end diff --git a/lib/personas/bot.rb b/lib/personas/bot.rb index aa701abc..2c9b5a3b 100644 --- a/lib/personas/bot.rb +++ b/lib/personas/bot.rb @@ -64,10 +64,13 @@ module DiscourseAi user = context.user - llm_kwargs = { user: user } + llm_kwargs = llm_args.dup + llm_kwargs[:user] = user llm_kwargs[:temperature] = persona.temperature if persona.temperature llm_kwargs[:top_p] = persona.top_p if persona.top_p - llm_kwargs[:max_tokens] = llm_args[:max_tokens] if llm_args[:max_tokens].present? + llm_kwargs[:response_format] = build_json_schema( + persona.response_format, + ) if persona.response_format.present? needs_newlines = false tools_ran = 0 @@ -148,6 +151,8 @@ module DiscourseAi raw_context << partial current_thinking << partial end + elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput) + update_blk.call(partial, cancel, nil, :structured_output) else update_blk.call(partial, cancel) end @@ -176,6 +181,10 @@ module DiscourseAi embed_thinking(raw_context) end + def returns_json? + persona.response_format.present? + end + private def embed_thinking(raw_context) @@ -301,6 +310,30 @@ module DiscourseAi placeholder end + + def build_json_schema(response_format) + properties = + response_format + .to_a + .reduce({}) do |memo, format| + memo[format[:key].to_sym] = { type: format[:type] } + memo + end + + { + type: "json_schema", + json_schema: { + name: "reply", + schema: { + type: "object", + properties: properties, + required: properties.keys.map(&:to_s), + additionalProperties: false, + }, + strict: true, + }, + } + end end end end diff --git a/lib/personas/bot_context.rb b/lib/personas/bot_context.rb index 94a67010..5f7dd99e 100644 --- a/lib/personas/bot_context.rb +++ b/lib/personas/bot_context.rb @@ -49,6 +49,8 @@ module DiscourseAi @site_title = site_title @site_description = site_description @time = time + @resource_url = resource_url + @feature_name = feature_name @resource_url = resource_url diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb index 3332e5a9..a8b08785 100644 --- a/lib/personas/persona.rb +++ b/lib/personas/persona.rb @@ -160,6 +160,10 @@ module DiscourseAi {} end + def response_format + nil + end + def available_tools self .class diff --git a/lib/personas/short_summarizer.rb b/lib/personas/short_summarizer.rb index d460d6c0..e7cef54a 100644 --- a/lib/personas/short_summarizer.rb +++ b/lib/personas/short_summarizer.rb @@ -18,9 +18,19 @@ module DiscourseAi - Limit the summary to a maximum of 40 words. - Do *NOT* repeat the discussion title in the summary. - Return the summary inside tags. + Format your response as a JSON object with a single key named "summary", which has the summary as the value. + Your output should be in the following format: + + {"summary": "xx"} + + + Where "xx" is replaced by the summary. PROMPT end + + def response_format + [{ key: "summary", type: "string" }] + end end end end diff --git a/lib/personas/summarizer.rb b/lib/personas/summarizer.rb index a2f81463..64540ff0 100644 --- a/lib/personas/summarizer.rb +++ b/lib/personas/summarizer.rb @@ -18,8 +18,20 @@ module DiscourseAi - Example: link to the 6th post by jane: [agreed with]({resource_url}/6) - Example: link to the 13th post by joe: [joe]({resource_url}/13) - When formatting usernames either use @USERNAME OR [USERNAME]({resource_url}/POST_NUMBER) + + Format your response as a JSON object with a single key named "summary", which has the summary as the value. + Your output should be in the following format: + + {"summary": "xx"} + + + Where "xx" is replaced by the summary. PROMPT end + + def response_format + [{ key: "summary", type: "string" }] + end end end end diff --git a/lib/summarization/fold_content.rb b/lib/summarization/fold_content.rb index 6fcd1876..15800087 100644 --- a/lib/summarization/fold_content.rb +++ b/lib/summarization/fold_content.rb @@ -29,18 +29,10 @@ module DiscourseAi summary = fold(truncated_content, user, &on_partial_blk) - clean_summary = Nokogiri::HTML5.fragment(summary).css("ai")&.first&.text || summary - if persist_summaries - AiSummary.store!( - strategy, - llm_model, - clean_summary, - truncated_content, - human: user&.human?, - ) + AiSummary.store!(strategy, llm_model, summary, truncated_content, human: user&.human?) else - AiSummary.new(summarized_text: clean_summary) + AiSummary.new(summarized_text: summary) end end @@ -118,9 +110,20 @@ module DiscourseAi ) summary = +"" + buffer_blk = - Proc.new do |partial, cancel, placeholder, type| - if type.blank? + Proc.new do |partial, cancel, _, type| + if type == :structured_output + json_summary_schema_key = bot.persona.response_format&.first.to_h + partial_summary = + partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym] + + if partial_summary.present? + summary << partial_summary + on_partial_blk.call(partial_summary, cancel) if on_partial_blk + end + elsif type.blank? + # Assume response is a regular completion. summary << partial on_partial_blk.call(partial, cancel) if on_partial_blk end @@ -154,12 +157,6 @@ module DiscourseAi item end - - def text_only_update(&on_partial_blk) - Proc.new do |partial, cancel, placeholder, type| - on_partial_blk.call(partial, cancel) if type.blank? - end - end end end end diff --git a/spec/jobs/regular/fast_track_topic_gist_spec.rb b/spec/jobs/regular/fast_track_topic_gist_spec.rb index 3eccc006..ef7dbc47 100644 --- a/spec/jobs/regular/fast_track_topic_gist_spec.rb +++ b/spec/jobs/regular/fast_track_topic_gist_spec.rb @@ -21,6 +21,7 @@ RSpec.describe Jobs::FastTrackTopicGist do created_at: 10.minutes.ago, ) end + let(:updated_gist) { "They updated me :(" } context "when it's up to date" do diff --git a/spec/jobs/regular/stream_topic_ai_summary_spec.rb b/spec/jobs/regular/stream_topic_ai_summary_spec.rb index 87848560..1591e701 100644 --- a/spec/jobs/regular/stream_topic_ai_summary_spec.rb +++ b/spec/jobs/regular/stream_topic_ai_summary_spec.rb @@ -51,7 +51,9 @@ RSpec.describe Jobs::StreamTopicAiSummary do end it "publishes updates with a partial summary" do - with_responses(["dummy"]) do + summary = "dummy" + + with_responses([summary]) do messages = MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do job.execute(topic_id: topic.id, user_id: user.id) @@ -59,12 +61,16 @@ RSpec.describe Jobs::StreamTopicAiSummary do partial_summary_update = messages.first.data expect(partial_summary_update[:done]).to eq(false) - expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text)).to eq("dummy") + expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text).chomp("\"}")).to eq( + summary, + ) end end it "publishes a final update to signal we're done and provide metadata" do - with_responses(["dummy"]) do + summary = "dummy" + + with_responses([summary]) do messages = MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do job.execute(topic_id: topic.id, user_id: user.id) @@ -73,6 +79,7 @@ RSpec.describe Jobs::StreamTopicAiSummary do final_update = messages.last.data expect(final_update[:done]).to eq(true) + expect(final_update.dig(:ai_topic_summary, :summarized_text)).to eq(summary) expect(final_update.dig(:ai_topic_summary, :algorithm)).to eq("fake") expect(final_update.dig(:ai_topic_summary, :outdated)).to eq(false) expect(final_update.dig(:ai_topic_summary, :can_regenerate)).to eq(true) diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index ca3bcb46..b07586e0 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -590,7 +590,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do data: {"type": "content_block_stop", "index": 0} event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} } + data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} } event: ping data: {"type": "ping"} @@ -769,4 +769,92 @@ data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_ expect(result).to eq("I won't use any tools. Here's a direct response instead.") end end + + describe "structured output via prefilling" do + it "forces the response to be a JSON and using the given JSON schema" do + schema = { + type: "json_schema", + json_schema: { + name: "reply", + schema: { + type: "object", + properties: { + key: { + type: "string", + }, + }, + required: ["key"], + additionalProperties: false, + }, + strict: true, + }, + } + + body = (<<~STRING).strip + event: message_start + data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} + + event: content_block_start + data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}} + + event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\""}} + + event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "key"}} + + event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\":\\""}} + + event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello!"}} + + event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\"}"}} + + event: content_block_stop + data: {"type": "content_block_stop", "index": 0} + + event: message_delta + data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}} + + event: message_stop + data: {"type": "message_stop"} + STRING + + parsed_body = nil + + stub_request(:post, url).with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "X-Api-Key" => "123", + "Anthropic-Version" => "2023-06-01", + }, + ).to_return(status: 200, body: body) + + structured_output = nil + llm.generate( + prompt, + user: Discourse.system_user, + feature_name: "testing", + response_format: schema, + ) { |partial, cancel| structured_output = partial } + + expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + + expected_body = { + model: "claude-3-opus-20240229", + max_tokens: 4096, + messages: [{ role: "user", content: "user1: hello" }, { role: "assistant", content: "{" }], + system: "You are hello bot", + stream: true, + } + expect(parsed_body).to eq(expected_body) + end + end end diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 3a424451..e5e5d8b7 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -546,4 +546,69 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do ) end end + + describe "structured output via prefilling" do + it "forces the response to be a JSON and using the given JSON schema" do + schema = { + type: "json_schema", + json_schema: { + name: "reply", + schema: { + type: "object", + properties: { + key: { + type: "string", + }, + }, + required: ["key"], + additionalProperties: false, + }, + strict: true, + }, + } + + messages = + [ + { type: "message_start", message: { usage: { input_tokens: 9 } } }, + { type: "content_block_delta", delta: { text: "\"" } }, + { type: "content_block_delta", delta: { text: "key" } }, + { type: "content_block_delta", delta: { text: "\":\"" } }, + { type: "content_block_delta", delta: { text: "Hello!" } }, + { type: "content_block_delta", delta: { text: "\"}" } }, + { type: "message_delta", delta: { usage: { output_tokens: 25 } } }, + ].map { |message| encode_message(message) } + + proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + request = nil + bedrock_mock.with_chunk_array_support do + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream", + ) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: messages) + + structured_output = nil + proxy.generate("hello world", response_format: schema, user: user) do |partial| + structured_output = partial + end + + expected = { + "max_tokens" => 4096, + "anthropic_version" => "bedrock-2023-05-31", + "messages" => [ + { "role" => "user", "content" => "hello world" }, + { "role" => "assistant", "content" => "{" }, + ], + "system" => "You are a helpful bot", + } + expect(JSON.parse(request.body)).to eq(expected) + + expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + end + end + end end diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb index bdff8fc3..d8ae70ef 100644 --- a/spec/lib/completions/endpoints/cohere_spec.rb +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -308,4 +308,64 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do expect(audit.request_tokens).to eq(14) expect(audit.response_tokens).to eq(11) end + + it "is able to return structured outputs" do + schema = { + type: "json_schema", + json_schema: { + name: "reply", + schema: { + type: "object", + properties: { + key: { + type: "string", + }, + }, + required: ["key"], + additionalProperties: false, + }, + strict: true, + }, + } + + body = <<~TEXT + {"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"} + {"is_finished":false,"event_type":"text-generation","text":"{\\""} + {"is_finished":false,"event_type":"text-generation","text":"key"} + {"is_finished":false,"event_type":"text-generation","text":"\\":\\""} + {"is_finished":false,"event_type":"text-generation","text":"Hello!"} + {"is_finished":false,"event_type":"text-generation","text":"\\"}"}| + {"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"{\\"key\\":\\"Hello!\\"}","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"} + TEXT + + parsed_body = nil + structured_output = nil + + EndpointMock.with_chunk_array_support do + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body.split("|")) + + result = + llm.generate(prompt, response_format: schema, user: user) do |partial, cancel| + structured_output = partial + end + end + + expect(parsed_body[:preamble]).to eq("You are hello bot") + expect(parsed_body[:chat_history]).to eq( + [{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }], + ) + expect(parsed_body[:message]).to eq("user1: thanks") + + expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 05a87b41..a4355d92 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -420,7 +420,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do body: proc do |_req_body| req_body = _req_body - true + _req_body end, ).to_return(status: 200, body: response) @@ -433,4 +433,67 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do { function_calling_config: { mode: "ANY", allowed_function_names: ["echo"] } }, ) end + + describe "structured output via JSON Schema" do + it "forces the response to be a JSON" do + schema = { + type: "json_schema", + json_schema: { + name: "reply", + schema: { + type: "object", + properties: { + key: { + type: "string", + }, + }, + required: ["key"], + additionalProperties: false, + }, + strict: true, + }, + } + + response = <<~TEXT.strip + data: {"candidates": [{"content": {"parts": [{"text": "{\\""}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"content": {"parts": [{"text": "key"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"content": {"parts": [{"text": "\\":\\""}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"content": {"parts": [{"text": "Hello!"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"content": {"parts": [{"text": "\\"}"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"finishReason": "MALFORMED_FUNCTION_CALL"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"} + + TEXT + + req_body = nil + + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + url = "#{model.url}:streamGenerateContent?alt=sse&key=123" + + stub_request(:post, url).with( + body: + proc do |_req_body| + req_body = _req_body + true + end, + ).to_return(status: 200, body: response) + + structured_response = nil + llm.generate("Hello", response_format: schema, user: user) do |partial| + structured_response = partial + end + + expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + + parsed = JSON.parse(req_body, symbolize_names: true) + + # Verify that schema is passed following Gemini API specs. + expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema) + expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json") + end + end end diff --git a/spec/lib/completions/structured_output_spec.rb b/spec/lib/completions/structured_output_spec.rb new file mode 100644 index 00000000..178f5ab3 --- /dev/null +++ b/spec/lib/completions/structured_output_spec.rb @@ -0,0 +1,69 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::StructuredOutput do + subject(:structured_output) do + described_class.new( + { + message: { + type: "string", + }, + bool: { + type: "boolean", + }, + number: { + type: "integer", + }, + status: { + type: "string", + }, + }, + ) + end + + describe "Parsing structured output on the fly" do + it "acts as a buffer for an streamed JSON" do + chunks = [ + +"{\"message\": \"Line 1\\n", + +"Line 2\\n", + +"Line 3\", ", + +"\"bool\": true,", + +"\"number\": 4", + +"2,", + +"\"status\": \"o", + +"\\\"k\\\"\"}", + ] + + structured_output << chunks[0] + expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" }) + + structured_output << chunks[1] + expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" }) + + structured_output << chunks[2] + expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" }) + + structured_output << chunks[3] + expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true }) + + # Waiting for number to be fully buffered. + structured_output << chunks[4] + expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true }) + + structured_output << chunks[5] + expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 }) + + structured_output << chunks[6] + expect(structured_output.read_latest_buffered_chunk).to eq( + { bool: true, number: 42, status: "o" }, + ) + + structured_output << chunks[7] + expect(structured_output.read_latest_buffered_chunk).to eq( + { bool: true, number: 42, status: "\"k\"" }, + ) + + # No partial string left to read. + expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 }) + end + end +end diff --git a/spec/lib/modules/summarization/fold_content_spec.rb b/spec/lib/modules/summarization/fold_content_spec.rb index 7f6fafaf..d2497c70 100644 --- a/spec/lib/modules/summarization/fold_content_spec.rb +++ b/spec/lib/modules/summarization/fold_content_spec.rb @@ -23,17 +23,17 @@ RSpec.describe DiscourseAi::Summarization::FoldContent do llm_model.update!(max_prompt_tokens: model_tokens) end - let(:single_summary) { "this is a summary" } + let(:summary) { "this is a summary" } fab!(:user) it "summarizes the content" do result = - DiscourseAi::Completions::Llm.with_prepared_responses([single_summary]) do |spy| + DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do |spy| summarizer.summarize(user).tap { expect(spy.completions).to eq(1) } end - expect(result.summarized_text).to eq(single_summary) + expect(result.summarized_text).to eq(summary) end end diff --git a/spec/models/ai_persona_spec.rb b/spec/models/ai_persona_spec.rb index 2e3a55de..c776e699 100644 --- a/spec/models/ai_persona_spec.rb +++ b/spec/models/ai_persona_spec.rb @@ -267,6 +267,7 @@ RSpec.describe AiPersona do description: "system persona", system_prompt: "system persona", tools: %w[Search Time], + response_format: [{ key: "summary", type: "string" }], system: true, ) end @@ -285,6 +286,22 @@ RSpec.describe AiPersona do I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"), ) end + + it "doesn't accept response format changes" do + new_format = [{ key: "summary2", type: "string" }] + + expect { system_persona.update!(response_format: new_format) }.to raise_error( + ActiveRecord::RecordInvalid, + ) + end + + it "doesn't accept additional format changes" do + new_format = [{ key: "summary", type: "string" }, { key: "summary2", type: "string" }] + + expect { system_persona.update!(response_format: new_format) }.to raise_error( + ActiveRecord::RecordInvalid, + ) + end end end end diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 7fbb016f..0e5662f5 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -185,6 +185,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do default_llm_id: llm_model.id, question_consolidator_llm_id: llm_model.id, forced_tool_count: 2, + response_format: [{ key: "summary", type: "string" }], } end @@ -209,6 +210,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(persona_json["allow_chat_channel_mentions"]).to eq(true) expect(persona_json["allow_chat_direct_messages"]).to eq(true) expect(persona_json["question_consolidator_llm_id"]).to eq(llm_model.id) + expect(persona_json["response_format"].map { |rf| rf["key"] }).to contain_exactly( + "summary", + ) persona = AiPersona.find(persona_json["id"]) diff --git a/spec/requests/summarization/summary_controller_spec.rb b/spec/requests/summarization/summary_controller_spec.rb index 01ab28bd..5c051f3d 100644 --- a/spec/requests/summarization/summary_controller_spec.rb +++ b/spec/requests/summarization/summary_controller_spec.rb @@ -74,6 +74,7 @@ RSpec.describe DiscourseAi::Summarization::SummaryController do it "returns a summary" do summary_text = "This is a summary" + DiscourseAi::Completions::Llm.with_prepared_responses([summary_text]) do get "/discourse-ai/summarization/t/#{topic.id}.json" diff --git a/spec/services/discourse_ai/topic_summarization_spec.rb b/spec/services/discourse_ai/topic_summarization_spec.rb index a75dd527..444cfc67 100644 --- a/spec/services/discourse_ai/topic_summarization_spec.rb +++ b/spec/services/discourse_ai/topic_summarization_spec.rb @@ -13,6 +13,8 @@ describe DiscourseAi::TopicSummarization do let(:strategy) { DiscourseAi::Summarization.topic_summary(topic) } + let(:summary) { "This is the final summary" } + describe "#summarize" do subject(:summarization) { described_class.new(strategy, user) } @@ -27,8 +29,6 @@ describe DiscourseAi::TopicSummarization do end context "when the content was summarized in a single chunk" do - let(:summary) { "This is the final summary" } - it "caches the summary" do DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do section = summarization.summarize @@ -54,7 +54,6 @@ describe DiscourseAi::TopicSummarization do describe "invalidating cached summaries" do let(:cached_text) { "This is a cached summary" } - let(:updated_summary) { "This is the final summary" } def cached_summary AiSummary.find_by(target: topic, summary_type: AiSummary.summary_types[:complete]) @@ -86,10 +85,10 @@ describe DiscourseAi::TopicSummarization do before { cached_summary.update!(original_content_sha: "outdated_sha") } it "returns a new summary" do - DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do + DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do section = summarization.summarize - expect(section.summarized_text).to eq(updated_summary) + expect(section.summarized_text).to eq(summary) end end @@ -106,10 +105,10 @@ describe DiscourseAi::TopicSummarization do end it "returns a new summary if the skip_age_check flag is passed" do - DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do + DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do section = summarization.summarize(skip_age_check: true) - expect(section.summarized_text).to eq(updated_summary) + expect(section.summarized_text).to eq(summary) end end end @@ -118,8 +117,6 @@ describe DiscourseAi::TopicSummarization do end describe "stream partial updates" do - let(:summary) { "This is the final summary" } - it "receives a blk that is passed to the underlying strategy and called with partial summaries" do partial_result = +"" @@ -127,7 +124,8 @@ describe DiscourseAi::TopicSummarization do summarization.summarize { |partial_summary| partial_result << partial_summary } end - expect(partial_result).to eq(summary) + # In a real world example, this is removed in the returned AiSummary obj. + expect(partial_result.chomp("\"}")).to eq(summary) end end end diff --git a/spec/system/summarization/topic_summarization_spec.rb b/spec/system/summarization/topic_summarization_spec.rb index 8a23772b..30b147a0 100644 --- a/spec/system/summarization/topic_summarization_spec.rb +++ b/spec/system/summarization/topic_summarization_spec.rb @@ -12,7 +12,9 @@ RSpec.describe "Summarize a topic ", type: :system do "I like to eat pie. It is a very good dessert. Some people are wasteful by throwing pie at others but I do not do that. I always eat the pie.", ) end + let(:summarization_result) { "This is a summary" } + let(:topic_page) { PageObjects::Pages::Topic.new } let(:summary_box) { PageObjects::Components::AiSummaryTrigger.new }