FIX: Fix embeddings truncation strategy (#139)
This commit is contained in:
parent
525c8b0913
commit
0738f67fa4
|
@ -22,17 +22,17 @@ module DiscourseAi
|
||||||
@model = model
|
@model = model
|
||||||
@target = target
|
@target = target
|
||||||
@tokenizer = @model.tokenizer
|
@tokenizer = @model.tokenizer
|
||||||
@max_length = @model.max_sequence_length
|
@max_length = @model.max_sequence_length - 2
|
||||||
@processed_target = +""
|
@processed_target = nil
|
||||||
end
|
end
|
||||||
|
|
||||||
# Need a better name for this method
|
# Need a better name for this method
|
||||||
def process!
|
def process!
|
||||||
case @target
|
case @target
|
||||||
when Topic
|
when Topic
|
||||||
topic_truncation(@target)
|
@processed_target = topic_truncation(@target)
|
||||||
when Post
|
when Post
|
||||||
post_truncation(@target)
|
@processed_target = post_truncation(@target)
|
||||||
else
|
else
|
||||||
raise ArgumentError, "Invalid target type"
|
raise ArgumentError, "Invalid target type"
|
||||||
end
|
end
|
||||||
|
@ -41,7 +41,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def topic_truncation(topic)
|
def topic_truncation(topic)
|
||||||
t = @processed_target
|
t = +""
|
||||||
|
|
||||||
t << topic.title
|
t << topic.title
|
||||||
t << "\n\n"
|
t << "\n\n"
|
||||||
|
@ -54,7 +54,7 @@ module DiscourseAi
|
||||||
|
|
||||||
topic.posts.find_each do |post|
|
topic.posts.find_each do |post|
|
||||||
t << post.raw
|
t << post.raw
|
||||||
break if @tokenizer.size(t) >= @max_length
|
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
|
||||||
t << "\n\n"
|
t << "\n\n"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def post_truncation(post)
|
def post_truncation(post)
|
||||||
t = processed_target
|
t = +""
|
||||||
|
|
||||||
t << post.topic.title
|
t << post.topic.title
|
||||||
t << "\n\n"
|
t << "\n\n"
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||||
|
describe "#process!" do
|
||||||
|
context "when the model uses OpenAI to create embeddings" do
|
||||||
|
before { SiteSetting.max_post_length = 100_000 }
|
||||||
|
|
||||||
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
|
fab!(:post) do
|
||||||
|
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
|
||||||
|
end
|
||||||
|
fab!(:post) do
|
||||||
|
Fabricate(
|
||||||
|
:post,
|
||||||
|
topic: topic,
|
||||||
|
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||||
|
|
||||||
|
let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
|
||||||
|
let(:truncation) { described_class.new(topic, model) }
|
||||||
|
|
||||||
|
it "truncates a topic" do
|
||||||
|
truncation.process!
|
||||||
|
|
||||||
|
expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue