FIX: Structured output discrepancies. (#1340)
This change fixes two bugs and adds a safeguard. The first issue is that the schema Gemini expected differed from the one sent, resulting in 400 errors when performing completions. The second issue was that creating a new persona won't define a method for `response_format`. This has to be explicitly defined when we wrap it inside the Persona class. Also, There was a mismatch between the default value and what we stored in the DB. Some parts of the code expected symbols as keys and others as strings. Finally, we add a safeguard when, even if asked to, the model refuses to reply with a valid JSON. In this case, we are making a best-effort to recover and stream the raw response.
This commit is contained in:
parent
1b3fdad5c7
commit
ff2e18f9ca
|
@ -266,6 +266,7 @@ class AiPersona < ActiveRecord::Base
|
||||||
define_method(:top_p) { @ai_persona&.top_p }
|
define_method(:top_p) { @ai_persona&.top_p }
|
||||||
define_method(:system_prompt) { @ai_persona&.system_prompt || "You are a helpful bot." }
|
define_method(:system_prompt) { @ai_persona&.system_prompt || "You are a helpful bot." }
|
||||||
define_method(:uploads) { @ai_persona&.uploads }
|
define_method(:uploads) { @ai_persona&.uploads }
|
||||||
|
define_method(:response_format) { @ai_persona&.response_format }
|
||||||
define_method(:examples) { @ai_persona&.examples }
|
define_method(:examples) { @ai_persona&.examples }
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -87,9 +87,13 @@ module DiscourseAi
|
||||||
if model_params.present?
|
if model_params.present?
|
||||||
payload[:generationConfig].merge!(model_params.except(:response_format))
|
payload[:generationConfig].merge!(model_params.except(:response_format))
|
||||||
|
|
||||||
if model_params[:response_format].present?
|
# https://ai.google.dev/api/generate-content#generationconfig
|
||||||
# https://ai.google.dev/api/generate-content#generationconfig
|
gemini_schema = model_params[:response_format].dig(:json_schema, :schema)
|
||||||
payload[:generationConfig][:responseSchema] = model_params[:response_format]
|
|
||||||
|
if gemini_schema.present?
|
||||||
|
payload[:generationConfig][:responseSchema] = gemini_schema.except(
|
||||||
|
:additionalProperties,
|
||||||
|
)
|
||||||
payload[:generationConfig][:responseMimeType] = "application/json"
|
payload[:generationConfig][:responseMimeType] = "application/json"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -24,6 +24,10 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def broken?
|
||||||
|
@broken
|
||||||
|
end
|
||||||
|
|
||||||
def <<(json)
|
def <<(json)
|
||||||
# llm could send broken json
|
# llm could send broken json
|
||||||
# in that case just deal with it later
|
# in that case just deal with it later
|
||||||
|
|
|
@ -13,31 +13,40 @@ module DiscourseAi
|
||||||
|
|
||||||
@tracked = {}
|
@tracked = {}
|
||||||
|
|
||||||
|
@raw_response = +""
|
||||||
|
@raw_cursor = 0
|
||||||
|
|
||||||
@partial_json_tracker = JsonStreamingTracker.new(self)
|
@partial_json_tracker = JsonStreamingTracker.new(self)
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :last_chunk_buffer
|
attr_reader :last_chunk_buffer
|
||||||
|
|
||||||
def <<(raw)
|
def <<(raw)
|
||||||
|
@raw_response << raw
|
||||||
@partial_json_tracker << raw
|
@partial_json_tracker << raw
|
||||||
end
|
end
|
||||||
|
|
||||||
def read_latest_buffered_chunk
|
def read_buffered_property(prop_name)
|
||||||
@property_names.reduce({}) do |memo, pn|
|
# Safeguard: If the model is misbehaving and generating something that's not a JSON,
|
||||||
if @tracked[pn].present?
|
# treat response as a normal string.
|
||||||
# This means this property is a string and we want to return unread chunks.
|
# This is a best-effort to recover from an unexpected scenario.
|
||||||
if @property_cursors[pn].present?
|
if @partial_json_tracker.broken?
|
||||||
unread = @tracked[pn][@property_cursors[pn]..]
|
unread_chunk = @raw_response[@raw_cursor..]
|
||||||
|
@raw_cursor = @raw_response.length
|
||||||
|
return unread_chunk
|
||||||
|
end
|
||||||
|
|
||||||
memo[pn] = unread if unread.present?
|
# Maybe we haven't read that part of the JSON yet.
|
||||||
@property_cursors[pn] = @tracked[pn].length
|
return nil if @tracked[prop_name].blank?
|
||||||
else
|
|
||||||
# Ints and bools are always returned as is.
|
|
||||||
memo[pn] = @tracked[pn]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
memo
|
# This means this property is a string and we want to return unread chunks.
|
||||||
|
if @property_cursors[prop_name].present?
|
||||||
|
unread = @tracked[prop_name][@property_cursors[prop_name]..]
|
||||||
|
@property_cursors[prop_name] = @tracked[prop_name].length
|
||||||
|
unread
|
||||||
|
else
|
||||||
|
# Ints and bools are always returned as is.
|
||||||
|
@tracked[prop_name]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -316,7 +316,7 @@ module DiscourseAi
|
||||||
response_format
|
response_format
|
||||||
.to_a
|
.to_a
|
||||||
.reduce({}) do |memo, format|
|
.reduce({}) do |memo, format|
|
||||||
memo[format[:key].to_sym] = { type: format[:type] }
|
memo[format["key"].to_sym] = { type: format["type"] }
|
||||||
memo
|
memo
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def response_format
|
def response_format
|
||||||
[{ key: "summary", type: "string" }]
|
[{ "key" => "summary", "type" => "string" }]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -34,7 +34,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def response_format
|
def response_format
|
||||||
[{ key: "summary", type: "string" }]
|
[{ "key" => "summary", "type" => "string" }]
|
||||||
end
|
end
|
||||||
|
|
||||||
def examples
|
def examples
|
||||||
|
|
|
@ -116,7 +116,7 @@ module DiscourseAi
|
||||||
if type == :structured_output
|
if type == :structured_output
|
||||||
json_summary_schema_key = bot.persona.response_format&.first.to_h
|
json_summary_schema_key = bot.persona.response_format&.first.to_h
|
||||||
partial_summary =
|
partial_summary =
|
||||||
partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym]
|
partial.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
|
||||||
|
|
||||||
if partial_summary.present?
|
if partial_summary.present?
|
||||||
summary << partial_summary
|
summary << partial_summary
|
||||||
|
|
|
@ -845,7 +845,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
response_format: schema,
|
response_format: schema,
|
||||||
) { |partial, cancel| structured_output = partial }
|
) { |partial, cancel| structured_output = partial }
|
||||||
|
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
|
||||||
|
|
||||||
expected_body = {
|
expected_body = {
|
||||||
model: "claude-3-opus-20240229",
|
model: "claude-3-opus-20240229",
|
||||||
|
|
|
@ -607,7 +607,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
}
|
}
|
||||||
expect(JSON.parse(request.body)).to eq(expected)
|
expect(JSON.parse(request.body)).to eq(expected)
|
||||||
|
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -366,6 +366,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
)
|
)
|
||||||
expect(parsed_body[:message]).to eq("user1: thanks")
|
expect(parsed_body[:message]).to eq("user1: thanks")
|
||||||
|
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -565,12 +565,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
structured_response = partial
|
structured_response = partial
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
expect(structured_response.read_buffered_property(:key)).to eq("Hello!")
|
||||||
|
|
||||||
parsed = JSON.parse(req_body, symbolize_names: true)
|
parsed = JSON.parse(req_body, symbolize_names: true)
|
||||||
|
|
||||||
# Verify that schema is passed following Gemini API specs.
|
# Verify that schema is passed following Gemini API specs.
|
||||||
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema)
|
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(
|
||||||
|
schema.dig(:json_schema, :schema).except(:additionalProperties),
|
||||||
|
)
|
||||||
expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
|
expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -34,36 +34,50 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do
|
||||||
]
|
]
|
||||||
|
|
||||||
structured_output << chunks[0]
|
structured_output << chunks[0]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" })
|
expect(structured_output.read_buffered_property(:message)).to eq("Line 1\n")
|
||||||
|
|
||||||
structured_output << chunks[1]
|
structured_output << chunks[1]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" })
|
expect(structured_output.read_buffered_property(:message)).to eq("Line 2\n")
|
||||||
|
|
||||||
structured_output << chunks[2]
|
structured_output << chunks[2]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" })
|
expect(structured_output.read_buffered_property(:message)).to eq("Line 3")
|
||||||
|
|
||||||
structured_output << chunks[3]
|
structured_output << chunks[3]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
|
expect(structured_output.read_buffered_property(:bool)).to eq(true)
|
||||||
|
|
||||||
# Waiting for number to be fully buffered.
|
# Waiting for number to be fully buffered.
|
||||||
structured_output << chunks[4]
|
structured_output << chunks[4]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
|
expect(structured_output.read_buffered_property(:bool)).to eq(true)
|
||||||
|
expect(structured_output.read_buffered_property(:number)).to be_nil
|
||||||
|
|
||||||
structured_output << chunks[5]
|
structured_output << chunks[5]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
|
expect(structured_output.read_buffered_property(:number)).to eq(42)
|
||||||
|
|
||||||
structured_output << chunks[6]
|
structured_output << chunks[6]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq(
|
expect(structured_output.read_buffered_property(:number)).to eq(42)
|
||||||
{ bool: true, number: 42, status: "o" },
|
expect(structured_output.read_buffered_property(:bool)).to eq(true)
|
||||||
)
|
expect(structured_output.read_buffered_property(:status)).to eq("o")
|
||||||
|
|
||||||
structured_output << chunks[7]
|
structured_output << chunks[7]
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq(
|
expect(structured_output.read_buffered_property(:status)).to eq("\"k\"")
|
||||||
{ bool: true, number: 42, status: "\"k\"" },
|
|
||||||
)
|
|
||||||
|
|
||||||
# No partial string left to read.
|
# No partial string left to read.
|
||||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
|
expect(structured_output.read_buffered_property(:status)).to eq("")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "dealing with non-JSON responses" do
|
||||||
|
it "treat it as plain text once we determined it's invalid JSON" do
|
||||||
|
chunks = [+"I'm not", +"a", +"JSON :)"]
|
||||||
|
|
||||||
|
structured_output << chunks[0]
|
||||||
|
expect(structured_output.read_buffered_property(nil)).to eq("I'm not")
|
||||||
|
|
||||||
|
structured_output << chunks[1]
|
||||||
|
expect(structured_output.read_buffered_property(nil)).to eq("a")
|
||||||
|
|
||||||
|
structured_output << chunks[2]
|
||||||
|
expect(structured_output.read_buffered_property(nil)).to eq("JSON :)")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue