DEV: Use structured responses for summaries (#1252)
* DEV: Use structured responses for summaries * Fix system specs * Make response_format a first class citizen and update endpoints to support it * Response format can be specified in the persona * lint * switch to jsonb and make column nullable * Reify structured output chunks. Move JSON parsing to the depths of Completion * Switch to JsonStreamingTracker for partial JSON parsing
This commit is contained in:
parent
c6a307b473
commit
c0a2d4c935
|
@ -221,6 +221,10 @@ module DiscourseAi
|
|||
permitted[:tools] = permit_tools(tools)
|
||||
end
|
||||
|
||||
if response_format = params.dig(:ai_persona, :response_format)
|
||||
permitted[:response_format] = permit_response_format(response_format)
|
||||
end
|
||||
|
||||
permitted
|
||||
end
|
||||
|
||||
|
@ -235,6 +239,18 @@ module DiscourseAi
|
|||
[tool, options, !!force_tool]
|
||||
end
|
||||
end
|
||||
|
||||
def permit_response_format(response_format)
|
||||
return [] if !response_format.is_a?(Array)
|
||||
|
||||
response_format.map do |element|
|
||||
if element && element.is_a?(ActionController::Parameters)
|
||||
element.permit!
|
||||
else
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -325,9 +325,11 @@ class AiPersona < ActiveRecord::Base
|
|||
end
|
||||
|
||||
def system_persona_unchangeable
|
||||
error_msg = I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona")
|
||||
|
||||
if top_p_changed? || temperature_changed? || system_prompt_changed? || name_changed? ||
|
||||
description_changed?
|
||||
errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"))
|
||||
errors.add(:base, error_msg)
|
||||
elsif tools_changed?
|
||||
old_tools = tools_change[0]
|
||||
new_tools = tools_change[1]
|
||||
|
@ -335,9 +337,12 @@ class AiPersona < ActiveRecord::Base
|
|||
old_tool_names = old_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set
|
||||
new_tool_names = new_tools.map { |t| t.is_a?(Array) ? t[0] : t }.to_set
|
||||
|
||||
if old_tool_names != new_tool_names
|
||||
errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"))
|
||||
end
|
||||
errors.add(:base, error_msg) if old_tool_names != new_tool_names
|
||||
elsif response_format_changed?
|
||||
old_format = response_format_change[0].map { |f| f["key"] }.to_set
|
||||
new_format = response_format_change[1].map { |f| f["key"] }.to_set
|
||||
|
||||
errors.add(:base, error_msg) if old_format != new_format
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -395,6 +400,7 @@ end
|
|||
# rag_llm_model_id :bigint
|
||||
# default_llm_id :bigint
|
||||
# question_consolidator_llm_id :bigint
|
||||
# response_format :jsonb
|
||||
#
|
||||
# Indexes
|
||||
#
|
||||
|
|
|
@ -30,7 +30,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
|
|||
:allow_chat_direct_messages,
|
||||
:allow_topic_mentions,
|
||||
:allow_personal_messages,
|
||||
:force_default_llm
|
||||
:force_default_llm,
|
||||
:response_format
|
||||
|
||||
has_one :user, serializer: BasicUserSerializer, embed: :object
|
||||
has_many :rag_uploads, serializer: UploadSerializer, embed: :object
|
||||
|
|
|
@ -33,6 +33,7 @@ const CREATE_ATTRIBUTES = [
|
|||
"allow_topic_mentions",
|
||||
"allow_chat_channel_mentions",
|
||||
"allow_chat_direct_messages",
|
||||
"response_format",
|
||||
];
|
||||
|
||||
const SYSTEM_ATTRIBUTES = [
|
||||
|
@ -60,6 +61,7 @@ const SYSTEM_ATTRIBUTES = [
|
|||
"allow_topic_mentions",
|
||||
"allow_chat_channel_mentions",
|
||||
"allow_chat_direct_messages",
|
||||
"response_format",
|
||||
];
|
||||
|
||||
export default class AiPersona extends RestModel {
|
||||
|
@ -151,6 +153,7 @@ export default class AiPersona extends RestModel {
|
|||
const attrs = this.getProperties(CREATE_ATTRIBUTES);
|
||||
this.populateTools(attrs);
|
||||
attrs.forced_tool_count = this.forced_tool_count || -1;
|
||||
attrs.response_format = attrs.response_format || [];
|
||||
|
||||
return attrs;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import Group from "discourse/models/group";
|
|||
import { i18n } from "discourse-i18n";
|
||||
import AdminUser from "admin/models/admin-user";
|
||||
import GroupChooser from "select-kit/components/group-chooser";
|
||||
import AiPersonaResponseFormatEditor from "../components/modal/ai-persona-response-format-editor";
|
||||
import AiLlmSelector from "./ai-llm-selector";
|
||||
import AiPersonaToolOptions from "./ai-persona-tool-options";
|
||||
import AiToolSelector from "./ai-tool-selector";
|
||||
|
@ -325,6 +326,8 @@ export default class PersonaEditor extends Component {
|
|||
<field.Textarea />
|
||||
</form.Field>
|
||||
|
||||
<AiPersonaResponseFormatEditor @form={{form}} @data={{data}} />
|
||||
|
||||
<form.Field
|
||||
@name="default_llm_id"
|
||||
@title={{i18n "discourse_ai.ai_persona.default_llm"}}
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
import Component from "@glimmer/component";
|
||||
import { tracked } from "@glimmer/tracking";
|
||||
import { fn, hash } from "@ember/helper";
|
||||
import { action } from "@ember/object";
|
||||
import { gt } from "truth-helpers";
|
||||
import ModalJsonSchemaEditor from "discourse/components/modal/json-schema-editor";
|
||||
import { prettyJSON } from "discourse/lib/formatter";
|
||||
import { i18n } from "discourse-i18n";
|
||||
|
||||
export default class AiPersonaResponseFormatEditor extends Component {
|
||||
@tracked showJsonEditorModal = false;
|
||||
|
||||
jsonSchema = {
|
||||
type: "array",
|
||||
uniqueItems: true,
|
||||
title: i18n("discourse_ai.ai_persona.response_format.modal.root_title"),
|
||||
items: {
|
||||
type: "object",
|
||||
title: i18n("discourse_ai.ai_persona.response_format.modal.key_title"),
|
||||
properties: {
|
||||
key: {
|
||||
type: "string",
|
||||
},
|
||||
type: {
|
||||
type: "string",
|
||||
enum: ["string", "integer", "boolean"],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
get editorTitle() {
|
||||
return i18n("discourse_ai.ai_persona.response_format.title");
|
||||
}
|
||||
|
||||
get responseFormatAsJSON() {
|
||||
return JSON.stringify(this.args.data.response_format);
|
||||
}
|
||||
|
||||
get displayJSON() {
|
||||
const toDisplay = {};
|
||||
|
||||
this.args.data.response_format.forEach((keyDesc) => {
|
||||
toDisplay[keyDesc.key] = keyDesc.type;
|
||||
});
|
||||
|
||||
return prettyJSON(toDisplay);
|
||||
}
|
||||
|
||||
@action
|
||||
openModal() {
|
||||
this.showJsonEditorModal = true;
|
||||
}
|
||||
|
||||
@action
|
||||
closeModal() {
|
||||
this.showJsonEditorModal = false;
|
||||
}
|
||||
|
||||
@action
|
||||
updateResponseFormat(form, value) {
|
||||
form.set("response_format", JSON.parse(value));
|
||||
}
|
||||
|
||||
<template>
|
||||
<@form.Container @title={{this.editorTitle}} @format="large">
|
||||
<div class="ai-persona-editor__response-format">
|
||||
{{#if (gt @data.response_format.length 0)}}
|
||||
<pre class="ai-persona-editor__response-format-pre">
|
||||
<code
|
||||
>{{this.displayJSON}}</code>
|
||||
</pre>
|
||||
{{else}}
|
||||
<div class="ai-persona-editor__response-format-none">
|
||||
{{i18n "discourse_ai.ai_persona.response_format.no_format"}}
|
||||
</div>
|
||||
{{/if}}
|
||||
|
||||
<@form.Button
|
||||
@action={{this.openModal}}
|
||||
@label="discourse_ai.ai_persona.response_format.open_modal"
|
||||
@disabled={{@data.system}}
|
||||
/>
|
||||
</div>
|
||||
</@form.Container>
|
||||
|
||||
{{#if this.showJsonEditorModal}}
|
||||
<ModalJsonSchemaEditor
|
||||
@model={{hash
|
||||
value=this.responseFormatAsJSON
|
||||
updateValue=(fn this.updateResponseFormat @form)
|
||||
settingName=this.editorTitle
|
||||
jsonSchema=this.jsonSchema
|
||||
}}
|
||||
@closeModal={{this.closeModal}}
|
||||
/>
|
||||
{{/if}}
|
||||
</template>
|
||||
}
|
|
@ -52,6 +52,21 @@
|
|||
margin-bottom: 10px;
|
||||
font-size: var(--font-down-1);
|
||||
}
|
||||
|
||||
&__response-format {
|
||||
width: 100%;
|
||||
display: block;
|
||||
}
|
||||
|
||||
&__response-format-pre {
|
||||
margin-bottom: 0;
|
||||
white-space: pre-line;
|
||||
}
|
||||
|
||||
&__response-format-none {
|
||||
margin-bottom: 1em;
|
||||
margin-top: 0.5em;
|
||||
}
|
||||
}
|
||||
|
||||
.rag-options {
|
||||
|
|
|
@ -323,6 +323,13 @@ en:
|
|||
rag_conversation_chunks: "Search conversation chunks"
|
||||
rag_conversation_chunks_help: "The number of chunks to use for the RAG model searches. Increase to increase the amount of context the AI can use."
|
||||
persona_description: "Personas are a powerful feature that allows you to customize the behavior of the AI engine in your Discourse forum. They act as a 'system message' that guides the AI's responses and interactions, helping to create a more personalized and engaging user experience."
|
||||
response_format:
|
||||
title: "JSON response format"
|
||||
no_format: "No JSON format specified"
|
||||
open_modal: "Edit"
|
||||
modal:
|
||||
root_title: "Response structure"
|
||||
key_title: "Key"
|
||||
|
||||
list:
|
||||
enabled: "AI Bot?"
|
||||
|
|
|
@ -72,6 +72,8 @@ DiscourseAi::Personas::Persona.system_personas.each do |persona_class, id|
|
|||
|
||||
persona.tools = tools.map { |name, value| [name, value] }
|
||||
|
||||
persona.response_format = instance.response_format
|
||||
|
||||
persona.system_prompt = instance.system_prompt
|
||||
persona.top_p = instance.top_p
|
||||
persona.temperature = instance.temperature
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# frozen_string_literal: true
|
||||
class AddResponseFormatJsonToPersonass < ActiveRecord::Migration[7.2]
|
||||
def change
|
||||
add_column :ai_personas, :response_format, :jsonb
|
||||
end
|
||||
end
|
|
@ -10,7 +10,7 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
@raw_json = +""
|
||||
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
|
||||
@streaming_parser =
|
||||
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
|
||||
DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls
|
||||
end
|
||||
|
||||
def append(json)
|
||||
|
|
|
@ -42,6 +42,7 @@ module DiscourseAi
|
|||
result = { system: system, messages: messages }
|
||||
result[:inferenceConfig] = inference_config if inference_config.present?
|
||||
result[:toolConfig] = tool_config if tool_config.present?
|
||||
result[:response_format] = { type: "json_object" } if options[:response_format].present?
|
||||
|
||||
result
|
||||
end
|
||||
|
|
|
@ -88,10 +88,16 @@ module DiscourseAi
|
|||
def prepare_payload(prompt, model_params, dialect)
|
||||
@native_tool_support = dialect.native_tool_support?
|
||||
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
payload =
|
||||
default_options(dialect).merge(model_params.except(:response_format)).merge(
|
||||
messages: prompt.messages,
|
||||
)
|
||||
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload[:stream] = true if @streaming_mode
|
||||
|
||||
prefilled_message = +""
|
||||
|
||||
if prompt.has_tools?
|
||||
payload[:tools] = prompt.tools
|
||||
if dialect.tool_choice.present?
|
||||
|
@ -100,16 +106,24 @@ module DiscourseAi
|
|||
|
||||
# prefill prompt to nudge LLM to generate a response that is useful.
|
||||
# without this LLM (even 3.7) can get confused and start text preambles for a tool calls.
|
||||
payload[:messages] << {
|
||||
role: "assistant",
|
||||
content: dialect.no_more_tool_calls_text,
|
||||
}
|
||||
prefilled_message << dialect.no_more_tool_calls_text
|
||||
else
|
||||
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Prefill prompt to force JSON output.
|
||||
if model_params[:response_format].present?
|
||||
prefilled_message << " " if !prefilled_message.empty?
|
||||
prefilled_message << "{"
|
||||
@forced_json_through_prefill = true
|
||||
end
|
||||
|
||||
if !prefilled_message.empty?
|
||||
payload[:messages] << { role: "assistant", content: prefilled_message }
|
||||
end
|
||||
|
||||
payload
|
||||
end
|
||||
|
||||
|
|
|
@ -116,9 +116,14 @@ module DiscourseAi
|
|||
payload = nil
|
||||
|
||||
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
payload =
|
||||
default_options(dialect).merge(model_params.except(:response_format)).merge(
|
||||
messages: prompt.messages,
|
||||
)
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
|
||||
prefilled_message = +""
|
||||
|
||||
if prompt.has_tools?
|
||||
payload[:tools] = prompt.tools
|
||||
if dialect.tool_choice.present?
|
||||
|
@ -128,15 +133,23 @@ module DiscourseAi
|
|||
# payload[:tool_choice] = { type: "none" }
|
||||
|
||||
# prefill prompt to nudge LLM to generate a response that is useful, instead of trying to call a tool
|
||||
payload[:messages] << {
|
||||
role: "assistant",
|
||||
content: dialect.no_more_tool_calls_text,
|
||||
}
|
||||
prefilled_message << dialect.no_more_tool_calls_text
|
||||
else
|
||||
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Prefill prompt to force JSON output.
|
||||
if model_params[:response_format].present?
|
||||
prefilled_message << " " if !prefilled_message.empty?
|
||||
prefilled_message << "{"
|
||||
@forced_json_through_prefill = true
|
||||
end
|
||||
|
||||
if !prefilled_message.empty?
|
||||
payload[:messages] << { role: "assistant", content: prefilled_message }
|
||||
end
|
||||
elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova)
|
||||
payload = prompt.to_payload(default_options(dialect).merge(model_params))
|
||||
else
|
||||
|
|
|
@ -73,6 +73,7 @@ module DiscourseAi
|
|||
LlmQuota.check_quotas!(@llm_model, user)
|
||||
start_time = Time.now
|
||||
|
||||
@forced_json_through_prefill = false
|
||||
@partial_tool_calls = partial_tool_calls
|
||||
@output_thinking = output_thinking
|
||||
|
||||
|
@ -106,6 +107,18 @@ module DiscourseAi
|
|||
|
||||
prompt = dialect.translate
|
||||
|
||||
structured_output = nil
|
||||
|
||||
if model_params[:response_format].present?
|
||||
schema_properties =
|
||||
model_params[:response_format].dig(:json_schema, :schema, :properties)
|
||||
|
||||
if schema_properties.present?
|
||||
structured_output =
|
||||
DiscourseAi::Completions::StructuredOutput.new(schema_properties)
|
||||
end
|
||||
end
|
||||
|
||||
FinalDestination::HTTP.start(
|
||||
model_uri.host,
|
||||
model_uri.port,
|
||||
|
@ -123,6 +136,10 @@ module DiscourseAi
|
|||
|
||||
request = prepare_request(request_body)
|
||||
|
||||
# Some providers rely on prefill to return structured outputs, so the start
|
||||
# of the JSON won't be included in the response. Supply it to keep JSON valid.
|
||||
structured_output << +"{" if structured_output && @forced_json_through_prefill
|
||||
|
||||
http.request(request) do |response|
|
||||
if response.code.to_i != 200
|
||||
Rails.logger.error(
|
||||
|
@ -140,10 +157,17 @@ module DiscourseAi
|
|||
xml_stripper =
|
||||
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
||||
|
||||
if @streaming_mode && xml_stripper
|
||||
if @streaming_mode
|
||||
blk =
|
||||
lambda do |partial, cancel|
|
||||
partial = xml_stripper << partial if partial.is_a?(String)
|
||||
if partial.is_a?(String)
|
||||
partial = xml_stripper << partial if xml_stripper
|
||||
|
||||
if structured_output.present?
|
||||
structured_output << partial
|
||||
partial = structured_output
|
||||
end
|
||||
end
|
||||
orig_blk.call(partial, cancel) if partial
|
||||
end
|
||||
end
|
||||
|
@ -167,6 +191,7 @@ module DiscourseAi
|
|||
xml_stripper: xml_stripper,
|
||||
partials_raw: partials_raw,
|
||||
response_raw: response_raw,
|
||||
structured_output: structured_output,
|
||||
)
|
||||
return response_data
|
||||
end
|
||||
|
@ -373,7 +398,8 @@ module DiscourseAi
|
|||
xml_tool_processor:,
|
||||
xml_stripper:,
|
||||
partials_raw:,
|
||||
response_raw:
|
||||
response_raw:,
|
||||
structured_output:
|
||||
)
|
||||
response_raw << response.read_body
|
||||
response_data = decode(response_raw)
|
||||
|
@ -403,6 +429,26 @@ module DiscourseAi
|
|||
|
||||
response_data.reject!(&:blank?)
|
||||
|
||||
if structured_output.present?
|
||||
has_string_response = false
|
||||
|
||||
response_data =
|
||||
response_data.reduce([]) do |memo, data|
|
||||
if data.is_a?(String)
|
||||
structured_output << data
|
||||
has_string_response = true
|
||||
next(memo)
|
||||
else
|
||||
memo << data
|
||||
end
|
||||
|
||||
memo
|
||||
end
|
||||
|
||||
# We only include the structured output if there was actually a structured response
|
||||
response_data << structured_output if has_string_response
|
||||
end
|
||||
|
||||
# this is to keep stuff backwards compatible
|
||||
response_data = response_data.first if response_data.length == 1
|
||||
|
||||
|
|
|
@ -40,6 +40,8 @@ module DiscourseAi
|
|||
"The number of completions you requested exceed the number of canned responses"
|
||||
end
|
||||
|
||||
response = as_structured_output(response) if model_params[:response_format].present?
|
||||
|
||||
raise response if response.is_a?(StandardError)
|
||||
|
||||
@completions += 1
|
||||
|
@ -54,6 +56,8 @@ module DiscourseAi
|
|||
yield(response, cancel_fn)
|
||||
elsif is_thinking?(response)
|
||||
yield(response, cancel_fn)
|
||||
elsif is_structured_output?(response)
|
||||
yield(response, cancel_fn)
|
||||
else
|
||||
response.each_char do |char|
|
||||
break if cancelled
|
||||
|
@ -80,6 +84,20 @@ module DiscourseAi
|
|||
def is_tool?(response)
|
||||
response.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
end
|
||||
|
||||
def is_structured_output?(response)
|
||||
response.is_a?(DiscourseAi::Completions::StructuredOutput)
|
||||
end
|
||||
|
||||
def as_structured_output(response)
|
||||
schema_properties = model_params[:response_format].dig(:json_schema, :schema, :properties)
|
||||
return response if schema_properties.blank?
|
||||
|
||||
output = DiscourseAi::Completions::StructuredOutput.new(schema_properties)
|
||||
output << { schema_properties.keys.first => response }.to_json
|
||||
|
||||
output
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -84,7 +84,16 @@ module DiscourseAi
|
|||
|
||||
payload[:tool_config] = { function_calling_config: function_calling_config }
|
||||
end
|
||||
payload[:generationConfig].merge!(model_params) if model_params.present?
|
||||
if model_params.present?
|
||||
payload[:generationConfig].merge!(model_params.except(:response_format))
|
||||
|
||||
if model_params[:response_format].present?
|
||||
# https://ai.google.dev/api/generate-content#generationconfig
|
||||
payload[:generationConfig][:responseSchema] = model_params[:response_format]
|
||||
payload[:generationConfig][:responseMimeType] = "application/json"
|
||||
end
|
||||
end
|
||||
|
||||
payload
|
||||
end
|
||||
|
||||
|
|
|
@ -34,7 +34,12 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options.merge(model_params).merge(messages: prompt)
|
||||
payload =
|
||||
default_options.merge(model_params.except(:response_format)).merge(messages: prompt)
|
||||
|
||||
if model_params[:response_format].present?
|
||||
payload[:response_format] = { type: "json_object" }
|
||||
end
|
||||
|
||||
payload[:stream] = true if @streaming_mode
|
||||
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
class ToolCallProgressTracker
|
||||
attr_reader :current_key, :current_value, :tool_call
|
||||
class JsonStreamingTracker
|
||||
attr_reader :current_key, :current_value, :stream_consumer
|
||||
|
||||
def initialize(tool_call)
|
||||
@tool_call = tool_call
|
||||
def initialize(stream_consumer)
|
||||
@stream_consumer = stream_consumer
|
||||
@current_key = nil
|
||||
@current_value = nil
|
||||
@parser = DiscourseAi::Completions::JsonStreamingParser.new
|
||||
|
@ -18,7 +18,7 @@ module DiscourseAi
|
|||
|
||||
@parser.value do |v|
|
||||
if @current_key
|
||||
tool_call.notify_progress(@current_key, v)
|
||||
stream_consumer.notify_progress(@current_key, v)
|
||||
@current_key = nil
|
||||
end
|
||||
end
|
||||
|
@ -39,7 +39,7 @@ module DiscourseAi
|
|||
|
||||
if @parser.state == :start_string && @current_key
|
||||
# this is is worth notifying
|
||||
tool_call.notify_progress(@current_key, @parser.buf)
|
||||
stream_consumer.notify_progress(@current_key, @parser.buf)
|
||||
end
|
||||
|
||||
@current_key = nil if @parser.state == :end_value
|
|
@ -304,6 +304,8 @@ module DiscourseAi
|
|||
# @param feature_context { Hash - Optional } - The feature context to use for the completion.
|
||||
# @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls.
|
||||
# @param output_thinking { Boolean - Optional } - If true, the completion will return the thinking output for thinking models.
|
||||
# @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output.
|
||||
# @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema.
|
||||
#
|
||||
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
||||
#
|
||||
|
@ -321,6 +323,7 @@ module DiscourseAi
|
|||
feature_context: nil,
|
||||
partial_tool_calls: false,
|
||||
output_thinking: false,
|
||||
response_format: nil,
|
||||
extra_model_params: nil,
|
||||
&partial_read_blk
|
||||
)
|
||||
|
@ -336,6 +339,7 @@ module DiscourseAi
|
|||
feature_context: feature_context,
|
||||
partial_tool_calls: partial_tool_calls,
|
||||
output_thinking: output_thinking,
|
||||
response_format: response_format,
|
||||
extra_model_params: extra_model_params,
|
||||
},
|
||||
)
|
||||
|
@ -344,6 +348,7 @@ module DiscourseAi
|
|||
|
||||
model_params[:temperature] = temperature if temperature
|
||||
model_params[:top_p] = top_p if top_p
|
||||
model_params[:response_format] = response_format if response_format
|
||||
model_params.merge!(extra_model_params) if extra_model_params
|
||||
|
||||
if prompt.is_a?(String)
|
||||
|
|
|
@ -10,7 +10,7 @@ class DiscourseAi::Completions::NovaMessageProcessor
|
|||
@raw_json = +""
|
||||
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
|
||||
@streaming_parser =
|
||||
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
|
||||
DiscourseAi::Completions::JsonStreamingTracker.new(self) if partial_tool_calls
|
||||
end
|
||||
|
||||
def append(json)
|
||||
|
|
|
@ -59,7 +59,7 @@ module DiscourseAi::Completions
|
|||
if id.present? && name.present?
|
||||
@tool_arguments = +""
|
||||
@tool = ToolCall.new(id: id, name: name)
|
||||
@streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls
|
||||
@streaming_parser = JsonStreamingTracker.new(self) if @partial_tool_calls
|
||||
end
|
||||
|
||||
@tool_arguments << arguments.to_s
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
class StructuredOutput
|
||||
def initialize(json_schema_properties)
|
||||
@property_names = json_schema_properties.keys.map(&:to_sym)
|
||||
@property_cursors =
|
||||
json_schema_properties.reduce({}) do |m, (k, prop)|
|
||||
m[k.to_sym] = 0 if prop[:type] == "string"
|
||||
m
|
||||
end
|
||||
|
||||
@tracked = {}
|
||||
|
||||
@partial_json_tracker = JsonStreamingTracker.new(self)
|
||||
end
|
||||
|
||||
attr_reader :last_chunk_buffer
|
||||
|
||||
def <<(raw)
|
||||
@partial_json_tracker << raw
|
||||
end
|
||||
|
||||
def read_latest_buffered_chunk
|
||||
@property_names.reduce({}) do |memo, pn|
|
||||
if @tracked[pn].present?
|
||||
# This means this property is a string and we want to return unread chunks.
|
||||
if @property_cursors[pn].present?
|
||||
unread = @tracked[pn][@property_cursors[pn]..]
|
||||
|
||||
memo[pn] = unread if unread.present?
|
||||
@property_cursors[pn] = @tracked[pn].length
|
||||
else
|
||||
# Ints and bools are always returned as is.
|
||||
memo[pn] = @tracked[pn]
|
||||
end
|
||||
end
|
||||
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
def notify_progress(key, value)
|
||||
key_sym = key.to_sym
|
||||
return if !@property_names.include?(key_sym)
|
||||
|
||||
@tracked[key_sym] = value
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -64,10 +64,13 @@ module DiscourseAi
|
|||
|
||||
user = context.user
|
||||
|
||||
llm_kwargs = { user: user }
|
||||
llm_kwargs = llm_args.dup
|
||||
llm_kwargs[:user] = user
|
||||
llm_kwargs[:temperature] = persona.temperature if persona.temperature
|
||||
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
||||
llm_kwargs[:max_tokens] = llm_args[:max_tokens] if llm_args[:max_tokens].present?
|
||||
llm_kwargs[:response_format] = build_json_schema(
|
||||
persona.response_format,
|
||||
) if persona.response_format.present?
|
||||
|
||||
needs_newlines = false
|
||||
tools_ran = 0
|
||||
|
@ -148,6 +151,8 @@ module DiscourseAi
|
|||
raw_context << partial
|
||||
current_thinking << partial
|
||||
end
|
||||
elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput)
|
||||
update_blk.call(partial, cancel, nil, :structured_output)
|
||||
else
|
||||
update_blk.call(partial, cancel)
|
||||
end
|
||||
|
@ -176,6 +181,10 @@ module DiscourseAi
|
|||
embed_thinking(raw_context)
|
||||
end
|
||||
|
||||
def returns_json?
|
||||
persona.response_format.present?
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def embed_thinking(raw_context)
|
||||
|
@ -301,6 +310,30 @@ module DiscourseAi
|
|||
|
||||
placeholder
|
||||
end
|
||||
|
||||
def build_json_schema(response_format)
|
||||
properties =
|
||||
response_format
|
||||
.to_a
|
||||
.reduce({}) do |memo, format|
|
||||
memo[format[:key].to_sym] = { type: format[:type] }
|
||||
memo
|
||||
end
|
||||
|
||||
{
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "reply",
|
||||
schema: {
|
||||
type: "object",
|
||||
properties: properties,
|
||||
required: properties.keys.map(&:to_s),
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -49,6 +49,8 @@ module DiscourseAi
|
|||
@site_title = site_title
|
||||
@site_description = site_description
|
||||
@time = time
|
||||
@resource_url = resource_url
|
||||
|
||||
@feature_name = feature_name
|
||||
@resource_url = resource_url
|
||||
|
||||
|
|
|
@ -160,6 +160,10 @@ module DiscourseAi
|
|||
{}
|
||||
end
|
||||
|
||||
def response_format
|
||||
nil
|
||||
end
|
||||
|
||||
def available_tools
|
||||
self
|
||||
.class
|
||||
|
|
|
@ -18,9 +18,19 @@ module DiscourseAi
|
|||
- Limit the summary to a maximum of 40 words.
|
||||
- Do *NOT* repeat the discussion title in the summary.
|
||||
|
||||
Return the summary inside <ai></ai> tags.
|
||||
Format your response as a JSON object with a single key named "summary", which has the summary as the value.
|
||||
Your output should be in the following format:
|
||||
<output>
|
||||
{"summary": "xx"}
|
||||
</output>
|
||||
|
||||
Where "xx" is replaced by the summary.
|
||||
PROMPT
|
||||
end
|
||||
|
||||
def response_format
|
||||
[{ key: "summary", type: "string" }]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -18,8 +18,20 @@ module DiscourseAi
|
|||
- Example: link to the 6th post by jane: [agreed with]({resource_url}/6)
|
||||
- Example: link to the 13th post by joe: [joe]({resource_url}/13)
|
||||
- When formatting usernames either use @USERNAME OR [USERNAME]({resource_url}/POST_NUMBER)
|
||||
|
||||
Format your response as a JSON object with a single key named "summary", which has the summary as the value.
|
||||
Your output should be in the following format:
|
||||
<output>
|
||||
{"summary": "xx"}
|
||||
</output>
|
||||
|
||||
Where "xx" is replaced by the summary.
|
||||
PROMPT
|
||||
end
|
||||
|
||||
def response_format
|
||||
[{ key: "summary", type: "string" }]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -29,18 +29,10 @@ module DiscourseAi
|
|||
|
||||
summary = fold(truncated_content, user, &on_partial_blk)
|
||||
|
||||
clean_summary = Nokogiri::HTML5.fragment(summary).css("ai")&.first&.text || summary
|
||||
|
||||
if persist_summaries
|
||||
AiSummary.store!(
|
||||
strategy,
|
||||
llm_model,
|
||||
clean_summary,
|
||||
truncated_content,
|
||||
human: user&.human?,
|
||||
)
|
||||
AiSummary.store!(strategy, llm_model, summary, truncated_content, human: user&.human?)
|
||||
else
|
||||
AiSummary.new(summarized_text: clean_summary)
|
||||
AiSummary.new(summarized_text: summary)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -118,9 +110,20 @@ module DiscourseAi
|
|||
)
|
||||
|
||||
summary = +""
|
||||
|
||||
buffer_blk =
|
||||
Proc.new do |partial, cancel, placeholder, type|
|
||||
if type.blank?
|
||||
Proc.new do |partial, cancel, _, type|
|
||||
if type == :structured_output
|
||||
json_summary_schema_key = bot.persona.response_format&.first.to_h
|
||||
partial_summary =
|
||||
partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym]
|
||||
|
||||
if partial_summary.present?
|
||||
summary << partial_summary
|
||||
on_partial_blk.call(partial_summary, cancel) if on_partial_blk
|
||||
end
|
||||
elsif type.blank?
|
||||
# Assume response is a regular completion.
|
||||
summary << partial
|
||||
on_partial_blk.call(partial, cancel) if on_partial_blk
|
||||
end
|
||||
|
@ -154,12 +157,6 @@ module DiscourseAi
|
|||
|
||||
item
|
||||
end
|
||||
|
||||
def text_only_update(&on_partial_blk)
|
||||
Proc.new do |partial, cancel, placeholder, type|
|
||||
on_partial_blk.call(partial, cancel) if type.blank?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -21,6 +21,7 @@ RSpec.describe Jobs::FastTrackTopicGist do
|
|||
created_at: 10.minutes.ago,
|
||||
)
|
||||
end
|
||||
|
||||
let(:updated_gist) { "They updated me :(" }
|
||||
|
||||
context "when it's up to date" do
|
||||
|
|
|
@ -51,7 +51,9 @@ RSpec.describe Jobs::StreamTopicAiSummary do
|
|||
end
|
||||
|
||||
it "publishes updates with a partial summary" do
|
||||
with_responses(["dummy"]) do
|
||||
summary = "dummy"
|
||||
|
||||
with_responses([summary]) do
|
||||
messages =
|
||||
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
|
||||
job.execute(topic_id: topic.id, user_id: user.id)
|
||||
|
@ -59,12 +61,16 @@ RSpec.describe Jobs::StreamTopicAiSummary do
|
|||
|
||||
partial_summary_update = messages.first.data
|
||||
expect(partial_summary_update[:done]).to eq(false)
|
||||
expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text)).to eq("dummy")
|
||||
expect(partial_summary_update.dig(:ai_topic_summary, :summarized_text).chomp("\"}")).to eq(
|
||||
summary,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
it "publishes a final update to signal we're done and provide metadata" do
|
||||
with_responses(["dummy"]) do
|
||||
summary = "dummy"
|
||||
|
||||
with_responses([summary]) do
|
||||
messages =
|
||||
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
|
||||
job.execute(topic_id: topic.id, user_id: user.id)
|
||||
|
@ -73,6 +79,7 @@ RSpec.describe Jobs::StreamTopicAiSummary do
|
|||
final_update = messages.last.data
|
||||
expect(final_update[:done]).to eq(true)
|
||||
|
||||
expect(final_update.dig(:ai_topic_summary, :summarized_text)).to eq(summary)
|
||||
expect(final_update.dig(:ai_topic_summary, :algorithm)).to eq("fake")
|
||||
expect(final_update.dig(:ai_topic_summary, :outdated)).to eq(false)
|
||||
expect(final_update.dig(:ai_topic_summary, :can_regenerate)).to eq(true)
|
||||
|
|
|
@ -590,7 +590,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
data: {"type": "content_block_stop", "index": 0}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} }
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"AAA=="} }
|
||||
|
||||
event: ping
|
||||
data: {"type": "ping"}
|
||||
|
@ -769,4 +769,92 @@ data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_
|
|||
expect(result).to eq("I won't use any tools. Here's a direct response instead.")
|
||||
end
|
||||
end
|
||||
|
||||
describe "structured output via prefilling" do
|
||||
it "forces the response to be a JSON and using the given JSON schema" do
|
||||
schema = {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "reply",
|
||||
schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
key: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
required: ["key"],
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
},
|
||||
}
|
||||
|
||||
body = (<<~STRING).strip
|
||||
event: message_start
|
||||
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\""}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "key"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\":\\""}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello!"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "\\"}"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type": "content_block_stop", "index": 0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type": "message_stop"}
|
||||
STRING
|
||||
|
||||
parsed_body = nil
|
||||
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
headers: {
|
||||
"Content-Type" => "application/json",
|
||||
"X-Api-Key" => "123",
|
||||
"Anthropic-Version" => "2023-06-01",
|
||||
},
|
||||
).to_return(status: 200, body: body)
|
||||
|
||||
structured_output = nil
|
||||
llm.generate(
|
||||
prompt,
|
||||
user: Discourse.system_user,
|
||||
feature_name: "testing",
|
||||
response_format: schema,
|
||||
) { |partial, cancel| structured_output = partial }
|
||||
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
|
||||
expected_body = {
|
||||
model: "claude-3-opus-20240229",
|
||||
max_tokens: 4096,
|
||||
messages: [{ role: "user", content: "user1: hello" }, { role: "assistant", content: "{" }],
|
||||
system: "You are hello bot",
|
||||
stream: true,
|
||||
}
|
||||
expect(parsed_body).to eq(expected_body)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -546,4 +546,69 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "structured output via prefilling" do
|
||||
it "forces the response to be a JSON and using the given JSON schema" do
|
||||
schema = {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "reply",
|
||||
schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
key: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
required: ["key"],
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
},
|
||||
}
|
||||
|
||||
messages =
|
||||
[
|
||||
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
||||
{ type: "content_block_delta", delta: { text: "\"" } },
|
||||
{ type: "content_block_delta", delta: { text: "key" } },
|
||||
{ type: "content_block_delta", delta: { text: "\":\"" } },
|
||||
{ type: "content_block_delta", delta: { text: "Hello!" } },
|
||||
{ type: "content_block_delta", delta: { text: "\"}" } },
|
||||
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
||||
].map { |message| encode_message(message) }
|
||||
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
request = nil
|
||||
bedrock_mock.with_chunk_array_support do
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: messages)
|
||||
|
||||
structured_output = nil
|
||||
proxy.generate("hello world", response_format: schema, user: user) do |partial|
|
||||
structured_output = partial
|
||||
end
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 4096,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [
|
||||
{ "role" => "user", "content" => "hello world" },
|
||||
{ "role" => "assistant", "content" => "{" },
|
||||
],
|
||||
"system" => "You are a helpful bot",
|
||||
}
|
||||
expect(JSON.parse(request.body)).to eq(expected)
|
||||
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -308,4 +308,64 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
expect(audit.request_tokens).to eq(14)
|
||||
expect(audit.response_tokens).to eq(11)
|
||||
end
|
||||
|
||||
it "is able to return structured outputs" do
|
||||
schema = {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "reply",
|
||||
schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
key: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
required: ["key"],
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
},
|
||||
}
|
||||
|
||||
body = <<~TEXT
|
||||
{"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"}
|
||||
{"is_finished":false,"event_type":"text-generation","text":"{\\""}
|
||||
{"is_finished":false,"event_type":"text-generation","text":"key"}
|
||||
{"is_finished":false,"event_type":"text-generation","text":"\\":\\""}
|
||||
{"is_finished":false,"event_type":"text-generation","text":"Hello!"}
|
||||
{"is_finished":false,"event_type":"text-generation","text":"\\"}"}|
|
||||
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"{\\"key\\":\\"Hello!\\"}","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"}
|
||||
TEXT
|
||||
|
||||
parsed_body = nil
|
||||
structured_output = nil
|
||||
|
||||
EndpointMock.with_chunk_array_support do
|
||||
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
headers: {
|
||||
"Content-Type" => "application/json",
|
||||
"Authorization" => "Bearer ABC",
|
||||
},
|
||||
).to_return(status: 200, body: body.split("|"))
|
||||
|
||||
result =
|
||||
llm.generate(prompt, response_format: schema, user: user) do |partial, cancel|
|
||||
structured_output = partial
|
||||
end
|
||||
end
|
||||
|
||||
expect(parsed_body[:preamble]).to eq("You are hello bot")
|
||||
expect(parsed_body[:chat_history]).to eq(
|
||||
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
|
||||
)
|
||||
expect(parsed_body[:message]).to eq("user1: thanks")
|
||||
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
end
|
||||
end
|
||||
|
|
|
@ -420,7 +420,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
body:
|
||||
proc do |_req_body|
|
||||
req_body = _req_body
|
||||
true
|
||||
_req_body
|
||||
end,
|
||||
).to_return(status: 200, body: response)
|
||||
|
||||
|
@ -433,4 +433,67 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
{ function_calling_config: { mode: "ANY", allowed_function_names: ["echo"] } },
|
||||
)
|
||||
end
|
||||
|
||||
describe "structured output via JSON Schema" do
|
||||
it "forces the response to be a JSON" do
|
||||
schema = {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "reply",
|
||||
schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
key: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
required: ["key"],
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
},
|
||||
}
|
||||
|
||||
response = <<~TEXT.strip
|
||||
data: {"candidates": [{"content": {"parts": [{"text": "{\\""}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
|
||||
|
||||
data: {"candidates": [{"content": {"parts": [{"text": "key"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
|
||||
|
||||
data: {"candidates": [{"content": {"parts": [{"text": "\\":\\""}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
|
||||
|
||||
data: {"candidates": [{"content": {"parts": [{"text": "Hello!"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
|
||||
|
||||
data: {"candidates": [{"content": {"parts": [{"text": "\\"}"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
|
||||
|
||||
data: {"candidates": [{"finishReason": "MALFORMED_FUNCTION_CALL"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
|
||||
|
||||
TEXT
|
||||
|
||||
req_body = nil
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
||||
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
proc do |_req_body|
|
||||
req_body = _req_body
|
||||
true
|
||||
end,
|
||||
).to_return(status: 200, body: response)
|
||||
|
||||
structured_response = nil
|
||||
llm.generate("Hello", response_format: schema, user: user) do |partial|
|
||||
structured_response = partial
|
||||
end
|
||||
|
||||
expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" })
|
||||
|
||||
parsed = JSON.parse(req_body, symbolize_names: true)
|
||||
|
||||
# Verify that schema is passed following Gemini API specs.
|
||||
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema)
|
||||
expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::StructuredOutput do
|
||||
subject(:structured_output) do
|
||||
described_class.new(
|
||||
{
|
||||
message: {
|
||||
type: "string",
|
||||
},
|
||||
bool: {
|
||||
type: "boolean",
|
||||
},
|
||||
number: {
|
||||
type: "integer",
|
||||
},
|
||||
status: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
describe "Parsing structured output on the fly" do
|
||||
it "acts as a buffer for an streamed JSON" do
|
||||
chunks = [
|
||||
+"{\"message\": \"Line 1\\n",
|
||||
+"Line 2\\n",
|
||||
+"Line 3\", ",
|
||||
+"\"bool\": true,",
|
||||
+"\"number\": 4",
|
||||
+"2,",
|
||||
+"\"status\": \"o",
|
||||
+"\\\"k\\\"\"}",
|
||||
]
|
||||
|
||||
structured_output << chunks[0]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" })
|
||||
|
||||
structured_output << chunks[1]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" })
|
||||
|
||||
structured_output << chunks[2]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" })
|
||||
|
||||
structured_output << chunks[3]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
|
||||
|
||||
# Waiting for number to be fully buffered.
|
||||
structured_output << chunks[4]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
|
||||
|
||||
structured_output << chunks[5]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
|
||||
|
||||
structured_output << chunks[6]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq(
|
||||
{ bool: true, number: 42, status: "o" },
|
||||
)
|
||||
|
||||
structured_output << chunks[7]
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq(
|
||||
{ bool: true, number: 42, status: "\"k\"" },
|
||||
)
|
||||
|
||||
# No partial string left to read.
|
||||
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
|
||||
end
|
||||
end
|
||||
end
|
|
@ -23,17 +23,17 @@ RSpec.describe DiscourseAi::Summarization::FoldContent do
|
|||
llm_model.update!(max_prompt_tokens: model_tokens)
|
||||
end
|
||||
|
||||
let(:single_summary) { "this is a summary" }
|
||||
let(:summary) { "this is a summary" }
|
||||
|
||||
fab!(:user)
|
||||
|
||||
it "summarizes the content" do
|
||||
result =
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([single_summary]) do |spy|
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do |spy|
|
||||
summarizer.summarize(user).tap { expect(spy.completions).to eq(1) }
|
||||
end
|
||||
|
||||
expect(result.summarized_text).to eq(single_summary)
|
||||
expect(result.summarized_text).to eq(summary)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -267,6 +267,7 @@ RSpec.describe AiPersona do
|
|||
description: "system persona",
|
||||
system_prompt: "system persona",
|
||||
tools: %w[Search Time],
|
||||
response_format: [{ key: "summary", type: "string" }],
|
||||
system: true,
|
||||
)
|
||||
end
|
||||
|
@ -285,6 +286,22 @@ RSpec.describe AiPersona do
|
|||
I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"),
|
||||
)
|
||||
end
|
||||
|
||||
it "doesn't accept response format changes" do
|
||||
new_format = [{ key: "summary2", type: "string" }]
|
||||
|
||||
expect { system_persona.update!(response_format: new_format) }.to raise_error(
|
||||
ActiveRecord::RecordInvalid,
|
||||
)
|
||||
end
|
||||
|
||||
it "doesn't accept additional format changes" do
|
||||
new_format = [{ key: "summary", type: "string" }, { key: "summary2", type: "string" }]
|
||||
|
||||
expect { system_persona.update!(response_format: new_format) }.to raise_error(
|
||||
ActiveRecord::RecordInvalid,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -185,6 +185,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
default_llm_id: llm_model.id,
|
||||
question_consolidator_llm_id: llm_model.id,
|
||||
forced_tool_count: 2,
|
||||
response_format: [{ key: "summary", type: "string" }],
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -209,6 +210,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
expect(persona_json["allow_chat_channel_mentions"]).to eq(true)
|
||||
expect(persona_json["allow_chat_direct_messages"]).to eq(true)
|
||||
expect(persona_json["question_consolidator_llm_id"]).to eq(llm_model.id)
|
||||
expect(persona_json["response_format"].map { |rf| rf["key"] }).to contain_exactly(
|
||||
"summary",
|
||||
)
|
||||
|
||||
persona = AiPersona.find(persona_json["id"])
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ RSpec.describe DiscourseAi::Summarization::SummaryController do
|
|||
|
||||
it "returns a summary" do
|
||||
summary_text = "This is a summary"
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary_text]) do
|
||||
get "/discourse-ai/summarization/t/#{topic.id}.json"
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@ describe DiscourseAi::TopicSummarization do
|
|||
|
||||
let(:strategy) { DiscourseAi::Summarization.topic_summary(topic) }
|
||||
|
||||
let(:summary) { "This is the final summary" }
|
||||
|
||||
describe "#summarize" do
|
||||
subject(:summarization) { described_class.new(strategy, user) }
|
||||
|
||||
|
@ -27,8 +29,6 @@ describe DiscourseAi::TopicSummarization do
|
|||
end
|
||||
|
||||
context "when the content was summarized in a single chunk" do
|
||||
let(:summary) { "This is the final summary" }
|
||||
|
||||
it "caches the summary" do
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
||||
section = summarization.summarize
|
||||
|
@ -54,7 +54,6 @@ describe DiscourseAi::TopicSummarization do
|
|||
|
||||
describe "invalidating cached summaries" do
|
||||
let(:cached_text) { "This is a cached summary" }
|
||||
let(:updated_summary) { "This is the final summary" }
|
||||
|
||||
def cached_summary
|
||||
AiSummary.find_by(target: topic, summary_type: AiSummary.summary_types[:complete])
|
||||
|
@ -86,10 +85,10 @@ describe DiscourseAi::TopicSummarization do
|
|||
before { cached_summary.update!(original_content_sha: "outdated_sha") }
|
||||
|
||||
it "returns a new summary" do
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
||||
section = summarization.summarize
|
||||
|
||||
expect(section.summarized_text).to eq(updated_summary)
|
||||
expect(section.summarized_text).to eq(summary)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -106,10 +105,10 @@ describe DiscourseAi::TopicSummarization do
|
|||
end
|
||||
|
||||
it "returns a new summary if the skip_age_check flag is passed" do
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([updated_summary]) do
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
||||
section = summarization.summarize(skip_age_check: true)
|
||||
|
||||
expect(section.summarized_text).to eq(updated_summary)
|
||||
expect(section.summarized_text).to eq(summary)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -118,8 +117,6 @@ describe DiscourseAi::TopicSummarization do
|
|||
end
|
||||
|
||||
describe "stream partial updates" do
|
||||
let(:summary) { "This is the final summary" }
|
||||
|
||||
it "receives a blk that is passed to the underlying strategy and called with partial summaries" do
|
||||
partial_result = +""
|
||||
|
||||
|
@ -127,7 +124,8 @@ describe DiscourseAi::TopicSummarization do
|
|||
summarization.summarize { |partial_summary| partial_result << partial_summary }
|
||||
end
|
||||
|
||||
expect(partial_result).to eq(summary)
|
||||
# In a real world example, this is removed in the returned AiSummary obj.
|
||||
expect(partial_result.chomp("\"}")).to eq(summary)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -12,7 +12,9 @@ RSpec.describe "Summarize a topic ", type: :system do
|
|||
"I like to eat pie. It is a very good dessert. Some people are wasteful by throwing pie at others but I do not do that. I always eat the pie.",
|
||||
)
|
||||
end
|
||||
|
||||
let(:summarization_result) { "This is a summary" }
|
||||
|
||||
let(:topic_page) { PageObjects::Pages::Topic.new }
|
||||
let(:summary_box) { PageObjects::Components::AiSummaryTrigger.new }
|
||||
|
||||
|
|
Loading…
Reference in New Issue