diff --git a/config/settings.yml b/config/settings.yml index 757f1871..25cdd20d 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -128,6 +128,9 @@ discourse_ai: default: "" ai_hugging_face_tei_endpoint: default: "" + ai_hugging_face_tei_endpoint_srv: + default: "" + hidden: true ai_google_custom_search_api_key: default: "" secret: true diff --git a/lib/embeddings/vector_representations/bge_large_en.rb b/lib/embeddings/vector_representations/bge_large_en.rb index 19a25ad1..af600b7e 100644 --- a/lib/embeddings/vector_representations/bge_large_en.rb +++ b/lib/embeddings/vector_representations/bge_large_en.rb @@ -10,7 +10,7 @@ module DiscourseAi .perform!(inference_model_name, { text: text }) .dig(:result, :data) .first - elsif SiteSetting.ai_hugging_face_tei_endpoint.present? + elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? truncated_text = tokenizer.truncate(text, max_sequence_length - 2) DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index e968f060..d7fcab4d 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -5,7 +5,7 @@ module DiscourseAi module VectorRepresentations class MultilingualE5Large < Base def vector_from(text) - if SiteSetting.ai_hugging_face_tei_endpoint.present? + if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? truncated_text = tokenizer.truncate(text, max_sequence_length - 2) DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb index 47dd3f27..fbc4ddec 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text_embeddings.rb @@ -7,7 +7,12 @@ module ::DiscourseAi headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } body = { inputs: content, truncate: true }.to_json - api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint + if SiteSetting.ai_hugging_face_tei_endpoint_srv.present? + service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv) + api_endpoint = "https://#{service.target}:#{service.port}" + else + api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint + end response = Faraday.post(api_endpoint, body, headers) @@ -15,6 +20,11 @@ module ::DiscourseAi JSON.parse(response.body, symbolize_names: true) end + + def self.configured? + SiteSetting.ai_hugging_face_tei_endpoint.present? || + SiteSetting.ai_hugging_face_tei_endpoint_srv.present? + end end end end diff --git a/lib/utils/dns_srv.rb b/lib/utils/dns_srv.rb new file mode 100644 index 00000000..ddc4038d --- /dev/null +++ b/lib/utils/dns_srv.rb @@ -0,0 +1,45 @@ +# frozen_string_literal: true + +require "resolv" + +module DiscourseAi + module Utils + module DnsSrv + def self.lookup(domain) + Discourse + .cache + .fetch("dns_srv_lookup:#{domain}", expires_in: 5.minutes) do + resources = dns_srv_lookup_for_domain(domain) + + select_server(resources) + end + end + + private + + def self.dns_srv_lookup_for_domain(domain) + resolver = Resolv::DNS.new + resources = resolver.getresources(domain, Resolv::DNS::Resource::IN::SRV) + end + + def self.select_server(resources) + priority = resources.group_by(&:priority).keys.min + + priority_resources = resources.select { |r| r.priority == priority } + + total_weight = priority_resources.map(&:weight).sum + + random_weight = rand(total_weight) + + priority_resources.each do |resource| + random_weight -= resource.weight + + return resource if random_weight < 0 + end + + # fallback + resources.first + end + end + end +end diff --git a/spec/lib/utils/dns_srv_spec.rb b/spec/lib/utils/dns_srv_spec.rb new file mode 100644 index 00000000..33a2115a --- /dev/null +++ b/spec/lib/utils/dns_srv_spec.rb @@ -0,0 +1,34 @@ +# frozen_string_literal: true + +describe DiscourseAi::Utils::DnsSrv do + let(:domain) { "example.com" } + let(:weighted_dns_results) do + [ + Resolv::DNS::Resource::IN::SRV.new(1, 1, 443, "service1.example.com"), + Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service2.example.com"), + Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service3.example.com"), + Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service4.example.com"), + Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service5.example.com"), + ] + end + + context "when there are several servers with the same priority" do + before do + Resolv::DNS.any_instance.stubs(:getresources).returns(weighted_dns_results) + + Discourse.cache.delete("dns_srv_lookup:#{domain}") + end + + it "picks a server" do + selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain) + + expect(weighted_dns_results).to include(selected_server) + expect(selected_server.port).to eq(443) + end + + it "doesn't pick a server with lower priority" do + selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain) + expect(weighted_dns_results.filter { |r| r.priority == 1 }).to include(selected_server) + end + end +end