From 8b81ff45b8dadfdd489cf0f8a86621080dee2af3 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 7 Jun 2024 23:52:01 +1000 Subject: [PATCH] FIX: switch off native tools on Anthropic Claude Opus (#659) Native tools do not work well on Opus. Chain of Thought prompting means it consumes enormous amounts of tokens and has poor latency. This commit introduce and XML stripper to remove various chain of thought XML islands from anthropic prompts when tools are involved. This mean Opus native tools is now functions (albeit slowly) From local testing XML just works better now. Also fixes enum support in Anthropic native tools --- config/locales/server.en.yml | 1 + config/settings.yml | 9 ++ lib/completions/dialects/claude.rb | 33 +++-- lib/completions/dialects/claude_tools.rb | 48 +++----- lib/completions/dialects/xml_tools.rb | 16 ++- lib/completions/endpoints/anthropic.rb | 21 +++- lib/completions/endpoints/aws_bedrock.rb | 17 ++- lib/completions/endpoints/base.rb | 21 ++++ lib/completions/xml_tag_stripper.rb | 115 ++++++++++++++++++ spec/lib/completions/dialects/claude_spec.rb | 65 +++++++++- .../completions/endpoints/anthropic_spec.rb | 1 + .../completions/endpoints/aws_bedrock_spec.rb | 107 ++++++++++++++-- spec/lib/completions/xml_tag_stripper_spec.rb | 51 ++++++++ 13 files changed, 439 insertions(+), 66 deletions(-) create mode 100644 lib/completions/xml_tag_stripper.rb create mode 100644 spec/lib/completions/xml_tag_stripper_spec.rb diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index c656283a..2bb29c83 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -51,6 +51,7 @@ en: ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_api_key: "API key for OpenAI API" ai_anthropic_api_key: "API key for Anthropic API" + ai_anthropic_native_tool_call_models: "List of models that will use native tool calls vs legacy XML based tools." ai_cohere_api_key: "API key for Cohere API" ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference" ai_hugging_face_api_key: API key for Hugging Face API diff --git a/config/settings.yml b/config/settings.yml index fe880e2c..2b0e56a5 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -111,6 +111,15 @@ discourse_ai: ai_anthropic_api_key: default: "" secret: true + ai_anthropic_native_tool_call_models: + type: list + list_type: compact + default: "claude-3-sonnet|claude-3-haiku" + allow_any: false + choices: + - claude-3-opus + - claude-3-sonnet + - claude-3-haiku ai_cohere_api_key: default: "" secret: true diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 6a8a9543..2c6fd131 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -22,6 +22,10 @@ module DiscourseAi @messages = messages @tools = tools end + + def has_tools? + tools.present? + end end def tokenizer @@ -33,6 +37,10 @@ module DiscourseAi system_prompt = messages.shift[:content] if messages.first[:role] == "system" + if !system_prompt && !native_tool_support? + system_prompt = tools_dialect.instructions.presence + end + interleving_messages = [] previous_message = nil @@ -48,11 +56,10 @@ module DiscourseAi previous_message = message end - ClaudePrompt.new( - system_prompt.presence, - interleving_messages, - tools_dialect.translated_tools, - ) + tools = nil + tools = tools_dialect.translated_tools if native_tool_support? + + ClaudePrompt.new(system_prompt.presence, interleving_messages, tools) end def max_prompt_tokens @@ -62,18 +69,28 @@ module DiscourseAi 200_000 # Claude-3 has a 200k context window for now end + def native_tool_support? + SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name) + end + private def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools) + if native_tool_support? + @tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools) + else + super + end end def tool_call_msg(msg) - tools_dialect.from_raw_tool_call(msg) + translated = tools_dialect.from_raw_tool_call(msg) + { role: "assistant", content: translated } end def tool_msg(msg) - tools_dialect.from_raw_tool(msg) + translated = tools_dialect.from_raw_tool(msg) + { role: "user", content: translated } end def model_msg(msg) diff --git a/lib/completions/dialects/claude_tools.rb b/lib/completions/dialects/claude_tools.rb index 8708497f..b42a1833 100644 --- a/lib/completions/dialects/claude_tools.rb +++ b/lib/completions/dialects/claude_tools.rb @@ -15,13 +15,14 @@ module DiscourseAi required = [] if t[:parameters] - properties = - t[:parameters].each_with_object({}) do |param, h| - h[param[:name]] = { - type: param[:type], - description: param[:description], - }.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] } - end + properties = {} + + t[:parameters].each do |param| + mapped = { type: param[:type], description: param[:description] } + mapped[:items] = { type: param[:item_type] } if param[:item_type] + mapped[:enum] = param[:enum] if param[:enum] + properties[param[:name]] = mapped + end required = t[:parameters].select { |param| param[:required] }.map { |param| param[:name] } end @@ -39,37 +40,24 @@ module DiscourseAi end def instructions - "" # Noop. Tools are listed separate. + "" end def from_raw_tool_call(raw_message) call_details = JSON.parse(raw_message[:content], symbolize_names: true) tool_call_id = raw_message[:id] - - { - role: "assistant", - content: [ - { - type: "tool_use", - id: tool_call_id, - name: raw_message[:name], - input: call_details[:arguments], - }, - ], - } + [ + { + type: "tool_use", + id: tool_call_id, + name: raw_message[:name], + input: call_details[:arguments], + }, + ] end def from_raw_tool(raw_message) - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: raw_message[:id], - content: raw_message[:content], - }, - ], - } + [{ type: "tool_result", tool_use_id: raw_message[:id], content: raw_message[:content] }] end private diff --git a/lib/completions/dialects/xml_tools.rb b/lib/completions/dialects/xml_tools.rb index 47988a71..9eabfadf 100644 --- a/lib/completions/dialects/xml_tools.rb +++ b/lib/completions/dialects/xml_tools.rb @@ -41,13 +41,17 @@ module DiscourseAi def instructions return "" if raw_tools.blank? - has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } + @instructions ||= + begin + has_arrays = + raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } - (<<~TEXT).strip - #{tool_preamble(include_array_tip: has_arrays)} - - #{translated_tools} - TEXT + (<<~TEXT).strip + #{tool_preamble(include_array_tip: has_arrays)} + + #{translated_tools} + TEXT + end end def from_raw_tool(raw_message) diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index c1ec288f..2739b02d 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -45,7 +45,12 @@ module DiscourseAi raise "Unsupported model: #{model}" end - { model: mapped_model, max_tokens: 3_000 } + options = { model: mapped_model, max_tokens: 3_000 } + + options[:stop_sequences] = [""] if !dialect.native_tool_support? && + dialect.prompt.has_tools? + + options end def provider_id @@ -54,6 +59,14 @@ module DiscourseAi private + def xml_tags_to_strip(dialect) + if dialect.prompt.has_tools? + %w[thinking search_quality_reflection search_quality_score] + else + [] + end + end + # this is an approximation, we will update it later if request goes through def prompt_size(prompt) tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s) @@ -66,11 +79,13 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) + @native_tool_support = dialect.native_tool_support? + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:stream] = true if @streaming_mode - payload[:tools] = prompt.tools if prompt.tools.present? + payload[:tools] = prompt.tools if prompt.has_tools? payload end @@ -108,7 +123,7 @@ module DiscourseAi end def native_tool_support? - true + @native_tool_support end def partials_from(decoded_chunk) diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index bbd87749..d0ef6274 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -36,6 +36,9 @@ module DiscourseAi def default_options(dialect) options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" } + + options[:stop_sequences] = [""] if !dialect.native_tool_support? && + dialect.prompt.has_tools? options end @@ -43,6 +46,14 @@ module DiscourseAi AiApiAuditLog::Provider::Anthropic end + def xml_tags_to_strip(dialect) + if dialect.prompt.has_tools? + %w[thinking search_quality_reflection search_quality_score] + else + [] + end + end + private def prompt_size(prompt) @@ -79,9 +90,11 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) + @native_tool_support = dialect.native_tool_support? + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? - payload[:tools] = prompt.tools if prompt.tools.present? + payload[:tools] = prompt.tools if prompt.has_tools? payload end @@ -169,7 +182,7 @@ module DiscourseAi end def native_tool_support? - true + @native_tool_support end def chunk_to_string(chunk) diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 8d9d5f68..3c0c5984 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -78,11 +78,27 @@ module DiscourseAi end end + def xml_tags_to_strip(dialect) + [] + end + def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk) allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) + orig_blk = blk @streaming_mode = block_given? + to_strip = xml_tags_to_strip(dialect) + @xml_stripper = + DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? + + if @streaming_mode && @xml_stripper + blk = + lambda do |partial, cancel| + partial = @xml_stripper << partial + orig_blk.call(partial, cancel) if partial + end + end prompt = dialect.translate @@ -270,6 +286,11 @@ module DiscourseAi blk.call(function_calls, cancel) end + if @xml_stripper + leftover = @xml_stripper.finish + orig_blk.call(leftover, cancel) if leftover.present? + end + return response_data ensure if log diff --git a/lib/completions/xml_tag_stripper.rb b/lib/completions/xml_tag_stripper.rb new file mode 100644 index 00000000..729c14f7 --- /dev/null +++ b/lib/completions/xml_tag_stripper.rb @@ -0,0 +1,115 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class XmlTagStripper + def initialize(tags_to_strip) + @tags_to_strip = tags_to_strip + @longest_tag = tags_to_strip.map(&:length).max + @parsed = [] + end + + def <<(text) + if node = @parsed[-1] + if node[:type] == :maybe_tag + @parsed.pop + text = node[:content] + text + end + end + @parsed.concat(parse_tags(text)) + @parsed, result = process_parsed(@parsed) + result + end + + def finish + @parsed.map { |node| node[:content] }.join + end + + def process_parsed(parsed) + output = [] + buffer = [] + stack = [] + + parsed.each do |node| + case node[:type] + when :text + if stack.empty? + output << node[:content] + else + buffer << node + end + when :open_tag + stack << node[:name] + buffer << node + when :close_tag + if stack.empty? + output << node[:content] + else + if stack[0] == node[:name] + buffer = [] + stack = [] + else + buffer << node + end + end + when :maybe_tag + buffer << node + end + end + + result = output.join + result = nil if result.empty? + + [buffer, result] + end + + def parse_tags(text) + parsed = [] + + while true + before, after = text.split("<", 2) + + parsed << { type: :text, content: before } + + break if after.nil? + + tag, after = after.split(">", 2) + + is_end_tag = tag[0] == "/" + tag_name = tag + tag_name = tag[1..-1] || "" if is_end_tag + + if !after + found = false + if tag_name.length <= @longest_tag + @tags_to_strip.each do |tag_to_strip| + if tag_to_strip.start_with?(tag_name) + parsed << { type: :maybe_tag, content: "<" + tag } + found = true + break + end + end + end + parsed << { type: :text, content: "<" + tag } if !found + break + end + + raw_tag = "<" + tag + ">" + + if @tags_to_strip.include?(tag_name) + parsed << { + type: is_end_tag ? :close_tag : :open_tag, + content: raw_tag, + name: tag_name, + } + else + parsed << { type: :text, content: raw_tag } + end + text = after + end + + parsed + end + end + end +end diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index 8eb00256..a201cf5b 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -1,6 +1,10 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Claude do + let :opus_dialect_klass do + DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") + end + describe "#translate" do it "can insert OKs to make stuff interleve properly" do messages = [ @@ -13,8 +17,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) - dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") - dialect = dialectKlass.new(prompt, "claude-3-opus") + dialect = opus_dialect_klass.new(prompt, "claude-3-opus") translated = dialect.translate expected_messages = [ @@ -29,8 +32,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do expect(translated.messages).to eq(expected_messages) end - it "can properly translate a prompt" do - dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") + it "can properly translate a prompt (legacy tools)" do + SiteSetting.ai_anthropic_native_tool_call_models = "" tools = [ { @@ -59,7 +62,59 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do tools: tools, ) - dialect = dialect.new(prompt, "claude-3-opus") + dialect = opus_dialect_klass.new(prompt, "claude-3-opus") + translated = dialect.translate + + expect(translated.system_prompt).to start_with("You are a helpful bot") + + expected = [ + { role: "user", content: "user1: echo something" }, + { + role: "assistant", + content: + "\n\necho\n\nsomething\n\n\n", + }, + { + role: "user", + content: + "\n\ntool_id\n\n\"something\"\n\n\n", + }, + { role: "assistant", content: "I did it" }, + { role: "user", content: "user1: echo something else" }, + ] + expect(translated.messages).to eq(expected) + end + + it "can properly translate a prompt (native tools)" do + SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" + + tools = [ + { + name: "echo", + description: "echo a string", + parameters: [ + { name: "text", type: "string", description: "string to echo", required: true }, + ], + }, + ] + + tool_call_prompt = { name: "echo", arguments: { text: "something" } } + + messages = [ + { type: :user, id: "user1", content: "echo something" }, + { type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json }, + { type: :tool, id: "tool_id", content: "something".to_json }, + { type: :model, content: "I did it" }, + { type: :user, id: "user1", content: "echo something else" }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a helpful bot", + messages: messages, + tools: tools, + ) + dialect = opus_dialect_klass.new(prompt, "claude-3-opus") translated = dialect.translate expect(translated.system_prompt).to start_with("You are a helpful bot") diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index ce185dcf..0c47f0e8 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -48,6 +48,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do before { SiteSetting.ai_anthropic_api_key = "123" } it "does not eat spaces with tool calls" do + SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" body = <<~STRING event: message_start data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} } diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index c6c2399b..b6c96112 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -18,6 +18,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user) end + def encode_message(message) + wrapped = { bytes: Base64.encode64(message.to_json) }.to_json + io = StringIO.new(wrapped) + aws_message = Aws::EventStream::Message.new(payload: io) + Aws::EventStream::Encoder.new.encode(aws_message) + end + before do SiteSetting.ai_bedrock_access_key_id = "123456" SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd" @@ -25,6 +32,85 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do end describe "function calling" do + it "supports old school xml function calls" do + SiteSetting.ai_anthropic_native_tool_call_models = "" + proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") + + incomplete_tool_call = <<~XML.strip + I should be ignored + also ignored + 0 + + + google + sydney weather today + + + XML + + messages = + [ + { type: "message_start", message: { usage: { input_tokens: 9 } } }, + { type: "content_block_delta", delta: { text: "hello\n" } }, + { type: "content_block_delta", delta: { text: incomplete_tool_call } }, + { type: "message_delta", delta: { usage: { output_tokens: 25 } } }, + ].map { |message| encode_message(message) } + + request = nil + bedrock_mock.with_chunk_array_support do + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream", + ) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: messages) + + prompt = + DiscourseAi::Completions::Prompt.new( + messages: [{ type: :user, content: "what is the weather in sydney" }], + ) + + tool = { + name: "google", + description: "Will search using Google", + parameters: [ + { name: "query", description: "The search query", type: "string", required: true }, + ], + } + + prompt.tools = [tool] + response = +"" + proxy.generate(prompt, user: user) { |partial| response << partial } + + expect(request.headers["Authorization"]).to be_present + expect(request.headers["X-Amz-Content-Sha256"]).to be_present + + parsed_body = JSON.parse(request.body) + expect(parsed_body["system"]).to include("") + expect(parsed_body["tools"]).to eq(nil) + expect(parsed_body["stop_sequences"]).to eq([""]) + + # note we now have a tool_id cause we were normalized + function_call = <<~XML.strip + hello + + + + + google + sydney weather today + tool_0 + + + XML + + expect(response.strip).to eq(function_call) + end + end + it "supports streaming function calls" do proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") @@ -48,6 +134,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do stop_reason: nil, }, }, + { + type: "content_block_start", + index: 0, + delta: { + text: "I should be ignored", + }, + }, { type: "content_block_start", index: 0, @@ -111,12 +204,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do firstByteLatency: 402, }, }, - ].map do |message| - wrapped = { bytes: Base64.encode64(message.to_json) }.to_json - io = StringIO.new(wrapped) - aws_message = Aws::EventStream::Message.new(payload: io) - Aws::EventStream::Encoder.new.encode(aws_message) - end + ].map { |message| encode_message(message) } messages = messages.join("").split @@ -248,12 +336,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do { type: "content_block_delta", delta: { text: "hello " } }, { type: "content_block_delta", delta: { text: "sam" } }, { type: "message_delta", delta: { usage: { output_tokens: 25 } } }, - ].map do |message| - wrapped = { bytes: Base64.encode64(message.to_json) }.to_json - io = StringIO.new(wrapped) - aws_message = Aws::EventStream::Message.new(payload: io) - Aws::EventStream::Encoder.new.encode(aws_message) - end + ].map { |message| encode_message(message) } # stream 1 letter at a time # cause we need to handle this case diff --git a/spec/lib/completions/xml_tag_stripper_spec.rb b/spec/lib/completions/xml_tag_stripper_spec.rb new file mode 100644 index 00000000..02ac36c4 --- /dev/null +++ b/spec/lib/completions/xml_tag_stripper_spec.rb @@ -0,0 +1,51 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::PromptMessagesBuilder do + let(:tag_stripper) { DiscourseAi::Completions::XmlTagStripper.new(%w[thinking results]) } + + it "should strip tags correctly in simple cases" do + result = tag_stripper << "xhelloz" + expect(result).to eq("z") + + result = tag_stripper << "king>hello" + expect(result).to eq("king>hello") + + result = tag_stripper << "123" + expect(result).to eq("123") + end + + it "supports odd nesting" do + text = <<~TEXT + + well lets see what happens if I say here... + + hello + TEXT + + result = tag_stripper << text + expect(result).to eq("\nhello\n") + end + + it "works when nesting unrelated tags it strips correctly" do + text = <<~TEXT + + well lets see what happens if I say

here... + + abc hello + TEXT + + result = tag_stripper << text + + expect(result).to eq("\nabc hello\n") + end + + it "handles maybe tags correctly" do + result = tag_stripper << "