From ff2e18f9cae80644cce86b7135e2ba913d68b484 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Thu, 15 May 2025 11:32:10 -0300 Subject: [PATCH] FIX: Structured output discrepancies. (#1340) This change fixes two bugs and adds a safeguard. The first issue is that the schema Gemini expected differed from the one sent, resulting in 400 errors when performing completions. The second issue was that creating a new persona won't define a method for `response_format`. This has to be explicitly defined when we wrap it inside the Persona class. Also, There was a mismatch between the default value and what we stored in the DB. Some parts of the code expected symbols as keys and others as strings. Finally, we add a safeguard when, even if asked to, the model refuses to reply with a valid JSON. In this case, we are making a best-effort to recover and stream the raw response. --- app/models/ai_persona.rb | 1 + lib/completions/endpoints/gemini.rb | 10 +++-- lib/completions/json_streaming_tracker.rb | 4 ++ lib/completions/structured_output.rb | 37 ++++++++++------- lib/personas/bot.rb | 2 +- lib/personas/short_summarizer.rb | 2 +- lib/personas/summarizer.rb | 2 +- lib/summarization/fold_content.rb | 2 +- .../completions/endpoints/anthropic_spec.rb | 2 +- .../completions/endpoints/aws_bedrock_spec.rb | 2 +- spec/lib/completions/endpoints/cohere_spec.rb | 2 +- spec/lib/completions/endpoints/gemini_spec.rb | 6 ++- .../lib/completions/structured_output_spec.rb | 40 +++++++++++++------ 13 files changed, 73 insertions(+), 39 deletions(-) diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 8793b398..eba075bf 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -266,6 +266,7 @@ class AiPersona < ActiveRecord::Base define_method(:top_p) { @ai_persona&.top_p } define_method(:system_prompt) { @ai_persona&.system_prompt || "You are a helpful bot." } define_method(:uploads) { @ai_persona&.uploads } + define_method(:response_format) { @ai_persona&.response_format } define_method(:examples) { @ai_persona&.examples } end end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 025d4fbc..17107455 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -87,9 +87,13 @@ module DiscourseAi if model_params.present? payload[:generationConfig].merge!(model_params.except(:response_format)) - if model_params[:response_format].present? - # https://ai.google.dev/api/generate-content#generationconfig - payload[:generationConfig][:responseSchema] = model_params[:response_format] + # https://ai.google.dev/api/generate-content#generationconfig + gemini_schema = model_params[:response_format].dig(:json_schema, :schema) + + if gemini_schema.present? + payload[:generationConfig][:responseSchema] = gemini_schema.except( + :additionalProperties, + ) payload[:generationConfig][:responseMimeType] = "application/json" end end diff --git a/lib/completions/json_streaming_tracker.rb b/lib/completions/json_streaming_tracker.rb index aa687ef1..849fd8ac 100644 --- a/lib/completions/json_streaming_tracker.rb +++ b/lib/completions/json_streaming_tracker.rb @@ -24,6 +24,10 @@ module DiscourseAi end end + def broken? + @broken + end + def <<(json) # llm could send broken json # in that case just deal with it later diff --git a/lib/completions/structured_output.rb b/lib/completions/structured_output.rb index fadf5722..7f13f536 100644 --- a/lib/completions/structured_output.rb +++ b/lib/completions/structured_output.rb @@ -13,31 +13,40 @@ module DiscourseAi @tracked = {} + @raw_response = +"" + @raw_cursor = 0 + @partial_json_tracker = JsonStreamingTracker.new(self) end attr_reader :last_chunk_buffer def <<(raw) + @raw_response << raw @partial_json_tracker << raw end - def read_latest_buffered_chunk - @property_names.reduce({}) do |memo, pn| - if @tracked[pn].present? - # This means this property is a string and we want to return unread chunks. - if @property_cursors[pn].present? - unread = @tracked[pn][@property_cursors[pn]..] + def read_buffered_property(prop_name) + # Safeguard: If the model is misbehaving and generating something that's not a JSON, + # treat response as a normal string. + # This is a best-effort to recover from an unexpected scenario. + if @partial_json_tracker.broken? + unread_chunk = @raw_response[@raw_cursor..] + @raw_cursor = @raw_response.length + return unread_chunk + end - memo[pn] = unread if unread.present? - @property_cursors[pn] = @tracked[pn].length - else - # Ints and bools are always returned as is. - memo[pn] = @tracked[pn] - end - end + # Maybe we haven't read that part of the JSON yet. + return nil if @tracked[prop_name].blank? - memo + # This means this property is a string and we want to return unread chunks. + if @property_cursors[prop_name].present? + unread = @tracked[prop_name][@property_cursors[prop_name]..] + @property_cursors[prop_name] = @tracked[prop_name].length + unread + else + # Ints and bools are always returned as is. + @tracked[prop_name] end end diff --git a/lib/personas/bot.rb b/lib/personas/bot.rb index 3686edac..b6e852c5 100644 --- a/lib/personas/bot.rb +++ b/lib/personas/bot.rb @@ -316,7 +316,7 @@ module DiscourseAi response_format .to_a .reduce({}) do |memo, format| - memo[format[:key].to_sym] = { type: format[:type] } + memo[format["key"].to_sym] = { type: format["type"] } memo end diff --git a/lib/personas/short_summarizer.rb b/lib/personas/short_summarizer.rb index 5b9b5195..26af56b9 100644 --- a/lib/personas/short_summarizer.rb +++ b/lib/personas/short_summarizer.rb @@ -33,7 +33,7 @@ module DiscourseAi end def response_format - [{ key: "summary", type: "string" }] + [{ "key" => "summary", "type" => "string" }] end end end diff --git a/lib/personas/summarizer.rb b/lib/personas/summarizer.rb index b8b0a95a..c2b4a714 100644 --- a/lib/personas/summarizer.rb +++ b/lib/personas/summarizer.rb @@ -34,7 +34,7 @@ module DiscourseAi end def response_format - [{ key: "summary", type: "string" }] + [{ "key" => "summary", "type" => "string" }] end def examples diff --git a/lib/summarization/fold_content.rb b/lib/summarization/fold_content.rb index 8dbcda38..17df679b 100644 --- a/lib/summarization/fold_content.rb +++ b/lib/summarization/fold_content.rb @@ -116,7 +116,7 @@ module DiscourseAi if type == :structured_output json_summary_schema_key = bot.persona.response_format&.first.to_h partial_summary = - partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym] + partial.read_buffered_property(json_summary_schema_key["key"]&.to_sym) if partial_summary.present? summary << partial_summary diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index b07586e0..96a4ae07 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -845,7 +845,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do response_format: schema, ) { |partial, cancel| structured_output = partial } - expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + expect(structured_output.read_buffered_property(:key)).to eq("Hello!") expected_body = { model: "claude-3-opus-20240229", diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index e5e5d8b7..aa86548e 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -607,7 +607,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do } expect(JSON.parse(request.body)).to eq(expected) - expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + expect(structured_output.read_buffered_property(:key)).to eq("Hello!") end end end diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb index d8ae70ef..546922b3 100644 --- a/spec/lib/completions/endpoints/cohere_spec.rb +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -366,6 +366,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do ) expect(parsed_body[:message]).to eq("user1: thanks") - expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + expect(structured_output.read_buffered_property(:key)).to eq("Hello!") end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 1cf8ca1d..429e8108 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -565,12 +565,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do structured_response = partial end - expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" }) + expect(structured_response.read_buffered_property(:key)).to eq("Hello!") parsed = JSON.parse(req_body, symbolize_names: true) # Verify that schema is passed following Gemini API specs. - expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema) + expect(parsed.dig(:generationConfig, :responseSchema)).to eq( + schema.dig(:json_schema, :schema).except(:additionalProperties), + ) expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json") end end diff --git a/spec/lib/completions/structured_output_spec.rb b/spec/lib/completions/structured_output_spec.rb index 178f5ab3..322cd0e2 100644 --- a/spec/lib/completions/structured_output_spec.rb +++ b/spec/lib/completions/structured_output_spec.rb @@ -34,36 +34,50 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do ] structured_output << chunks[0] - expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" }) + expect(structured_output.read_buffered_property(:message)).to eq("Line 1\n") structured_output << chunks[1] - expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" }) + expect(structured_output.read_buffered_property(:message)).to eq("Line 2\n") structured_output << chunks[2] - expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" }) + expect(structured_output.read_buffered_property(:message)).to eq("Line 3") structured_output << chunks[3] - expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true }) + expect(structured_output.read_buffered_property(:bool)).to eq(true) # Waiting for number to be fully buffered. structured_output << chunks[4] - expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true }) + expect(structured_output.read_buffered_property(:bool)).to eq(true) + expect(structured_output.read_buffered_property(:number)).to be_nil structured_output << chunks[5] - expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 }) + expect(structured_output.read_buffered_property(:number)).to eq(42) structured_output << chunks[6] - expect(structured_output.read_latest_buffered_chunk).to eq( - { bool: true, number: 42, status: "o" }, - ) + expect(structured_output.read_buffered_property(:number)).to eq(42) + expect(structured_output.read_buffered_property(:bool)).to eq(true) + expect(structured_output.read_buffered_property(:status)).to eq("o") structured_output << chunks[7] - expect(structured_output.read_latest_buffered_chunk).to eq( - { bool: true, number: 42, status: "\"k\"" }, - ) + expect(structured_output.read_buffered_property(:status)).to eq("\"k\"") # No partial string left to read. - expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 }) + expect(structured_output.read_buffered_property(:status)).to eq("") + end + end + + describe "dealing with non-JSON responses" do + it "treat it as plain text once we determined it's invalid JSON" do + chunks = [+"I'm not", +"a", +"JSON :)"] + + structured_output << chunks[0] + expect(structured_output.read_buffered_property(nil)).to eq("I'm not") + + structured_output << chunks[1] + expect(structured_output.read_buffered_property(nil)).to eq("a") + + structured_output << chunks[2] + expect(structured_output.read_buffered_property(nil)).to eq("JSON :)") end end end