diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index d36c8f1d..7d49e396 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -95,7 +95,17 @@ module DiscourseAi range = (0..-1) if messages.dig(0, :type) == :system + max_system_tokens = prompt_limit * 0.6 system_message = messages[0] + system_size = calculate_message_token(system_message) + + if system_size > max_system_tokens + system_message[:content] = tokenizer.truncate( + system_message[:content], + max_system_tokens, + ) + end + trimmed_messages << system_message current_token_count += calculate_message_token(system_message) range = (1..-1) diff --git a/spec/lib/completions/dialects/dialect_spec.rb b/spec/lib/completions/dialects/dialect_spec.rb index e5511bbf..a7667a7c 100644 --- a/spec/lib/completions/dialects/dialect_spec.rb +++ b/spec/lib/completions/dialects/dialect_spec.rb @@ -8,22 +8,20 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect end def tokenizer - Class.new do - def self.size(str) - str.length - end - end + DiscourseAi::Tokenizer::OpenAiTokenizer end end RSpec.describe DiscourseAi::Completions::Dialects::Dialect do describe "#trim_messages" do + let(:five_token_msg) { "This represents five tokens." } + it "should trim tool messages if tool_calls are trimmed" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.push(type: :user, content: "12345") - prompt.push(type: :tool_call, content: "12345", id: 1) - prompt.push(type: :tool, content: "12345", id: 1) - prompt.push(type: :user, content: "12345") + prompt = DiscourseAi::Completions::Prompt.new(five_token_msg) + prompt.push(type: :user, content: five_token_msg) + prompt.push(type: :tool_call, content: five_token_msg, id: 1) + prompt.push(type: :tool, content: five_token_msg, id: 1) + prompt.push(type: :user, content: five_token_msg) dialect = TestDialect.new(prompt, "test") dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message @@ -31,7 +29,24 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do trimmed = dialect.trim(prompt.messages) expect(trimmed).to eq( - [{ type: :system, content: "12345" }, { type: :user, content: "12345" }], + [{ type: :system, content: five_token_msg }, { type: :user, content: five_token_msg }], + ) + end + + it "limits the system message to 60% of available tokens" do + prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens") + prompt.push(type: :user, content: five_token_msg) + + dialect = TestDialect.new(prompt, "test") + dialect.max_prompt_tokens = 15 + + trimmed = dialect.trim(prompt.messages) + + expect(trimmed).to eq( + [ + { type: :system, content: "I'm a system message consisting of 10" }, + { type: :user, content: five_token_msg }, + ], ) end end