diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
index d2ee8738..dd7db451 100644
--- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
+++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
@@ -43,7 +43,7 @@ module DiscourseAi
prompt,
input,
current_user,
- force_default_locale,
+ force_default_locale: force_default_locale,
),
status: 200
end
@@ -110,26 +110,44 @@ module DiscourseAi
end
def stream_suggestion
- post_id = get_post_param!
text = get_text_param!
- post = Post.includes(:topic).find_by(id: post_id)
+
+ location = params[:location]
+ raise Discourse::InvalidParameters.new(:location) if !location
+
prompt = CompletionPrompt.find_by(id: params[:mode])
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
- raise Discourse::InvalidParameters.new(:post_id) unless post
+ return suggest_thumbnails(input) if prompt.id == CompletionPrompt::ILLUSTRATE_POST
if prompt.id == CompletionPrompt::CUSTOM_PROMPT
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
end
- Jobs.enqueue(
- :stream_post_helper,
- post_id: post.id,
- user_id: current_user.id,
- text: text,
- prompt: prompt.name,
- custom_prompt: params[:custom_prompt],
- )
+ if location == "composer"
+ Jobs.enqueue(
+ :stream_composer_helper,
+ user_id: current_user.id,
+ text: text,
+ prompt: prompt.name,
+ custom_prompt: params[:custom_prompt],
+ force_default_locale: params[:force_default_locale] || false,
+ )
+ else
+ post_id = get_post_param!
+ post = Post.includes(:topic).find_by(id: post_id)
+
+ raise Discourse::InvalidParameters.new(:post_id) unless post
+
+ Jobs.enqueue(
+ :stream_post_helper,
+ post_id: post.id,
+ user_id: current_user.id,
+ text: text,
+ prompt: prompt.name,
+ custom_prompt: params[:custom_prompt],
+ )
+ end
render json: { success: true }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
diff --git a/app/jobs/regular/stream_composer_helper.rb b/app/jobs/regular/stream_composer_helper.rb
new file mode 100644
index 00000000..5e8f13d6
--- /dev/null
+++ b/app/jobs/regular/stream_composer_helper.rb
@@ -0,0 +1,27 @@
+# frozen_string_literal: true
+
+module Jobs
+ class StreamComposerHelper < ::Jobs::Base
+ sidekiq_options retry: false
+
+ def execute(args)
+ return unless args[:prompt]
+ return unless user = User.find_by(id: args[:user_id])
+ return unless args[:text]
+
+ prompt = CompletionPrompt.enabled_by_name(args[:prompt])
+
+ if prompt.id == CompletionPrompt::CUSTOM_PROMPT
+ prompt.custom_instruction = args[:custom_prompt]
+ end
+
+ DiscourseAi::AiHelper::Assistant.new.stream_prompt(
+ prompt,
+ args[:text],
+ user,
+ "/discourse-ai/ai-helper/stream_composer_suggestion",
+ force_default_locale: args[:force_default_locale],
+ )
+ end
+ end
+end
diff --git a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs
index df626e86..f181d390 100644
--- a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs
+++ b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs
@@ -237,6 +237,7 @@ export default class AiPostHelperMenu extends Component {
this._activeAiRequest = ajax(fetchUrl, {
method: "POST",
data: {
+ location: "post",
mode: option.id,
text: this.args.data.selectedText,
post_id: this.args.data.quoteState.postId,
diff --git a/assets/javascripts/discourse/components/modal/diff-modal.gjs b/assets/javascripts/discourse/components/modal/diff-modal.gjs
index 4647726d..8731cde3 100644
--- a/assets/javascripts/discourse/components/modal/diff-modal.gjs
+++ b/assets/javascripts/discourse/components/modal/diff-modal.gjs
@@ -1,45 +1,92 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
+import didInsert from "@ember/render-modifiers/modifiers/did-insert";
+import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
import { service } from "@ember/service";
import { htmlSafe } from "@ember/template";
import CookText from "discourse/components/cook-text";
import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal";
+import concatClass from "discourse/helpers/concat-class";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
+import { bind } from "discourse/lib/decorators";
import { i18n } from "discourse-i18n";
+import SmoothStreamer from "../../lib/smooth-streamer";
import AiIndicatorWave from "../ai-indicator-wave";
export default class ModalDiffModal extends Component {
@service currentUser;
+ @service messageBus;
@tracked loading = false;
@tracked diff;
@tracked suggestion = "";
+ @tracked
+ smoothStreamer = new SmoothStreamer(
+ () => this.suggestion,
+ (newValue) => (this.suggestion = newValue)
+ );
constructor() {
super(...arguments);
this.suggestChanges();
}
+ @bind
+ subscribe() {
+ const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
+ this.messageBus.subscribe(channel, this.updateResult);
+ }
+
+ @bind
+ unsubscribe() {
+ const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
+ this.messageBus.subscribe(channel, this.updateResult);
+ }
+
+ @action
+ async updateResult(result) {
+ if (result) {
+ this.loading = false;
+ }
+ await this.smoothStreamer.updateResult(result, "result");
+
+ if (result.done) {
+ this.diff = result.diff;
+ }
+
+ const mdTablePromptId = this.currentUser?.ai_helper_prompts.find(
+ (prompt) => prompt.name === "markdown_table"
+ ).id;
+
+ // Markdown table prompt looks better with
+ // before/after results than diff
+ // despite having `type: diff`
+ if (this.args.model.mode === mdTablePromptId) {
+ this.diff = null;
+ }
+ }
+
@action
async suggestChanges() {
+ this.smoothStreamer.resetStreaming();
+ this.diff = null;
+ this.suggestion = "";
this.loading = true;
try {
- const suggestion = await ajax("/discourse-ai/ai-helper/suggest", {
+ return await ajax("/discourse-ai/ai-helper/stream_suggestion", {
method: "POST",
data: {
+ location: "composer",
mode: this.args.model.mode,
text: this.args.model.selectedText,
custom_prompt: this.args.model.customPromptValue,
force_default_locale: true,
},
});
-
- this.diff = suggestion.diff;
- this.suggestion = suggestion.suggestions[0];
} catch (e) {
popupAjaxError(e);
} finally {
@@ -66,24 +113,42 @@ export default class ModalDiffModal extends Component {
@closeModal={{@closeModal}}
>
<:body>
- {{#if this.loading}}
-
-
-
- {{else}}
- {{#if this.diff}}
- {{htmlSafe this.diff}}
- {{else}}
-
- {{@model.selectedText}}
+
+ {{#if this.loading}}
+
+
-
-
- {{this.suggestion}}
+ {{else}}
+
+ {{#if this.smoothStreamer.isStreaming}}
+
+ {{else}}
+ {{#if this.diff}}
+ {{htmlSafe this.diff}}
+ {{else}}
+
+ {{@model.selectedText}}
+
+
+
+
+ {{/if}}
+ {{/if}}
{{/if}}
- {{/if}}
-
+
<:footer>
diff --git a/assets/javascripts/discourse/components/thumbnail-suggestion-item.gjs b/assets/javascripts/discourse/components/thumbnail-suggestion-item.gjs
index 2e791a33..8a6f9325 100644
--- a/assets/javascripts/discourse/components/thumbnail-suggestion-item.gjs
+++ b/assets/javascripts/discourse/components/thumbnail-suggestion-item.gjs
@@ -18,7 +18,7 @@ export default class ThumbnailSuggestionItem extends Component {
return this.args.removeSelection(thumbnail);
}
- this.selectIcon = "check-circle";
+ this.selectIcon = "circle-check";
this.selectLabel = "discourse_ai.ai_helper.thumbnail_suggestions.selected";
this.selected = true;
return this.args.addSelection(thumbnail);
diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb
index c6716aff..c15bd358 100644
--- a/lib/ai_helper/assistant.rb
+++ b/lib/ai_helper/assistant.rb
@@ -85,7 +85,7 @@ module DiscourseAi
end
end
- def localize_prompt!(prompt, user = nil, force_default_locale = false)
+ def localize_prompt!(prompt, user = nil, force_default_locale: false)
locale_instructions = custom_locale_instructions(user, force_default_locale)
if locale_instructions
prompt.messages[0][:content] = prompt.messages[0][:content] + locale_instructions
@@ -128,10 +128,10 @@ module DiscourseAi
end
end
- def generate_prompt(completion_prompt, input, user, force_default_locale = false, &block)
+ def generate_prompt(completion_prompt, input, user, force_default_locale: false, &block)
llm = helper_llm
prompt = completion_prompt.messages_with_input(input)
- localize_prompt!(prompt, user, force_default_locale)
+ localize_prompt!(prompt, user, force_default_locale: force_default_locale)
llm.generate(
prompt,
@@ -143,8 +143,14 @@ module DiscourseAi
)
end
- def generate_and_send_prompt(completion_prompt, input, user, force_default_locale = false)
- completion_result = generate_prompt(completion_prompt, input, user, force_default_locale)
+ def generate_and_send_prompt(completion_prompt, input, user, force_default_locale: false)
+ completion_result =
+ generate_prompt(
+ completion_prompt,
+ input,
+ user,
+ force_default_locale: force_default_locale,
+ )
result = { type: completion_prompt.prompt_type }
result[:suggestions] = (
@@ -160,24 +166,37 @@ module DiscourseAi
result
end
- def stream_prompt(completion_prompt, input, user, channel)
+ def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false)
+ streamed_diff = +""
streamed_result = +""
start = Time.now
- generate_prompt(completion_prompt, input, user) do |partial_response, cancel_function|
+ generate_prompt(
+ completion_prompt,
+ input,
+ user,
+ force_default_locale: force_default_locale,
+ ) do |partial_response, cancel_function|
streamed_result << partial_response
- # Throttle the updates
- if (Time.now - start > 0.5) || Rails.env.test?
- payload = { result: sanitize_result(streamed_result), done: false }
+ streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
+
+ # Throttle the updates and
+ # checking length prevents partial tags
+ # that aren't sanitized correctly yet (i.e. '