diff --git a/app/jobs/regular/generate_embeddings.rb b/app/jobs/regular/generate_embeddings.rb index b67cbb1f..05b22ea6 100644 --- a/app/jobs/regular/generate_embeddings.rb +++ b/app/jobs/regular/generate_embeddings.rb @@ -6,18 +6,20 @@ module Jobs def execute(args) return unless SiteSetting.ai_embeddings_enabled - return if (topic_id = args[:topic_id]).blank? + return if args[:target_type].blank? || args[:target_id].blank? + target = args[:target_type].constantize.find_by_id(args[:target_id]) + return if target.nil? || target.deleted_at.present? - topic = Topic.find_by_id(topic_id) - return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms - post = topic.first_post - return if post.nil? || post.raw.blank? + topic = target.is_a?(Topic) ? target : target.topic + post = target.is_a?(Post) ? target : target.first_post + return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms + return if post.raw.blank? strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) - vector_rep.generate_topic_representation_from(topic) + vector_rep.generate_representation_from(target) end end end diff --git a/app/jobs/scheduled/embeddings_backfill.rb b/app/jobs/scheduled/embeddings_backfill.rb index fc59fff4..d0384da4 100644 --- a/app/jobs/scheduled/embeddings_backfill.rb +++ b/app/jobs/scheduled/embeddings_backfill.rb @@ -15,7 +15,7 @@ module Jobs strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) - table_name = vector_rep.table_name + table_name = vector_rep.topic_table_name topics = Topic @@ -28,7 +28,7 @@ module Jobs topics .where("#{table_name}.topic_id IS NULL") .find_each do |t| - vector_rep.generate_topic_representation_from(t) + vector_rep.generate_representation_from(t) rebaked += 1 end @@ -45,7 +45,7 @@ module Jobs #{table_name}.strategy_version < #{strategy.version} SQL .find_each do |t| - vector_rep.generate_topic_representation_from(t) + vector_rep.generate_representation_from(t) rebaked += 1 end @@ -59,7 +59,57 @@ module Jobs .limit((limit - rebaked) / 10) .pluck(:id) .each do |id| - vector_rep.generate_topic_representation_from(Topic.find_by(id: id)) + vector_rep.generate_representation_from(Topic.find_by(id: id)) + rebaked += 1 + end + + return if rebaked >= limit + + # Now for posts + table_name = vector_rep.post_table_name + + posts = + Post + .joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id") + .where(deleted_at: nil) + .limit(limit - rebaked) + + # First, we'll try to backfill embeddings for posts that have none + posts + .where("#{table_name}.post_id IS NULL") + .find_each do |t| + vector_rep.generate_representation_from(t) + rebaked += 1 + end + + vector_rep.consider_indexing + + return if rebaked >= limit + + # Then, we'll try to backfill embeddings for posts that have outdated + # embeddings, be it model or strategy version + posts + .where(<<~SQL) + #{table_name}.model_version < #{vector_rep.version} + OR + #{table_name}.strategy_version < #{strategy.version} + SQL + .find_each do |t| + vector_rep.generate_representation_from(t) + rebaked += 1 + end + + return if rebaked >= limit + + # Finally, we'll try to backfill embeddings for posts that have outdated + # embeddings due to edits. Here we only do 10% of the limit + posts + .where("#{table_name}.updated_at < ?", 7.days.ago) + .order("random()") + .limit((limit - rebaked) / 10) + .pluck(:id) + .each do |id| + vector_rep.generate_representation_from(Post.find_by(id: id)) rebaked += 1 end diff --git a/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb b/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb index 70eaa864..5bac1c0c 100644 --- a/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb +++ b/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb @@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0] ].map { |k| k.new(truncation) } vector_reps.each do |vector_rep| - new_table_name = vector_rep.table_name + new_table_name = vector_rep.topic_table_name old_table_name = "topic_embeddings_#{vector_rep.name.underscore}" begin diff --git a/db/migrate/20231228213036_create_ai_post_embeddings_tables.rb b/db/migrate/20231228213036_create_ai_post_embeddings_tables.rb new file mode 100644 index 00000000..f6ad3174 --- /dev/null +++ b/db/migrate/20231228213036_create_ai_post_embeddings_tables.rb @@ -0,0 +1,60 @@ +# frozen_string_literal: true + +class CreateAiPostEmbeddingsTables < ActiveRecord::Migration[7.0] + def change + create_table :ai_post_embeddings_1_1, id: false do |t| + t.integer :post_id, null: false + t.integer :model_version, null: false + t.integer :strategy_version, null: false + t.text :digest, null: false + t.column :embeddings, "vector(768)", null: false + t.timestamps + + t.index :post_id, unique: true + end + + create_table :ai_post_embeddings_2_1, id: false do |t| + t.integer :post_id, null: false + t.integer :model_version, null: false + t.integer :strategy_version, null: false + t.text :digest, null: false + t.column :embeddings, "vector(1536)", null: false + t.timestamps + + t.index :post_id, unique: true + end + + create_table :ai_post_embeddings_3_1, id: false do |t| + t.integer :post_id, null: false + t.integer :model_version, null: false + t.integer :strategy_version, null: false + t.text :digest, null: false + t.column :embeddings, "vector(1024)", null: false + t.timestamps + + t.index :post_id, unique: true + end + + create_table :ai_post_embeddings_4_1, id: false do |t| + t.integer :post_id, null: false + t.integer :model_version, null: false + t.integer :strategy_version, null: false + t.text :digest, null: false + t.column :embeddings, "vector(1024)", null: false + t.timestamps + + t.index :post_id, unique: true + end + + create_table :ai_post_embeddings_5_1, id: false do |t| + t.integer :post_id, null: false + t.integer :model_version, null: false + t.integer :strategy_version, null: false + t.text :digest, null: false + t.column :embeddings, "vector(768)", null: false + t.timestamps + + t.index :post_id, unique: true + end + end +end diff --git a/lib/embeddings/entry_point.rb b/lib/embeddings/entry_point.rb index 3ce2551f..bb46b319 100644 --- a/lib/embeddings/entry_point.rb +++ b/lib/embeddings/entry_point.rb @@ -43,14 +43,20 @@ module DiscourseAi # embeddings generation. callback = - Proc.new do |topic| + Proc.new do |target| if SiteSetting.ai_embeddings_enabled - Jobs.enqueue(:generate_embeddings, topic_id: topic.id) + Jobs.enqueue( + :generate_embeddings, + target_id: target.id, + target_type: target.class.name, + ) end end plugin.on(:topic_created, &callback) plugin.on(:topic_edited, &callback) + plugin.on(:post_created, &callback) + plugin.on(:post_edited, &callback) end end end diff --git a/lib/embeddings/strategies/truncation.rb b/lib/embeddings/strategies/truncation.rb index e88693cb..39d8871f 100644 --- a/lib/embeddings/strategies/truncation.rb +++ b/lib/embeddings/strategies/truncation.rb @@ -50,9 +50,9 @@ module DiscourseAi tokenizer.truncate(text, max_length) end - def post_truncation(topic, tokenizer, max_length) + def post_truncation(post, tokenizer, max_length) text = +topic_information(post.topic) - text << post.raw + text << Nokogiri::HTML5.fragment(post.cooked).text tokenizer.truncate(text, max_length) end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 3a5cb469..b6e65ffa 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -21,62 +21,66 @@ module DiscourseAi end def consider_indexing(memory: "100MB") - # Using extension maintainer's recommendation for ivfflat indexes - # Results are not as good as without indexes, but it's much faster - # Disk usage is ~1x the size of the table, so this doubles table total size - count = DB.query_single("SELECT count(*) FROM #{table_name};").first - lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max - probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max + [topic_table_name, post_table_name].each do |table_name| + index_name = index_name(table_name) + # Using extension maintainer's recommendation for ivfflat indexes + # Results are not as good as without indexes, but it's much faster + # Disk usage is ~1x the size of the table, so this doubles table total size + count = DB.query_single("SELECT count(*) FROM #{table_name};").first + lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max + probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max - existing_index = DB.query_single(<<~SQL, index_name: index_name).first - SELECT - indexdef - FROM - pg_indexes - WHERE - indexname = :index_name - LIMIT 1 - SQL + existing_index = DB.query_single(<<~SQL, index_name: index_name).first + SELECT + indexdef + FROM + pg_indexes + WHERE + indexname = :index_name + LIMIT 1 + SQL - if !existing_index.present? - Rails.logger.info("Index #{index_name} does not exist, creating...") - return create_index!(memory, lists, probes) - end - - existing_index_age = - DB - .query_single( - "SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');", - index_name: index_name, - ) - .first - .to_i || 0 - new_rows = - DB.query_single( - "SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';", - ).first - existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i - - if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i - if new_rows > 10_000 - Rails.logger.info( - "Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...", - ) - return create_index!(memory, lists, probes) - elsif existing_lists != lists - Rails.logger.info( - "Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...", - ) - return create_index!(memory, lists, probes) + if !existing_index.present? + Rails.logger.info("Index #{index_name} does not exist, creating...") + return create_index!(table_name, memory, lists, probes) end - end - Rails.logger.info( - "Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.", - ) + existing_index_age = + DB + .query_single( + "SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');", + index_name: index_name, + ) + .first + .to_i || 0 + new_rows = + DB.query_single( + "SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';", + ).first + existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i + + if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i + if new_rows > 10_000 + Rails.logger.info( + "Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...", + ) + return create_index!(table_name, memory, lists, probes) + elsif existing_lists != lists + Rails.logger.info( + "Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...", + ) + return create_index!(table_name, memory, lists, probes) + end + end + + Rails.logger.info( + "Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.", + ) + end end - def create_index!(memory, lists, probes) + def create_index!(table_name, memory, lists, probes) + index_name = index_name(table_name) DB.exec("SET work_mem TO '#{memory}';") DB.exec("SET maintenance_work_mem TO '#{memory}';") DB.exec(<<~SQL) @@ -102,17 +106,17 @@ module DiscourseAi raise NotImplementedError end - def generate_topic_representation_from(target, persist: true) + def generate_representation_from(target, persist: true) text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2) new_digest = OpenSSL::Digest::SHA1.hexdigest(text) - current_digest = DB.query_single(<<~SQL, topic_id: target.id).first + current_digest = DB.query_single(<<~SQL, target_id: target.id).first SELECT digest FROM - #{table_name} + #{table_name(target)} WHERE - topic_id = :topic_id + #{target.is_a?(Topic) ? "topic_id" : "post_id"} = :target_id LIMIT 1 SQL return if current_digest == new_digest @@ -127,7 +131,19 @@ module DiscourseAi SELECT topic_id FROM - #{table_name} + #{topic_table_name} + ORDER BY + embeddings #{pg_function} '[:query_embedding]' + LIMIT 1 + SQL + end + + def post_id_from_representation(raw_vector) + DB.query_single(<<~SQL, query_embedding: raw_vector).first + SELECT + post_id + FROM + #{post_table_name} ORDER BY embeddings #{pg_function} '[:query_embedding]' LIMIT 1 @@ -140,7 +156,7 @@ module DiscourseAi topic_id, embeddings #{pg_function} '[:query_embedding]' AS distance FROM - #{table_name} + #{topic_table_name} ORDER BY embeddings #{pg_function} '[:query_embedding]' LIMIT :limit @@ -162,13 +178,13 @@ module DiscourseAi SELECT topic_id FROM - #{table_name} + #{topic_table_name} ORDER BY embeddings #{pg_function} ( SELECT embeddings FROM - #{table_name} + #{topic_table_name} WHERE topic_id = :topic_id LIMIT 1 @@ -182,11 +198,26 @@ module DiscourseAi raise MissingEmbeddingError end - def table_name + def topic_table_name "ai_topic_embeddings_#{id}_#{@strategy.id}" end - def index_name + def post_table_name + "ai_post_embeddings_#{id}_#{@strategy.id}" + end + + def table_name(target) + case target + when Topic + topic_table_name + when Post + post_table_name + else + raise ArgumentError, "Invalid target type" + end + end + + def index_name(table_name) "#{table_name}_search" end @@ -221,24 +252,47 @@ module DiscourseAi protected def save_to_db(target, vector, digest) - DB.exec( - <<~SQL, - INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at) - VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ON CONFLICT (topic_id) - DO UPDATE SET - model_version = :model_version, - strategy_version = :strategy_version, - digest = :digest, - embeddings = '[:embeddings]', - updated_at = CURRENT_TIMESTAMP - SQL - topic_id: target.id, - model_version: version, - strategy_version: @strategy.version, - digest: digest, - embeddings: vector, - ) + if target.is_a?(Topic) + DB.exec( + <<~SQL, + INSERT INTO #{topic_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at) + VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + ON CONFLICT (topic_id) + DO UPDATE SET + model_version = :model_version, + strategy_version = :strategy_version, + digest = :digest, + embeddings = '[:embeddings]', + updated_at = CURRENT_TIMESTAMP + SQL + topic_id: target.id, + model_version: version, + strategy_version: @strategy.version, + digest: digest, + embeddings: vector, + ) + elsif target.is_a?(Post) + DB.exec( + <<~SQL, + INSERT INTO #{post_table_name} (post_id, model_version, strategy_version, digest, embeddings, created_at, updated_at) + VALUES (:post_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + ON CONFLICT (post_id) + DO UPDATE SET + model_version = :model_version, + strategy_version = :strategy_version, + digest = :digest, + embeddings = '[:embeddings]', + updated_at = CURRENT_TIMESTAMP + SQL + post_id: target.id, + model_version: version, + strategy_version: @strategy.version, + digest: digest, + embeddings: vector, + ) + else + raise ArgumentError, "Invalid target type" + end end end end diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake index c89e7926..e6249453 100644 --- a/lib/tasks/modules/embeddings/database.rake +++ b/lib/tasks/modules/embeddings/database.rake @@ -1,23 +1,33 @@ # frozen_string_literal: true -desc "Backfill embeddings for all topics" -task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args| +desc "Backfill embeddings for all topics and posts" +task "ai:embeddings:backfill" => [:environment] do public_categories = Category.where(read_restricted: false).pluck(:id) strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) - table_name = vector_rep.table_name + table_name = vector_rep.topic_table_name Topic .joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id") .where("#{table_name}.topic_id IS NULL") - .where("topics.id >= ?", args[:start_topic].to_i || 0) .where("category_id IN (?)", public_categories) .where(deleted_at: nil) - .order("topics.id ASC") + .order("topics.id DESC") .find_each do |t| print "." - vector_rep.generate_topic_representation_from(t) + vector_rep.generate_representation_from(t) + end + + table_name = vector_rep.post_table_name + Post + .joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id") + .where("#{table_name}.post_id IS NULL") + .where(deleted_at: nil) + .order("posts.id DESC") + .find_each do |t| + print "." + vector_rep.generate_representation_from(t) end end diff --git a/spec/lib/modules/embeddings/entry_point_spec.rb b/spec/lib/modules/embeddings/entry_point_spec.rb index 4139d704..cc848bf9 100644 --- a/spec/lib/modules/embeddings/entry_point_spec.rb +++ b/spec/lib/modules/embeddings/entry_point_spec.rb @@ -18,7 +18,7 @@ describe DiscourseAi::Embeddings::EntryPoint do it "queues a job on create if embeddings is enabled" do SiteSetting.ai_embeddings_enabled = true - expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(1) + expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(2) # topic_created and post_created end it "does nothing if sentiment analysis is disabled" do diff --git a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb index c611c8d8..d8758289 100644 --- a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb +++ b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb @@ -18,7 +18,7 @@ RSpec.describe Jobs::GenerateEmbeddings do DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) end - it "works" do + it "works for topics" do expected_embedding = [0.0038493] * vector_rep.dimensions text = @@ -29,9 +29,21 @@ RSpec.describe Jobs::GenerateEmbeddings do ) EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) - job.execute(topic_id: topic.id) + job.execute(target_id: topic.id, target_type: "Topic") expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id) end + + it "works for posts" do + expected_embedding = [0.0038493] * vector_rep.dimensions + + text = + truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2) + EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) + + job.execute(target_id: post.id, target_type: "Post") + + expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id) + end end end diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb index 7d5cc213..9689a3c6 100644 --- a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb +++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb @@ -12,9 +12,10 @@ RSpec.shared_examples "generates and store embedding using with vector represent end end - describe "#generate_topic_representation_from" do + describe "#generate_representation_from" do fab!(:topic) { Fabricate(:topic) } fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } + fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) } it "creates a vector from a topic and stores it in the database" do text = @@ -25,10 +26,24 @@ RSpec.shared_examples "generates and store embedding using with vector represent ) stub_vector_mapping(text, @expected_embedding) - vector_rep.generate_topic_representation_from(topic) + vector_rep.generate_representation_from(topic) expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id) end + + it "creates a vector from a post and stores it in the database" do + text = + truncation.prepare_text_from( + post2, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + stub_vector_mapping(text, @expected_embedding) + + vector_rep.generate_representation_from(post) + + expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id) + end end describe "#asymmetric_topics_similarity_search" do @@ -44,7 +59,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.max_sequence_length - 2, ) stub_vector_mapping(text, @expected_embedding) - vector_rep.generate_topic_representation_from(topic) + vector_rep.generate_representation_from(topic) expect( vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),