diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
index 0d0f85ee..d2ee8738 100644
--- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
+++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
@@ -78,22 +78,26 @@ module DiscourseAi
end
def suggest_category
- input = get_text_param!
- input_hash = { text: input }
+ if params[:topic_id]
+ opts = { topic_id: params[:topic_id] }
+ else
+ input = get_text_param!
+ opts = { text: input }
+ end
- render json:
- DiscourseAi::AiHelper::SemanticCategorizer.new(
- input_hash,
- current_user,
- ).categories,
+ render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).categories,
status: 200
end
def suggest_tags
- input = get_text_param!
- input_hash = { text: input }
+ if params[:topic_id]
+ opts = { topic_id: params[:topic_id] }
+ else
+ input = get_text_param!
+ opts = { text: input }
+ end
- render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags,
+ render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).tags,
status: 200
end
diff --git a/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs b/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs
index 18051b52..ff44af91 100644
--- a/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs
+++ b/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs
@@ -20,27 +20,13 @@ export default class AiCategorySuggester extends Component {
@tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles";
@tracked content = null;
- @tracked topicContent = null;
-
- constructor() {
- super(...arguments);
- if (!this.topicContent && this.args.composer?.reply === undefined) {
- this.fetchTopicContent();
- }
- }
-
- async fetchTopicContent() {
- await ajax(`/t/${this.args.buffered.content.id}.json`).then(
- ({ post_stream }) => {
- this.topicContent = post_stream.posts[0].cooked;
- }
- );
- }
get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields");
- this.content = this.args.composer?.reply || this.topicContent;
- const showTrigger = this.content?.length > MIN_CHARACTER_COUNT;
+ this.content = this.args.composer?.reply;
+ const showTrigger =
+ this.content?.length > MIN_CHARACTER_COUNT ||
+ this.args.topicState === "edit";
if (composerFields) {
if (showTrigger) {
@@ -62,12 +48,20 @@ export default class AiCategorySuggester extends Component {
this.loading = true;
this.triggerIcon = "spinner";
+ const data = {};
+
+ if (this.content) {
+ data.text = this.content;
+ } else {
+ data.topic_id = this.args.buffered.content.id;
+ }
+
try {
const { assistant } = await ajax(
"/discourse-ai/ai-helper/suggest_category",
{
method: "POST",
- data: { text: this.content },
+ data,
}
);
this.suggestions = assistant;
diff --git a/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs b/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs
index 47ad38f6..84ce1958 100644
--- a/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs
+++ b/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs
@@ -21,27 +21,13 @@ export default class AiTagSuggester extends Component {
@tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles";
@tracked content = null;
- @tracked topicContent = null;
-
- constructor() {
- super(...arguments);
- if (!this.topicContent && this.args.composer?.reply === undefined) {
- this.fetchTopicContent();
- }
- }
-
- async fetchTopicContent() {
- await ajax(`/t/${this.args.buffered.content.id}.json`).then(
- ({ post_stream }) => {
- this.topicContent = post_stream.posts[0].cooked;
- }
- );
- }
get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields");
- this.content = this.args.composer?.reply || this.topicContent;
- const showTrigger = this.content?.length > MIN_CHARACTER_COUNT;
+ this.content = this.args.composer?.reply;
+ const showTrigger =
+ this.content?.length > MIN_CHARACTER_COUNT ||
+ this.args.topicState === "edit";
if (composerFields) {
if (showTrigger) {
@@ -74,15 +60,25 @@ export default class AiTagSuggester extends Component {
this.loading = true;
this.triggerIcon = "spinner";
+ const data = {};
+
+ if (this.content) {
+ data.text = this.content;
+ } else {
+ data.topic_id = this.args.buffered.content.id;
+ }
+
try {
const { assistant } = await ajax("/discourse-ai/ai-helper/suggest_tags", {
method: "POST",
- data: { text: this.content },
+ data,
});
this.suggestions = assistant;
+
const model = this.args.composer
? this.args.composer
: this.args.buffered;
+
if (this.#tagSelectorHasValues()) {
this.suggestions = this.suggestions.filter(
(s) => !model.get("tags").includes(s.name)
diff --git a/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs b/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs
index d7bef642..47cc9d88 100644
--- a/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs
+++ b/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs
@@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
}
-
+
}
diff --git a/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs b/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs
index ac6ad686..9f02ed19 100644
--- a/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs
+++ b/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs
@@ -13,6 +13,6 @@ export default class AiTagSuggestion extends Component {
}
-
+
}
diff --git a/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs b/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs
index c1f5c0c0..1dcf3483 100644
--- a/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs
+++ b/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs
@@ -13,6 +13,9 @@ export default class AiCategorySuggestion extends Component {
}
-
+
}
diff --git a/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs b/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs
index 3ab8656f..7404822b 100644
--- a/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs
+++ b/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs
@@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
}
-
+
}
diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb
index b05c3ece..488741de 100644
--- a/lib/ai_helper/semantic_categorizer.rb
+++ b/lib/ai_helper/semantic_categorizer.rb
@@ -2,15 +2,16 @@
module DiscourseAi
module AiHelper
class SemanticCategorizer
- def initialize(input, user)
+ def initialize(user, opts)
@user = user
- @text = input[:text]
+ @text = opts[:text]
@vector = DiscourseAi::Embeddings::Vector.instance
@schema = DiscourseAi::Embeddings::Schema.for(Topic)
+ @topic_id = opts[:topic_id]
end
def categories
- return [] if @text.blank?
+ return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled?
candidates = nearest_neighbors
@@ -55,7 +56,7 @@ module DiscourseAi
end
def tags
- return [] if @text.blank?
+ return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled?
candidates = nearest_neighbors(limit: 100)
@@ -100,7 +101,23 @@ module DiscourseAi
private
def nearest_neighbors(limit: 50)
- raw_vector = @vector.vector_from(@text)
+ if @topic_id
+ target = Topic.find_by(id: @topic_id)
+ embeddings = @schema.find_by_target(target)&.embeddings
+
+ if embeddings.blank?
+ @text =
+ DiscourseAi::Summarization::Strategies::TopicSummary
+ .new(target)
+ .targets_data
+ .pluck(:text)
+ raw_vector = @vector.vector_from(@text)
+ else
+ raw_vector = JSON.parse(embeddings)
+ end
+ else
+ raw_vector = @vector.vector_from(@text)
+ end
muted_category_ids = nil
if @user.present?
diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb
index bbbfe6af..4390959b 100644
--- a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb
+++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb
@@ -16,7 +16,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:topic) { Fabricate(:topic, category: category) }
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
- let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
+ let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new(user, { text: "hello" }) }
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do