diff --git a/app/controllers/discourse_ai/ai_bot/bot_controller.rb b/app/controllers/discourse_ai/ai_bot/bot_controller.rb
index e5d5bcf0..5ea13795 100644
--- a/app/controllers/discourse_ai/ai_bot/bot_controller.rb
+++ b/app/controllers/discourse_ai/ai_bot/bot_controller.rb
@@ -6,6 +6,14 @@ module DiscourseAi
requires_plugin ::DiscourseAi::PLUGIN_NAME
requires_login
+ def show_debug_info_by_id
+ log = AiApiAuditLog.find(params[:id])
+ raise Discourse::NotFound if !log.topic
+
+ guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
+ render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
+ end
+
def show_debug_info
post = Post.find(params[:post_id])
guardian.ensure_can_debug_ai_bot_conversation!(post)
diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb
index 2fa9f5c3..2fa0a214 100644
--- a/app/models/ai_api_audit_log.rb
+++ b/app/models/ai_api_audit_log.rb
@@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
Ollama = 7
SambaNova = 8
end
+
+ def next_log_id
+ self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
+ end
+
+ def prev_log_id
+ self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
+ end
end
# == Schema Information
diff --git a/app/serializers/ai_api_audit_log_serializer.rb b/app/serializers/ai_api_audit_log_serializer.rb
index 0c438a7b..eeb3843a 100644
--- a/app/serializers/ai_api_audit_log_serializer.rb
+++ b/app/serializers/ai_api_audit_log_serializer.rb
@@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
:post_id,
:feature_name,
:language_model,
- :created_at
+ :created_at,
+ :prev_log_id,
+ :next_log_id
end
diff --git a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs
index 5d0cdf69..c21e8df3 100644
--- a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs
+++ b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs
@@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal";
import { ajax } from "discourse/lib/ajax";
+import { popupAjaxError } from "discourse/lib/ajax-error";
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
import i18n from "discourse-common/helpers/i18n";
import discourseLater from "discourse-common/lib/later";
@@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
this.copy(this.info.raw_response_payload);
}
+ async loadLog(logId) {
+ try {
+ await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
+ (result) => {
+ this.info = result;
+ }
+ );
+ } catch (e) {
+ popupAjaxError(e);
+ }
+ }
+
+ @action
+ prevLog() {
+ this.loadLog(this.info.prev_log_id);
+ }
+
+ @action
+ nextLog() {
+ this.loadLog(this.info.next_log_id);
+ }
+
copy(text) {
clipboardCopy(text);
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
@@ -73,11 +96,13 @@ export default class DebugAiModal extends Component {
}
loadApiRequestInfo() {
- ajax(
- `/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
- ).then((result) => {
- this.info = result;
- });
+ ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`)
+ .then((result) => {
+ this.info = result;
+ })
+ .catch((e) => {
+ popupAjaxError(e);
+ });
}
get requestActive() {
@@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
@action={{this.copyResponse}}
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
/>
+ {{#if this.info.prev_log_id}}
+
+ {{/if}}
+ {{#if this.info.next_log_id}}
+
+ {{/if}}
{{this.justCopiedText}}
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index 2d517946..82898c91 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -415,6 +415,8 @@ en:
response_tokens: "Response tokens:"
request: "Request"
response: "Response"
+ next_log: "Next"
+ previous_log: "Previous"
share_full_topic_modal:
title: "Share Conversation Publicly"
diff --git a/config/routes.rb b/config/routes.rb
index a5c009ff..322e67ce 100644
--- a/config/routes.rb
+++ b/config/routes.rb
@@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
get "bot-username" => "bot#show_bot_username"
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
+ get "show-debug-info/:id" => "bot#show_debug_info_by_id"
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
end
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index 834ae059..b965b1f6 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -100,6 +100,7 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p
needs_newlines = false
+ tools_ran = 0
while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false
@@ -107,9 +108,10 @@ module DiscourseAi
result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
- tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
+ tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
+ tool = nil if tools_ran >= MAX_TOOLS
- if (tools.present?)
+ if tool.present?
tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
@@ -117,13 +119,16 @@ module DiscourseAi
needs_newlines = false
end
- tools[0..MAX_TOOLS].each do |tool|
- process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
- ongoing_chain &&= tool.chain_next_response?
- end
+ process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
+ tools_ran += 1
+ ongoing_chain &&= tool.chain_next_response?
else
needs_newlines = true
- update_blk.call(partial, cancel)
+ if partial.is_a?(DiscourseAi::Completions::ToolCall)
+ Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")
+ else
+ update_blk.call(partial, cancel)
+ end
end
end
diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb
index 73224808..63255a17 100644
--- a/lib/ai_bot/personas/persona.rb
+++ b/lib/ai_bot/personas/persona.rb
@@ -199,23 +199,16 @@ module DiscourseAi
prompt
end
- def find_tools(partial, bot_user:, llm:, context:)
- return [] if !partial.include?("")
-
- parsed_function = Nokogiri::HTML5.fragment(partial)
- parsed_function
- .css("invoke")
- .map do |fragment|
- tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
- end
- .compact
+ def find_tool(partial, bot_user:, llm:, context:)
+ return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
+ tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
end
protected
- def tool_instance(parsed_function, bot_user:, llm:, context:)
- function_id = parsed_function.at("tool_id")&.text
- function_name = parsed_function.at("tool_name")&.text
+ def tool_instance(tool_call, bot_user:, llm:, context:)
+ function_id = tool_call.id
+ function_name = tool_call.name
return nil if function_name.nil?
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
@@ -224,7 +217,7 @@ module DiscourseAi
arguments = {}
tool_klass.signature[:parameters].to_a.each do |param|
name = param[:name]
- value = parsed_function.at(name)&.text
+ value = tool_call.parameters[name.to_sym]
if param[:type] == "array" && value
value =
diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb
index 1d1516fa..5d5602ef 100644
--- a/lib/completions/anthropic_message_processor.rb
+++ b/lib/completions/anthropic_message_processor.rb
@@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def append(json)
@raw_json << json
end
+
+ def to_tool_call
+ parameters = JSON.parse(raw_json, symbolize_names: true)
+ DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
+ end
end
attr_reader :tool_calls, :input_tokens, :output_tokens
@@ -20,80 +25,69 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def initialize(streaming_mode:)
@streaming_mode = streaming_mode
@tool_calls = []
+ @current_tool_call = nil
end
- def to_xml_tool_calls(function_buffer)
- return function_buffer if @tool_calls.blank?
+ def to_tool_calls
+ @tool_calls.map { |tool_call| tool_call.to_tool_call }
+ end
- function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
-
-
- TEXT
-
- @tool_calls.each do |tool_call|
- node =
- function_buffer.at("function_calls").add_child(
- Nokogiri::HTML5::DocumentFragment.parse(
- DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
- ),
- )
-
- params = JSON.parse(tool_call.raw_json, symbolize_names: true)
- xml =
- params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}#{name}>" }.join("\n")
-
- node.at("tool_name").content = tool_call.name
- node.at("tool_id").content = tool_call.id
- node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
+ def process_streamed_message(parsed)
+ result = nil
+ if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
+ tool_name = parsed.dig(:content_block, :name)
+ tool_id = parsed.dig(:content_block, :id)
+ result = @current_tool_call.to_tool_call if @current_tool_call
+ @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
+ elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
+ if @current_tool_call
+ tool_delta = parsed.dig(:delta, :partial_json).to_s
+ @current_tool_call.append(tool_delta)
+ else
+ result = parsed.dig(:delta, :text).to_s
+ end
+ elsif parsed[:type] == "content_block_stop"
+ if @current_tool_call
+ result = @current_tool_call.to_tool_call
+ @current_tool_call = nil
+ end
+ elsif parsed[:type] == "message_start"
+ @input_tokens = parsed.dig(:message, :usage, :input_tokens)
+ elsif parsed[:type] == "message_delta"
+ @output_tokens =
+ parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
+ elsif parsed[:type] == "message_stop"
+ # bedrock has this ...
+ if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
+ @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
+ @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
+ end
end
-
- function_buffer
+ result
end
def process_message(payload)
result = ""
- parsed = JSON.parse(payload, symbolize_names: true)
+ parsed = payload
+ parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
- if @streaming_mode
- if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
- tool_name = parsed.dig(:content_block, :name)
- tool_id = parsed.dig(:content_block, :id)
- @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
- elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
- if @tool_calls.present?
- result = parsed.dig(:delta, :partial_json).to_s
- @tool_calls.last.append(result)
- else
- result = parsed.dig(:delta, :text).to_s
+ content = parsed.dig(:content)
+ if content.is_a?(Array)
+ result =
+ content.map do |data|
+ if data[:type] == "tool_use"
+ call = AnthropicToolCall.new(data[:name], data[:id])
+ call.append(data[:input].to_json)
+ call.to_tool_call
+ else
+ data[:text]
+ end
end
- elsif parsed[:type] == "message_start"
- @input_tokens = parsed.dig(:message, :usage, :input_tokens)
- elsif parsed[:type] == "message_delta"
- @output_tokens =
- parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
- elsif parsed[:type] == "message_stop"
- # bedrock has this ...
- if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
- @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
- @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
- end
- end
- else
- content = parsed.dig(:content)
- if content.is_a?(Array)
- tool_call = content.find { |c| c[:type] == "tool_use" }
- if tool_call
- @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
- @tool_calls.last.append(tool_call[:input].to_json)
- else
- result = parsed.dig(:content, 0, :text).to_s
- end
- end
-
- @input_tokens = parsed.dig(:usage, :input_tokens)
- @output_tokens = parsed.dig(:usage, :output_tokens)
end
+ @input_tokens = parsed.dig(:usage, :input_tokens)
+ @output_tokens = parsed.dig(:usage, :output_tokens)
+
result
end
end
diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb
index 541d0e73..3a32e592 100644
--- a/lib/completions/dialects/ollama.rb
+++ b/lib/completions/dialects/ollama.rb
@@ -63,8 +63,23 @@ module DiscourseAi
def user_msg(msg)
user_message = { role: "user", content: msg[:content] }
- # TODO: Add support for user messages with empbeded user ids
- # TODO: Add support for user messages with attachments
+ encoded_uploads = prompt.encoded_uploads(msg)
+ if encoded_uploads.present?
+ images =
+ encoded_uploads
+ .map do |upload|
+ if upload[:mime_type].start_with?("image/")
+ upload[:base64]
+ else
+ nil
+ end
+ end
+ .compact
+
+ user_message[:images] = images if images.present?
+ end
+
+ # TODO: Add support for user messages with embedded user ids
user_message
end
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index 44762b88..6576ef3b 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -63,6 +63,10 @@ module DiscourseAi
URI(llm_model.url)
end
+ def xml_tools_enabled?
+ !@native_tool_support
+ end
+
def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support?
@@ -90,35 +94,34 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
+ def decode_chunk(partial_data)
+ @decoder ||= JsonStreamDecoder.new
+ (@decoder << partial_data)
+ .map { |parsed_json| processor.process_streamed_message(parsed_json) }
+ .compact
+ end
+
+ def decode(response_data)
+ processor.process_message(response_data)
+ end
+
def processor
@processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end
- def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
- processor.to_xml_tool_calls(function_buffer) if !partial
- end
-
- def extract_completion_from(response_raw)
- processor.process_message(response_raw)
- end
-
def has_tool?(_response_data)
processor.tool_calls.present?
end
+ def tool_calls
+ processor.to_tool_calls
+ end
+
def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens
end
-
- def native_tool_support?
- @native_tool_support
- end
-
- def partials_from(decoded_chunk)
- decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
- end
end
end
end
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index f3146c2d..c17a051f 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -117,7 +117,24 @@ module DiscourseAi
end
end
- def decode(chunk)
+ def decode_chunk(partial_data)
+ bedrock_decode(partial_data)
+ .map do |decoded_partial_data|
+ @raw_response ||= +""
+ @raw_response << decoded_partial_data
+ @raw_response << "\n"
+
+ parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true)
+ processor.process_streamed_message(parsed_json)
+ end
+ .compact
+ end
+
+ def decode(response_data)
+ processor.process_message(response_data)
+ end
+
+ def bedrock_decode(chunk)
@decoder ||= Aws::EventStream::Decoder.new
decoded, _done = @decoder.decode_chunk(chunk)
@@ -147,12 +164,13 @@ module DiscourseAi
Aws::EventStream::Errors::MessageChecksumError,
Aws::EventStream::Errors::PreludeChecksumError => e
Rails.logger.error("#{self.class.name}: #{e.message}")
- nil
+ []
end
def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens
+ log.raw_response_payload = @raw_response
end
def processor
@@ -160,30 +178,8 @@ module DiscourseAi
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end
- def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
- processor.to_xml_tool_calls(function_buffer) if !partial
- end
-
- def extract_completion_from(response_raw)
- processor.process_message(response_raw)
- end
-
- def has_tool?(_response_data)
- processor.tool_calls.present?
- end
-
- def partials_from(decoded_chunks)
- decoded_chunks
- end
-
- def native_tool_support?
- @native_tool_support
- end
-
- def chunk_to_string(chunk)
- joined = +chunk.join("\n")
- joined << "\n" if joined.length > 0
- joined
+ def xml_tools_enabled?
+ !@native_tool_support
end
end
end
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index a0405b42..c78fcdd9 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -40,10 +40,6 @@ module DiscourseAi
@llm_model = llm_model
end
- def native_tool_support?
- false
- end
-
def use_ssl?
if model_uri&.scheme.present?
model_uri.scheme == "https"
@@ -64,22 +60,10 @@ module DiscourseAi
feature_context: nil,
&blk
)
- allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
orig_blk = blk
@streaming_mode = block_given?
- to_strip = xml_tags_to_strip(dialect)
- @xml_stripper =
- DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
-
- if @streaming_mode && @xml_stripper
- blk =
- lambda do |partial, cancel|
- partial = @xml_stripper << partial
- orig_blk.call(partial, cancel) if partial
- end
- end
prompt = dialect.translate
@@ -108,177 +92,91 @@ module DiscourseAi
raise CompletionFailed, response.body
end
+ xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? &&
+ dialect.prompt.has_tools?
+
+ to_strip = xml_tags_to_strip(dialect)
+ xml_stripper =
+ DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
+
+ if @streaming_mode && xml_stripper
+ blk =
+ lambda do |partial, cancel|
+ partial = xml_stripper << partial if partial.is_a?(String)
+ orig_blk.call(partial, cancel) if partial
+ end
+ end
+
log =
- AiApiAuditLog.new(
+ start_log(
provider_id: provider_id,
- user_id: user&.id,
- raw_request_payload: request_body,
- request_tokens: prompt_size(prompt),
- topic_id: dialect.prompt.topic_id,
- post_id: dialect.prompt.post_id,
+ request_body: request_body,
+ dialect: dialect,
+ prompt: prompt,
+ user: user,
feature_name: feature_name,
- language_model: llm_model.name,
- feature_context: feature_context.present? ? feature_context.as_json : nil,
+ feature_context: feature_context,
)
if !@streaming_mode
- response_raw = response.read_body
- response_data = extract_completion_from(response_raw)
- partials_raw = response_data.to_s
-
- if native_tool_support?
- if allow_tools && has_tool?(response_data)
- function_buffer = build_buffer # Nokogiri document
- function_buffer =
- add_to_function_buffer(function_buffer, payload: response_data)
- FunctionCallNormalizer.normalize_function_ids!(function_buffer)
-
- response_data = +function_buffer.at("function_calls").to_s
- response_data << "\n"
- end
- else
- if allow_tools
- response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
- response_data = function_calls if function_calls.present?
- end
- end
-
- return response_data
+ return(
+ non_streaming_response(
+ response: response,
+ xml_tool_processor: xml_tool_processor,
+ xml_stripper: xml_stripper,
+ partials_raw: partials_raw,
+ response_raw: response_raw,
+ )
+ )
end
- has_tool = false
-
begin
cancelled = false
cancel = -> { cancelled = true }
-
- wrapped_blk = ->(partial, inner_cancel) do
- response_data << partial
- blk.call(partial, inner_cancel)
+ if cancelled
+ http.finish
+ break
end
- normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
-
- leftover = ""
- function_buffer = build_buffer # Nokogiri document
- prev_processed_partials = 0
-
response.read_body do |chunk|
- if cancelled
- http.finish
- break
- end
-
- decoded_chunk = decode(chunk)
- if decoded_chunk.nil?
- raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion"
- end
- response_raw << chunk_to_string(decoded_chunk)
-
- if decoded_chunk.is_a?(String)
- redo_chunk = leftover + decoded_chunk
- else
- # custom implementation for endpoint
- # no implicit leftover support
- redo_chunk = decoded_chunk
- end
-
- raw_partials = partials_from(redo_chunk)
-
- raw_partials =
- raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0
-
- if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?)
- leftover = redo_chunk
- next
- end
-
- json_error = false
-
- raw_partials.each do |raw_partial|
- json_error = false
- prev_processed_partials += 1
-
- next if cancelled
- next if raw_partial.blank?
-
- begin
- partial = extract_completion_from(raw_partial)
- next if partial.nil?
- # empty vs blank... we still accept " "
- next if response_data.empty? && partial.empty?
- partials_raw << partial.to_s
-
- if native_tool_support?
- # Stop streaming the response as soon as you find a tool.
- # We'll buffer and yield it later.
- has_tool = true if allow_tools && has_tool?(partials_raw)
-
- if has_tool
- function_buffer =
- add_to_function_buffer(function_buffer, partial: partial)
- else
- response_data << partial
- blk.call(partial, cancel) if partial
- end
- else
- if allow_tools
- normalizer << partial
- else
- response_data << partial
- blk.call(partial, cancel) if partial
- end
+ response_raw << chunk
+ decode_chunk(chunk).each do |partial|
+ partials_raw << partial.to_s
+ response_data << partial if partial.is_a?(String)
+ partials = [partial]
+ if xml_tool_processor && partial.is_a?(String)
+ partials = (xml_tool_processor << partial)
+ if xml_tool_processor.should_cancel?
+ cancel.call
+ break
end
- rescue JSON::ParserError
- leftover = redo_chunk
- json_error = true
end
+ partials.each { |inner_partial| blk.call(inner_partial, cancel) }
end
-
- if json_error
- prev_processed_partials -= 1
- else
- leftover = ""
- end
-
- prev_processed_partials = 0 if leftover.blank?
end
rescue IOError, StandardError
raise if !cancelled
end
-
- has_tool ||= has_tool?(partials_raw)
- # Once we have the full response, try to return the tool as a XML doc.
- if has_tool && native_tool_support?
- function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
-
- if function_buffer.at("tool_name").text.present?
- FunctionCallNormalizer.normalize_function_ids!(function_buffer)
-
- invocation = +function_buffer.at("function_calls").to_s
- invocation << "\n"
-
- response_data << invocation
- blk.call(invocation, cancel)
+ if xml_stripper
+ stripped = xml_stripper.finish
+ if stripped.present?
+ response_data << stripped
+ result = []
+ result = (xml_tool_processor << stripped) if xml_tool_processor
+ result.each { |partial| blk.call(partial, cancel) }
end
end
-
- if !native_tool_support? && function_calls = normalizer.function_calls
- response_data << function_calls
- blk.call(function_calls, cancel)
+ if xml_tool_processor
+ xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
end
-
- if @xml_stripper
- leftover = @xml_stripper.finish
- orig_blk.call(leftover, cancel) if leftover.present?
- end
-
+ decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
return response_data
ensure
if log
log.raw_response_payload = response_raw
- log.response_tokens = tokenizer.size(partials_raw)
final_log_update(log)
+
+ log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.save!
if Rails.env.development?
@@ -330,15 +228,15 @@ module DiscourseAi
raise NotImplementedError
end
- def extract_completion_from(_response_raw)
+ def decode(_response_raw)
raise NotImplementedError
end
- def decode(chunk)
- chunk
+ def decode_chunk_finish
+ []
end
- def partials_from(_decoded_chunk)
+ def decode_chunk(_chunk)
raise NotImplementedError
end
@@ -346,49 +244,73 @@ module DiscourseAi
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
- def build_buffer
- Nokogiri::HTML5.fragment(<<~TEXT)
-
- #{noop_function_call_text}
-
- TEXT
+ def xml_tools_enabled?
+ raise NotImplementedError
end
- def self.noop_function_call_text
- (<<~TEXT).strip
-
-
-
-
-
-
- TEXT
+ private
+
+ def start_log(
+ provider_id:,
+ request_body:,
+ dialect:,
+ prompt:,
+ user:,
+ feature_name:,
+ feature_context:
+ )
+ AiApiAuditLog.new(
+ provider_id: provider_id,
+ user_id: user&.id,
+ raw_request_payload: request_body,
+ request_tokens: prompt_size(prompt),
+ topic_id: dialect.prompt.topic_id,
+ post_id: dialect.prompt.post_id,
+ feature_name: feature_name,
+ language_model: llm_model.name,
+ feature_context: feature_context.present? ? feature_context.as_json : nil,
+ )
end
- def noop_function_call_text
- self.class.noop_function_call_text
- end
+ def non_streaming_response(
+ response:,
+ xml_tool_processor:,
+ xml_stripper:,
+ partials_raw:,
+ response_raw:
+ )
+ response_raw << response.read_body
+ response_data = decode(response_raw)
- def has_tool?(response)
- response.include?("")
- end
+ response_data.each { |partial| partials_raw << partial.to_s }
- def chunk_to_string(chunk)
- if chunk.is_a?(String)
- chunk
- else
- chunk.to_s
- end
- end
-
- def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
- if payload&.include?("")
- matches = payload.match(%r{.*}m)
- function_buffer =
- Nokogiri::HTML5.fragment(matches[0] + "\n") if matches
+ if xml_tool_processor
+ response_data.each do |partial|
+ processed = (xml_tool_processor << partial)
+ processed << xml_tool_processor.finish
+ response_data = []
+ processed.flatten.compact.each { |inner| response_data << inner }
+ end
end
- function_buffer
+ if xml_stripper
+ response_data.map! do |partial|
+ stripped = (xml_stripper << partial) if partial.is_a?(String)
+ if stripped.present?
+ stripped
+ else
+ partial
+ end
+ end
+ response_data << xml_stripper.finish
+ end
+
+ response_data.reject!(&:blank?)
+
+ # this is to keep stuff backwards compatible
+ response_data = response_data.first if response_data.length == 1
+
+ response_data
end
end
end
diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb
index eaef21da..bd3ae4ea 100644
--- a/lib/completions/endpoints/canned_response.rb
+++ b/lib/completions/endpoints/canned_response.rb
@@ -45,17 +45,21 @@ module DiscourseAi
cancel_fn = lambda { cancelled = true }
# We buffer and return tool invocations in one go.
- if is_tool?(response)
- yield(response, cancel_fn)
- else
- response.each_char do |char|
- break if cancelled
- yield(char, cancel_fn)
+ as_array = response.is_a?(Array) ? response : [response]
+ as_array.each do |response|
+ if is_tool?(response)
+ yield(response, cancel_fn)
+ else
+ response.each_char do |char|
+ break if cancelled
+ yield(char, cancel_fn)
+ end
end
end
- else
- response
end
+
+ response = response.first if response.is_a?(Array) && response.length == 1
+ response
end
def tokenizer
@@ -65,7 +69,7 @@ module DiscourseAi
private
def is_tool?(response)
- Nokogiri::HTML5.fragment(response).at("function_calls").present?
+ response.is_a?(DiscourseAi::Completions::ToolCall)
end
end
end
diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb
index 180c27c8..258062a1 100644
--- a/lib/completions/endpoints/cohere.rb
+++ b/lib/completions/endpoints/cohere.rb
@@ -49,6 +49,47 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
+ def decode(response_raw)
+ rval = []
+
+ parsed = JSON.parse(response_raw, symbolize_names: true)
+
+ text = parsed[:text]
+ rval << parsed[:text] if !text.to_s.empty? # also allow " "
+
+ # TODO tool calls
+
+ update_usage(parsed)
+
+ rval
+ end
+
+ def decode_chunk(chunk)
+ @tool_idx ||= -1
+ @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
+ (@json_decoder << chunk)
+ .map do |parsed|
+ update_usage(parsed)
+ rval = []
+
+ rval << parsed[:text] if !parsed[:text].to_s.empty?
+
+ if tool_calls = parsed[:tool_calls]
+ tool_calls&.each do |tool_call|
+ @tool_idx += 1
+ tool_name = tool_call[:name]
+ tool_params = tool_call[:parameters]
+ tool_id = "tool_#{@tool_idx}"
+ rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params)
+ end
+ end
+
+ rval
+ end
+ .flatten
+ .compact
+ end
+
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true)
@@ -77,36 +118,8 @@ module DiscourseAi
end
end
- def has_tool?(_ignored)
- @has_tool
- end
-
- def native_tool_support?
- true
- end
-
- def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
- if partial
- tools = JSON.parse(partial)
- tools.each do |tool|
- name = tool["name"]
- parameters = tool["parameters"]
- xml_params = parameters.map { |k, v| "<#{k}>#{v}#{k}>\n" }.join
-
- current_function = function_buffer.at("invoke")
- if current_function.nil? || current_function.at("tool_name").content.present?
- current_function =
- function_buffer.at("function_calls").add_child(
- Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
- )
- end
-
- current_function.at("tool_name").content = name == "search_local" ? "search" : name
- current_function.at("parameters").children =
- Nokogiri::HTML5::DocumentFragment.parse(xml_params)
- end
- end
- function_buffer
+ def xml_tools_enabled?
+ false
end
def final_log_update(log)
@@ -114,10 +127,6 @@ module DiscourseAi
log.response_tokens = @output_tokens if @output_tokens
end
- def partials_from(decoded_chunk)
- decoded_chunk.split("\n").compact
- end
-
def extract_prompt_for_tokenizer(prompt)
text = +""
if prompt[:chat_history]
@@ -131,6 +140,18 @@ module DiscourseAi
text
end
+
+ private
+
+ def update_usage(parsed)
+ input_tokens = parsed.dig(:meta, :billed_units, :input_tokens)
+ input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens)
+ @input_tokens = input_tokens if input_tokens.present?
+
+ output_tokens = parsed.dig(:meta, :billed_units, :output_tokens)
+ output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens)
+ @output_tokens = output_tokens if output_tokens.present?
+ end
end
end
end
diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb
index a51ff3ac..15cc254d 100644
--- a/lib/completions/endpoints/fake.rb
+++ b/lib/completions/endpoints/fake.rb
@@ -133,31 +133,35 @@ module DiscourseAi
content = content.shift if content.is_a?(Array)
if block_given?
- split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
- indexes = [0, *split_indices, content.length]
+ if content.is_a?(DiscourseAi::Completions::ToolCall)
+ yield(content, -> {})
+ else
+ split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
+ indexes = [0, *split_indices, content.length]
- original_content = content
- content = +""
+ original_content = content
+ content = +""
- cancel = false
- cancel_proc = -> { cancel = true }
+ cancel = false
+ cancel_proc = -> { cancel = true }
- i = 0
- indexes
- .each_cons(2)
- .map { |start, finish| original_content[start...finish] }
- .each do |chunk|
- break if cancel
- if self.class.delays.present? &&
- (delay = self.class.delays[i % self.class.delays.length])
- sleep(delay)
- i += 1
+ i = 0
+ indexes
+ .each_cons(2)
+ .map { |start, finish| original_content[start...finish] }
+ .each do |chunk|
+ break if cancel
+ if self.class.delays.present? &&
+ (delay = self.class.delays[i % self.class.delays.length])
+ sleep(delay)
+ i += 1
+ end
+ break if cancel
+
+ content << chunk
+ yield(chunk, cancel_proc)
end
- break if cancel
-
- content << chunk
- yield(chunk, cancel_proc)
- end
+ end
end
content
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
index ddf607b2..2450dc99 100644
--- a/lib/completions/endpoints/gemini.rb
+++ b/lib/completions/endpoints/gemini.rb
@@ -103,15 +103,7 @@ module DiscourseAi
end
end
- def partials_from(decoded_chunk)
- decoded_chunk
- end
-
- def chunk_to_string(chunk)
- chunk.to_s
- end
-
- class Decoder
+ class GeminiStreamingDecoder
def initialize
@buffer = +""
end
@@ -151,43 +143,87 @@ module DiscourseAi
end
def decode(chunk)
- @decoder ||= Decoder.new
- @decoder.decode(chunk)
+ json = JSON.parse(chunk, symbolize_names: true)
+ idx = -1
+ json
+ .dig(:candidates, 0, :content, :parts)
+ .map do |part|
+ if part[:functionCall]
+ idx += 1
+ ToolCall.new(
+ id: "tool_#{idx}",
+ name: part[:functionCall][:name],
+ parameters: part[:functionCall][:args],
+ )
+ else
+ part = part[:text]
+ if part != ""
+ part
+ else
+ nil
+ end
+ end
+ end
+ end
+
+ def decode_chunk(chunk)
+ @tool_index ||= -1
+
+ streaming_decoder
+ .decode(chunk)
+ .map do |parsed|
+ update_usage(parsed)
+ parsed
+ .dig(:candidates, 0, :content, :parts)
+ .map do |part|
+ if part[:text]
+ part = part[:text]
+ if part != ""
+ part
+ else
+ nil
+ end
+ elsif part[:functionCall]
+ @tool_index += 1
+ ToolCall.new(
+ id: "tool_#{@tool_index}",
+ name: part[:functionCall][:name],
+ parameters: part[:functionCall][:args],
+ )
+ end
+ end
+ end
+ .flatten
+ .compact
+ end
+
+ def update_usage(parsed)
+ usage = parsed.dig(:usageMetadata)
+ if usage
+ if prompt_token_count = usage[:promptTokenCount]
+ @prompt_token_count = prompt_token_count
+ end
+ if candidate_token_count = usage[:candidatesTokenCount]
+ @candidate_token_count = candidate_token_count
+ end
+ end
+ end
+
+ def final_log_update(log)
+ log.request_tokens = @prompt_token_count if @prompt_token_count
+ log.response_tokens = @candidate_token_count if @candidate_token_count
+ end
+
+ def streaming_decoder
+ @decoder ||= GeminiStreamingDecoder.new
end
def extract_prompt_for_tokenizer(prompt)
prompt.to_s
end
- def has_tool?(_response_data)
- @has_function_call
- end
-
- def native_tool_support?
- true
- end
-
- def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
- if @streaming_mode
- return function_buffer if !partial
- else
- partial = payload
- end
-
- function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
-
- if partial[:args]
- argument_fragments =
- partial[:args].reduce(+"") do |memo, (arg_name, value)|
- memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}#{arg_name}>"
- end
- argument_fragments << "\n"
-
- function_buffer.at("parameters").children =
- Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
- end
-
- function_buffer
+ def xml_tools_enabled?
+ false
end
end
end
diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb
index bd7edc06..b0b14722 100644
--- a/lib/completions/endpoints/hugging_face.rb
+++ b/lib/completions/endpoints/hugging_face.rb
@@ -59,22 +59,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
- def extract_completion_from(response_raw)
- parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
- # half a line sent here
- return if !parsed
-
- response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
-
- response_h.dig(:content)
+ def xml_tools_enabled?
+ true
end
- def partials_from(decoded_chunk)
- decoded_chunk
- .split("\n")
- .map do |line|
- data = line.split("data:", 2)[1]
- data&.squish == "[DONE]" ? nil : data
+ def decode(response_raw)
+ parsed = JSON.parse(response_raw, symbolize_names: true)
+ text = parsed.dig(:choices, 0, :message, :content)
+ if text.to_s.empty?
+ [""]
+ else
+ [text]
+ end
+ end
+
+ def decode_chunk(chunk)
+ @json_decoder ||= JsonStreamDecoder.new
+ (@json_decoder << chunk)
+ .map do |parsed|
+ text = parsed.dig(:choices, 0, :delta, :content)
+ if text.to_s.empty?
+ nil
+ else
+ text
+ end
end
.compact
end
diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb
index cc58006a..dd4ca2c7 100644
--- a/lib/completions/endpoints/ollama.rb
+++ b/lib/completions/endpoints/ollama.rb
@@ -37,12 +37,8 @@ module DiscourseAi
URI(llm_model.url)
end
- def native_tool_support?
- @native_tool_support
- end
-
- def has_tool?(_response_data)
- @has_function_call
+ def xml_tools_enabled?
+ !@native_tool_support
end
def prepare_payload(prompt, model_params, dialect)
@@ -67,74 +63,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
- def partials_from(decoded_chunk)
- decoded_chunk.split("\n").compact
+ def decode_chunk(chunk)
+ # Native tool calls are not working right in streaming mode, use XML
+ @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
+ (@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact
end
- def extract_completion_from(response_raw)
+ def decode(response_raw)
+ rval = []
parsed = JSON.parse(response_raw, symbolize_names: true)
- return if !parsed
+ content = parsed.dig(:message, :content)
+ rval << content if !content.to_s.empty?
- response_h = parsed.dig(:message)
-
- @has_function_call ||= response_h.dig(:tool_calls).present?
- @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
- end
-
- def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
- @args_buffer ||= +""
-
- if @streaming_mode
- return function_buffer if !partial
- else
- partial = payload
- end
-
- f_name = partial.dig(:function, :name)
-
- @current_function ||= function_buffer.at("invoke")
-
- if f_name
- current_name = function_buffer.at("tool_name").content
-
- if current_name.blank?
- # first call
- else
- # we have a previous function, so we need to add a noop
- @args_buffer = +""
- @current_function =
- function_buffer.at("function_calls").add_child(
- Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
- )
+ idx = -1
+ parsed
+ .dig(:message, :tool_calls)
+ &.each do |tool_call|
+ idx += 1
+ id = "tool_#{idx}"
+ name = tool_call.dig(:function, :name)
+ args = tool_call.dig(:function, :arguments)
+ rval << ToolCall.new(id: id, name: name, parameters: args)
end
- end
- @current_function.at("tool_name").content = f_name if f_name
- @current_function.at("tool_id").content = partial[:id] if partial[:id]
-
- args = partial.dig(:function, :arguments)
-
- # allow for SPACE within arguments
- if args && args != ""
- @args_buffer << args.to_json
-
- begin
- json_args = JSON.parse(@args_buffer, symbolize_names: true)
-
- argument_fragments =
- json_args.reduce(+"") do |memo, (arg_name, value)|
- memo << "\n<#{arg_name}>#{value}#{arg_name}>"
- end
- argument_fragments << "\n"
-
- @current_function.at("parameters").children =
- Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
- rescue JSON::ParserError
- return function_buffer
- end
- end
-
- function_buffer
+ rval
end
end
end
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index 92315ed5..a185a840 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -93,98 +93,34 @@ module DiscourseAi
end
def final_log_update(log)
- log.request_tokens = @prompt_tokens if @prompt_tokens
- log.response_tokens = @completion_tokens if @completion_tokens
+ log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
+ log.response_tokens = processor.completion_tokens if processor.completion_tokens
end
- def extract_completion_from(response_raw)
- json = JSON.parse(response_raw, symbolize_names: true)
-
- if @streaming_mode
- @prompt_tokens ||= json.dig(:usage, :prompt_tokens)
- @completion_tokens ||= json.dig(:usage, :completion_tokens)
- end
-
- parsed = json.dig(:choices, 0)
- return if !parsed
-
- response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
- @has_function_call ||= response_h.dig(:tool_calls).present?
- @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
+ def decode(response_raw)
+ processor.process_message(JSON.parse(response_raw, symbolize_names: true))
end
- def partials_from(decoded_chunk)
- decoded_chunk
- .split("\n")
- .map do |line|
- data = line.split("data: ", 2)[1]
- data == "[DONE]" ? nil : data
- end
+ def decode_chunk(chunk)
+ @decoder ||= JsonStreamDecoder.new
+ (@decoder << chunk)
+ .map { |parsed_json| processor.process_streamed_message(parsed_json) }
+ .flatten
.compact
end
- def has_tool?(_response_data)
- @has_function_call
+ def decode_chunk_finish
+ @processor.finish
end
- def native_tool_support?
- true
+ def xml_tools_enabled?
+ false
end
- def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
- if @streaming_mode
- return function_buffer if !partial
- else
- partial = payload
- end
+ private
- @args_buffer ||= +""
-
- f_name = partial.dig(:function, :name)
-
- @current_function ||= function_buffer.at("invoke")
-
- if f_name
- current_name = function_buffer.at("tool_name").content
-
- if current_name.blank?
- # first call
- else
- # we have a previous function, so we need to add a noop
- @args_buffer = +""
- @current_function =
- function_buffer.at("function_calls").add_child(
- Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
- )
- end
- end
-
- @current_function.at("tool_name").content = f_name if f_name
- @current_function.at("tool_id").content = partial[:id] if partial[:id]
-
- args = partial.dig(:function, :arguments)
-
- # allow for SPACE within arguments
- if args && args != ""
- @args_buffer << args
-
- begin
- json_args = JSON.parse(@args_buffer, symbolize_names: true)
-
- argument_fragments =
- json_args.reduce(+"") do |memo, (arg_name, value)|
- memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}#{arg_name}>"
- end
- argument_fragments << "\n"
-
- @current_function.at("parameters").children =
- Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
- rescue JSON::ParserError
- return function_buffer
- end
- end
-
- function_buffer
+ def processor
+ @processor ||= OpenAiMessageProcessor.new
end
end
end
diff --git a/lib/completions/endpoints/samba_nova.rb b/lib/completions/endpoints/samba_nova.rb
index ccb883cc..cc81e786 100644
--- a/lib/completions/endpoints/samba_nova.rb
+++ b/lib/completions/endpoints/samba_nova.rb
@@ -55,27 +55,31 @@ module DiscourseAi
log.response_tokens = @completion_tokens if @completion_tokens
end
- def extract_completion_from(response_raw)
- json = JSON.parse(response_raw, symbolize_names: true)
-
- if @streaming_mode
- @prompt_tokens ||= json.dig(:usage, :prompt_tokens)
- @completion_tokens ||= json.dig(:usage, :completion_tokens)
- end
-
- parsed = json.dig(:choices, 0)
- return if !parsed
-
- @streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
+ def xml_tools_enabled?
+ true
end
- def partials_from(decoded_chunk)
- decoded_chunk
- .split("\n")
- .map do |line|
- data = line.split("data: ", 2)[1]
- data == "[DONE]" ? nil : data
+ def decode(response_raw)
+ json = JSON.parse(response_raw, symbolize_names: true)
+ [json.dig(:choices, 0, :message, :content)]
+ end
+
+ def decode_chunk(chunk)
+ @json_decoder ||= JsonStreamDecoder.new
+ (@json_decoder << chunk)
+ .map do |json|
+ text = json.dig(:choices, 0, :delta, :content)
+
+ @prompt_tokens ||= json.dig(:usage, :prompt_tokens)
+ @completion_tokens ||= json.dig(:usage, :completion_tokens)
+
+ if !text.to_s.empty?
+ text
+ else
+ nil
+ end
end
+ .flatten
.compact
end
end
diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb
index 57fcf051..6b371a09 100644
--- a/lib/completions/endpoints/vllm.rb
+++ b/lib/completions/endpoints/vllm.rb
@@ -42,7 +42,10 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
payload = default_options.merge(model_params).merge(messages: prompt)
- payload[:stream] = true if @streaming_mode
+ if @streaming_mode
+ payload[:stream] = true if @streaming_mode
+ payload[:stream_options] = { include_usage: true }
+ end
payload
end
@@ -56,24 +59,42 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
- def partials_from(decoded_chunk)
- decoded_chunk
- .split("\n")
- .map do |line|
- data = line.split("data: ", 2)[1]
- data == "[DONE]" ? nil : data
- end
- .compact
+ def xml_tools_enabled?
+ true
end
- def extract_completion_from(response_raw)
- parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
- # half a line sent here
- return if !parsed
+ def final_log_update(log)
+ log.request_tokens = @prompt_tokens if @prompt_tokens
+ log.response_tokens = @completion_tokens if @completion_tokens
+ end
- response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
+ def decode(response_raw)
+ json = JSON.parse(response_raw, symbolize_names: true)
+ @prompt_tokens = json.dig(:usage, :prompt_tokens)
+ @completion_tokens = json.dig(:usage, :completion_tokens)
+ [json.dig(:choices, 0, :message, :content)]
+ end
- response_h.dig(:content)
+ def decode_chunk(chunk)
+ @json_decoder ||= JsonStreamDecoder.new
+ (@json_decoder << chunk)
+ .map do |parsed|
+ # vLLM keeps sending usage over and over again
+ prompt_tokens = parsed.dig(:usage, :prompt_tokens)
+ completion_tokens = parsed.dig(:usage, :completion_tokens)
+
+ @prompt_tokens = prompt_tokens if prompt_tokens
+
+ @completion_tokens = completion_tokens if completion_tokens
+
+ text = parsed.dig(:choices, 0, :delta, :content)
+ if text.to_s.empty?
+ nil
+ else
+ text
+ end
+ end
+ .compact
end
end
end
diff --git a/lib/completions/function_call_normalizer.rb b/lib/completions/function_call_normalizer.rb
deleted file mode 100644
index ef40809c..00000000
--- a/lib/completions/function_call_normalizer.rb
+++ /dev/null
@@ -1,113 +0,0 @@
-# frozen_string_literal: true
-
-class DiscourseAi::Completions::FunctionCallNormalizer
- attr_reader :done
-
- # blk is the block to call with filtered data
- def initialize(blk, cancel)
- @blk = blk
- @cancel = cancel
- @done = false
-
- @in_tool = false
-
- @buffer = +""
- @function_buffer = +""
- end
-
- def self.normalize(data)
- text = +""
- cancel = -> {}
- blk = ->(partial, _) { text << partial }
-
- normalizer = self.new(blk, cancel)
- normalizer << data
-
- [text, normalizer.function_calls]
- end
-
- def function_calls
- return nil if @function_buffer.blank?
-
- xml = Nokogiri::HTML5.fragment(@function_buffer)
- self.class.normalize_function_ids!(xml)
- last_invoke = xml.at("invoke:last")
- if last_invoke
- last_invoke.next_sibling.remove while last_invoke.next_sibling
- xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
- end
- xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
- end
-
- def <<(text)
- @buffer << text
-
- if !@in_tool
- # double check if we are clearly in a tool
- search_length = text.length + 20
- search_string = @buffer[-search_length..-1] || @buffer
-
- index = search_string.rindex("")
- @in_tool = !!index
- if @in_tool
- @function_buffer = @buffer[index..-1]
- text_index = text.rindex("")
- @blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
- end
- else
- @function_buffer << text
- end
-
- if !@in_tool
- if maybe_has_tool?(@buffer)
- split_index = text.rindex("<").to_i - 1
- if split_index >= 0
- @function_buffer = text[split_index + 1..-1] || ""
- text = text[0..split_index] || ""
- else
- @function_buffer << text
- text = ""
- end
- else
- if @function_buffer.length > 0
- @blk.call(@function_buffer, @cancel)
- @function_buffer = +""
- end
- end
-
- @blk.call(text, @cancel) if text.length > 0
- else
- if text.include?("")
- @done = true
- @cancel.call
- end
- end
- end
-
- def self.normalize_function_ids!(function_buffer)
- function_buffer
- .css("invoke")
- .each_with_index do |invoke, index|
- if invoke.at("tool_id")
- invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
- else
- invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id")
- end
- end
- end
-
- private
-
- def maybe_has_tool?(text)
- # 16 is the length of function calls
- substring = text[-16..-1] || text
- split = substring.split("<")
-
- if split.length > 1
- match = "<" + split.last
- "".start_with?(match)
- else
- substring.ends_with?("<")
- end
- end
-end
diff --git a/lib/completions/json_stream_decoder.rb b/lib/completions/json_stream_decoder.rb
new file mode 100644
index 00000000..e575a3b7
--- /dev/null
+++ b/lib/completions/json_stream_decoder.rb
@@ -0,0 +1,48 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ # will work for anthropic and open ai compatible
+ class JsonStreamDecoder
+ attr_reader :buffer
+
+ LINE_REGEX = /data: ({.*})\s*$/
+
+ def initialize(symbolize_keys: true, line_regex: LINE_REGEX)
+ @symbolize_keys = symbolize_keys
+ @buffer = +""
+ @line_regex = line_regex
+ end
+
+ def <<(raw)
+ @buffer << raw.to_s
+ rval = []
+
+ split = @buffer.scan(/.*\n?/)
+ split.pop if split.last.blank?
+
+ @buffer = +(split.pop.to_s)
+
+ split.each do |line|
+ matches = line.match(@line_regex)
+ next if !matches
+ rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
+ end
+
+ if @buffer.present?
+ matches = @buffer.match(@line_regex)
+ if matches
+ begin
+ rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
+ @buffer = +""
+ rescue JSON::ParserError
+ # maybe it is a partial line
+ end
+ end
+ end
+
+ rval
+ end
+ end
+ end
+end
diff --git a/lib/completions/open_ai_message_processor.rb b/lib/completions/open_ai_message_processor.rb
new file mode 100644
index 00000000..02369bec
--- /dev/null
+++ b/lib/completions/open_ai_message_processor.rb
@@ -0,0 +1,103 @@
+# frozen_string_literal: true
+module DiscourseAi::Completions
+ class OpenAiMessageProcessor
+ attr_reader :prompt_tokens, :completion_tokens
+
+ def initialize
+ @tool = nil
+ @tool_arguments = +""
+ @prompt_tokens = nil
+ @completion_tokens = nil
+ end
+
+ def process_message(json)
+ result = []
+ tool_calls = json.dig(:choices, 0, :message, :tool_calls)
+
+ message = json.dig(:choices, 0, :message, :content)
+ result << message if message.present?
+
+ if tool_calls.present?
+ tool_calls.each do |tool_call|
+ id = tool_call.dig(:id)
+ name = tool_call.dig(:function, :name)
+ arguments = tool_call.dig(:function, :arguments)
+ parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {}
+ result << ToolCall.new(id: id, name: name, parameters: parameters)
+ end
+ end
+
+ update_usage(json)
+
+ result
+ end
+
+ def process_streamed_message(json)
+ rval = nil
+
+ tool_calls = json.dig(:choices, 0, :delta, :tool_calls)
+ content = json.dig(:choices, 0, :delta, :content)
+
+ finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == []
+
+ if tool_calls.present?
+ id = tool_calls.dig(0, :id)
+ name = tool_calls.dig(0, :function, :name)
+ arguments = tool_calls.dig(0, :function, :arguments)
+
+ # TODO: multiple tool support may require index
+ #index = tool_calls[0].dig(:index)
+
+ if id.present? && @tool && @tool.id != id
+ process_arguments
+ rval = @tool
+ @tool = nil
+ end
+
+ if id.present? && name.present?
+ @tool_arguments = +""
+ @tool = ToolCall.new(id: id, name: name)
+ end
+
+ @tool_arguments << arguments.to_s
+ elsif finished_tools && @tool
+ parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
+ @tool.parameters = parsed_args
+ rval = @tool
+ @tool = nil
+ elsif content.present?
+ rval = content
+ end
+
+ update_usage(json)
+
+ rval
+ end
+
+ def finish
+ rval = []
+ if @tool
+ process_arguments
+ rval << @tool
+ @tool = nil
+ end
+
+ rval
+ end
+
+ private
+
+ def process_arguments
+ if @tool_arguments.present?
+ parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
+ @tool.parameters = parsed_args
+ @tool_arguments = nil
+ end
+ end
+
+ def update_usage(json)
+ @prompt_tokens ||= json.dig(:usage, :prompt_tokens)
+ @completion_tokens ||= json.dig(:usage, :completion_tokens)
+ end
+ end
+end
diff --git a/lib/completions/tool_call.rb b/lib/completions/tool_call.rb
new file mode 100644
index 00000000..15be7b3f
--- /dev/null
+++ b/lib/completions/tool_call.rb
@@ -0,0 +1,29 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ class ToolCall
+ attr_reader :id, :name, :parameters
+
+ def initialize(id:, name:, parameters: nil)
+ @id = id
+ @name = name
+ self.parameters = parameters if parameters
+ @parameters ||= {}
+ end
+
+ def parameters=(parameters)
+ raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash)
+ @parameters = parameters.symbolize_keys
+ end
+
+ def ==(other)
+ id == other.id && name == other.name && parameters == other.parameters
+ end
+
+ def to_s
+ "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
+ end
+ end
+ end
+end
diff --git a/lib/completions/xml_tool_processor.rb b/lib/completions/xml_tool_processor.rb
new file mode 100644
index 00000000..1b42b333
--- /dev/null
+++ b/lib/completions/xml_tool_processor.rb
@@ -0,0 +1,124 @@
+# frozen_string_literal: true
+
+# This class can be used to process a stream of text that may contain XML tool
+# calls.
+# It will return either text or ToolCall objects.
+
+module DiscourseAi
+ module Completions
+ class XmlToolProcessor
+ def initialize
+ @buffer = +""
+ @function_buffer = +""
+ @should_cancel = false
+ @in_tool = false
+ end
+
+ def <<(text)
+ @buffer << text
+ result = []
+
+ if !@in_tool
+ # double check if we are clearly in a tool
+ search_length = text.length + 20
+ search_string = @buffer[-search_length..-1] || @buffer
+
+ index = search_string.rindex("")
+ @in_tool = !!index
+ if @in_tool
+ @function_buffer = @buffer[index..-1]
+ text_index = text.rindex("")
+ result << text[0..text_index - 1].strip if text_index && text_index > 0
+ end
+ else
+ @function_buffer << text
+ end
+
+ if !@in_tool
+ if maybe_has_tool?(@buffer)
+ split_index = text.rindex("<").to_i - 1
+ if split_index >= 0
+ @function_buffer = text[split_index + 1..-1] || ""
+ text = text[0..split_index] || ""
+ else
+ @function_buffer << text
+ text = ""
+ end
+ else
+ if @function_buffer.length > 0
+ result << @function_buffer
+ @function_buffer = +""
+ end
+ end
+
+ result << text if text.length > 0
+ else
+ @should_cancel = true if text.include?("")
+ end
+
+ result
+ end
+
+ def finish
+ return [] if @function_buffer.blank?
+
+ xml = Nokogiri::HTML5.fragment(@function_buffer)
+ normalize_function_ids!(xml)
+ last_invoke = xml.at("invoke:last")
+ if last_invoke
+ last_invoke.next_sibling.remove while last_invoke.next_sibling
+ xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
+ end
+
+ xml
+ .css("invoke")
+ .map do |invoke|
+ tool_name = invoke.at("tool_name").content.force_encoding("UTF-8")
+ tool_id = invoke.at("tool_id").content.force_encoding("UTF-8")
+ parameters = {}
+ invoke
+ .at("parameters")
+ &.children
+ &.each do |node|
+ next if node.text?
+ name = node.name
+ value = node.content.to_s
+ parameters[name.to_sym] = value.to_s.force_encoding("UTF-8")
+ end
+ ToolCall.new(id: tool_id, name: tool_name, parameters: parameters)
+ end
+ end
+
+ def should_cancel?
+ @should_cancel
+ end
+
+ private
+
+ def normalize_function_ids!(function_buffer)
+ function_buffer
+ .css("invoke")
+ .each_with_index do |invoke, index|
+ if invoke.at("tool_id")
+ invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
+ else
+ invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id")
+ end
+ end
+ end
+
+ def maybe_has_tool?(text)
+ # 16 is the length of function calls
+ substring = text[-16..-1] || text
+ split = substring.split("<")
+
+ if split.length > 1
+ match = "<" + split.last
+ "".start_with?(match)
+ else
+ substring.ends_with?("<")
+ end
+ end
+ end
+ end
+end
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index 94d5d655..40eca30f 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: {"type":"message_stop"}
STRING
- result = +""
+ result = []
body = body.scan(/.*\n/)
EndpointMock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: body)
@@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end
end
- expected = (<<~TEXT).strip
-
-
- search
- s<a>m sam
- general
- toolu_01DjrShFRRHp9SnHYRFRc53F
-
-
- TEXT
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
+ parameters: {
+ search_query: "sm sam",
+ category: "general",
+ },
+ )
- expect(result.strip).to eq(expected)
+ expect(result).to eq([tool_call])
end
it "can stream a response" do
@@ -191,6 +190,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
expect(log.feature_name).to eq("testing")
expect(log.response_tokens).to eq(15)
expect(log.request_tokens).to eq(25)
+ expect(log.raw_request_payload).to eq(expected_body.to_json)
+ expect(log.raw_response_payload.strip).to eq(body.strip)
end
it "supports non streaming tool calls" do
@@ -242,17 +243,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
result = llm.generate(prompt, user: Discourse.system_user)
- expected = <<~TEXT.strip
-
-
- calculate
- 2758975 + 21.11
- toolu_012kBdhG4eHaV68W56p4N94h
-
-
- TEXT
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "calculate",
+ id: "toolu_012kBdhG4eHaV68W56p4N94h",
+ parameters: {
+ expression: "2758975 + 21.11",
+ },
+ )
- expect(result.strip).to eq(expected)
+ expect(result).to eq(["Here is the calculation:", tool_call])
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(345)
+ expect(log.response_tokens).to eq(65)
end
it "can send images via a completion prompt" do
diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
index d9519344..2a9cc77f 100644
--- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb
+++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
@@ -79,7 +79,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
}
prompt.tools = [tool]
- response = +""
+ response = []
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
@@ -90,21 +90,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(parsed_body["tools"]).to eq(nil)
expect(parsed_body["stop_sequences"]).to eq([""])
- # note we now have a tool_id cause we were normalized
- function_call = <<~XML.strip
- hello
+ expected = [
+ "hello\n",
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "google",
+ parameters: {
+ query: "sydney weather today",
+ },
+ ),
+ ]
-
-
-
- google
- sydney weather today
- tool_0
-
-
- XML
-
- expect(response.strip).to eq(function_call)
+ expect(response).to eq(expected)
end
end
@@ -230,23 +227,23 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
}
prompt.tools = [tool]
- response = +""
+ response = []
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
- expected_response = (<<~RESPONSE).strip
-
-
- google
- sydney weather today
- toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7
-
-
- RESPONSE
+ expected_response = [
+ DiscourseAi::Completions::ToolCall.new(
+ id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
+ name: "google",
+ parameters: {
+ query: "sydney weather today",
+ },
+ ),
+ ]
- expect(response.strip).to eq(expected_response)
+ expect(response).to eq(expected_response)
expected = {
"max_tokens" => 3000,
diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb
index 4bb213ff..bdff8fc3 100644
--- a/spec/lib/completions/endpoints/cohere_spec.rb
+++ b/spec/lib/completions/endpoints/cohere_spec.rb
@@ -66,7 +66,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
TEXT
parsed_body = nil
- result = +""
+ result = []
sig = {
name: "google",
@@ -91,21 +91,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
},
).to_return(status: 200, body: body.split("|"))
- result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
+ llm.generate(prompt, user: user) { |partial, cancel| result << partial }
end
- expected = <<~TEXT
-
-
- google
- who is sam saffron
-
- tool_0
-
-
- TEXT
+ text = "I will search for 'who is sam saffron' and relay the information to the user."
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "google",
+ parameters: {
+ query: "who is sam saffron",
+ },
+ )
- expect(result.strip).to eq(expected.strip)
+ expect(result).to eq([text, tool_call])
expected = {
model: "command-r-plus",
diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb
index 372c529b..130c735b 100644
--- a/spec/lib/completions/endpoints/endpoint_compliance.rb
+++ b/spec/lib/completions/endpoints/endpoint_compliance.rb
@@ -62,18 +62,14 @@ class EndpointMock
end
def invocation_response
- <<~TEXT
-
-
- get_weather
-
- Sydney
- c
-
- tool_0
-
-
- TEXT
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "get_weather",
+ parameters: {
+ location: "Sydney",
+ unit: "c",
+ },
+ )
end
def tool_id
@@ -185,7 +181,7 @@ class EndpointsCompliance
mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user)
- expect(completion_response.strip).to eq(mock.invocation_response.strip)
+ expect(completion_response).to eq(mock.invocation_response)
end
def streaming_mode_simple_prompt(mock)
@@ -205,6 +201,7 @@ class EndpointsCompliance
expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
+
expect(log.response_tokens).to eq(
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
)
@@ -216,14 +213,14 @@ class EndpointsCompliance
a_dialect = dialect(prompt: prompt)
mock.stub_streamed_tool_call(a_dialect.translate) do
- buffered_partial = +""
+ buffered_partial = []
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
buffered_partial << partial
- cancel.call if buffered_partial.include?("")
+ cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall)
end
- expect(buffered_partial.strip).to eq(mock.invocation_response.strip)
+ expect(buffered_partial).to eq([mock.invocation_response])
end
end
diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb
index 2f602d3a..18933843 100644
--- a/spec/lib/completions/endpoints/gemini_spec.rb
+++ b/spec/lib/completions/endpoints/gemini_spec.rb
@@ -195,19 +195,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
response = llm.generate(prompt, user: user)
- expected = (<<~XML).strip
-
-
- echo
-
- <S>ydney
-
- tool_0
-
-
- XML
+ tool =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "echo",
+ parameters: {
+ text: "ydney",
+ },
+ )
- expect(response.strip).to eq(expected)
+ expect(response).to eq(tool)
end
it "Supports Vision API" do
@@ -265,6 +262,68 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
expect(JSON.parse(req_body)).to eq(expected_prompt)
end
+ it "Can stream tool calls correctly" do
+ rows = [
+ {
+ candidates: [
+ {
+ content: {
+ parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }],
+ role: "model",
+ },
+ safetyRatings: [
+ { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
+ ],
+ },
+ ],
+ usageMetadata: {
+ promptTokenCount: 625,
+ totalTokenCount: 625,
+ },
+ modelVersion: "gemini-1.5-pro-002",
+ },
+ {
+ candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }],
+ usageMetadata: {
+ promptTokenCount: 625,
+ candidatesTokenCount: 4,
+ totalTokenCount: 629,
+ },
+ modelVersion: "gemini-1.5-pro-002",
+ },
+ ]
+
+ payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join
+
+ llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
+ url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
+
+ prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
+
+ output = []
+
+ stub_request(:post, url).to_return(status: 200, body: payload)
+ llm.generate(prompt, user: user) { |partial| output << partial }
+
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "echo",
+ parameters: {
+ text: "sam<>wh!s",
+ },
+ )
+
+ expect(output).to eq([tool_call])
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(625)
+ expect(log.response_tokens).to eq(4)
+ end
+
it "Can correctly handle streamed responses even if they are chunked badly" do
data = +""
data << "da|ta: |"
@@ -279,12 +338,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
- output = +""
+ output = []
gemini_mock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: split)
llm.generate("Hello", user: user) { |partial| output << partial }
end
- expect(output).to eq("Hello World Sam")
+ expect(output.join).to eq("Hello World Sam")
end
end
diff --git a/spec/lib/completions/endpoints/ollama_spec.rb b/spec/lib/completions/endpoints/ollama_spec.rb
index eb6bc63c..4f458283 100644
--- a/spec/lib/completions/endpoints/ollama_spec.rb
+++ b/spec/lib/completions/endpoints/ollama_spec.rb
@@ -150,7 +150,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
end
describe "when using streaming mode" do
- context "with simpel prompts" do
+ context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(ollama_mock)
end
diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb
index 60df1d67..c4d7758a 100644
--- a/spec/lib/completions/endpoints/open_ai_spec.rb
+++ b/spec/lib/completions/endpoints/open_ai_spec.rb
@@ -17,8 +17,8 @@ class OpenAiMock < EndpointMock
created: 1_678_464_820,
model: "gpt-3.5-turbo-0301",
usage: {
- prompt_tokens: 337,
- completion_tokens: 162,
+ prompt_tokens: 8,
+ completion_tokens: 13,
total_tokens: 499,
},
choices: [
@@ -231,19 +231,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
result = llm.generate(prompt, user: user)
- expected = (<<~TXT).strip
-
-
- echo
-
- hello
-
- call_I8LKnoijVuhKOM85nnEQgWwd
-
-
- TXT
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "call_I8LKnoijVuhKOM85nnEQgWwd",
+ name: "echo",
+ parameters: {
+ text: "hello",
+ },
+ )
- expect(result.strip).to eq(expected)
+ expect(result).to eq(tool_call)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json,
@@ -320,19 +317,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
- expected = (<<~TXT).strip
-
-
- echo
-
- h<e>llo
-
- call_I8LKnoijVuhKOM85nnEQgWwd
-
-
- TXT
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(55)
+ expect(log.response_tokens).to eq(13)
- expect(result.strip).to eq(expected)
+ expected =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "call_I8LKnoijVuhKOM85nnEQgWwd",
+ name: "echo",
+ parameters: {
+ text: "hllo",
+ },
+ )
+
+ expect(result).to eq(expected)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json,
@@ -487,7 +485,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
- data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
+ data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
@@ -495,32 +493,30 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
TEXT
open_ai_mock.stub_raw(raw_data)
- content = +""
+ response = []
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
- endpoint.perform_completion!(dialect, user) { |partial| content << partial }
+ endpoint.perform_completion!(dialect, user) { |partial| response << partial }
- expected = <<~TEXT
-
-
- search
-
- Discourse AI bot
-
- call_3Gyr3HylFJwfrtKrL6NaIit1
-
-
- search
-
- Discourse AI bot
-
- call_H7YkbgYurHpyJqzwUN4bghwN
-
-
- TEXT
+ tool_calls = [
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "call_3Gyr3HylFJwfrtKrL6NaIit1",
+ parameters: {
+ search_query: "Discourse AI bot",
+ },
+ ),
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "call_H7YkbgYurHpyJqzwUN4bghwN",
+ parameters: {
+ query: "Discourse AI bot2",
+ },
+ ),
+ ]
- expect(content).to eq(expected)
+ expect(response).to eq(tool_calls)
end
it "uses proper token accounting" do
@@ -593,21 +589,16 @@ TEXT
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
- expect(partials.length).to eq(1)
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "func_id",
+ name: "google",
+ parameters: {
+ query: "Adabas 9.1",
+ },
+ )
- function_call = (<<~TXT).strip
-
-
- google
-
- Adabas 9.1
-
- func_id
-
-
- TXT
-
- expect(partials[0].strip).to eq(function_call)
+ expect(partials).to eq([tool_call])
end
end
end
diff --git a/spec/lib/completions/endpoints/samba_nova_spec.rb b/spec/lib/completions/endpoints/samba_nova_spec.rb
index 0f1f68ac..83839bf4 100644
--- a/spec/lib/completions/endpoints/samba_nova_spec.rb
+++ b/spec/lib/completions/endpoints/samba_nova_spec.rb
@@ -22,10 +22,15 @@ data: [DONE]
},
).to_return(status: 200, body: body, headers: {})
- response = +""
+ response = []
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
- expect(response).to eq("I am a bot")
+ expect(response).to eq(["I am a bot"])
+
+ log = AiApiAuditLog.order(:id).last
+
+ expect(log.request_tokens).to eq(21)
+ expect(log.response_tokens).to eq(41)
end
it "can perform regular completions" do
diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb
index 6f5387c0..824bcbe0 100644
--- a/spec/lib/completions/endpoints/vllm_spec.rb
+++ b/spec/lib/completions/endpoints/vllm_spec.rb
@@ -51,7 +51,13 @@ class VllmMock < EndpointMock
WebMock
.stub_request(:post, "https://test.dev/v1/chat/completions")
- .with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
+ .with(
+ body:
+ model
+ .default_options
+ .merge(messages: prompt, stream: true, stream_options: { include_usage: true })
+ .to_json,
+ )
.to_return(status: 200, body: chunks)
end
end
@@ -136,29 +142,115 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
result = llm.generate(prompt, user: Discourse.system_user)
- expected = <<~TEXT
-
-
- calculate
-
- 1+1
- tool_0
-
-
- TEXT
+ expected =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "calculate",
+ id: "tool_0",
+ parameters: {
+ expression: "1+1",
+ },
+ )
- expect(result.strip).to eq(expected.strip)
+ expect(result).to eq(expected)
end
end
+ it "correctly accounts for tokens in non streaming mode" do
+ body = (<<~TEXT).strip
+ {"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null}
+ TEXT
+
+ stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body)
+
+ result = llm.generate("generate a title", user: Discourse.system_user)
+
+ expect(result).to eq("Random Number Generator Produces Smallest Possible Result")
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(146)
+ expect(log.response_tokens).to eq(10)
+ end
+
+ it "can properly include usage in streaming mode" do
+ payload = <<~TEXT.strip
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
+
+ data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
+
+ data: [DONE]
+ TEXT
+
+ stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
+ status: 200,
+ body: payload,
+ )
+
+ response = []
+ llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial }
+
+ expect(response.join).to eq(
+ "Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?",
+ )
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(46)
+ expect(log.response_tokens).to eq(26)
+ end
+
describe "#perform_completion!" do
context "when using regular mode" do
- context "with simple prompts" do
- it "completes a trivial prompt and logs the response" do
- compliance.regular_mode_simple_prompt(vllm_mock)
- end
- end
-
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(vllm_mock)
diff --git a/spec/lib/completions/function_call_normalizer_spec.rb b/spec/lib/completions/function_call_normalizer_spec.rb
deleted file mode 100644
index dd78ed7f..00000000
--- a/spec/lib/completions/function_call_normalizer_spec.rb
+++ /dev/null
@@ -1,182 +0,0 @@
-# frozen_string_literal: true
-
-RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
- let(:buffer) { +"" }
-
- let(:normalizer) do
- blk = ->(data, cancel) { buffer << data }
- cancel = -> { @done = true }
- DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
- end
-
- def pass_through!(data)
- normalizer << data
- expect(buffer[-data.length..-1]).to eq(data)
- end
-
- it "is usable in non streaming mode" do
- xml = (<<~XML).strip
- hello
-
-
- hello
-
- XML
-
- text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
-
- expect(text).to eq("hello")
-
- expected_function_calls = (<<~XML).strip
-
-
- hello
- tool_0
-
-
- XML
-
- expect(function_calls).to eq(expected_function_calls)
- end
-
- it "strips junk from end of function calls" do
- xml = (<<~XML).strip
- hello
-
-
- hello
-
- junk
- XML
-
- _text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
-
- expected_function_calls = (<<~XML).strip
-
-
- hello
- tool_0
-
-
- XML
-
- expect(function_calls).to eq(expected_function_calls)
- end
-
- it "returns nil for function calls if there are none" do
- input = "hello world\n"
- text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
-
- expect(text).to eq(input)
- expect(function_calls).to eq(nil)
- end
-
- it "passes through data if there are no function calls detected" do
- pass_through!("hello")
- pass_through!("hello")
- pass_through!("world")
- pass_through!("")
- end
-
- it "properly handles non English tools" do
- normalizer << "hello\n"
-
- normalizer << (<<~XML).strip
-
- hello
-
- 世界
-
-
- XML
-
- expected = (<<~XML).strip
-
-
- hello
-
- 世界
-
- tool_0
-
-
- XML
-
- function_calls = normalizer.function_calls
- expect(function_calls).to eq(expected)
- end
-
- it "works correctly even if you only give it 1 letter at a time" do
- xml = (<<~XML).strip
- abc
-
-
- hello
-
- world
-
- abc
-
-
- hello2
-
- world
-
- aba
-
-
- XML
-
- xml.each_char { |char| normalizer << char }
-
- expect(buffer + normalizer.function_calls).to eq(xml)
- end
-
- it "supports multiple invokes" do
- xml = (<<~XML).strip
-
-
- hello
-
- world
-
- abc
-
-
- hello2
-
- world
-
- aba
-
-
- XML
-
- normalizer << xml
-
- expect(normalizer.function_calls).to eq(xml)
- end
-
- it "can will cancel if it encounteres " do
- normalizer << ""
- expect(normalizer.done).to eq(false)
- normalizer << ""
- expect(normalizer.done).to eq(true)
- expect(@done).to eq(true)
-
- expect(normalizer.function_calls).to eq("")
- end
-
- it "pauses on function call and starts buffering" do
- normalizer << "hello"
- expect(buffer).to eq("hello")
- expect(normalizer.done).to eq(false)
- end
-end
diff --git a/spec/lib/completions/json_stream_decoder_spec.rb b/spec/lib/completions/json_stream_decoder_spec.rb
new file mode 100644
index 00000000..831bad6f
--- /dev/null
+++ b/spec/lib/completions/json_stream_decoder_spec.rb
@@ -0,0 +1,47 @@
+# frozen_string_literal: true
+
+describe DiscourseAi::Completions::JsonStreamDecoder do
+ let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
+
+ it "should be able to parse simple messages" do
+ result = decoder << "data: #{{ hello: "world" }.to_json}"
+ expect(result).to eq([{ hello: "world" }])
+ end
+
+ it "should handle anthropic mixed stlye streams" do
+ stream = (<<~TEXT).split("|")
+ event: |message_start|
+ data: |{"hel|lo": "world"}|
+
+ event: |message_start
+ data: {"foo": "bar"}
+
+ event: |message_start
+ data: {"ba|z": "qux"|}
+
+ [DONE]
+ TEXT
+
+ results = []
+ stream.each { |chunk| results << (decoder << chunk) }
+
+ expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
+ end
+
+ it "should be able to handle complex overlaps" do
+ stream = (<<~TEXT).split("|")
+ data: |{"hel|lo": "world"}
+
+ data: {"foo": "bar"}
+
+ data: {"ba|z": "qux"|}
+
+ [DONE]
+ TEXT
+
+ results = []
+ stream.each { |chunk| results << (decoder << chunk) }
+
+ expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
+ end
+end
diff --git a/spec/lib/completions/xml_tool_processor_spec.rb b/spec/lib/completions/xml_tool_processor_spec.rb
new file mode 100644
index 00000000..003f4356
--- /dev/null
+++ b/spec/lib/completions/xml_tool_processor_spec.rb
@@ -0,0 +1,188 @@
+# frozen_string_literal: true
+
+RSpec.describe DiscourseAi::Completions::XmlToolProcessor do
+ let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new }
+
+ it "can process simple text" do
+ result = []
+ result << (processor << "hello")
+ result << (processor << " world ")
+ expect(result).to eq([["hello"], [" world "]])
+ expect(processor.finish).to eq([])
+ expect(processor.should_cancel?).to eq(false)
+ end
+
+ it "is usable for simple single message mode" do
+ xml = (<<~XML).strip
+ hello
+
+
+ hello
+
+ world
+ value
+
+
+ XML
+
+ result = []
+ result << (processor << xml)
+ result << (processor.finish)
+
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "hello",
+ parameters: {
+ hello: "world",
+ test: "value",
+ },
+ )
+ expect(result).to eq([["hello"], [tool_call]])
+ expect(processor.should_cancel?).to eq(false)
+ end
+
+ it "handles multiple tool calls in sequence" do
+ xml = (<<~XML).strip
+ start
+
+
+ first_tool
+
+ value1
+
+
+
+ second_tool
+
+ value2
+
+
+
+ end
+ XML
+
+ result = []
+ result << (processor << xml)
+ result << (processor.finish)
+
+ first_tool =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "first_tool",
+ parameters: {
+ param1: "value1",
+ },
+ )
+
+ second_tool =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_1",
+ name: "second_tool",
+ parameters: {
+ param2: "value2",
+ },
+ )
+
+ expect(result).to eq([["start"], [first_tool, second_tool]])
+ expect(processor.should_cancel?).to eq(true)
+ end
+
+ it "handles non-English parameters correctly" do
+ xml = (<<~XML).strip
+ こんにちは
+
+
+ translator
+
+ 世界
+
+
+ XML
+
+ result = []
+ result << (processor << xml)
+ result << (processor.finish)
+
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ id: "tool_0",
+ name: "translator",
+ parameters: {
+ text: "世界",
+ },
+ )
+
+ expect(result).to eq([["こんにちは"], [tool_call]])
+ end
+
+ it "processes input character by character" do
+ xml =
+ "hitestv
"
+
+ result = []
+ xml.each_char { |char| result << (processor << char) }
+ result << processor.finish
+
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" })
+
+ filtered_result = result.reject(&:empty?)
+ expect(filtered_result).to eq([["h"], ["i"], [tool_call]])
+ end
+
+ it "handles malformed XML gracefully" do
+ xml = (<<~XML).strip
+ text
+
+
+ test
+
+ value
+
+
+ malformed
+ XML
+
+ result = []
+ result << (processor << xml)
+ result << (processor.finish)
+
+ # Should just do its best to parse the XML
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" })
+ expect(result).to eq([["text"], [tool_call]])
+ end
+
+ it "correctly processes empty parameter sets" do
+ xml = (<<~XML).strip
+ hello
+
+
+ no_params
+
+
+
+ XML
+
+ result = []
+ result << (processor << xml)
+ result << (processor.finish)
+
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {})
+
+ expect(result).to eq([["hello"], [tool_call]])
+ end
+
+ it "properly handles cancelled processing" do
+ xml = "start"
+ result = []
+ result << (processor << xml)
+ result << (processor << "more text")
+ result << processor.finish
+
+ expect(result).to eq([["start"], [], []])
+ expect(processor.should_cancel?).to eq(true)
+ end
+end
diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb
index 2fb95d19..5271374d 100644
--- a/spec/lib/modules/ai_bot/personas/persona_spec.rb
+++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb
@@ -72,40 +72,27 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
it "can parse string that are wrapped in quotes" do
SiteSetting.ai_stability_api_key = "123"
- xml = <<~XML
-
-
- image
- call_JtYQMful5QKqw97XFsHzPweB
-
- ["cat oil painting", "big car"]
- "16:9"
-
-
-
- image
- call_JtYQMful5QKqw97XFsHzPweB
-
- ["cat oil painting", "big car"]
- '16:9'
-
-
-
- XML
- image1, image2 =
- tools =
- DiscourseAi::AiBot::Personas::Artist.new.find_tools(
- xml,
- bot_user: nil,
- llm: nil,
- context: nil,
- )
- expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
- expect(image1.parameters[:aspect_ratio]).to eq("16:9")
- expect(image2.parameters[:aspect_ratio]).to eq("16:9")
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "image",
+ id: "call_JtYQMful5QKqw97XFsHzPweB",
+ parameters: {
+ prompts: ["cat oil painting", "big car"],
+ aspect_ratio: "16:9",
+ },
+ )
- expect(tools.length).to eq(2)
+ tool_instance =
+ DiscourseAi::AiBot::Personas::Artist.new.find_tool(
+ tool_call,
+ bot_user: nil,
+ llm: nil,
+ context: nil,
+ )
+
+ expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
+ expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9")
end
it "enforces enums" do
@@ -132,42 +119,68 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
XML
- search1, search2 =
- tools =
- DiscourseAi::AiBot::Personas::General.new.find_tools(
- xml,
- bot_user: nil,
- llm: nil,
- context: nil,
- )
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "call_JtYQMful5QKqw97XFsHzPweB",
+ parameters: {
+ max_posts: "3.2",
+ status: "cow",
+ foo: "bar",
+ },
+ )
- expect(search1.parameters.key?(:status)).to eq(false)
- expect(search2.parameters[:status]).to eq("open")
+ tool_instance =
+ DiscourseAi::AiBot::Personas::General.new.find_tool(
+ tool_call,
+ bot_user: nil,
+ llm: nil,
+ context: nil,
+ )
+
+ expect(tool_instance.parameters.key?(:status)).to eq(false)
+
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "call_JtYQMful5QKqw97XFsHzPweB",
+ parameters: {
+ max_posts: "3.2",
+ status: "open",
+ foo: "bar",
+ },
+ )
+
+ tool_instance =
+ DiscourseAi::AiBot::Personas::General.new.find_tool(
+ tool_call,
+ bot_user: nil,
+ llm: nil,
+ context: nil,
+ )
+
+ expect(tool_instance.parameters[:status]).to eq("open")
end
it "can coerce integers" do
- xml = <<~XML
-
-
- search
- call_JtYQMful5QKqw97XFsHzPweB
-
- "3.2"
- hello world
- bar
-
-
-
- XML
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "call_JtYQMful5QKqw97XFsHzPweB",
+ parameters: {
+ max_posts: "3.2",
+ search_query: "hello world",
+ foo: "bar",
+ },
+ )
- search, =
- tools =
- DiscourseAi::AiBot::Personas::General.new.find_tools(
- xml,
- bot_user: nil,
- llm: nil,
- context: nil,
- )
+ search =
+ DiscourseAi::AiBot::Personas::General.new.find_tool(
+ tool_call,
+ bot_user: nil,
+ llm: nil,
+ context: nil,
+ )
expect(search.parameters[:max_posts]).to eq(3)
expect(search.parameters[:search_query]).to eq("hello world")
@@ -177,43 +190,23 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
it "can correctly parse arrays in tools" do
SiteSetting.ai_openai_api_key = "123"
- # Dall E tool uses an array for params
- xml = <<~XML
-
-
- dall_e
- call_JtYQMful5QKqw97XFsHzPweB
-
- ["cat oil painting", "big car"]
-
-
-
- dall_e
- abc
-
- ["pic3"]
-
-
-
- unknown
- abc
-
- ["pic3"]
-
-
-
- XML
- dall_e1, dall_e2 =
- tools =
- DiscourseAi::AiBot::Personas::DallE3.new.find_tools(
- xml,
- bot_user: nil,
- llm: nil,
- context: nil,
- )
- expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
- expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
- expect(tools.length).to eq(2)
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "dall_e",
+ id: "call_JtYQMful5QKqw97XFsHzPweB",
+ parameters: {
+ prompts: ["cat oil painting", "big car"],
+ },
+ )
+
+ tool_instance =
+ DiscourseAi::AiBot::Personas::DallE3.new.find_tool(
+ tool_call,
+ bot_user: nil,
+ llm: nil,
+ context: nil,
+ )
+ expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
end
describe "custom personas" do
diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb
index 9c98a08a..2a07ad52 100644
--- a/spec/lib/modules/ai_bot/playground_spec.rb
+++ b/spec/lib/modules/ai_bot/playground_spec.rb
@@ -55,6 +55,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
)
end
+ before { SiteSetting.ai_embeddings_enabled = false }
+
after do
# we must reset cache on persona cause data can be rolled back
AiPersona.persona_cache.flush!
@@ -83,17 +85,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
- let(:function_call) { (<<~XML).strip }
-
-
- search
- 666
-
- Can you use the custom tool
-
-
- ",
- XML
+ let(:tool_call) do
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "666",
+ parameters: {
+ query: "Can you use the custom tool",
+ },
+ )
+ end
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
@@ -115,7 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
reply_post = nil
prompts = nil
- responses = [function_call]
+ responses = [tool_call]
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
reply_post = playground.reply_to(new_post)
@@ -133,7 +133,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can force usage of a tool" do
tool_name = "custom-#{custom_tool.id}"
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
- responses = [function_call, "custom tool did stuff (maybe)"]
+ responses = [tool_call, "custom tool did stuff (maybe)"]
prompts = nil
reply_post = nil
@@ -166,7 +166,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
- responses = [function_call, "custom tool did stuff (maybe)"]
+ responses = [tool_call, "custom tool did stuff (maybe)"]
reply_post = nil
@@ -206,13 +206,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
+ responses = ["custom tool did stuff (maybe)", tool_call]
+
# lets ensure tool does not run...
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
reply_post = playground.reply_to(new_post)
end
- expect(reply_post.raw.strip).to eq(function_call)
+ expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)")
end
end
@@ -452,10 +454,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can run tools" do
persona.update!(tools: ["Time"])
- responses = [
- "timetimeBuenos Aires",
- "The time is 2023-12-14 17:24:00 -0300",
- ]
+ tool_call1 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "time",
+ id: "time",
+ parameters: {
+ timezone: "Buenos Aires",
+ },
+ )
+
+ tool_call2 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "time",
+ id: "time",
+ parameters: {
+ timezone: "Sydney",
+ },
+ )
+
+ responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
message =
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
@@ -470,7 +487,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# it also needs to have tool details now set on message
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
- expect(prompt.custom_prompt.length).to eq(3)
+
+ expect(prompt.custom_prompt.length).to eq(5)
# TODO in chat I am mixed on including this in the context, but I guess maybe?
# thinking about this
@@ -782,30 +800,29 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
it "supports multiple function calls" do
- response1 = (<<~TXT).strip
-
-
- search
- search
-
- testing various things
-
-
-
- search
- search
-
- another search
-
-
-
- TXT
+ tool_call1 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "search",
+ parameters: {
+ search_query: "testing various things",
+ },
+ )
+
+ tool_call2 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "search",
+ parameters: {
+ search_query: "another search",
+ },
+ )
response2 = "I found stuff"
- DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
- playground.reply_to(third_post)
- end
+ DiscourseAi::Completions::Llm.with_prepared_responses(
+ [[tool_call1, tool_call2], response2],
+ ) { playground.reply_to(third_post) }
last_post = third_post.topic.reload.posts.order(:post_number).last
@@ -819,17 +836,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new)
playground = described_class.new(bot)
- response1 = (<<~TXT).strip
-
-
- search
- search
-
- testing various things
-
-
-
- TXT
+ response1 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "search",
+ parameters: {
+ search_query: "testing various things",
+ },
+ )
response2 = "I found stuff"
@@ -843,17 +857,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
it "does not include placeholders in conversation context but includes all completions" do
- response1 = (<<~TXT).strip
-
-
- search
- search
-
- testing various things
-
-
-
- TXT
+ response1 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "search",
+ parameters: {
+ search_query: "testing various things",
+ },
+ )
response2 = "I found some really amazing stuff!"
@@ -889,17 +900,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
[{ b64_json: image, revised_prompt: "a pink cow 1" }]
end
- let(:response) { (<<~TXT).strip }
-
-
- dall_e
- dall_e
-
- ["a pink cow"]
-
-
-
- TXT
+ let(:response) do
+ DiscourseAi::Completions::ToolCall.new(
+ name: "dall_e",
+ id: "dall_e",
+ parameters: {
+ prompts: ["a pink cow"],
+ },
+ )
+ end
it "properly returns an image when skipping tool details" do
persona.update!(tool_details: false)
diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb
index fb42506e..16e0001b 100644
--- a/spec/requests/admin/ai_personas_controller_spec.rb
+++ b/spec/requests/admin/ai_personas_controller_spec.rb
@@ -541,16 +541,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(topic.title).to eq("An amazing title")
expect(topic.posts.count).to eq(2)
- # now let's try to make a reply with a tool call
- function_call = <<~XML
-
-
- categories
-
-
- XML
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1")
- fake_endpoint.fake_content = [function_call, "this is the response after the tool"]
+ fake_endpoint.fake_content = [tool_call, "this is the response after the tool"]
# this simplifies function calls
fake_endpoint.chunk_count = 1
diff --git a/spec/requests/ai_bot/bot_controller_spec.rb b/spec/requests/ai_bot/bot_controller_spec.rb
index e7430185..007e868b 100644
--- a/spec/requests/ai_bot/bot_controller_spec.rb
+++ b/spec/requests/ai_bot/bot_controller_spec.rb
@@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
fab!(:user)
fab!(:pm_topic) { Fabricate(:private_message_topic) }
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
+ fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
+ fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
before { sign_in(user) }
@@ -22,15 +24,37 @@ RSpec.describe DiscourseAi::AiBot::BotController do
user = pm_topic.topic_allowed_users.first.user
sign_in(user)
- AiApiAuditLog.create!(
- post_id: pm_post.id,
- provider_id: 1,
- topic_id: pm_topic.id,
- raw_request_payload: "request",
- raw_response_payload: "response",
- request_tokens: 1,
- response_tokens: 2,
- )
+ log1 =
+ AiApiAuditLog.create!(
+ provider_id: 1,
+ topic_id: pm_topic.id,
+ raw_request_payload: "request",
+ raw_response_payload: "response",
+ request_tokens: 1,
+ response_tokens: 2,
+ )
+
+ log2 =
+ AiApiAuditLog.create!(
+ post_id: pm_post.id,
+ provider_id: 1,
+ topic_id: pm_topic.id,
+ raw_request_payload: "request",
+ raw_response_payload: "response",
+ request_tokens: 1,
+ response_tokens: 2,
+ )
+
+ log3 =
+ AiApiAuditLog.create!(
+ post_id: pm_post2.id,
+ provider_id: 1,
+ topic_id: pm_topic.id,
+ raw_request_payload: "request",
+ raw_response_payload: "response",
+ request_tokens: 1,
+ response_tokens: 2,
+ )
Group.refresh_automatic_groups!
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
@@ -38,18 +62,26 @@ RSpec.describe DiscourseAi::AiBot::BotController do
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
expect(response.status).to eq(200)
+ expect(response.parsed_body["id"]).to eq(log2.id)
+ expect(response.parsed_body["next_log_id"]).to eq(log3.id)
+ expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
+ expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
+
expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2)
expect(response.parsed_body["raw_request_payload"]).to eq("request")
expect(response.parsed_body["raw_response_payload"]).to eq("response")
- post2 = Fabricate(:post, topic: pm_topic)
-
# return previous post if current has no debug info
- get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info"
+ get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
expect(response.status).to eq(200)
expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2)
+
+ # can return debug info by id as well
+ get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
+ expect(response.status).to eq(200)
+ expect(response.parsed_body["id"]).to eq(log1.id)
end
end