diff --git a/app/controllers/discourse_ai/ai_bot/bot_controller.rb b/app/controllers/discourse_ai/ai_bot/bot_controller.rb index e5d5bcf0..5ea13795 100644 --- a/app/controllers/discourse_ai/ai_bot/bot_controller.rb +++ b/app/controllers/discourse_ai/ai_bot/bot_controller.rb @@ -6,6 +6,14 @@ module DiscourseAi requires_plugin ::DiscourseAi::PLUGIN_NAME requires_login + def show_debug_info_by_id + log = AiApiAuditLog.find(params[:id]) + raise Discourse::NotFound if !log.topic + + guardian.ensure_can_debug_ai_bot_conversation!(log.topic) + render json: AiApiAuditLogSerializer.new(log, root: false), status: 200 + end + def show_debug_info post = Post.find(params[:post_id]) guardian.ensure_can_debug_ai_bot_conversation!(post) diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 2fa9f5c3..2fa0a214 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base Ollama = 7 SambaNova = 8 end + + def next_log_id + self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first + end + + def prev_log_id + self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first + end end # == Schema Information diff --git a/app/serializers/ai_api_audit_log_serializer.rb b/app/serializers/ai_api_audit_log_serializer.rb index 0c438a7b..eeb3843a 100644 --- a/app/serializers/ai_api_audit_log_serializer.rb +++ b/app/serializers/ai_api_audit_log_serializer.rb @@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer :post_id, :feature_name, :language_model, - :created_at + :created_at, + :prev_log_id, + :next_log_id end diff --git a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs index 5d0cdf69..c21e8df3 100644 --- a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs +++ b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs @@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template"; import DButton from "discourse/components/d-button"; import DModal from "discourse/components/d-modal"; import { ajax } from "discourse/lib/ajax"; +import { popupAjaxError } from "discourse/lib/ajax-error"; import { clipboardCopy, escapeExpression } from "discourse/lib/utilities"; import i18n from "discourse-common/helpers/i18n"; import discourseLater from "discourse-common/lib/later"; @@ -63,6 +64,28 @@ export default class DebugAiModal extends Component { this.copy(this.info.raw_response_payload); } + async loadLog(logId) { + try { + await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then( + (result) => { + this.info = result; + } + ); + } catch (e) { + popupAjaxError(e); + } + } + + @action + prevLog() { + this.loadLog(this.info.prev_log_id); + } + + @action + nextLog() { + this.loadLog(this.info.next_log_id); + } + copy(text) { clipboardCopy(text); this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared"); @@ -73,11 +96,13 @@ export default class DebugAiModal extends Component { } loadApiRequestInfo() { - ajax( - `/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json` - ).then((result) => { - this.info = result; - }); + ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`) + .then((result) => { + this.info = result; + }) + .catch((e) => { + popupAjaxError(e); + }); } get requestActive() { @@ -147,6 +172,22 @@ export default class DebugAiModal extends Component { @action={{this.copyResponse}} @label="discourse_ai.ai_bot.debug_ai_modal.copy_response" /> + {{#if this.info.prev_log_id}} + + {{/if}} + {{#if this.info.next_log_id}} + + {{/if}} {{this.justCopiedText}} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 2d517946..82898c91 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -415,6 +415,8 @@ en: response_tokens: "Response tokens:" request: "Request" response: "Response" + next_log: "Next" + previous_log: "Previous" share_full_topic_modal: title: "Share Conversation Publicly" diff --git a/config/routes.rb b/config/routes.rb index a5c009ff..322e67ce 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do get "bot-username" => "bot#show_bot_username" get "post/:post_id/show-debug-info" => "bot#show_debug_info" + get "show-debug-info/:id" => "bot#show_debug_info_by_id" post "post/:post_id/stop-streaming" => "bot#stop_streaming_response" end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 834ae059..b965b1f6 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -100,6 +100,7 @@ module DiscourseAi llm_kwargs[:top_p] = persona.top_p if persona.top_p needs_newlines = false + tools_ran = 0 while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false @@ -107,9 +108,10 @@ module DiscourseAi result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| - tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) + tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context) + tool = nil if tools_ran >= MAX_TOOLS - if (tools.present?) + if tool.present? tool_found = true # a bit hacky, but extra newlines do no harm if needs_newlines @@ -117,13 +119,16 @@ module DiscourseAi needs_newlines = false end - tools[0..MAX_TOOLS].each do |tool| - process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) - ongoing_chain &&= tool.chain_next_response? - end + process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) + tools_ran += 1 + ongoing_chain &&= tool.chain_next_response? else needs_newlines = true - update_blk.call(partial, cancel) + if partial.is_a?(DiscourseAi::Completions::ToolCall) + Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}") + else + update_blk.call(partial, cancel) + end end end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 73224808..63255a17 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -199,23 +199,16 @@ module DiscourseAi prompt end - def find_tools(partial, bot_user:, llm:, context:) - return [] if !partial.include?("") - - parsed_function = Nokogiri::HTML5.fragment(partial) - parsed_function - .css("invoke") - .map do |fragment| - tool_instance(fragment, bot_user: bot_user, llm: llm, context: context) - end - .compact + def find_tool(partial, bot_user:, llm:, context:) + return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall) + tool_instance(partial, bot_user: bot_user, llm: llm, context: context) end protected - def tool_instance(parsed_function, bot_user:, llm:, context:) - function_id = parsed_function.at("tool_id")&.text - function_name = parsed_function.at("tool_name")&.text + def tool_instance(tool_call, bot_user:, llm:, context:) + function_id = tool_call.id + function_name = tool_call.name return nil if function_name.nil? tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } @@ -224,7 +217,7 @@ module DiscourseAi arguments = {} tool_klass.signature[:parameters].to_a.each do |param| name = param[:name] - value = parsed_function.at(name)&.text + value = tool_call.parameters[name.to_sym] if param[:type] == "array" && value value = diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb index 1d1516fa..5d5602ef 100644 --- a/lib/completions/anthropic_message_processor.rb +++ b/lib/completions/anthropic_message_processor.rb @@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor def append(json) @raw_json << json end + + def to_tool_call + parameters = JSON.parse(raw_json, symbolize_names: true) + DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters) + end end attr_reader :tool_calls, :input_tokens, :output_tokens @@ -20,80 +25,69 @@ class DiscourseAi::Completions::AnthropicMessageProcessor def initialize(streaming_mode:) @streaming_mode = streaming_mode @tool_calls = [] + @current_tool_call = nil end - def to_xml_tool_calls(function_buffer) - return function_buffer if @tool_calls.blank? + def to_tool_calls + @tool_calls.map { |tool_call| tool_call.to_tool_call } + end - function_buffer = Nokogiri::HTML5.fragment(<<~TEXT) - - - TEXT - - @tool_calls.each do |tool_call| - node = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse( - DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n", - ), - ) - - params = JSON.parse(tool_call.raw_json, symbolize_names: true) - xml = - params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}" }.join("\n") - - node.at("tool_name").content = tool_call.name - node.at("tool_id").content = tool_call.id - node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present? + def process_streamed_message(parsed) + result = nil + if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" + tool_name = parsed.dig(:content_block, :name) + tool_id = parsed.dig(:content_block, :id) + result = @current_tool_call.to_tool_call if @current_tool_call + @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name + elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" + if @current_tool_call + tool_delta = parsed.dig(:delta, :partial_json).to_s + @current_tool_call.append(tool_delta) + else + result = parsed.dig(:delta, :text).to_s + end + elsif parsed[:type] == "content_block_stop" + if @current_tool_call + result = @current_tool_call.to_tool_call + @current_tool_call = nil + end + elsif parsed[:type] == "message_start" + @input_tokens = parsed.dig(:message, :usage, :input_tokens) + elsif parsed[:type] == "message_delta" + @output_tokens = + parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) + elsif parsed[:type] == "message_stop" + # bedrock has this ... + if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) + @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens + @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens + end end - - function_buffer + result end def process_message(payload) result = "" - parsed = JSON.parse(payload, symbolize_names: true) + parsed = payload + parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String) - if @streaming_mode - if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" - tool_name = parsed.dig(:content_block, :name) - tool_id = parsed.dig(:content_block, :id) - @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name - elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" - if @tool_calls.present? - result = parsed.dig(:delta, :partial_json).to_s - @tool_calls.last.append(result) - else - result = parsed.dig(:delta, :text).to_s + content = parsed.dig(:content) + if content.is_a?(Array) + result = + content.map do |data| + if data[:type] == "tool_use" + call = AnthropicToolCall.new(data[:name], data[:id]) + call.append(data[:input].to_json) + call.to_tool_call + else + data[:text] + end end - elsif parsed[:type] == "message_start" - @input_tokens = parsed.dig(:message, :usage, :input_tokens) - elsif parsed[:type] == "message_delta" - @output_tokens = - parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) - elsif parsed[:type] == "message_stop" - # bedrock has this ... - if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) - @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens - @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens - end - end - else - content = parsed.dig(:content) - if content.is_a?(Array) - tool_call = content.find { |c| c[:type] == "tool_use" } - if tool_call - @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id]) - @tool_calls.last.append(tool_call[:input].to_json) - else - result = parsed.dig(:content, 0, :text).to_s - end - end - - @input_tokens = parsed.dig(:usage, :input_tokens) - @output_tokens = parsed.dig(:usage, :output_tokens) end + @input_tokens = parsed.dig(:usage, :input_tokens) + @output_tokens = parsed.dig(:usage, :output_tokens) + result end end diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index 541d0e73..3a32e592 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -63,8 +63,23 @@ module DiscourseAi def user_msg(msg) user_message = { role: "user", content: msg[:content] } - # TODO: Add support for user messages with empbeded user ids - # TODO: Add support for user messages with attachments + encoded_uploads = prompt.encoded_uploads(msg) + if encoded_uploads.present? + images = + encoded_uploads + .map do |upload| + if upload[:mime_type].start_with?("image/") + upload[:base64] + else + nil + end + end + .compact + + user_message[:images] = images if images.present? + end + + # TODO: Add support for user messages with embedded user ids user_message end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 44762b88..6576ef3b 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -63,6 +63,10 @@ module DiscourseAi URI(llm_model.url) end + def xml_tools_enabled? + !@native_tool_support + end + def prepare_payload(prompt, model_params, dialect) @native_tool_support = dialect.native_tool_support? @@ -90,35 +94,34 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def decode_chunk(partial_data) + @decoder ||= JsonStreamDecoder.new + (@decoder << partial_data) + .map { |parsed_json| processor.process_streamed_message(parsed_json) } + .compact + end + + def decode(response_data) + processor.process_message(response_data) + end + def processor @processor ||= DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - processor.to_xml_tool_calls(function_buffer) if !partial - end - - def extract_completion_from(response_raw) - processor.process_message(response_raw) - end - def has_tool?(_response_data) processor.tool_calls.present? end + def tool_calls + processor.to_tool_calls + end + def final_log_update(log) log.request_tokens = processor.input_tokens if processor.input_tokens log.response_tokens = processor.output_tokens if processor.output_tokens end - - def native_tool_support? - @native_tool_support - end - - def partials_from(decoded_chunk) - decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact - end end end end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index f3146c2d..c17a051f 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -117,7 +117,24 @@ module DiscourseAi end end - def decode(chunk) + def decode_chunk(partial_data) + bedrock_decode(partial_data) + .map do |decoded_partial_data| + @raw_response ||= +"" + @raw_response << decoded_partial_data + @raw_response << "\n" + + parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true) + processor.process_streamed_message(parsed_json) + end + .compact + end + + def decode(response_data) + processor.process_message(response_data) + end + + def bedrock_decode(chunk) @decoder ||= Aws::EventStream::Decoder.new decoded, _done = @decoder.decode_chunk(chunk) @@ -147,12 +164,13 @@ module DiscourseAi Aws::EventStream::Errors::MessageChecksumError, Aws::EventStream::Errors::PreludeChecksumError => e Rails.logger.error("#{self.class.name}: #{e.message}") - nil + [] end def final_log_update(log) log.request_tokens = processor.input_tokens if processor.input_tokens log.response_tokens = processor.output_tokens if processor.output_tokens + log.raw_response_payload = @raw_response end def processor @@ -160,30 +178,8 @@ module DiscourseAi DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - processor.to_xml_tool_calls(function_buffer) if !partial - end - - def extract_completion_from(response_raw) - processor.process_message(response_raw) - end - - def has_tool?(_response_data) - processor.tool_calls.present? - end - - def partials_from(decoded_chunks) - decoded_chunks - end - - def native_tool_support? - @native_tool_support - end - - def chunk_to_string(chunk) - joined = +chunk.join("\n") - joined << "\n" if joined.length > 0 - joined + def xml_tools_enabled? + !@native_tool_support end end end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index a0405b42..c78fcdd9 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -40,10 +40,6 @@ module DiscourseAi @llm_model = llm_model end - def native_tool_support? - false - end - def use_ssl? if model_uri&.scheme.present? model_uri.scheme == "https" @@ -64,22 +60,10 @@ module DiscourseAi feature_context: nil, &blk ) - allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) orig_blk = blk @streaming_mode = block_given? - to_strip = xml_tags_to_strip(dialect) - @xml_stripper = - DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? - - if @streaming_mode && @xml_stripper - blk = - lambda do |partial, cancel| - partial = @xml_stripper << partial - orig_blk.call(partial, cancel) if partial - end - end prompt = dialect.translate @@ -108,177 +92,91 @@ module DiscourseAi raise CompletionFailed, response.body end + xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? && + dialect.prompt.has_tools? + + to_strip = xml_tags_to_strip(dialect) + xml_stripper = + DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? + + if @streaming_mode && xml_stripper + blk = + lambda do |partial, cancel| + partial = xml_stripper << partial if partial.is_a?(String) + orig_blk.call(partial, cancel) if partial + end + end + log = - AiApiAuditLog.new( + start_log( provider_id: provider_id, - user_id: user&.id, - raw_request_payload: request_body, - request_tokens: prompt_size(prompt), - topic_id: dialect.prompt.topic_id, - post_id: dialect.prompt.post_id, + request_body: request_body, + dialect: dialect, + prompt: prompt, + user: user, feature_name: feature_name, - language_model: llm_model.name, - feature_context: feature_context.present? ? feature_context.as_json : nil, + feature_context: feature_context, ) if !@streaming_mode - response_raw = response.read_body - response_data = extract_completion_from(response_raw) - partials_raw = response_data.to_s - - if native_tool_support? - if allow_tools && has_tool?(response_data) - function_buffer = build_buffer # Nokogiri document - function_buffer = - add_to_function_buffer(function_buffer, payload: response_data) - FunctionCallNormalizer.normalize_function_ids!(function_buffer) - - response_data = +function_buffer.at("function_calls").to_s - response_data << "\n" - end - else - if allow_tools - response_data, function_calls = FunctionCallNormalizer.normalize(response_data) - response_data = function_calls if function_calls.present? - end - end - - return response_data + return( + non_streaming_response( + response: response, + xml_tool_processor: xml_tool_processor, + xml_stripper: xml_stripper, + partials_raw: partials_raw, + response_raw: response_raw, + ) + ) end - has_tool = false - begin cancelled = false cancel = -> { cancelled = true } - - wrapped_blk = ->(partial, inner_cancel) do - response_data << partial - blk.call(partial, inner_cancel) + if cancelled + http.finish + break end - normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel) - - leftover = "" - function_buffer = build_buffer # Nokogiri document - prev_processed_partials = 0 - response.read_body do |chunk| - if cancelled - http.finish - break - end - - decoded_chunk = decode(chunk) - if decoded_chunk.nil? - raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion" - end - response_raw << chunk_to_string(decoded_chunk) - - if decoded_chunk.is_a?(String) - redo_chunk = leftover + decoded_chunk - else - # custom implementation for endpoint - # no implicit leftover support - redo_chunk = decoded_chunk - end - - raw_partials = partials_from(redo_chunk) - - raw_partials = - raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0 - - if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?) - leftover = redo_chunk - next - end - - json_error = false - - raw_partials.each do |raw_partial| - json_error = false - prev_processed_partials += 1 - - next if cancelled - next if raw_partial.blank? - - begin - partial = extract_completion_from(raw_partial) - next if partial.nil? - # empty vs blank... we still accept " " - next if response_data.empty? && partial.empty? - partials_raw << partial.to_s - - if native_tool_support? - # Stop streaming the response as soon as you find a tool. - # We'll buffer and yield it later. - has_tool = true if allow_tools && has_tool?(partials_raw) - - if has_tool - function_buffer = - add_to_function_buffer(function_buffer, partial: partial) - else - response_data << partial - blk.call(partial, cancel) if partial - end - else - if allow_tools - normalizer << partial - else - response_data << partial - blk.call(partial, cancel) if partial - end + response_raw << chunk + decode_chunk(chunk).each do |partial| + partials_raw << partial.to_s + response_data << partial if partial.is_a?(String) + partials = [partial] + if xml_tool_processor && partial.is_a?(String) + partials = (xml_tool_processor << partial) + if xml_tool_processor.should_cancel? + cancel.call + break end - rescue JSON::ParserError - leftover = redo_chunk - json_error = true end + partials.each { |inner_partial| blk.call(inner_partial, cancel) } end - - if json_error - prev_processed_partials -= 1 - else - leftover = "" - end - - prev_processed_partials = 0 if leftover.blank? end rescue IOError, StandardError raise if !cancelled end - - has_tool ||= has_tool?(partials_raw) - # Once we have the full response, try to return the tool as a XML doc. - if has_tool && native_tool_support? - function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw) - - if function_buffer.at("tool_name").text.present? - FunctionCallNormalizer.normalize_function_ids!(function_buffer) - - invocation = +function_buffer.at("function_calls").to_s - invocation << "\n" - - response_data << invocation - blk.call(invocation, cancel) + if xml_stripper + stripped = xml_stripper.finish + if stripped.present? + response_data << stripped + result = [] + result = (xml_tool_processor << stripped) if xml_tool_processor + result.each { |partial| blk.call(partial, cancel) } end end - - if !native_tool_support? && function_calls = normalizer.function_calls - response_data << function_calls - blk.call(function_calls, cancel) + if xml_tool_processor + xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) } end - - if @xml_stripper - leftover = @xml_stripper.finish - orig_blk.call(leftover, cancel) if leftover.present? - end - + decode_chunk_finish.each { |partial| blk.call(partial, cancel) } return response_data ensure if log log.raw_response_payload = response_raw - log.response_tokens = tokenizer.size(partials_raw) final_log_update(log) + + log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank? log.save! if Rails.env.development? @@ -330,15 +228,15 @@ module DiscourseAi raise NotImplementedError end - def extract_completion_from(_response_raw) + def decode(_response_raw) raise NotImplementedError end - def decode(chunk) - chunk + def decode_chunk_finish + [] end - def partials_from(_decoded_chunk) + def decode_chunk(_chunk) raise NotImplementedError end @@ -346,49 +244,73 @@ module DiscourseAi prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end - def build_buffer - Nokogiri::HTML5.fragment(<<~TEXT) - - #{noop_function_call_text} - - TEXT + def xml_tools_enabled? + raise NotImplementedError end - def self.noop_function_call_text - (<<~TEXT).strip - - - - - - - TEXT + private + + def start_log( + provider_id:, + request_body:, + dialect:, + prompt:, + user:, + feature_name:, + feature_context: + ) + AiApiAuditLog.new( + provider_id: provider_id, + user_id: user&.id, + raw_request_payload: request_body, + request_tokens: prompt_size(prompt), + topic_id: dialect.prompt.topic_id, + post_id: dialect.prompt.post_id, + feature_name: feature_name, + language_model: llm_model.name, + feature_context: feature_context.present? ? feature_context.as_json : nil, + ) end - def noop_function_call_text - self.class.noop_function_call_text - end + def non_streaming_response( + response:, + xml_tool_processor:, + xml_stripper:, + partials_raw:, + response_raw: + ) + response_raw << response.read_body + response_data = decode(response_raw) - def has_tool?(response) - response.include?("") - end + response_data.each { |partial| partials_raw << partial.to_s } - def chunk_to_string(chunk) - if chunk.is_a?(String) - chunk - else - chunk.to_s - end - end - - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if payload&.include?("") - matches = payload.match(%r{.*}m) - function_buffer = - Nokogiri::HTML5.fragment(matches[0] + "\n") if matches + if xml_tool_processor + response_data.each do |partial| + processed = (xml_tool_processor << partial) + processed << xml_tool_processor.finish + response_data = [] + processed.flatten.compact.each { |inner| response_data << inner } + end end - function_buffer + if xml_stripper + response_data.map! do |partial| + stripped = (xml_stripper << partial) if partial.is_a?(String) + if stripped.present? + stripped + else + partial + end + end + response_data << xml_stripper.finish + end + + response_data.reject!(&:blank?) + + # this is to keep stuff backwards compatible + response_data = response_data.first if response_data.length == 1 + + response_data end end end diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index eaef21da..bd3ae4ea 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -45,17 +45,21 @@ module DiscourseAi cancel_fn = lambda { cancelled = true } # We buffer and return tool invocations in one go. - if is_tool?(response) - yield(response, cancel_fn) - else - response.each_char do |char| - break if cancelled - yield(char, cancel_fn) + as_array = response.is_a?(Array) ? response : [response] + as_array.each do |response| + if is_tool?(response) + yield(response, cancel_fn) + else + response.each_char do |char| + break if cancelled + yield(char, cancel_fn) + end end end - else - response end + + response = response.first if response.is_a?(Array) && response.length == 1 + response end def tokenizer @@ -65,7 +69,7 @@ module DiscourseAi private def is_tool?(response) - Nokogiri::HTML5.fragment(response).at("function_calls").present? + response.is_a?(DiscourseAi::Completions::ToolCall) end end end diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb index 180c27c8..258062a1 100644 --- a/lib/completions/endpoints/cohere.rb +++ b/lib/completions/endpoints/cohere.rb @@ -49,6 +49,47 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def decode(response_raw) + rval = [] + + parsed = JSON.parse(response_raw, symbolize_names: true) + + text = parsed[:text] + rval << parsed[:text] if !text.to_s.empty? # also allow " " + + # TODO tool calls + + update_usage(parsed) + + rval + end + + def decode_chunk(chunk) + @tool_idx ||= -1 + @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/) + (@json_decoder << chunk) + .map do |parsed| + update_usage(parsed) + rval = [] + + rval << parsed[:text] if !parsed[:text].to_s.empty? + + if tool_calls = parsed[:tool_calls] + tool_calls&.each do |tool_call| + @tool_idx += 1 + tool_name = tool_call[:name] + tool_params = tool_call[:parameters] + tool_id = "tool_#{@tool_idx}" + rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params) + end + end + + rval + end + .flatten + .compact + end + def extract_completion_from(response_raw) parsed = JSON.parse(response_raw, symbolize_names: true) @@ -77,36 +118,8 @@ module DiscourseAi end end - def has_tool?(_ignored) - @has_tool - end - - def native_tool_support? - true - end - - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if partial - tools = JSON.parse(partial) - tools.each do |tool| - name = tool["name"] - parameters = tool["parameters"] - xml_params = parameters.map { |k, v| "<#{k}>#{v}\n" }.join - - current_function = function_buffer.at("invoke") - if current_function.nil? || current_function.at("tool_name").content.present? - current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - - current_function.at("tool_name").content = name == "search_local" ? "search" : name - current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(xml_params) - end - end - function_buffer + def xml_tools_enabled? + false end def final_log_update(log) @@ -114,10 +127,6 @@ module DiscourseAi log.response_tokens = @output_tokens if @output_tokens end - def partials_from(decoded_chunk) - decoded_chunk.split("\n").compact - end - def extract_prompt_for_tokenizer(prompt) text = +"" if prompt[:chat_history] @@ -131,6 +140,18 @@ module DiscourseAi text end + + private + + def update_usage(parsed) + input_tokens = parsed.dig(:meta, :billed_units, :input_tokens) + input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens) + @input_tokens = input_tokens if input_tokens.present? + + output_tokens = parsed.dig(:meta, :billed_units, :output_tokens) + output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens) + @output_tokens = output_tokens if output_tokens.present? + end end end end diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index a51ff3ac..15cc254d 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -133,31 +133,35 @@ module DiscourseAi content = content.shift if content.is_a?(Array) if block_given? - split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort - indexes = [0, *split_indices, content.length] + if content.is_a?(DiscourseAi::Completions::ToolCall) + yield(content, -> {}) + else + split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort + indexes = [0, *split_indices, content.length] - original_content = content - content = +"" + original_content = content + content = +"" - cancel = false - cancel_proc = -> { cancel = true } + cancel = false + cancel_proc = -> { cancel = true } - i = 0 - indexes - .each_cons(2) - .map { |start, finish| original_content[start...finish] } - .each do |chunk| - break if cancel - if self.class.delays.present? && - (delay = self.class.delays[i % self.class.delays.length]) - sleep(delay) - i += 1 + i = 0 + indexes + .each_cons(2) + .map { |start, finish| original_content[start...finish] } + .each do |chunk| + break if cancel + if self.class.delays.present? && + (delay = self.class.delays[i % self.class.delays.length]) + sleep(delay) + i += 1 + end + break if cancel + + content << chunk + yield(chunk, cancel_proc) end - break if cancel - - content << chunk - yield(chunk, cancel_proc) - end + end end content diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index ddf607b2..2450dc99 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -103,15 +103,7 @@ module DiscourseAi end end - def partials_from(decoded_chunk) - decoded_chunk - end - - def chunk_to_string(chunk) - chunk.to_s - end - - class Decoder + class GeminiStreamingDecoder def initialize @buffer = +"" end @@ -151,43 +143,87 @@ module DiscourseAi end def decode(chunk) - @decoder ||= Decoder.new - @decoder.decode(chunk) + json = JSON.parse(chunk, symbolize_names: true) + idx = -1 + json + .dig(:candidates, 0, :content, :parts) + .map do |part| + if part[:functionCall] + idx += 1 + ToolCall.new( + id: "tool_#{idx}", + name: part[:functionCall][:name], + parameters: part[:functionCall][:args], + ) + else + part = part[:text] + if part != "" + part + else + nil + end + end + end + end + + def decode_chunk(chunk) + @tool_index ||= -1 + + streaming_decoder + .decode(chunk) + .map do |parsed| + update_usage(parsed) + parsed + .dig(:candidates, 0, :content, :parts) + .map do |part| + if part[:text] + part = part[:text] + if part != "" + part + else + nil + end + elsif part[:functionCall] + @tool_index += 1 + ToolCall.new( + id: "tool_#{@tool_index}", + name: part[:functionCall][:name], + parameters: part[:functionCall][:args], + ) + end + end + end + .flatten + .compact + end + + def update_usage(parsed) + usage = parsed.dig(:usageMetadata) + if usage + if prompt_token_count = usage[:promptTokenCount] + @prompt_token_count = prompt_token_count + end + if candidate_token_count = usage[:candidatesTokenCount] + @candidate_token_count = candidate_token_count + end + end + end + + def final_log_update(log) + log.request_tokens = @prompt_token_count if @prompt_token_count + log.response_tokens = @candidate_token_count if @candidate_token_count + end + + def streaming_decoder + @decoder ||= GeminiStreamingDecoder.new end def extract_prompt_for_tokenizer(prompt) prompt.to_s end - def has_tool?(_response_data) - @has_function_call - end - - def native_tool_support? - true - end - - def add_to_function_buffer(function_buffer, payload: nil, partial: nil) - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? - - if partial[:args] - argument_fragments = - partial[:args].reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}" - end - argument_fragments << "\n" - - function_buffer.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - end - - function_buffer + def xml_tools_enabled? + false end end end diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index bd7edc06..b0b14722 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -59,22 +59,30 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here - return if !parsed - - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - - response_h.dig(:content) + def xml_tools_enabled? + true end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data:", 2)[1] - data&.squish == "[DONE]" ? nil : data + def decode(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true) + text = parsed.dig(:choices, 0, :message, :content) + if text.to_s.empty? + [""] + else + [text] + end + end + + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |parsed| + text = parsed.dig(:choices, 0, :delta, :content) + if text.to_s.empty? + nil + else + text + end end .compact end diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb index cc58006a..dd4ca2c7 100644 --- a/lib/completions/endpoints/ollama.rb +++ b/lib/completions/endpoints/ollama.rb @@ -37,12 +37,8 @@ module DiscourseAi URI(llm_model.url) end - def native_tool_support? - @native_tool_support - end - - def has_tool?(_response_data) - @has_function_call + def xml_tools_enabled? + !@native_tool_support end def prepare_payload(prompt, model_params, dialect) @@ -67,74 +63,30 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def partials_from(decoded_chunk) - decoded_chunk.split("\n").compact + def decode_chunk(chunk) + # Native tool calls are not working right in streaming mode, use XML + @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/) + (@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact end - def extract_completion_from(response_raw) + def decode(response_raw) + rval = [] parsed = JSON.parse(response_raw, symbolize_names: true) - return if !parsed + content = parsed.dig(:message, :content) + rval << content if !content.to_s.empty? - response_h = parsed.dig(:message) - - @has_function_call ||= response_h.dig(:tool_calls).present? - @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) - end - - def add_to_function_buffer(function_buffer, payload: nil, partial: nil) - @args_buffer ||= +"" - - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - f_name = partial.dig(:function, :name) - - @current_function ||= function_buffer.at("invoke") - - if f_name - current_name = function_buffer.at("tool_name").content - - if current_name.blank? - # first call - else - # we have a previous function, so we need to add a noop - @args_buffer = +"" - @current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) + idx = -1 + parsed + .dig(:message, :tool_calls) + &.each do |tool_call| + idx += 1 + id = "tool_#{idx}" + name = tool_call.dig(:function, :name) + args = tool_call.dig(:function, :arguments) + rval << ToolCall.new(id: id, name: name, parameters: args) end - end - @current_function.at("tool_name").content = f_name if f_name - @current_function.at("tool_id").content = partial[:id] if partial[:id] - - args = partial.dig(:function, :arguments) - - # allow for SPACE within arguments - if args && args != "" - @args_buffer << args.to_json - - begin - json_args = JSON.parse(@args_buffer, symbolize_names: true) - - argument_fragments = - json_args.reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{value}" - end - argument_fragments << "\n" - - @current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - rescue JSON::ParserError - return function_buffer - end - end - - function_buffer + rval end end end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 92315ed5..a185a840 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -93,98 +93,34 @@ module DiscourseAi end def final_log_update(log) - log.request_tokens = @prompt_tokens if @prompt_tokens - log.response_tokens = @completion_tokens if @completion_tokens + log.request_tokens = processor.prompt_tokens if processor.prompt_tokens + log.response_tokens = processor.completion_tokens if processor.completion_tokens end - def extract_completion_from(response_raw) - json = JSON.parse(response_raw, symbolize_names: true) - - if @streaming_mode - @prompt_tokens ||= json.dig(:usage, :prompt_tokens) - @completion_tokens ||= json.dig(:usage, :completion_tokens) - end - - parsed = json.dig(:choices, 0) - return if !parsed - - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - @has_function_call ||= response_h.dig(:tool_calls).present? - @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) + def decode(response_raw) + processor.process_message(JSON.parse(response_raw, symbolize_names: true)) end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data - end + def decode_chunk(chunk) + @decoder ||= JsonStreamDecoder.new + (@decoder << chunk) + .map { |parsed_json| processor.process_streamed_message(parsed_json) } + .flatten .compact end - def has_tool?(_response_data) - @has_function_call + def decode_chunk_finish + @processor.finish end - def native_tool_support? - true + def xml_tools_enabled? + false end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end + private - @args_buffer ||= +"" - - f_name = partial.dig(:function, :name) - - @current_function ||= function_buffer.at("invoke") - - if f_name - current_name = function_buffer.at("tool_name").content - - if current_name.blank? - # first call - else - # we have a previous function, so we need to add a noop - @args_buffer = +"" - @current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - end - - @current_function.at("tool_name").content = f_name if f_name - @current_function.at("tool_id").content = partial[:id] if partial[:id] - - args = partial.dig(:function, :arguments) - - # allow for SPACE within arguments - if args && args != "" - @args_buffer << args - - begin - json_args = JSON.parse(@args_buffer, symbolize_names: true) - - argument_fragments = - json_args.reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}" - end - argument_fragments << "\n" - - @current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - rescue JSON::ParserError - return function_buffer - end - end - - function_buffer + def processor + @processor ||= OpenAiMessageProcessor.new end end end diff --git a/lib/completions/endpoints/samba_nova.rb b/lib/completions/endpoints/samba_nova.rb index ccb883cc..cc81e786 100644 --- a/lib/completions/endpoints/samba_nova.rb +++ b/lib/completions/endpoints/samba_nova.rb @@ -55,27 +55,31 @@ module DiscourseAi log.response_tokens = @completion_tokens if @completion_tokens end - def extract_completion_from(response_raw) - json = JSON.parse(response_raw, symbolize_names: true) - - if @streaming_mode - @prompt_tokens ||= json.dig(:usage, :prompt_tokens) - @completion_tokens ||= json.dig(:usage, :completion_tokens) - end - - parsed = json.dig(:choices, 0) - return if !parsed - - @streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content) + def xml_tools_enabled? + true end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data + def decode(response_raw) + json = JSON.parse(response_raw, symbolize_names: true) + [json.dig(:choices, 0, :message, :content)] + end + + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |json| + text = json.dig(:choices, 0, :delta, :content) + + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) + + if !text.to_s.empty? + text + else + nil + end end + .flatten .compact end end diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 57fcf051..6b371a09 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -42,7 +42,10 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) payload = default_options.merge(model_params).merge(messages: prompt) - payload[:stream] = true if @streaming_mode + if @streaming_mode + payload[:stream] = true if @streaming_mode + payload[:stream_options] = { include_usage: true } + end payload end @@ -56,24 +59,42 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data - end - .compact + def xml_tools_enabled? + true end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here - return if !parsed + def final_log_update(log) + log.request_tokens = @prompt_tokens if @prompt_tokens + log.response_tokens = @completion_tokens if @completion_tokens + end - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + def decode(response_raw) + json = JSON.parse(response_raw, symbolize_names: true) + @prompt_tokens = json.dig(:usage, :prompt_tokens) + @completion_tokens = json.dig(:usage, :completion_tokens) + [json.dig(:choices, 0, :message, :content)] + end - response_h.dig(:content) + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |parsed| + # vLLM keeps sending usage over and over again + prompt_tokens = parsed.dig(:usage, :prompt_tokens) + completion_tokens = parsed.dig(:usage, :completion_tokens) + + @prompt_tokens = prompt_tokens if prompt_tokens + + @completion_tokens = completion_tokens if completion_tokens + + text = parsed.dig(:choices, 0, :delta, :content) + if text.to_s.empty? + nil + else + text + end + end + .compact end end end diff --git a/lib/completions/function_call_normalizer.rb b/lib/completions/function_call_normalizer.rb deleted file mode 100644 index ef40809c..00000000 --- a/lib/completions/function_call_normalizer.rb +++ /dev/null @@ -1,113 +0,0 @@ -# frozen_string_literal: true - -class DiscourseAi::Completions::FunctionCallNormalizer - attr_reader :done - - # blk is the block to call with filtered data - def initialize(blk, cancel) - @blk = blk - @cancel = cancel - @done = false - - @in_tool = false - - @buffer = +"" - @function_buffer = +"" - end - - def self.normalize(data) - text = +"" - cancel = -> {} - blk = ->(partial, _) { text << partial } - - normalizer = self.new(blk, cancel) - normalizer << data - - [text, normalizer.function_calls] - end - - def function_calls - return nil if @function_buffer.blank? - - xml = Nokogiri::HTML5.fragment(@function_buffer) - self.class.normalize_function_ids!(xml) - last_invoke = xml.at("invoke:last") - if last_invoke - last_invoke.next_sibling.remove while last_invoke.next_sibling - xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling - end - xml.at("function_calls").to_s.dup.force_encoding("UTF-8") - end - - def <<(text) - @buffer << text - - if !@in_tool - # double check if we are clearly in a tool - search_length = text.length + 20 - search_string = @buffer[-search_length..-1] || @buffer - - index = search_string.rindex("") - @in_tool = !!index - if @in_tool - @function_buffer = @buffer[index..-1] - text_index = text.rindex("") - @blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0 - end - else - @function_buffer << text - end - - if !@in_tool - if maybe_has_tool?(@buffer) - split_index = text.rindex("<").to_i - 1 - if split_index >= 0 - @function_buffer = text[split_index + 1..-1] || "" - text = text[0..split_index] || "" - else - @function_buffer << text - text = "" - end - else - if @function_buffer.length > 0 - @blk.call(@function_buffer, @cancel) - @function_buffer = +"" - end - end - - @blk.call(text, @cancel) if text.length > 0 - else - if text.include?("") - @done = true - @cancel.call - end - end - end - - def self.normalize_function_ids!(function_buffer) - function_buffer - .css("invoke") - .each_with_index do |invoke, index| - if invoke.at("tool_id") - invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? - else - invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") - end - end - end - - private - - def maybe_has_tool?(text) - # 16 is the length of function calls - substring = text[-16..-1] || text - split = substring.split("<") - - if split.length > 1 - match = "<" + split.last - "".start_with?(match) - else - substring.ends_with?("<") - end - end -end diff --git a/lib/completions/json_stream_decoder.rb b/lib/completions/json_stream_decoder.rb new file mode 100644 index 00000000..e575a3b7 --- /dev/null +++ b/lib/completions/json_stream_decoder.rb @@ -0,0 +1,48 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + # will work for anthropic and open ai compatible + class JsonStreamDecoder + attr_reader :buffer + + LINE_REGEX = /data: ({.*})\s*$/ + + def initialize(symbolize_keys: true, line_regex: LINE_REGEX) + @symbolize_keys = symbolize_keys + @buffer = +"" + @line_regex = line_regex + end + + def <<(raw) + @buffer << raw.to_s + rval = [] + + split = @buffer.scan(/.*\n?/) + split.pop if split.last.blank? + + @buffer = +(split.pop.to_s) + + split.each do |line| + matches = line.match(@line_regex) + next if !matches + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + end + + if @buffer.present? + matches = @buffer.match(@line_regex) + if matches + begin + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + @buffer = +"" + rescue JSON::ParserError + # maybe it is a partial line + end + end + end + + rval + end + end + end +end diff --git a/lib/completions/open_ai_message_processor.rb b/lib/completions/open_ai_message_processor.rb new file mode 100644 index 00000000..02369bec --- /dev/null +++ b/lib/completions/open_ai_message_processor.rb @@ -0,0 +1,103 @@ +# frozen_string_literal: true +module DiscourseAi::Completions + class OpenAiMessageProcessor + attr_reader :prompt_tokens, :completion_tokens + + def initialize + @tool = nil + @tool_arguments = +"" + @prompt_tokens = nil + @completion_tokens = nil + end + + def process_message(json) + result = [] + tool_calls = json.dig(:choices, 0, :message, :tool_calls) + + message = json.dig(:choices, 0, :message, :content) + result << message if message.present? + + if tool_calls.present? + tool_calls.each do |tool_call| + id = tool_call.dig(:id) + name = tool_call.dig(:function, :name) + arguments = tool_call.dig(:function, :arguments) + parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {} + result << ToolCall.new(id: id, name: name, parameters: parameters) + end + end + + update_usage(json) + + result + end + + def process_streamed_message(json) + rval = nil + + tool_calls = json.dig(:choices, 0, :delta, :tool_calls) + content = json.dig(:choices, 0, :delta, :content) + + finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == [] + + if tool_calls.present? + id = tool_calls.dig(0, :id) + name = tool_calls.dig(0, :function, :name) + arguments = tool_calls.dig(0, :function, :arguments) + + # TODO: multiple tool support may require index + #index = tool_calls[0].dig(:index) + + if id.present? && @tool && @tool.id != id + process_arguments + rval = @tool + @tool = nil + end + + if id.present? && name.present? + @tool_arguments = +"" + @tool = ToolCall.new(id: id, name: name) + end + + @tool_arguments << arguments.to_s + elsif finished_tools && @tool + parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) + @tool.parameters = parsed_args + rval = @tool + @tool = nil + elsif content.present? + rval = content + end + + update_usage(json) + + rval + end + + def finish + rval = [] + if @tool + process_arguments + rval << @tool + @tool = nil + end + + rval + end + + private + + def process_arguments + if @tool_arguments.present? + parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) + @tool.parameters = parsed_args + @tool_arguments = nil + end + end + + def update_usage(json) + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) + end + end +end diff --git a/lib/completions/tool_call.rb b/lib/completions/tool_call.rb new file mode 100644 index 00000000..15be7b3f --- /dev/null +++ b/lib/completions/tool_call.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class ToolCall + attr_reader :id, :name, :parameters + + def initialize(id:, name:, parameters: nil) + @id = id + @name = name + self.parameters = parameters if parameters + @parameters ||= {} + end + + def parameters=(parameters) + raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash) + @parameters = parameters.symbolize_keys + end + + def ==(other) + id == other.id && name == other.name && parameters == other.parameters + end + + def to_s + "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)" + end + end + end +end diff --git a/lib/completions/xml_tool_processor.rb b/lib/completions/xml_tool_processor.rb new file mode 100644 index 00000000..1b42b333 --- /dev/null +++ b/lib/completions/xml_tool_processor.rb @@ -0,0 +1,124 @@ +# frozen_string_literal: true + +# This class can be used to process a stream of text that may contain XML tool +# calls. +# It will return either text or ToolCall objects. + +module DiscourseAi + module Completions + class XmlToolProcessor + def initialize + @buffer = +"" + @function_buffer = +"" + @should_cancel = false + @in_tool = false + end + + def <<(text) + @buffer << text + result = [] + + if !@in_tool + # double check if we are clearly in a tool + search_length = text.length + 20 + search_string = @buffer[-search_length..-1] || @buffer + + index = search_string.rindex("") + @in_tool = !!index + if @in_tool + @function_buffer = @buffer[index..-1] + text_index = text.rindex("") + result << text[0..text_index - 1].strip if text_index && text_index > 0 + end + else + @function_buffer << text + end + + if !@in_tool + if maybe_has_tool?(@buffer) + split_index = text.rindex("<").to_i - 1 + if split_index >= 0 + @function_buffer = text[split_index + 1..-1] || "" + text = text[0..split_index] || "" + else + @function_buffer << text + text = "" + end + else + if @function_buffer.length > 0 + result << @function_buffer + @function_buffer = +"" + end + end + + result << text if text.length > 0 + else + @should_cancel = true if text.include?("") + end + + result + end + + def finish + return [] if @function_buffer.blank? + + xml = Nokogiri::HTML5.fragment(@function_buffer) + normalize_function_ids!(xml) + last_invoke = xml.at("invoke:last") + if last_invoke + last_invoke.next_sibling.remove while last_invoke.next_sibling + xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling + end + + xml + .css("invoke") + .map do |invoke| + tool_name = invoke.at("tool_name").content.force_encoding("UTF-8") + tool_id = invoke.at("tool_id").content.force_encoding("UTF-8") + parameters = {} + invoke + .at("parameters") + &.children + &.each do |node| + next if node.text? + name = node.name + value = node.content.to_s + parameters[name.to_sym] = value.to_s.force_encoding("UTF-8") + end + ToolCall.new(id: tool_id, name: tool_name, parameters: parameters) + end + end + + def should_cancel? + @should_cancel + end + + private + + def normalize_function_ids!(function_buffer) + function_buffer + .css("invoke") + .each_with_index do |invoke, index| + if invoke.at("tool_id") + invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? + else + invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") + end + end + end + + def maybe_has_tool?(text) + # 16 is the length of function calls + substring = text[-16..-1] || text + split = substring.split("<") + + if split.length > 1 + match = "<" + split.last + "".start_with?(match) + else + substring.ends_with?("<") + end + end + end + end +end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 94d5d655..40eca30f 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do data: {"type":"message_stop"} STRING - result = +"" + result = [] body = body.scan(/.*\n/) EndpointMock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: body) @@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end end - expected = (<<~TEXT).strip - - - search - s<a>m sam - general - toolu_01DjrShFRRHp9SnHYRFRc53F - - - TEXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "toolu_01DjrShFRRHp9SnHYRFRc53F", + parameters: { + search_query: "sm sam", + category: "general", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq([tool_call]) end it "can stream a response" do @@ -191,6 +190,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do expect(log.feature_name).to eq("testing") expect(log.response_tokens).to eq(15) expect(log.request_tokens).to eq(25) + expect(log.raw_request_payload).to eq(expected_body.to_json) + expect(log.raw_response_payload.strip).to eq(body.strip) end it "supports non streaming tool calls" do @@ -242,17 +243,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do result = llm.generate(prompt, user: Discourse.system_user) - expected = <<~TEXT.strip - - - calculate - 2758975 + 21.11 - toolu_012kBdhG4eHaV68W56p4N94h - - - TEXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "calculate", + id: "toolu_012kBdhG4eHaV68W56p4N94h", + parameters: { + expression: "2758975 + 21.11", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq(["Here is the calculation:", tool_call]) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(345) + expect(log.response_tokens).to eq(65) end it "can send images via a completion prompt" do diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index d9519344..2a9cc77f 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -79,7 +79,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do } prompt.tools = [tool] - response = +"" + response = [] proxy.generate(prompt, user: user) { |partial| response << partial } expect(request.headers["Authorization"]).to be_present @@ -90,21 +90,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do expect(parsed_body["tools"]).to eq(nil) expect(parsed_body["stop_sequences"]).to eq([""]) - # note we now have a tool_id cause we were normalized - function_call = <<~XML.strip - hello + expected = [ + "hello\n", + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "google", + parameters: { + query: "sydney weather today", + }, + ), + ] - - - - google - sydney weather today - tool_0 - - - XML - - expect(response.strip).to eq(function_call) + expect(response).to eq(expected) end end @@ -230,23 +227,23 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do } prompt.tools = [tool] - response = +"" + response = [] proxy.generate(prompt, user: user) { |partial| response << partial } expect(request.headers["Authorization"]).to be_present expect(request.headers["X-Amz-Content-Sha256"]).to be_present - expected_response = (<<~RESPONSE).strip - - - google - sydney weather today - toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7 - - - RESPONSE + expected_response = [ + DiscourseAi::Completions::ToolCall.new( + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "google", + parameters: { + query: "sydney weather today", + }, + ), + ] - expect(response.strip).to eq(expected_response) + expect(response).to eq(expected_response) expected = { "max_tokens" => 3000, diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb index 4bb213ff..bdff8fc3 100644 --- a/spec/lib/completions/endpoints/cohere_spec.rb +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -66,7 +66,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do TEXT parsed_body = nil - result = +"" + result = [] sig = { name: "google", @@ -91,21 +91,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do }, ).to_return(status: 200, body: body.split("|")) - result = llm.generate(prompt, user: user) { |partial, cancel| result << partial } + llm.generate(prompt, user: user) { |partial, cancel| result << partial } end - expected = <<~TEXT - - - google - who is sam saffron - - tool_0 - - - TEXT + text = "I will search for 'who is sam saffron' and relay the information to the user." + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "google", + parameters: { + query: "who is sam saffron", + }, + ) - expect(result.strip).to eq(expected.strip) + expect(result).to eq([text, tool_call]) expected = { model: "command-r-plus", diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb index 372c529b..130c735b 100644 --- a/spec/lib/completions/endpoints/endpoint_compliance.rb +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -62,18 +62,14 @@ class EndpointMock end def invocation_response - <<~TEXT - - - get_weather - - Sydney - c - - tool_0 - - - TEXT + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "get_weather", + parameters: { + location: "Sydney", + unit: "c", + }, + ) end def tool_id @@ -185,7 +181,7 @@ class EndpointsCompliance mock.stub_tool_call(a_dialect.translate) completion_response = endpoint.perform_completion!(a_dialect, user) - expect(completion_response.strip).to eq(mock.invocation_response.strip) + expect(completion_response).to eq(mock.invocation_response) end def streaming_mode_simple_prompt(mock) @@ -205,6 +201,7 @@ class EndpointsCompliance expect(log.raw_request_payload).to be_present expect(log.raw_response_payload).to be_present expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) + expect(log.response_tokens).to eq( endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join), ) @@ -216,14 +213,14 @@ class EndpointsCompliance a_dialect = dialect(prompt: prompt) mock.stub_streamed_tool_call(a_dialect.translate) do - buffered_partial = +"" + buffered_partial = [] endpoint.perform_completion!(a_dialect, user) do |partial, cancel| buffered_partial << partial - cancel.call if buffered_partial.include?("") + cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall) end - expect(buffered_partial.strip).to eq(mock.invocation_response.strip) + expect(buffered_partial).to eq([mock.invocation_response]) end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 2f602d3a..18933843 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -195,19 +195,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do response = llm.generate(prompt, user: user) - expected = (<<~XML).strip - - - echo - - <S>ydney - - tool_0 - - - XML + tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "echo", + parameters: { + text: "ydney", + }, + ) - expect(response.strip).to eq(expected) + expect(response).to eq(tool) end it "Supports Vision API" do @@ -265,6 +262,68 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do expect(JSON.parse(req_body)).to eq(expected_prompt) end + it "Can stream tool calls correctly" do + rows = [ + { + candidates: [ + { + content: { + parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }], + role: "model", + }, + safetyRatings: [ + { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" }, + ], + }, + ], + usageMetadata: { + promptTokenCount: 625, + totalTokenCount: 625, + }, + modelVersion: "gemini-1.5-pro-002", + }, + { + candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }], + usageMetadata: { + promptTokenCount: 625, + candidatesTokenCount: 4, + totalTokenCount: 629, + }, + modelVersion: "gemini-1.5-pro-002", + }, + ] + + payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join + + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + url = "#{model.url}:streamGenerateContent?alt=sse&key=123" + + prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool]) + + output = [] + + stub_request(:post, url).to_return(status: 200, body: payload) + llm.generate(prompt, user: user) { |partial| output << partial } + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "echo", + parameters: { + text: "sam<>wh!s", + }, + ) + + expect(output).to eq([tool_call]) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(625) + expect(log.response_tokens).to eq(4) + end + it "Can correctly handle streamed responses even if they are chunked badly" do data = +"" data << "da|ta: |" @@ -279,12 +338,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") url = "#{model.url}:streamGenerateContent?alt=sse&key=123" - output = +"" + output = [] gemini_mock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: split) llm.generate("Hello", user: user) { |partial| output << partial } end - expect(output).to eq("Hello World Sam") + expect(output.join).to eq("Hello World Sam") end end diff --git a/spec/lib/completions/endpoints/ollama_spec.rb b/spec/lib/completions/endpoints/ollama_spec.rb index eb6bc63c..4f458283 100644 --- a/spec/lib/completions/endpoints/ollama_spec.rb +++ b/spec/lib/completions/endpoints/ollama_spec.rb @@ -150,7 +150,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do end describe "when using streaming mode" do - context "with simpel prompts" do + context "with simple prompts" do it "completes a trivial prompt and logs the response" do compliance.streaming_mode_simple_prompt(ollama_mock) end diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 60df1d67..c4d7758a 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -17,8 +17,8 @@ class OpenAiMock < EndpointMock created: 1_678_464_820, model: "gpt-3.5-turbo-0301", usage: { - prompt_tokens: 337, - completion_tokens: 162, + prompt_tokens: 8, + completion_tokens: 13, total_tokens: 499, }, choices: [ @@ -231,19 +231,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do result = llm.generate(prompt, user: user) - expected = (<<~TXT).strip - - - echo - - hello - - call_I8LKnoijVuhKOM85nnEQgWwd - - - TXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + name: "echo", + parameters: { + text: "hello", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq(tool_call) stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( body: { choices: [message: { content: "OK" }] }.to_json, @@ -320,19 +317,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } }) - expected = (<<~TXT).strip - - - echo - - h<e>llo - - call_I8LKnoijVuhKOM85nnEQgWwd - - - TXT + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(55) + expect(log.response_tokens).to eq(13) - expect(result.strip).to eq(expected) + expected = + DiscourseAi::Completions::ToolCall.new( + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + name: "echo", + parameters: { + text: "hllo", + }, + ) + + expect(result).to eq(expected) stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( body: { choices: [message: { content: "OK" }] }.to_json, @@ -487,7 +485,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} @@ -495,32 +493,30 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do TEXT open_ai_mock.stub_raw(raw_data) - content = +"" + response = [] dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) - endpoint.perform_completion!(dialect, user) { |partial| content << partial } + endpoint.perform_completion!(dialect, user) { |partial| response << partial } - expected = <<~TEXT - - - search - - Discourse AI bot - - call_3Gyr3HylFJwfrtKrL6NaIit1 - - - search - - Discourse AI bot - - call_H7YkbgYurHpyJqzwUN4bghwN - - - TEXT + tool_calls = [ + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_3Gyr3HylFJwfrtKrL6NaIit1", + parameters: { + search_query: "Discourse AI bot", + }, + ), + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_H7YkbgYurHpyJqzwUN4bghwN", + parameters: { + query: "Discourse AI bot2", + }, + ), + ] - expect(content).to eq(expected) + expect(response).to eq(tool_calls) end it "uses proper token accounting" do @@ -593,21 +589,16 @@ TEXT dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) endpoint.perform_completion!(dialect, user) { |partial| partials << partial } - expect(partials.length).to eq(1) + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "func_id", + name: "google", + parameters: { + query: "Adabas 9.1", + }, + ) - function_call = (<<~TXT).strip - - - google - - Adabas 9.1 - - func_id - - - TXT - - expect(partials[0].strip).to eq(function_call) + expect(partials).to eq([tool_call]) end end end diff --git a/spec/lib/completions/endpoints/samba_nova_spec.rb b/spec/lib/completions/endpoints/samba_nova_spec.rb index 0f1f68ac..83839bf4 100644 --- a/spec/lib/completions/endpoints/samba_nova_spec.rb +++ b/spec/lib/completions/endpoints/samba_nova_spec.rb @@ -22,10 +22,15 @@ data: [DONE] }, ).to_return(status: 200, body: body, headers: {}) - response = +"" + response = [] llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial } - expect(response).to eq("I am a bot") + expect(response).to eq(["I am a bot"]) + + log = AiApiAuditLog.order(:id).last + + expect(log.request_tokens).to eq(21) + expect(log.response_tokens).to eq(41) end it "can perform regular completions" do diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 6f5387c0..824bcbe0 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -51,7 +51,13 @@ class VllmMock < EndpointMock WebMock .stub_request(:post, "https://test.dev/v1/chat/completions") - .with(body: model.default_options.merge(messages: prompt, stream: true).to_json) + .with( + body: + model + .default_options + .merge(messages: prompt, stream: true, stream_options: { include_usage: true }) + .to_json, + ) .to_return(status: 200, body: chunks) end end @@ -136,29 +142,115 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do result = llm.generate(prompt, user: Discourse.system_user) - expected = <<~TEXT - - - calculate - - 1+1 - tool_0 - - - TEXT + expected = + DiscourseAi::Completions::ToolCall.new( + name: "calculate", + id: "tool_0", + parameters: { + expression: "1+1", + }, + ) - expect(result.strip).to eq(expected.strip) + expect(result).to eq(expected) end end + it "correctly accounts for tokens in non streaming mode" do + body = (<<~TEXT).strip + {"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null} + TEXT + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body) + + result = llm.generate("generate a title", user: Discourse.system_user) + + expect(result).to eq("Random Number Generator Produces Smallest Possible Result") + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(146) + expect(log.response_tokens).to eq(10) + end + + it "can properly include usage in streaming mode" do + payload = <<~TEXT.strip + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}} + + data: [DONE] + TEXT + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return( + status: 200, + body: payload, + ) + + response = [] + llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial } + + expect(response.join).to eq( + "Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?", + ) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(46) + expect(log.response_tokens).to eq(26) + end + describe "#perform_completion!" do context "when using regular mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(vllm_mock) - end - end - context "with tools" do it "returns a function invocation" do compliance.regular_mode_tools(vllm_mock) diff --git a/spec/lib/completions/function_call_normalizer_spec.rb b/spec/lib/completions/function_call_normalizer_spec.rb deleted file mode 100644 index dd78ed7f..00000000 --- a/spec/lib/completions/function_call_normalizer_spec.rb +++ /dev/null @@ -1,182 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do - let(:buffer) { +"" } - - let(:normalizer) do - blk = ->(data, cancel) { buffer << data } - cancel = -> { @done = true } - DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel) - end - - def pass_through!(data) - normalizer << data - expect(buffer[-data.length..-1]).to eq(data) - end - - it "is usable in non streaming mode" do - xml = (<<~XML).strip - hello - - - hello - - XML - - text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) - - expect(text).to eq("hello") - - expected_function_calls = (<<~XML).strip - - - hello - tool_0 - - - XML - - expect(function_calls).to eq(expected_function_calls) - end - - it "strips junk from end of function calls" do - xml = (<<~XML).strip - hello - - - hello - - junk - XML - - _text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) - - expected_function_calls = (<<~XML).strip - - - hello - tool_0 - - - XML - - expect(function_calls).to eq(expected_function_calls) - end - - it "returns nil for function calls if there are none" do - input = "hello world\n" - text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input) - - expect(text).to eq(input) - expect(function_calls).to eq(nil) - end - - it "passes through data if there are no function calls detected" do - pass_through!("hello") - pass_through!("hello") - pass_through!("world") - pass_through!("") - end - - it "properly handles non English tools" do - normalizer << "hello\n" - - normalizer << (<<~XML).strip - - hello - - 世界 - - - XML - - expected = (<<~XML).strip - - - hello - - 世界 - - tool_0 - - - XML - - function_calls = normalizer.function_calls - expect(function_calls).to eq(expected) - end - - it "works correctly even if you only give it 1 letter at a time" do - xml = (<<~XML).strip - abc - - - hello - - world - - abc - - - hello2 - - world - - aba - - - XML - - xml.each_char { |char| normalizer << char } - - expect(buffer + normalizer.function_calls).to eq(xml) - end - - it "supports multiple invokes" do - xml = (<<~XML).strip - - - hello - - world - - abc - - - hello2 - - world - - aba - - - XML - - normalizer << xml - - expect(normalizer.function_calls).to eq(xml) - end - - it "can will cancel if it encounteres " do - normalizer << "" - expect(normalizer.done).to eq(false) - normalizer << "" - expect(normalizer.done).to eq(true) - expect(@done).to eq(true) - - expect(normalizer.function_calls).to eq("") - end - - it "pauses on function call and starts buffering" do - normalizer << "hello" - expect(buffer).to eq("hello") - expect(normalizer.done).to eq(false) - end -end diff --git a/spec/lib/completions/json_stream_decoder_spec.rb b/spec/lib/completions/json_stream_decoder_spec.rb new file mode 100644 index 00000000..831bad6f --- /dev/null +++ b/spec/lib/completions/json_stream_decoder_spec.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::JsonStreamDecoder do + let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new } + + it "should be able to parse simple messages" do + result = decoder << "data: #{{ hello: "world" }.to_json}" + expect(result).to eq([{ hello: "world" }]) + end + + it "should handle anthropic mixed stlye streams" do + stream = (<<~TEXT).split("|") + event: |message_start| + data: |{"hel|lo": "world"}| + + event: |message_start + data: {"foo": "bar"} + + event: |message_start + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each { |chunk| results << (decoder << chunk) } + + expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }]) + end + + it "should be able to handle complex overlaps" do + stream = (<<~TEXT).split("|") + data: |{"hel|lo": "world"} + + data: {"foo": "bar"} + + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each { |chunk| results << (decoder << chunk) } + + expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }]) + end +end diff --git a/spec/lib/completions/xml_tool_processor_spec.rb b/spec/lib/completions/xml_tool_processor_spec.rb new file mode 100644 index 00000000..003f4356 --- /dev/null +++ b/spec/lib/completions/xml_tool_processor_spec.rb @@ -0,0 +1,188 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::XmlToolProcessor do + let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new } + + it "can process simple text" do + result = [] + result << (processor << "hello") + result << (processor << " world ") + expect(result).to eq([["hello"], [" world "]]) + expect(processor.finish).to eq([]) + expect(processor.should_cancel?).to eq(false) + end + + it "is usable for simple single message mode" do + xml = (<<~XML).strip + hello + + + hello + + world + value + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "hello", + parameters: { + hello: "world", + test: "value", + }, + ) + expect(result).to eq([["hello"], [tool_call]]) + expect(processor.should_cancel?).to eq(false) + end + + it "handles multiple tool calls in sequence" do + xml = (<<~XML).strip + start + + + first_tool + + value1 + + + + second_tool + + value2 + + + + end + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + first_tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "first_tool", + parameters: { + param1: "value1", + }, + ) + + second_tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_1", + name: "second_tool", + parameters: { + param2: "value2", + }, + ) + + expect(result).to eq([["start"], [first_tool, second_tool]]) + expect(processor.should_cancel?).to eq(true) + end + + it "handles non-English parameters correctly" do + xml = (<<~XML).strip + こんにちは + + + translator + + 世界 + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "translator", + parameters: { + text: "世界", + }, + ) + + expect(result).to eq([["こんにちは"], [tool_call]]) + end + + it "processes input character by character" do + xml = + "hitest

v

" + + result = [] + xml.each_char { |char| result << (processor << char) } + result << processor.finish + + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" }) + + filtered_result = result.reject(&:empty?) + expect(filtered_result).to eq([["h"], ["i"], [tool_call]]) + end + + it "handles malformed XML gracefully" do + xml = (<<~XML).strip + text + + + test + + value + + + malformed + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + # Should just do its best to parse the XML + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" }) + expect(result).to eq([["text"], [tool_call]]) + end + + it "correctly processes empty parameter sets" do + xml = (<<~XML).strip + hello + + + no_params + + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {}) + + expect(result).to eq([["hello"], [tool_call]]) + end + + it "properly handles cancelled processing" do + xml = "start" + result = [] + result << (processor << xml) + result << (processor << "more text") + result << processor.finish + + expect(result).to eq([["start"], [], []]) + expect(processor.should_cancel?).to eq(true) + end +end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 2fb95d19..5271374d 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -72,40 +72,27 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do it "can parse string that are wrapped in quotes" do SiteSetting.ai_stability_api_key = "123" - xml = <<~XML - - - image - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - "16:9" - - - - image - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - '16:9' - - - - XML - image1, image2 = - tools = - DiscourseAi::AiBot::Personas::Artist.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) - expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) - expect(image1.parameters[:aspect_ratio]).to eq("16:9") - expect(image2.parameters[:aspect_ratio]).to eq("16:9") + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "image", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + prompts: ["cat oil painting", "big car"], + aspect_ratio: "16:9", + }, + ) - expect(tools.length).to eq(2) + tool_instance = + DiscourseAi::AiBot::Personas::Artist.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"]) + expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9") end it "enforces enums" do @@ -132,42 +119,68 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do XML - search1, search2 = - tools = - DiscourseAi::AiBot::Personas::General.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + status: "cow", + foo: "bar", + }, + ) - expect(search1.parameters.key?(:status)).to eq(false) - expect(search2.parameters[:status]).to eq("open") + tool_instance = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters.key?(:status)).to eq(false) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + status: "open", + foo: "bar", + }, + ) + + tool_instance = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters[:status]).to eq("open") end it "can coerce integers" do - xml = <<~XML - - - search - call_JtYQMful5QKqw97XFsHzPweB - - "3.2" - hello world - bar - - - - XML + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + search_query: "hello world", + foo: "bar", + }, + ) - search, = - tools = - DiscourseAi::AiBot::Personas::General.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) + search = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) expect(search.parameters[:max_posts]).to eq(3) expect(search.parameters[:search_query]).to eq("hello world") @@ -177,43 +190,23 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do it "can correctly parse arrays in tools" do SiteSetting.ai_openai_api_key = "123" - # Dall E tool uses an array for params - xml = <<~XML - - - dall_e - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - - - - dall_e - abc - - ["pic3"] - - - - unknown - abc - - ["pic3"] - - - - XML - dall_e1, dall_e2 = - tools = - DiscourseAi::AiBot::Personas::DallE3.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) - expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) - expect(dall_e2.parameters[:prompts]).to eq(["pic3"]) - expect(tools.length).to eq(2) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "dall_e", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + prompts: ["cat oil painting", "big car"], + }, + ) + + tool_instance = + DiscourseAi::AiBot::Personas::DallE3.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"]) end describe "custom personas" do diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 9c98a08a..2a07ad52 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -55,6 +55,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do ) end + before { SiteSetting.ai_embeddings_enabled = false } + after do # we must reset cache on persona cause data can be rolled back AiPersona.persona_cache.flush! @@ -83,17 +85,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do end let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } - let(:function_call) { (<<~XML).strip } - - - search - 666 - - Can you use the custom tool - - - ", - XML + let(:tool_call) do + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "666", + parameters: { + query: "Can you use the custom tool", + }, + ) + end let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) } @@ -115,7 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do reply_post = nil prompts = nil - responses = [function_call] + responses = [tool_call] DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| new_post = Fabricate(:post, raw: "Can you use the custom tool?") reply_post = playground.reply_to(new_post) @@ -133,7 +133,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do it "can force usage of a tool" do tool_name = "custom-#{custom_tool.id}" ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1) - responses = [function_call, "custom tool did stuff (maybe)"] + responses = [tool_call, "custom tool did stuff (maybe)"] prompts = nil reply_post = nil @@ -166,7 +166,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) playground = DiscourseAi::AiBot::Playground.new(bot) - responses = [function_call, "custom tool did stuff (maybe)"] + responses = [tool_call, "custom tool did stuff (maybe)"] reply_post = nil @@ -206,13 +206,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) playground = DiscourseAi::AiBot::Playground.new(bot) + responses = ["custom tool did stuff (maybe)", tool_call] + # lets ensure tool does not run... DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt| new_post = Fabricate(:post, raw: "Can you use the custom tool?") reply_post = playground.reply_to(new_post) end - expect(reply_post.raw.strip).to eq(function_call) + expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)") end end @@ -452,10 +454,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do it "can run tools" do persona.update!(tools: ["Time"]) - responses = [ - "timetimeBuenos Aires", - "The time is 2023-12-14 17:24:00 -0300", - ] + tool_call1 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Buenos Aires", + }, + ) + + tool_call2 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Sydney", + }, + ) + + responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"] message = DiscourseAi::Completions::Llm.with_prepared_responses(responses) do @@ -470,7 +487,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do # it also needs to have tool details now set on message prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id) - expect(prompt.custom_prompt.length).to eq(3) + + expect(prompt.custom_prompt.length).to eq(5) # TODO in chat I am mixed on including this in the context, but I guess maybe? # thinking about this @@ -782,30 +800,29 @@ RSpec.describe DiscourseAi::AiBot::Playground do end it "supports multiple function calls" do - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - search - search - - another search - - - - TXT + tool_call1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) + + tool_call2 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "another search", + }, + ) response2 = "I found stuff" - DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do - playground.reply_to(third_post) - end + DiscourseAi::Completions::Llm.with_prepared_responses( + [[tool_call1, tool_call2], response2], + ) { playground.reply_to(third_post) } last_post = third_post.topic.reload.posts.order(:post_number).last @@ -819,17 +836,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new) playground = described_class.new(bot) - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - TXT + response1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) response2 = "I found stuff" @@ -843,17 +857,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do end it "does not include placeholders in conversation context but includes all completions" do - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - TXT + response1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) response2 = "I found some really amazing stuff!" @@ -889,17 +900,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do [{ b64_json: image, revised_prompt: "a pink cow 1" }] end - let(:response) { (<<~TXT).strip } - - - dall_e - dall_e - - ["a pink cow"] - - - - TXT + let(:response) do + DiscourseAi::Completions::ToolCall.new( + name: "dall_e", + id: "dall_e", + parameters: { + prompts: ["a pink cow"], + }, + ) + end it "properly returns an image when skipping tool details" do persona.update!(tool_details: false) diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index fb42506e..16e0001b 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -541,16 +541,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(topic.title).to eq("An amazing title") expect(topic.posts.count).to eq(2) - # now let's try to make a reply with a tool call - function_call = <<~XML - - - categories - - - XML + tool_call = + DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1") - fake_endpoint.fake_content = [function_call, "this is the response after the tool"] + fake_endpoint.fake_content = [tool_call, "this is the response after the tool"] # this simplifies function calls fake_endpoint.chunk_count = 1 diff --git a/spec/requests/ai_bot/bot_controller_spec.rb b/spec/requests/ai_bot/bot_controller_spec.rb index e7430185..007e868b 100644 --- a/spec/requests/ai_bot/bot_controller_spec.rb +++ b/spec/requests/ai_bot/bot_controller_spec.rb @@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do fab!(:user) fab!(:pm_topic) { Fabricate(:private_message_topic) } fab!(:pm_post) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) } before { sign_in(user) } @@ -22,15 +24,37 @@ RSpec.describe DiscourseAi::AiBot::BotController do user = pm_topic.topic_allowed_users.first.user sign_in(user) - AiApiAuditLog.create!( - post_id: pm_post.id, - provider_id: 1, - topic_id: pm_topic.id, - raw_request_payload: "request", - raw_response_payload: "response", - request_tokens: 1, - response_tokens: 2, - ) + log1 = + AiApiAuditLog.create!( + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + + log2 = + AiApiAuditLog.create!( + post_id: pm_post.id, + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + + log3 = + AiApiAuditLog.create!( + post_id: pm_post2.id, + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) Group.refresh_automatic_groups! SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s @@ -38,18 +62,26 @@ RSpec.describe DiscourseAi::AiBot::BotController do get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info" expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log2.id) + expect(response.parsed_body["next_log_id"]).to eq(log3.id) + expect(response.parsed_body["prev_log_id"]).to eq(log1.id) + expect(response.parsed_body["topic_id"]).to eq(pm_topic.id) + expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["raw_request_payload"]).to eq("request") expect(response.parsed_body["raw_response_payload"]).to eq("response") - post2 = Fabricate(:post, topic: pm_topic) - # return previous post if current has no debug info - get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info" + get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info" expect(response.status).to eq(200) expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) + + # can return debug info by id as well + get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}" + expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log1.id) end end