diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 1cea2215..0ad75624 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -13,14 +13,13 @@ module DiscourseAi end class ClaudePrompt - attr_reader :system_prompt - attr_reader :messages - attr_reader :tools + attr_reader :system_prompt, :messages, :tools, :tool_choice - def initialize(system_prompt, messages, tools) + def initialize(system_prompt, messages, tools, tool_choice) @system_prompt = system_prompt @messages = messages @tools = tools + @tool_choice = tool_choice end def has_tools? @@ -55,7 +54,7 @@ module DiscourseAi tools = nil tools = tools_dialect.translated_tools if native_tool_support? - ClaudePrompt.new(system_prompt.presence, interleving_messages, tools) + ClaudePrompt.new(system_prompt.presence, interleving_messages, tools, tool_choice) end def max_prompt_tokens diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index f0e53bea..24d7e0f5 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -770,6 +770,55 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end end + describe "forced tool use" do + it "can properly force tool use" do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [type: :user, id: "user1", content: "echo hello"], + tools: [echo_tool], + tool_choice: "echo", + ) + + response_body = { + id: "msg_01RdJkxCbsEj9VFyFYAkfy2S", + type: "message", + role: "assistant", + model: "claude-3-haiku-20240307", + content: [ + { + type: "tool_use", + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "echo", + input: { + text: "hello", + }, + }, + ], + stop_reason: "end_turn", + stop_sequence: nil, + usage: { + input_tokens: 345, + output_tokens: 65, + }, + }.to_json + + parsed_body = nil + stub_request(:post, url).with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + ).to_return(status: 200, body: response_body) + + llm.generate(prompt, user: Discourse.system_user) + + # Verify that tool_choice: "echo" is present + expect(parsed_body.dig(:tool_choice, :name)).to eq("echo") + end + end + describe "structured output via prefilling" do it "forces the response to be a JSON and using the given JSON schema" do schema = { diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index bd60f988..364c3b6b 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -547,6 +547,67 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do end end + describe "forced tool use" do + it "can properly force tool use" do + proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + request = nil + + tools = [ + { + name: "echo", + description: "echo something", + parameters: [ + { name: "text", type: "string", description: "text to echo", required: true }, + ], + }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [type: :user, id: "user1", content: "echo hello"], + tools: tools, + tool_choice: "echo", + ) + + # Mock response from Bedrock + content = { + content: [ + { + type: "tool_use", + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "echo", + input: { + text: "hello", + }, + }, + ], + usage: { + input_tokens: 25, + output_tokens: 15, + }, + }.to_json + + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke", + ) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: content) + + proxy.generate(prompt, user: user) + + # Parse the request body + request_body = JSON.parse(request.body) + + # Verify that tool_choice: "echo" is present + expect(request_body.dig("tool_choice", "name")).to eq("echo") + end + end + describe "structured output via prefilling" do it "forces the response to be a JSON and using the given JSON schema" do schema = {