706 lines
22 KiB
Ruby
706 lines
22 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Agents
|
|
class ToolRunner
|
|
attr_reader :tool, :parameters, :llm
|
|
attr_accessor :running_attached_function, :timeout, :custom_raw
|
|
|
|
TooManyRequestsError = Class.new(StandardError)
|
|
|
|
DEFAULT_TIMEOUT = 2000
|
|
MAX_MEMORY = 10_000_000
|
|
MARSHAL_STACK_DEPTH = 20
|
|
MAX_HTTP_REQUESTS = 20
|
|
|
|
def initialize(parameters:, llm:, bot_user:, context: nil, tool:, timeout: nil)
|
|
if context && !context.is_a?(DiscourseAi::Agents::BotContext)
|
|
raise ArgumentError, "context must be a BotContext object"
|
|
end
|
|
|
|
context ||= DiscourseAi::Agents::BotContext.new
|
|
|
|
@parameters = parameters
|
|
@llm = llm
|
|
@bot_user = bot_user
|
|
@context = context
|
|
@tool = tool
|
|
@timeout = timeout || DEFAULT_TIMEOUT
|
|
@running_attached_function = false
|
|
|
|
@http_requests_made = 0
|
|
end
|
|
|
|
def mini_racer_context
|
|
@mini_racer_context ||=
|
|
begin
|
|
ctx =
|
|
MiniRacer::Context.new(
|
|
max_memory: MAX_MEMORY,
|
|
marshal_stack_depth: MARSHAL_STACK_DEPTH,
|
|
)
|
|
attach_truncate(ctx)
|
|
attach_http(ctx)
|
|
attach_index(ctx)
|
|
attach_upload(ctx)
|
|
attach_chain(ctx)
|
|
attach_discourse(ctx)
|
|
ctx.eval(framework_script)
|
|
ctx
|
|
end
|
|
end
|
|
|
|
def framework_script
|
|
http_methods = %i[get post put patch delete].map { |method| <<~JS }.join("\n")
|
|
#{method}: function(url, options) {
|
|
return _http_#{method}(url, options);
|
|
},
|
|
JS
|
|
<<~JS
|
|
const http = {
|
|
#{http_methods}
|
|
};
|
|
|
|
const llm = {
|
|
truncate: _llm_truncate,
|
|
generate: _llm_generate,
|
|
};
|
|
|
|
const index = {
|
|
search: _index_search,
|
|
}
|
|
|
|
const upload = {
|
|
create: _upload_create,
|
|
}
|
|
|
|
const chain = {
|
|
setCustomRaw: _chain_set_custom_raw,
|
|
};
|
|
|
|
const discourse = {
|
|
search: function(params) {
|
|
return _discourse_search(params);
|
|
},
|
|
updateAgent: function(agent_id_or_name, updates) {
|
|
const result = _discourse_update_agent(agent_id_or_name, updates);
|
|
if (result.error) {
|
|
throw new Error(result.error);
|
|
}
|
|
return result;
|
|
},
|
|
getPost: _discourse_get_post,
|
|
getTopic: _discourse_get_topic,
|
|
getUser: _discourse_get_user,
|
|
getAgent: function(name) {
|
|
const agentDetails = _discourse_get_agent(name);
|
|
if (agentDetails.error) {
|
|
throw new Error(agentDetails.error);
|
|
}
|
|
|
|
// merge result.agent with {}..
|
|
return Object.assign({
|
|
update: function(updates) {
|
|
const result = _discourse_update_agent(name, updates);
|
|
if (result.error) {
|
|
throw new Error(result.error);
|
|
}
|
|
return result;
|
|
},
|
|
respondTo: function(params) {
|
|
const result = _discourse_respond_to_agent(name, params);
|
|
if (result.error) {
|
|
throw new Error(result.error);
|
|
}
|
|
return result;
|
|
}
|
|
}, agentDetails.agent);
|
|
},
|
|
createChatMessage: function(params) {
|
|
const result = _discourse_create_chat_message(params);
|
|
if (result.error) {
|
|
throw new Error(result.error);
|
|
}
|
|
return result;
|
|
},
|
|
};
|
|
|
|
const context = #{JSON.generate(@context.to_json)};
|
|
|
|
function details() { return ""; };
|
|
JS
|
|
end
|
|
|
|
def details
|
|
eval_with_timeout("details()")
|
|
end
|
|
|
|
def eval_with_timeout(script, timeout: nil)
|
|
timeout ||= @timeout
|
|
mutex = Mutex.new
|
|
done = false
|
|
elapsed = 0
|
|
|
|
t =
|
|
Thread.new do
|
|
begin
|
|
while !done
|
|
# this is not accurate. but reasonable enough for a timeout
|
|
sleep(0.001)
|
|
elapsed += 1 if !self.running_attached_function
|
|
if elapsed > timeout
|
|
mutex.synchronize { mini_racer_context.stop unless done }
|
|
break
|
|
end
|
|
end
|
|
rescue => e
|
|
STDERR.puts e
|
|
STDERR.puts "FAILED TO TERMINATE DUE TO TIMEOUT"
|
|
end
|
|
end
|
|
|
|
rval = mini_racer_context.eval(script)
|
|
|
|
mutex.synchronize { done = true }
|
|
|
|
# ensure we do not leak a thread in state
|
|
t.join
|
|
t = nil
|
|
|
|
rval
|
|
ensure
|
|
# exceptions need to be handled
|
|
t&.join
|
|
end
|
|
|
|
def invoke
|
|
mini_racer_context.eval(tool.script)
|
|
eval_with_timeout("invoke(#{JSON.generate(parameters)})")
|
|
rescue MiniRacer::ScriptTerminatedError
|
|
{ error: "Script terminated due to timeout" }
|
|
end
|
|
|
|
def has_custom_context?
|
|
mini_racer_context.eval(tool.script)
|
|
mini_racer_context.eval("typeof customContext === 'function'")
|
|
rescue StandardError
|
|
false
|
|
end
|
|
|
|
def custom_context
|
|
mini_racer_context.eval(tool.script)
|
|
mini_racer_context.eval("customContext()")
|
|
rescue StandardError
|
|
nil
|
|
end
|
|
|
|
private
|
|
|
|
MAX_FRAGMENTS = 200
|
|
|
|
def rag_search(query, filenames: nil, limit: 10)
|
|
limit = limit.to_i
|
|
return [] if limit < 1
|
|
limit = [MAX_FRAGMENTS, limit].min
|
|
|
|
upload_refs =
|
|
UploadReference.where(target_id: tool.id, target_type: "AiTool").pluck(:upload_id)
|
|
|
|
if filenames
|
|
upload_refs = Upload.where(id: upload_refs).where(original_filename: filenames).pluck(:id)
|
|
end
|
|
|
|
return [] if upload_refs.empty?
|
|
|
|
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
|
|
fragment_ids =
|
|
DiscourseAi::Embeddings::Schema
|
|
.for(RagDocumentFragment)
|
|
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
|
|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
|
|
rag_document_fragments ON
|
|
rag_document_fragments.id = rag_document_fragment_id AND
|
|
rag_document_fragments.target_id = :target_id AND
|
|
rag_document_fragments.target_type = :target_type
|
|
SQL
|
|
end
|
|
.map(&:rag_document_fragment_id)
|
|
|
|
fragments =
|
|
RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck(
|
|
:id,
|
|
:fragment,
|
|
:metadata,
|
|
)
|
|
|
|
mapped = {}
|
|
fragments.each do |id, fragment, metadata|
|
|
mapped[id] = { fragment: fragment, metadata: metadata }
|
|
end
|
|
|
|
fragment_ids.take(limit).map { |fragment_id| mapped[fragment_id] }
|
|
end
|
|
|
|
def attach_truncate(mini_racer_context)
|
|
mini_racer_context.attach(
|
|
"_llm_truncate",
|
|
->(text, length) { @llm.tokenizer.truncate(text, length) },
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_llm_generate",
|
|
->(prompt) do
|
|
in_attached_function do
|
|
@llm.generate(
|
|
convert_js_prompt_to_ruby(prompt),
|
|
user: llm_user,
|
|
feature_name: "custom_tool_#{tool.name}",
|
|
)
|
|
end
|
|
end,
|
|
)
|
|
end
|
|
|
|
def convert_js_prompt_to_ruby(prompt)
|
|
if prompt.is_a?(String)
|
|
prompt
|
|
elsif prompt.is_a?(Hash)
|
|
messages = prompt["messages"]
|
|
if messages.blank? || !messages.is_a?(Array)
|
|
raise Discourse::InvalidParameters.new("Prompt must have messages")
|
|
end
|
|
messages.each(&:symbolize_keys!)
|
|
messages.each { |message| message[:type] = message[:type].to_sym }
|
|
DiscourseAi::Completions::Prompt.new(messages: prompt["messages"])
|
|
else
|
|
raise Discourse::InvalidParameters.new("Prompt must be a string or a hash")
|
|
end
|
|
end
|
|
|
|
def llm_user
|
|
@llm_user ||=
|
|
begin
|
|
post&.user || @bot_user
|
|
end
|
|
end
|
|
|
|
def post
|
|
return @post if defined?(@post)
|
|
post_id = @context.post_id
|
|
@post = post_id && Post.find_by(id: post_id)
|
|
end
|
|
|
|
def attach_index(mini_racer_context)
|
|
mini_racer_context.attach(
|
|
"_index_search",
|
|
->(*params) do
|
|
in_attached_function do
|
|
query, options = params
|
|
self.running_attached_function = true
|
|
options ||= {}
|
|
options = options.symbolize_keys
|
|
self.rag_search(query, **options)
|
|
end
|
|
end,
|
|
)
|
|
end
|
|
|
|
def attach_chain(mini_racer_context)
|
|
mini_racer_context.attach("_chain_set_custom_raw", ->(raw) { self.custom_raw = raw })
|
|
end
|
|
|
|
def attach_discourse(mini_racer_context)
|
|
mini_racer_context.attach(
|
|
"_discourse_get_post",
|
|
->(post_id) do
|
|
in_attached_function do
|
|
post = Post.find_by(id: post_id)
|
|
return nil if post.nil?
|
|
guardian = Guardian.new(Discourse.system_user)
|
|
obj =
|
|
recursive_as_json(
|
|
PostSerializer.new(post, scope: guardian, root: false, add_raw: true),
|
|
)
|
|
topic_obj =
|
|
recursive_as_json(
|
|
ListableTopicSerializer.new(post.topic, scope: guardian, root: false),
|
|
)
|
|
obj["topic"] = topic_obj
|
|
obj
|
|
end
|
|
end,
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_discourse_get_topic",
|
|
->(topic_id) do
|
|
in_attached_function do
|
|
topic = Topic.find_by(id: topic_id)
|
|
return nil if topic.nil?
|
|
guardian = Guardian.new(Discourse.system_user)
|
|
recursive_as_json(ListableTopicSerializer.new(topic, scope: guardian, root: false))
|
|
end
|
|
end,
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_discourse_get_user",
|
|
->(user_id_or_username) do
|
|
in_attached_function do
|
|
user = nil
|
|
|
|
if user_id_or_username.is_a?(Integer) ||
|
|
user_id_or_username.to_i.to_s == user_id_or_username
|
|
user = User.find_by(id: user_id_or_username.to_i)
|
|
else
|
|
user = User.find_by(username: user_id_or_username)
|
|
end
|
|
|
|
return nil if user.nil?
|
|
|
|
guardian = Guardian.new(Discourse.system_user)
|
|
recursive_as_json(UserSerializer.new(user, scope: guardian, root: false))
|
|
end
|
|
end,
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_discourse_respond_to_agent",
|
|
->(agent_name, params) do
|
|
in_attached_function do
|
|
# if we have 1000s of agents this can be slow ... we may need to optimize
|
|
agent_class = AiAgent.all_agents.find { |agent| agent.name == agent_name }
|
|
return { error: "Agent not found" } if agent_class.nil?
|
|
|
|
agent = agent_class.new
|
|
bot = DiscourseAi::Agents::Bot.as(@bot_user || agent.user, agent: agent)
|
|
playground = DiscourseAi::AiBot::Playground.new(bot)
|
|
|
|
if @context.post_id
|
|
post = Post.find_by(id: @context.post_id)
|
|
return { error: "Post not found" } if post.nil?
|
|
|
|
reply_post =
|
|
playground.reply_to(
|
|
post,
|
|
custom_instructions: params["instructions"],
|
|
whisper: params["whisper"],
|
|
)
|
|
|
|
if reply_post
|
|
return(
|
|
{ success: true, post_id: reply_post.id, post_number: reply_post.post_number }
|
|
)
|
|
else
|
|
return { error: "Failed to create reply" }
|
|
end
|
|
elsif @context.message_id && @context.channel_id
|
|
message = Chat::Message.find_by(id: @context.message_id)
|
|
channel = Chat::Channel.find_by(id: @context.channel_id)
|
|
return { error: "Message or channel not found" } if message.nil? || channel.nil?
|
|
|
|
reply =
|
|
playground.reply_to_chat_message(message, channel, @context.context_post_ids)
|
|
|
|
if reply
|
|
return { success: true, message_id: reply.id }
|
|
else
|
|
return { error: "Failed to create chat reply" }
|
|
end
|
|
else
|
|
return { error: "No valid context for response" }
|
|
end
|
|
end
|
|
end,
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_discourse_create_chat_message",
|
|
->(params) do
|
|
in_attached_function do
|
|
params = params.symbolize_keys
|
|
channel_name = params[:channel_name]
|
|
username = params[:username]
|
|
message = params[:message]
|
|
|
|
# Validate parameters
|
|
return { error: "Missing required parameter: channel_name" } if channel_name.blank?
|
|
return { error: "Missing required parameter: username" } if username.blank?
|
|
return { error: "Missing required parameter: message" } if message.blank?
|
|
|
|
# Find the user
|
|
user = User.find_by(username: username)
|
|
return { error: "User not found: #{username}" } if user.nil?
|
|
|
|
# Find the channel
|
|
channel = Chat::Channel.find_by(name: channel_name)
|
|
if channel.nil?
|
|
# Try finding by slug if not found by name
|
|
channel = Chat::Channel.find_by(slug: channel_name.parameterize)
|
|
end
|
|
return { error: "Channel not found: #{channel_name}" } if channel.nil?
|
|
|
|
begin
|
|
guardian = Guardian.new(user)
|
|
message =
|
|
ChatSDK::Message.create(
|
|
raw: message,
|
|
channel_id: channel.id,
|
|
guardian: guardian,
|
|
enforce_membership: !channel.direct_message_channel?,
|
|
)
|
|
|
|
{
|
|
success: true,
|
|
message_id: message.id,
|
|
message: message.message,
|
|
created_at: message.created_at.iso8601,
|
|
}
|
|
rescue => e
|
|
{ error: "Failed to create chat message: #{e.message}" }
|
|
end
|
|
end
|
|
end,
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_discourse_search",
|
|
->(params) do
|
|
in_attached_function do
|
|
search_params = params.symbolize_keys
|
|
if search_params.delete(:with_private)
|
|
search_params[:current_user] = Discourse.system_user
|
|
end
|
|
search_params[:result_style] = :detailed
|
|
results = DiscourseAi::Utils::Search.perform_search(**search_params)
|
|
recursive_as_json(results)
|
|
end
|
|
end,
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_discourse_get_agent",
|
|
->(agent_name) do
|
|
in_attached_function do
|
|
agent = AiAgent.find_by(name: agent_name)
|
|
|
|
return { error: "Agent not found" } if agent.nil?
|
|
|
|
# Return a subset of relevant agent attributes
|
|
{
|
|
agent:
|
|
agent.attributes.slice(
|
|
"id",
|
|
"name",
|
|
"description",
|
|
"enabled",
|
|
"system_prompt",
|
|
"temperature",
|
|
"top_p",
|
|
"vision_enabled",
|
|
"tools",
|
|
"max_context_posts",
|
|
"allow_chat_channel_mentions",
|
|
"allow_chat_direct_messages",
|
|
"allow_topic_mentions",
|
|
"allow_agentl_messages",
|
|
),
|
|
}
|
|
end
|
|
end,
|
|
)
|
|
|
|
mini_racer_context.attach(
|
|
"_discourse_update_agent",
|
|
->(agent_id_or_name, updates) do
|
|
in_attached_function do
|
|
# Find agent by ID or name
|
|
agent = nil
|
|
if agent_id_or_name.is_a?(Integer) ||
|
|
agent_id_or_name.to_i.to_s == agent_id_or_name
|
|
agent = AiAgent.find_by(id: agent_id_or_name.to_i)
|
|
else
|
|
agent = AiAgent.find_by(name: agent_id_or_name)
|
|
end
|
|
|
|
return { error: "Agent not found" } if agent.nil?
|
|
|
|
allowed_updates = {}
|
|
|
|
if updates["system_prompt"].present?
|
|
allowed_updates[:system_prompt] = updates["system_prompt"]
|
|
end
|
|
|
|
if updates["temperature"].is_a?(Numeric)
|
|
allowed_updates[:temperature] = updates["temperature"]
|
|
end
|
|
|
|
allowed_updates[:top_p] = updates["top_p"] if updates["top_p"].is_a?(Numeric)
|
|
|
|
if updates["description"].present?
|
|
allowed_updates[:description] = updates["description"]
|
|
end
|
|
|
|
allowed_updates[:enabled] = updates["enabled"] if updates["enabled"].is_a?(
|
|
TrueClass,
|
|
) || updates["enabled"].is_a?(FalseClass)
|
|
|
|
if agent.update(allowed_updates)
|
|
return(
|
|
{
|
|
success: true,
|
|
agent:
|
|
agent.attributes.slice(
|
|
"id",
|
|
"name",
|
|
"description",
|
|
"enabled",
|
|
"system_prompt",
|
|
"temperature",
|
|
"top_p",
|
|
),
|
|
}
|
|
)
|
|
else
|
|
return { error: agent.errors.full_messages.join(", ") }
|
|
end
|
|
end
|
|
end,
|
|
)
|
|
end
|
|
|
|
def attach_upload(mini_racer_context)
|
|
mini_racer_context.attach(
|
|
"_upload_create",
|
|
->(filename, base_64_content) do
|
|
begin
|
|
in_attached_function do
|
|
# protect against misuse
|
|
filename = File.basename(filename)
|
|
|
|
Tempfile.create(filename) do |file|
|
|
file.binmode
|
|
file.write(Base64.decode64(base_64_content))
|
|
file.rewind
|
|
|
|
upload =
|
|
UploadCreator.new(
|
|
file,
|
|
filename,
|
|
for_private_message: @context.private_message,
|
|
).create_for(@bot_user.id)
|
|
|
|
{ id: upload.id, short_url: upload.short_url, url: upload.url }
|
|
end
|
|
end
|
|
end
|
|
end,
|
|
)
|
|
end
|
|
|
|
def attach_http(mini_racer_context)
|
|
mini_racer_context.attach(
|
|
"_http_get",
|
|
->(url, options) do
|
|
begin
|
|
@http_requests_made += 1
|
|
if @http_requests_made > MAX_HTTP_REQUESTS
|
|
raise TooManyRequestsError.new("Tool made too many HTTP requests")
|
|
end
|
|
|
|
in_attached_function do
|
|
headers = (options && options["headers"]) || {}
|
|
|
|
result = {}
|
|
DiscourseAi::Agents::Tools::Tool.send_http_request(
|
|
url,
|
|
headers: headers,
|
|
) do |response|
|
|
result[:body] = response.body
|
|
result[:status] = response.code.to_i
|
|
end
|
|
|
|
result
|
|
end
|
|
end
|
|
end,
|
|
)
|
|
|
|
%i[post put patch delete].each do |method|
|
|
mini_racer_context.attach(
|
|
"_http_#{method}",
|
|
->(url, options) do
|
|
begin
|
|
@http_requests_made += 1
|
|
if @http_requests_made > MAX_HTTP_REQUESTS
|
|
raise TooManyRequestsError.new("Tool made too many HTTP requests")
|
|
end
|
|
|
|
in_attached_function do
|
|
headers = (options && options["headers"]) || {}
|
|
body = options && options["body"]
|
|
|
|
result = {}
|
|
DiscourseAi::Agents::Tools::Tool.send_http_request(
|
|
url,
|
|
method: method,
|
|
headers: headers,
|
|
body: body,
|
|
) do |response|
|
|
result[:body] = response.body
|
|
result[:status] = response.code.to_i
|
|
end
|
|
|
|
result
|
|
rescue => e
|
|
if Rails.env.development?
|
|
p url
|
|
p options
|
|
p e
|
|
puts e.backtrace
|
|
end
|
|
raise e
|
|
end
|
|
end
|
|
end,
|
|
)
|
|
end
|
|
end
|
|
|
|
def in_attached_function
|
|
self.running_attached_function = true
|
|
yield
|
|
ensure
|
|
self.running_attached_function = false
|
|
end
|
|
|
|
def recursive_as_json(obj)
|
|
case obj
|
|
when Array
|
|
obj.map { |item| recursive_as_json(item) }
|
|
when Hash
|
|
obj.transform_values { |value| recursive_as_json(value) }
|
|
when ActiveModel::Serializer, ActiveModel::ArraySerializer
|
|
recursive_as_json(obj.as_json)
|
|
when ActiveRecord::Base
|
|
recursive_as_json(obj.as_json)
|
|
else
|
|
# Handle objects that respond to as_json but aren't handled above
|
|
if obj.respond_to?(:as_json)
|
|
result = obj.as_json
|
|
if result.equal?(obj)
|
|
# If as_json returned the same object, return it to avoid infinite recursion
|
|
result
|
|
else
|
|
recursive_as_json(result)
|
|
end
|
|
else
|
|
# Primitive values like strings, numbers, booleans, nil
|
|
obj
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|