discourse-ai/lib/personas/tool_runner.rb

706 lines
22 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Personas
class ToolRunner
attr_reader :tool, :parameters, :llm
attr_accessor :running_attached_function, :timeout, :custom_raw
TooManyRequestsError = Class.new(StandardError)
DEFAULT_TIMEOUT = 2000
MAX_MEMORY = 10_000_000
MARSHAL_STACK_DEPTH = 20
MAX_HTTP_REQUESTS = 20
def initialize(parameters:, llm:, bot_user:, context: nil, tool:, timeout: nil)
if context && !context.is_a?(DiscourseAi::Personas::BotContext)
raise ArgumentError, "context must be a BotContext object"
end
context ||= DiscourseAi::Personas::BotContext.new
@parameters = parameters
@llm = llm
@bot_user = bot_user
@context = context
@tool = tool
@timeout = timeout || DEFAULT_TIMEOUT
@running_attached_function = false
@http_requests_made = 0
end
def mini_racer_context
@mini_racer_context ||=
begin
ctx =
MiniRacer::Context.new(
max_memory: MAX_MEMORY,
marshal_stack_depth: MARSHAL_STACK_DEPTH,
)
attach_truncate(ctx)
attach_http(ctx)
attach_index(ctx)
attach_upload(ctx)
attach_chain(ctx)
attach_discourse(ctx)
ctx.eval(framework_script)
ctx
end
end
def framework_script
http_methods = %i[get post put patch delete].map { |method| <<~JS }.join("\n")
#{method}: function(url, options) {
return _http_#{method}(url, options);
},
JS
<<~JS
const http = {
#{http_methods}
};
const llm = {
truncate: _llm_truncate,
generate: _llm_generate,
};
const index = {
search: _index_search,
}
const upload = {
create: _upload_create,
}
const chain = {
setCustomRaw: _chain_set_custom_raw,
};
const discourse = {
search: function(params) {
return _discourse_search(params);
},
updatePersona: function(persona_id_or_name, updates) {
const result = _discourse_update_persona(persona_id_or_name, updates);
if (result.error) {
throw new Error(result.error);
}
return result;
},
getPost: _discourse_get_post,
getTopic: _discourse_get_topic,
getUser: _discourse_get_user,
getPersona: function(name) {
const personaDetails = _discourse_get_persona(name);
if (personaDetails.error) {
throw new Error(personaDetails.error);
}
// merge result.persona with {}..
return Object.assign({
update: function(updates) {
const result = _discourse_update_persona(name, updates);
if (result.error) {
throw new Error(result.error);
}
return result;
},
respondTo: function(params) {
const result = _discourse_respond_to_persona(name, params);
if (result.error) {
throw new Error(result.error);
}
return result;
}
}, personaDetails.persona);
},
createChatMessage: function(params) {
const result = _discourse_create_chat_message(params);
if (result.error) {
throw new Error(result.error);
}
return result;
},
};
const context = #{JSON.generate(@context.to_json)};
function details() { return ""; };
JS
end
def details
eval_with_timeout("details()")
end
def eval_with_timeout(script, timeout: nil)
timeout ||= @timeout
mutex = Mutex.new
done = false
elapsed = 0
t =
Thread.new do
begin
while !done
# this is not accurate. but reasonable enough for a timeout
sleep(0.001)
elapsed += 1 if !self.running_attached_function
if elapsed > timeout
mutex.synchronize { mini_racer_context.stop unless done }
break
end
end
rescue => e
STDERR.puts e
STDERR.puts "FAILED TO TERMINATE DUE TO TIMEOUT"
end
end
rval = mini_racer_context.eval(script)
mutex.synchronize { done = true }
# ensure we do not leak a thread in state
t.join
t = nil
rval
ensure
# exceptions need to be handled
t&.join
end
def invoke
mini_racer_context.eval(tool.script)
eval_with_timeout("invoke(#{JSON.generate(parameters)})")
rescue MiniRacer::ScriptTerminatedError
{ error: "Script terminated due to timeout" }
end
def has_custom_context?
mini_racer_context.eval(tool.script)
mini_racer_context.eval("typeof customContext === 'function'")
rescue StandardError
false
end
def custom_context
mini_racer_context.eval(tool.script)
mini_racer_context.eval("customContext()")
rescue StandardError
nil
end
private
MAX_FRAGMENTS = 200
def rag_search(query, filenames: nil, limit: 10)
limit = limit.to_i
return [] if limit < 1
limit = [MAX_FRAGMENTS, limit].min
upload_refs =
UploadReference.where(target_id: tool.id, target_type: "AiTool").pluck(:upload_id)
if filenames
upload_refs = Upload.where(id: upload_refs).where(original_filename: filenames).pluck(:id)
end
return [] if upload_refs.empty?
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
fragment_ids =
DiscourseAi::Embeddings::Schema
.for(RagDocumentFragment)
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
SQL
end
.map(&:rag_document_fragment_id)
fragments =
RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck(
:id,
:fragment,
:metadata,
)
mapped = {}
fragments.each do |id, fragment, metadata|
mapped[id] = { fragment: fragment, metadata: metadata }
end
fragment_ids.take(limit).map { |fragment_id| mapped[fragment_id] }
end
def attach_truncate(mini_racer_context)
mini_racer_context.attach(
"_llm_truncate",
->(text, length) { @llm.tokenizer.truncate(text, length) },
)
mini_racer_context.attach(
"_llm_generate",
->(prompt) do
in_attached_function do
@llm.generate(
convert_js_prompt_to_ruby(prompt),
user: llm_user,
feature_name: "custom_tool_#{tool.name}",
)
end
end,
)
end
def convert_js_prompt_to_ruby(prompt)
if prompt.is_a?(String)
prompt
elsif prompt.is_a?(Hash)
messages = prompt["messages"]
if messages.blank? || !messages.is_a?(Array)
raise Discourse::InvalidParameters.new("Prompt must have messages")
end
messages.each(&:symbolize_keys!)
messages.each { |message| message[:type] = message[:type].to_sym }
DiscourseAi::Completions::Prompt.new(messages: prompt["messages"])
else
raise Discourse::InvalidParameters.new("Prompt must be a string or a hash")
end
end
def llm_user
@llm_user ||=
begin
post&.user || @bot_user
end
end
def post
return @post if defined?(@post)
post_id = @context.post_id
@post = post_id && Post.find_by(id: post_id)
end
def attach_index(mini_racer_context)
mini_racer_context.attach(
"_index_search",
->(*params) do
in_attached_function do
query, options = params
self.running_attached_function = true
options ||= {}
options = options.symbolize_keys
self.rag_search(query, **options)
end
end,
)
end
def attach_chain(mini_racer_context)
mini_racer_context.attach("_chain_set_custom_raw", ->(raw) { self.custom_raw = raw })
end
def attach_discourse(mini_racer_context)
mini_racer_context.attach(
"_discourse_get_post",
->(post_id) do
in_attached_function do
post = Post.find_by(id: post_id)
return nil if post.nil?
guardian = Guardian.new(Discourse.system_user)
obj =
recursive_as_json(
PostSerializer.new(post, scope: guardian, root: false, add_raw: true),
)
topic_obj =
recursive_as_json(
ListableTopicSerializer.new(post.topic, scope: guardian, root: false),
)
obj["topic"] = topic_obj
obj
end
end,
)
mini_racer_context.attach(
"_discourse_get_topic",
->(topic_id) do
in_attached_function do
topic = Topic.find_by(id: topic_id)
return nil if topic.nil?
guardian = Guardian.new(Discourse.system_user)
recursive_as_json(ListableTopicSerializer.new(topic, scope: guardian, root: false))
end
end,
)
mini_racer_context.attach(
"_discourse_get_user",
->(user_id_or_username) do
in_attached_function do
user = nil
if user_id_or_username.is_a?(Integer) ||
user_id_or_username.to_i.to_s == user_id_or_username
user = User.find_by(id: user_id_or_username.to_i)
else
user = User.find_by(username: user_id_or_username)
end
return nil if user.nil?
guardian = Guardian.new(Discourse.system_user)
recursive_as_json(UserSerializer.new(user, scope: guardian, root: false))
end
end,
)
mini_racer_context.attach(
"_discourse_respond_to_persona",
->(persona_name, params) do
in_attached_function do
# if we have 1000s of personas this can be slow ... we may need to optimize
persona_class = AiPersona.all_personas.find { |persona| persona.name == persona_name }
return { error: "Persona not found" } if persona_class.nil?
persona = persona_class.new
bot = DiscourseAi::Personas::Bot.as(@bot_user || persona.user, persona: persona)
playground = DiscourseAi::AiBot::Playground.new(bot)
if @context.post_id
post = Post.find_by(id: @context.post_id)
return { error: "Post not found" } if post.nil?
reply_post =
playground.reply_to(
post,
custom_instructions: params["instructions"],
whisper: params["whisper"],
)
if reply_post
return(
{ success: true, post_id: reply_post.id, post_number: reply_post.post_number }
)
else
return { error: "Failed to create reply" }
end
elsif @context.message_id && @context.channel_id
message = Chat::Message.find_by(id: @context.message_id)
channel = Chat::Channel.find_by(id: @context.channel_id)
return { error: "Message or channel not found" } if message.nil? || channel.nil?
reply =
playground.reply_to_chat_message(message, channel, @context.context_post_ids)
if reply
return { success: true, message_id: reply.id }
else
return { error: "Failed to create chat reply" }
end
else
return { error: "No valid context for response" }
end
end
end,
)
mini_racer_context.attach(
"_discourse_create_chat_message",
->(params) do
in_attached_function do
params = params.symbolize_keys
channel_name = params[:channel_name]
username = params[:username]
message = params[:message]
# Validate parameters
return { error: "Missing required parameter: channel_name" } if channel_name.blank?
return { error: "Missing required parameter: username" } if username.blank?
return { error: "Missing required parameter: message" } if message.blank?
# Find the user
user = User.find_by(username: username)
return { error: "User not found: #{username}" } if user.nil?
# Find the channel
channel = Chat::Channel.find_by(name: channel_name)
if channel.nil?
# Try finding by slug if not found by name
channel = Chat::Channel.find_by(slug: channel_name.parameterize)
end
return { error: "Channel not found: #{channel_name}" } if channel.nil?
begin
guardian = Guardian.new(user)
message =
ChatSDK::Message.create(
raw: message,
channel_id: channel.id,
guardian: guardian,
enforce_membership: !channel.direct_message_channel?,
)
{
success: true,
message_id: message.id,
message: message.message,
created_at: message.created_at.iso8601,
}
rescue => e
{ error: "Failed to create chat message: #{e.message}" }
end
end
end,
)
mini_racer_context.attach(
"_discourse_search",
->(params) do
in_attached_function do
search_params = params.symbolize_keys
if search_params.delete(:with_private)
search_params[:current_user] = Discourse.system_user
end
search_params[:result_style] = :detailed
results = DiscourseAi::Utils::Search.perform_search(**search_params)
recursive_as_json(results)
end
end,
)
mini_racer_context.attach(
"_discourse_get_persona",
->(persona_name) do
in_attached_function do
persona = AiPersona.find_by(name: persona_name)
return { error: "Persona not found" } if persona.nil?
# Return a subset of relevant persona attributes
{
persona:
persona.attributes.slice(
"id",
"name",
"description",
"enabled",
"system_prompt",
"temperature",
"top_p",
"vision_enabled",
"tools",
"max_context_posts",
"allow_chat_channel_mentions",
"allow_chat_direct_messages",
"allow_topic_mentions",
"allow_personal_messages",
),
}
end
end,
)
mini_racer_context.attach(
"_discourse_update_persona",
->(persona_id_or_name, updates) do
in_attached_function do
# Find persona by ID or name
persona = nil
if persona_id_or_name.is_a?(Integer) ||
persona_id_or_name.to_i.to_s == persona_id_or_name
persona = AiPersona.find_by(id: persona_id_or_name.to_i)
else
persona = AiPersona.find_by(name: persona_id_or_name)
end
return { error: "Persona not found" } if persona.nil?
allowed_updates = {}
if updates["system_prompt"].present?
allowed_updates[:system_prompt] = updates["system_prompt"]
end
if updates["temperature"].is_a?(Numeric)
allowed_updates[:temperature] = updates["temperature"]
end
allowed_updates[:top_p] = updates["top_p"] if updates["top_p"].is_a?(Numeric)
if updates["description"].present?
allowed_updates[:description] = updates["description"]
end
allowed_updates[:enabled] = updates["enabled"] if updates["enabled"].is_a?(
TrueClass,
) || updates["enabled"].is_a?(FalseClass)
if persona.update(allowed_updates)
return(
{
success: true,
persona:
persona.attributes.slice(
"id",
"name",
"description",
"enabled",
"system_prompt",
"temperature",
"top_p",
),
}
)
else
return { error: persona.errors.full_messages.join(", ") }
end
end
end,
)
end
def attach_upload(mini_racer_context)
mini_racer_context.attach(
"_upload_create",
->(filename, base_64_content) do
begin
in_attached_function do
# protect against misuse
filename = File.basename(filename)
Tempfile.create(filename) do |file|
file.binmode
file.write(Base64.decode64(base_64_content))
file.rewind
upload =
UploadCreator.new(
file,
filename,
for_private_message: @context.private_message,
).create_for(@bot_user.id)
{ id: upload.id, short_url: upload.short_url, url: upload.url }
end
end
end
end,
)
end
def attach_http(mini_racer_context)
mini_racer_context.attach(
"_http_get",
->(url, options) do
begin
@http_requests_made += 1
if @http_requests_made > MAX_HTTP_REQUESTS
raise TooManyRequestsError.new("Tool made too many HTTP requests")
end
in_attached_function do
headers = (options && options["headers"]) || {}
result = {}
DiscourseAi::Personas::Tools::Tool.send_http_request(
url,
headers: headers,
) do |response|
result[:body] = response.body
result[:status] = response.code.to_i
end
result
end
end
end,
)
%i[post put patch delete].each do |method|
mini_racer_context.attach(
"_http_#{method}",
->(url, options) do
begin
@http_requests_made += 1
if @http_requests_made > MAX_HTTP_REQUESTS
raise TooManyRequestsError.new("Tool made too many HTTP requests")
end
in_attached_function do
headers = (options && options["headers"]) || {}
body = options && options["body"]
result = {}
DiscourseAi::Personas::Tools::Tool.send_http_request(
url,
method: method,
headers: headers,
body: body,
) do |response|
result[:body] = response.body
result[:status] = response.code.to_i
end
result
rescue => e
if Rails.env.development?
p url
p options
p e
puts e.backtrace
end
raise e
end
end
end,
)
end
end
def in_attached_function
self.running_attached_function = true
yield
ensure
self.running_attached_function = false
end
def recursive_as_json(obj)
case obj
when Array
obj.map { |item| recursive_as_json(item) }
when Hash
obj.transform_values { |value| recursive_as_json(value) }
when ActiveModel::Serializer, ActiveModel::ArraySerializer
recursive_as_json(obj.as_json)
when ActiveRecord::Base
recursive_as_json(obj.as_json)
else
# Handle objects that respond to as_json but aren't handled above
if obj.respond_to?(:as_json)
result = obj.as_json
if result.equal?(obj)
# If as_json returned the same object, return it to avoid infinite recursion
result
else
recursive_as_json(result)
end
else
# Primitive values like strings, numbers, booleans, nil
obj
end
end
end
end
end
end