From 6aaf8a061909d71873e40d13d247a3ead7a26d22 Mon Sep 17 00:00:00 2001 From: Keegan George Date: Wed, 12 Mar 2025 18:52:07 -0700 Subject: [PATCH] DEV: Use existing topic embeddings when suggesting tags/categories on edit (#1189) When editing a topic (instead of creating one) and using the tag/category suggestion buttons. We want to use existing topic embeddings instead of creating new ones. --- .../ai_helper/assistant_controller.rb | 24 +++++++------ .../ai-category-suggester.gjs | 32 +++++++---------- .../suggestion-menus/ai-tag-suggester.gjs | 34 ++++++++----------- .../ai-category-suggestion.gjs | 2 +- .../ai-tag-suggestion.gjs | 2 +- .../ai-category-suggestion.gjs | 5 ++- .../ai-tag-suggestion.gjs | 2 +- lib/ai_helper/semantic_categorizer.rb | 27 ++++++++++++--- .../ai_helper/semantic_categorizer_spec.rb | 2 +- 9 files changed, 72 insertions(+), 58 deletions(-) diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index 0d0f85ee..d2ee8738 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -78,22 +78,26 @@ module DiscourseAi end def suggest_category - input = get_text_param! - input_hash = { text: input } + if params[:topic_id] + opts = { topic_id: params[:topic_id] } + else + input = get_text_param! + opts = { text: input } + end - render json: - DiscourseAi::AiHelper::SemanticCategorizer.new( - input_hash, - current_user, - ).categories, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).categories, status: 200 end def suggest_tags - input = get_text_param! - input_hash = { text: input } + if params[:topic_id] + opts = { topic_id: params[:topic_id] } + else + input = get_text_param! + opts = { text: input } + end - render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).tags, status: 200 end diff --git a/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs b/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs index 18051b52..ff44af91 100644 --- a/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs +++ b/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs @@ -20,27 +20,13 @@ export default class AiCategorySuggester extends Component { @tracked untriggers = []; @tracked triggerIcon = "discourse-sparkles"; @tracked content = null; - @tracked topicContent = null; - - constructor() { - super(...arguments); - if (!this.topicContent && this.args.composer?.reply === undefined) { - this.fetchTopicContent(); - } - } - - async fetchTopicContent() { - await ajax(`/t/${this.args.buffered.content.id}.json`).then( - ({ post_stream }) => { - this.topicContent = post_stream.posts[0].cooked; - } - ); - } get showSuggestionButton() { const composerFields = document.querySelector(".composer-fields"); - this.content = this.args.composer?.reply || this.topicContent; - const showTrigger = this.content?.length > MIN_CHARACTER_COUNT; + this.content = this.args.composer?.reply; + const showTrigger = + this.content?.length > MIN_CHARACTER_COUNT || + this.args.topicState === "edit"; if (composerFields) { if (showTrigger) { @@ -62,12 +48,20 @@ export default class AiCategorySuggester extends Component { this.loading = true; this.triggerIcon = "spinner"; + const data = {}; + + if (this.content) { + data.text = this.content; + } else { + data.topic_id = this.args.buffered.content.id; + } + try { const { assistant } = await ajax( "/discourse-ai/ai-helper/suggest_category", { method: "POST", - data: { text: this.content }, + data, } ); this.suggestions = assistant; diff --git a/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs b/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs index 47ad38f6..84ce1958 100644 --- a/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs +++ b/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs @@ -21,27 +21,13 @@ export default class AiTagSuggester extends Component { @tracked untriggers = []; @tracked triggerIcon = "discourse-sparkles"; @tracked content = null; - @tracked topicContent = null; - - constructor() { - super(...arguments); - if (!this.topicContent && this.args.composer?.reply === undefined) { - this.fetchTopicContent(); - } - } - - async fetchTopicContent() { - await ajax(`/t/${this.args.buffered.content.id}.json`).then( - ({ post_stream }) => { - this.topicContent = post_stream.posts[0].cooked; - } - ); - } get showSuggestionButton() { const composerFields = document.querySelector(".composer-fields"); - this.content = this.args.composer?.reply || this.topicContent; - const showTrigger = this.content?.length > MIN_CHARACTER_COUNT; + this.content = this.args.composer?.reply; + const showTrigger = + this.content?.length > MIN_CHARACTER_COUNT || + this.args.topicState === "edit"; if (composerFields) { if (showTrigger) { @@ -74,15 +60,25 @@ export default class AiTagSuggester extends Component { this.loading = true; this.triggerIcon = "spinner"; + const data = {}; + + if (this.content) { + data.text = this.content; + } else { + data.topic_id = this.args.buffered.content.id; + } + try { const { assistant } = await ajax("/discourse-ai/ai-helper/suggest_tags", { method: "POST", - data: { text: this.content }, + data, }); this.suggestions = assistant; + const model = this.args.composer ? this.args.composer : this.args.buffered; + if (this.#tagSelectorHasValues()) { this.suggestions = this.suggestions.filter( (s) => !model.get("tags").includes(s.name) diff --git a/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs b/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs index d7bef642..47cc9d88 100644 --- a/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs @@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component { } } diff --git a/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs b/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs index ac6ad686..9f02ed19 100644 --- a/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs @@ -13,6 +13,6 @@ export default class AiTagSuggestion extends Component { } } diff --git a/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs b/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs index c1f5c0c0..1dcf3483 100644 --- a/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs @@ -13,6 +13,9 @@ export default class AiCategorySuggestion extends Component { } } diff --git a/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs b/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs index 3ab8656f..7404822b 100644 --- a/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs @@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component { } } diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index b05c3ece..488741de 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -2,15 +2,16 @@ module DiscourseAi module AiHelper class SemanticCategorizer - def initialize(input, user) + def initialize(user, opts) @user = user - @text = input[:text] + @text = opts[:text] @vector = DiscourseAi::Embeddings::Vector.instance @schema = DiscourseAi::Embeddings::Schema.for(Topic) + @topic_id = opts[:topic_id] end def categories - return [] if @text.blank? + return [] if @text.blank? && @topic_id.nil? return [] if !DiscourseAi::Embeddings.enabled? candidates = nearest_neighbors @@ -55,7 +56,7 @@ module DiscourseAi end def tags - return [] if @text.blank? + return [] if @text.blank? && @topic_id.nil? return [] if !DiscourseAi::Embeddings.enabled? candidates = nearest_neighbors(limit: 100) @@ -100,7 +101,23 @@ module DiscourseAi private def nearest_neighbors(limit: 50) - raw_vector = @vector.vector_from(@text) + if @topic_id + target = Topic.find_by(id: @topic_id) + embeddings = @schema.find_by_target(target)&.embeddings + + if embeddings.blank? + @text = + DiscourseAi::Summarization::Strategies::TopicSummary + .new(target) + .targets_data + .pluck(:text) + raw_vector = @vector.vector_from(@text) + else + raw_vector = JSON.parse(embeddings) + end + else + raw_vector = @vector.vector_from(@text) + end muted_category_ids = nil if @user.present? diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb index bbbfe6af..4390959b 100644 --- a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -16,7 +16,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do fab!(:topic) { Fabricate(:topic, category: category) } let(:vector) { DiscourseAi::Embeddings::Vector.instance } - let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } + let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new(user, { text: "hello" }) } let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions } before do