diff --git a/.gitignore b/.gitignore index 3f60c8fc..120fc597 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ node_modules evals/log evals/cases config/eval-llms.local.yml +# this gets rid of search results from ag, ripgrep, etc +tokenizers/ +public/ai-share/highlight.min.js diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 5f4eb92c..05326ba6 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -7,6 +7,7 @@ en: discourse_ai: search: "Allows AI search" stream_completion: "Allows streaming AI persona completions" + update_personas: "Allows updating AI personas" site_settings: categories: diff --git a/lib/personas/tool_runner.rb b/lib/personas/tool_runner.rb index 33a5a657..6ce68476 100644 --- a/lib/personas/tool_runner.rb +++ b/lib/personas/tool_runner.rb @@ -82,19 +82,39 @@ module DiscourseAi search: function(params) { return _discourse_search(params); }, + updatePersona: function(persona_id_or_name, updates) { + const result = _discourse_update_persona(persona_id_or_name, updates); + if (result.error) { + throw new Error(result.error); + } + return result; + }, getPost: _discourse_get_post, getTopic: _discourse_get_topic, getUser: _discourse_get_user, getPersona: function(name) { - return { - respondTo: function(params) { - result = _discourse_respond_to_persona(name, params); + const personaDetails = _discourse_get_persona(name); + if (personaDetails.error) { + throw new Error(personaDetails.error); + } + + // merge result.persona with {}.. + return Object.assign({ + update: function(updates) { + const result = _discourse_update_persona(name, updates); if (result.error) { throw new Error(result.error); } return result; }, - }; + respondTo: function(params) { + const result = _discourse_respond_to_persona(name, params); + if (result.error) { + throw new Error(result.error); + } + return result; + } + }, personaDetails.persona); }, createChatMessage: function(params) { const result = _discourse_create_chat_message(params); @@ -160,6 +180,20 @@ module DiscourseAi { error: "Script terminated due to timeout" } end + def has_custom_context? + mini_racer_context.eval(tool.script) + mini_racer_context.eval("typeof customContext === 'function'") + rescue StandardError + false + end + + def custom_context + mini_racer_context.eval(tool.script) + mini_racer_context.eval("customContext()") + rescue StandardError + nil + end + private MAX_FRAGMENTS = 200 @@ -443,6 +477,96 @@ module DiscourseAi end end, ) + + mini_racer_context.attach( + "_discourse_get_persona", + ->(persona_name) do + in_attached_function do + persona = AiPersona.find_by(name: persona_name) + + return { error: "Persona not found" } if persona.nil? + + # Return a subset of relevant persona attributes + { + persona: + persona.attributes.slice( + "id", + "name", + "description", + "enabled", + "system_prompt", + "temperature", + "top_p", + "vision_enabled", + "tools", + "max_context_posts", + "allow_chat_channel_mentions", + "allow_chat_direct_messages", + "allow_topic_mentions", + "allow_personal_messages", + ), + } + end + end, + ) + + mini_racer_context.attach( + "_discourse_update_persona", + ->(persona_id_or_name, updates) do + in_attached_function do + # Find persona by ID or name + persona = nil + if persona_id_or_name.is_a?(Integer) || + persona_id_or_name.to_i.to_s == persona_id_or_name + persona = AiPersona.find_by(id: persona_id_or_name.to_i) + else + persona = AiPersona.find_by(name: persona_id_or_name) + end + + return { error: "Persona not found" } if persona.nil? + + allowed_updates = {} + + if updates["system_prompt"].present? + allowed_updates[:system_prompt] = updates["system_prompt"] + end + + if updates["temperature"].is_a?(Numeric) + allowed_updates[:temperature] = updates["temperature"] + end + + allowed_updates[:top_p] = updates["top_p"] if updates["top_p"].is_a?(Numeric) + + if updates["description"].present? + allowed_updates[:description] = updates["description"] + end + + allowed_updates[:enabled] = updates["enabled"] if updates["enabled"].is_a?( + TrueClass, + ) || updates["enabled"].is_a?(FalseClass) + + if persona.update(allowed_updates) + return( + { + success: true, + persona: + persona.attributes.slice( + "id", + "name", + "description", + "enabled", + "system_prompt", + "temperature", + "top_p", + ), + } + ) + else + return { error: persona.errors.full_messages.join(", ") } + end + end + end, + ) end def attach_upload(mini_racer_context) diff --git a/lib/personas/tools/custom.rb b/lib/personas/tools/custom.rb index 505b051e..29dbb12d 100644 --- a/lib/personas/tools/custom.rb +++ b/lib/personas/tools/custom.rb @@ -29,10 +29,38 @@ module DiscourseAi # Backwards compatibility: if tool_name is not set (existing custom tools), use name def self.name name, tool_name = AiTool.where(id: tool_id).pluck(:name, :tool_name).first - tool_name.presence || name end + def self.has_custom_context? + # note on safety, this can be cached safely, we bump the whole persona cache when an ai tool is saved + # which will expire this class + return @has_custom_context if defined?(@has_custom_context) + + @has_custom_context = false + ai_tool = AiTool.find_by(id: tool_id) + if ai_tool.script.include?("customContext") + runner = ai_tool.runner({}, llm: nil, bot_user: nil, context: nil) + @has_custom_context = runner.has_custom_context? + end + + @has_custom_context + end + + def self.inject_prompt(prompt:, context:, persona:) + if has_custom_context? + ai_tool = AiTool.find_by(id: tool_id) + if ai_tool + runner = ai_tool.runner({}, llm: nil, bot_user: nil, context: context) + custom_context = runner.custom_context + if custom_context.present? + last_message = prompt.messages.last + last_message[:content] = "#{custom_context}\n\n#{last_message[:content]}" + end + end + end + end + def initialize(*args, **kwargs) @chain_next_response = true super(*args, **kwargs) diff --git a/plugin.rb b/plugin.rb index 77bdde7b..de9fdf5f 100644 --- a/plugin.rb +++ b/plugin.rb @@ -127,6 +127,11 @@ after_initialize do end end + add_api_key_scope( + :discourse_ai, + { update_personas: { actions: %w[discourse_ai/admin/ai_personas#update] } }, + ) + plugin_icons = %w[ chart-column spell-check diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index e825ebc2..6633d522 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -1151,4 +1151,50 @@ RSpec.describe DiscourseAi::AiBot::Playground do expect(playground.available_bot_usernames).to include(persona.user.username) end end + + describe "custom tool context injection" do + let!(:custom_tool) do + AiTool.create!( + name: "context_tool", + tool_name: "context_tool", + summary: "tool with custom context", + description: "A test custom tool that injects context", + parameters: [{ name: "query", type: "string", description: "Input for the custom tool" }], + script: <<~JS, + function invoke(params) { + return 'Custom tool result: ' + params.query; + } + + function customContext() { + return "This is additional context from the tool"; + } + + function details() { + return 'executed with custom context'; + } + JS + created_by: user, + ) + end + + let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } + let(:bot) { DiscourseAi::Personas::Bot.as(bot_user, persona: ai_persona.class_instance.new) } + let(:playground) { DiscourseAi::AiBot::Playground.new(bot) } + + it "injects custom context into the prompt" do + prompts = nil + response = "I received the additional context" + + DiscourseAi::Completions::Llm.with_prepared_responses([response]) do |_, _, _prompts| + new_post = Fabricate(:post, raw: "Can you use the custom context tool?") + playground.reply_to(new_post) + prompts = _prompts + end + + # The first prompt should have the custom context prepended to the user message + user_message = prompts[0].messages.last + expect(user_message[:content]).to include("This is additional context from the tool") + expect(user_message[:content]).to include("Can you use the custom context tool?") + end + end end diff --git a/spec/models/ai_tool_spec.rb b/spec/models/ai_tool_spec.rb index 30aa8a23..c1c706b3 100644 --- a/spec/models/ai_tool_spec.rb +++ b/spec/models/ai_tool_spec.rb @@ -560,4 +560,114 @@ RSpec.describe AiTool do expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created end end + + context "when updating personas" do + fab!(:ai_persona) do + Fabricate(:ai_persona, name: "TestPersona", system_prompt: "Original prompt") + end + + it "can update a persona with proper permissions" do + script = <<~JS + function invoke(params) { + return discourse.updatePersona(params.persona_name, { + system_prompt: params.new_prompt, + temperature: 0.7, + top_p: 0.9 + }); + } + JS + + tool = create_tool(script: script) + runner = + tool.runner( + { persona_name: "TestPersona", new_prompt: "Updated system prompt" }, + llm: nil, + bot_user: bot_user, + ) + + result = runner.invoke + expect(result["success"]).to eq(true) + expect(result["persona"]["system_prompt"]).to eq("Updated system prompt") + expect(result["persona"]["temperature"]).to eq(0.7) + + ai_persona.reload + expect(ai_persona.system_prompt).to eq("Updated system prompt") + expect(ai_persona.temperature).to eq(0.7) + expect(ai_persona.top_p).to eq(0.9) + end + end + + context "when fetching persona information" do + fab!(:ai_persona) do + Fabricate( + :ai_persona, + name: "TestPersona", + description: "Test description", + system_prompt: "Test system prompt", + temperature: 0.8, + top_p: 0.9, + vision_enabled: true, + tools: ["Search", ["WebSearch", { param: "value" }, true]], + ) + end + + it "can fetch a persona by name" do + script = <<~JS + function invoke(params) { + const persona = discourse.getPersona(params.persona_name); + return persona; + } + JS + + tool = create_tool(script: script) + runner = tool.runner({ persona_name: "TestPersona" }, llm: nil, bot_user: bot_user) + + result = runner.invoke + + expect(result["id"]).to eq(ai_persona.id) + expect(result["name"]).to eq("TestPersona") + expect(result["description"]).to eq("Test description") + expect(result["system_prompt"]).to eq("Test system prompt") + expect(result["temperature"]).to eq(0.8) + expect(result["top_p"]).to eq(0.9) + expect(result["vision_enabled"]).to eq(true) + expect(result["tools"]).to include("Search") + expect(result["tools"][1]).to be_a(Array) + end + + it "raises an error when the persona doesn't exist" do + script = <<~JS + function invoke(params) { + return discourse.getPersona("NonExistentPersona"); + } + JS + + tool = create_tool(script: script) + runner = tool.runner({}, llm: nil, bot_user: bot_user) + + expect { runner.invoke }.to raise_error(MiniRacer::RuntimeError, /Persona not found/) + end + + it "can update a persona after fetching it" do + script = <<~JS + function invoke(params) { + const persona = discourse.getPersona("TestPersona"); + return persona.update({ + system_prompt: "Updated through getPersona().update()", + temperature: 0.5 + }); + } + JS + + tool = create_tool(script: script) + runner = tool.runner({}, llm: nil, bot_user: bot_user) + + result = runner.invoke + expect(result["success"]).to eq(true) + + ai_persona.reload + expect(ai_persona.system_prompt).to eq("Updated through getPersona().update()") + expect(ai_persona.temperature).to eq(0.5) + end + end end diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index c7989e28..7fbb016f 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -239,6 +239,54 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do end describe "PUT #update" do + context "with scoped api key" do + it "allows updates with a properly scoped API key" do + api_key = Fabricate(:api_key, user: admin, created_by: admin) + + scope = + ApiKeyScope.create!( + resource: "discourse_ai", + action: "update_personas", + api_key_id: api_key.id, + allowed_parameters: { + }, + ) + + put "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json", + params: { + ai_persona: { + name: "UpdatedByAPI", + description: "Updated via API key", + }, + }, + headers: { + "Api-Key" => api_key.key, + "Api-Username" => admin.username, + } + + expect(response).to have_http_status(:ok) + ai_persona.reload + expect(ai_persona.name).to eq("UpdatedByAPI") + expect(ai_persona.description).to eq("Updated via API key") + + scope.update!(action: "fake") + + put "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json", + params: { + ai_persona: { + name: "UpdatedByAPI 2", + description: "Updated via API key", + }, + }, + headers: { + "Api-Key" => api_key.key, + "Api-Username" => admin.username, + } + + expect(response).not_to have_http_status(:ok) + end + end + it "allows us to trivially clear top_p and temperature" do persona = Fabricate(:ai_persona, name: "test_bot2", top_p: 0.5, temperature: 0.1) put "/admin/plugins/discourse-ai/ai-personas/#{persona.id}.json",