From 2c0f535babd41b0681827f7b67fc48ab1d6b454d Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Tue, 5 Sep 2023 11:08:23 -0300 Subject: [PATCH] FEATURE: HyDE-powered semantic search. (#136) * FEATURE: HyDE-powered semantic search. It relies on the new outlet added on discourse/discourse#23390 to display semantic search results in an unobtrusive way. We'll use a HyDE-backed approach for semantic search, which consists on generating an hypothetical document from a given keywords, which gets transformed into a vector and used in a asymmetric similarity topic search. This PR also reorganizes the internals to have less moving parts, maintaining one hierarchy of DAOish classes for vector-related operations like transformations and querying. Completions and vectors created by HyDE will remain cached on Redis for now, but we could later use Postgres instead. * Missing translation and rate limiting --------- Co-authored-by: Roman Rizzi --- .../embeddings/embeddings_controller.rb | 18 +- .../semantic-search.hbs | 51 ++++++ .../semantic-search.js | 82 +++++++++ .../initializers/semantic-full-page-search.js | 63 ------- .../embeddings/common/semantic-search.scss | 39 ++++ config/locales/client.en.yml | 4 + config/locales/server.en.yml | 1 + config/settings.yml | 12 ++ lib/modules/ai_helper/semantic_categorizer.rb | 11 +- lib/modules/embeddings/entry_point.rb | 17 +- .../embeddings/hyde_generators/anthropic.rb | 32 ++++ .../embeddings/hyde_generators/base.rb | 17 ++ .../embeddings/hyde_generators/llama2.rb | 34 ++++ .../embeddings/hyde_generators/llama2_ftos.rb | 27 +++ .../embeddings/hyde_generators/openai.rb | 30 ++++ .../jobs/regular/generate_embeddings.rb | 6 +- lib/modules/embeddings/manager.rb | 64 ------- .../embeddings/models/all_mpnet_base_v2.rb | 52 ------ lib/modules/embeddings/models/base.rb | 10 -- .../models/multilingual_e5_large.rb | 52 ------ .../models/text_embedding_ada_002.rb | 48 ----- lib/modules/embeddings/semantic_related.rb | 138 ++++++--------- lib/modules/embeddings/semantic_search.rb | 83 +++++---- .../embeddings/semantic_topic_query.rb | 2 +- .../embeddings/strategies/truncation.rb | 74 +++----- .../all_mpnet_base_v2.rb | 50 ++++++ .../embeddings/vector_representations/base.rb | 166 ++++++++++++++++++ .../multilingual_e5_large.rb | 50 ++++++ .../text_embedding_ada_002.rb | 46 +++++ .../inference/hugging_face_text_generation.rb | 2 +- lib/tasks/modules/embeddings/database.rake | 34 ++-- plugin.rb | 1 + spec/integration/embeddings/manager_spec.rb | 44 ----- .../modules/embeddings/entry_point_spec.rb | 93 ---------- .../models/all_mpnet_base_v2_spec.rb | 24 --- .../models/text_embedding_ada_002_spec.rb | 22 --- .../embeddings/semantic_related_spec.rb | 6 +- .../embeddings/semantic_search_spec.rb | 25 ++- .../embeddings/semantic_topic_query_spec.rb | 7 +- .../embeddings/strategies/truncation_spec.rb | 16 +- .../all_mpnet_base_v2_spec.rb | 18 ++ .../multilingual_e5_large_spec.rb | 22 +++ .../text_embedding_ada_002_spec.rb | 16 ++ .../vector_rep_shared_examples.rb | 54 ++++++ spec/requests/topic_spec.rb | 7 +- 45 files changed, 970 insertions(+), 700 deletions(-) create mode 100644 assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.hbs create mode 100644 assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.js delete mode 100644 assets/javascripts/initializers/semantic-full-page-search.js create mode 100644 assets/stylesheets/modules/embeddings/common/semantic-search.scss create mode 100644 lib/modules/embeddings/hyde_generators/anthropic.rb create mode 100644 lib/modules/embeddings/hyde_generators/base.rb create mode 100644 lib/modules/embeddings/hyde_generators/llama2.rb create mode 100644 lib/modules/embeddings/hyde_generators/llama2_ftos.rb create mode 100644 lib/modules/embeddings/hyde_generators/openai.rb delete mode 100644 lib/modules/embeddings/manager.rb delete mode 100644 lib/modules/embeddings/models/all_mpnet_base_v2.rb delete mode 100644 lib/modules/embeddings/models/base.rb delete mode 100644 lib/modules/embeddings/models/multilingual_e5_large.rb delete mode 100644 lib/modules/embeddings/models/text_embedding_ada_002.rb create mode 100644 lib/modules/embeddings/vector_representations/all_mpnet_base_v2.rb create mode 100644 lib/modules/embeddings/vector_representations/base.rb create mode 100644 lib/modules/embeddings/vector_representations/multilingual_e5_large.rb create mode 100644 lib/modules/embeddings/vector_representations/text_embedding_ada_002.rb delete mode 100644 spec/integration/embeddings/manager_spec.rb delete mode 100644 spec/lib/modules/embeddings/models/all_mpnet_base_v2_spec.rb delete mode 100644 spec/lib/modules/embeddings/models/text_embedding_ada_002_spec.rb create mode 100644 spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb create mode 100644 spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb create mode 100644 spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb create mode 100644 spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb diff --git a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb index 71a7c435..a202c709 100644 --- a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb +++ b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb @@ -9,7 +9,6 @@ module DiscourseAi def search query = params[:q] - page = (params[:page] || 1).to_i grouped_results = Search::GroupedSearchResults.new( @@ -19,12 +18,19 @@ module DiscourseAi use_pg_headlines_for_excerpt: false, ) - DiscourseAi::Embeddings::SemanticSearch - .new(guardian) - .search_for_topics(query, page) - .each { |topic_post| grouped_results.add(topic_post) } + semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian) - render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results) + if !semantic_search.cached_query?(query) + RateLimiter.new(current_user, "semantic-search", 4, 1.minutes).performed! + end + + hijack do + semantic_search + .search_for_topics(query) + .each { |topic_post| grouped_results.add(topic_post) } + + render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results) + end end end end diff --git a/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.hbs b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.hbs new file mode 100644 index 00000000..9b71af92 --- /dev/null +++ b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.hbs @@ -0,0 +1,51 @@ +{{#if this.searchEnabled}} +
+
+ {{#if this.searching}} +
+
+ {{i18n "discourse_ai.embeddings.semantic_search_loading"}} +
+ + . + . + . + +
+ {{else}} + {{#if this.results.length}} +
+ +
+ + {{#unless this.collapsedResults}} +
+ +
+ {{/unless}} + {{else}} +
+ {{i18n "discourse_ai.embeddings.semantic_search_results.none"}} +
+ {{/if}} + {{/if}} +
+
+{{/if}} \ No newline at end of file diff --git a/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.js b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.js new file mode 100644 index 00000000..083994c3 --- /dev/null +++ b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.js @@ -0,0 +1,82 @@ +import Component from "@glimmer/component"; +import { action, computed } from "@ember/object"; +import I18n from "I18n"; +import { tracked } from "@glimmer/tracking"; +import { ajax } from "discourse/lib/ajax"; +import { translateResults } from "discourse/lib/search"; +import discourseDebounce from "discourse-common/lib/debounce"; +import { inject as service } from "@ember/service"; +import { bind } from "discourse-common/utils/decorators"; +import { SEARCH_TYPE_DEFAULT } from "discourse/controllers/full-page-search"; + +export default class extends Component { + static shouldRender(_args, { siteSettings }) { + return siteSettings.ai_embeddings_semantic_search_enabled; + } + + @service appEvents; + + @tracked searching = false; + @tracked collapsedResults = true; + @tracked results = []; + + @computed("args.outletArgs.search") + get searchTerm() { + return this.args.outletArgs.search; + } + + @computed("args.outletArgs.type") + get searchEnabled() { + return this.args.outletArgs.type === SEARCH_TYPE_DEFAULT; + } + + @computed("results") + get collapsedResultsTitle() { + return I18n.t("discourse_ai.embeddings.semantic_search_results.toggle", { + count: this.results.length, + }); + } + + @action + setup() { + this.appEvents.on( + "full-page-search:trigger-search", + this, + "debouncedSearch" + ); + } + + @action + teardown() { + this.appEvents.off( + "full-page-search:trigger-search", + this, + "debouncedSearch" + ); + } + + @bind + performHyDESearch() { + if (!this.searchTerm || !this.searchEnabled || this.searching) { + return; + } + + this.searching = true; + this.collapsedResults = true; + this.results = []; + + ajax("/discourse-ai/embeddings/semantic-search", { + data: { q: this.searchTerm }, + }) + .then(async (results) => { + const model = (await translateResults(results)) || {}; + this.results = model.posts; + }) + .finally(() => (this.searching = false)); + } + + @action + debouncedSearch() { + discourseDebounce(this, this.performHyDESearch, 500); + } +} diff --git a/assets/javascripts/initializers/semantic-full-page-search.js b/assets/javascripts/initializers/semantic-full-page-search.js deleted file mode 100644 index 4db7f979..00000000 --- a/assets/javascripts/initializers/semantic-full-page-search.js +++ /dev/null @@ -1,63 +0,0 @@ -import { withPluginApi } from "discourse/lib/plugin-api"; -import { translateResults, updateRecentSearches } from "discourse/lib/search"; -import { ajax } from "discourse/lib/ajax"; - -const SEMANTIC_SEARCH = "semantic_search"; - -function initializeSemanticSearch(api) { - api.addFullPageSearchType( - "discourse_ai.embeddings.semantic_search", - SEMANTIC_SEARCH, - (searchController, args) => { - if (searchController.currentUser) { - updateRecentSearches(searchController.currentUser, args.searchTerm); - } - - ajax("/discourse-ai/embeddings/semantic-search", { data: args }) - .then(async (results) => { - const model = (await translateResults(results)) || {}; - - if (results.grouped_search_result) { - searchController.set("q", results.grouped_search_result.term); - } - - if (args.page > 1) { - if (model) { - searchController.model.posts.pushObjects(model.posts); - searchController.model.topics.pushObjects(model.topics); - searchController.model.set( - "grouped_search_result", - results.grouped_search_result - ); - } - } else { - model.grouped_search_result = results.grouped_search_result; - searchController.set("model", model); - } - searchController.set("error", null); - }) - .catch((e) => { - searchController.set("error", e.jqXHR.responseJSON?.message); - }) - .finally(() => { - searchController.setProperties({ - searching: false, - loading: false, - }); - }); - } - ); -} - -export default { - name: "discourse-ai-full-page-semantic-search", - - initialize(container) { - const settings = container.lookup("service:site-settings"); - const semanticSearch = settings.ai_embeddings_semantic_search_enabled; - - if (settings.ai_embeddings_enabled && semanticSearch) { - withPluginApi("1.6.0", initializeSemanticSearch); - } - }, -}; diff --git a/assets/stylesheets/modules/embeddings/common/semantic-search.scss b/assets/stylesheets/modules/embeddings/common/semantic-search.scss new file mode 100644 index 00000000..fcf09cb6 --- /dev/null +++ b/assets/stylesheets/modules/embeddings/common/semantic-search.scss @@ -0,0 +1,39 @@ +.semantic-search__container { + background: var(--primary-very-low); + margin: 1rem 0 1rem 0; + + .semantic-search__results { + display: flex; + flex-direction: column; + align-items: baseline; + + .semantic-search { + &__searching-text { + display: inline-block; + margin-left: 3px; + } + &__indicator-wave { + flex: 0 0 auto; + display: inline-flex; + } + &__indicator-dot { + display: inline-block; + animation: ai-summary__indicator-wave 1.8s linear infinite; + &:nth-child(2) { + animation-delay: -1.6s; + } + &:nth-child(3) { + animation-delay: -1.4s; + } + } + } + + .semantic-search__entries { + margin-top: 10px; + } + + .semantic-search__searching { + margin-left: 5px; + } + } +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index f026b06a..271ea9e4 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -34,6 +34,10 @@ en: embeddings: semantic_search: "Topics (Semantic)" + semantic_search_loading: "Searching for more results using AI" + semantic_search_results: + toggle: "Found %{count} results using AI" + none: "Sorry, our AI search found no matching topics." ai_bot: pm_warning: "AI chatbot messages are monitored regularly by moderators." diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 2b5ec6b3..987e5013 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -55,6 +55,7 @@ en: ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info." ai_embeddings_semantic_search_enabled: "Enable full-page semantic search." ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results" + ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search" ai_summarization_discourse_service_api_endpoint: "URL where the Discourse summarization API is running." ai_summarization_discourse_service_api_key: "API key for the Discourse summarization API." diff --git a/config/settings.yml b/config/settings.yml index 621b2635..6663a133 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -177,6 +177,18 @@ discourse_ai: ai_embeddings_semantic_search_enabled: default: false client: true + ai_embeddings_semantic_search_hyde_model: + default: "gpt-3.5-turbo" + type: enum + allow_any: false + choices: + - Llama2-*-chat-hf + - claude-instant-1 + - claude-2 + - gpt-3.5-turbo + - gpt-4 + - StableBeluga2 + - Upstage-Llama-2-*-instruct-v2 ai_summarization_discourse_service_api_endpoint: "" ai_summarization_discourse_service_api_key: diff --git a/lib/modules/ai_helper/semantic_categorizer.rb b/lib/modules/ai_helper/semantic_categorizer.rb index 5acb3f5b..c3918c5d 100644 --- a/lib/modules/ai_helper/semantic_categorizer.rb +++ b/lib/modules/ai_helper/semantic_categorizer.rb @@ -11,13 +11,12 @@ module DiscourseAi return [] if @text.blank? return [] unless SiteSetting.ai_embeddings_enabled + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + candidates = - ::DiscourseAi::Embeddings::SemanticSearch.new(nil).asymmetric_semantic_search( - @text, - 100, - 0, - return_distance: true, - ) + vector_rep.asymmetric_semantic_search(@text, limit: 100, offset: 0, return_distance: true) candidate_ids = candidates.map(&:first) ::Topic diff --git a/lib/modules/embeddings/entry_point.rb b/lib/modules/embeddings/entry_point.rb index bcbc7979..208e297b 100644 --- a/lib/modules/embeddings/entry_point.rb +++ b/lib/modules/embeddings/entry_point.rb @@ -4,16 +4,21 @@ module DiscourseAi module Embeddings class EntryPoint def load_files - require_relative "models/base" - require_relative "models/all_mpnet_base_v2" - require_relative "models/text_embedding_ada_002" - require_relative "models/multilingual_e5_large" + require_relative "vector_representations/base" + require_relative "vector_representations/all_mpnet_base_v2" + require_relative "vector_representations/text_embedding_ada_002" + require_relative "vector_representations/multilingual_e5_large" require_relative "strategies/truncation" - require_relative "manager" require_relative "jobs/regular/generate_embeddings" require_relative "semantic_related" - require_relative "semantic_search" require_relative "semantic_topic_query" + + require_relative "hyde_generators/base" + require_relative "hyde_generators/openai" + require_relative "hyde_generators/anthropic" + require_relative "hyde_generators/llama2" + require_relative "hyde_generators/llama2_ftos" + require_relative "semantic_search" end def inject_into(plugin) diff --git a/lib/modules/embeddings/hyde_generators/anthropic.rb b/lib/modules/embeddings/hyde_generators/anthropic.rb new file mode 100644 index 00000000..693ea0dc --- /dev/null +++ b/lib/modules/embeddings/hyde_generators/anthropic.rb @@ -0,0 +1,32 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module HydeGenerators + class Anthropic < DiscourseAi::Embeddings::HydeGenerators::Base + def prompt(search_term) + <<~TEXT + Given a search term given between tags, generate a forum post about the search term. + Respond with the generated post between tags. + + #{search_term} + TEXT + end + + def models + %w[claude-instant-1 claude-2] + end + + def hypothetical_post_from(query) + response = + ::DiscourseAi::Inference::AnthropicCompletions.perform!( + prompt(query), + SiteSetting.ai_embeddings_semantic_search_hyde_model, + ).dig(:completion) + + Nokogiri::HTML5.fragment(response).at("ai").text + end + end + end + end +end diff --git a/lib/modules/embeddings/hyde_generators/base.rb b/lib/modules/embeddings/hyde_generators/base.rb new file mode 100644 index 00000000..8514b414 --- /dev/null +++ b/lib/modules/embeddings/hyde_generators/base.rb @@ -0,0 +1,17 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module HydeGenerators + class Base + def self.current_hyde_model + DiscourseAi::Embeddings::HydeGenerators::Base.descendants.find do |generator_klass| + generator_klass.new.models.include?( + SiteSetting.ai_embeddings_semantic_search_hyde_model, + ) + end + end + end + end + end +end diff --git a/lib/modules/embeddings/hyde_generators/llama2.rb b/lib/modules/embeddings/hyde_generators/llama2.rb new file mode 100644 index 00000000..6a72bb8c --- /dev/null +++ b/lib/modules/embeddings/hyde_generators/llama2.rb @@ -0,0 +1,34 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module HydeGenerators + class Llama2 < DiscourseAi::Embeddings::HydeGenerators::Base + def prompt(search_term) + <<~TEXT + [INST] <> + You are a helpful bot + You create forum posts about a given topic + <> + + Topic: #{search_term} + [/INST] + Here is a forum post about the above topic: + TEXT + end + + def models + ["Llama2-*-chat-hf"] + end + + def hypothetical_post_from(query) + ::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!( + prompt(query), + SiteSetting.ai_embeddings_semantic_search_hyde_model, + token_limit: 400, + ).dig(:generated_text) + end + end + end + end +end diff --git a/lib/modules/embeddings/hyde_generators/llama2_ftos.rb b/lib/modules/embeddings/hyde_generators/llama2_ftos.rb new file mode 100644 index 00000000..fd4245ba --- /dev/null +++ b/lib/modules/embeddings/hyde_generators/llama2_ftos.rb @@ -0,0 +1,27 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module HydeGenerators + class Llama2Ftos < DiscourseAi::Embeddings::HydeGenerators::Llama2 + def prompt(search_term) + <<~TEXT + ### System: + You are a helpful bot + You create forum posts about a given topic + + ### User: + Topic: #{search_term} + + ### Assistant: + Here is a forum post about the above topic: + TEXT + end + + def models + %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2] + end + end + end + end +end diff --git a/lib/modules/embeddings/hyde_generators/openai.rb b/lib/modules/embeddings/hyde_generators/openai.rb new file mode 100644 index 00000000..f44ca8fe --- /dev/null +++ b/lib/modules/embeddings/hyde_generators/openai.rb @@ -0,0 +1,30 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module HydeGenerators + class OpenAi < DiscourseAi::Embeddings::HydeGenerators::Base + def prompt(search_term) + [ + { + role: "system", + content: "You are a helpful bot. You create forum posts about a given topic.", + }, + { role: "user", content: "Create a forum post about the topic: #{search_term}" }, + ] + end + + def models + %w[gpt-3.5-turbo gpt-4] + end + + def hypothetical_post_from(query) + ::DiscourseAi::Inference::OpenAiCompletions.perform!( + prompt(query), + SiteSetting.ai_embeddings_semantic_search_hyde_model, + ).dig(:choices, 0, :message, :content) + end + end + end + end +end diff --git a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb index 7d41cd30..919b43c6 100644 --- a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb +++ b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb @@ -11,7 +11,11 @@ module Jobs post = topic.first_post return if post.nil? || post.raw.blank? - DiscourseAi::Embeddings::Manager.new(topic).generate! + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new + + vector_rep.generate_topic_representation_from(topic, strategy) end end end diff --git a/lib/modules/embeddings/manager.rb b/lib/modules/embeddings/manager.rb deleted file mode 100644 index 2a0c5aee..00000000 --- a/lib/modules/embeddings/manager.rb +++ /dev/null @@ -1,64 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - class Manager - attr_reader :target, :model, :strategy - - def initialize(target) - @target = target - @model = - DiscourseAi::Embeddings::Models::Base.subclasses.find do - _1.name == SiteSetting.ai_embeddings_model - end - @strategy = DiscourseAi::Embeddings::Strategies::Truncation.new(@target, @model) - end - - def generate! - @strategy.process! - - # TODO bail here if we already have an embedding with matching version and digest - - @embeddings = @model.generate_embeddings(@strategy.processed_target) - - persist! - end - - def persist! - begin - DB.exec( - <<~SQL, - INSERT INTO ai_topic_embeddings_#{table_suffix} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at) - VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ON CONFLICT (topic_id) - DO UPDATE SET - model_version = :model_version, - strategy_version = :strategy_version, - digest = :digest, - embeddings = '[:embeddings]', - updated_at = CURRENT_TIMESTAMP - - SQL - topic_id: @target.id, - model_version: @model.version, - strategy_version: @strategy.version, - digest: @strategy.digest, - embeddings: @embeddings, - ) - rescue PG::Error => e - Rails.logger.error( - "Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}", - ) - end - end - - def table_suffix - "#{@model.id}_#{@strategy.id}" - end - - def topic_embeddings_table - "ai_topic_embeddings_#{table_suffix}" - end - end - end -end diff --git a/lib/modules/embeddings/models/all_mpnet_base_v2.rb b/lib/modules/embeddings/models/all_mpnet_base_v2.rb deleted file mode 100644 index 7160052a..00000000 --- a/lib/modules/embeddings/models/all_mpnet_base_v2.rb +++ /dev/null @@ -1,52 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module Models - class AllMpnetBaseV2 < Base - class << self - def id - 1 - end - - def version - 1 - end - - def name - "all-mpnet-base-v2" - end - - def dimensions - 768 - end - - def max_sequence_length - 384 - end - - def pg_function - "<#>" - end - - def pg_index_type - "vector_ip_ops" - end - - def generate_embeddings(text) - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - name, - text, - SiteSetting.ai_embeddings_discourse_service_api_key, - ) - end - - def tokenizer - DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer - end - end - end - end - end -end diff --git a/lib/modules/embeddings/models/base.rb b/lib/modules/embeddings/models/base.rb deleted file mode 100644 index c888308a..00000000 --- a/lib/modules/embeddings/models/base.rb +++ /dev/null @@ -1,10 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module Models - class Base - end - end - end -end diff --git a/lib/modules/embeddings/models/multilingual_e5_large.rb b/lib/modules/embeddings/models/multilingual_e5_large.rb deleted file mode 100644 index 4de1b4d4..00000000 --- a/lib/modules/embeddings/models/multilingual_e5_large.rb +++ /dev/null @@ -1,52 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module Models - class MultilingualE5Large < Base - class << self - def id - 3 - end - - def version - 1 - end - - def name - "multilingual-e5-large" - end - - def dimensions - 1024 - end - - def max_sequence_length - 512 - end - - def pg_function - "<=>" - end - - def pg_index_type - "vector_cosine_ops" - end - - def generate_embeddings(text) - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - name, - "query: #{text}", - SiteSetting.ai_embeddings_discourse_service_api_key, - ) - end - - def tokenizer - DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer - end - end - end - end - end -end diff --git a/lib/modules/embeddings/models/text_embedding_ada_002.rb b/lib/modules/embeddings/models/text_embedding_ada_002.rb deleted file mode 100644 index 167418d4..00000000 --- a/lib/modules/embeddings/models/text_embedding_ada_002.rb +++ /dev/null @@ -1,48 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module Models - class TextEmbeddingAda002 < Base - class << self - def id - 2 - end - - def version - 1 - end - - def name - "text-embedding-ada-002" - end - - def dimensions - 1536 - end - - def max_sequence_length - 8191 - end - - def pg_function - "<=>" - end - - def pg_index_type - "vector_cosine_ops" - end - - def generate_embeddings(text) - response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text) - response[:data].first[:embedding] - end - - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer - end - end - end - end - end -end diff --git a/lib/modules/embeddings/semantic_related.rb b/lib/modules/embeddings/semantic_related.rb index 667866ad..683cf66d 100644 --- a/lib/modules/embeddings/semantic_related.rb +++ b/lib/modules/embeddings/semantic_related.rb @@ -5,101 +5,67 @@ module DiscourseAi class SemanticRelated MissingEmbeddingError = Class.new(StandardError) - class << self - def semantic_suggested_key(topic_id) - "semantic-suggested-topic-#{topic_id}" - end + def self.clear_cache_for(topic) + Discourse.cache.delete("semantic-suggested-topic-#{topic.id}") + Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}") + end - def build_semantic_suggested_key(topic_id) - "build-semantic-suggested-topic-#{topic_id}" - end + def related_topic_ids_for(topic) + return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1 - def clear_cache_for(topic) - Discourse.cache.delete(semantic_suggested_key(topic.id)) - Discourse.redis.del(build_semantic_suggested_key(topic.id)) - end + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + cache_for = results_ttl(topic) - def related_topic_ids_for(topic) - return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1 - - manager = DiscourseAi::Embeddings::Manager.new(topic) - cache_for = results_ttl(topic) - - begin - Discourse - .cache - .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do - symmetric_semantic_search(manager) - end - rescue MissingEmbeddingError - # avoid a flood of jobs when visiting topic - if Discourse.redis.set( - build_semantic_suggested_key(topic.id), - "queued", - ex: 15.minutes.to_i, - nx: true, - ) - Jobs.enqueue(:generate_embeddings, topic_id: topic.id) + asd = + Discourse + .cache + .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do + vector_rep + .symmetric_topics_similarity_search(topic) + .tap do |candidate_ids| + # Happens when the topic doesn't have any embeddings + # I'd rather not use Exceptions to control the flow, so this should be refactored soon + if candidate_ids.empty? || !candidate_ids.include?(topic.id) + raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}" + end + end end - [] - end + rescue MissingEmbeddingError + # avoid a flood of jobs when visiting topic + if Discourse.redis.set( + build_semantic_suggested_key(topic.id), + "queued", + ex: 15.minutes.to_i, + nx: true, + ) + Jobs.enqueue(:generate_embeddings, topic_id: topic.id) end + [] + end - def symmetric_semantic_search(manager) - topic = manager.target - candidate_ids = self.query_symmetric_embeddings(manager) - - # Happens when the topic doesn't have any embeddings - # I'd rather not use Exceptions to control the flow, so this should be refactored soon - if candidate_ids.empty? || !candidate_ids.include?(topic.id) - raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}" - end - - candidate_ids + def results_ttl(topic) + case topic.created_at + when 6.hour.ago..Time.now + 15.minutes + when 3.day.ago..6.hour.ago + 1.hour + when 15.days.ago..3.day.ago + 12.hours + else + 1.week end + end - def query_symmetric_embeddings(manager) - topic = manager.target - model = manager.model - table = manager.topic_embeddings_table - begin - DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id) - SELECT - topic_id - FROM - #{table} - ORDER BY - embeddings #{model.pg_function} ( - SELECT - embeddings - FROM - #{table} - WHERE - topic_id = :topic_id - LIMIT 1 - ) - LIMIT 100 - SQL - rescue PG::Error => e - Rails.logger.error( - "Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}", - ) - raise MissingEmbeddingError - end - end + private - def results_ttl(topic) - case topic.created_at - when 6.hour.ago..Time.now - 15.minutes - when 3.day.ago..6.hour.ago - 1.hour - when 15.days.ago..3.day.ago - 12.hours - else - 1.week - end - end + def semantic_suggested_key(topic_id) + "semantic-suggested-topic-#{topic_id}" + end + + def build_semantic_suggested_key(topic_id) + "build-semantic-suggested-topic-#{topic_id}" end end end diff --git a/lib/modules/embeddings/semantic_search.rb b/lib/modules/embeddings/semantic_search.rb index e35fdc56..2e2e3fd2 100644 --- a/lib/modules/embeddings/semantic_search.rb +++ b/lib/modules/embeddings/semantic_search.rb @@ -3,59 +3,66 @@ module DiscourseAi module Embeddings class SemanticSearch + def self.clear_cache_for(query) + digest = OpenSSL::Digest::SHA1.hexdigest(query) + + Discourse.cache.delete("hyde-doc-#{digest}") + Discourse.cache.delete("hyde-doc-embedding-#{digest}") + end + def initialize(guardian) @guardian = guardian - @manager = DiscourseAi::Embeddings::Manager.new(nil) - @model = @manager.model + end + + def cached_query?(query) + digest = OpenSSL::Digest::SHA1.hexdigest(query) + Discourse.cache.read("hyde-doc-embedding-#{digest}").present? end def search_for_topics(query, page = 1) - limit = Search.per_filter + 1 - offset = (page - 1) * Search.per_filter + max_results_per_page = 50 + limit = [Search.per_filter, max_results_per_page].min + 1 + offset = (page - 1) * limit - candidate_ids = asymmetric_semantic_search(query, limit, offset) + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + + digest = OpenSSL::Digest::SHA1.hexdigest(query) + + hypothetical_post = + Discourse + .cache + .fetch("hyde-doc-#{digest}", expires_in: 1.week) do + hyde_generator = DiscourseAi::Embeddings::HydeGenerators::Base.current_hyde_model.new + hyde_generator.hypothetical_post_from(query) + end + + hypothetical_post_embedding = + Discourse + .cache + .fetch("hyde-doc-embedding-#{digest}", expires_in: 1.week) do + vector_rep.vector_from(hypothetical_post) + end + + candidate_topic_ids = + vector_rep.asymmetric_topics_similarity_search( + hypothetical_post_embedding, + limit: limit, + offset: offset, + ) ::Post .where(post_type: ::Topic.visible_post_types(guardian.user)) .public_posts .where("topics.visible") - .where(topic_id: candidate_ids, post_number: 1) - .order("array_position(ARRAY#{candidate_ids}, topic_id)") - end - - def asymmetric_semantic_search(query, limit, offset, return_distance: false) - embedding = model.generate_embeddings(query) - table = @manager.topic_embeddings_table - - begin - candidate_ids = DB.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset) - SELECT - topic_id, - embeddings #{@model.pg_function} '[:query_embedding]' AS distance - FROM - #{table} - ORDER BY - embeddings #{@model.pg_function} '[:query_embedding]' - LIMIT :limit - OFFSET :offset - SQL - rescue PG::Error => e - Rails.logger.error( - "Error #{e} querying embeddings for model #{model.name} and search #{query}", - ) - raise MissingEmbeddingError - end - - if return_distance - candidate_ids.map { |c| [c.topic_id, c.distance] } - else - candidate_ids.map(&:topic_id) - end + .where(topic_id: candidate_topic_ids, post_number: 1) + .order("array_position(ARRAY#{candidate_topic_ids}, topic_id)") end private - attr_reader :model, :guardian + attr_reader :guardian end end end diff --git a/lib/modules/embeddings/semantic_topic_query.rb b/lib/modules/embeddings/semantic_topic_query.rb index c1034baf..2ee85b65 100644 --- a/lib/modules/embeddings/semantic_topic_query.rb +++ b/lib/modules/embeddings/semantic_topic_query.rb @@ -14,7 +14,7 @@ class DiscourseAi::Embeddings::SemanticTopicQuery < TopicQuery list = create_list(:semantic_related, query_opts) do |topics| - candidate_ids = DiscourseAi::Embeddings::SemanticRelated.related_topic_ids_for(topic) + candidate_ids = DiscourseAi::Embeddings::SemanticRelated.new.related_topic_ids_for(topic) list = topics diff --git a/lib/modules/embeddings/strategies/truncation.rb b/lib/modules/embeddings/strategies/truncation.rb index f7d76340..4b2c977a 100644 --- a/lib/modules/embeddings/strategies/truncation.rb +++ b/lib/modules/embeddings/strategies/truncation.rb @@ -4,77 +4,57 @@ module DiscourseAi module Embeddings module Strategies class Truncation - attr_reader :processed_target, :digest - - def self.id - 1 - end - def id - self.class.id + 1 end def version 1 end - def initialize(target, model) - @model = model - @target = target - @tokenizer = @model.tokenizer - @max_length = @model.max_sequence_length - 2 - @processed_target = nil - end - - # Need a better name for this method - def process! - case @target + def prepare_text_from(target, tokenizer, max_length) + case target when Topic - @processed_target = topic_truncation(@target) + topic_truncation(target, tokenizer, max_length) when Post - @processed_target = post_truncation(@target) + post_truncation(target, tokenizer, max_length) else raise ArgumentError, "Invalid target type" end - - @digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target) end - def topic_truncation(topic) - t = +"" + private - t << topic.title - t << "\n\n" - t << topic.category.name + def topic_information(topic) + info = +"" + + info << topic.title + info << "\n\n" + info << topic.category.name if SiteSetting.tagging_enabled - t << "\n\n" - t << topic.tags.pluck(:name).join(", ") + info << "\n\n" + info << topic.tags.pluck(:name).join(", ") end - t << "\n\n" + info << "\n\n" + end + + def topic_truncation(topic, tokenizer, max_length) + text = +topic_information(topic) topic.posts.find_each do |post| - t << post.raw - break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up? - t << "\n\n" + text << post.raw + break if tokenizer.size(text) >= max_length #maybe keep a partial counter to speed this up? + text << "\n\n" end - @tokenizer.truncate(t, @max_length) + tokenizer.truncate(text, max_length) end - def post_truncation(post) - t = +"" + def post_truncation(topic, tokenizer, max_length) + text = +topic_information(post.topic) + text << post.raw - t << post.topic.title - t << "\n\n" - t << post.topic.category.name - if SiteSetting.tagging_enabled - t << "\n\n" - t << post.topic.tags.pluck(:name).join(", ") - end - t << "\n\n" - t << post.raw - - @tokenizer.truncate(t, @max_length) + tokenizer.truncate(text, max_length) end end end diff --git a/lib/modules/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/modules/embeddings/vector_representations/all_mpnet_base_v2.rb new file mode 100644 index 00000000..8dfb2a47 --- /dev/null +++ b/lib/modules/embeddings/vector_representations/all_mpnet_base_v2.rb @@ -0,0 +1,50 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module VectorRepresentations + class AllMpnetBaseV2 < Base + def vector_from(text) + DiscourseAi::Inference::DiscourseClassifier.perform!( + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + name, + text, + SiteSetting.ai_embeddings_discourse_service_api_key, + ) + end + + def name + "all-mpnet-base-v2" + end + + def dimensions + 768 + end + + def max_sequence_length + 384 + end + + def id + 1 + end + + def version + 1 + end + + def pg_function + "<#>" + end + + def pg_index_type + "vector_ip_ops" + end + + def tokenizer + DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer + end + end + end + end +end diff --git a/lib/modules/embeddings/vector_representations/base.rb b/lib/modules/embeddings/vector_representations/base.rb new file mode 100644 index 00000000..e89bf259 --- /dev/null +++ b/lib/modules/embeddings/vector_representations/base.rb @@ -0,0 +1,166 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module VectorRepresentations + class Base + def self.current_representation(strategy) + subclasses.map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model } + end + + def initialize(strategy) + @strategy = strategy + end + + def create_index(lists, probes) + index_name = "#{table_name}_search" + + DB.exec(<<~SQL) + DROP INDEX IF EXISTS #{index_name}; + CREATE INDEX IF NOT EXISTS + #{index} + ON + #{table_name} + USING + ivfflat (embeddings #{pg_index_type}) + WITH + (lists = #{lists}) + WHERE + model_version = #{version} AND + strategy_version = #{@strategy.version}; + SQL + end + + def vector_from(text) + raise NotImplementedError + end + + def generate_topic_representation_from(target, persist: true) + text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2) + + vector_from(text).tap do |vector| + if persist + digest = OpenSSL::Digest::SHA1.hexdigest(text) + save_to_db(target, vector, digest) + end + end + end + + def topic_id_from_representation(raw_vector) + DB.query_single(<<~SQL, query_embedding: raw_vector).first + SELECT + topic_id + FROM + #{table_name} + ORDER BY + embeddings #{pg_function} '[:query_embedding]' + LIMIT 1 + SQL + end + + def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false) + results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset) + SELECT + topic_id, + embeddings #{pg_function} '[:query_embedding]' AS distance + FROM + #{table_name} + ORDER BY + embeddings #{pg_function} '[:query_embedding]' + LIMIT :limit + OFFSET :offset + SQL + + if return_distance + results.map { |r| [r.topic_id, r.distance] } + else + results.map(&:topic_id) + end + rescue PG::Error => e + Rails.logger.error("Error #{e} querying embeddings for model #{name}") + raise MissingEmbeddingError + end + + def symmetric_topics_similarity_search(topic) + DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id) + SELECT + topic_id + FROM + #{table_name} + ORDER BY + embeddings #{pg_function} ( + SELECT + embeddings + FROM + #{table_name} + WHERE + topic_id = :topic_id + LIMIT 1 + ) + LIMIT 100 + SQL + rescue PG::Error => e + Rails.logger.error( + "Error #{e} querying embeddings for topic #{topic.id} and model #{name}", + ) + raise MissingEmbeddingError + end + + def table_name + "ai_topic_embeddings_#{id}_#{@strategy.id}" + end + + def name + raise NotImplementedError + end + + def dimensions + raise NotImplementedError + end + + def max_sequence_length + raise NotImplementedError + end + + def id + raise NotImplementedError + end + + def pg_function + raise NotImplementedError + end + + def version + raise NotImplementedError + end + + def tokenizer + raise NotImplementedError + end + + protected + + def save_to_db(target, vector, digest) + DB.exec( + <<~SQL, + INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at) + VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + ON CONFLICT (topic_id) + DO UPDATE SET + model_version = :model_version, + strategy_version = :strategy_version, + digest = :digest, + embeddings = '[:embeddings]', + updated_at = CURRENT_TIMESTAMP + SQL + topic_id: target.id, + model_version: version, + strategy_version: @strategy.version, + digest: digest, + embeddings: vector, + ) + end + end + end + end +end diff --git a/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb b/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb new file mode 100644 index 00000000..30ffb4d8 --- /dev/null +++ b/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb @@ -0,0 +1,50 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module VectorRepresentations + class MultilingualE5Large < Base + def vector_from(text) + DiscourseAi::Inference::DiscourseClassifier.perform!( + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + name, + "query: #{text}", + SiteSetting.ai_embeddings_discourse_service_api_key, + ) + end + + def id + 3 + end + + def version + 1 + end + + def name + "multilingual-e5-large" + end + + def dimensions + 1024 + end + + def max_sequence_length + 512 + end + + def pg_function + "<=>" + end + + def pg_index_type + "vector_cosine_ops" + end + + def tokenizer + DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer + end + end + end + end +end diff --git a/lib/modules/embeddings/vector_representations/text_embedding_ada_002.rb b/lib/modules/embeddings/vector_representations/text_embedding_ada_002.rb new file mode 100644 index 00000000..3d1fc0ff --- /dev/null +++ b/lib/modules/embeddings/vector_representations/text_embedding_ada_002.rb @@ -0,0 +1,46 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module VectorRepresentations + class TextEmbeddingAda002 < Base + def id + 2 + end + + def version + 1 + end + + def name + "text-embedding-ada-002" + end + + def dimensions + 1536 + end + + def max_sequence_length + 8191 + end + + def pg_function + "<=>" + end + + def pg_index_type + "vector_cosine_ops" + end + + def vector_from(text) + response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text) + response[:data].first[:embedding] + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer + end + end + end + end +end diff --git a/lib/shared/inference/hugging_face_text_generation.rb b/lib/shared/inference/hugging_face_text_generation.rb index f3753a21..9a8cd22e 100644 --- a/lib/shared/inference/hugging_face_text_generation.rb +++ b/lib/shared/inference/hugging_face_text_generation.rb @@ -4,7 +4,7 @@ module ::DiscourseAi module Inference class HuggingFaceTextGeneration CompletionFailed = Class.new(StandardError) - TIMEOUT = 60 + TIMEOUT = 120 def self.perform!( prompt, diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake index 96b9f15d..4ee260cb 100644 --- a/lib/tasks/modules/embeddings/database.rake +++ b/lib/tasks/modules/embeddings/database.rake @@ -4,18 +4,22 @@ desc "Backfill embeddings for all topics" task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args| public_categories = Category.where(read_restricted: false).pluck(:id) manager = DiscourseAi::Embeddings::Manager.new(Topic.first) + + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy) + table_name = vector_rep.table_name + Topic - .joins( - "LEFT JOIN #{manager.topic_embeddings_table} ON #{manager.topic_embeddings_table}.topic_id = topics.id", - ) - .where("#{manager.topic_embeddings_table}.topic_id IS NULL") + .joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id") + .where("#{table_name}.topic_id IS NULL") .where("topics.id >= ?", args[:start_topic].to_i || 0) .where("category_id IN (?)", public_categories) .where(deleted_at: nil) .order("topics.id ASC") .find_each do |t| print "." - DiscourseAi::Embeddings::Manager.new(t).generate! + vector_rep.generate_topic_representation_from(t) end end @@ -28,25 +32,11 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args| lists = count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i probes = count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i - manager = DiscourseAi::Embeddings::Manager.new(Topic.first) - table = manager.topic_embeddings_table - index = "#{table}_search" + vector_representation_klass = DiscourseAi::Embeddings::Vectors::Base.find_vector_representation + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new DB.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';") - DB.exec(<<~SQL) - DROP INDEX IF EXISTS #{index}; - CREATE INDEX IF NOT EXISTS - #{index} - ON - #{table} - USING - ivfflat (embeddings #{manager.model.pg_index_type}) - WITH - (lists = #{lists}) - WHERE - model_version = #{manager.model.version} AND - strategy_version = #{manager.strategy.version}; - SQL + vector_representation_klass.new(strategy).create_index(lists, probes) DB.exec("RESET work_mem;") DB.exec("SET ivfflat.probes = #{probes};") end diff --git a/plugin.rb b/plugin.rb index cf01059c..0ad43c07 100644 --- a/plugin.rb +++ b/plugin.rb @@ -17,6 +17,7 @@ register_asset "stylesheets/modules/ai-helper/common/ai-helper.scss" register_asset "stylesheets/modules/ai-bot/common/bot-replies.scss" register_asset "stylesheets/modules/embeddings/common/semantic-related-topics.scss" +register_asset "stylesheets/modules/embeddings/common/semantic-search.scss" module ::DiscourseAi PLUGIN_NAME = "discourse-ai" diff --git a/spec/integration/embeddings/manager_spec.rb b/spec/integration/embeddings/manager_spec.rb deleted file mode 100644 index d515c5ed..00000000 --- a/spec/integration/embeddings/manager_spec.rb +++ /dev/null @@ -1,44 +0,0 @@ -# frozen_string_literal: true - -require_relative "../../support/embeddings_generation_stubs" - -RSpec.describe DiscourseAi::Embeddings::Manager do - let(:user) { Fabricate(:user) } - let(:expected_embedding) do - JSON.parse( - File.read("#{Rails.root}/plugins/discourse-ai/spec/fixtures/embeddings/embedding.txt"), - ) - end - let(:discourse_model) { "all-mpnet-base-v2" } - - before do - SiteSetting.discourse_ai_enabled = true - SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_model = "all-mpnet-base-v2" - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - Jobs.run_immediately! - end - - it "generates embeddings for new topics automatically" do - pc = - PostCreator.new( - user, - raw: "this is the new content for my topic", - title: "this is my new topic title", - ) - input = - "This is my new topic title\n\nUncategorized\n\n\n\nthis is the new content for my topic\n\n" - EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding) - post = pc.create - manager = DiscourseAi::Embeddings::Manager.new(post.topic) - - embeddings = - DB.query_single( - "SELECT embeddings FROM #{manager.topic_embeddings_table} WHERE topic_id = #{post.topic.id}", - ).first - - expect(embeddings.split(",")[1].to_f).to be_within(0.0001).of(expected_embedding[1]) - expect(embeddings.split(",")[13].to_f).to be_within(0.0001).of(expected_embedding[13]) - expect(embeddings.split(",")[135].to_f).to be_within(0.0001).of(expected_embedding[135]) - end -end diff --git a/spec/lib/modules/embeddings/entry_point_spec.rb b/spec/lib/modules/embeddings/entry_point_spec.rb index 72545005..3d4f14e8 100644 --- a/spec/lib/modules/embeddings/entry_point_spec.rb +++ b/spec/lib/modules/embeddings/entry_point_spec.rb @@ -28,97 +28,4 @@ describe DiscourseAi::Embeddings::EntryPoint do end end end - - describe "SemanticTopicQuery extension" do - describe "#list_semantic_related_topics" do - subject(:topic_query) { DiscourseAi::Embeddings::SemanticTopicQuery.new(user) } - - fab!(:target) { Fabricate(:topic) } - - def stub_semantic_search_with(results) - DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results) - end - - context "when the semantic search returns an unlisted topic" do - fab!(:unlisted_topic) { Fabricate(:topic, visible: false) } - - before { stub_semantic_search_with([unlisted_topic.id]) } - - it "filters it out" do - expect(topic_query.list_semantic_related_topics(target).topics).to be_empty - end - end - - context "when the semantic search returns a private topic" do - fab!(:private_topic) { Fabricate(:private_message_topic) } - - before { stub_semantic_search_with([private_topic.id]) } - - it "filters it out" do - expect(topic_query.list_semantic_related_topics(target).topics).to be_empty - end - end - - context "when the semantic search returns a topic from a restricted category" do - fab!(:group) { Fabricate(:group) } - fab!(:category) { Fabricate(:private_category, group: group) } - fab!(:secured_category_topic) { Fabricate(:topic, category: category) } - - before { stub_semantic_search_with([secured_category_topic.id]) } - - it "filters it out" do - expect(topic_query.list_semantic_related_topics(target).topics).to be_empty - end - - it "doesn't filter it out if the user has access to the category" do - group.add(user) - - expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly( - secured_category_topic, - ) - end - end - - context "when the semantic search returns a closed topic and we explicitly exclude them" do - fab!(:closed_topic) { Fabricate(:topic, closed: true) } - - before do - SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false - stub_semantic_search_with([closed_topic.id]) - end - - it "filters it out" do - expect(topic_query.list_semantic_related_topics(target).topics).to be_empty - end - end - - context "when the semantic search returns public topics" do - fab!(:normal_topic_1) { Fabricate(:topic) } - fab!(:normal_topic_2) { Fabricate(:topic) } - fab!(:normal_topic_3) { Fabricate(:topic) } - fab!(:closed_topic) { Fabricate(:topic, closed: true) } - - before do - stub_semantic_search_with( - [closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id], - ) - end - - it "filters it out" do - expect(topic_query.list_semantic_related_topics(target).topics).to eq( - [closed_topic, normal_topic_1, normal_topic_2, normal_topic_3], - ) - end - - it "returns the plugin limit for the number of results" do - SiteSetting.ai_embeddings_semantic_related_topics = 2 - - expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly( - closed_topic, - normal_topic_1, - ) - end - end - end - end end diff --git a/spec/lib/modules/embeddings/models/all_mpnet_base_v2_spec.rb b/spec/lib/modules/embeddings/models/all_mpnet_base_v2_spec.rb deleted file mode 100644 index 6766d72a..00000000 --- a/spec/lib/modules/embeddings/models/all_mpnet_base_v2_spec.rb +++ /dev/null @@ -1,24 +0,0 @@ -# frozen_string_literal: true - -require_relative "../../../../support/embeddings_generation_stubs" - -RSpec.describe DiscourseAi::Embeddings::Models::AllMpnetBaseV2 do - describe "#generate_embeddings" do - let(:input) { "test" } - let(:expected_embedding) { [0.0038493, 0.482001] } - - context "when the model uses the discourse service to create embeddings" do - before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } - - let(:discourse_model) { "all-mpnet-base-v2" } - - it "returns an embedding for a given string" do - EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding) - - embedding = described_class.generate_embeddings(input) - - expect(embedding).to contain_exactly(*expected_embedding) - end - end - end -end diff --git a/spec/lib/modules/embeddings/models/text_embedding_ada_002_spec.rb b/spec/lib/modules/embeddings/models/text_embedding_ada_002_spec.rb deleted file mode 100644 index 60d41d02..00000000 --- a/spec/lib/modules/embeddings/models/text_embedding_ada_002_spec.rb +++ /dev/null @@ -1,22 +0,0 @@ -# frozen_string_literal: true - -require_relative "../../../../support/embeddings_generation_stubs" - -RSpec.describe DiscourseAi::Embeddings::Models::TextEmbeddingAda002 do - describe "#generate_embeddings" do - let(:input) { "test" } - let(:expected_embedding) { [0.0038493, 0.482001] } - - context "when the model uses OpenAI to create embeddings" do - let(:openai_model) { "text-embedding-ada-002" } - - it "returns an embedding for a given string" do - EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding) - - embedding = described_class.generate_embeddings(input) - - expect(embedding).to contain_exactly(*expected_embedding) - end - end - end -end diff --git a/spec/lib/modules/embeddings/semantic_related_spec.rb b/spec/lib/modules/embeddings/semantic_related_spec.rb index 22c85b51..1911218f 100644 --- a/spec/lib/modules/embeddings/semantic_related_spec.rb +++ b/spec/lib/modules/embeddings/semantic_related_spec.rb @@ -3,6 +3,8 @@ require "rails_helper" describe DiscourseAi::Embeddings::SemanticRelated do + subject(:semantic_related) { described_class.new } + fab!(:target) { Fabricate(:topic) } fab!(:normal_topic_1) { Fabricate(:topic) } fab!(:normal_topic_2) { Fabricate(:topic) } @@ -25,13 +27,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do results = nil expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do - results = described_class.related_topic_ids_for(topic) + results = semantic_related.related_topic_ids_for(topic) end expect(results).to eq([]) expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do - results = described_class.related_topic_ids_for(topic) + results = semantic_related.related_topic_ids_for(topic) end expect(results).to eq([]) diff --git a/spec/lib/modules/embeddings/semantic_search_spec.rb b/spec/lib/modules/embeddings/semantic_search_spec.rb index 48ac4584..49826dd0 100644 --- a/spec/lib/modules/embeddings/semantic_search_spec.rb +++ b/spec/lib/modules/embeddings/semantic_search_spec.rb @@ -1,5 +1,8 @@ # frozen_string_literal: true +require_relative "../../../support/embeddings_generation_stubs" +require_relative "../../../support/openai_completions_inference_stubs" + RSpec.describe DiscourseAi::Embeddings::SemanticSearch do fab!(:post) { Fabricate(:post) } fab!(:user) { Fabricate(:user) } @@ -8,10 +11,28 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do let(:subject) { described_class.new(Guardian.new(user)) } describe "#search_for_topics" do + let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" } + + before do + SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + + prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query) + OpenAiCompletionsInferenceStubs.stub_response(prompt, hypothetical_post) + + hyde_embedding = [0.049382, 0.9999] + EmbeddingsGenerationStubs.discourse_service( + SiteSetting.ai_embeddings_model, + hypothetical_post, + hyde_embedding, + ) + end + + after { described_class.clear_cache_for(query) } + def stub_candidate_ids(candidate_ids) - DiscourseAi::Embeddings::SemanticSearch + DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 .any_instance - .expects(:asymmetric_semantic_search) + .expects(:asymmetric_topics_similarity_search) .returns(candidate_ids) end diff --git a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb index cfc4bc20..911bb4d8 100644 --- a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb +++ b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb @@ -12,9 +12,14 @@ describe DiscourseAi::Embeddings::EntryPoint do fab!(:target) { Fabricate(:topic) } def stub_semantic_search_with(results) - DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results) + DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 + .any_instance + .expects(:symmetric_topics_similarity_search) + .returns(results.concat([target.id])) end + after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) } + context "when the semantic search returns an unlisted topic" do fab!(:unlisted_topic) { Fabricate(:topic, visible: false) } diff --git a/spec/lib/modules/embeddings/strategies/truncation_spec.rb b/spec/lib/modules/embeddings/strategies/truncation_spec.rb index c25ade73..850db322 100644 --- a/spec/lib/modules/embeddings/strategies/truncation_spec.rb +++ b/spec/lib/modules/embeddings/strategies/truncation_spec.rb @@ -1,8 +1,10 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do - describe "#process!" do - context "when the model uses OpenAI to create embeddings" do + subject(:truncation) { described_class.new } + + describe "#prepare_text_from" do + context "when using vector from OpenAI" do before { SiteSetting.max_post_length = 100_000 } fab!(:topic) { Fabricate(:topic) } @@ -18,13 +20,15 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do end fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) } - let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first } - let(:truncation) { described_class.new(topic, model) } + let(:model) do + DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation) + end it "truncates a topic" do - truncation.process! + prepared_text = + truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length) - expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length + expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length end end end diff --git a/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb b/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb new file mode 100644 index 00000000..16f20abf --- /dev/null +++ b/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb @@ -0,0 +1,18 @@ +# frozen_string_literal: true + +require_relative "../../../../support/embeddings_generation_stubs" +require_relative "vector_rep_shared_examples" + +RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do + subject(:vector_rep) { described_class.new(truncation) } + + let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } + + before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) + end + + it_behaves_like "generates and store embedding using with vector representation" +end diff --git a/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb b/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb new file mode 100644 index 00000000..1c1b2a5f --- /dev/null +++ b/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +require_relative "../../../../support/embeddings_generation_stubs" +require_relative "vector_rep_shared_examples" + +RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do + subject(:vector_rep) { described_class.new(truncation) } + + let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } + + before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.discourse_service( + vector_rep.name, + "query: #{text}", + expected_embedding, + ) + end + + it_behaves_like "generates and store embedding using with vector representation" +end diff --git a/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb b/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb new file mode 100644 index 00000000..59b48dd3 --- /dev/null +++ b/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../../../../support/embeddings_generation_stubs" +require_relative "vector_rep_shared_examples" + +RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do + subject(:vector_rep) { described_class.new(truncation) } + + let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding) + end + + it_behaves_like "generates and store embedding using with vector representation" +end diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb new file mode 100644 index 00000000..7d5cc213 --- /dev/null +++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb @@ -0,0 +1,54 @@ +# frozen_string_literal: true + +RSpec.shared_examples "generates and store embedding using with vector representation" do + before { @expected_embedding = [0.0038493] * vector_rep.dimensions } + + describe "#vector_from" do + it "creates a vector from a given string" do + text = "This is a piece of text" + stub_vector_mapping(text, @expected_embedding) + + expect(vector_rep.vector_from(text)).to eq(@expected_embedding) + end + end + + describe "#generate_topic_representation_from" do + fab!(:topic) { Fabricate(:topic) } + fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } + + it "creates a vector from a topic and stores it in the database" do + text = + truncation.prepare_text_from( + topic, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + stub_vector_mapping(text, @expected_embedding) + + vector_rep.generate_topic_representation_from(topic) + + expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id) + end + end + + describe "#asymmetric_topics_similarity_search" do + fab!(:topic) { Fabricate(:topic) } + fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } + + it "finds IDs of similar topics with a given embedding" do + similar_vector = [0.0038494] * vector_rep.dimensions + text = + truncation.prepare_text_from( + topic, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + stub_vector_mapping(text, @expected_embedding) + vector_rep.generate_topic_representation_from(topic) + + expect( + vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0), + ).to contain_exactly(topic.id) + end + end +end diff --git a/spec/requests/topic_spec.rb b/spec/requests/topic_spec.rb index 25a00895..a4f2141d 100644 --- a/spec/requests/topic_spec.rb +++ b/spec/requests/topic_spec.rb @@ -16,9 +16,10 @@ describe ::TopicsController do context "when a user is logged on" do it "includes related topics in payload when configured" do - DiscourseAi::Embeddings::SemanticRelated.stubs(:related_topic_ids_for).returns( - [topic1.id, topic2.id, topic3.id], - ) + DiscourseAi::Embeddings::SemanticRelated + .any_instance + .stubs(:related_topic_ids_for) + .returns([topic1.id, topic2.id, topic3.id]) get("#{topic.relative_url}.json") expect(response.status).to eq(200)