diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index f426925a..45752927 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -22,12 +22,13 @@ end # id :bigint not null, primary key # provider_id :integer not null # user_id :integer -# topic_id :integer -# post_id :integer # request_tokens :integer # response_tokens :integer # raw_request_payload :string # raw_response_payload :string # created_at :datetime not null # updated_at :datetime not null +# topic_id :integer +# post_id :integer +# feature_name :string(255) # diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index fe95f782..9ed71e1e 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -20,6 +20,7 @@ en: mistral_7b_instruct_v0_2: Mistral 7B Instruct V0.2 command_r: Cohere Command R command_r_plus: Cohere Command R+ + gpt_4o: GPT 4 Omni scriptables: llm_report: fields: @@ -328,6 +329,7 @@ en: cohere-command-r-plus: "Cohere Command R Plus" gpt-4: "GPT-4" gpt-4-turbo: "GPT-4 Turbo" + gpt-4o: "GPT-4 Omni" gpt-3: 5-turbo: "GPT-3.5" claude-2: "Claude 2" diff --git a/config/locales/client.pl_PL.yml b/config/locales/client.pl_PL.yml index 2e495d46..8412822d 100644 --- a/config/locales/client.pl_PL.yml +++ b/config/locales/client.pl_PL.yml @@ -215,6 +215,7 @@ pl_PL: bot_names: gpt-4: "GPT-4" gpt-4-turbo: "GPT-4 Turbo" + gpt-4o: "GPT-4 Omni" gpt-3: 5-turbo: "GPT-3.5" claude-2: "Claude 2" diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index d3438549..689b9db7 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -43,6 +43,7 @@ en: ai_openai_gpt35_url: "Custom URL used for GPT 3.5 chat completions. (for Azure support)" ai_openai_gpt35_16k_url: "Custom URL used for GPT 3.5 16k chat completions. (for Azure support)" ai_openai_gpt4_url: "Custom URL used for GPT 4 chat completions. (for Azure support)" + ai_openai_gpt4o_url: "Custom URL used for GPT 4 Omni chat completions. (for Azure support)" ai_openai_gpt4_32k_url: "Custom URL used for GPT 4 32k chat completions. (for Azure support)" ai_openai_gpt4_turbo_url: "Custom URL used for GPT 4 Turbo chat completions. (for Azure support)" ai_openai_dall_e_3_url: "Custom URL used for DALL-E 3 image generation. (for Azure support)" diff --git a/config/settings.yml b/config/settings.yml index cff41ce5..02e4f386 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -98,6 +98,7 @@ discourse_ai: ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions" + ai_openai_gpt4o_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt4_turbo_url: "https://api.openai.com/v1/chat/completions" @@ -343,6 +344,7 @@ discourse_ai: - gpt-3.5-turbo - gpt-4 - gpt-4-turbo + - gpt-4o - claude-2 - gemini-1.5-pro - mixtral-8x7B-Instruct-V0.1 diff --git a/db/migrate/20240514001334_add_feature_name_to_ai_api_audit_log.rb b/db/migrate/20240514001334_add_feature_name_to_ai_api_audit_log.rb new file mode 100644 index 00000000..1f1f08dc --- /dev/null +++ b/db/migrate/20240514001334_add_feature_name_to_ai_api_audit_log.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class AddFeatureNameToAiApiAuditLog < ActiveRecord::Migration[7.0] + def change + add_column :ai_api_audit_logs, :feature_name, :string, limit: 255 + end +end diff --git a/db/post_migrate/20240119152348_explicit_provider_backwards_compat.rb b/db/post_migrate/20240119152348_explicit_provider_backwards_compat.rb index c57acb09..86ee7f16 100644 --- a/db/post_migrate/20240119152348_explicit_provider_backwards_compat.rb +++ b/db/post_migrate/20240119152348_explicit_provider_backwards_compat.rb @@ -52,7 +52,7 @@ class ExplicitProviderBackwardsCompat < ActiveRecord::Migration[7.0] end def append_provider(value) - open_ai_models = %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo] + open_ai_models = %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo gpt-4o] return "open_ai:#{value}" if open_ai_models.include?(value) return "google:#{value}" if value == "gemini-pro" diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index d2ca71f5..b97a6abc 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -43,7 +43,7 @@ module DiscourseAi DiscourseAi::Completions::Llm .proxy(model) - .generate(title_prompt, user: post.user) + .generate(title_prompt, user: post.user, feature_name: "bot_title") .strip .split("\n") .last @@ -67,7 +67,7 @@ module DiscourseAi tool_found = false result = - llm.generate(prompt, **llm_kwargs) do |partial, cancel| + llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) if (tools.present?) @@ -162,6 +162,8 @@ module DiscourseAi "open_ai:gpt-4" when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID "open_ai:gpt-4-turbo" + when DiscourseAi::AiBot::EntryPoint::GPT4O_ID + "open_ai:gpt-4o" when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID "open_ai:gpt-3.5-turbo-16k" when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index 8d943c79..63e7721f 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -18,6 +18,7 @@ module DiscourseAi CLAUDE_3_SONNET_ID = -118 CLAUDE_3_HAIKU_ID = -119 COHERE_COMMAND_R_PLUS = -120 + GPT4O_ID = -121 BOTS = [ [GPT4_ID, "gpt4_bot", "gpt-4"], @@ -31,6 +32,7 @@ module DiscourseAi [CLAUDE_3_SONNET_ID, "claude_3_sonnet_bot", "claude-3-sonnet"], [CLAUDE_3_HAIKU_ID, "claude_3_haiku_bot", "claude-3-haiku"], [COHERE_COMMAND_R_PLUS, "cohere_command_bot", "cohere-command-r-plus"], + [GPT4O_ID, "gpt4o_bot", "gpt-4o"], ] BOT_USER_IDS = BOTS.map(&:first) @@ -49,6 +51,8 @@ module DiscourseAi def self.map_bot_model_to_user_id(model_name) case model_name + in "gpt-4o" + GPT4O_ID in "gpt-4-turbo" GPT4_TURBO_ID in "gpt-3.5-turbo" diff --git a/lib/ai_bot/question_consolidator.rb b/lib/ai_bot/question_consolidator.rb index 4a4e7612..59159148 100644 --- a/lib/ai_bot/question_consolidator.rb +++ b/lib/ai_bot/question_consolidator.rb @@ -17,7 +17,7 @@ module DiscourseAi end def consolidate_question - @llm.generate(revised_prompt, user: @user) + @llm.generate(revised_prompt, user: @user, feature_name: "question_consolidator") end def revised_prompt diff --git a/lib/ai_bot/tools/summarize.rb b/lib/ai_bot/tools/summarize.rb index 3a916301..214d45fa 100644 --- a/lib/ai_bot/tools/summarize.rb +++ b/lib/ai_bot/tools/summarize.rb @@ -135,7 +135,14 @@ module DiscourseAi prompt = section_prompt(topic, section, guidance) - summary = llm.generate(prompt, temperature: 0.6, max_tokens: 400, user: bot_user) + summary = + llm.generate( + prompt, + temperature: 0.6, + max_tokens: 400, + user: bot_user, + feature_name: "summarize_tool", + ) summaries << summary end @@ -150,7 +157,13 @@ module DiscourseAi "concatenated the disjoint summaries, creating a cohesive narrative:\n#{summaries.join("\n")}}", } - llm.generate(concatenation_prompt, temperature: 0.6, max_tokens: 500, user: bot_user) + llm.generate( + concatenation_prompt, + temperature: 0.6, + max_tokens: 500, + user: bot_user, + feature_name: "summarize_tool", + ) else summaries.first end diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index cbb16e3c..25c177a9 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -85,6 +85,7 @@ module DiscourseAi user: user, temperature: completion_prompt.temperature, stop_sequences: completion_prompt.stop_sequences, + feature_name: "ai_helper", &block ) end @@ -163,6 +164,7 @@ module DiscourseAi prompt, user: Discourse.system_user, max_tokens: 1024, + feature_name: "image_caption", ) end end diff --git a/lib/ai_helper/chat_thread_titler.rb b/lib/ai_helper/chat_thread_titler.rb index 36b4517f..2b2c1828 100644 --- a/lib/ai_helper/chat_thread_titler.rb +++ b/lib/ai_helper/chat_thread_titler.rb @@ -32,6 +32,7 @@ module DiscourseAi prompt, user: Discourse.system_user, stop_sequences: [""], + feature_name: "chat_thread_title", ) end diff --git a/lib/ai_helper/painter.rb b/lib/ai_helper/painter.rb index 63a6bc7d..cfa79b4c 100644 --- a/lib/ai_helper/painter.rb +++ b/lib/ai_helper/painter.rb @@ -68,6 +68,7 @@ module DiscourseAi DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).generate( prompt, user: user, + feature_name: "illustrate_post", ) end end diff --git a/lib/automation.rb b/lib/automation.rb index 71cde97b..ba0a36ed 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -3,6 +3,7 @@ module DiscourseAi module Automation AVAILABLE_MODELS = [ + { id: "gpt-4o", name: "discourse_automation.ai_models.gpt_4o" }, { id: "gpt-4-turbo", name: "discourse_automation.ai_models.gpt_4_turbo" }, { id: "gpt-4", name: "discourse_automation.ai_models.gpt_4" }, { id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" }, diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb index 4bdcc257..c116043b 100644 --- a/lib/automation/llm_triage.rb +++ b/lib/automation/llm_triage.rb @@ -41,6 +41,7 @@ module DiscourseAi temperature: 0, max_tokens: llm.tokenizer.tokenize(search_for_text).length * 2 + 10, user: Discourse.system_user, + feature_name: "llm_triage", ) if result.present? && result.strip.downcase.include?(search_for_text) diff --git a/lib/automation/report_runner.rb b/lib/automation/report_runner.rb index ea218ee1..5205e701 100644 --- a/lib/automation/report_runner.rb +++ b/lib/automation/report_runner.rb @@ -154,6 +154,7 @@ Follow the provided writing composition instructions carefully and precisely ste temperature: @temperature, top_p: @top_p, user: Discourse.system_user, + feature_name: "ai_report", ) do |response| print response if Rails.env.development? && @debug_mode result << response diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index f6142d09..196a0015 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -83,7 +83,8 @@ module DiscourseAi end def inline_images(content, message) - if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo" + if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo" || + model_name == "gpt-4o" content = message[:content] encoded_uploads = prompt.encoded_uploads(message) if encoded_uploads.present? @@ -125,6 +126,8 @@ module DiscourseAi 32_768 when "gpt-4-turbo" 131_072 + when "gpt-4o" + 131_072 else 8192 end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index eede850f..69fbfcc9 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -73,7 +73,7 @@ module DiscourseAi true end - def perform_completion!(dialect, user, model_params = {}, &blk) + def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk) allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) @@ -114,6 +114,7 @@ module DiscourseAi request_tokens: prompt_size(prompt), topic_id: dialect.prompt.topic_id, post_id: dialect.prompt.post_id, + feature_name: feature_name, ) if !@streaming_mode diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index ab04961e..1beed66f 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -23,7 +23,7 @@ module DiscourseAi attr_reader :responses, :completions, :prompt - def perform_completion!(prompt, _user, _model_params) + def perform_completion!(prompt, _user, _model_params, feature_name: nil) @prompt = prompt response = responses[completions] if response.nil? diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index 982d4242..08720b73 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -110,7 +110,7 @@ module DiscourseAi @last_call = params end - def perform_completion!(dialect, user, model_params = {}) + def perform_completion!(dialect, user, model_params = {}, feature_name: nil) self.class.last_call = { dialect: dialect, user: user, model_params: model_params } content = self.class.fake_content diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 9aa00ea1..4d75a330 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -12,6 +12,7 @@ module DiscourseAi def dependant_setting_names %w[ ai_openai_api_key + ai_openai_gpt4o_url ai_openai_gpt4_32k_url ai_openai_gpt4_turbo_url ai_openai_gpt4_url @@ -33,6 +34,8 @@ module DiscourseAi else if model.include?("1106") || model.include?("turbo") SiteSetting.ai_openai_gpt4_turbo_url + elsif model.include?("gpt-4o") + SiteSetting.ai_openai_gpt4o_url else SiteSetting.ai_openai_gpt4_url end @@ -98,35 +101,47 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) - default_options - .merge(model_params) - .merge(messages: prompt) - .tap do |payload| - payload[:stream] = true if @streaming_mode - payload[:tools] = dialect.tools if dialect.tools.present? - end + payload = default_options.merge(model_params).merge(messages: prompt) + + if @streaming_mode + payload[:stream] = true + payload[:stream_options] = { include_usage: true } + end + + payload[:tools] = dialect.tools if dialect.tools.present? + payload end def prepare_request(payload) - headers = - { "Content-Type" => "application/json" }.tap do |h| - if model_uri.host.include?("azure") - h["api-key"] = SiteSetting.ai_openai_api_key - else - h["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" - end + headers = { "Content-Type" => "application/json" } - if SiteSetting.ai_openai_organization.present? - h["OpenAI-Organization"] = SiteSetting.ai_openai_organization - end - end + if model_uri.host.include?("azure") + headers["api-key"] = SiteSetting.ai_openai_api_key + else + headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" + end + + if SiteSetting.ai_openai_organization.present? + headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization + end Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def final_log_update(log) + log.request_tokens = @prompt_tokens if @prompt_tokens + log.response_tokens = @completion_tokens if @completion_tokens + end + def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here + json = JSON.parse(response_raw, symbolize_names: true) + + if @streaming_mode + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) + end + + parsed = json.dig(:choices, 0) return if !parsed response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index d5548890..21b4a67c 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -54,6 +54,7 @@ module DiscourseAi gpt-4-32k gpt-4-turbo gpt-4-vision-preview + gpt-4o ], google: %w[gemini-pro gemini-1.5-pro], }.tap do |h| @@ -106,12 +107,6 @@ module DiscourseAi dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov) - if is_custom_model - tokenizer = llm_model.tokenizer_class - else - tokenizer = dialect_klass.tokenizer - end - if @canned_response if @canned_llm && @canned_llm != model_name raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" @@ -164,6 +159,7 @@ module DiscourseAi max_tokens: nil, stop_sequences: nil, user:, + feature_name: nil, &partial_read_blk ) self.class.record_prompt(prompt) @@ -196,7 +192,13 @@ module DiscourseAi model_name, opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens), ) - gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) + gateway.perform_completion!( + dialect, + user, + model_params, + feature_name: feature_name, + &partial_read_blk + ) end def max_prompt_tokens diff --git a/lib/configuration/llm_validator.rb b/lib/configuration/llm_validator.rb index 8391f085..c39f3b7c 100644 --- a/lib/configuration/llm_validator.rb +++ b/lib/configuration/llm_validator.rb @@ -69,7 +69,7 @@ module DiscourseAi def can_talk_to_model?(model_name) DiscourseAi::Completions::Llm .proxy(model_name) - .generate("How much is 1 + 1?", user: nil) + .generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator") .present? rescue StandardError false diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index 8de23797..824f280c 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -169,7 +169,7 @@ module DiscourseAi llm_response = DiscourseAi::Completions::Llm.proxy( SiteSetting.ai_embeddings_semantic_search_hyde_model, - ).generate(prompt, user: @guardian.user) + ).generate(prompt, user: @guardian.user, feature_name: "semantic_search_hyde") Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response end diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index 8e4a18c1..8f42c346 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -8,6 +8,7 @@ module DiscourseAi Models::OpenAi.new("open_ai:gpt-4", max_tokens: 8192), Models::OpenAi.new("open_ai:gpt-4-32k", max_tokens: 32_768), Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000), + Models::OpenAi.new("open_ai:gpt-4o", max_tokens: 100_000), Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096), Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384), Models::Gemini.new("google:gemini-pro", max_tokens: 32_768), @@ -50,24 +51,31 @@ module DiscourseAi max_tokens: 32_000, ) - LlmModel.all.each do |model| - foldable_models << Models::CustomLlm.new( - "custom:#{model.id}", - max_tokens: model.max_prompt_tokens, - ) - end + # TODO: Roman, we need to de-register custom LLMs on destroy from summarization + # strategy and clear cache + # it may be better to pull all of this code into Discourse AI cause as it stands + # the coupling is making it really hard to reason about summarization + # + # Auto registration and de-registration needs to be tested + + #LlmModel.all.each do |model| + # foldable_models << Models::CustomLlm.new( + # "custom:#{model.id}", + # max_tokens: model.max_prompt_tokens, + # ) + #end foldable_models.each do |model| plugin.register_summarization_strategy(Strategies::FoldContent.new(model)) end - plugin.add_model_callback(LlmModel, :after_create) do - new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens) + #plugin.add_model_callback(LlmModel, :after_create) do + # new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens) - if ::Summarization::Base.find_strategy("custom:#{self.id}").nil? - plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model)) - end - end + # if ::Summarization::Base.find_strategy("custom:#{self.id}").nil? + # plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model)) + # end + #end end end end diff --git a/lib/summarization/strategies/fold_content.rb b/lib/summarization/strategies/fold_content.rb index d064c325..2184e05a 100644 --- a/lib/summarization/strategies/fold_content.rb +++ b/lib/summarization/strategies/fold_content.rb @@ -99,14 +99,19 @@ module DiscourseAi def summarize_single(llm, text, user, opts, &on_partial_blk) prompt = summarization_prompt(text, opts) - llm.generate(prompt, user: user, &on_partial_blk) + llm.generate(prompt, user: user, feature_name: "summarize", &on_partial_blk) end def summarize_in_chunks(llm, chunks, user, opts) chunks.map do |chunk| prompt = summarization_prompt(chunk[:summary], opts) - chunk[:summary] = llm.generate(prompt, user: user, max_tokens: 300) + chunk[:summary] = llm.generate( + prompt, + user: user, + max_tokens: 300, + feature_name: "summarize", + ) chunk end end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index dffbab95..9faae853 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -268,7 +268,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do ).to_return(status: 200, body: body) result = +"" - llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial } + llm.generate(prompt, user: Discourse.system_user, feature_name: "testing") do |partial, cancel| + result << partial + end expect(result).to eq("Hello!") @@ -285,6 +287,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) expect(log.request_tokens).to eq(25) expect(log.response_tokens).to eq(15) + expect(log.feature_name).to eq("testing") end it "can return multiple function calls" do diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 1da22be0..b4d942f2 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -135,7 +135,10 @@ class OpenAiMock < EndpointMock .default_options .merge(messages: prompt) .tap do |b| - b[:stream] = true if stream + if stream + b[:stream] = true + b[:stream_options] = { include_usage: true } + end b[:tools] = [tool_payload] if tool_call end .to_json @@ -431,6 +434,36 @@ TEXT expect(content).to eq(expected) end + it "uses proper token accounting" do + response = <<~TEXT.strip + data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}| + + data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null}| + + data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}| + + data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[],"usage":{"prompt_tokens":20,"completion_tokens":9,"total_tokens":29}}| + + data: [DONE] + TEXT + + chunks = response.split("|") + open_ai_mock.with_chunk_array_support do + open_ai_mock.stub_raw(chunks) + partials = [] + + dialect = compliance.dialect(prompt: compliance.generic_prompt) + endpoint.perform_completion!(dialect, user) { |partial| partials << partial } + + expect(partials).to eq(["Hello"]) + + log = AiApiAuditLog.order("id desc").first + + expect(log.request_tokens).to eq(20) + expect(log.response_tokens).to eq(9) + end + end + it "properly handles spaces in tools payload" do raw_data = <<~TEXT.strip data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}