FIX: Structured output discrepancies. (#1340)

This change fixes two bugs and adds a safeguard.

The first issue is that the schema Gemini expected differed from the one sent, resulting in 400 errors when performing completions.

The second issue was that creating a new persona won't define a method
for `response_format`. This has to be explicitly defined when we wrap it inside the Persona class. Also, There was a mismatch between the default value and what we stored in the DB. Some parts of the code expected symbols as keys and others as strings.

Finally, we add a safeguard when, even if asked to, the model refuses to reply with a valid JSON. In this case, we are making a best-effort to recover and stream the raw response.
This commit is contained in:
Roman Rizzi 2025-05-15 11:32:10 -03:00 committed by GitHub
parent 1b3fdad5c7
commit ff2e18f9ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 73 additions and 39 deletions

View File

@ -266,6 +266,7 @@ class AiPersona < ActiveRecord::Base
define_method(:top_p) { @ai_persona&.top_p }
define_method(:system_prompt) { @ai_persona&.system_prompt || "You are a helpful bot." }
define_method(:uploads) { @ai_persona&.uploads }
define_method(:response_format) { @ai_persona&.response_format }
define_method(:examples) { @ai_persona&.examples }
end
end

View File

@ -87,9 +87,13 @@ module DiscourseAi
if model_params.present?
payload[:generationConfig].merge!(model_params.except(:response_format))
if model_params[:response_format].present?
# https://ai.google.dev/api/generate-content#generationconfig
payload[:generationConfig][:responseSchema] = model_params[:response_format]
# https://ai.google.dev/api/generate-content#generationconfig
gemini_schema = model_params[:response_format].dig(:json_schema, :schema)
if gemini_schema.present?
payload[:generationConfig][:responseSchema] = gemini_schema.except(
:additionalProperties,
)
payload[:generationConfig][:responseMimeType] = "application/json"
end
end

View File

@ -24,6 +24,10 @@ module DiscourseAi
end
end
def broken?
@broken
end
def <<(json)
# llm could send broken json
# in that case just deal with it later

View File

@ -13,31 +13,40 @@ module DiscourseAi
@tracked = {}
@raw_response = +""
@raw_cursor = 0
@partial_json_tracker = JsonStreamingTracker.new(self)
end
attr_reader :last_chunk_buffer
def <<(raw)
@raw_response << raw
@partial_json_tracker << raw
end
def read_latest_buffered_chunk
@property_names.reduce({}) do |memo, pn|
if @tracked[pn].present?
# This means this property is a string and we want to return unread chunks.
if @property_cursors[pn].present?
unread = @tracked[pn][@property_cursors[pn]..]
def read_buffered_property(prop_name)
# Safeguard: If the model is misbehaving and generating something that's not a JSON,
# treat response as a normal string.
# This is a best-effort to recover from an unexpected scenario.
if @partial_json_tracker.broken?
unread_chunk = @raw_response[@raw_cursor..]
@raw_cursor = @raw_response.length
return unread_chunk
end
memo[pn] = unread if unread.present?
@property_cursors[pn] = @tracked[pn].length
else
# Ints and bools are always returned as is.
memo[pn] = @tracked[pn]
end
end
# Maybe we haven't read that part of the JSON yet.
return nil if @tracked[prop_name].blank?
memo
# This means this property is a string and we want to return unread chunks.
if @property_cursors[prop_name].present?
unread = @tracked[prop_name][@property_cursors[prop_name]..]
@property_cursors[prop_name] = @tracked[prop_name].length
unread
else
# Ints and bools are always returned as is.
@tracked[prop_name]
end
end

View File

@ -316,7 +316,7 @@ module DiscourseAi
response_format
.to_a
.reduce({}) do |memo, format|
memo[format[:key].to_sym] = { type: format[:type] }
memo[format["key"].to_sym] = { type: format["type"] }
memo
end

View File

@ -33,7 +33,7 @@ module DiscourseAi
end
def response_format
[{ key: "summary", type: "string" }]
[{ "key" => "summary", "type" => "string" }]
end
end
end

View File

@ -34,7 +34,7 @@ module DiscourseAi
end
def response_format
[{ key: "summary", type: "string" }]
[{ "key" => "summary", "type" => "string" }]
end
def examples

View File

@ -116,7 +116,7 @@ module DiscourseAi
if type == :structured_output
json_summary_schema_key = bot.persona.response_format&.first.to_h
partial_summary =
partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym]
partial.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
if partial_summary.present?
summary << partial_summary

View File

@ -845,7 +845,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
response_format: schema,
) { |partial, cancel| structured_output = partial }
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
expected_body = {
model: "claude-3-opus-20240229",

View File

@ -607,7 +607,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
}
expect(JSON.parse(request.body)).to eq(expected)
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
end
end
end

View File

@ -366,6 +366,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
)
expect(parsed_body[:message]).to eq("user1: thanks")
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
end
end

View File

@ -565,12 +565,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
structured_response = partial
end
expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" })
expect(structured_response.read_buffered_property(:key)).to eq("Hello!")
parsed = JSON.parse(req_body, symbolize_names: true)
# Verify that schema is passed following Gemini API specs.
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema)
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(
schema.dig(:json_schema, :schema).except(:additionalProperties),
)
expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
end
end

View File

@ -34,36 +34,50 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do
]
structured_output << chunks[0]
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" })
expect(structured_output.read_buffered_property(:message)).to eq("Line 1\n")
structured_output << chunks[1]
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" })
expect(structured_output.read_buffered_property(:message)).to eq("Line 2\n")
structured_output << chunks[2]
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" })
expect(structured_output.read_buffered_property(:message)).to eq("Line 3")
structured_output << chunks[3]
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
expect(structured_output.read_buffered_property(:bool)).to eq(true)
# Waiting for number to be fully buffered.
structured_output << chunks[4]
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
expect(structured_output.read_buffered_property(:bool)).to eq(true)
expect(structured_output.read_buffered_property(:number)).to be_nil
structured_output << chunks[5]
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
expect(structured_output.read_buffered_property(:number)).to eq(42)
structured_output << chunks[6]
expect(structured_output.read_latest_buffered_chunk).to eq(
{ bool: true, number: 42, status: "o" },
)
expect(structured_output.read_buffered_property(:number)).to eq(42)
expect(structured_output.read_buffered_property(:bool)).to eq(true)
expect(structured_output.read_buffered_property(:status)).to eq("o")
structured_output << chunks[7]
expect(structured_output.read_latest_buffered_chunk).to eq(
{ bool: true, number: 42, status: "\"k\"" },
)
expect(structured_output.read_buffered_property(:status)).to eq("\"k\"")
# No partial string left to read.
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
expect(structured_output.read_buffered_property(:status)).to eq("")
end
end
describe "dealing with non-JSON responses" do
it "treat it as plain text once we determined it's invalid JSON" do
chunks = [+"I'm not", +"a", +"JSON :)"]
structured_output << chunks[0]
expect(structured_output.read_buffered_property(nil)).to eq("I'm not")
structured_output << chunks[1]
expect(structured_output.read_buffered_property(nil)).to eq("a")
structured_output << chunks[2]
expect(structured_output.read_buffered_property(nil)).to eq("JSON :)")
end
end
end