diff --git a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb index 919b43c6..a76f1ad8 100644 --- a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb +++ b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb @@ -13,9 +13,9 @@ module Jobs strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) - vector_rep.generate_topic_representation_from(topic, strategy) + vector_rep.generate_topic_representation_from(topic) end end end diff --git a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb new file mode 100644 index 00000000..258161ee --- /dev/null +++ b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb @@ -0,0 +1,39 @@ +# frozen_string_literal: true + +require_relative "../../../../support/embeddings_generation_stubs" + +RSpec.describe Jobs::GenerateEmbeddings do + subject(:job) { described_class.new } + + describe "#execute" do + before do + SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + SiteSetting.ai_embeddings_enabled = true + SiteSetting.ai_embeddings_model = "all-mpnet-base-v2" + end + + fab!(:topic) { Fabricate(:topic) } + fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } + + let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } + let(:vector_rep) do + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) + end + + it "works" do + expected_embedding = [0.0038493] * vector_rep.dimensions + + text = + truncation.prepare_text_from( + topic, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) + + job.execute(topic_id: topic.id) + + expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id) + end + end +end