DEV: Use existing topic embeddings when suggesting tags/categories on edit (#1189)

When editing a topic (instead of creating one) and using the
tag/category suggestion buttons. We want to use existing topic
embeddings instead of creating new ones.
This commit is contained in:
Keegan George 2025-03-12 18:52:07 -07:00 committed by GitHub
parent ac29d3080f
commit 6aaf8a0619
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 72 additions and 58 deletions

View File

@ -78,22 +78,26 @@ module DiscourseAi
end
def suggest_category
input = get_text_param!
input_hash = { text: input }
if params[:topic_id]
opts = { topic_id: params[:topic_id] }
else
input = get_text_param!
opts = { text: input }
end
render json:
DiscourseAi::AiHelper::SemanticCategorizer.new(
input_hash,
current_user,
).categories,
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).categories,
status: 200
end
def suggest_tags
input = get_text_param!
input_hash = { text: input }
if params[:topic_id]
opts = { topic_id: params[:topic_id] }
else
input = get_text_param!
opts = { text: input }
end
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags,
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).tags,
status: 200
end

View File

@ -20,27 +20,13 @@ export default class AiCategorySuggester extends Component {
@tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles";
@tracked content = null;
@tracked topicContent = null;
constructor() {
super(...arguments);
if (!this.topicContent && this.args.composer?.reply === undefined) {
this.fetchTopicContent();
}
}
async fetchTopicContent() {
await ajax(`/t/${this.args.buffered.content.id}.json`).then(
({ post_stream }) => {
this.topicContent = post_stream.posts[0].cooked;
}
);
}
get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields");
this.content = this.args.composer?.reply || this.topicContent;
const showTrigger = this.content?.length > MIN_CHARACTER_COUNT;
this.content = this.args.composer?.reply;
const showTrigger =
this.content?.length > MIN_CHARACTER_COUNT ||
this.args.topicState === "edit";
if (composerFields) {
if (showTrigger) {
@ -62,12 +48,20 @@ export default class AiCategorySuggester extends Component {
this.loading = true;
this.triggerIcon = "spinner";
const data = {};
if (this.content) {
data.text = this.content;
} else {
data.topic_id = this.args.buffered.content.id;
}
try {
const { assistant } = await ajax(
"/discourse-ai/ai-helper/suggest_category",
{
method: "POST",
data: { text: this.content },
data,
}
);
this.suggestions = assistant;

View File

@ -21,27 +21,13 @@ export default class AiTagSuggester extends Component {
@tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles";
@tracked content = null;
@tracked topicContent = null;
constructor() {
super(...arguments);
if (!this.topicContent && this.args.composer?.reply === undefined) {
this.fetchTopicContent();
}
}
async fetchTopicContent() {
await ajax(`/t/${this.args.buffered.content.id}.json`).then(
({ post_stream }) => {
this.topicContent = post_stream.posts[0].cooked;
}
);
}
get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields");
this.content = this.args.composer?.reply || this.topicContent;
const showTrigger = this.content?.length > MIN_CHARACTER_COUNT;
this.content = this.args.composer?.reply;
const showTrigger =
this.content?.length > MIN_CHARACTER_COUNT ||
this.args.topicState === "edit";
if (composerFields) {
if (showTrigger) {
@ -74,15 +60,25 @@ export default class AiTagSuggester extends Component {
this.loading = true;
this.triggerIcon = "spinner";
const data = {};
if (this.content) {
data.text = this.content;
} else {
data.topic_id = this.args.buffered.content.id;
}
try {
const { assistant } = await ajax("/discourse-ai/ai-helper/suggest_tags", {
method: "POST",
data: { text: this.content },
data,
});
this.suggestions = assistant;
const model = this.args.composer
? this.args.composer
: this.args.buffered;
if (this.#tagSelectorHasValues()) {
this.suggestions = this.suggestions.filter(
(s) => !model.get("tags").includes(s.name)

View File

@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
}
<template>
<AiCategorySuggester @composer={{@outletArgs.composer}} />
<AiCategorySuggester @composer={{@outletArgs.composer}} @topicState="new" />
</template>
}

View File

@ -13,6 +13,6 @@ export default class AiTagSuggestion extends Component {
}
<template>
<AiTagSuggester @composer={{@outletArgs.composer}} />
<AiTagSuggester @composer={{@outletArgs.composer}} @topicState="new" />
</template>
}

View File

@ -13,6 +13,9 @@ export default class AiCategorySuggestion extends Component {
}
<template>
<AiCategorySuggester @buffered={{@outletArgs.buffered}} />
<AiCategorySuggester
@buffered={{@outletArgs.buffered}}
@topicState="edit"
/>
</template>
}

View File

@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
}
<template>
<AiTagSuggester @buffered={{@outletArgs.buffered}} />
<AiTagSuggester @buffered={{@outletArgs.buffered}} @topicState="edit" />
</template>
}

View File

@ -2,15 +2,16 @@
module DiscourseAi
module AiHelper
class SemanticCategorizer
def initialize(input, user)
def initialize(user, opts)
@user = user
@text = input[:text]
@text = opts[:text]
@vector = DiscourseAi::Embeddings::Vector.instance
@schema = DiscourseAi::Embeddings::Schema.for(Topic)
@topic_id = opts[:topic_id]
end
def categories
return [] if @text.blank?
return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled?
candidates = nearest_neighbors
@ -55,7 +56,7 @@ module DiscourseAi
end
def tags
return [] if @text.blank?
return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled?
candidates = nearest_neighbors(limit: 100)
@ -100,7 +101,23 @@ module DiscourseAi
private
def nearest_neighbors(limit: 50)
raw_vector = @vector.vector_from(@text)
if @topic_id
target = Topic.find_by(id: @topic_id)
embeddings = @schema.find_by_target(target)&.embeddings
if embeddings.blank?
@text =
DiscourseAi::Summarization::Strategies::TopicSummary
.new(target)
.targets_data
.pluck(:text)
raw_vector = @vector.vector_from(@text)
else
raw_vector = JSON.parse(embeddings)
end
else
raw_vector = @vector.vector_from(@text)
end
muted_category_ids = nil
if @user.present?

View File

@ -16,7 +16,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:topic) { Fabricate(:topic, category: category) }
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new(user, { text: "hello" }) }
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do