DEV: Build sentiment clients outside of promises (#1117)
This commit is contained in:
parent
e52045ebdc
commit
90bcb8b503
|
@ -12,11 +12,6 @@ module ::DiscourseAi
|
||||||
attr_reader :endpoint, :key, :referer
|
attr_reader :endpoint, :key, :referer
|
||||||
|
|
||||||
class << self
|
class << self
|
||||||
def configured?
|
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def reranker_configured?
|
def reranker_configured?
|
||||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
||||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
||||||
|
@ -50,32 +45,23 @@ module ::DiscourseAi
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def classify(content, model_config, base_url = Discourse.base_url)
|
def classify_by_sentiment!(content)
|
||||||
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
|
response = do_request!(content)
|
||||||
headers["X-API-KEY"] = model_config.api_key
|
|
||||||
headers["Authorization"] = "Bearer #{model_config.api_key}"
|
|
||||||
|
|
||||||
body = { inputs: content, truncate: true }.to_json
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
|
|
||||||
api_endpoint = model_config.endpoint
|
|
||||||
if api_endpoint.present? && api_endpoint.start_with?("srv://")
|
|
||||||
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
|
|
||||||
api_endpoint = "https://#{service.target}:#{service.port}"
|
|
||||||
end
|
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
|
||||||
response = conn.post(api_endpoint, body, headers)
|
|
||||||
|
|
||||||
if response.status != 200
|
|
||||||
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
|
|
||||||
end
|
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def perform!(content)
|
def perform!(content)
|
||||||
|
response = do_request!(content)
|
||||||
|
|
||||||
|
JSON.parse(response.body, symbolize_names: true).first
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def do_request!(content)
|
||||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||||
body = { inputs: content, truncate: true }.to_json
|
body = { inputs: content, truncate: true }.to_json
|
||||||
|
|
||||||
|
@ -89,7 +75,7 @@ module ::DiscourseAi
|
||||||
|
|
||||||
raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)
|
raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true).first
|
response
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -55,7 +55,6 @@ module DiscourseAi
|
||||||
|
|
||||||
available_classifiers = classifiers
|
available_classifiers = classifiers
|
||||||
return if available_classifiers.blank?
|
return if available_classifiers.blank?
|
||||||
base_url = Discourse.base_url
|
|
||||||
|
|
||||||
promised_classifications =
|
promised_classifications =
|
||||||
relation
|
relation
|
||||||
|
@ -70,12 +69,14 @@ module DiscourseAi
|
||||||
already_classified = w_text[:target].sentiment_classifications.map(&:model_used)
|
already_classified = w_text[:target].sentiment_classifications.map(&:model_used)
|
||||||
|
|
||||||
classifiers_for_target =
|
classifiers_for_target =
|
||||||
available_classifiers.reject { |ac| already_classified.include?(ac.model_name) }
|
available_classifiers.reject do |ac|
|
||||||
|
already_classified.include?(ac[:model_name])
|
||||||
|
end
|
||||||
|
|
||||||
promised_target_results =
|
promised_target_results =
|
||||||
classifiers_for_target.map do |c|
|
classifiers_for_target.map do |cft|
|
||||||
Concurrent::Promises.future_on(pool) do
|
Concurrent::Promises.future_on(pool) do
|
||||||
results[c.model_name] = request_with(w_text[:text], c, base_url)
|
results[cft[:model_name]] = request_with(cft[:client], w_text[:text])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -98,18 +99,19 @@ module DiscourseAi
|
||||||
|
|
||||||
def classify!(target)
|
def classify!(target)
|
||||||
return if target.blank?
|
return if target.blank?
|
||||||
return if classifiers.blank?
|
available_classifiers = classifiers
|
||||||
|
return if available_classifiers.blank?
|
||||||
|
|
||||||
to_classify = prepare_text(target)
|
to_classify = prepare_text(target)
|
||||||
return if to_classify.blank?
|
return if to_classify.blank?
|
||||||
|
|
||||||
already_classified = target.sentiment_classifications.map(&:model_used)
|
already_classified = target.sentiment_classifications.map(&:model_used)
|
||||||
classifiers_for_target =
|
classifiers_for_target =
|
||||||
classifiers.reject { |ac| already_classified.include?(ac.model_name) }
|
available_classifiers.reject { |ac| already_classified.include?(ac[:model_name]) }
|
||||||
|
|
||||||
results =
|
results =
|
||||||
classifiers_for_target.reduce({}) do |memo, model|
|
classifiers_for_target.reduce({}) do |memo, cft|
|
||||||
memo[model.model_name] = request_with(to_classify, model)
|
memo[cft[:model_name]] = request_with(cft[:client], to_classify)
|
||||||
memo
|
memo
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -117,7 +119,20 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def classifiers
|
def classifiers
|
||||||
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
|
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.map do |config|
|
||||||
|
api_endpoint = config.endpoint
|
||||||
|
|
||||||
|
if api_endpoint.present? && api_endpoint.start_with?("srv://")
|
||||||
|
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
|
||||||
|
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||||
|
end
|
||||||
|
|
||||||
|
{
|
||||||
|
model_name: config.model_name,
|
||||||
|
client:
|
||||||
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(api_endpoint, config.api_key),
|
||||||
|
}
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_classifiers?
|
def has_classifiers?
|
||||||
|
@ -137,9 +152,9 @@ module DiscourseAi
|
||||||
Tokenizer::BertTokenizer.truncate(content, 512)
|
Tokenizer::BertTokenizer.truncate(content, 512)
|
||||||
end
|
end
|
||||||
|
|
||||||
def request_with(content, config, base_url = Discourse.base_url)
|
def request_with(client, content)
|
||||||
result =
|
result = client.classify_by_sentiment!(content)
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
|
|
||||||
transform_result(result)
|
transform_result(result)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue