FIX: Improve MessageBus efficiency and correctly stop streaming (#1362)
* FIX: Improve MessageBus efficiency and correctly stop streaming This commit enhances the message bus implementation for AI helper streaming by: - Adding client_id targeting for message bus publications to ensure only the requesting client receives streaming updates - Limiting MessageBus backlog size (2) and age (60 seconds) to prevent Redis bloat - Replacing clearTimeout with Ember's cancel method for proper runloop management, we were leaking a stop - Adding tests for client-specific message delivery These changes improve memory usage and make streaming more reliable by ensuring messages are properly directed to the requesting client. * composer suggestion needed a fix as well. * backlog size of 2 is risky here cause same channel name is reused between clients
This commit is contained in:
parent
61ef1932fa
commit
cf220c530c
|
@ -124,6 +124,10 @@ module DiscourseAi
|
|||
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
|
||||
end
|
||||
|
||||
# to stream we must have an appropriate client_id
|
||||
# otherwise we may end up streaming the data to the wrong client
|
||||
raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank?
|
||||
|
||||
if location == "composer"
|
||||
Jobs.enqueue(
|
||||
:stream_composer_helper,
|
||||
|
@ -132,6 +136,7 @@ module DiscourseAi
|
|||
prompt: prompt.name,
|
||||
custom_prompt: params[:custom_prompt],
|
||||
force_default_locale: params[:force_default_locale] || false,
|
||||
client_id: params[:client_id],
|
||||
)
|
||||
else
|
||||
post_id = get_post_param!
|
||||
|
@ -146,6 +151,7 @@ module DiscourseAi
|
|||
text: text,
|
||||
prompt: prompt.name,
|
||||
custom_prompt: params[:custom_prompt],
|
||||
client_id: params[:client_id],
|
||||
)
|
||||
end
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ module Jobs
|
|||
return unless args[:prompt]
|
||||
return unless user = User.find_by(id: args[:user_id])
|
||||
return unless args[:text]
|
||||
return unless args[:client_id]
|
||||
|
||||
prompt = CompletionPrompt.enabled_by_name(args[:prompt])
|
||||
|
||||
|
@ -21,6 +22,7 @@ module Jobs
|
|||
user,
|
||||
"/discourse-ai/ai-helper/stream_composer_suggestion",
|
||||
force_default_locale: args[:force_default_locale],
|
||||
client_id: args[:client_id],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -242,6 +242,7 @@ export default class AiPostHelperMenu extends Component {
|
|||
text: this.args.data.selectedText,
|
||||
post_id: this.args.data.quoteState.postId,
|
||||
custom_prompt: this.customPromptValue,
|
||||
client_id: this.messageBus.clientId,
|
||||
},
|
||||
});
|
||||
|
||||
|
|
|
@ -108,6 +108,7 @@ export default class ModalDiffModal extends Component {
|
|||
text: this.selectedText,
|
||||
custom_prompt: this.args.model.customPromptValue,
|
||||
force_default_locale: true,
|
||||
client_id: this.messageBus.clientId,
|
||||
},
|
||||
});
|
||||
} catch (e) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { tracked } from "@glimmer/tracking";
|
||||
import { later } from "@ember/runloop";
|
||||
import { cancel, later } from "@ember/runloop";
|
||||
import loadJSDiff from "discourse/lib/load-js-diff";
|
||||
import { parseAsync } from "discourse/lib/text";
|
||||
|
||||
|
@ -45,7 +45,7 @@ export default class DiffStreamer {
|
|||
this.words = [];
|
||||
|
||||
if (this.typingTimer) {
|
||||
clearTimeout(this.typingTimer);
|
||||
cancel(this.typingTimer);
|
||||
this.typingTimer = null;
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ export default class DiffStreamer {
|
|||
this.currentCharIndex = 0;
|
||||
this.isStreaming = false;
|
||||
if (this.typingTimer) {
|
||||
clearTimeout(this.typingTimer);
|
||||
cancel(this.typingTimer);
|
||||
this.typingTimer = null;
|
||||
}
|
||||
}
|
||||
|
@ -254,6 +254,8 @@ export default class DiffStreamer {
|
|||
|
||||
#formatDiffWithTags(diffArray, highlightLastWord = true) {
|
||||
const wordsWithType = [];
|
||||
const output = [];
|
||||
|
||||
diffArray.forEach((part) => {
|
||||
const tokens = part.value.match(/\S+|\s+/g) || [];
|
||||
tokens.forEach((token) => {
|
||||
|
@ -277,8 +279,6 @@ export default class DiffStreamer {
|
|||
}
|
||||
}
|
||||
|
||||
const output = [];
|
||||
|
||||
for (let i = 0; i <= lastWordIndex; i++) {
|
||||
const { text, type } = wordsWithType[i];
|
||||
|
||||
|
|
|
@ -166,7 +166,14 @@ module DiscourseAi
|
|||
result
|
||||
end
|
||||
|
||||
def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false)
|
||||
def stream_prompt(
|
||||
completion_prompt,
|
||||
input,
|
||||
user,
|
||||
channel,
|
||||
force_default_locale: false,
|
||||
client_id: nil
|
||||
)
|
||||
streamed_diff = +""
|
||||
streamed_result = +""
|
||||
start = Time.now
|
||||
|
@ -178,7 +185,6 @@ module DiscourseAi
|
|||
force_default_locale: force_default_locale,
|
||||
) do |partial_response, cancel_function|
|
||||
streamed_result << partial_response
|
||||
|
||||
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
|
||||
|
||||
# Throttle updates and check for safe stream points
|
||||
|
@ -186,7 +192,7 @@ module DiscourseAi
|
|||
sanitized = sanitize_result(streamed_result)
|
||||
|
||||
payload = { result: sanitized, diff: streamed_diff, done: false }
|
||||
publish_update(channel, payload, user)
|
||||
publish_update(channel, payload, user, client_id: client_id)
|
||||
start = Time.now
|
||||
end
|
||||
end
|
||||
|
@ -195,7 +201,12 @@ module DiscourseAi
|
|||
|
||||
sanitized_result = sanitize_result(streamed_result)
|
||||
if sanitized_result.present?
|
||||
publish_update(channel, { result: sanitized_result, diff: final_diff, done: true }, user)
|
||||
publish_update(
|
||||
channel,
|
||||
{ result: sanitized_result, diff: final_diff, done: true },
|
||||
user,
|
||||
client_id: client_id,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -238,8 +249,21 @@ module DiscourseAi
|
|||
result.gsub(SANITIZE_REGEX, "")
|
||||
end
|
||||
|
||||
def publish_update(channel, payload, user)
|
||||
MessageBus.publish(channel, payload, user_ids: [user.id])
|
||||
def publish_update(channel, payload, user, client_id: nil)
|
||||
# when publishing we make sure we do not keep large backlogs on the channel
|
||||
# and make sure we clear the streaming info after 60 seconds
|
||||
# this ensures we do not bloat redis
|
||||
if client_id
|
||||
MessageBus.publish(
|
||||
channel,
|
||||
payload,
|
||||
user_ids: [user.id],
|
||||
client_ids: [client_id],
|
||||
max_backlog_age: 60,
|
||||
)
|
||||
else
|
||||
MessageBus.publish(channel, payload, user_ids: [user.id], max_backlog_age: 60)
|
||||
end
|
||||
end
|
||||
|
||||
def icon_map(name)
|
||||
|
|
|
@ -35,6 +35,7 @@ RSpec.describe Jobs::StreamComposerHelper do
|
|||
text: nil,
|
||||
prompt: prompt.name,
|
||||
force_default_locale: false,
|
||||
client_id: "123",
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -58,6 +59,7 @@ RSpec.describe Jobs::StreamComposerHelper do
|
|||
text: input,
|
||||
prompt: prompt.name,
|
||||
force_default_locale: true,
|
||||
client_id: "123",
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -78,6 +80,7 @@ RSpec.describe Jobs::StreamComposerHelper do
|
|||
text: input,
|
||||
prompt: prompt.name,
|
||||
force_default_locale: true,
|
||||
client_id: "123",
|
||||
)
|
||||
end
|
||||
|
||||
|
|
|
@ -2,6 +2,41 @@
|
|||
|
||||
RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
||||
before { assign_fake_provider_to(:ai_helper_model) }
|
||||
fab!(:newuser)
|
||||
fab!(:user) { Fabricate(:user, refresh_auto_groups: true) }
|
||||
|
||||
describe "#stream_suggestion" do
|
||||
before do
|
||||
Jobs.run_immediately!
|
||||
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0]
|
||||
end
|
||||
|
||||
it "is able to stream suggestions back on appropriate channel" do
|
||||
sign_in(user)
|
||||
messages =
|
||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
|
||||
results = [["hello ", "world"]]
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
|
||||
post "/discourse-ai/ai-helper/stream_suggestion.json",
|
||||
params: {
|
||||
text: "hello wrld",
|
||||
location: "composer",
|
||||
client_id: "1234",
|
||||
mode: CompletionPrompt::PROOFREAD,
|
||||
}
|
||||
|
||||
expect(response.status).to eq(200)
|
||||
end
|
||||
end
|
||||
|
||||
last_message = messages.last
|
||||
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
|
||||
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)
|
||||
|
||||
expect(last_message.data[:result]).to eq("hello world")
|
||||
expect(last_message.data[:done]).to eq(true)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#suggest" do
|
||||
let(:text_to_proofread) { "The rain in spain stays mainly in the plane." }
|
||||
|
@ -17,10 +52,8 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
end
|
||||
|
||||
context "when logged in as an user without enough privileges" do
|
||||
fab!(:user) { Fabricate(:newuser) }
|
||||
|
||||
before do
|
||||
sign_in(user)
|
||||
sign_in(newuser)
|
||||
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:staff]
|
||||
end
|
||||
|
||||
|
@ -32,8 +65,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
end
|
||||
|
||||
context "when logged in as an allowed user" do
|
||||
fab!(:user)
|
||||
|
||||
before do
|
||||
sign_in(user)
|
||||
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
|
||||
|
@ -141,8 +172,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
fab!(:post_2) { Fabricate(:post, topic: topic, raw: "I love bananas") }
|
||||
|
||||
context "when logged in as an allowed user" do
|
||||
fab!(:user)
|
||||
|
||||
before do
|
||||
sign_in(user)
|
||||
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
|
||||
|
@ -219,8 +248,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
end
|
||||
|
||||
context "when logged in as an allowed user" do
|
||||
fab!(:user) { Fabricate(:user, refresh_auto_groups: true) }
|
||||
|
||||
before do
|
||||
sign_in(user)
|
||||
|
||||
|
|
Loading…
Reference in New Issue