From 6add06af8f5a97e36ce41ea262ceb01fedfe9d55 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 27 Oct 2023 14:48:12 +1100 Subject: [PATCH] FEATURE: Make artist more creative (#266) This allows for 2 big features: 1. Artist can ship up to 4 prompts for image generation 2. Artist can regenerate images cause it is aware of seed This allows for iteration on images maintaining visual style --- lib/modules/ai_bot/commands/command.rb | 5 +- lib/modules/ai_bot/commands/image_command.rb | 103 ++++++++++++------ lib/modules/ai_bot/personas/artist.rb | 4 + lib/shared/inference/function.rb | 24 ++-- lib/shared/inference/stability_generator.rb | 15 ++- .../ai_bot/commands/image_command_spec.rb | 28 ++++- 6 files changed, 125 insertions(+), 54 deletions(-) diff --git a/lib/modules/ai_bot/commands/command.rb b/lib/modules/ai_bot/commands/command.rb index 024f6dfa..47ded9d5 100644 --- a/lib/modules/ai_bot/commands/command.rb +++ b/lib/modules/ai_bot/commands/command.rb @@ -4,13 +4,14 @@ module DiscourseAi module AiBot module Commands class Parameter - attr_reader :name, :description, :type, :enum, :required - def initialize(name:, description:, type:, enum: nil, required: false) + attr_reader :item_type, :name, :description, :type, :enum, :required + def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil) @name = name @description = description @type = type @enum = enum @required = required + @item_type = item_type end end diff --git a/lib/modules/ai_bot/commands/image_command.rb b/lib/modules/ai_bot/commands/image_command.rb index b5843b2a..22cd83fc 100644 --- a/lib/modules/ai_bot/commands/image_command.rb +++ b/lib/modules/ai_bot/commands/image_command.rb @@ -14,12 +14,20 @@ module DiscourseAi::AiBot::Commands def parameters [ Parameter.new( - name: "prompt", + name: "prompts", description: - "The prompt used to generate or create or draw the image (40 words or less, be creative)", - type: "string", + "The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts", + type: "array", + item_type: "string", required: true, ), + Parameter.new( + name: "seeds", + description: + "The seed used to generate the image (optional) - can be used to retain image style on amended prompts", + type: "array", + item_type: "integer", + ), ] end end @@ -40,8 +48,12 @@ module DiscourseAi::AiBot::Commands @custom_raw end - def process(prompt:) - @last_prompt = prompt + def process(prompts:, seeds: nil) + # max 4 prompts + prompts = prompts[0..3] + seeds = seeds[0..3] if seeds + + @last_prompt = prompts[0] show_progress(localized_description) @@ -53,41 +65,55 @@ module DiscourseAi::AiBot::Commands engine = SiteSetting.ai_stability_engine api_url = SiteSetting.ai_stability_api_url - # API is flaky, so try a few times - 3.times do - begin - thread = - Thread.new do - begin - results = - DiscourseAi::Inference::StabilityGenerator.perform!( - prompt, - engine: engine, - api_key: api_key, - api_url: api_url, - ) - rescue => e - Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}") - end - end - - show_progress(".", progress_caret: true) while !thread.join(2) - - break if results + threads = [] + prompts.each_with_index do |prompt, index| + seed = seeds ? seeds[index] : nil + threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt| + attempts = 0 + begin + DiscourseAi::Inference::StabilityGenerator.perform!( + inner_prompt, + engine: engine, + api_key: api_key, + api_url: api_url, + image_count: 1, + seed: inner_seed, + ) + rescue => e + attempts += 1 + retry if attempts < 3 + Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}") + nil + end end end - return { prompt: prompt, error: "Something went wrong, could not generate image" } if !results + while true + show_progress(".", progress_caret: true) + break if threads.all? { |t| t.join(2) } + end + + results = threads.map(&:value).compact + + if !results.present? + return { prompt: prompt, error: "Something went wrong, could not generate image" } + end uploads = [] - results[:artifacts].each_with_index do |image, i| - f = Tempfile.new("v1_txt2img_#{i}.png") - f.binmode - f.write(Base64.decode64(image[:base64])) - f.rewind - uploads << UploadCreator.new(f, "image.png").create_for(bot_user.id) - f.unlink + results.each_with_index do |result, index| + result[:artifacts].each do |image| + Tempfile.create("v1_txt2img_#{index}.png") do |file| + file.binmode + file.write(Base64.decode64(image[:base64])) + file.rewind + uploads << { + prompt: prompts[index], + upload: UploadCreator.new(file, "image.png").create_for(bot_user.id), + seed: image[:seed], + } + end + end end @custom_raw = <<~RAW @@ -95,13 +121,18 @@ module DiscourseAi::AiBot::Commands [grid] #{ uploads - .map { |upload| "![#{prompt.gsub(/\|\'\"/, "")}|512x512, 50%](#{upload.short_url})" } + .map do |item| + "![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})" + end .join(" ") } [/grid] RAW - { prompt: prompt, displayed_to_user: true } + { + prompts: uploads.map { |item| { prompt: item[:prompt], seed: item[:seed] } }, + displayed_to_user: true, + } end end end diff --git a/lib/modules/ai_bot/personas/artist.rb b/lib/modules/ai_bot/personas/artist.rb index 9f520c57..a90e7c5a 100644 --- a/lib/modules/ai_bot/personas/artist.rb +++ b/lib/modules/ai_bot/personas/artist.rb @@ -23,6 +23,10 @@ module DiscourseAi - Do not include any connector words such as "and" or "but" etc. - You are extremely creative, when given short non descriptive prompts from a user you add your own details + - When generating images, usually opt to generate 4 images unless the user specifies otherwise. + - Be creative with your prompts, offer diverse options + - You can use the seeds to regenerate the same image and amend the prompt keeping general style + {commands} PROMPT diff --git a/lib/shared/inference/function.rb b/lib/shared/inference/function.rb index d84acd42..e19ae931 100644 --- a/lib/shared/inference/function.rb +++ b/lib/shared/inference/function.rb @@ -20,20 +20,26 @@ module ::DiscourseAi description: parameter.description, required: parameter.required, enum: parameter.enum, + item_type: parameter.item_type, ) else add_parameter_kwargs(**kwargs) end end - def add_parameter_kwargs(name:, type:, description:, enum: nil, required: false) - @parameters << { - name: name, - type: type, - description: description, - enum: enum, - required: required, - } + def add_parameter_kwargs( + name:, + type:, + description:, + enum: nil, + required: false, + item_type: nil + ) + param = { name: name, type: type, description: description, enum: enum, required: required } + param[:enum] = enum if enum + param[:item_type] = item_type if item_type + + @parameters << param end def to_json(*args) @@ -47,7 +53,7 @@ module ::DiscourseAi parameters.each do |parameter| definition = { type: parameter[:type], description: parameter[:description] } definition[:enum] = parameter[:enum] if parameter[:enum] - + definition[:items] = { type: parameter[:item_type] } if parameter[:item_type] required_params << parameter[:name] if parameter[:required] properties[parameter[:name]] = definition end diff --git a/lib/shared/inference/stability_generator.rb b/lib/shared/inference/stability_generator.rb index 04a8ca83..33447e1d 100644 --- a/lib/shared/inference/stability_generator.rb +++ b/lib/shared/inference/stability_generator.rb @@ -3,7 +3,16 @@ module ::DiscourseAi module Inference class StabilityGenerator - def self.perform!(prompt, width: nil, height: nil, api_key: nil, engine: nil, api_url: nil) + def self.perform!( + prompt, + width: nil, + height: nil, + api_key: nil, + engine: nil, + api_url: nil, + image_count: 4, + seed: nil + ) api_key ||= SiteSetting.ai_stability_api_key engine ||= SiteSetting.ai_stability_engine api_url ||= SiteSetting.ai_stability_api_url @@ -40,10 +49,12 @@ module ::DiscourseAi clip_guidance_preset: "FAST_BLUE", height: width, width: height, - samples: 4, + samples: image_count, steps: 30, } + payload[:seed] = seed if seed + endpoint = "v1/generation/#{engine}/text-to-image" response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers) diff --git a/spec/lib/modules/ai_bot/commands/image_command_spec.rb b/spec/lib/modules/ai_bot/commands/image_command_spec.rb index c2fc7b81..ddecb5f2 100644 --- a/spec/lib/modules/ai_bot/commands/image_command_spec.rb +++ b/spec/lib/modules/ai_bot/commands/image_command_spec.rb @@ -1,7 +1,5 @@ #frozen_string_literal: true -require_relative "../../../../support/stable_difussion_stubs" - RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } @@ -17,16 +15,36 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" - StableDiffusionStubs.new.stub_response("a pink cow", [image, image]) + artifacts = [{ base64: image, seed: 99 }] + prompts = ["a pink cow", "a red cow"] + + WebMock + .stub_request( + :post, + "https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image", + ) + .with do |request| + json = JSON.parse(request.body, symbolize_names: true) + expect(prompts).to include(json[:text_prompts][0][:text]) + true + end + .to_return(status: 200, body: { artifacts: artifacts }.to_json) image = described_class.new(bot_user: bot_user, post: post, args: nil) - info = image.process(prompt: "a pink cow").to_json + info = image.process(prompts: prompts).to_json - expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true) + expect(JSON.parse(info)).to eq( + "prompts" => [ + { "prompt" => "a pink cow", "seed" => 99 }, + { "prompt" => "a red cow", "seed" => 99 }, + ], + "displayed_to_user" => true, + ) expect(image.custom_raw).to include("upload://") expect(image.custom_raw).to include("[grid]") expect(image.custom_raw).to include("a pink cow") + expect(image.custom_raw).to include("a red cow") end end end