diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index 2f3563ac..d272910b 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -106,6 +106,7 @@ module DiscourseAi :question_consolidator_llm, :allow_chat, :tool_details, + :forced_tool_count, allowed_group_ids: [], rag_uploads: [:id], ) diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 54eb6310..350ec064 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 } validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 } validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 } + validates :forced_tool_count, numericality: { greater_than: -2, maximum: 100_000 } has_many :rag_document_fragments, dependent: :destroy, as: :target belongs_to :created_by, class_name: "User" @@ -185,6 +186,7 @@ class AiPersona < ActiveRecord::Base define_method(:tools) { tools } define_method(:force_tool_use) { force_tool_use } + define_method(:forced_tool_count) { @ai_persona&.forced_tool_count } define_method(:options) { options } define_method(:temperature) { @ai_persona&.temperature } define_method(:top_p) { @ai_persona&.top_p } @@ -265,32 +267,40 @@ end # # Table name: ai_personas # -# id :bigint not null, primary key -# name :string(100) not null -# description :string(2000) not null -# system_prompt :string(10000000) not null -# allowed_group_ids :integer default([]), not null, is an Array -# created_by_id :integer -# enabled :boolean default(TRUE), not null -# created_at :datetime not null -# updated_at :datetime not null -# system :boolean default(FALSE), not null -# priority :boolean default(FALSE), not null -# temperature :float -# top_p :float -# user_id :integer -# mentionable :boolean default(FALSE), not null -# default_llm :text -# max_context_posts :integer -# vision_enabled :boolean default(FALSE), not null -# vision_max_pixels :integer default(1048576), not null -# rag_chunk_tokens :integer default(374), not null -# rag_chunk_overlap_tokens :integer default(10), not null -# rag_conversation_chunks :integer default(10), not null -# question_consolidator_llm :text -# allow_chat :boolean default(FALSE), not null -# tool_details :boolean default(TRUE), not null -# tools :json not null +# id :bigint not null, primary key +# name :string(100) not null +# description :string(2000) not null +# system_prompt :string(10000000) not null +# allowed_group_ids :integer default([]), not null, is an Array +# created_by_id :integer +# enabled :boolean default(TRUE), not null +# created_at :datetime not null +# updated_at :datetime not null +# system :boolean default(FALSE), not null +# priority :boolean default(FALSE), not null +# temperature :float +# top_p :float +# user_id :integer +# mentionable :boolean default(FALSE), not null +# default_llm :text +# max_context_posts :integer +# max_post_context_tokens :integer +# max_context_tokens :integer +# vision_enabled :boolean default(FALSE), not null +# vision_max_pixels :integer default(1048576), not null +# rag_chunk_tokens :integer default(374), not null +# rag_chunk_overlap_tokens :integer default(10), not null +# rag_conversation_chunks :integer default(10), not null +# role :enum default("bot"), not null +# role_category_ids :integer default([]), not null, is an Array +# role_tags :string default([]), not null, is an Array +# role_group_ids :integer default([]), not null, is an Array +# role_whispers :boolean default(FALSE), not null +# role_max_responses_per_hour :integer default(50), not null +# question_consolidator_llm :text +# allow_chat :boolean default(FALSE), not null +# tool_details :boolean default(TRUE), not null +# tools :json not null # # Indexes # diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb index da6660b0..69c9812b 100644 --- a/app/serializers/localized_ai_persona_serializer.rb +++ b/app/serializers/localized_ai_persona_serializer.rb @@ -25,7 +25,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer :rag_conversation_chunks, :question_consolidator_llm, :allow_chat, - :tool_details + :tool_details, + :forced_tool_count has_one :user, serializer: BasicUserSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index 9344f579..53be8076 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -28,6 +28,7 @@ const CREATE_ATTRIBUTES = [ "question_consolidator_llm", "allow_chat", "tool_details", + "forced_tool_count", ]; const SYSTEM_ATTRIBUTES = [ @@ -154,6 +155,7 @@ export default class AiPersona extends RestModel { const persona = AiPersona.create(attrs); persona.forcedTools = (this.forcedTools || []).slice(); + persona.forced_tool_count = this.forced_tool_count || -1; return persona; } } diff --git a/assets/javascripts/discourse/components/ai-forced-tool-strategy-selector.gjs b/assets/javascripts/discourse/components/ai-forced-tool-strategy-selector.gjs new file mode 100644 index 00000000..b4e6e40a --- /dev/null +++ b/assets/javascripts/discourse/components/ai-forced-tool-strategy-selector.gjs @@ -0,0 +1,29 @@ +import { computed } from "@ember/object"; +import I18n from "discourse-i18n"; +import ComboBox from "select-kit/components/combo-box"; + +export default ComboBox.extend({ + content: computed(function () { + const content = [ + { + id: -1, + name: I18n.t("discourse_ai.ai_persona.tool_strategies.all"), + }, + ]; + + [1, 2, 5].forEach((i) => { + content.push({ + id: i, + name: I18n.t("discourse_ai.ai_persona.tool_strategies.replies", { + count: i, + }), + }); + }); + + return content; + }), + + selectKitOptions: { + filterable: false, + }, +}); diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 574911ec..c2dd6182 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -20,6 +20,7 @@ import AdminUser from "admin/models/admin-user"; import ComboBox from "select-kit/components/combo-box"; import GroupChooser from "select-kit/components/group-chooser"; import DTooltip from "float-kit/components/d-tooltip"; +import AiForcedToolStrategySelector from "./ai-forced-tool-strategy-selector"; import AiLlmSelector from "./ai-llm-selector"; import AiPersonaToolOptions from "./ai-persona-tool-options"; import AiToolSelector from "./ai-tool-selector"; @@ -49,7 +50,11 @@ export default class PersonaEditor extends Component { } get allowForceTools() { - return !this.editingModel?.system && this.editingModel?.tools?.length > 0; + return !this.editingModel?.system && this.selectedToolNames.length > 0; + } + + get hasForcedTools() { + return this.forcedToolNames.length > 0; } @action @@ -381,12 +386,23 @@ export default class PersonaEditor extends Component {
+ {{#if this.hasForcedTools}} +
+ + +
+ {{/if}} {{/if}} {{#unless this.editingModel.system}} 0 + user_turns = prompt.messages.select { |m| m[:type] == :user }.length + force_tool = false if user_turns > persona.forced_tool_count + end + if force_tool context[:chosen_tools] << force_tool prompt.tool_choice = force_tool diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 0a31598c..dd2a729a 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -117,6 +117,10 @@ module DiscourseAi [] end + def forced_tool_count + -1 + end + def required_tools [] end diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 5aae0b7e..47f529a4 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -127,19 +127,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do it "can force usage of a tool" do tool_name = "custom-#{custom_tool.id}" - ai_persona.update!(tools: [[tool_name, nil, true]]) + ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1) responses = [function_call, "custom tool did stuff (maybe)"] prompts = nil + reply_post = nil + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| new_post = Fabricate(:post, raw: "Can you use the custom tool?") - _reply_post = playground.reply_to(new_post) + reply_post = playground.reply_to(new_post) prompts = _prompts end expect(prompts.length).to eq(2) expect(prompts[0].tool_choice).to eq("search") expect(prompts[1].tool_choice).to eq(nil) + + ai_persona.update!(forced_tool_count: 1) + responses = ["no tool call here"] + + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| + new_post = Fabricate(:post, raw: "Will you use the custom tool?", topic: reply_post.topic) + _reply_post = playground.reply_to(new_post) + prompts = _prompts + end + + expect(prompts.length).to eq(1) + expect(prompts[0].tool_choice).to eq(nil) end it "uses custom tool in conversation" do diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 1b1da5fc..61e831ce 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -39,9 +39,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do Fabricate( :ai_persona, name: "search2", - tools: [["SearchCommand", { base_query: "test" }]], + tools: [["SearchCommand", { base_query: "test" }, true]], mentionable: true, default_llm: "anthropic:claude-2", + forced_tool_count: 2, ) persona2.create_user! @@ -55,6 +56,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2") expect(serializer_persona2["user_id"]).to eq(persona2.user_id) expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id) + expect(serializer_persona2["forced_tool_count"]).to eq(2) tools = response.parsed_body["meta"]["tools"] search_tool = tools.find { |c| c["id"] == "Search" } @@ -85,7 +87,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do ) expect(serializer_persona1["tools"]).to eq(["SearchCommand"]) - expect(serializer_persona2["tools"]).to eq([["SearchCommand", { "base_query" => "test" }]]) + expect(serializer_persona2["tools"]).to eq( + [["SearchCommand", { "base_query" => "test" }, true]], + ) end context "with translations" do @@ -165,6 +169,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do temperature: 0.5, mentionable: true, default_llm: "anthropic:claude-2", + forced_tool_count: 2, } end @@ -183,6 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(persona_json["temperature"]).to eq(0.5) expect(persona_json["mentionable"]).to eq(true) expect(persona_json["default_llm"]).to eq("anthropic:claude-2") + expect(persona_json["forced_tool_count"]).to eq(2) persona = AiPersona.find(persona_json["id"]) diff --git a/spec/system/admin_ai_persona_spec.rb b/spec/system/admin_ai_persona_spec.rb index 171cfcc8..4bee1f14 100644 --- a/spec/system/admin_ai_persona_spec.rb +++ b/spec/system/admin_ai_persona_spec.rb @@ -19,6 +19,17 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools") tool_selector.expand tool_selector.select_row_by_value("Read") + tool_selector.collapse + + tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tools") + tool_selector.expand + tool_selector.select_row_by_value("Read") + tool_selector.collapse + + strategy_selector = + PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tool_strategy") + strategy_selector.expand + strategy_selector.select_row_by_value(1) find(".ai-persona-editor__save").click() @@ -30,7 +41,8 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do expect(persona.name).to eq("Test Persona") expect(persona.description).to eq("I am a test persona") expect(persona.system_prompt).to eq("You are a helpful bot") - expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]]) + expect(persona.tools).to eq([["Read", { "read_private" => nil }, true]]) + expect(persona.forced_tool_count).to eq(1) end it "will not allow deletion or editing of system personas" do diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js index c1f25aeb..f9674f44 100644 --- a/test/javascripts/unit/models/ai-persona-test.js +++ b/test/javascripts/unit/models/ai-persona-test.js @@ -51,6 +51,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { question_consolidator_llm: "Question Consolidator LLM", allow_chat: false, tool_details: true, + forced_tool_count: -1, }; const aiPersona = AiPersona.create({ ...properties }); @@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { question_consolidator_llm: "Question Consolidator LLM", allow_chat: false, tool_details: true, + forced_tool_count: -1, }; const aiPersona = AiPersona.create({ ...properties });