UX: Re-introduce embedding settings validations (#457)

* Revert "Revert "UX: Validate embeddings settings (#455)" (#456)"

This reverts commit 392e2e8aef.

* Resstore previous default
This commit is contained in:
Roman Rizzi 2024-02-01 16:54:09 -03:00 committed by GitHub
parent 392e2e8aef
commit fba9c1bf2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 278 additions and 52 deletions

View File

@ -251,3 +251,12 @@ en:
configuration_hint: configuration_hint:
one: "Make sure the `%{settings}` setting was configured." one: "Make sure the `%{settings}` setting was configured."
other: "Make sure these settings were configured: %{settings}" other: "Make sure these settings were configured: %{settings}"
embeddings:
configuration:
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
choose_model: "Set 'ai embeddings model' first."
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
hint:
one: "Make sure the `%{settings}` setting was configured."
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"

View File

@ -216,6 +216,7 @@ discourse_ai:
ai_embeddings_enabled: ai_embeddings_enabled:
default: false default: false
client: true client: true
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
ai_embeddings_discourse_service_api_endpoint: "" ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_endpoint_srv: ai_embeddings_discourse_service_api_endpoint_srv:
default: "" default: ""
@ -225,7 +226,6 @@ discourse_ai:
secret: true secret: true
ai_embeddings_model: ai_embeddings_model:
type: enum type: enum
list_type: compact
default: "bge-large-en" default: "bge-large-en"
allow_any: false allow_any: false
choices: choices:
@ -236,6 +236,7 @@ discourse_ai:
- multilingual-e5-large - multilingual-e5-large
- bge-large-en - bge-large-en
- gemini - gemini
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
ai_embeddings_per_post_enabled: ai_embeddings_per_post_enabled:
default: false default: false
hidden: true hidden: true

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class EmbeddingsModelValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
return true if Rails.env.test?
representation =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(val)
return false if representation.nil?
if !representation.correctly_configured?
@representation = representation
return false
end
if !can_generate_embeddings?(val)
@unreachable = true
return false
end
true
end
def error_message
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
@representation&.configuration_hint
end
def can_generate_embeddings?(val)
DiscourseAi::Embeddings::VectorRepresentations::Base
.find_representation(val)
.new(DiscourseAi::Embeddings::Strategies::Truncation.new)
.vector_from("this is a test")
.present?
end
end
end
end

View File

@ -0,0 +1,51 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class EmbeddingsModuleValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
return true if val == "f"
return true if Rails.env.test?
chosen_model = SiteSetting.ai_embeddings_model
return false if !chosen_model
representation =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model)
return false if representation.nil?
if !representation.correctly_configured?
@representation = representation
return false
end
if !can_generate_embeddings?(chosen_model)
@unreachable = true
return false
end
true
end
def error_message
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
@representation&.configuration_hint
end
def can_generate_embeddings?(val)
DiscourseAi::Embeddings::VectorRepresentations::Base
.find_representation(val)
.new(DiscourseAi::Embeddings::Strategies::Truncation.new)
.vector_from("this is a test")
.present?
end
end
end
end

View File

@ -4,19 +4,34 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class AllMpnetBaseV2 < Base class AllMpnetBaseV2 < Base
class << self
def name
"all-mpnet-base-v2"
end
def correctly_configured?
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
end
def dependant_setting_names
%w[
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text) def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!( DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify", "#{discourse_embeddings_endpoint}/api/v1/classify",
name, self.class.name,
text, text,
SiteSetting.ai_embeddings_discourse_service_api_key, SiteSetting.ai_embeddings_discourse_service_api_key,
) )
end end
def name
"all-mpnet-base-v2"
end
def dimensions def dimensions
768 768
end end

View File

@ -4,18 +4,41 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class Base class Base
def self.current_representation(strategy) class << self
# we are explicit here cause the loader may have not def find_representation(model_name)
# loaded the subclasses yet # we are explicit here cause the loader may have not
[ # loaded the subclasses yet
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2, [
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn, DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::Gemini, DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002, DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small, DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large, DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model } DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end end
def initialize(strategy) def initialize(strategy)

View File

@ -4,6 +4,32 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class BgeLargeEn < Base class BgeLargeEn < Base
class << self
def name
"bge-large-en"
end
def correctly_configured?
SiteSetting.ai_cloudflare_workers_api_token.present? ||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_cloudflare_workers_api_token
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text) def vector_from(text)
if SiteSetting.ai_cloudflare_workers_api_token.present? if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi DiscourseAi::Inference::CloudflareWorkersAi
@ -25,10 +51,6 @@ module DiscourseAi
end end
end end
def name
"bge-large-en"
end
def inference_model_name def inference_model_name
"baai/bge-large-en-v1.5" "baai/bge-large-en-v1.5"
end end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class Gemini < Base class Gemini < Base
class << self
def name
"gemini"
end
def correctly_configured?
SiteSetting.ai_gemini_api_key.present?
end
def dependant_setting_names
%w[ai_gemini_api_key]
end
end
def id def id
5 5
end end
@ -12,10 +26,6 @@ module DiscourseAi
1 1
end end
def name
"gemini"
end
def dimensions def dimensions
768 768
end end

View File

@ -4,6 +4,30 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class MultilingualE5Large < Base class MultilingualE5Large < Base
class << self
def name
"multilingual-e5-large"
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text) def vector_from(text)
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2) truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
@ -11,7 +35,7 @@ module DiscourseAi
elsif discourse_embeddings_endpoint.present? elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!( DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify", "#{discourse_embeddings_endpoint}/api/v1/classify",
name, self.class.name,
"query: #{text}", "query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key, SiteSetting.ai_embeddings_discourse_service_api_key,
) )
@ -28,10 +52,6 @@ module DiscourseAi
1 1
end end
def name
"multilingual-e5-large"
end
def dimensions def dimensions
1024 1024
end end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class TextEmbedding3Large < Base class TextEmbedding3Large < Base
class << self
def name
"text-embedding-3-large"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id def id
7 7
end end
@ -12,10 +26,6 @@ module DiscourseAi
1 1
end end
def name
"text-embedding-3-large"
end
def dimensions def dimensions
# real dimentions are 3072, but we only support up to 2000 in the # real dimentions are 3072, but we only support up to 2000 in the
# indexes, so we downsample to 2000 via API # indexes, so we downsample to 2000 via API
@ -38,7 +48,7 @@ module DiscourseAi
response = response =
DiscourseAi::Inference::OpenAiEmbeddings.perform!( DiscourseAi::Inference::OpenAiEmbeddings.perform!(
text, text,
model: name, model: self.clas.name,
dimensions: dimensions, dimensions: dimensions,
) )
response[:data].first[:embedding] response[:data].first[:embedding]

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class TextEmbedding3Small < Base class TextEmbedding3Small < Base
class << self
def name
"text-embedding-3-small"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id def id
6 6
end end
@ -12,10 +26,6 @@ module DiscourseAi
1 1
end end
def name
"text-embedding-3-small"
end
def dimensions def dimensions
1536 1536
end end
@ -33,7 +43,7 @@ module DiscourseAi
end end
def vector_from(text) def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name) response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding] response[:data].first[:embedding]
end end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings module Embeddings
module VectorRepresentations module VectorRepresentations
class TextEmbeddingAda002 < Base class TextEmbeddingAda002 < Base
class << self
def name
"text-embedding-ada-002"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id def id
2 2
end end
@ -12,10 +26,6 @@ module DiscourseAi
1 1
end end
def name
"text-embedding-ada-002"
end
def dimensions def dimensions
1536 1536
end end
@ -33,7 +43,7 @@ module DiscourseAi
end end
def vector_from(text) def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name) response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding] response[:data].first[:embedding]
end end

View File

@ -7,7 +7,6 @@ RSpec.describe Jobs::GenerateEmbeddings do
before do before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_enabled = true SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_model = "bge-large-en"
end end
fab!(:topic) { Fabricate(:topic) } fab!(:topic) { Fabricate(:topic) }
@ -27,7 +26,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
vector_rep.tokenizer, vector_rep.tokenizer,
vector_rep.max_sequence_length - 2, vector_rep.max_sequence_length - 2,
) )
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: topic.id, target_type: "Topic") job.execute(target_id: topic.id, target_type: "Topic")
@ -39,7 +38,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
text = text =
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2) truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: post.id, target_type: "Post") job.execute(target_id: post.id, target_type: "Post")

View File

@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding) def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
end end
it_behaves_like "generates and store embedding using with vector representation" it_behaves_like "generates and store embedding using with vector representation"

View File

@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Lar
def stub_vector_mapping(text, expected_embedding) def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service( EmbeddingsGenerationStubs.discourse_service(
vector_rep.name, described_class.name,
"query: #{text}", "query: #{text}",
expected_embedding, expected_embedding,
) )

View File

@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda0
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding) def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding) EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
end end
it_behaves_like "generates and store embedding using with vector representation" it_behaves_like "generates and store embedding using with vector representation"