FIX: Restore the accidentally deleted query prefix. (#1079)
Additionally, we add a prefix for embedding generation. Both are stored in the definitions table.
This commit is contained in:
parent
f5cf1019fb
commit
3b66fb3e87
|
@ -111,6 +111,8 @@ module DiscourseAi
|
||||||
:url,
|
:url,
|
||||||
:api_key,
|
:api_key,
|
||||||
:tokenizer_class,
|
:tokenizer_class,
|
||||||
|
:embed_prompt,
|
||||||
|
:search_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
|
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
|
||||||
|
|
|
@ -42,6 +42,7 @@ class EmbeddingDefinition < ActiveRecord::Base
|
||||||
pg_function: "<#>",
|
pg_function: "<#>",
|
||||||
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
|
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
|
||||||
provider: HUGGING_FACE,
|
provider: HUGGING_FACE,
|
||||||
|
search_prompt: "Represent this sentence for searching relevant passages:",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
preset_id: "bge-m3",
|
preset_id: "bge-m3",
|
||||||
|
@ -228,4 +229,6 @@ end
|
||||||
# provider_params :jsonb
|
# provider_params :jsonb
|
||||||
# created_at :datetime not null
|
# created_at :datetime not null
|
||||||
# updated_at :datetime not null
|
# updated_at :datetime not null
|
||||||
|
# embed_prompt :string default(""), not null
|
||||||
|
# search_prompt :string default(""), not null
|
||||||
#
|
#
|
||||||
|
|
|
@ -13,6 +13,8 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer
|
||||||
:api_key,
|
:api_key,
|
||||||
:seeded,
|
:seeded,
|
||||||
:tokenizer_class,
|
:tokenizer_class,
|
||||||
|
:embed_prompt,
|
||||||
|
:search_prompt,
|
||||||
:provider_params
|
:provider_params
|
||||||
|
|
||||||
def api_key
|
def api_key
|
||||||
|
|
|
@ -14,7 +14,9 @@ export default class AiEmbedding extends RestModel {
|
||||||
"api_key",
|
"api_key",
|
||||||
"max_sequence_length",
|
"max_sequence_length",
|
||||||
"provider_params",
|
"provider_params",
|
||||||
"pg_function"
|
"pg_function",
|
||||||
|
"embed_prompt",
|
||||||
|
"search_prompt"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -290,6 +290,24 @@ export default class AiEmbeddingEditor extends Component {
|
||||||
{{/if}}
|
{{/if}}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.embed_prompt"}}</label>
|
||||||
|
<Input
|
||||||
|
@type="text"
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__embed_prompt"
|
||||||
|
@value={{this.editingModel.embed_prompt}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.search_prompt"}}</label>
|
||||||
|
<Input
|
||||||
|
@type="text"
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__search_prompt"
|
||||||
|
@value={{this.editingModel.search_prompt}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{i18n "discourse_ai.embeddings.max_sequence_length"}}</label>
|
<label>{{i18n "discourse_ai.embeddings.max_sequence_length"}}</label>
|
||||||
<Input
|
<Input
|
||||||
|
|
|
@ -530,6 +530,8 @@ en:
|
||||||
tokenizer: "Tokenizer"
|
tokenizer: "Tokenizer"
|
||||||
dimensions: "Embedding dimensions"
|
dimensions: "Embedding dimensions"
|
||||||
max_sequence_length: "Sequence length"
|
max_sequence_length: "Sequence length"
|
||||||
|
embed_prompt: "Embed prompt"
|
||||||
|
search_prompt: "Search prompt"
|
||||||
|
|
||||||
distance_function: "Distance function"
|
distance_function: "Distance function"
|
||||||
distance_functions:
|
distance_functions:
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
class ConfigurableEmbeddingsPrefixes < ActiveRecord::Migration[7.2]
|
||||||
|
def up
|
||||||
|
add_column :embedding_definitions, :embed_prompt, :string, null: false, default: ""
|
||||||
|
add_column :embedding_definitions, :search_prompt, :string, null: false, default: ""
|
||||||
|
|
||||||
|
# 4 is bge-large-en. Default model and the only one using this so far.
|
||||||
|
execute <<~SQL
|
||||||
|
UPDATE embedding_definitions
|
||||||
|
SET search_prompt='Represent this sentence for searching relevant passages:'
|
||||||
|
WHERE id = 4
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
def down
|
||||||
|
raise ActiveRecord::IrreversibleMigration
|
||||||
|
end
|
||||||
|
end
|
|
@ -15,23 +15,28 @@ module DiscourseAi
|
||||||
def prepare_target_text(target, vdef)
|
def prepare_target_text(target, vdef)
|
||||||
max_length = vdef.max_sequence_length - 2
|
max_length = vdef.max_sequence_length - 2
|
||||||
|
|
||||||
case target
|
prepared_text =
|
||||||
when Topic
|
case target
|
||||||
topic_truncation(target, vdef.tokenizer, max_length)
|
when Topic
|
||||||
when Post
|
topic_truncation(target, vdef.tokenizer, max_length)
|
||||||
post_truncation(target, vdef.tokenizer, max_length)
|
when Post
|
||||||
when RagDocumentFragment
|
post_truncation(target, vdef.tokenizer, max_length)
|
||||||
vdef.tokenizer.truncate(target.fragment, max_length)
|
when RagDocumentFragment
|
||||||
else
|
vdef.tokenizer.truncate(target.fragment, max_length)
|
||||||
raise ArgumentError, "Invalid target type"
|
else
|
||||||
end
|
raise ArgumentError, "Invalid target type"
|
||||||
|
end
|
||||||
|
|
||||||
|
return prepared_text if vdef.embed_prompt.blank?
|
||||||
|
|
||||||
|
[vdef.embed_prompt, prepared_text].join(" ")
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_query_text(text, vdef, asymetric: false)
|
def prepare_query_text(text, vdef, asymetric: false)
|
||||||
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
|
qtext = asymetric ? "#{vdef.search_prompt} #{text}" : text
|
||||||
max_length = vdef.max_sequence_length - 2
|
max_length = vdef.max_sequence_length - 2
|
||||||
|
|
||||||
vdef.tokenizer.truncate(text, max_length)
|
vdef.tokenizer.truncate(qtext, max_length)
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
|
@ -3,29 +3,51 @@
|
||||||
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||||
subject(:truncation) { described_class.new }
|
subject(:truncation) { described_class.new }
|
||||||
|
|
||||||
|
fab!(:open_ai_embedding_def)
|
||||||
|
let(:prefix) { "I come first:" }
|
||||||
|
|
||||||
|
describe "#prepare_target_text" do
|
||||||
|
before { SiteSetting.max_post_length = 100_000 }
|
||||||
|
|
||||||
|
fab!(:topic)
|
||||||
|
fab!(:post) do
|
||||||
|
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
|
||||||
|
end
|
||||||
|
fab!(:post) do
|
||||||
|
Fabricate(
|
||||||
|
:post,
|
||||||
|
topic: topic,
|
||||||
|
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||||
|
fab!(:open_ai_embedding_def)
|
||||||
|
|
||||||
|
it "truncates a topic" do
|
||||||
|
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
|
||||||
|
|
||||||
|
expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
|
||||||
|
open_ai_embedding_def.max_sequence_length
|
||||||
|
end
|
||||||
|
|
||||||
|
it "includes embed prefix" do
|
||||||
|
open_ai_embedding_def.update!(embed_prompt: prefix)
|
||||||
|
|
||||||
|
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
|
||||||
|
|
||||||
|
expect(prepared_text.starts_with?(prefix)).to eq(true)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "#prepare_query_text" do
|
describe "#prepare_query_text" do
|
||||||
context "when using vector def from OpenAI" do
|
context "when search is asymetric" do
|
||||||
before { SiteSetting.max_post_length = 100_000 }
|
it "includes search prefix" do
|
||||||
|
open_ai_embedding_def.update!(search_prompt: prefix)
|
||||||
|
|
||||||
fab!(:topic)
|
prepared_query_text =
|
||||||
fab!(:post) do
|
truncation.prepare_query_text("searching", open_ai_embedding_def, asymetric: true)
|
||||||
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
|
|
||||||
end
|
|
||||||
fab!(:post) do
|
|
||||||
Fabricate(
|
|
||||||
:post,
|
|
||||||
topic: topic,
|
|
||||||
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
|
||||||
fab!(:open_ai_embedding_def)
|
|
||||||
|
|
||||||
it "truncates a topic" do
|
expect(prepared_query_text.starts_with?(prefix)).to eq(true)
|
||||||
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
|
|
||||||
|
|
||||||
expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
|
|
||||||
open_ai_embedding_def.max_sequence_length
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
|
||||||
url: "https://test.com/api/v1/embeddings",
|
url: "https://test.com/api/v1/embeddings",
|
||||||
api_key: "test",
|
api_key: "test",
|
||||||
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||||
|
embed_prompt: "I come first:",
|
||||||
|
search_prompt: "prefix for search",
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -27,6 +29,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
|
||||||
|
|
||||||
expect(response.status).to eq(201)
|
expect(response.status).to eq(201)
|
||||||
expect(created_def.display_name).to eq(valid_attrs[:display_name])
|
expect(created_def.display_name).to eq(valid_attrs[:display_name])
|
||||||
|
expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt])
|
||||||
|
expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt])
|
||||||
end
|
end
|
||||||
|
|
||||||
it "stores provider-specific config params" do
|
it "stores provider-specific config params" do
|
||||||
|
|
|
@ -61,6 +61,11 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
|
||||||
select_kit.expand
|
select_kit.expand
|
||||||
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")
|
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")
|
||||||
|
|
||||||
|
embed_prefix = "On creation:"
|
||||||
|
search_prefix = "On search:"
|
||||||
|
find("input.ai-embedding-editor__embed_prompt").fill_in(with: embed_prefix)
|
||||||
|
find("input.ai-embedding-editor__search_prompt").fill_in(with: search_prefix)
|
||||||
|
|
||||||
find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
|
find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
|
||||||
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)
|
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)
|
||||||
|
|
||||||
|
@ -83,5 +88,7 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
|
||||||
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
|
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
|
||||||
expect(embedding_def.pg_function).to eq(preset[:pg_function])
|
expect(embedding_def.pg_function).to eq(preset[:pg_function])
|
||||||
expect(embedding_def.provider).to eq(preset[:provider])
|
expect(embedding_def.provider).to eq(preset[:provider])
|
||||||
|
expect(embedding_def.embed_prompt).to eq(embed_prefix)
|
||||||
|
expect(embedding_def.search_prompt).to eq(search_prefix)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue