FIX: Correctly pass tool_choice when using Claude models. (#1364)

The `ClaudePrompt` object couldn't access the original prompt's tool_choice attribute, affecting both Anthropic and Bedrock.
This commit is contained in:
Roman Rizzi 2025-05-23 10:36:52 -03:00 committed by GitHub
parent cf220c530c
commit 0ce17a122f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 114 additions and 5 deletions

View File

@ -13,14 +13,13 @@ module DiscourseAi
end
class ClaudePrompt
attr_reader :system_prompt
attr_reader :messages
attr_reader :tools
attr_reader :system_prompt, :messages, :tools, :tool_choice
def initialize(system_prompt, messages, tools)
def initialize(system_prompt, messages, tools, tool_choice)
@system_prompt = system_prompt
@messages = messages
@tools = tools
@tool_choice = tool_choice
end
def has_tools?
@ -55,7 +54,7 @@ module DiscourseAi
tools = nil
tools = tools_dialect.translated_tools if native_tool_support?
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools)
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools, tool_choice)
end
def max_prompt_tokens

View File

@ -770,6 +770,55 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end
end
describe "forced tool use" do
it "can properly force tool use" do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [type: :user, id: "user1", content: "echo hello"],
tools: [echo_tool],
tool_choice: "echo",
)
response_body = {
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
type: "message",
role: "assistant",
model: "claude-3-haiku-20240307",
content: [
{
type: "tool_use",
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "echo",
input: {
text: "hello",
},
},
],
stop_reason: "end_turn",
stop_sequence: nil,
usage: {
input_tokens: 345,
output_tokens: 65,
},
}.to_json
parsed_body = nil
stub_request(:post, url).with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
).to_return(status: 200, body: response_body)
llm.generate(prompt, user: Discourse.system_user)
# Verify that tool_choice: "echo" is present
expect(parsed_body.dig(:tool_choice, :name)).to eq("echo")
end
end
describe "structured output via prefilling" do
it "forces the response to be a JSON and using the given JSON schema" do
schema = {

View File

@ -547,6 +547,67 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end
end
describe "forced tool use" do
it "can properly force tool use" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
tools = [
{
name: "echo",
description: "echo something",
parameters: [
{ name: "text", type: "string", description: "text to echo", required: true },
],
},
]
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [type: :user, id: "user1", content: "echo hello"],
tools: tools,
tool_choice: "echo",
)
# Mock response from Bedrock
content = {
content: [
{
type: "tool_use",
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "echo",
input: {
text: "hello",
},
},
],
usage: {
input_tokens: 25,
output_tokens: 15,
},
}.to_json
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: content)
proxy.generate(prompt, user: user)
# Parse the request body
request_body = JSON.parse(request.body)
# Verify that tool_choice: "echo" is present
expect(request_body.dig("tool_choice", "name")).to eq("echo")
end
end
describe "structured output via prefilling" do
it "forces the response to be a JSON and using the given JSON schema" do
schema = {