FEATURE: allow specifying tool use none in completion prompt
This PR adds support for disabling further tool calls by setting tool_choice to :none across all supported LLM providers: - OpenAI: Uses "none" tool_choice parameter - Anthropic: Uses {type: "none"} and adds a prefill message to prevent confusion - Gemini: Sets function_calling_config mode to "NONE" - AWS Bedrock: Doesn't natively support tool disabling, so adds a prefill message We previously used to disable tool calls by simply removing tool definitions, but this would cause errors with some providers. This implementation uses the supported method appropriate for each provider while providing a fallback for Bedrock. Co-authored-by: Natalie Tay <natalie.tay@gmail.com> * remove stray puts * cleaner chain breaker for last tool call (works in thinking) remove unused code * improve test --------- Co-authored-by: Natalie Tay <natalie.tay@gmail.com>
This commit is contained in:
parent
50e1bc774a
commit
1dde82eb58
|
@ -6,8 +6,10 @@ module DiscourseAi
|
|||
attr_reader :model
|
||||
|
||||
BOT_NOT_FOUND = Class.new(StandardError)
|
||||
|
||||
# the future is agentic, allow for more turns
|
||||
MAX_COMPLETIONS = 8
|
||||
|
||||
# limit is arbitrary, but 5 which was used in the past was too low
|
||||
MAX_TOOLS = 20
|
||||
|
||||
|
@ -71,6 +73,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def force_tool_if_needed(prompt, context)
|
||||
return if prompt.tool_choice == :none
|
||||
|
||||
context[:chosen_tools] ||= []
|
||||
forced_tools = persona.force_tool_use.map { |tool| tool.name }
|
||||
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
|
||||
|
@ -105,7 +109,7 @@ module DiscourseAi
|
|||
needs_newlines = false
|
||||
tools_ran = 0
|
||||
|
||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||
while total_completions < MAX_COMPLETIONS && ongoing_chain
|
||||
tool_found = false
|
||||
force_tool_if_needed(prompt, context)
|
||||
|
||||
|
@ -202,8 +206,8 @@ module DiscourseAi
|
|||
|
||||
total_completions += 1
|
||||
|
||||
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
|
||||
prompt.tools = [] if total_completions == MAX_COMPLETIONS
|
||||
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS - 1)
|
||||
prompt.tool_choice = :none if total_completions == MAX_COMPLETIONS - 1
|
||||
end
|
||||
|
||||
embed_thinking(raw_context)
|
||||
|
|
|
@ -46,10 +46,6 @@ module DiscourseAi
|
|||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def can_end_with_assistant_msg?
|
||||
false
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
false
|
||||
end
|
||||
|
@ -66,16 +62,58 @@ module DiscourseAi
|
|||
prompt.tool_choice
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = prompt.messages
|
||||
def self.no_more_tool_calls_text
|
||||
# note, Anthropic must never prefill with an ending whitespace
|
||||
"I WILL NOT USE TOOLS IN THIS REPLY, user expressed they wanted to stop using tool calls.\nHere is the best, complete, answer I can come up with given the information I have."
|
||||
end
|
||||
|
||||
# Some models use an assistant msg to improve long-context responses.
|
||||
if messages.last[:type] == :model && can_end_with_assistant_msg?
|
||||
messages = messages.dup
|
||||
messages.pop
|
||||
def self.no_more_tool_calls_text_user
|
||||
"DO NOT USE TOOLS IN YOUR REPLY. Return the best answer you can given the information I supplied you."
|
||||
end
|
||||
|
||||
def no_more_tool_calls_text
|
||||
self.class.no_more_tool_calls_text
|
||||
end
|
||||
|
||||
def no_more_tool_calls_text_user
|
||||
self.class.no_more_tool_calls_text_user
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = trim_messages(prompt.messages)
|
||||
last_message = messages.last
|
||||
inject_done_on_last_tool_call = false
|
||||
|
||||
if !native_tool_support? && last_message && last_message[:type].to_sym == :tool &&
|
||||
prompt.tool_choice == :none
|
||||
inject_done_on_last_tool_call = true
|
||||
end
|
||||
|
||||
trim_messages(messages).map { |msg| send("#{msg[:type]}_msg", msg) }.compact
|
||||
translated =
|
||||
messages
|
||||
.map do |msg|
|
||||
case msg[:type].to_sym
|
||||
when :system
|
||||
system_msg(msg)
|
||||
when :user
|
||||
user_msg(msg)
|
||||
when :model
|
||||
model_msg(msg)
|
||||
when :tool
|
||||
if inject_done_on_last_tool_call && msg == last_message
|
||||
tools_dialect.inject_done { tool_msg(msg) }
|
||||
else
|
||||
tool_msg(msg)
|
||||
end
|
||||
when :tool_call
|
||||
tool_call_msg(msg)
|
||||
else
|
||||
raise ArgumentError, "Unknown message type: #{msg[:type]}"
|
||||
end
|
||||
end
|
||||
.compact
|
||||
|
||||
translated
|
||||
end
|
||||
|
||||
def conversation_context
|
||||
|
|
|
@ -54,8 +54,11 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
DONE_MESSAGE =
|
||||
"Regardless of what you think, REPLY IMMEDIATELY, WITHOUT MAKING ANY FURTHER TOOL CALLS, YOU ARE OUT OF TOOL CALL QUOTA!"
|
||||
|
||||
def from_raw_tool(raw_message)
|
||||
(<<~TEXT).strip
|
||||
result = (<<~TEXT).strip
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>#{raw_message[:name] || raw_message[:id]}</tool_name>
|
||||
|
@ -65,6 +68,12 @@ module DiscourseAi
|
|||
</result>
|
||||
</function_results>
|
||||
TEXT
|
||||
|
||||
if @injecting_done
|
||||
"#{result}\n\n#{DONE_MESSAGE}"
|
||||
else
|
||||
result
|
||||
end
|
||||
end
|
||||
|
||||
def from_raw_tool_call(raw_message)
|
||||
|
@ -86,6 +95,13 @@ module DiscourseAi
|
|||
TEXT
|
||||
end
|
||||
|
||||
def inject_done(&blk)
|
||||
@injecting_done = true
|
||||
blk.call
|
||||
ensure
|
||||
@injecting_done = false
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :raw_tools
|
||||
|
|
|
@ -95,7 +95,18 @@ module DiscourseAi
|
|||
if prompt.has_tools?
|
||||
payload[:tools] = prompt.tools
|
||||
if dialect.tool_choice.present?
|
||||
payload[:tool_choice] = { type: "tool", name: dialect.tool_choice }
|
||||
if dialect.tool_choice == :none
|
||||
payload[:tool_choice] = { type: "none" }
|
||||
|
||||
# prefill prompt to nudge LLM to generate a response that is useful.
|
||||
# without this LLM (even 3.7) can get confused and start text preambles for a tool calls.
|
||||
payload[:messages] << {
|
||||
role: "assistant",
|
||||
content: dialect.no_more_tool_calls_text,
|
||||
}
|
||||
else
|
||||
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -122,7 +122,19 @@ module DiscourseAi
|
|||
if prompt.has_tools?
|
||||
payload[:tools] = prompt.tools
|
||||
if dialect.tool_choice.present?
|
||||
payload[:tool_choice] = { type: "tool", name: dialect.tool_choice }
|
||||
if dialect.tool_choice == :none
|
||||
# not supported on bedrock as of 2025-03-24
|
||||
# retest in 6 months
|
||||
# payload[:tool_choice] = { type: "none" }
|
||||
|
||||
# prefill prompt to nudge LLM to generate a response that is useful, instead of trying to call a tool
|
||||
payload[:messages] << {
|
||||
role: "assistant",
|
||||
content: dialect.no_more_tool_calls_text,
|
||||
}
|
||||
else
|
||||
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
|
||||
end
|
||||
end
|
||||
end
|
||||
elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova)
|
||||
|
|
|
@ -72,10 +72,14 @@ module DiscourseAi
|
|||
|
||||
function_calling_config = { mode: "AUTO" }
|
||||
if dialect.tool_choice.present?
|
||||
function_calling_config = {
|
||||
mode: "ANY",
|
||||
allowed_function_names: [dialect.tool_choice],
|
||||
}
|
||||
if dialect.tool_choice == :none
|
||||
function_calling_config = { mode: "NONE" }
|
||||
else
|
||||
function_calling_config = {
|
||||
mode: "ANY",
|
||||
allowed_function_names: [dialect.tool_choice],
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
payload[:tool_config] = { function_calling_config: function_calling_config }
|
||||
|
|
|
@ -92,12 +92,16 @@ module DiscourseAi
|
|||
if dialect.tools.present?
|
||||
payload[:tools] = dialect.tools
|
||||
if dialect.tool_choice.present?
|
||||
payload[:tool_choice] = {
|
||||
type: "function",
|
||||
function: {
|
||||
name: dialect.tool_choice,
|
||||
},
|
||||
}
|
||||
if dialect.tool_choice == :none
|
||||
payload[:tool_choice] = "none"
|
||||
else
|
||||
payload[:tool_choice] = {
|
||||
type: "function",
|
||||
function: {
|
||||
name: dialect.tool_choice,
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -7,6 +7,18 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
|
|||
trim_messages(messages)
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
msg
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
msg
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
msg
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
@ -15,6 +27,57 @@ end
|
|||
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
||||
fab!(:llm_model)
|
||||
|
||||
describe "#translate" do
|
||||
let(:five_token_msg) { "This represents five tokens." }
|
||||
let(:tools) do
|
||||
[
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo a string",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "string to echo", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
end
|
||||
|
||||
it "injects done message when tool_choice is :none and last message follows tool pattern" do
|
||||
tool_call_prompt = { name: "echo", arguments: { text: "test message" } }
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("System instructions", tools: tools)
|
||||
prompt.push(type: :user, content: "echo test message")
|
||||
prompt.push(type: :tool_call, content: tool_call_prompt.to_json, id: "123", name: "echo")
|
||||
prompt.push(type: :tool, content: "test message".to_json, name: "echo", id: "123")
|
||||
prompt.tool_choice = :none
|
||||
|
||||
dialect = TestDialect.new(prompt, llm_model)
|
||||
dialect.max_prompt_tokens = 100 # Set high enough to avoid trimming
|
||||
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(
|
||||
[
|
||||
{ type: :system, content: "System instructions" },
|
||||
{ type: :user, content: "echo test message" },
|
||||
{
|
||||
type: :tool_call,
|
||||
content:
|
||||
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>test message</text>\n</parameters>\n</invoke>\n</function_calls>",
|
||||
id: "123",
|
||||
name: "echo",
|
||||
},
|
||||
{
|
||||
type: :tool,
|
||||
id: "123",
|
||||
name: "echo",
|
||||
content:
|
||||
"<function_results>\n<result>\n<tool_name>echo</tool_name>\n<json>\n\"test message\"\n</json>\n</result>\n</function_results>\n\n#{::DiscourseAi::Completions::Dialects::XmlTools::DONE_MESSAGE}",
|
||||
},
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#trim_messages" do
|
||||
let(:five_token_msg) { "This represents five tokens." }
|
||||
|
||||
|
|
|
@ -714,4 +714,59 @@ data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_
|
|||
expect(parsed_body[:max_tokens]).to eq(500)
|
||||
end
|
||||
end
|
||||
|
||||
describe "disabled tool use" do
|
||||
it "can properly disable tool use with :none" do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [type: :user, id: "user1", content: "don't use any tools please"],
|
||||
tools: [echo_tool],
|
||||
tool_choice: :none,
|
||||
)
|
||||
|
||||
response_body = {
|
||||
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: "claude-3-haiku-20240307",
|
||||
content: [
|
||||
{ type: "text", text: "I won't use any tools. Here's a direct response instead." },
|
||||
],
|
||||
stop_reason: "end_turn",
|
||||
stop_sequence: nil,
|
||||
usage: {
|
||||
input_tokens: 345,
|
||||
output_tokens: 65,
|
||||
},
|
||||
}.to_json
|
||||
|
||||
parsed_body = nil
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
).to_return(status: 200, body: response_body)
|
||||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
# Verify that tool_choice is set to { type: "none" }
|
||||
expect(parsed_body[:tool_choice]).to eq({ type: "none" })
|
||||
|
||||
# Verify that an assistant message with no_more_tool_calls_text was added
|
||||
messages = parsed_body[:messages]
|
||||
expect(messages.length).to eq(2) # user message + added assistant message
|
||||
|
||||
last_message = messages.last
|
||||
expect(last_message[:role]).to eq("assistant")
|
||||
|
||||
expect(last_message[:content]).to eq(
|
||||
DiscourseAi::Completions::Dialects::Dialect.no_more_tool_calls_text,
|
||||
)
|
||||
|
||||
expect(result).to eq("I won't use any tools. Here's a direct response instead.")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -484,4 +484,66 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
expect(request_body["max_tokens"]).to eq(500)
|
||||
end
|
||||
end
|
||||
|
||||
describe "disabled tool use" do
|
||||
it "handles tool_choice: :none by adding a prefill message instead of using tool_choice param" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
request = nil
|
||||
|
||||
# Create a prompt with tool_choice: :none
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a helpful assistant",
|
||||
messages: [{ type: :user, content: "don't use any tools please" }],
|
||||
tools: [
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo something",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "text to echo", required: true },
|
||||
],
|
||||
},
|
||||
],
|
||||
tool_choice: :none,
|
||||
)
|
||||
|
||||
# Mock response from Bedrock
|
||||
content = {
|
||||
content: [text: "I won't use any tools. Here's a direct response instead."],
|
||||
usage: {
|
||||
input_tokens: 25,
|
||||
output_tokens: 15,
|
||||
},
|
||||
}.to_json
|
||||
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: content)
|
||||
|
||||
proxy.generate(prompt, user: user)
|
||||
|
||||
# Parse the request body
|
||||
request_body = JSON.parse(request.body)
|
||||
|
||||
# Verify that tool_choice is NOT present (not supported in Bedrock)
|
||||
expect(request_body).not_to have_key("tool_choice")
|
||||
|
||||
# Verify that an assistant message was added with no_more_tool_calls_text
|
||||
messages = request_body["messages"]
|
||||
expect(messages.length).to eq(2) # user message + added assistant message
|
||||
|
||||
last_message = messages.last
|
||||
expect(last_message["role"]).to eq("assistant")
|
||||
|
||||
expect(last_message["content"]).to eq(
|
||||
DiscourseAi::Completions::Dialects::Dialect.no_more_tool_calls_text,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -377,4 +377,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
|
||||
expect(output.join).to eq("Hello World Sam")
|
||||
end
|
||||
|
||||
it "can properly disable tool use with :none" do
|
||||
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool], tool_choice: :none)
|
||||
|
||||
response = gemini_mock.response("I won't use any tools").to_json
|
||||
|
||||
req_body = nil
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
url = "#{model.url}:generateContent?key=123"
|
||||
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
proc do |_req_body|
|
||||
req_body = _req_body
|
||||
true
|
||||
end,
|
||||
).to_return(status: 200, body: response)
|
||||
|
||||
response = llm.generate(prompt, user: user)
|
||||
|
||||
expect(response).to eq("I won't use any tools")
|
||||
|
||||
parsed = JSON.parse(req_body, symbolize_names: true)
|
||||
|
||||
# Verify that function_calling_config mode is set to "NONE"
|
||||
expect(parsed[:tool_config]).to eq({ function_calling_config: { mode: "NONE" } })
|
||||
end
|
||||
|
||||
it "can properly force specific tool use" do
|
||||
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool], tool_choice: "echo")
|
||||
|
||||
response = gemini_mock.response("World").to_json
|
||||
|
||||
req_body = nil
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
url = "#{model.url}:generateContent?key=123"
|
||||
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
proc do |_req_body|
|
||||
req_body = _req_body
|
||||
true
|
||||
end,
|
||||
).to_return(status: 200, body: response)
|
||||
|
||||
response = llm.generate(prompt, user: user)
|
||||
|
||||
parsed = JSON.parse(req_body, symbolize_names: true)
|
||||
|
||||
# Verify that function_calling_config is correctly set to ANY mode with the specified tool
|
||||
expect(parsed[:tool_config]).to eq(
|
||||
{ function_calling_config: { mode: "ANY", allowed_function_names: ["echo"] } },
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -395,6 +395,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
end
|
||||
end
|
||||
|
||||
describe "disabled tool use" do
|
||||
it "can properly disable tool use with :none" do
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
|
||||
tools = [
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo something",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "text to echo", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [type: :user, id: "user1", content: "don't use any tools please"],
|
||||
tools: tools,
|
||||
tool_choice: :none,
|
||||
)
|
||||
|
||||
response = {
|
||||
id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH",
|
||||
object: "chat.completion",
|
||||
created: 1_714_544_914,
|
||||
model: "gpt-4-turbo-2024-04-09",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "I won't use any tools. Here's a direct response instead.",
|
||||
},
|
||||
logprobs: nil,
|
||||
finish_reason: "stop",
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 55,
|
||||
completion_tokens: 13,
|
||||
total_tokens: 68,
|
||||
},
|
||||
system_fingerprint: "fp_ea6eb70039",
|
||||
}.to_json
|
||||
|
||||
body_json = nil
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
|
||||
body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) },
|
||||
).to_return(body: response)
|
||||
|
||||
result = llm.generate(prompt, user: user)
|
||||
|
||||
# Verify that tool_choice is set to "none" in the request
|
||||
expect(body_json[:tool_choice]).to eq("none")
|
||||
expect(result).to eq("I won't use any tools. Here's a direct response instead.")
|
||||
end
|
||||
end
|
||||
|
||||
describe "parameter disabling" do
|
||||
it "excludes disabled parameters from the request" do
|
||||
model.update!(provider_params: { disable_top_p: true, disable_temperature: true })
|
||||
|
|
Loading…
Reference in New Issue