Add proper support for array types in structured outputs
This commit is contained in:
parent
5682e8db0d
commit
a73e9e6724
|
@ -22,10 +22,20 @@ export default class AiPersonaResponseFormatEditor extends Component {
|
|||
type: "string",
|
||||
},
|
||||
type: {
|
||||
type: "string",
|
||||
enum: ["string", "integer", "boolean", "array"],
|
||||
},
|
||||
array_type: {
|
||||
type: "string",
|
||||
enum: ["string", "integer", "boolean"],
|
||||
options: {
|
||||
dependencies: {
|
||||
type: "array",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ["key", "type"],
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -41,7 +51,11 @@ export default class AiPersonaResponseFormatEditor extends Component {
|
|||
const toDisplay = {};
|
||||
|
||||
this.args.data.response_format.forEach((keyDesc) => {
|
||||
toDisplay[keyDesc.key] = keyDesc.type;
|
||||
if (keyDesc.type === "array") {
|
||||
toDisplay[keyDesc.key] = `[${keyDesc.array_type}]`;
|
||||
} else {
|
||||
toDisplay[keyDesc.key] = keyDesc.type;
|
||||
}
|
||||
});
|
||||
|
||||
return prettyJSON(toDisplay);
|
||||
|
|
|
@ -9,6 +9,7 @@ module DiscourseAi
|
|||
@stream_consumer = stream_consumer
|
||||
@current_key = nil
|
||||
@current_value = nil
|
||||
@tracking_array = false
|
||||
@parser = DiscourseAi::Completions::JsonStreamingParser.new
|
||||
|
||||
@parser.key do |k|
|
||||
|
@ -16,12 +17,28 @@ module DiscourseAi
|
|||
@current_value = nil
|
||||
end
|
||||
|
||||
@parser.value do |v|
|
||||
@parser.value do |value|
|
||||
if @current_key
|
||||
stream_consumer.notify_progress(@current_key, v)
|
||||
@current_key = nil
|
||||
if @tracking_array
|
||||
@current_value << value
|
||||
stream_consumer.notify_progress(@current_key, @current_value)
|
||||
else
|
||||
stream_consumer.notify_progress(@current_key, value)
|
||||
@current_key = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@parser.start_array do
|
||||
@tracking_array = true
|
||||
@current_value = []
|
||||
end
|
||||
|
||||
@parser.end_array do
|
||||
@tracking_array = false
|
||||
@current_key = nil
|
||||
@current_value = nil
|
||||
end
|
||||
end
|
||||
|
||||
def broken?
|
||||
|
@ -46,8 +63,9 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
if @parser.state == :start_string && @current_key
|
||||
buffered = @tracking_array ? [@parser.buf] : @parser.buf
|
||||
# this is is worth notifying
|
||||
stream_consumer.notify_progress(@current_key, @parser.buf)
|
||||
stream_consumer.notify_progress(@current_key, buffered)
|
||||
end
|
||||
|
||||
@current_key = nil if @parser.state == :end_value
|
||||
|
|
|
@ -45,7 +45,7 @@ module DiscourseAi
|
|||
@property_cursors[prop_name] = @tracked[prop_name].length
|
||||
unread
|
||||
else
|
||||
# Ints and bools are always returned as is.
|
||||
# Ints and bools, and arrays are always returned as is.
|
||||
@tracked[prop_name]
|
||||
end
|
||||
end
|
||||
|
|
|
@ -122,12 +122,13 @@ module DiscourseAi
|
|||
)
|
||||
|
||||
bot = DiscourseAi::Personas::Bot.as(Discourse.system_user, persona: persona, model: llm)
|
||||
structured_output = nil
|
||||
|
||||
response = bot.reply(context)
|
||||
bot.reply(context) do |partial, _, type|
|
||||
structured_output = partial if type == :structured_output
|
||||
end
|
||||
|
||||
matching_concepts = JSON.parse(response[0][0]).dig("matching_concepts")
|
||||
|
||||
matching_concepts || []
|
||||
structured_output&.read_buffered_property(:matching_concepts) || []
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -24,11 +24,13 @@ module DiscourseAi
|
|||
)
|
||||
|
||||
bot = DiscourseAi::Personas::Bot.as(Discourse.system_user, persona: persona, model: llm)
|
||||
structured_output = nil
|
||||
|
||||
response = bot.reply(context)
|
||||
bot.reply(context) do |partial, _, type|
|
||||
structured_output = partial if type == :structured_output
|
||||
end
|
||||
|
||||
concepts = JSON.parse(response[0][0]).dig("concepts")
|
||||
concepts || []
|
||||
structured_output&.read_buffered_property(:concepts) || []
|
||||
end
|
||||
|
||||
# Creates or finds concepts in the database from provided names
|
||||
|
@ -161,10 +163,13 @@ module DiscourseAi
|
|||
DiscourseAi::Personas::BotContext.new(messages: [input], user: Discourse.system_user)
|
||||
|
||||
bot = DiscourseAi::Personas::Bot.as(Discourse.system_user, persona: persona, model: llm)
|
||||
structured_output = nil
|
||||
|
||||
response = bot.reply(context)
|
||||
bot.reply(context) do |partial, _, type|
|
||||
structured_output = partial if type == :structured_output
|
||||
end
|
||||
|
||||
concepts = JSON.parse(response[0][0]).dig("streamlined_tags")
|
||||
structured_output&.read_buffered_property(:streamlined_tags) || []
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -318,8 +318,13 @@ module DiscourseAi
|
|||
response_format
|
||||
.to_a
|
||||
.reduce({}) do |memo, format|
|
||||
memo[format["key"].to_sym] = { type: format["type"] }
|
||||
memo[format["key"].to_sym][:items] = format["items"] if format["items"]
|
||||
type_desc = { type: format["type"] }
|
||||
|
||||
if format["type"] == "array"
|
||||
type_desc[:items] = { type: format["array_type"] || "string" }
|
||||
end
|
||||
|
||||
memo[format["key"].to_sym] = type_desc
|
||||
memo
|
||||
end
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def response_format
|
||||
[{ "key" => "streamlined_tags", "type" => "array" }]
|
||||
[{ "key" => "streamlined_tags", "type" => "array", "array_type" => "string" }]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -42,7 +42,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def response_format
|
||||
[{ "key" => "concepts", "type" => "array", "items" => { "type" => "string" } }]
|
||||
[{ "key" => "concepts", "type" => "array", "array_type" => "string" }]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -36,7 +36,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def response_format
|
||||
[{ "key" => "matching_concepts", "type" => "array" }]
|
||||
[{ "key" => "matching_concepts", "type" => "array", "array_type" => "string" }]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -672,5 +672,87 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
expect(structured_output.read_buffered_property(:key)).to eq("Hello!\n There")
|
||||
end
|
||||
end
|
||||
|
||||
it "works with JSON schema array types" do
|
||||
schema = {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "reply",
|
||||
schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
plain: {
|
||||
type: "string",
|
||||
},
|
||||
key: {
|
||||
type: "array",
|
||||
items: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
required: %w[plain key],
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
},
|
||||
}
|
||||
|
||||
messages =
|
||||
[
|
||||
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
||||
{ type: "content_block_delta", delta: { text: "\"" } },
|
||||
{ type: "content_block_delta", delta: { text: "key" } },
|
||||
{ type: "content_block_delta", delta: { text: "\":" } },
|
||||
{ type: "content_block_delta", delta: { text: " [\"" } },
|
||||
{ type: "content_block_delta", delta: { text: "Hello!" } },
|
||||
{ type: "content_block_delta", delta: { text: " I am" } },
|
||||
{ type: "content_block_delta", delta: { text: " a " } },
|
||||
{ type: "content_block_delta", delta: { text: "chunk\"," } },
|
||||
{ type: "content_block_delta", delta: { text: "\"There" } },
|
||||
{ type: "content_block_delta", delta: { text: "\"]," } },
|
||||
{ type: "content_block_delta", delta: { text: " \"plain" } },
|
||||
{ type: "content_block_delta", delta: { text: "\":\"" } },
|
||||
{ type: "content_block_delta", delta: { text: "I'm here" } },
|
||||
{ type: "content_block_delta", delta: { text: " too\"}" } },
|
||||
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
||||
].map { |message| encode_message(message) }
|
||||
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
request = nil
|
||||
bedrock_mock.with_chunk_array_support do
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: messages)
|
||||
|
||||
structured_output = nil
|
||||
proxy.generate("hello world", response_format: schema, user: user) do |partial|
|
||||
structured_output = partial
|
||||
end
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 4096,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [
|
||||
{ "role" => "user", "content" => "hello world" },
|
||||
{ "role" => "assistant", "content" => "{" },
|
||||
],
|
||||
"system" => "You are a helpful bot",
|
||||
}
|
||||
expect(JSON.parse(request.body)).to eq(expected)
|
||||
|
||||
expect(structured_output.read_buffered_property(:key)).to contain_exactly(
|
||||
"Hello! I am a chunk",
|
||||
"There",
|
||||
)
|
||||
expect(structured_output.read_buffered_property(:plain)).to eq("I'm here too")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -16,6 +16,12 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do
|
|||
status: {
|
||||
type: "string",
|
||||
},
|
||||
list: {
|
||||
type: "array",
|
||||
items: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
end
|
||||
|
@ -64,6 +70,48 @@ RSpec.describe DiscourseAi::Completions::StructuredOutput do
|
|||
# No partial string left to read.
|
||||
expect(structured_output.read_buffered_property(:status)).to eq("")
|
||||
end
|
||||
|
||||
it "supports array types" do
|
||||
chunks = [
|
||||
+"{ \"",
|
||||
+"list",
|
||||
+"\":",
|
||||
+" [\"",
|
||||
+"Hello!",
|
||||
+" I am",
|
||||
+" a ",
|
||||
+"chunk\",",
|
||||
+"\"There\"",
|
||||
+"]}",
|
||||
]
|
||||
|
||||
structured_output << chunks[0]
|
||||
structured_output << chunks[1]
|
||||
structured_output << chunks[2]
|
||||
expect(structured_output.read_buffered_property(:list)).to eq(nil)
|
||||
|
||||
structured_output << chunks[3]
|
||||
expect(structured_output.read_buffered_property(:list)).to eq([""])
|
||||
|
||||
structured_output << chunks[4]
|
||||
expect(structured_output.read_buffered_property(:list)).to eq(["Hello!"])
|
||||
|
||||
structured_output << chunks[5]
|
||||
structured_output << chunks[6]
|
||||
structured_output << chunks[7]
|
||||
|
||||
expect(structured_output.read_buffered_property(:list)).to eq(["Hello! I am a chunk"])
|
||||
|
||||
structured_output << chunks[8]
|
||||
expect(structured_output.read_buffered_property(:list)).to eq(
|
||||
["Hello! I am a chunk", "There"],
|
||||
)
|
||||
|
||||
structured_output << chunks[9]
|
||||
expect(structured_output.read_buffered_property(:list)).to eq(
|
||||
["Hello! I am a chunk", "There"],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "dealing with non-JSON responses" do
|
||||
|
|
Loading…
Reference in New Issue