FEATURE: improve context management (#1260)
1. Add age of post to topic context (1 month ago, 1 year ago, etc) 2. Refactor code for simplicity 3. Fix handling of post context in DMs which was not using new handling of uploads
This commit is contained in:
parent
67e3a610cb
commit
38b492529f
|
@ -6,7 +6,6 @@ module DiscourseAi
|
|||
MAX_CHAT_UPLOADS = 5
|
||||
MAX_TOPIC_UPLOADS = 5
|
||||
attr_reader :chat_context_posts
|
||||
attr_reader :chat_context_post_upload_ids
|
||||
attr_accessor :topic
|
||||
|
||||
def self.messages_from_chat(
|
||||
|
@ -113,12 +112,13 @@ module DiscourseAi
|
|||
FROM upload_references ref
|
||||
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
|
||||
) as upload_ids",
|
||||
"posts.created_at",
|
||||
)
|
||||
|
||||
builder = new
|
||||
builder.topic = post.topic
|
||||
|
||||
context.reverse_each do |raw, username, custom_prompt, upload_ids|
|
||||
context.reverse_each do |raw, username, custom_prompt, upload_ids, created_at|
|
||||
custom_prompt_translation =
|
||||
Proc.new do |message|
|
||||
# We can't keep backwards-compatibility for stored functions.
|
||||
|
@ -134,6 +134,7 @@ module DiscourseAi
|
|||
|
||||
thinking = message[4]
|
||||
custom_context[:thinking] = thinking if thinking
|
||||
custom_context[:created_at] = created_at
|
||||
|
||||
builder.push(**custom_context)
|
||||
end
|
||||
|
@ -149,6 +150,7 @@ module DiscourseAi
|
|||
if upload_ids.present? && context[:type] == :user && include_uploads
|
||||
context[:upload_ids] = upload_ids.compact
|
||||
end
|
||||
context[:created_at] = created_at
|
||||
|
||||
builder.push(**context)
|
||||
end
|
||||
|
@ -159,6 +161,7 @@ module DiscourseAi
|
|||
|
||||
def initialize
|
||||
@raw_messages = []
|
||||
@timestamps = {}
|
||||
end
|
||||
|
||||
def set_chat_context_posts(post_ids, guardian, include_uploads:)
|
||||
|
@ -171,27 +174,66 @@ module DiscourseAi
|
|||
posts << post
|
||||
end
|
||||
if posts.present?
|
||||
posts_context =
|
||||
+"\nThis chat is in the context of the Discourse topic '#{posts[0].topic.title}':\n\n"
|
||||
posts_context = +"{{{\n"
|
||||
posts_context = []
|
||||
posts_context << "\nThis chat is in the context of the Discourse topic '#{posts[0].topic.title}':\n\n"
|
||||
posts_context << "{{{\n"
|
||||
posts.each do |post|
|
||||
posts_context << "url: #{post.url}\n"
|
||||
posts_context << "#{post.username}: #{post.raw}\n\n"
|
||||
if include_uploads
|
||||
post.uploads.each { |upload| posts_context << { upload_id: upload.id } }
|
||||
end
|
||||
end
|
||||
posts_context << "}}}"
|
||||
@chat_context_posts = posts_context
|
||||
if include_uploads
|
||||
uploads = []
|
||||
posts.each { |post| uploads.concat(post.uploads.pluck(:id)) }
|
||||
uploads.uniq!
|
||||
@chat_context_post_upload_ids = uploads.take(MAX_CHAT_UPLOADS)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def to_a(limit: nil, style: nil)
|
||||
# topic and chat array are special, they are single messages that contain all history
|
||||
return chat_array(limit: limit) if style == :chat
|
||||
return topic_array if style == :topic
|
||||
|
||||
# the rest of the styles can include multiple messages
|
||||
result = valid_messages_array(@raw_messages)
|
||||
prepend_chat_post_context(result) if style == :chat_with_context
|
||||
|
||||
if limit
|
||||
result[0..limit]
|
||||
else
|
||||
result
|
||||
end
|
||||
end
|
||||
|
||||
def push(type:, content:, name: nil, upload_ids: nil, id: nil, thinking: nil, created_at: nil)
|
||||
if !%i[user model tool tool_call system].include?(type)
|
||||
raise ArgumentError, "type must be either :user, :model, :tool, :tool_call or :system"
|
||||
end
|
||||
raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array)
|
||||
|
||||
content = [content, *upload_ids.map { |upload_id| { upload_id: upload_id } }] if upload_ids
|
||||
message = { type: type, content: content }
|
||||
message[:name] = name.to_s if name
|
||||
message[:id] = id.to_s if id
|
||||
if thinking
|
||||
message[:thinking] = thinking["thinking"] if thinking["thinking"]
|
||||
message[:thinking_signature] = thinking["thinking_signature"] if thinking[
|
||||
"thinking_signature"
|
||||
]
|
||||
message[:redacted_thinking_signature] = thinking[
|
||||
"redacted_thinking_signature"
|
||||
] if thinking["redacted_thinking_signature"]
|
||||
end
|
||||
|
||||
@raw_messages << message
|
||||
@timestamps[message] = created_at if created_at
|
||||
|
||||
message
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def valid_messages_array(messages)
|
||||
result = []
|
||||
|
||||
# this will create a "valid" messages array
|
||||
|
@ -199,7 +241,7 @@ module DiscourseAi
|
|||
# 2. ensures we always end with a user message
|
||||
# 3. ensures we always interleave user and model messages
|
||||
last_type = nil
|
||||
@raw_messages.each do |message|
|
||||
messages.each do |message|
|
||||
next if !last_type && message[:type] != :user
|
||||
|
||||
if last_type == :tool_call && message[:type] != :tool
|
||||
|
@ -239,52 +281,27 @@ module DiscourseAi
|
|||
last_type = message[:type]
|
||||
end
|
||||
|
||||
if style == :chat_with_context && @chat_context_posts
|
||||
buffer = +"You are replying inside a Discourse chat."
|
||||
buffer << "\n"
|
||||
buffer << @chat_context_posts
|
||||
buffer << "\n"
|
||||
buffer << "Your instructions are:\n"
|
||||
result[0][:content] = "#{buffer}#{result[0][:content]}"
|
||||
if @chat_context_post_upload_ids.present?
|
||||
result[0][:upload_ids] = (result[0][:upload_ids] || []).concat(
|
||||
@chat_context_post_upload_ids,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
if limit
|
||||
result[0..limit]
|
||||
else
|
||||
result
|
||||
end
|
||||
result
|
||||
end
|
||||
|
||||
def push(type:, content:, name: nil, upload_ids: nil, id: nil, thinking: nil)
|
||||
if !%i[user model tool tool_call system].include?(type)
|
||||
raise ArgumentError, "type must be either :user, :model, :tool, :tool_call or :system"
|
||||
end
|
||||
raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array)
|
||||
def prepend_chat_post_context(messages)
|
||||
return if @chat_context_posts.blank?
|
||||
|
||||
content = [content, *upload_ids.map { |upload_id| { upload_id: upload_id } }] if upload_ids
|
||||
message = { type: type, content: content }
|
||||
message[:name] = name.to_s if name
|
||||
message[:id] = id.to_s if id
|
||||
if thinking
|
||||
message[:thinking] = thinking["thinking"] if thinking["thinking"]
|
||||
message[:thinking_signature] = thinking["thinking_signature"] if thinking[
|
||||
"thinking_signature"
|
||||
]
|
||||
message[:redacted_thinking_signature] = thinking[
|
||||
"redacted_thinking_signature"
|
||||
] if thinking["redacted_thinking_signature"]
|
||||
end
|
||||
old_content = messages[0][:content]
|
||||
old_content = [old_content] if !old_content.is_a?(Array)
|
||||
|
||||
@raw_messages << message
|
||||
new_content = []
|
||||
new_content << "You are replying inside a Discourse chat.\n"
|
||||
new_content.concat(@chat_context_posts)
|
||||
new_content << "\n"
|
||||
new_content << "Your instructions are:\n"
|
||||
new_content.concat(old_content)
|
||||
|
||||
compressed = compress_messages_buffer(new_content.flatten, max_uploads: MAX_CHAT_UPLOADS)
|
||||
|
||||
messages[0][:content] = compressed
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def format_user_info(user)
|
||||
info = []
|
||||
info << user_role(user)
|
||||
|
@ -294,6 +311,34 @@ module DiscourseAi
|
|||
"#{user.username} (#{user.name}): #{info.compact.join(", ")}"
|
||||
end
|
||||
|
||||
def format_timestamp(timestamp)
|
||||
return nil unless timestamp
|
||||
|
||||
time_diff = Time.now - timestamp
|
||||
|
||||
if time_diff < 1.minute
|
||||
"just now"
|
||||
elsif time_diff < 1.hour
|
||||
mins = (time_diff / 1.minute).round
|
||||
"#{mins} #{mins == 1 ? "minute" : "minutes"} ago"
|
||||
elsif time_diff < 1.day
|
||||
hours = (time_diff / 1.hour).round
|
||||
"#{hours} #{hours == 1 ? "hour" : "hours"} ago"
|
||||
elsif time_diff < 7.days
|
||||
days = (time_diff / 1.day).round
|
||||
"#{days} #{days == 1 ? "day" : "days"} ago"
|
||||
elsif time_diff < 30.days
|
||||
weeks = (time_diff / 7.days).round
|
||||
"#{weeks} #{weeks == 1 ? "week" : "weeks"} ago"
|
||||
elsif time_diff < 365.days
|
||||
months = (time_diff / 30.days).round
|
||||
"#{months} #{months == 1 ? "month" : "months"} ago"
|
||||
else
|
||||
years = (time_diff / 365.days).round
|
||||
"#{years} #{years == 1 ? "year" : "years"} ago"
|
||||
end
|
||||
end
|
||||
|
||||
def user_role(user)
|
||||
return "moderator" if user.moderator?
|
||||
return "admin" if user.admin?
|
||||
|
@ -323,45 +368,57 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def format_topic_info(topic)
|
||||
content_array = []
|
||||
|
||||
if topic.private_message?
|
||||
content_array << "Private message info.\n"
|
||||
else
|
||||
content_array << "Topic information:\n"
|
||||
end
|
||||
|
||||
content_array << "- URL: #{topic.url}\n"
|
||||
content_array << "- Title: #{topic.title}\n"
|
||||
if SiteSetting.tagging_enabled
|
||||
tags = topic.tags.pluck(:name)
|
||||
tags -= DiscourseTagging.hidden_tag_names if tags.present?
|
||||
content_array << "- Tags: #{tags.join(", ")}\n" if tags.present?
|
||||
end
|
||||
if !topic.private_message?
|
||||
content_array << "- Category: #{topic.category.name}\n" if topic.category
|
||||
end
|
||||
content_array << "- Number of replies: #{topic.posts_count - 1}\n\n"
|
||||
|
||||
content_array.join
|
||||
end
|
||||
|
||||
def format_user_infos(usernames)
|
||||
content_array = []
|
||||
|
||||
if usernames.present?
|
||||
users_details =
|
||||
User
|
||||
.where(username: usernames)
|
||||
.includes(:user_stat)
|
||||
.map { |user| format_user_info(user) }
|
||||
.compact
|
||||
content_array << "User information:\n"
|
||||
content_array << "- #{users_details.join("\n- ")}\n\n" if users_details.present?
|
||||
end
|
||||
content_array.join
|
||||
end
|
||||
|
||||
def topic_array
|
||||
raw_messages = @raw_messages.dup
|
||||
content_array = []
|
||||
content_array << "You are operating in a Discourse forum.\n\n"
|
||||
|
||||
if @topic
|
||||
if @topic.private_message?
|
||||
content_array << "Private message info.\n"
|
||||
else
|
||||
content_array << "Topic information:\n"
|
||||
end
|
||||
|
||||
content_array << "- URL: #{@topic.url}\n"
|
||||
content_array << "- Title: #{@topic.title}\n"
|
||||
if SiteSetting.tagging_enabled
|
||||
tags = @topic.tags.pluck(:name)
|
||||
tags -= DiscourseTagging.hidden_tag_names if tags.present?
|
||||
content_array << "- Tags: #{tags.join(", ")}\n" if tags.present?
|
||||
end
|
||||
if !@topic.private_message?
|
||||
content_array << "- Category: #{@topic.category.name}\n" if @topic.category
|
||||
end
|
||||
content_array << "- Number of replies: #{@topic.posts_count - 1}\n\n"
|
||||
end
|
||||
content_array << format_topic_info(@topic) if @topic
|
||||
|
||||
if raw_messages.present?
|
||||
usernames =
|
||||
raw_messages.filter { |message| message[:type] == :user }.map { |message| message[:id] }
|
||||
|
||||
if usernames.present?
|
||||
users_details =
|
||||
User
|
||||
.where(username: usernames)
|
||||
.includes(:user_stat)
|
||||
.map { |user| format_user_info(user) }
|
||||
.compact
|
||||
content_array << "User information:\n"
|
||||
content_array << "- #{users_details.join("\n- ")}\n\n" if users_details.present?
|
||||
end
|
||||
content_array << format_user_infos(usernames) if usernames.present?
|
||||
end
|
||||
|
||||
last_user_message = raw_messages.pop
|
||||
|
@ -370,6 +427,8 @@ module DiscourseAi
|
|||
content_array << "Here is the conversation so far:\n"
|
||||
raw_messages.each do |message|
|
||||
content_array << "#{message[:id] || "User"}: "
|
||||
timestamp = @timestamps[message]
|
||||
content_array << "(#{format_timestamp(timestamp)}) " if timestamp
|
||||
content_array << message[:content]
|
||||
content_array << "\n\n"
|
||||
end
|
||||
|
|
|
@ -95,6 +95,72 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
|
|||
expect(content).to include("How do I solve this")
|
||||
end
|
||||
|
||||
describe "chat context posts in direct messages" do
|
||||
fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, bot_user]) }
|
||||
fab!(:dm_message) do
|
||||
Fabricate(
|
||||
:chat_message,
|
||||
chat_channel: dm_channel,
|
||||
user: user,
|
||||
message: "I have a question about the topic",
|
||||
)
|
||||
end
|
||||
|
||||
fab!(:topic) { Fabricate(:topic, title: "Important topic for context") }
|
||||
fab!(:post1) { Fabricate(:post, topic: topic, user: other_user, raw: "This is the first post") }
|
||||
fab!(:post2) { Fabricate(:post, topic: topic, user: user, raw: "And here's a follow-up") }
|
||||
|
||||
it "correctly includes topic posts as context in direct message channels" do
|
||||
context =
|
||||
described_class.messages_from_chat(
|
||||
dm_message,
|
||||
channel: dm_channel,
|
||||
context_post_ids: [post1.id, post2.id],
|
||||
max_messages: 10,
|
||||
include_uploads: false,
|
||||
bot_user_ids: [bot_user.id],
|
||||
instruction_message: nil,
|
||||
)
|
||||
|
||||
expect(context.length).to eq(1)
|
||||
content = context.first[:content]
|
||||
|
||||
# First part should contain the context intro
|
||||
expect(content).to include("You are replying inside a Discourse chat")
|
||||
expect(content).to include(
|
||||
"This chat is in the context of the Discourse topic 'Important topic for context'",
|
||||
)
|
||||
expect(content).to include(post1.username)
|
||||
expect(content).to include("This is the first post")
|
||||
expect(content).to include(post2.username)
|
||||
expect(content).to include("And here's a follow-up")
|
||||
|
||||
# Last part should have the user's message
|
||||
expect(content).to include("I have a question about the topic")
|
||||
end
|
||||
|
||||
it "includes uploads from context posts when include_uploads is true" do
|
||||
upload = Fabricate(:upload, user: user)
|
||||
UploadReference.create!(target: post1, upload: upload)
|
||||
|
||||
context =
|
||||
described_class.messages_from_chat(
|
||||
dm_message,
|
||||
channel: dm_channel,
|
||||
context_post_ids: [post1.id],
|
||||
max_messages: 10,
|
||||
include_uploads: true,
|
||||
bot_user_ids: [bot_user.id],
|
||||
instruction_message: nil,
|
||||
)
|
||||
|
||||
# Verify the upload reference is included
|
||||
upload_hashes = context.first[:content].select { |item| item.is_a?(Hash) && item[:upload_id] }
|
||||
expect(upload_hashes).to be_present
|
||||
expect(upload_hashes.first[:upload_id]).to eq(upload.id)
|
||||
end
|
||||
end
|
||||
|
||||
describe ".messages_from_chat" do
|
||||
fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, bot_user]) }
|
||||
fab!(:dm_message1) do
|
||||
|
@ -366,7 +432,9 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
|
|||
# will be brittle, but open to changing this
|
||||
end
|
||||
|
||||
it "handles uploads correctly in topic style messages" do
|
||||
it "handles uploads correctly in topic style messages (and times)" do
|
||||
freeze_time 1.month.ago
|
||||
|
||||
# Use Discourse's upload format in the post raw content
|
||||
upload_markdown = ""
|
||||
|
||||
|
@ -382,6 +450,8 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
|
|||
|
||||
upload2_markdown = ""
|
||||
|
||||
freeze_time 1.month.from_now
|
||||
|
||||
post2_with_upload =
|
||||
Fabricate(
|
||||
:post,
|
||||
|
@ -415,6 +485,7 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
|
|||
# second image
|
||||
expect(content.length).to eq(4)
|
||||
expect(content[0]).to include("This is the original")
|
||||
expect(content[0]).to include("(1 month ago)")
|
||||
expect(content[1]).to eq({ upload_id: image_upload1.id })
|
||||
expect(content[2]).to include("different image")
|
||||
expect(content[3]).to eq({ upload_id: image_upload2.id })
|
||||
|
|
Loading…
Reference in New Issue