diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb index b4184ab3..e4d12e51 100644 --- a/lib/automation/llm_triage.rb +++ b/lib/automation/llm_triage.rb @@ -30,7 +30,7 @@ module DiscourseAi result = nil - llm = DiscourseAi::Completions::Llm.proxy(model) + llm = DiscourseAi::Completions::Llm.proxy(translate_model(model)) result = llm.generate( @@ -67,6 +67,17 @@ module DiscourseAi post.topic.update!(visible: false) if hide_topic end end + + def self.translate_model(model) + return "google:gemini-pro" if model == "gemini-pro" + return "open_ai:#{model}" if model != "claude-2" + + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + "aws_bedrock:claude-2" + else + "anthropic:claude-2" + end + end end end end diff --git a/lib/automation/report_runner.rb b/lib/automation/report_runner.rb index 375062b2..0ec2ddef 100644 --- a/lib/automation/report_runner.rb +++ b/lib/automation/report_runner.rb @@ -60,7 +60,7 @@ module DiscourseAi I18n.t("discourse_automation.llm_report.title") end @model = model - @llm = DiscourseAi::Completions::Llm.proxy(model) + @llm = DiscourseAi::Completions::Llm.proxy(translate_model(model)) @category_ids = category_ids @tags = tags @allow_secure_categories = allow_secure_categories @@ -176,6 +176,17 @@ module DiscourseAi end end end + + def translate_model(model) + return "google:gemini-pro" if model == "gemini-pro" + return "open_ai:#{model}" if model != "claude-2" + + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + "aws_bedrock:claude-2" + else + "anthropic:claude-2" + end + end end end end diff --git a/spec/lib/modules/automation/llm_triage_spec.rb b/spec/lib/modules/automation/llm_triage_spec.rb index c1bb8188..bd3830cd 100644 --- a/spec/lib/modules/automation/llm_triage_spec.rb +++ b/spec/lib/modules/automation/llm_triage_spec.rb @@ -10,7 +10,7 @@ describe DiscourseAi::Automation::LlmTriage do DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do triage( post: post, - model: "fake:fake", + model: "gpt-4", hide_topic: true, system_prompt: "test %%POST%%", search_for_text: "bad", @@ -24,7 +24,7 @@ describe DiscourseAi::Automation::LlmTriage do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do triage( post: post, - model: "fake:fake", + model: "gpt-4", hide_topic: true, system_prompt: "test %%POST%%", search_for_text: "bad", @@ -40,7 +40,7 @@ describe DiscourseAi::Automation::LlmTriage do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do triage( post: post, - model: "fake:fake", + model: "gpt-4", category_id: category.id, system_prompt: "test %%POST%%", search_for_text: "bad", @@ -55,7 +55,7 @@ describe DiscourseAi::Automation::LlmTriage do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do triage( post: post, - model: "fake:fake", + model: "gpt-4", system_prompt: "test %%POST%%", search_for_text: "bad", canned_reply: "test canned reply 123", diff --git a/spec/lib/modules/automation/report_runner_spec.rb b/spec/lib/modules/automation/report_runner_spec.rb index 5e650641..ca424bf2 100644 --- a/spec/lib/modules/automation/report_runner_spec.rb +++ b/spec/lib/modules/automation/report_runner_spec.rb @@ -22,7 +22,7 @@ module DiscourseAi sender_username: user.username, receivers: ["fake@discourse.com"], title: "test report %DATE%", - model: "fake:fake", + model: "gpt-4", category_ids: nil, tags: nil, allow_secure_categories: false, @@ -48,7 +48,7 @@ module DiscourseAi sender_username: user.username, receivers: [receiver.username], title: "test report", - model: "fake:fake", + model: "gpt-4", category_ids: nil, tags: nil, allow_secure_categories: false,