From c02794cf2ed088f9737c434dbbdff4d7a160d1ff Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 2 Mar 2024 07:53:21 +1100 Subject: [PATCH] FIX: support multiple tool calls (#502) * FIX: support multiple tool calls Prior to this change we had a hard limit of 1 tool call per llm round trip. This meant you could not google multiple things at once or perform searches across two tools. Also: - Hint when Google stops working - Log topic_id / post_id when performing completions * Also track id for title --- lib/ai_bot/bot.rb | 78 ++++++++++-------- lib/ai_bot/personas/persona.rb | 12 ++- lib/ai_bot/playground.rb | 4 +- lib/ai_bot/tools/google.rb | 14 ++++ lib/completions/dialects/dialect.rb | 4 +- lib/completions/endpoints/base.rb | 20 +++-- lib/completions/endpoints/open_ai.rb | 24 +++++- lib/completions/prompt.rb | 14 +++- .../endpoints/endpoint_compliance.rb | 2 - .../lib/completions/endpoints/open_ai_spec.rb | 82 +++++++++++++++++-- spec/lib/completions/llm_spec.rb | 34 ++++++++ .../modules/ai_bot/personas/persona_spec.rb | 12 ++- spec/lib/modules/ai_bot/playground_spec.rb | 33 ++++++++ 13 files changed, 275 insertions(+), 58 deletions(-) diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 909a3811..d4db4d0d 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -7,6 +7,7 @@ module DiscourseAi BOT_NOT_FOUND = Class.new(StandardError) MAX_COMPLETIONS = 5 + MAX_TOOLS = 5 def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil) new(bot_user, persona, model) @@ -21,14 +22,19 @@ module DiscourseAi attr_reader :bot_user attr_accessor :persona - def get_updated_title(conversation_context, post_user) + def get_updated_title(conversation_context, post) system_insts = <<~TEXT.strip You are titlebot. Given a topic, you will figure out a title. You will never respond with anything but 7 word topic title. TEXT title_prompt = - DiscourseAi::Completions::Prompt.new(system_insts, messages: conversation_context) + DiscourseAi::Completions::Prompt.new( + system_insts, + messages: conversation_context, + topic_id: post.topic_id, + post_id: post.id, + ) title_prompt.push( type: :user, @@ -38,7 +44,7 @@ module DiscourseAi DiscourseAi::Completions::Llm .proxy(model) - .generate(title_prompt, user: post_user) + .generate(title_prompt, user: post.user) .strip .split("\n") .last @@ -64,37 +70,14 @@ module DiscourseAi result = llm.generate(prompt, **llm_kwargs) do |partial, cancel| - if (tool = persona.find_tool(partial)) + tools = persona.find_tools(partial) + + if (tools.present?) tool_found = true - ongoing_chain = tool.chain_next_response? - tool_call_id = tool.tool_call_id - invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json - - tool_call_message = { - type: :tool_call, - id: tool_call_id, - content: { name: tool.name, arguments: tool.parameters }.to_json, - } - - tool_message = { type: :tool, id: tool_call_id, content: invocation_result_json } - - if tool.standalone? - standalone_context = - context.dup.merge( - conversation_context: [ - context[:conversation_context].last, - tool_call_message, - tool_message, - ], - ) - prompt = persona.craft_prompt(standalone_context) - else - prompt.push(**tool_call_message) - prompt.push(**tool_message) + tools[0..MAX_TOOLS].each do |tool| + ongoing_chain &&= tool.chain_next_response? + process_tool(tool, raw_context, llm, cancel, update_blk, prompt) end - - raw_context << [tool_call_message[:content], tool_call_id, "tool_call"] - raw_context << [invocation_result_json, tool_call_id, "tool"] else update_blk.call(partial, cancel, nil) end @@ -115,6 +98,37 @@ module DiscourseAi private + def process_tool(tool, raw_context, llm, cancel, update_blk, prompt) + tool_call_id = tool.tool_call_id + invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json + + tool_call_message = { + type: :tool_call, + id: tool_call_id, + content: { name: tool.name, arguments: tool.parameters }.to_json, + } + + tool_message = { type: :tool, id: tool_call_id, content: invocation_result_json } + + if tool.standalone? + standalone_context = + context.dup.merge( + conversation_context: [ + context[:conversation_context].last, + tool_call_message, + tool_message, + ], + ) + prompt = persona.craft_prompt(standalone_context) + else + prompt.push(**tool_call_message) + prompt.push(**tool_message) + end + + raw_context << [tool_call_message[:content], tool_call_id, "tool_call"] + raw_context << [invocation_result_json, tool_call_id, "tool"] + end + def invoke_tool(tool, llm, cancel, &update_blk) update_blk.call("", cancel, build_placeholder(tool.summary, "")) diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 920e24d2..81e8d7a1 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -117,6 +117,8 @@ module DiscourseAi #{available_tools.map(&:custom_system_message).compact_blank.join("\n")} TEXT messages: context[:conversation_context].to_a, + topic_id: context[:topic_id], + post_id: context[:post_id], ) prompt.tools = available_tools.map(&:signature) if available_tools @@ -124,8 +126,16 @@ module DiscourseAi prompt end - def find_tool(partial) + def find_tools(partial) + return [] if !partial.include?("") + parsed_function = Nokogiri::HTML5.fragment(partial) + parsed_function.css("invoke").map { |fragment| find_tool(fragment) }.compact + end + + protected + + def find_tool(parsed_function) function_id = parsed_function.at("tool_id")&.text function_name = parsed_function.at("tool_name")&.text return false if function_name.nil? diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 27386b8b..26150762 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -156,7 +156,7 @@ module DiscourseAi context = conversation_context(post) bot - .get_updated_title(context, post.user) + .get_updated_title(context, post) .tap do |new_title| PostRevisor.new(post.topic.first_post, post.topic).revise!( bot.bot_user, @@ -182,6 +182,8 @@ module DiscourseAi participants: post.topic.allowed_users.map(&:username).join(", "), conversation_context: conversation_context(post), user: post.user, + post_id: post.id, + topic_id: post.topic_id, } reply_user = bot.bot_user diff --git a/lib/ai_bot/tools/google.rb b/lib/ai_bot/tools/google.rb index bcb0aff7..a41a9a85 100644 --- a/lib/ai_bot/tools/google.rb +++ b/lib/ai_bot/tools/google.rb @@ -37,6 +37,7 @@ module DiscourseAi URI( "https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10", ) + body = Net::HTTP.get(uri) parse_search_json(body, escaped_query, llm) @@ -65,6 +66,19 @@ module DiscourseAi def parse_search_json(json_data, escaped_query, llm) parsed = JSON.parse(json_data) + error_code = parsed.dig("error", "code") + if error_code == 429 + Rails.logger.warn( + "Google Custom Search is Rate Limited, no search can be performed at the moment. #{json_data[0..1000]}", + ) + return( + "Google Custom Search is Rate Limited, no search can be performed at the moment. Let the user know there is a problem." + ) + elsif error_code + Rails.logger.warn("Google Custom Search returned an error. #{json_data[0..1000]}") + return "Google Custom Search returned an error. Let the user know there is a problem." + end + results = parsed["items"] @results_count = parsed.dig("searchInformation", "totalResults").to_i diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 8b6acd59..bed1c884 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -106,9 +106,11 @@ module DiscourseAi raise NotImplemented end + attr_reader :prompt + private - attr_reader :prompt, :model_name, :opts + attr_reader :model_name, :opts def trim_messages(messages) prompt_limit = max_prompt_tokens diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index d9962a5b..718d021d 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -100,6 +100,8 @@ module DiscourseAi user_id: user&.id, raw_request_payload: request_body, request_tokens: prompt_size(prompt), + topic_id: dialect.prompt.topic_id, + post_id: dialect.prompt.post_id, ) if !@streaming_mode @@ -273,16 +275,22 @@ module DiscourseAi def build_buffer Nokogiri::HTML5.fragment(<<~TEXT) - - - - - - + #{noop_function_call_text} TEXT end + def noop_function_call_text + (<<~TEXT).strip + + + + + + + TEXT + end + def has_tool?(response) response.include?(" + + search + call_3Gyr3HylFJwfrtKrL6NaIit1 + + Discourse AI bot + + + + search + call_H7YkbgYurHpyJqzwUN4bghwN + + Discourse AI bot + + + + TEXT + + expect(content).to eq(expected) + end + it "properly handles spaces in tools payload" do raw_data = <<~TEXT.strip - data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"google","arguments":""}}]}}]} + data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]} data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\""}}]}}]} @@ -253,9 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do open_ai_mock.stub_raw(chunks) partials = [] - endpoint.perform_completion!(compliance.dialect, user) do |partial, x, y| - partials << partial - end + endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial } expect(partials.length).to eq(1) diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index 556f1810..10c7d8ee 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -21,6 +21,40 @@ RSpec.describe DiscourseAi::Completions::Llm do end end + describe "AiApiAuditLog" do + it "is able to keep track of post and topic id" do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are fake", + messages: [{ type: :user, content: "fake orders" }], + topic_id: 123, + post_id: 1, + ) + + result = <<~TEXT + data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + + data: [DONE] + TEXT + + WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( + status: 200, + body: result, + ) + result = +"" + described_class + .proxy("open_ai:gpt-3.5-turbo") + .generate(prompt, user: user) { |partial| result << partial } + + expect(result).to eq("Hello") + log = AiApiAuditLog.order("id desc").first + expect(log.topic_id).to eq(123) + expect(log.post_id).to eq(1) + end + end + describe "#generate with fake model" do before do DiscourseAi::Completions::Endpoints::Fake.delays = [] diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 2a30ead0..0e2cf504 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -82,10 +82,18 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do ["cat oil painting", "big car"] + + dall_e + abc + + ["pic3"] + + XML - dall_e = DiscourseAi::AiBot::Personas::DallE3.new.find_tool(xml) - expect(dall_e.parameters[:prompts]).to eq(["cat oil painting", "big car"]) + dall_e1, dall_e2 = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml) + expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) + expect(dall_e2.parameters[:prompts]).to eq(["pic3"]) end describe "custom personas" do diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 5b77cf1c..9a3c7734 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -212,6 +212,39 @@ RSpec.describe DiscourseAi::AiBot::Playground do end end + it "supports multiple function calls" do + response1 = (<<~TXT).strip + + + search + search + + testing various things + + + + search + search + + another search + + + + TXT + + response2 = "I found stuff" + + DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do + playground.reply_to(third_post) + end + + last_post = third_post.topic.reload.posts.order(:post_number).last + + expect(last_post.raw).to include("testing various things") + expect(last_post.raw).to include("another search") + expect(last_post.raw).to include("I found stuff") + end + it "does not include placeholders in conversation context but includes all completions" do response1 = (<<~TXT).strip