FIX: Correctly pass tool_choice when using Claude models. (#1364)
The `ClaudePrompt` object couldn't access the original prompt's tool_choice attribute, affecting both Anthropic and Bedrock.
This commit is contained in:
parent
cf220c530c
commit
0ce17a122f
|
@ -13,14 +13,13 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
class ClaudePrompt
|
||||
attr_reader :system_prompt
|
||||
attr_reader :messages
|
||||
attr_reader :tools
|
||||
attr_reader :system_prompt, :messages, :tools, :tool_choice
|
||||
|
||||
def initialize(system_prompt, messages, tools)
|
||||
def initialize(system_prompt, messages, tools, tool_choice)
|
||||
@system_prompt = system_prompt
|
||||
@messages = messages
|
||||
@tools = tools
|
||||
@tool_choice = tool_choice
|
||||
end
|
||||
|
||||
def has_tools?
|
||||
|
@ -55,7 +54,7 @@ module DiscourseAi
|
|||
tools = nil
|
||||
tools = tools_dialect.translated_tools if native_tool_support?
|
||||
|
||||
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools)
|
||||
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools, tool_choice)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
|
|
|
@ -770,6 +770,55 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
end
|
||||
end
|
||||
|
||||
describe "forced tool use" do
|
||||
it "can properly force tool use" do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [type: :user, id: "user1", content: "echo hello"],
|
||||
tools: [echo_tool],
|
||||
tool_choice: "echo",
|
||||
)
|
||||
|
||||
response_body = {
|
||||
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: "claude-3-haiku-20240307",
|
||||
content: [
|
||||
{
|
||||
type: "tool_use",
|
||||
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
||||
name: "echo",
|
||||
input: {
|
||||
text: "hello",
|
||||
},
|
||||
},
|
||||
],
|
||||
stop_reason: "end_turn",
|
||||
stop_sequence: nil,
|
||||
usage: {
|
||||
input_tokens: 345,
|
||||
output_tokens: 65,
|
||||
},
|
||||
}.to_json
|
||||
|
||||
parsed_body = nil
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
).to_return(status: 200, body: response_body)
|
||||
|
||||
llm.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
# Verify that tool_choice: "echo" is present
|
||||
expect(parsed_body.dig(:tool_choice, :name)).to eq("echo")
|
||||
end
|
||||
end
|
||||
|
||||
describe "structured output via prefilling" do
|
||||
it "forces the response to be a JSON and using the given JSON schema" do
|
||||
schema = {
|
||||
|
|
|
@ -547,6 +547,67 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
end
|
||||
end
|
||||
|
||||
describe "forced tool use" do
|
||||
it "can properly force tool use" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
request = nil
|
||||
|
||||
tools = [
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo something",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "text to echo", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [type: :user, id: "user1", content: "echo hello"],
|
||||
tools: tools,
|
||||
tool_choice: "echo",
|
||||
)
|
||||
|
||||
# Mock response from Bedrock
|
||||
content = {
|
||||
content: [
|
||||
{
|
||||
type: "tool_use",
|
||||
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
||||
name: "echo",
|
||||
input: {
|
||||
text: "hello",
|
||||
},
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
input_tokens: 25,
|
||||
output_tokens: 15,
|
||||
},
|
||||
}.to_json
|
||||
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: content)
|
||||
|
||||
proxy.generate(prompt, user: user)
|
||||
|
||||
# Parse the request body
|
||||
request_body = JSON.parse(request.body)
|
||||
|
||||
# Verify that tool_choice: "echo" is present
|
||||
expect(request_body.dig("tool_choice", "name")).to eq("echo")
|
||||
end
|
||||
end
|
||||
|
||||
describe "structured output via prefilling" do
|
||||
it "forces the response to be a JSON and using the given JSON schema" do
|
||||
schema = {
|
||||
|
|
Loading…
Reference in New Issue