diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 37c661c8..0b63278c 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -31,11 +31,21 @@ en: ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW." ai_nsfw_models: "Models to use for NSFW inference." + ai_openai_api_key: "API key for OpenAI API" + composer_ai_helper_enabled: "Enable the Composer's AI helper." - ai_openai_api_key: "API key for the AI helper" ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer." ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs." + ai_embeddings_enabled: "Enable the embeddings module." + ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module" + ai_embeddings_discourse_service_api_key: "API key for the embeddings API" + ai_embeddings_models: "Discourse will generate embeddings for each of the models enabled here" + ai_embeddings_semantic_suggested_model: "Model to use for suggested topics." + ai_embeddings_generate_for_pms: "Generate embeddings for personal messages." + ai_embeddings_semantic_suggested_topics_anons_enabled: "Use Semantic Search for suggested topics for anonymous users." + ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info." + reviewables: reasons: flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. diff --git a/config/settings.yml b/config/settings.yml index 79ed96b7..4d9a10ea 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -101,3 +101,32 @@ plugins: ai_helper_allowed_in_pm: default: false client: true + + ai_embeddings_enabled: false + ai_embeddings_discourse_service_api_endpoint: "" + ai_embeddings_discourse_service_api_key: "" + ai_embeddings_models: + type: list + list_type: compact + default: "" + allow_any: false + choices: + - all-mpnet-base-v2 + - all-distilroberta-v1 + - multi-qa-mpnet-base-dot-v1 + - paraphrase-multilingual-mpnet-base-v2 + - msmarco-distilbert-base-v4 + - msmarco-distilbert-base-tas-b + - text-embedding-ada-002 + ai_embeddings_semantic_suggested_model: + type: enum + default: all-mpnet-base-v2 + choices: + - all-mpnet-base-v2 + - text-embedding-ada-002 + - all-distilroberta-v1 + - multi-qa-mpnet-base-dot-v1 + - paraphrase-multilingual-mpnet-base-v2 + ai_embeddings_generate_for_pms: false + ai_embeddings_semantic_suggested_topics_anons_enabled: false + ai_embeddings_pg_connection_string: "" diff --git a/lib/modules/embeddings/entry_point.rb b/lib/modules/embeddings/entry_point.rb new file mode 100644 index 00000000..93c4c499 --- /dev/null +++ b/lib/modules/embeddings/entry_point.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class EntryPoint + def load_files + require_relative "models" + require_relative "topic" + require_relative "jobs/regular/generate_embeddings" + require_relative "semantic_suggested" + end + + def inject_into(plugin) + callback = + Proc.new do |topic| + if SiteSetting.ai_embeddings_enabled + Jobs.enqueue(:generate_embeddings, topic_id: topic.id) + end + end + + plugin.on(:topic_created, &callback) + plugin.on(:topic_edited, &callback) + + DiscoursePluginRegistry.register_list_suggested_for_provider( + SemanticSuggested.method(:build_suggested_topics), + plugin, + ) + end + end + end +end diff --git a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb new file mode 100644 index 00000000..2004aaa3 --- /dev/null +++ b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb @@ -0,0 +1,17 @@ +# frozen_string_literal: true + +module Jobs + class GenerateEmbeddings < ::Jobs::Base + def execute(args) + return unless SiteSetting.ai_embeddings_enabled + return if (topic_id = args[:topic_id]).blank? + + topic = Topic.find_by_id(topic_id) + return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms + post = Topic.find_by_id(topic_id).first_post + return if post.nil? || post.raw.blank? + + DiscourseAi::Embeddings::Topic.new(post.topic).perform! + end + end +end diff --git a/lib/modules/embeddings/models.rb b/lib/modules/embeddings/models.rb new file mode 100644 index 00000000..1d1dc391 --- /dev/null +++ b/lib/modules/embeddings/models.rb @@ -0,0 +1,62 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class Models + MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider) + + SEARCH_FUNCTION_TO_PG_INDEX = { + dot: "vector_ip_ops", + cosine: "vector_cosine_ops", + euclidean: "vector_l2_ops", + } + + SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" } + + def self.enabled_models + setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip) + list.filter { |model| setting.include?(model.name) } + end + + def self.list + @@list ||= [ + MODEL.new( + "all-mpnet-base-v2", + 768, + 384, + %i[dot cosine euclidean], + [:symmetric], + "discourse", + ), + MODEL.new( + "all-distilroberta-v1", + 768, + 512, + %i[dot cosine euclidean], + [:symmetric], + "discourse", + ), + MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"), + MODEL.new( + "paraphrase-multilingual-mpnet-base-v2", + 768, + 128, + [:cosine], + [:symmetric], + "discourse", + ), + MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"), + MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"), + MODEL.new( + "text-embedding-ada-002", + 1536, + 2048, + [:cosine], + %i[:symmetric :asymmetric], + "openai", + ), + ] + end + end + end +end diff --git a/lib/modules/embeddings/semantic_suggested.rb b/lib/modules/embeddings/semantic_suggested.rb new file mode 100644 index 00000000..27e3d4a5 --- /dev/null +++ b/lib/modules/embeddings/semantic_suggested.rb @@ -0,0 +1,72 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class SemanticSuggested + def self.build_suggested_topics(topic, pm_params, topic_query) + return unless SiteSetting.ai_embeddings_semantic_suggested_topics_anons_enabled + return if topic_query.user + return if topic.private_message? + + cache_for = + case topic.created_at + when 6.hour.ago..Time.now + 15.minutes + when 1.day.ago..6.hour.ago + 1.hour + else + 1.day + end + + begin + candidate_ids = + Discourse + .cache + .fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do + suggested = search_suggestions(topic) + + # Happens when the topic doesn't have any embeddings + if suggested.empty? || !suggested.include?(topic.id) + return { result: [], params: {} } + end + + suggested + end + rescue StandardError => e + Rails.logger.error("SemanticSuggested: #{e}") + end + + # array_position forces the order of the topics to be preserved + candidates = + ::Topic.where(id: candidate_ids).order("array_position(ARRAY#{candidate_ids}, id)") + + { result: candidates, params: {} } + end + + def self.search_suggestions(topic) + model_name = SiteSetting.ai_embeddings_semantic_suggested_model + model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name } + function = + DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first] + + DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id) + SELECT + topic_id + FROM + topic_embeddings_#{model_name.underscore} + ORDER BY + embedding #{function} ( + SELECT + embedding + FROM + topic_embeddings_#{model_name.underscore} + WHERE + topic_id = :topic_id + LIMIT 1 + ) + LIMIT 11 + SQL + end + end + end +end diff --git a/lib/modules/embeddings/topic.rb b/lib/modules/embeddings/topic.rb new file mode 100644 index 00000000..73a3cc2b --- /dev/null +++ b/lib/modules/embeddings/topic.rb @@ -0,0 +1,57 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class Topic + def initialize(topic) + @topic = topic + @embeddings = {} + end + + def perform! + return unless SiteSetting.ai_embeddings_enabled + return if DiscourseAi::Embeddings::Models.enabled_models.empty? + + calculate_embeddings! + persist_embeddings! unless @embeddings.empty? + end + + def calculate_embeddings! + return if @topic.blank? || @topic.first_post.blank? + + DiscourseAi::Embeddings::Models.enabled_models.each do |model| + @embeddings[model.name] = send("#{model.provider}_embeddings", model.name) + end + end + + def persist_embeddings! + @embeddings.each do |model, model_embedding| + DiscourseAi::Database::Connection.db.exec( + <<~SQL, + INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding) + VALUES (:topic_id, '[:embedding]') + ON CONFLICT (topic_id) + DO UPDATE SET embedding = '[:embedding]' + SQL + topic_id: @topic.id, + embedding: model_embedding, + ) + end + end + + def discourse_embeddings(model) + DiscourseAi::Inference::DiscourseClassifier.perform!( + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + model.to_s, + @topic.first_post.raw, + SiteSetting.ai_embeddings_discourse_service_api_key, + ) + end + + def openai_embeddings(model) + response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw) + response[:data].first[:embedding] + end + end + end +end diff --git a/lib/shared/database/connection.rb b/lib/shared/database/connection.rb new file mode 100644 index 00000000..13ee2aac --- /dev/null +++ b/lib/shared/database/connection.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +module ::DiscourseAi + module Database + class Connection + def self.connect! + pg_conn = PG.connect(SiteSetting.ai_embeddings_pg_connection_string) + @@db = MiniSql::Connection.get(pg_conn) + end + + def self.db + @@db ||= connect! + end + end + end +end diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake new file mode 100644 index 00000000..cee4d3f3 --- /dev/null +++ b/lib/tasks/modules/embeddings/database.rake @@ -0,0 +1,46 @@ +# frozen_string_literal: true + +desc "Creates tables to store embeddings" +task "ai:embeddings:create_table" => [:environment] do + DiscourseAi::Embeddings::Models.enabled_models.each do |model| + DiscourseAi::Database::Connection.db.exec(<<~SQL) + CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} ( + topic_id bigint PRIMARY KEY, + embedding vector(#{model.dimensions}) + ); + SQL + end +end + +desc "Backfill embeddings for all topics" +task "ai:embeddings:backfill" => [:environment] do + public_categories = Category.where(read_restricted: false).pluck(:id) + Topic + .where("category_id IN ?", public_categories) + .where(deleted_at: nil) + .find_each do |t| + print "." + DiscourseAI::Embeddings::Topic.new(t).perform! + end +end + +desc "Creates indexes for embeddings" +task "ai:embeddings:index" => [:environment] do + # Using 4 * sqrt(number of topics) as a rule of thumb for now + # Results are not as good as without indexes, but it's much faster + # Disk usage is ~1x the size of the table, so this double table total size + lists = 4 * Math.sqrt(Topic.count).to_i + + DiscourseAi::Embeddings::Models.enabled_models.each do |model| + DiscourseAi::Database::Connection.db.exec(<<~SQL) + CREATE INDEX IF NOT EXISTS + topic_embeddings_#{model.name.underscore}_search + ON + topic_embeddings_#{model.name.underscore} + USING + ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]}) + WITH + (lists = #{lists}); + SQL + end +end diff --git a/plugin.rb b/plugin.rb index f8714b6e..456b8903 100644 --- a/plugin.rb +++ b/plugin.rb @@ -27,12 +27,16 @@ after_initialize do require_relative "lib/shared/post_classificator" require_relative "lib/shared/chat_message_classificator" + require_relative "lib/shared/database/connection" + require_relative "lib/modules/nsfw/entry_point" require_relative "lib/modules/toxicity/entry_point" require_relative "lib/modules/sentiment/entry_point" require_relative "lib/modules/ai_helper/entry_point" + require_relative "lib/modules/embeddings/entry_point" [ + DiscourseAi::Embeddings::EntryPoint.new, DiscourseAi::NSFW::EntryPoint.new, DiscourseAi::Toxicity::EntryPoint.new, DiscourseAi::Sentiment::EntryPoint.new,