discourse-ai/lib/completions/endpoints/base.rb

447 lines
14 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Base
attr_reader :partial_tool_calls, :output_thinking
CompletionFailed = Class.new(StandardError)
# 6 minutes
# Reasoning LLMs can take a very long time to respond, generally it will be under 5 minutes
# The alternative is to have per LLM timeouts but that would make it extra confusing for people
# configuring. Let's try this simple solution first.
TIMEOUT = 360
class << self
def endpoint_for(provider_name)
endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::Cohere,
DiscourseAi::Completions::Endpoints::SambaNova,
DiscourseAi::Completions::Endpoints::Mistral,
DiscourseAi::Completions::Endpoints::OpenRouter,
]
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake
end
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(provider_name)
end
end
def can_contact?(_model_provider)
raise NotImplementedError
end
end
def initialize(llm_model)
@llm_model = llm_model
end
def use_ssl?
if model_uri&.scheme.present?
model_uri.scheme == "https"
else
true
end
end
def xml_tags_to_strip(dialect)
[]
end
def perform_completion!(
dialect,
user,
model_params = {},
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
output_thinking: false,
&blk
)
LlmQuota.check_quotas!(@llm_model, user)
start_time = Time.now
@forced_json_through_prefill = false
@partial_tool_calls = partial_tool_calls
@output_thinking = output_thinking
model_params = normalize_model_params(model_params)
orig_blk = blk
if block_given? && disable_streaming?
result =
perform_completion!(
dialect,
user,
model_params,
feature_name: feature_name,
feature_context: feature_context,
partial_tool_calls: partial_tool_calls,
output_thinking: output_thinking,
)
wrapped = result
wrapped = [result] if !result.is_a?(Array)
cancelled_by_caller = false
cancel_proc = -> { cancelled_by_caller = true }
wrapped.each do |partial|
blk.call(partial, cancel_proc)
break if cancelled_by_caller
end
return result
end
@streaming_mode = block_given?
prompt = dialect.translate
structured_output = nil
if model_params[:response_format].present?
schema_properties =
model_params[:response_format].dig(:json_schema, :schema, :properties)
if schema_properties.present?
structured_output = DiscourseAi::Completions::StructuredOutput.new(schema_properties)
end
end
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
use_ssl: use_ssl?,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
response_data = +""
response_raw = +""
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
partials_raw = +""
request_body = prepare_payload(prompt, model_params, dialect).to_json
request = prepare_request(request_body)
# Some providers rely on prefill to return structured outputs, so the start
# of the JSON won't be included in the response. Supply it to keep JSON valid.
structured_output << +"{" if structured_output && @forced_json_through_prefill
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed, response.body
end
xml_tool_processor =
XmlToolProcessor.new(
partial_tool_calls: partial_tool_calls,
) if xml_tools_enabled? && dialect.prompt.has_tools?
to_strip = xml_tags_to_strip(dialect)
xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode
blk =
lambda do |partial, cancel|
if partial.is_a?(String)
partial = xml_stripper << partial if xml_stripper
if structured_output.present?
structured_output << partial
partial = structured_output
end
end
orig_blk.call(partial, cancel) if partial
end
end
log =
start_log(
provider_id: provider_id,
request_body: request_body,
dialect: dialect,
prompt: prompt,
user: user,
feature_name: feature_name,
feature_context: feature_context,
)
if !@streaming_mode
response_data =
non_streaming_response(
response: response,
xml_tool_processor: xml_tool_processor,
xml_stripper: xml_stripper,
partials_raw: partials_raw,
response_raw: response_raw,
structured_output: structured_output,
)
return response_data
end
begin
cancelled = false
cancel = -> do
cancelled = true
http.finish
end
break if cancelled
response.read_body do |chunk|
break if cancelled
response_raw << chunk
decode_chunk(chunk).each do |partial|
break if cancelled
partials_raw << partial.to_s
response_data << partial if partial.is_a?(String)
partials = [partial]
if xml_tool_processor && partial.is_a?(String)
partials = (xml_tool_processor << partial)
if xml_tool_processor.should_cancel?
cancel.call
break
end
end
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
end
end
rescue IOError, StandardError
raise if !cancelled
end
if xml_stripper
stripped = xml_stripper.finish
if stripped.present?
response_data << stripped
result = []
result = (xml_tool_processor << stripped) if xml_tool_processor
result.each { |partial| blk.call(partial, cancel) }
end
end
if xml_tool_processor
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
end
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
return response_data
ensure
if log
log.raw_response_payload = response_raw
final_log_update(log)
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.created_at = start_time
log.updated_at = Time.now
log.duration_msecs = (Time.now - start_time) * 1000
log.save!
LlmQuota.log_usage(@llm_model, user, log.request_tokens, log.response_tokens)
if Rails.env.development? && !ENV["DISCOURSE_AI_NO_DEBUG"]
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
puts "raw response: #{response_raw}"
end
end
if log && (logger = Thread.current[:llm_audit_log])
call_data = <<~LOG
#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}
request:
#{format_possible_json_payload(log.raw_request_payload)}
response:
#{response_data}
LOG
logger.info(call_data)
end
if log && (structured_logger = Thread.current[:llm_audit_structured_log])
llm_request =
begin
JSON.parse(log.raw_request_payload)
rescue StandardError
log.raw_request_payload
end
# gemini puts passwords in query params
# we don't want to log that
structured_logger.log(
"llm_call",
args: {
class: self.class.name,
completion_url: request.uri.to_s.split("?")[0],
request: llm_request,
result: response_data,
request_tokens: log.request_tokens,
response_tokens: log.response_tokens,
duration: log.duration_msecs,
stream: @streaming_mode,
},
start_time: start_time.utc,
end_time: Time.now.utc,
)
end
end
end
end
def final_log_update(log)
# for people that need to override
end
def default_options
raise NotImplementedError
end
def provider_id
raise NotImplementedError
end
def prompt_size(prompt)
tokenizer.size(extract_prompt_for_tokenizer(prompt))
end
attr_reader :llm_model
protected
def tokenizer
llm_model.tokenizer_class
end
# should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params)
raise NotImplementedError
end
def model_uri
raise NotImplementedError
end
def prepare_payload(_prompt, _model_params)
raise NotImplementedError
end
def prepare_request(_payload)
raise NotImplementedError
end
def decode(_response_raw)
raise NotImplementedError
end
def decode_chunk_finish
[]
end
def decode_chunk(_chunk)
raise NotImplementedError
end
def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def xml_tools_enabled?
raise NotImplementedError
end
def disable_streaming?
@disable_streaming = !!llm_model.lookup_custom_param("disable_streaming")
end
private
def format_possible_json_payload(payload)
begin
JSON.pretty_generate(JSON.parse(payload))
rescue JSON::ParserError
payload
end
end
def start_log(
provider_id:,
request_body:,
dialect:,
prompt:,
user:,
feature_name:,
feature_context:
)
AiApiAuditLog.new(
provider_id: provider_id,
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
feature_name: feature_name,
language_model: llm_model.name,
feature_context: feature_context.present? ? feature_context.as_json : nil,
)
end
def non_streaming_response(
response:,
xml_tool_processor:,
xml_stripper:,
partials_raw:,
response_raw:,
structured_output:
)
response_raw << response.read_body
response_data = decode(response_raw)
response_data.each { |partial| partials_raw << partial.to_s }
if xml_tool_processor
response_data.each do |partial|
processed = (xml_tool_processor << partial)
processed << xml_tool_processor.finish
response_data = []
processed.flatten.compact.each { |inner| response_data << inner }
end
end
if xml_stripper
response_data.map! do |partial|
stripped = (xml_stripper << partial) if partial.is_a?(String)
if stripped.present?
stripped
else
partial
end
end
response_data << xml_stripper.finish
end
response_data.reject!(&:blank?)
if structured_output.present?
response_data.each { |data| structured_output << data if data.is_a?(String) }
return structured_output
end
# this is to keep stuff backwards compatible
response_data = response_data.first if response_data.length == 1
response_data
end
end
end
end
end