266 lines
7.6 KiB
Ruby
266 lines
7.6 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Dialects
|
|
class Dialect
|
|
class << self
|
|
def can_translate?(llm_model)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def all_dialects
|
|
[
|
|
DiscourseAi::Completions::Dialects::ChatGpt,
|
|
DiscourseAi::Completions::Dialects::Gemini,
|
|
DiscourseAi::Completions::Dialects::Claude,
|
|
DiscourseAi::Completions::Dialects::Command,
|
|
DiscourseAi::Completions::Dialects::Ollama,
|
|
DiscourseAi::Completions::Dialects::Mistral,
|
|
DiscourseAi::Completions::Dialects::Nova,
|
|
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
|
]
|
|
end
|
|
|
|
def dialect_for(llm_model)
|
|
dialects = []
|
|
|
|
if Rails.env.test? || Rails.env.development?
|
|
dialects = [DiscourseAi::Completions::Dialects::Fake]
|
|
end
|
|
|
|
dialects = dialects.concat(all_dialects)
|
|
|
|
dialect = dialects.find { |d| d.can_translate?(llm_model) }
|
|
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
|
|
|
dialect
|
|
end
|
|
end
|
|
|
|
def initialize(generic_prompt, llm_model, opts: {})
|
|
@prompt = generic_prompt
|
|
@opts = opts
|
|
@llm_model = llm_model
|
|
end
|
|
|
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
|
|
|
def native_tool_support?
|
|
false
|
|
end
|
|
|
|
def vision_support?
|
|
llm_model.vision_enabled?
|
|
end
|
|
|
|
def tools
|
|
@tools ||= tools_dialect.translated_tools
|
|
end
|
|
|
|
def tool_choice
|
|
prompt.tool_choice
|
|
end
|
|
|
|
def self.no_more_tool_calls_text
|
|
# note, Anthropic must never prefill with an ending whitespace
|
|
"I WILL NOT USE TOOLS IN THIS REPLY, user expressed they wanted to stop using tool calls.\nHere is the best, complete, answer I can come up with given the information I have."
|
|
end
|
|
|
|
def self.no_more_tool_calls_text_user
|
|
"DO NOT USE TOOLS IN YOUR REPLY. Return the best answer you can given the information I supplied you."
|
|
end
|
|
|
|
def no_more_tool_calls_text
|
|
self.class.no_more_tool_calls_text
|
|
end
|
|
|
|
def no_more_tool_calls_text_user
|
|
self.class.no_more_tool_calls_text_user
|
|
end
|
|
|
|
def translate
|
|
messages = trim_messages(prompt.messages)
|
|
last_message = messages.last
|
|
inject_done_on_last_tool_call = false
|
|
|
|
if !native_tool_support? && last_message && last_message[:type].to_sym == :tool &&
|
|
prompt.tool_choice == :none
|
|
inject_done_on_last_tool_call = true
|
|
end
|
|
|
|
translated =
|
|
messages
|
|
.map do |msg|
|
|
case msg[:type].to_sym
|
|
when :system
|
|
system_msg(msg)
|
|
when :user
|
|
user_msg(msg)
|
|
when :model
|
|
model_msg(msg)
|
|
when :tool
|
|
if inject_done_on_last_tool_call && msg == last_message
|
|
tools_dialect.inject_done { tool_msg(msg) }
|
|
else
|
|
tool_msg(msg)
|
|
end
|
|
when :tool_call
|
|
tool_call_msg(msg)
|
|
else
|
|
raise ArgumentError, "Unknown message type: #{msg[:type]}"
|
|
end
|
|
end
|
|
.compact
|
|
|
|
translated
|
|
end
|
|
|
|
def conversation_context
|
|
raise NotImplemented
|
|
end
|
|
|
|
def max_prompt_tokens
|
|
raise NotImplemented
|
|
end
|
|
|
|
attr_reader :prompt
|
|
|
|
private
|
|
|
|
attr_reader :opts, :llm_model
|
|
|
|
def trim_messages(messages)
|
|
prompt_limit = max_prompt_tokens
|
|
current_token_count = 0
|
|
message_step_size = (prompt_limit / 25).to_i * -1
|
|
|
|
trimmed_messages = []
|
|
|
|
range = (0..-1)
|
|
if messages.dig(0, :type) == :system
|
|
max_system_tokens = prompt_limit * 0.6
|
|
system_message = messages[0]
|
|
system_size = calculate_message_token(system_message)
|
|
|
|
if system_size > max_system_tokens
|
|
system_message[:content] = tokenizer.truncate(
|
|
system_message[:content],
|
|
max_system_tokens,
|
|
)
|
|
end
|
|
|
|
trimmed_messages << system_message
|
|
current_token_count += calculate_message_token(system_message)
|
|
range = (1..-1)
|
|
end
|
|
|
|
reversed_trimmed_msgs = []
|
|
|
|
messages[range].reverse.each do |msg|
|
|
break if current_token_count >= prompt_limit
|
|
|
|
message_tokens = calculate_message_token(msg)
|
|
|
|
dupped_msg = msg.dup
|
|
|
|
# Don't trim tool call metadata.
|
|
if msg[:type] == :tool_call
|
|
break if current_token_count + message_tokens + per_message_overhead > prompt_limit
|
|
|
|
current_token_count += message_tokens + per_message_overhead
|
|
reversed_trimmed_msgs << dupped_msg
|
|
next
|
|
end
|
|
|
|
# Trimming content to make sure we respect token limit.
|
|
while dupped_msg[:content].present? &&
|
|
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
|
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
|
|
message_tokens = calculate_message_token(dupped_msg)
|
|
end
|
|
|
|
next if dupped_msg[:content].blank?
|
|
|
|
current_token_count += message_tokens + per_message_overhead
|
|
|
|
reversed_trimmed_msgs << dupped_msg
|
|
end
|
|
|
|
reversed_trimmed_msgs.pop if reversed_trimmed_msgs.last&.dig(:type) == :tool
|
|
|
|
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
|
|
end
|
|
|
|
def per_message_overhead
|
|
0
|
|
end
|
|
|
|
def calculate_message_token(msg)
|
|
llm_model.tokenizer_class.size(msg[:content].to_s)
|
|
end
|
|
|
|
def tools_dialect
|
|
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
|
|
end
|
|
|
|
def system_msg(msg)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def model_msg(msg)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def user_msg(msg)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def tool_call_msg(msg)
|
|
new_content = tools_dialect.from_raw_tool_call(msg)
|
|
msg = msg.merge(content: new_content)
|
|
model_msg(msg)
|
|
end
|
|
|
|
def tool_msg(msg)
|
|
new_content = tools_dialect.from_raw_tool(msg)
|
|
msg = msg.merge(content: new_content)
|
|
user_msg(msg)
|
|
end
|
|
|
|
def to_encoded_content_array(
|
|
content:,
|
|
image_encoder:,
|
|
text_encoder:,
|
|
other_encoder: nil,
|
|
allow_vision:
|
|
)
|
|
content = [content] if !content.is_a?(Array)
|
|
|
|
current_string = +""
|
|
result = []
|
|
|
|
content.each do |c|
|
|
if c.is_a?(String)
|
|
current_string << c
|
|
elsif c.is_a?(Hash) && c.key?(:upload_id) && allow_vision
|
|
if !current_string.empty?
|
|
result << text_encoder.call(current_string)
|
|
current_string = +""
|
|
end
|
|
encoded = prompt.encode_upload(c[:upload_id])
|
|
result << image_encoder.call(encoded) if encoded
|
|
elsif other_encoder
|
|
encoded = other_encoder.call(c)
|
|
result << encoded if encoded
|
|
end
|
|
end
|
|
|
|
result << text_encoder.call(current_string) if !current_string.empty?
|
|
result
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|