448 lines
14 KiB
Ruby
448 lines
14 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module AiHelper
|
|
class Assistant
|
|
IMAGE_CAPTION_MAX_WORDS = 50
|
|
|
|
TRANSLATE = "translate"
|
|
GENERATE_TITLES = "generate_titles"
|
|
PROOFREAD = "proofread"
|
|
MARKDOWN_TABLE = "markdown_table"
|
|
CUSTOM_PROMPT = "custom_prompt"
|
|
EXPLAIN = "explain"
|
|
ILLUSTRATE_POST = "illustrate_post"
|
|
REPLACE_DATES = "replace_dates"
|
|
IMAGE_CAPTION = "image_caption"
|
|
|
|
def self.prompt_cache
|
|
@prompt_cache ||= ::DiscourseAi::MultisiteHash.new("prompt_cache")
|
|
end
|
|
|
|
def self.clear_prompt_cache!
|
|
prompt_cache.flush!
|
|
end
|
|
|
|
def initialize(helper_llm: nil, image_caption_llm: nil)
|
|
@helper_llm = helper_llm
|
|
@image_caption_llm = image_caption_llm
|
|
end
|
|
|
|
def available_prompts(user)
|
|
key = "prompt_cache_#{I18n.locale}"
|
|
prompts = self.class.prompt_cache.fetch(key) { self.all_prompts }
|
|
|
|
prompts
|
|
.map do |prompt|
|
|
next if !user.in_any_groups?(prompt[:allowed_group_ids])
|
|
|
|
if prompt[:name] == ILLUSTRATE_POST &&
|
|
SiteSetting.ai_helper_illustrate_post_model == "disabled"
|
|
next
|
|
end
|
|
|
|
# We cannot cache this. It depends on the user's effective_locale.
|
|
if prompt[:name] == TRANSLATE
|
|
locale = user.effective_locale
|
|
locale_hash =
|
|
LocaleSiteSetting.language_names[locale] ||
|
|
LocaleSiteSetting.language_names[locale.split("_")[0]]
|
|
translation =
|
|
I18n.t(
|
|
"discourse_ai.ai_helper.prompts.translate",
|
|
language: locale_hash["nativeName"],
|
|
) || prompt[:name]
|
|
|
|
prompt.merge(translated_name: translation)
|
|
else
|
|
prompt
|
|
end
|
|
end
|
|
.compact
|
|
end
|
|
|
|
def custom_locale_instructions(user = nil, force_default_locale)
|
|
locale = SiteSetting.default_locale
|
|
locale = user.effective_locale if !force_default_locale && user
|
|
locale_hash = LocaleSiteSetting.language_names[locale]
|
|
|
|
if locale != "en" && locale_hash
|
|
locale_description = "#{locale_hash["name"]} (#{locale_hash["nativeName"]})"
|
|
"It is imperative that you write your answer in #{locale_description}, you are interacting with a #{locale_description} speaking user. Leave tag names in English."
|
|
else
|
|
nil
|
|
end
|
|
end
|
|
|
|
def attach_user_context(context, user = nil, force_default_locale: false)
|
|
locale = SiteSetting.default_locale
|
|
locale = user.effective_locale if user && !force_default_locale
|
|
locale_hash = LocaleSiteSetting.language_names[locale]
|
|
|
|
context.user_language = "#{locale_hash["name"]}"
|
|
|
|
if user
|
|
timezone = user.user_option.timezone || "UTC"
|
|
current_time = Time.now.in_time_zone(timezone)
|
|
|
|
temporal_context = {
|
|
utc_date_time: current_time.iso8601,
|
|
local_time: current_time.strftime("%H:%M"),
|
|
user: {
|
|
timezone: timezone,
|
|
weekday: current_time.strftime("%A"),
|
|
},
|
|
}
|
|
|
|
context.temporal_context = temporal_context.to_json
|
|
end
|
|
|
|
context
|
|
end
|
|
|
|
def generate_prompt(
|
|
helper_mode,
|
|
input,
|
|
user,
|
|
force_default_locale: false,
|
|
custom_prompt: nil,
|
|
&block
|
|
)
|
|
bot = build_bot(helper_mode, user)
|
|
|
|
user_input = "<input>#{input}</input>"
|
|
if helper_mode == CUSTOM_PROMPT && custom_prompt.present?
|
|
user_input = "<input>#{custom_prompt}:\n#{input}</input>"
|
|
end
|
|
|
|
context =
|
|
DiscourseAi::Personas::BotContext.new(
|
|
user: user,
|
|
skip_tool_details: true,
|
|
feature_name: "ai_helper",
|
|
messages: [{ type: :user, content: user_input }],
|
|
format_dates: helper_mode == REPLACE_DATES,
|
|
custom_instructions: custom_locale_instructions(user, force_default_locale),
|
|
)
|
|
context = attach_user_context(context, user, force_default_locale: force_default_locale)
|
|
|
|
helper_response = +""
|
|
|
|
buffer_blk =
|
|
Proc.new do |partial, _, type|
|
|
if type == :structured_output
|
|
json_summary_schema_key = bot.persona.response_format&.first.to_h
|
|
helper_chunk = partial.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
|
|
|
|
if helper_chunk.present?
|
|
helper_response << helper_chunk
|
|
block.call(helper_chunk) if block
|
|
end
|
|
elsif type.blank?
|
|
# Assume response is a regular completion.
|
|
helper_response << helper_chunk
|
|
block.call(helper_chunk) if block
|
|
end
|
|
end
|
|
|
|
bot.reply(context, &buffer_blk)
|
|
|
|
helper_response
|
|
end
|
|
|
|
def generate_and_send_prompt(
|
|
helper_mode,
|
|
input,
|
|
user,
|
|
force_default_locale: false,
|
|
custom_prompt: nil
|
|
)
|
|
helper_response =
|
|
generate_prompt(
|
|
helper_mode,
|
|
input,
|
|
user,
|
|
force_default_locale: force_default_locale,
|
|
custom_prompt: custom_prompt,
|
|
)
|
|
result = { type: prompt_type(helper_mode) }
|
|
|
|
result[:suggestions] = (
|
|
if result[:type] == :list
|
|
parse_list(helper_response).map { |suggestion| sanitize_result(suggestion) }
|
|
else
|
|
sanitized = sanitize_result(helper_response)
|
|
result[:diff] = parse_diff(input, sanitized) if result[:type] == :diff
|
|
[sanitized]
|
|
end
|
|
)
|
|
|
|
result
|
|
end
|
|
|
|
def stream_prompt(
|
|
helper_mode,
|
|
input,
|
|
user,
|
|
channel,
|
|
force_default_locale: false,
|
|
client_id: nil,
|
|
custom_prompt: nil
|
|
)
|
|
streamed_diff = +""
|
|
streamed_result = +""
|
|
start = Time.now
|
|
type = prompt_type(helper_mode)
|
|
|
|
generate_prompt(
|
|
helper_mode,
|
|
input,
|
|
user,
|
|
force_default_locale: force_default_locale,
|
|
custom_prompt: custom_prompt,
|
|
) do |partial_response|
|
|
streamed_result << partial_response
|
|
streamed_diff = parse_diff(input, partial_response) if type == :diff
|
|
|
|
# Throttle updates and check for safe stream points
|
|
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
|
|
sanitized = sanitize_result(streamed_result)
|
|
|
|
payload = { result: sanitized, diff: streamed_diff, done: false }
|
|
publish_update(channel, payload, user, client_id: client_id)
|
|
start = Time.now
|
|
end
|
|
end
|
|
|
|
final_diff = parse_diff(input, streamed_result) if type == :diff
|
|
|
|
sanitized_result = sanitize_result(streamed_result)
|
|
if sanitized_result.present?
|
|
publish_update(
|
|
channel,
|
|
{ result: sanitized_result, diff: final_diff, done: true },
|
|
user,
|
|
client_id: client_id,
|
|
)
|
|
end
|
|
end
|
|
|
|
def generate_image_caption(upload, user)
|
|
bot = build_bot(IMAGE_CAPTION, user)
|
|
force_default_locale = false
|
|
|
|
context =
|
|
DiscourseAi::Personas::BotContext.new(
|
|
user: user,
|
|
skip_tool_details: true,
|
|
feature_name: IMAGE_CAPTION,
|
|
messages: [
|
|
{
|
|
type: :user,
|
|
content: ["Describe this image in a single sentence.", { upload_id: upload.id }],
|
|
},
|
|
],
|
|
custom_instructions: custom_locale_instructions(user, force_default_locale),
|
|
)
|
|
|
|
structured_output = nil
|
|
|
|
buffer_blk =
|
|
Proc.new do |partial, _, type|
|
|
if type == :structured_output
|
|
structured_output = partial
|
|
json_summary_schema_key = bot.persona.response_format&.first.to_h
|
|
end
|
|
end
|
|
|
|
bot.reply(context, llm_args: { max_tokens: 1024 }, &buffer_blk)
|
|
|
|
raw_caption = ""
|
|
|
|
if structured_output
|
|
json_summary_schema_key = bot.persona.response_format&.first.to_h
|
|
raw_caption =
|
|
structured_output.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
|
|
end
|
|
|
|
raw_caption.delete("|").squish.truncate_words(IMAGE_CAPTION_MAX_WORDS)
|
|
end
|
|
|
|
private
|
|
|
|
def build_bot(helper_mode, user)
|
|
persona_id = personas_prompt_map(include_image_caption: true).invert[helper_mode]
|
|
raise Discourse::InvalidParameters.new(:mode) if persona_id.blank?
|
|
|
|
persona_klass = AiPersona.find_by(id: persona_id)&.class_instance
|
|
return if persona_klass.nil?
|
|
|
|
llm_model = find_ai_helper_model(helper_mode, persona_klass)
|
|
|
|
DiscourseAi::Personas::Bot.as(user, persona: persona_klass.new, model: llm_model)
|
|
end
|
|
|
|
# Priorities are:
|
|
# 1. Persona's default LLM
|
|
# 2. Hidden `ai_helper_model` setting, or `ai_helper_image_caption_model` for image_caption.
|
|
# 3. Newest LLM config
|
|
def find_ai_helper_model(helper_mode, persona_klass)
|
|
model_id = persona_klass.default_llm_id
|
|
|
|
if !model_id
|
|
if helper_mode == IMAGE_CAPTION
|
|
model_id = @helper_llm || SiteSetting.ai_helper_image_caption_model&.split(":")&.last
|
|
else
|
|
model_id = @image_caption_llm || SiteSetting.ai_helper_model&.split(":")&.last
|
|
end
|
|
end
|
|
|
|
if model_id.present?
|
|
LlmModel.find_by(id: model_id)
|
|
else
|
|
LlmModel.last
|
|
end
|
|
end
|
|
|
|
def personas_prompt_map(include_image_caption: false)
|
|
map = {
|
|
SiteSetting.ai_helper_translator_persona.to_i => TRANSLATE,
|
|
SiteSetting.ai_helper_tittle_suggestions_persona.to_i => GENERATE_TITLES,
|
|
SiteSetting.ai_helper_proofreader_persona.to_i => PROOFREAD,
|
|
SiteSetting.ai_helper_markdown_tables_persona.to_i => MARKDOWN_TABLE,
|
|
SiteSetting.ai_helper_custom_prompt_persona.to_i => CUSTOM_PROMPT,
|
|
SiteSetting.ai_helper_explain_persona.to_i => EXPLAIN,
|
|
SiteSetting.ai_helper_post_illustrator_persona.to_i => ILLUSTRATE_POST,
|
|
SiteSetting.ai_helper_smart_dates_persona.to_i => REPLACE_DATES,
|
|
}
|
|
|
|
if include_image_caption
|
|
image_caption_persona = SiteSetting.ai_helper_image_caption_persona.to_i
|
|
map[image_caption_persona] = IMAGE_CAPTION if image_caption_persona
|
|
end
|
|
|
|
map
|
|
end
|
|
|
|
def all_prompts
|
|
personas_and_prompts = personas_prompt_map
|
|
|
|
AiPersona
|
|
.where(id: personas_prompt_map.keys)
|
|
.map do |ai_persona|
|
|
prompt_name = personas_prompt_map[ai_persona.id]
|
|
|
|
if prompt_name
|
|
{
|
|
name: prompt_name,
|
|
translated_name:
|
|
I18n.t("discourse_ai.ai_helper.prompts.#{prompt_name}", default: nil) ||
|
|
prompt_name,
|
|
prompt_type: prompt_type(prompt_name),
|
|
icon: icon_map(prompt_name),
|
|
location: location_map(prompt_name),
|
|
allowed_group_ids: ai_persona.allowed_group_ids,
|
|
}
|
|
end
|
|
end
|
|
.compact
|
|
end
|
|
|
|
SANITIZE_REGEX_STR =
|
|
%w[term context topic replyTo input output result]
|
|
.map { |tag| "<#{tag}>\\n?|\\n?</#{tag}>" }
|
|
.join("|")
|
|
|
|
SANITIZE_REGEX = Regexp.new(SANITIZE_REGEX_STR, Regexp::IGNORECASE | Regexp::MULTILINE)
|
|
|
|
def sanitize_result(result)
|
|
result.gsub(SANITIZE_REGEX, "")
|
|
end
|
|
|
|
def publish_update(channel, payload, user, client_id: nil)
|
|
# when publishing we make sure we do not keep large backlogs on the channel
|
|
# and make sure we clear the streaming info after 60 seconds
|
|
# this ensures we do not bloat redis
|
|
if client_id
|
|
MessageBus.publish(
|
|
channel,
|
|
payload,
|
|
user_ids: [user.id],
|
|
client_ids: [client_id],
|
|
max_backlog_age: 60,
|
|
)
|
|
else
|
|
MessageBus.publish(channel, payload, user_ids: [user.id], max_backlog_age: 60)
|
|
end
|
|
end
|
|
|
|
def icon_map(name)
|
|
case name
|
|
when TRANSLATE
|
|
"language"
|
|
when GENERATE_TITLES
|
|
"heading"
|
|
when PROOFREAD
|
|
"spell-check"
|
|
when MARKDOWN_TABLE
|
|
"table"
|
|
when CUSTOM_PROMPT
|
|
"comment"
|
|
when EXPLAIN
|
|
"question"
|
|
when ILLUSTRATE_POST
|
|
"images"
|
|
when REPLACE_DATES
|
|
"calendar-days"
|
|
else
|
|
nil
|
|
end
|
|
end
|
|
|
|
def location_map(name)
|
|
case name
|
|
when TRANSLATE
|
|
%w[composer post]
|
|
when GENERATE_TITLES
|
|
%w[composer]
|
|
when PROOFREAD
|
|
%w[composer post]
|
|
when MARKDOWN_TABLE
|
|
%w[composer]
|
|
when CUSTOM_PROMPT
|
|
%w[composer post]
|
|
when EXPLAIN
|
|
%w[post]
|
|
when ILLUSTRATE_POST
|
|
%w[composer]
|
|
when REPLACE_DATES
|
|
%w[composer]
|
|
else
|
|
%w[]
|
|
end
|
|
end
|
|
|
|
def prompt_type(prompt_name)
|
|
if [PROOFREAD, MARKDOWN_TABLE, REPLACE_DATES, CUSTOM_PROMPT].include?(prompt_name)
|
|
return :diff
|
|
end
|
|
|
|
return :list if [ILLUSTRATE_POST, GENERATE_TITLES].include?(prompt_name)
|
|
|
|
:text
|
|
end
|
|
|
|
def parse_diff(text, suggestion)
|
|
cooked_text = PrettyText.cook(text)
|
|
cooked_suggestion = PrettyText.cook(suggestion)
|
|
|
|
DiscourseDiff.new(cooked_text, cooked_suggestion).inline_html
|
|
end
|
|
|
|
def parse_list(list)
|
|
Nokogiri::HTML5.fragment(list).css("item").map(&:text)
|
|
end
|
|
end
|
|
end
|
|
end
|