From 6c4c96e83c2e09940b4536023c22dbbc01db0c07 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 11 Oct 2024 07:23:42 +1100 Subject: [PATCH] FEATURE: allow persona to only force tool calls on limited replies (#827) This introduces another configuration that allows operators to limit the amount of interactions with forced tool usage. Forced tools are very handy in initial llm interactions, but as conversation progresses they can hinder by slowing down stuff and adding confusion. --- .../admin/ai_personas_controller.rb | 1 + app/models/ai_persona.rb | 62 +++++++++++-------- .../localized_ai_persona_serializer.rb | 3 +- .../discourse/admin/models/ai-persona.js | 2 + .../ai-forced-tool-strategy-selector.gjs | 29 +++++++++ .../components/ai-persona-editor.gjs | 20 +++++- config/locales/client.en.yml | 6 ++ ...24_add_forced_tool_count_to_ai_personas.rb | 7 +++ lib/ai_bot/bot.rb | 5 ++ lib/ai_bot/personas/persona.rb | 4 ++ spec/lib/modules/ai_bot/playground_spec.rb | 18 +++++- .../admin/ai_personas_controller_spec.rb | 10 ++- spec/system/admin_ai_persona_spec.rb | 14 ++++- .../unit/models/ai-persona-test.js | 2 + 14 files changed, 149 insertions(+), 34 deletions(-) create mode 100644 assets/javascripts/discourse/components/ai-forced-tool-strategy-selector.gjs create mode 100644 db/migrate/20241009230724_add_forced_tool_count_to_ai_personas.rb diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index 2f3563ac..d272910b 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -106,6 +106,7 @@ module DiscourseAi :question_consolidator_llm, :allow_chat, :tool_details, + :forced_tool_count, allowed_group_ids: [], rag_uploads: [:id], ) diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 54eb6310..350ec064 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 } validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 } validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 } + validates :forced_tool_count, numericality: { greater_than: -2, maximum: 100_000 } has_many :rag_document_fragments, dependent: :destroy, as: :target belongs_to :created_by, class_name: "User" @@ -185,6 +186,7 @@ class AiPersona < ActiveRecord::Base define_method(:tools) { tools } define_method(:force_tool_use) { force_tool_use } + define_method(:forced_tool_count) { @ai_persona&.forced_tool_count } define_method(:options) { options } define_method(:temperature) { @ai_persona&.temperature } define_method(:top_p) { @ai_persona&.top_p } @@ -265,32 +267,40 @@ end # # Table name: ai_personas # -# id :bigint not null, primary key -# name :string(100) not null -# description :string(2000) not null -# system_prompt :string(10000000) not null -# allowed_group_ids :integer default([]), not null, is an Array -# created_by_id :integer -# enabled :boolean default(TRUE), not null -# created_at :datetime not null -# updated_at :datetime not null -# system :boolean default(FALSE), not null -# priority :boolean default(FALSE), not null -# temperature :float -# top_p :float -# user_id :integer -# mentionable :boolean default(FALSE), not null -# default_llm :text -# max_context_posts :integer -# vision_enabled :boolean default(FALSE), not null -# vision_max_pixels :integer default(1048576), not null -# rag_chunk_tokens :integer default(374), not null -# rag_chunk_overlap_tokens :integer default(10), not null -# rag_conversation_chunks :integer default(10), not null -# question_consolidator_llm :text -# allow_chat :boolean default(FALSE), not null -# tool_details :boolean default(TRUE), not null -# tools :json not null +# id :bigint not null, primary key +# name :string(100) not null +# description :string(2000) not null +# system_prompt :string(10000000) not null +# allowed_group_ids :integer default([]), not null, is an Array +# created_by_id :integer +# enabled :boolean default(TRUE), not null +# created_at :datetime not null +# updated_at :datetime not null +# system :boolean default(FALSE), not null +# priority :boolean default(FALSE), not null +# temperature :float +# top_p :float +# user_id :integer +# mentionable :boolean default(FALSE), not null +# default_llm :text +# max_context_posts :integer +# max_post_context_tokens :integer +# max_context_tokens :integer +# vision_enabled :boolean default(FALSE), not null +# vision_max_pixels :integer default(1048576), not null +# rag_chunk_tokens :integer default(374), not null +# rag_chunk_overlap_tokens :integer default(10), not null +# rag_conversation_chunks :integer default(10), not null +# role :enum default("bot"), not null +# role_category_ids :integer default([]), not null, is an Array +# role_tags :string default([]), not null, is an Array +# role_group_ids :integer default([]), not null, is an Array +# role_whispers :boolean default(FALSE), not null +# role_max_responses_per_hour :integer default(50), not null +# question_consolidator_llm :text +# allow_chat :boolean default(FALSE), not null +# tool_details :boolean default(TRUE), not null +# tools :json not null # # Indexes # diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb index da6660b0..69c9812b 100644 --- a/app/serializers/localized_ai_persona_serializer.rb +++ b/app/serializers/localized_ai_persona_serializer.rb @@ -25,7 +25,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer :rag_conversation_chunks, :question_consolidator_llm, :allow_chat, - :tool_details + :tool_details, + :forced_tool_count has_one :user, serializer: BasicUserSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index 9344f579..53be8076 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -28,6 +28,7 @@ const CREATE_ATTRIBUTES = [ "question_consolidator_llm", "allow_chat", "tool_details", + "forced_tool_count", ]; const SYSTEM_ATTRIBUTES = [ @@ -154,6 +155,7 @@ export default class AiPersona extends RestModel { const persona = AiPersona.create(attrs); persona.forcedTools = (this.forcedTools || []).slice(); + persona.forced_tool_count = this.forced_tool_count || -1; return persona; } } diff --git a/assets/javascripts/discourse/components/ai-forced-tool-strategy-selector.gjs b/assets/javascripts/discourse/components/ai-forced-tool-strategy-selector.gjs new file mode 100644 index 00000000..b4e6e40a --- /dev/null +++ b/assets/javascripts/discourse/components/ai-forced-tool-strategy-selector.gjs @@ -0,0 +1,29 @@ +import { computed } from "@ember/object"; +import I18n from "discourse-i18n"; +import ComboBox from "select-kit/components/combo-box"; + +export default ComboBox.extend({ + content: computed(function () { + const content = [ + { + id: -1, + name: I18n.t("discourse_ai.ai_persona.tool_strategies.all"), + }, + ]; + + [1, 2, 5].forEach((i) => { + content.push({ + id: i, + name: I18n.t("discourse_ai.ai_persona.tool_strategies.replies", { + count: i, + }), + }); + }); + + return content; + }), + + selectKitOptions: { + filterable: false, + }, +}); diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 574911ec..c2dd6182 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -20,6 +20,7 @@ import AdminUser from "admin/models/admin-user"; import ComboBox from "select-kit/components/combo-box"; import GroupChooser from "select-kit/components/group-chooser"; import DTooltip from "float-kit/components/d-tooltip"; +import AiForcedToolStrategySelector from "./ai-forced-tool-strategy-selector"; import AiLlmSelector from "./ai-llm-selector"; import AiPersonaToolOptions from "./ai-persona-tool-options"; import AiToolSelector from "./ai-tool-selector"; @@ -49,7 +50,11 @@ export default class PersonaEditor extends Component { } get allowForceTools() { - return !this.editingModel?.system && this.editingModel?.tools?.length > 0; + return !this.editingModel?.system && this.selectedToolNames.length > 0; + } + + get hasForcedTools() { + return this.forcedToolNames.length > 0; } @action @@ -381,12 +386,23 @@ export default class PersonaEditor extends Component {
+ {{#if this.hasForcedTools}} +
+ + +
+ {{/if}} {{/if}} {{#unless this.editingModel.system}} 0 + user_turns = prompt.messages.select { |m| m[:type] == :user }.length + force_tool = false if user_turns > persona.forced_tool_count + end + if force_tool context[:chosen_tools] << force_tool prompt.tool_choice = force_tool diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 0a31598c..dd2a729a 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -117,6 +117,10 @@ module DiscourseAi [] end + def forced_tool_count + -1 + end + def required_tools [] end diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 5aae0b7e..47f529a4 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -127,19 +127,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do it "can force usage of a tool" do tool_name = "custom-#{custom_tool.id}" - ai_persona.update!(tools: [[tool_name, nil, true]]) + ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1) responses = [function_call, "custom tool did stuff (maybe)"] prompts = nil + reply_post = nil + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| new_post = Fabricate(:post, raw: "Can you use the custom tool?") - _reply_post = playground.reply_to(new_post) + reply_post = playground.reply_to(new_post) prompts = _prompts end expect(prompts.length).to eq(2) expect(prompts[0].tool_choice).to eq("search") expect(prompts[1].tool_choice).to eq(nil) + + ai_persona.update!(forced_tool_count: 1) + responses = ["no tool call here"] + + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| + new_post = Fabricate(:post, raw: "Will you use the custom tool?", topic: reply_post.topic) + _reply_post = playground.reply_to(new_post) + prompts = _prompts + end + + expect(prompts.length).to eq(1) + expect(prompts[0].tool_choice).to eq(nil) end it "uses custom tool in conversation" do diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 1b1da5fc..61e831ce 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -39,9 +39,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do Fabricate( :ai_persona, name: "search2", - tools: [["SearchCommand", { base_query: "test" }]], + tools: [["SearchCommand", { base_query: "test" }, true]], mentionable: true, default_llm: "anthropic:claude-2", + forced_tool_count: 2, ) persona2.create_user! @@ -55,6 +56,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2") expect(serializer_persona2["user_id"]).to eq(persona2.user_id) expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id) + expect(serializer_persona2["forced_tool_count"]).to eq(2) tools = response.parsed_body["meta"]["tools"] search_tool = tools.find { |c| c["id"] == "Search" } @@ -85,7 +87,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do ) expect(serializer_persona1["tools"]).to eq(["SearchCommand"]) - expect(serializer_persona2["tools"]).to eq([["SearchCommand", { "base_query" => "test" }]]) + expect(serializer_persona2["tools"]).to eq( + [["SearchCommand", { "base_query" => "test" }, true]], + ) end context "with translations" do @@ -165,6 +169,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do temperature: 0.5, mentionable: true, default_llm: "anthropic:claude-2", + forced_tool_count: 2, } end @@ -183,6 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(persona_json["temperature"]).to eq(0.5) expect(persona_json["mentionable"]).to eq(true) expect(persona_json["default_llm"]).to eq("anthropic:claude-2") + expect(persona_json["forced_tool_count"]).to eq(2) persona = AiPersona.find(persona_json["id"]) diff --git a/spec/system/admin_ai_persona_spec.rb b/spec/system/admin_ai_persona_spec.rb index 171cfcc8..4bee1f14 100644 --- a/spec/system/admin_ai_persona_spec.rb +++ b/spec/system/admin_ai_persona_spec.rb @@ -19,6 +19,17 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools") tool_selector.expand tool_selector.select_row_by_value("Read") + tool_selector.collapse + + tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tools") + tool_selector.expand + tool_selector.select_row_by_value("Read") + tool_selector.collapse + + strategy_selector = + PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tool_strategy") + strategy_selector.expand + strategy_selector.select_row_by_value(1) find(".ai-persona-editor__save").click() @@ -30,7 +41,8 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do expect(persona.name).to eq("Test Persona") expect(persona.description).to eq("I am a test persona") expect(persona.system_prompt).to eq("You are a helpful bot") - expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]]) + expect(persona.tools).to eq([["Read", { "read_private" => nil }, true]]) + expect(persona.forced_tool_count).to eq(1) end it "will not allow deletion or editing of system personas" do diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js index c1f25aeb..f9674f44 100644 --- a/test/javascripts/unit/models/ai-persona-test.js +++ b/test/javascripts/unit/models/ai-persona-test.js @@ -51,6 +51,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { question_consolidator_llm: "Question Consolidator LLM", allow_chat: false, tool_details: true, + forced_tool_count: -1, }; const aiPersona = AiPersona.create({ ...properties }); @@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { question_consolidator_llm: "Question Consolidator LLM", allow_chat: false, tool_details: true, + forced_tool_count: -1, }; const aiPersona = AiPersona.create({ ...properties });