DEV: Use existing topic embeddings when suggesting tags/categories on edit (#1189)

When editing a topic (instead of creating one) and using the
tag/category suggestion buttons. We want to use existing topic
embeddings instead of creating new ones.
This commit is contained in:
Keegan George 2025-03-12 18:52:07 -07:00 committed by GitHub
parent ac29d3080f
commit 6aaf8a0619
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 72 additions and 58 deletions

View File

@ -78,22 +78,26 @@ module DiscourseAi
end end
def suggest_category def suggest_category
input = get_text_param! if params[:topic_id]
input_hash = { text: input } opts = { topic_id: params[:topic_id] }
else
input = get_text_param!
opts = { text: input }
end
render json: render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).categories,
DiscourseAi::AiHelper::SemanticCategorizer.new(
input_hash,
current_user,
).categories,
status: 200 status: 200
end end
def suggest_tags def suggest_tags
input = get_text_param! if params[:topic_id]
input_hash = { text: input } opts = { topic_id: params[:topic_id] }
else
input = get_text_param!
opts = { text: input }
end
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags, render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).tags,
status: 200 status: 200
end end

View File

@ -20,27 +20,13 @@ export default class AiCategorySuggester extends Component {
@tracked untriggers = []; @tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles"; @tracked triggerIcon = "discourse-sparkles";
@tracked content = null; @tracked content = null;
@tracked topicContent = null;
constructor() {
super(...arguments);
if (!this.topicContent && this.args.composer?.reply === undefined) {
this.fetchTopicContent();
}
}
async fetchTopicContent() {
await ajax(`/t/${this.args.buffered.content.id}.json`).then(
({ post_stream }) => {
this.topicContent = post_stream.posts[0].cooked;
}
);
}
get showSuggestionButton() { get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields"); const composerFields = document.querySelector(".composer-fields");
this.content = this.args.composer?.reply || this.topicContent; this.content = this.args.composer?.reply;
const showTrigger = this.content?.length > MIN_CHARACTER_COUNT; const showTrigger =
this.content?.length > MIN_CHARACTER_COUNT ||
this.args.topicState === "edit";
if (composerFields) { if (composerFields) {
if (showTrigger) { if (showTrigger) {
@ -62,12 +48,20 @@ export default class AiCategorySuggester extends Component {
this.loading = true; this.loading = true;
this.triggerIcon = "spinner"; this.triggerIcon = "spinner";
const data = {};
if (this.content) {
data.text = this.content;
} else {
data.topic_id = this.args.buffered.content.id;
}
try { try {
const { assistant } = await ajax( const { assistant } = await ajax(
"/discourse-ai/ai-helper/suggest_category", "/discourse-ai/ai-helper/suggest_category",
{ {
method: "POST", method: "POST",
data: { text: this.content }, data,
} }
); );
this.suggestions = assistant; this.suggestions = assistant;

View File

@ -21,27 +21,13 @@ export default class AiTagSuggester extends Component {
@tracked untriggers = []; @tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles"; @tracked triggerIcon = "discourse-sparkles";
@tracked content = null; @tracked content = null;
@tracked topicContent = null;
constructor() {
super(...arguments);
if (!this.topicContent && this.args.composer?.reply === undefined) {
this.fetchTopicContent();
}
}
async fetchTopicContent() {
await ajax(`/t/${this.args.buffered.content.id}.json`).then(
({ post_stream }) => {
this.topicContent = post_stream.posts[0].cooked;
}
);
}
get showSuggestionButton() { get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields"); const composerFields = document.querySelector(".composer-fields");
this.content = this.args.composer?.reply || this.topicContent; this.content = this.args.composer?.reply;
const showTrigger = this.content?.length > MIN_CHARACTER_COUNT; const showTrigger =
this.content?.length > MIN_CHARACTER_COUNT ||
this.args.topicState === "edit";
if (composerFields) { if (composerFields) {
if (showTrigger) { if (showTrigger) {
@ -74,15 +60,25 @@ export default class AiTagSuggester extends Component {
this.loading = true; this.loading = true;
this.triggerIcon = "spinner"; this.triggerIcon = "spinner";
const data = {};
if (this.content) {
data.text = this.content;
} else {
data.topic_id = this.args.buffered.content.id;
}
try { try {
const { assistant } = await ajax("/discourse-ai/ai-helper/suggest_tags", { const { assistant } = await ajax("/discourse-ai/ai-helper/suggest_tags", {
method: "POST", method: "POST",
data: { text: this.content }, data,
}); });
this.suggestions = assistant; this.suggestions = assistant;
const model = this.args.composer const model = this.args.composer
? this.args.composer ? this.args.composer
: this.args.buffered; : this.args.buffered;
if (this.#tagSelectorHasValues()) { if (this.#tagSelectorHasValues()) {
this.suggestions = this.suggestions.filter( this.suggestions = this.suggestions.filter(
(s) => !model.get("tags").includes(s.name) (s) => !model.get("tags").includes(s.name)

View File

@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
} }
<template> <template>
<AiCategorySuggester @composer={{@outletArgs.composer}} /> <AiCategorySuggester @composer={{@outletArgs.composer}} @topicState="new" />
</template> </template>
} }

View File

@ -13,6 +13,6 @@ export default class AiTagSuggestion extends Component {
} }
<template> <template>
<AiTagSuggester @composer={{@outletArgs.composer}} /> <AiTagSuggester @composer={{@outletArgs.composer}} @topicState="new" />
</template> </template>
} }

View File

@ -13,6 +13,9 @@ export default class AiCategorySuggestion extends Component {
} }
<template> <template>
<AiCategorySuggester @buffered={{@outletArgs.buffered}} /> <AiCategorySuggester
@buffered={{@outletArgs.buffered}}
@topicState="edit"
/>
</template> </template>
} }

View File

@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
} }
<template> <template>
<AiTagSuggester @buffered={{@outletArgs.buffered}} /> <AiTagSuggester @buffered={{@outletArgs.buffered}} @topicState="edit" />
</template> </template>
} }

View File

@ -2,15 +2,16 @@
module DiscourseAi module DiscourseAi
module AiHelper module AiHelper
class SemanticCategorizer class SemanticCategorizer
def initialize(input, user) def initialize(user, opts)
@user = user @user = user
@text = input[:text] @text = opts[:text]
@vector = DiscourseAi::Embeddings::Vector.instance @vector = DiscourseAi::Embeddings::Vector.instance
@schema = DiscourseAi::Embeddings::Schema.for(Topic) @schema = DiscourseAi::Embeddings::Schema.for(Topic)
@topic_id = opts[:topic_id]
end end
def categories def categories
return [] if @text.blank? return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled? return [] if !DiscourseAi::Embeddings.enabled?
candidates = nearest_neighbors candidates = nearest_neighbors
@ -55,7 +56,7 @@ module DiscourseAi
end end
def tags def tags
return [] if @text.blank? return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled? return [] if !DiscourseAi::Embeddings.enabled?
candidates = nearest_neighbors(limit: 100) candidates = nearest_neighbors(limit: 100)
@ -100,7 +101,23 @@ module DiscourseAi
private private
def nearest_neighbors(limit: 50) def nearest_neighbors(limit: 50)
raw_vector = @vector.vector_from(@text) if @topic_id
target = Topic.find_by(id: @topic_id)
embeddings = @schema.find_by_target(target)&.embeddings
if embeddings.blank?
@text =
DiscourseAi::Summarization::Strategies::TopicSummary
.new(target)
.targets_data
.pluck(:text)
raw_vector = @vector.vector_from(@text)
else
raw_vector = JSON.parse(embeddings)
end
else
raw_vector = @vector.vector_from(@text)
end
muted_category_ids = nil muted_category_ids = nil
if @user.present? if @user.present?

View File

@ -16,7 +16,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:topic) { Fabricate(:topic, category: category) } fab!(:topic) { Fabricate(:topic, category: category) }
let(:vector) { DiscourseAi::Embeddings::Vector.instance } let(:vector) { DiscourseAi::Embeddings::Vector.instance }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new(user, { text: "hello" }) }
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions } let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do before do