diff --git a/config/settings.yml b/config/settings.yml index 3b78e2d9..00e33c2e 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -122,6 +122,8 @@ discourse_ai: default: 4096 ai_hugging_face_model_display_name: default: "" + ai_hugging_face_tei_endpoint: + default: "" ai_google_custom_search_api_key: default: "" secret: true diff --git a/lib/modules/embeddings/vector_representations/bge_large_en.rb b/lib/modules/embeddings/vector_representations/bge_large_en.rb index bc56f867..dbdd377f 100644 --- a/lib/modules/embeddings/vector_representations/bge_large_en.rb +++ b/lib/modules/embeddings/vector_representations/bge_large_en.rb @@ -5,10 +5,23 @@ module DiscourseAi module VectorRepresentations class BgeLargeEn < Base def vector_from(text) - DiscourseAi::Inference::CloudflareWorkersAi - .perform!(inference_model_name, { text: text }) - .dig(:result, :data) - .first + if SiteSetting.ai_cloudflare_workers_api_token.present? + DiscourseAi::Inference::CloudflareWorkersAi + .perform!(inference_model_name, { text: text }) + .dig(:result, :data) + .first + elsif SiteSetting.ai_hugging_face_tei_endpoint.present? + DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(text).first + elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + DiscourseAi::Inference::DiscourseClassifier.perform!( + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + inference_model_name.split("/").last, + text, + SiteSetting.ai_embeddings_discourse_service_api_key, + ) + else + raise "No inference endpoint configured" + end end def name diff --git a/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb b/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb index 30ffb4d8..9789f199 100644 --- a/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb @@ -5,12 +5,18 @@ module DiscourseAi module VectorRepresentations class MultilingualE5Large < Base def vector_from(text) - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - name, - "query: #{text}", - SiteSetting.ai_embeddings_discourse_service_api_key, - ) + if SiteSetting.ai_hugging_face_tei_endpoint.present? + DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(text).first + elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + DiscourseAi::Inference::DiscourseClassifier.perform!( + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + name, + "query: #{text}", + SiteSetting.ai_embeddings_discourse_service_api_key, + ) + else + raise "No inference endpoint configured" + end end def id diff --git a/lib/shared/inference/hugging_face_text_embeddings.rb b/lib/shared/inference/hugging_face_text_embeddings.rb new file mode 100644 index 00000000..e8968dde --- /dev/null +++ b/lib/shared/inference/hugging_face_text_embeddings.rb @@ -0,0 +1,20 @@ +# frozen_string_literal: true + +module ::DiscourseAi + module Inference + class HuggingFaceTextEmbeddings + def self.perform!(content) + headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + body = { inputs: content }.to_json + + api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint + + response = Faraday.post(api_endpoint, body, headers) + + raise Net::HTTPBadResponse if ![200].include?(response.status) + + JSON.parse(response.body, symbolize_names: true) + end + end + end +end diff --git a/plugin.rb b/plugin.rb index e2e410c4..4e415cd2 100644 --- a/plugin.rb +++ b/plugin.rb @@ -42,6 +42,7 @@ after_initialize do require_relative "lib/shared/inference/hugging_face_text_generation" require_relative "lib/shared/inference/amazon_bedrock_inference" require_relative "lib/shared/inference/cloudflare_workers_ai" + require_relative "lib/shared/inference/hugging_face_text_embeddings" require_relative "lib/shared/inference/function" require_relative "lib/shared/inference/function_list"