diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index dd7db451..9f3a1baa 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -124,6 +124,10 @@ module DiscourseAi raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank? end + # to stream we must have an appropriate client_id + # otherwise we may end up streaming the data to the wrong client + raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank? + if location == "composer" Jobs.enqueue( :stream_composer_helper, @@ -132,6 +136,7 @@ module DiscourseAi prompt: prompt.name, custom_prompt: params[:custom_prompt], force_default_locale: params[:force_default_locale] || false, + client_id: params[:client_id], ) else post_id = get_post_param! @@ -146,6 +151,7 @@ module DiscourseAi text: text, prompt: prompt.name, custom_prompt: params[:custom_prompt], + client_id: params[:client_id], ) end diff --git a/app/jobs/regular/stream_composer_helper.rb b/app/jobs/regular/stream_composer_helper.rb index 5e8f13d6..c3066c32 100644 --- a/app/jobs/regular/stream_composer_helper.rb +++ b/app/jobs/regular/stream_composer_helper.rb @@ -8,6 +8,7 @@ module Jobs return unless args[:prompt] return unless user = User.find_by(id: args[:user_id]) return unless args[:text] + return unless args[:client_id] prompt = CompletionPrompt.enabled_by_name(args[:prompt]) @@ -21,6 +22,7 @@ module Jobs user, "/discourse-ai/ai-helper/stream_composer_suggestion", force_default_locale: args[:force_default_locale], + client_id: args[:client_id], ) end end diff --git a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs index f181d390..52c34f82 100644 --- a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs +++ b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs @@ -242,6 +242,7 @@ export default class AiPostHelperMenu extends Component { text: this.args.data.selectedText, post_id: this.args.data.quoteState.postId, custom_prompt: this.customPromptValue, + client_id: this.messageBus.clientId, }, }); diff --git a/assets/javascripts/discourse/components/modal/diff-modal.gjs b/assets/javascripts/discourse/components/modal/diff-modal.gjs index 00cb4124..0fb4f262 100644 --- a/assets/javascripts/discourse/components/modal/diff-modal.gjs +++ b/assets/javascripts/discourse/components/modal/diff-modal.gjs @@ -108,6 +108,7 @@ export default class ModalDiffModal extends Component { text: this.selectedText, custom_prompt: this.args.model.customPromptValue, force_default_locale: true, + client_id: this.messageBus.clientId, }, }); } catch (e) { diff --git a/assets/javascripts/discourse/lib/diff-streamer.gjs b/assets/javascripts/discourse/lib/diff-streamer.gjs index a83095f5..f02a4be2 100644 --- a/assets/javascripts/discourse/lib/diff-streamer.gjs +++ b/assets/javascripts/discourse/lib/diff-streamer.gjs @@ -1,5 +1,5 @@ import { tracked } from "@glimmer/tracking"; -import { later } from "@ember/runloop"; +import { cancel, later } from "@ember/runloop"; import loadJSDiff from "discourse/lib/load-js-diff"; import { parseAsync } from "discourse/lib/text"; @@ -45,7 +45,7 @@ export default class DiffStreamer { this.words = []; if (this.typingTimer) { - clearTimeout(this.typingTimer); + cancel(this.typingTimer); this.typingTimer = null; } @@ -100,7 +100,7 @@ export default class DiffStreamer { this.currentCharIndex = 0; this.isStreaming = false; if (this.typingTimer) { - clearTimeout(this.typingTimer); + cancel(this.typingTimer); this.typingTimer = null; } } @@ -254,6 +254,8 @@ export default class DiffStreamer { #formatDiffWithTags(diffArray, highlightLastWord = true) { const wordsWithType = []; + const output = []; + diffArray.forEach((part) => { const tokens = part.value.match(/\S+|\s+/g) || []; tokens.forEach((token) => { @@ -277,8 +279,6 @@ export default class DiffStreamer { } } - const output = []; - for (let i = 0; i <= lastWordIndex; i++) { const { text, type } = wordsWithType[i]; diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index c9d4691e..a771b9cd 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -166,7 +166,14 @@ module DiscourseAi result end - def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false) + def stream_prompt( + completion_prompt, + input, + user, + channel, + force_default_locale: false, + client_id: nil + ) streamed_diff = +"" streamed_result = +"" start = Time.now @@ -178,7 +185,6 @@ module DiscourseAi force_default_locale: force_default_locale, ) do |partial_response, cancel_function| streamed_result << partial_response - streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff? # Throttle updates and check for safe stream points @@ -186,7 +192,7 @@ module DiscourseAi sanitized = sanitize_result(streamed_result) payload = { result: sanitized, diff: streamed_diff, done: false } - publish_update(channel, payload, user) + publish_update(channel, payload, user, client_id: client_id) start = Time.now end end @@ -195,7 +201,12 @@ module DiscourseAi sanitized_result = sanitize_result(streamed_result) if sanitized_result.present? - publish_update(channel, { result: sanitized_result, diff: final_diff, done: true }, user) + publish_update( + channel, + { result: sanitized_result, diff: final_diff, done: true }, + user, + client_id: client_id, + ) end end @@ -238,8 +249,21 @@ module DiscourseAi result.gsub(SANITIZE_REGEX, "") end - def publish_update(channel, payload, user) - MessageBus.publish(channel, payload, user_ids: [user.id]) + def publish_update(channel, payload, user, client_id: nil) + # when publishing we make sure we do not keep large backlogs on the channel + # and make sure we clear the streaming info after 60 seconds + # this ensures we do not bloat redis + if client_id + MessageBus.publish( + channel, + payload, + user_ids: [user.id], + client_ids: [client_id], + max_backlog_age: 60, + ) + else + MessageBus.publish(channel, payload, user_ids: [user.id], max_backlog_age: 60) + end end def icon_map(name) diff --git a/spec/jobs/regular/stream_composer_helper_spec.rb b/spec/jobs/regular/stream_composer_helper_spec.rb index 03afc2f8..7d6c623b 100644 --- a/spec/jobs/regular/stream_composer_helper_spec.rb +++ b/spec/jobs/regular/stream_composer_helper_spec.rb @@ -35,6 +35,7 @@ RSpec.describe Jobs::StreamComposerHelper do text: nil, prompt: prompt.name, force_default_locale: false, + client_id: "123", ) end @@ -58,6 +59,7 @@ RSpec.describe Jobs::StreamComposerHelper do text: input, prompt: prompt.name, force_default_locale: true, + client_id: "123", ) end @@ -78,6 +80,7 @@ RSpec.describe Jobs::StreamComposerHelper do text: input, prompt: prompt.name, force_default_locale: true, + client_id: "123", ) end diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 87b29578..deae8a2a 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -2,6 +2,41 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do before { assign_fake_provider_to(:ai_helper_model) } + fab!(:newuser) + fab!(:user) { Fabricate(:user, refresh_auto_groups: true) } + + describe "#stream_suggestion" do + before do + Jobs.run_immediately! + SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0] + end + + it "is able to stream suggestions back on appropriate channel" do + sign_in(user) + messages = + MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do + results = [["hello ", "world"]] + DiscourseAi::Completions::Llm.with_prepared_responses(results) do + post "/discourse-ai/ai-helper/stream_suggestion.json", + params: { + text: "hello wrld", + location: "composer", + client_id: "1234", + mode: CompletionPrompt::PROOFREAD, + } + + expect(response.status).to eq(200) + end + end + + last_message = messages.last + expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true) + expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true) + + expect(last_message.data[:result]).to eq("hello world") + expect(last_message.data[:done]).to eq(true) + end + end describe "#suggest" do let(:text_to_proofread) { "The rain in spain stays mainly in the plane." } @@ -17,10 +52,8 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do end context "when logged in as an user without enough privileges" do - fab!(:user) { Fabricate(:newuser) } - before do - sign_in(user) + sign_in(newuser) SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:staff] end @@ -32,8 +65,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do end context "when logged in as an allowed user" do - fab!(:user) - before do sign_in(user) user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]] @@ -141,8 +172,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do fab!(:post_2) { Fabricate(:post, topic: topic, raw: "I love bananas") } context "when logged in as an allowed user" do - fab!(:user) - before do sign_in(user) user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]] @@ -219,8 +248,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do end context "when logged in as an allowed user" do - fab!(:user) { Fabricate(:user, refresh_auto_groups: true) } - before do sign_in(user)