Add proper support for array types in structured outputs

This commit is contained in:
Roman Rizzi 2025-05-29 17:57:30 -03:00
parent 5682e8db0d
commit a73e9e6724
No known key found for this signature in database
GPG Key ID: 64024A71CE7330D3
11 changed files with 193 additions and 20 deletions

View File

@ -22,10 +22,20 @@ export default class AiPersonaResponseFormatEditor extends Component {
type: "string",
},
type: {
type: "string",
enum: ["string", "integer", "boolean", "array"],
},
array_type: {
type: "string",
enum: ["string", "integer", "boolean"],
options: {
dependencies: {
type: "array",
},
},
},
},
required: ["key", "type"],
},
};
@ -41,7 +51,11 @@ export default class AiPersonaResponseFormatEditor extends Component {
const toDisplay = {};
this.args.data.response_format.forEach((keyDesc) => {
toDisplay[keyDesc.key] = keyDesc.type;
if (keyDesc.type === "array") {
toDisplay[keyDesc.key] = `[${keyDesc.array_type}]`;
} else {
toDisplay[keyDesc.key] = keyDesc.type;
}
});
return prettyJSON(toDisplay);

View File

@ -9,6 +9,7 @@ module DiscourseAi
@stream_consumer = stream_consumer
@current_key = nil
@current_value = nil
@tracking_array = false
@parser = DiscourseAi::Completions::JsonStreamingParser.new
@parser.key do |k|
@ -16,12 +17,28 @@ module DiscourseAi
@current_value = nil
end
@parser.value do |v|
@parser.value do |value|
if @current_key
stream_consumer.notify_progress(@current_key, v)
@current_key = nil
if @tracking_array
@current_value << value
stream_consumer.notify_progress(@current_key, @current_value)
else
stream_consumer.notify_progress(@current_key, value)
@current_key = nil
end
end
end
@parser.start_array do
@tracking_array = true
@current_value = []
end
@parser.end_array do
@tracking_array = false
@current_key = nil
@current_value = nil
end
end
def broken?
@ -46,8 +63,9 @@ module DiscourseAi
end
if @parser.state == :start_string && @current_key
buffered = @tracking_array ? [@parser.buf] : @parser.buf
# this is is worth notifying
stream_consumer.notify_progress(@current_key, @parser.buf)
stream_consumer.notify_progress(@current_key, buffered)
end
@current_key = nil if @parser.state == :end_value

View File

@ -45,7 +45,7 @@ module DiscourseAi
@property_cursors[prop_name] = @tracked[prop_name].length
unread
else
# Ints and bools are always returned as is.
# Ints and bools, and arrays are always returned as is.
@tracked[prop_name]
end
end

View File

@ -122,12 +122,13 @@ module DiscourseAi
)
bot = DiscourseAi::Personas::Bot.as(Discourse.system_user, persona: persona, model: llm)
structured_output = nil
response = bot.reply(context)
bot.reply(context) do |partial, _, type|
structured_output = partial if type == :structured_output
end
matching_concepts = JSON.parse(response[0][0]).dig("matching_concepts")
matching_concepts || []
structured_output&.read_buffered_property(:matching_concepts) || []
end
end
end

View File

@ -24,11 +24,13 @@ module DiscourseAi
)
bot = DiscourseAi::Personas::Bot.as(Discourse.system_user, persona: persona, model: llm)
structured_output = nil
response = bot.reply(context)
bot.reply(context) do |partial, _, type|
structured_output = partial if type == :structured_output
end
concepts = JSON.parse(response[0][0]).dig("concepts")
concepts || []
structured_output&.read_buffered_property(:concepts) || []
end
# Creates or finds concepts in the database from provided names
@ -161,10 +163,13 @@ module DiscourseAi
DiscourseAi::Personas::BotContext.new(messages: [input], user: Discourse.system_user)
bot = DiscourseAi::Personas::Bot.as(Discourse.system_user, persona: persona, model: llm)
structured_output = nil
response = bot.reply(context)
bot.reply(context) do |partial, _, type|
structured_output = partial if type == :structured_output
end
concepts = JSON.parse(response[0][0]).dig("streamlined_tags")
structured_output&.read_buffered_property(:streamlined_tags) || []
end
end
end

View File

@ -318,8 +318,13 @@ module DiscourseAi
response_format
.to_a
.reduce({}) do |memo, format|
memo[format["key"].to_sym] = { type: format["type"] }
memo[format["key"].to_sym][:items] = format["items"] if format["items"]
type_desc = { type: format["type"] }
if format["type"] == "array"
type_desc[:items] = { type: format["array_type"] || "string" }
end
memo[format["key"].to_sym] = type_desc
memo
end

View File

@ -46,7 +46,7 @@ module DiscourseAi
end
def response_format
[{ "key" => "streamlined_tags", "type" => "array" }]
[{ "key" => "streamlined_tags", "type" => "array", "array_type" => "string" }]
end
end
end

View File

@ -42,7 +42,7 @@ module DiscourseAi
end
def response_format
[{ "key" => "concepts", "type" => "array", "items" => { "type" => "string" } }]
[{ "key" => "concepts", "type" => "array", "array_type" => "string" }]
end
end
end

View File

@ -36,7 +36,7 @@ module DiscourseAi
end
def response_format
[{ "key" => "matching_concepts", "type" => "array" }]
[{ "key" => "matching_concepts", "type" => "array", "array_type" => "string" }]
end
end
end

View File

@ -672,5 +672,87 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(structured_output.read_buffered_property(:key)).to eq("Hello!\n There")
end
end
it "works with JSON schema array types" do
schema = {
type: "json_schema",
json_schema: {
name: "reply",
schema: {
type: "object",
properties: {
plain: {
type: "string",
},
key: {
type: "array",
items: {
type: "string",
},
},
},
required: %w[plain key],
additionalProperties: false,
},
strict: true,
},
}
messages =
[
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
{ type: "content_block_delta", delta: { text: "\"" } },
{ type: "content_block_delta", delta: { text: "key" } },
{ type: "content_block_delta", delta: { text: "\":" } },
{ type: "content_block_delta", delta: { text: " [\"" } },
{ type: "content_block_delta", delta: { text: "Hello!" } },
{ type: "content_block_delta", delta: { text: " I am" } },
{ type: "content_block_delta", delta: { text: " a " } },
{ type: "content_block_delta", delta: { text: "chunk\"," } },
{ type: "content_block_delta", delta: { text: "\"There" } },
{ type: "content_block_delta", delta: { text: "\"]," } },
{ type: "content_block_delta", delta: { text: " \"plain" } },
{ type: "content_block_delta", delta: { text: "\":\"" } },
{ type: "content_block_delta", delta: { text: "I'm here" } },
{ type: "content_block_delta", delta: { text: " too\"}" } },
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
].map { |message| encode_message(message) }
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
structured_output = nil
proxy.generate("hello world", response_format: schema, user: user) do |partial|
structured_output = partial
end
expected = {
"max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [
{ "role" => "user", "content" => "hello world" },
{ "role" => "assistant", "content" => "{" },
],
"system" => "You are a helpful bot",
}
expect(JSON.parse(request.body)).to eq(expected)
expect(structured_output.read_buffered_property(:key)).to contain_exactly(
"Hello! I am a chunk",
"There",
)
expect(structured_output.read_buffered_property(:plain)).to eq("I'm here too")
end
end
end
end

View File

@ -16,6 +16,12 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do
status: {
type: "string",
},
list: {
type: "array",
items: {
type: "string",
},
},
},
)
end
@ -64,6 +70,48 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do
# No partial string left to read.
expect(structured_output.read_buffered_property(:status)).to eq("")
end
it "supports array types" do
chunks = [
+"{ \"",
+"list",
+"\":",
+" [\"",
+"Hello!",
+" I am",
+" a ",
+"chunk\",",
+"\"There\"",
+"]}",
]
structured_output << chunks[0]
structured_output << chunks[1]
structured_output << chunks[2]
expect(structured_output.read_buffered_property(:list)).to eq(nil)
structured_output << chunks[3]
expect(structured_output.read_buffered_property(:list)).to eq([""])
structured_output << chunks[4]
expect(structured_output.read_buffered_property(:list)).to eq(["Hello!"])
structured_output << chunks[5]
structured_output << chunks[6]
structured_output << chunks[7]
expect(structured_output.read_buffered_property(:list)).to eq(["Hello! I am a chunk"])
structured_output << chunks[8]
expect(structured_output.read_buffered_property(:list)).to eq(
["Hello! I am a chunk", "There"],
)
structured_output << chunks[9]
expect(structured_output.read_buffered_property(:list)).to eq(
["Hello! I am a chunk", "There"],
)
end
end
describe "dealing with non-JSON responses" do