discourse-ai/lib/completions/prompt.rb

219 lines
6.7 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Completions
class Prompt
INVALID_TURN = Class.new(StandardError)
attr_reader :messages
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
def initialize(
system_message_text = nil,
messages: [],
tools: [],
topic_id: nil,
post_id: nil,
max_pixels: nil,
tool_choice: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@max_pixels = max_pixels || 1_048_576
@topic_id = topic_id
@post_id = post_id
@messages = []
if system_message_text
system_message = { type: :system, content: system_message_text }
@messages << system_message
end
@messages.concat(messages)
@messages.each { |message| validate_message(message) }
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
@tools = tools
@tool_choice = tool_choice
end
# this new api tries to create symmetry between responses and prompts
# this means anything we get back from the model via endpoint can be easily appended
def push_model_response(response)
response = [response] if !response.is_a? Array
thinking, thinking_signature, redacted_thinking_signature = nil
response.each do |message|
if message.is_a?(Thinking)
# we can safely skip partials here
next if message.partial?
if message.redacted
redacted_thinking_signature = message.signature
else
thinking = message.message
thinking_signature = message.signature
end
elsif message.is_a?(ToolCall)
next if message.partial?
# this is a bit surprising about the API
# needing to add arguments is not ideal
push(
type: :tool_call,
content: { arguments: message.parameters }.to_json,
id: message.id,
name: message.name,
)
elsif message.is_a?(String)
push(type: :model, content: message)
else
raise ArgumentError, "response must be an array of strings, ToolCalls, or Thinkings"
end
end
# anthropic rules are that we attach thinking to last for the response
# it is odd, I wonder if long term we just keep thinking as a separate object
if thinking || redacted_thinking_signature
messages.last[:thinking] = thinking
messages.last[:thinking_signature] = thinking_signature
messages.last[:redacted_thinking_signature] = redacted_thinking_signature
end
end
def push(
type:,
content:,
id: nil,
name: nil,
thinking: nil,
thinking_signature: nil,
redacted_thinking_signature: nil
)
return if type == :system
new_message = { type: type, content: content }
new_message[:name] = name.to_s if name
new_message[:id] = id.to_s if id
new_message[:thinking] = thinking if thinking
new_message[:thinking_signature] = thinking_signature if thinking_signature
new_message[
:redacted_thinking_signature
] = redacted_thinking_signature if redacted_thinking_signature
validate_message(new_message)
validate_turn(messages.last, new_message)
messages << new_message
end
def has_tools?
tools.present?
end
def encoded_uploads(message)
if message[:content].is_a?(Array)
upload_ids =
message[:content]
.map do |content|
content[:upload_id] if content.is_a?(Hash) && content.key?(:upload_id)
end
.compact
if !upload_ids.empty?
return UploadEncoder.encode(upload_ids: upload_ids, max_pixels: max_pixels)
end
end
[]
end
def text_only(message)
if message[:content].is_a?(Array)
message[:content].map { |element| element if element.is_a?(String) }.compact.join
else
message[:content]
end
end
def encode_upload(upload_id)
UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first
end
def content_with_encoded_uploads(content)
return [content] unless content.is_a?(Array)
content.map do |c|
if c.is_a?(Hash) && c.key?(:upload_id)
encode_upload(c[:upload_id])
else
c
end
end
end
def ==(other)
return false unless other.is_a?(Prompt)
messages == other.messages && tools == other.tools && topic_id == other.topic_id &&
post_id == other.post_id && max_pixels == other.max_pixels &&
tool_choice == other.tool_choice
end
def eql?(other)
self == other
end
def hash
[messages, tools, topic_id, post_id, max_pixels, tool_choice].hash
end
private
def validate_message(message)
valid_types = %i[system user model tool tool_call]
if !valid_types.include?(message[:type])
raise ArgumentError, "message type must be one of #{valid_types}"
end
valid_keys = %i[
type
content
id
name
thinking
thinking_signature
redacted_thinking_signature
]
if (invalid_keys = message.keys - valid_keys).any?
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
end
if message[:content].is_a?(Array)
message[:content].each do |content|
if !content.is_a?(String) && !(content.is_a?(Hash) && content.keys == [:upload_id])
raise ArgumentError, "Array message content must be a string or {upload_id: ...} "
end
end
else
if !message[:content].is_a?(String)
raise ArgumentError, "Message content must be a string or an array"
end
end
end
def validate_turn(last_turn, new_turn)
valid_types = %i[tool tool_call model user]
raise INVALID_TURN if !valid_types.include?(new_turn[:type])
if last_turn[:type] == :system && %i[tool tool_call model].include?(new_turn[:type])
raise INVALID_TURN
end
raise INVALID_TURN if new_turn[:type] == :tool && last_turn[:type] != :tool_call
raise INVALID_TURN if new_turn[:type] == :model && last_turn[:type] == :model
end
end
end
end