diff --git a/lib/modules/embeddings/strategies/truncation.rb b/lib/modules/embeddings/strategies/truncation.rb index 8a9b9fa4..f7d76340 100644 --- a/lib/modules/embeddings/strategies/truncation.rb +++ b/lib/modules/embeddings/strategies/truncation.rb @@ -22,17 +22,17 @@ module DiscourseAi @model = model @target = target @tokenizer = @model.tokenizer - @max_length = @model.max_sequence_length - @processed_target = +"" + @max_length = @model.max_sequence_length - 2 + @processed_target = nil end # Need a better name for this method def process! case @target when Topic - topic_truncation(@target) + @processed_target = topic_truncation(@target) when Post - post_truncation(@target) + @processed_target = post_truncation(@target) else raise ArgumentError, "Invalid target type" end @@ -41,7 +41,7 @@ module DiscourseAi end def topic_truncation(topic) - t = @processed_target + t = +"" t << topic.title t << "\n\n" @@ -54,7 +54,7 @@ module DiscourseAi topic.posts.find_each do |post| t << post.raw - break if @tokenizer.size(t) >= @max_length + break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up? t << "\n\n" end @@ -62,7 +62,7 @@ module DiscourseAi end def post_truncation(post) - t = processed_target + t = +"" t << post.topic.title t << "\n\n" diff --git a/spec/lib/modules/embeddings/strategies/truncation_spec.rb b/spec/lib/modules/embeddings/strategies/truncation_spec.rb new file mode 100644 index 00000000..c25ade73 --- /dev/null +++ b/spec/lib/modules/embeddings/strategies/truncation_spec.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do + describe "#process!" do + context "when the model uses OpenAI to create embeddings" do + before { SiteSetting.max_post_length = 100_000 } + + fab!(:topic) { Fabricate(:topic) } + fab!(:post) do + Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500) + end + fab!(:post) do + Fabricate( + :post, + topic: topic, + raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400, + ) + end + fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) } + + let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first } + let(:truncation) { described_class.new(topic, model) } + + it "truncates a topic" do + truncation.process! + + expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length + end + end + end +end