FIX: Structured output discrepancies. (#1340)
This change fixes two bugs and adds a safeguard. The first issue is that the schema Gemini expected differed from the one sent, resulting in 400 errors when performing completions. The second issue was that creating a new persona won't define a method for `response_format`. This has to be explicitly defined when we wrap it inside the Persona class. Also, There was a mismatch between the default value and what we stored in the DB. Some parts of the code expected symbols as keys and others as strings. Finally, we add a safeguard when, even if asked to, the model refuses to reply with a valid JSON. In this case, we are making a best-effort to recover and stream the raw response.
This commit is contained in:
parent
1b3fdad5c7
commit
ff2e18f9ca
|
@ -266,6 +266,7 @@ class AiPersona < ActiveRecord::Base
|
|||
define_method(:top_p) { @ai_persona&.top_p }
|
||||
define_method(:system_prompt) { @ai_persona&.system_prompt || "You are a helpful bot." }
|
||||
define_method(:uploads) { @ai_persona&.uploads }
|
||||
define_method(:response_format) { @ai_persona&.response_format }
|
||||
define_method(:examples) { @ai_persona&.examples }
|
||||
end
|
||||
end
|
||||
|
|
|
@ -87,9 +87,13 @@ module DiscourseAi
|
|||
if model_params.present?
|
||||
payload[:generationConfig].merge!(model_params.except(:response_format))
|
||||
|
||||
if model_params[:response_format].present?
|
||||
# https://ai.google.dev/api/generate-content#generationconfig
|
||||
payload[:generationConfig][:responseSchema] = model_params[:response_format]
|
||||
# https://ai.google.dev/api/generate-content#generationconfig
|
||||
gemini_schema = model_params[:response_format].dig(:json_schema, :schema)
|
||||
|
||||
if gemini_schema.present?
|
||||
payload[:generationConfig][:responseSchema] = gemini_schema.except(
|
||||
:additionalProperties,
|
||||
)
|
||||
payload[:generationConfig][:responseMimeType] = "application/json"
|
||||
end
|
||||
end
|
||||
|
|
|
@ -24,6 +24,10 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def broken?
|
||||
@broken
|
||||
end
|
||||
|
||||
def <<(json)
|
||||
# llm could send broken json
|
||||
# in that case just deal with it later
|
||||
|
|
|
@ -13,31 +13,40 @@ module DiscourseAi
|
|||
|
||||
@tracked = {}
|
||||
|
||||
@raw_response = +""
|
||||
@raw_cursor = 0
|
||||
|
||||
@partial_json_tracker = JsonStreamingTracker.new(self)
|
||||
end
|
||||
|
||||
attr_reader :last_chunk_buffer
|
||||
|
||||
def <<(raw)
|
||||
@raw_response << raw
|
||||
@partial_json_tracker << raw
|
||||
end
|
||||
|
||||
def read_latest_buffered_chunk
|
||||
@property_names.reduce({}) do |memo, pn|
|
||||
if @tracked[pn].present?
|
||||
# This means this property is a string and we want to return unread chunks.
|
||||
if @property_cursors[pn].present?
|
||||
unread = @tracked[pn][@property_cursors[pn]..]
|
||||
def read_buffered_property(prop_name)
|
||||
# Safeguard: If the model is misbehaving and generating something that's not a JSON,
|
||||
# treat response as a normal string.
|
||||
# This is a best-effort to recover from an unexpected scenario.
|
||||
if @partial_json_tracker.broken?
|
||||
unread_chunk = @raw_response[@raw_cursor..]
|
||||
@raw_cursor = @raw_response.length
|
||||
return unread_chunk
|
||||
end
|
||||
|
||||
memo[pn] = unread if unread.present?
|
||||
@property_cursors[pn] = @tracked[pn].length
|
||||
else
|
||||
# Ints and bools are always returned as is.
|
||||
memo[pn] = @tracked[pn]
|
||||
end
|
||||
end
|
||||
# Maybe we haven't read that part of the JSON yet.
|
||||
return nil if @tracked[prop_name].blank?
|
||||
|
||||
memo
|
||||
# This means this property is a string and we want to return unread chunks.
|
||||
if @property_cursors[prop_name].present?
|
||||
unread = @tracked[prop_name][@property_cursors[prop_name]..]
|
||||
@property_cursors[prop_name] = @tracked[prop_name].length
|
||||
unread
|
||||
else
|
||||
# Ints and bools are always returned as is.
|
||||
@tracked[prop_name]
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -316,7 +316,7 @@ module DiscourseAi
|
|||
response_format
|
||||
.to_a
|
||||
.reduce({}) do |memo, format|
|
||||
memo[format[:key].to_sym] = { type: format[:type] }
|
||||
memo[format["key"].to_sym] = { type: format["type"] }
|
||||
memo
|
||||
end
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def response_format
|
||||
[{ key: "summary", type: "string" }]
|
||||
[{ "key" => "summary", "type" => "string" }]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -34,7 +34,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def response_format
|
||||
[{ key: "summary", type: "string" }]
|
||||
[{ "key" => "summary", "type" => "string" }]
|
||||
end
|
||||
|
||||
def examples
|
||||
|
|
|
@ -116,7 +116,7 @@ module DiscourseAi
|
|||
if type == :structured_output
|
||||
json_summary_schema_key = bot.persona.response_format&.first.to_h
|
||||
partial_summary =
|
||||
partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym]
|
||||
partial.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
|
||||
|
||||
if partial_summary.present?
|
||||
summary << partial_summary
|
||||
|
|
|
@ -845,7 +845,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
response_format: schema,
|
||||
) { |partial, cancel| structured_output = partial }
|
||||
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
|
||||
|
||||
expected_body = {
|
||||
model: "claude-3-opus-20240229",
|
||||
|
|
|
@ -607,7 +607,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
}
|
||||
expect(JSON.parse(request.body)).to eq(expected)
|
||||
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -366,6 +366,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
)
|
||||
expect(parsed_body[:message]).to eq("user1: thanks")
|
||||
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
|
||||
end
|
||||
end
|
||||
|
|
|
@ -565,12 +565,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
structured_response = partial
|
||||
end
|
||||
|
||||
expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
expect(structured_response.read_buffered_property(:key)).to eq("Hello!")
|
||||
|
||||
parsed = JSON.parse(req_body, symbolize_names: true)
|
||||
|
||||
# Verify that schema is passed following Gemini API specs.
|
||||
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema)
|
||||
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(
|
||||
schema.dig(:json_schema, :schema).except(:additionalProperties),
|
||||
)
|
||||
expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
|
||||
end
|
||||
end
|
||||
|
|
|
@ -34,36 +34,50 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do
|
|||
]
|
||||
|
||||
structured_output << chunks[0]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" })
|
||||
expect(structured_output.read_buffered_property(:message)).to eq("Line 1\n")
|
||||
|
||||
structured_output << chunks[1]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" })
|
||||
expect(structured_output.read_buffered_property(:message)).to eq("Line 2\n")
|
||||
|
||||
structured_output << chunks[2]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" })
|
||||
expect(structured_output.read_buffered_property(:message)).to eq("Line 3")
|
||||
|
||||
structured_output << chunks[3]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
|
||||
expect(structured_output.read_buffered_property(:bool)).to eq(true)
|
||||
|
||||
# Waiting for number to be fully buffered.
|
||||
structured_output << chunks[4]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
|
||||
expect(structured_output.read_buffered_property(:bool)).to eq(true)
|
||||
expect(structured_output.read_buffered_property(:number)).to be_nil
|
||||
|
||||
structured_output << chunks[5]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
|
||||
expect(structured_output.read_buffered_property(:number)).to eq(42)
|
||||
|
||||
structured_output << chunks[6]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq(
|
||||
{ bool: true, number: 42, status: "o" },
|
||||
)
|
||||
expect(structured_output.read_buffered_property(:number)).to eq(42)
|
||||
expect(structured_output.read_buffered_property(:bool)).to eq(true)
|
||||
expect(structured_output.read_buffered_property(:status)).to eq("o")
|
||||
|
||||
structured_output << chunks[7]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq(
|
||||
{ bool: true, number: 42, status: "\"k\"" },
|
||||
)
|
||||
expect(structured_output.read_buffered_property(:status)).to eq("\"k\"")
|
||||
|
||||
# No partial string left to read.
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
|
||||
expect(structured_output.read_buffered_property(:status)).to eq("")
|
||||
end
|
||||
end
|
||||
|
||||
describe "dealing with non-JSON responses" do
|
||||
it "treat it as plain text once we determined it's invalid JSON" do
|
||||
chunks = [+"I'm not", +"a", +"JSON :)"]
|
||||
|
||||
structured_output << chunks[0]
|
||||
expect(structured_output.read_buffered_property(nil)).to eq("I'm not")
|
||||
|
||||
structured_output << chunks[1]
|
||||
expect(structured_output.read_buffered_property(nil)).to eq("a")
|
||||
|
||||
structured_output << chunks[2]
|
||||
expect(structured_output.read_buffered_property(nil)).to eq("JSON :)")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue