discourse-ai/lib/ai_helper/assistant.rb

448 lines
14 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module AiHelper
class Assistant
IMAGE_CAPTION_MAX_WORDS = 50
TRANSLATE = "translate"
GENERATE_TITLES = "generate_titles"
PROOFREAD = "proofread"
MARKDOWN_TABLE = "markdown_table"
CUSTOM_PROMPT = "custom_prompt"
EXPLAIN = "explain"
ILLUSTRATE_POST = "illustrate_post"
REPLACE_DATES = "replace_dates"
IMAGE_CAPTION = "image_caption"
def self.prompt_cache
@prompt_cache ||= ::DiscourseAi::MultisiteHash.new("prompt_cache")
end
def self.clear_prompt_cache!
prompt_cache.flush!
end
def initialize(helper_llm: nil, image_caption_llm: nil)
@helper_llm = helper_llm
@image_caption_llm = image_caption_llm
end
def available_prompts(user)
key = "prompt_cache_#{I18n.locale}"
prompts = self.class.prompt_cache.fetch(key) { self.all_prompts }
prompts
.map do |prompt|
next if !user.in_any_groups?(prompt[:allowed_group_ids])
if prompt[:name] == ILLUSTRATE_POST &&
SiteSetting.ai_helper_illustrate_post_model == "disabled"
next
end
# We cannot cache this. It depends on the user's effective_locale.
if prompt[:name] == TRANSLATE
locale = user.effective_locale
locale_hash =
LocaleSiteSetting.language_names[locale] ||
LocaleSiteSetting.language_names[locale.split("_")[0]]
translation =
I18n.t(
"discourse_ai.ai_helper.prompts.translate",
language: locale_hash["nativeName"],
) || prompt[:name]
prompt.merge(translated_name: translation)
else
prompt
end
end
.compact
end
def custom_locale_instructions(user = nil, force_default_locale)
locale = SiteSetting.default_locale
locale = user.effective_locale if !force_default_locale && user
locale_hash = LocaleSiteSetting.language_names[locale]
if locale != "en" && locale_hash
locale_description = "#{locale_hash["name"]} (#{locale_hash["nativeName"]})"
"It is imperative that you write your answer in #{locale_description}, you are interacting with a #{locale_description} speaking user. Leave tag names in English."
else
nil
end
end
def attach_user_context(context, user = nil, force_default_locale: false)
locale = SiteSetting.default_locale
locale = user.effective_locale if user && !force_default_locale
locale_hash = LocaleSiteSetting.language_names[locale]
context.user_language = "#{locale_hash["name"]}"
if user
timezone = user.user_option.timezone || "UTC"
current_time = Time.now.in_time_zone(timezone)
temporal_context = {
utc_date_time: current_time.iso8601,
local_time: current_time.strftime("%H:%M"),
user: {
timezone: timezone,
weekday: current_time.strftime("%A"),
},
}
context.temporal_context = temporal_context.to_json
end
context
end
def generate_prompt(
helper_mode,
input,
user,
force_default_locale: false,
custom_prompt: nil,
&block
)
bot = build_bot(helper_mode, user)
user_input = "<input>#{input}</input>"
if helper_mode == CUSTOM_PROMPT && custom_prompt.present?
user_input = "<input>#{custom_prompt}:\n#{input}</input>"
end
context =
DiscourseAi::Personas::BotContext.new(
user: user,
skip_tool_details: true,
feature_name: "ai_helper",
messages: [{ type: :user, content: user_input }],
format_dates: helper_mode == REPLACE_DATES,
custom_instructions: custom_locale_instructions(user, force_default_locale),
)
context = attach_user_context(context, user, force_default_locale: force_default_locale)
helper_response = +""
buffer_blk =
Proc.new do |partial, _, type|
if type == :structured_output
json_summary_schema_key = bot.persona.response_format&.first.to_h
helper_chunk = partial.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
if helper_chunk.present?
helper_response << helper_chunk
block.call(helper_chunk) if block
end
elsif type.blank?
# Assume response is a regular completion.
helper_response << helper_chunk
block.call(helper_chunk) if block
end
end
bot.reply(context, &buffer_blk)
helper_response
end
def generate_and_send_prompt(
helper_mode,
input,
user,
force_default_locale: false,
custom_prompt: nil
)
helper_response =
generate_prompt(
helper_mode,
input,
user,
force_default_locale: force_default_locale,
custom_prompt: custom_prompt,
)
result = { type: prompt_type(helper_mode) }
result[:suggestions] = (
if result[:type] == :list
parse_list(helper_response).map { |suggestion| sanitize_result(suggestion) }
else
sanitized = sanitize_result(helper_response)
result[:diff] = parse_diff(input, sanitized) if result[:type] == :diff
[sanitized]
end
)
result
end
def stream_prompt(
helper_mode,
input,
user,
channel,
force_default_locale: false,
client_id: nil,
custom_prompt: nil
)
streamed_diff = +""
streamed_result = +""
start = Time.now
type = prompt_type(helper_mode)
generate_prompt(
helper_mode,
input,
user,
force_default_locale: force_default_locale,
custom_prompt: custom_prompt,
) do |partial_response|
streamed_result << partial_response
streamed_diff = parse_diff(input, partial_response) if type == :diff
# Throttle updates and check for safe stream points
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
sanitized = sanitize_result(streamed_result)
payload = { result: sanitized, diff: streamed_diff, done: false }
publish_update(channel, payload, user, client_id: client_id)
start = Time.now
end
end
final_diff = parse_diff(input, streamed_result) if type == :diff
sanitized_result = sanitize_result(streamed_result)
if sanitized_result.present?
publish_update(
channel,
{ result: sanitized_result, diff: final_diff, done: true },
user,
client_id: client_id,
)
end
end
def generate_image_caption(upload, user)
bot = build_bot(IMAGE_CAPTION, user)
force_default_locale = false
context =
DiscourseAi::Personas::BotContext.new(
user: user,
skip_tool_details: true,
feature_name: IMAGE_CAPTION,
messages: [
{
type: :user,
content: ["Describe this image in a single sentence.", { upload_id: upload.id }],
},
],
custom_instructions: custom_locale_instructions(user, force_default_locale),
)
structured_output = nil
buffer_blk =
Proc.new do |partial, _, type|
if type == :structured_output
structured_output = partial
json_summary_schema_key = bot.persona.response_format&.first.to_h
end
end
bot.reply(context, llm_args: { max_tokens: 1024 }, &buffer_blk)
raw_caption = ""
if structured_output
json_summary_schema_key = bot.persona.response_format&.first.to_h
raw_caption =
structured_output.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
end
raw_caption.delete("|").squish.truncate_words(IMAGE_CAPTION_MAX_WORDS)
end
private
def build_bot(helper_mode, user)
persona_id = personas_prompt_map(include_image_caption: true).invert[helper_mode]
raise Discourse::InvalidParameters.new(:mode) if persona_id.blank?
persona_klass = AiPersona.find_by(id: persona_id)&.class_instance
return if persona_klass.nil?
llm_model = find_ai_helper_model(helper_mode, persona_klass)
DiscourseAi::Personas::Bot.as(user, persona: persona_klass.new, model: llm_model)
end
# Priorities are:
# 1. Persona's default LLM
# 2. Hidden `ai_helper_model` setting, or `ai_helper_image_caption_model` for image_caption.
# 3. Newest LLM config
def find_ai_helper_model(helper_mode, persona_klass)
model_id = persona_klass.default_llm_id
if !model_id
if helper_mode == IMAGE_CAPTION
model_id = @helper_llm || SiteSetting.ai_helper_image_caption_model&.split(":")&.last
else
model_id = @image_caption_llm || SiteSetting.ai_helper_model&.split(":")&.last
end
end
if model_id.present?
LlmModel.find_by(id: model_id)
else
LlmModel.last
end
end
def personas_prompt_map(include_image_caption: false)
map = {
SiteSetting.ai_helper_translator_persona.to_i => TRANSLATE,
SiteSetting.ai_helper_tittle_suggestions_persona.to_i => GENERATE_TITLES,
SiteSetting.ai_helper_proofreader_persona.to_i => PROOFREAD,
SiteSetting.ai_helper_markdown_tables_persona.to_i => MARKDOWN_TABLE,
SiteSetting.ai_helper_custom_prompt_persona.to_i => CUSTOM_PROMPT,
SiteSetting.ai_helper_explain_persona.to_i => EXPLAIN,
SiteSetting.ai_helper_post_illustrator_persona.to_i => ILLUSTRATE_POST,
SiteSetting.ai_helper_smart_dates_persona.to_i => REPLACE_DATES,
}
if include_image_caption
image_caption_persona = SiteSetting.ai_helper_image_caption_persona.to_i
map[image_caption_persona] = IMAGE_CAPTION if image_caption_persona
end
map
end
def all_prompts
personas_and_prompts = personas_prompt_map
AiPersona
.where(id: personas_prompt_map.keys)
.map do |ai_persona|
prompt_name = personas_prompt_map[ai_persona.id]
if prompt_name
{
name: prompt_name,
translated_name:
I18n.t("discourse_ai.ai_helper.prompts.#{prompt_name}", default: nil) ||
prompt_name,
prompt_type: prompt_type(prompt_name),
icon: icon_map(prompt_name),
location: location_map(prompt_name),
allowed_group_ids: ai_persona.allowed_group_ids,
}
end
end
.compact
end
SANITIZE_REGEX_STR =
%w[term context topic replyTo input output result]
.map { |tag| "<#{tag}>\\n?|\\n?</#{tag}>" }
.join("|")
SANITIZE_REGEX = Regexp.new(SANITIZE_REGEX_STR, Regexp::IGNORECASE | Regexp::MULTILINE)
def sanitize_result(result)
result.gsub(SANITIZE_REGEX, "")
end
def publish_update(channel, payload, user, client_id: nil)
# when publishing we make sure we do not keep large backlogs on the channel
# and make sure we clear the streaming info after 60 seconds
# this ensures we do not bloat redis
if client_id
MessageBus.publish(
channel,
payload,
user_ids: [user.id],
client_ids: [client_id],
max_backlog_age: 60,
)
else
MessageBus.publish(channel, payload, user_ids: [user.id], max_backlog_age: 60)
end
end
def icon_map(name)
case name
when TRANSLATE
"language"
when GENERATE_TITLES
"heading"
when PROOFREAD
"spell-check"
when MARKDOWN_TABLE
"table"
when CUSTOM_PROMPT
"comment"
when EXPLAIN
"question"
when ILLUSTRATE_POST
"images"
when REPLACE_DATES
"calendar-days"
else
nil
end
end
def location_map(name)
case name
when TRANSLATE
%w[composer post]
when GENERATE_TITLES
%w[composer]
when PROOFREAD
%w[composer post]
when MARKDOWN_TABLE
%w[composer]
when CUSTOM_PROMPT
%w[composer post]
when EXPLAIN
%w[post]
when ILLUSTRATE_POST
%w[composer]
when REPLACE_DATES
%w[composer]
else
%w[]
end
end
def prompt_type(prompt_name)
if [PROOFREAD, MARKDOWN_TABLE, REPLACE_DATES, CUSTOM_PROMPT].include?(prompt_name)
return :diff
end
return :list if [ILLUSTRATE_POST, GENERATE_TITLES].include?(prompt_name)
:text
end
def parse_diff(text, suggestion)
cooked_text = PrettyText.cook(text)
cooked_suggestion = PrettyText.cook(suggestion)
DiscourseDiff.new(cooked_text, cooked_suggestion).inline_html
end
def parse_list(list)
Nokogiri::HTML5.fragment(list).css("item").map(&:text)
end
end
end
end