diff --git a/config/settings.yml b/config/settings.yml index 0f86f47e..0dbcf624 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -122,6 +122,8 @@ discourse_ai: default: "stable-diffusion-xl-1024-v1-0" type: enum choices: + - "sd3" + - "sd3-turbo" - "stable-diffusion-xl-1024-v1-0" - "stable-diffusion-768-v2-1" - "stable-diffusion-v1-5" diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index b9efb83d..d2aa873e 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -203,10 +203,6 @@ module DiscourseAi end end - def tool_invocation?(partial) - Nokogiri::HTML5.fragment(partial).at("invoke").present? - end - def build_placeholder(summary, details, custom_raw: nil) placeholder = +(<<~HTML)
diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 5c14d418..86ee6889 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -188,7 +188,7 @@ module DiscourseAi begin JSON.parse(value) rescue JSON::ParserError - nil + [value.to_s] end end diff --git a/lib/ai_bot/tools/image.rb b/lib/ai_bot/tools/image.rb index 92f3a3d4..d25810c4 100644 --- a/lib/ai_bot/tools/image.rb +++ b/lib/ai_bot/tools/image.rb @@ -24,7 +24,13 @@ module DiscourseAi "The seed used to generate the image (optional) - can be used to retain image style on amended prompts", type: "array", item_type: "integer", - required: true, + }, + { + name: "aspect_ratio", + description: "The aspect ratio of the image (optional defaults to 1:1)", + type: "string", + required: false, + enum: %w[16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21], }, ], } @@ -35,7 +41,11 @@ module DiscourseAi end def prompts - JSON.parse(parameters[:prompts].to_s) + parameters[:prompts] + end + + def aspect_ratio + parameters[:aspect_ratio] end def seeds @@ -75,6 +85,7 @@ module DiscourseAi api_url: api_url, image_count: 1, seed: inner_seed, + aspect_ratio: aspect_ratio, ) rescue => e attempts += 1 @@ -116,7 +127,7 @@ module DiscourseAi #{ uploads .map do |item| - "![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})" + "![#{item[:prompt].gsub(/\|\'\"/, "")}|50%](#{item[:upload].short_url})" end .join(" ") } diff --git a/lib/ai_helper/painter.rb b/lib/ai_helper/painter.rb index c50d2de9..63a6bc7d 100644 --- a/lib/ai_helper/painter.rb +++ b/lib/ai_helper/painter.rb @@ -7,7 +7,6 @@ module DiscourseAi return [] if input.blank? model = SiteSetting.ai_helper_illustrate_post_model - attribution = "discourse_ai.ai_helper.painter.attribution.#{model}" if model == "stable_diffusion_xl" stable_diffusion_prompt = diffusion_prompt(input, user) diff --git a/lib/inference/stability_generator.rb b/lib/inference/stability_generator.rb index 47bc6814..124f5d8a 100644 --- a/lib/inference/stability_generator.rb +++ b/lib/inference/stability_generator.rb @@ -3,10 +3,74 @@ module ::DiscourseAi module Inference class StabilityGenerator + TIMEOUT = 120 + + # there is a new api for sd3 + def self.perform_sd3!( + prompt, + aspect_ratio: nil, + api_key: nil, + engine: nil, + api_url: nil, + output_format: "png", + seed: nil + ) + api_key ||= SiteSetting.ai_stability_api_key + engine ||= SiteSetting.ai_stability_engine + api_url ||= SiteSetting.ai_stability_api_url + + allowed_ratios = %w[16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21] + + aspect_ratio = "1:1" if !aspect_ratio || !allowed_ratios.include?(aspect_ratio) + + payload = { + prompt: prompt, + mode: "text-to-image", + model: engine, + output_format: output_format, + aspect_ratio: aspect_ratio, + } + + payload[:seed] = seed if seed + + endpoint = "v2beta/stable-image/generate/sd3" + + form_data = payload.to_a.map { |k, v| [k.to_s, v.to_s] } + + uri = URI("#{api_url}/#{endpoint}") + request = FinalDestination::HTTP::Post.new(uri) + + request["authorization"] = "Bearer #{api_key}" + request["accept"] = "application/json" + request["User-Agent"] = DiscourseAi::AiBot::USER_AGENT + request.set_form form_data, "multipart/form-data" + + response = + FinalDestination::HTTP.start( + uri.hostname, + uri.port, + use_ssl: uri.port != 80, + read_timeout: TIMEOUT, + open_timeout: TIMEOUT, + write_timeout: TIMEOUT, + ) { |http| http.request(request) } + + if response.code != "200" + Rails.logger.error( + "AI stability generator failed with status #{response.code}: #{response.body}}", + ) + raise Net::HTTPBadResponse + end + + parsed = JSON.parse(response.body, symbolize_names: true) + + # remap to old format + { artifacts: [{ base64: parsed[:image], seed: parsed[:seed] }] } + end + def self.perform!( prompt, - width: nil, - height: nil, + aspect_ratio: nil, api_key: nil, engine: nil, api_url: nil, @@ -17,30 +81,52 @@ module ::DiscourseAi engine ||= SiteSetting.ai_stability_engine api_url ||= SiteSetting.ai_stability_api_url + image_count = 4 if image_count > 4 + + if engine.start_with? "sd3" + artifacts = + image_count.times.map do + perform_sd3!( + prompt, + api_key: api_key, + engine: engine, + api_url: api_url, + aspect_ratio: aspect_ratio, + seed: seed, + )[ + :artifacts + ][ + 0 + ] + end + + return { artifacts: artifacts } + end + headers = { "Content-Type" => "application/json", "Accept" => "application/json", "Authorization" => "Bearer #{api_key}", } - sdxl_allowed_dimensions = [ - [1024, 1024], - [1152, 896], - [1216, 832], - [1344, 768], - [1536, 640], - [640, 1536], - [768, 1344], - [832, 1216], - [896, 1152], - ] + ratio_to_dimension = { + "16:9" => [1536, 640], + "1:1" => [1024, 1024], + "21:9" => [1344, 768], + "2:3" => [896, 1152], + "3:2" => [1152, 896], + "4:5" => [832, 1216], + "5:4" => [1216, 832], + "9:16" => [640, 1536], + "9:21" => [768, 1344], + } - if (!width && !height) - if engine.include? "xl" - width, height = sdxl_allowed_dimensions[0] - else - width, height = [512, 512] - end + if engine.include? "xl" + width, height = ratio_to_dimension[aspect_ratio] if aspect_ratio + + width, height = [1024, 1024] if !width || !height + else + width, height = [512, 512] end payload = { diff --git a/spec/shared/inference/stability_generator_spec.rb b/spec/shared/inference/stability_generator_spec.rb index 34f98ef3..45ee4dbb 100644 --- a/spec/shared/inference/stability_generator_spec.rb +++ b/spec/shared/inference/stability_generator_spec.rb @@ -5,6 +5,36 @@ describe DiscourseAi::Inference::StabilityGenerator do DiscourseAi::Inference::StabilityGenerator.perform!(prompt) end + let :sd3_response do + { image: "BASE64", seed: 1 }.to_json + end + + it "is able to generate sd3 images" do + SiteSetting.ai_stability_engine = "sd3" + SiteSetting.ai_stability_api_url = "http://www.a.b.c" + SiteSetting.ai_stability_api_key = "123" + + # webmock does not support multipart form data :( + stub_request(:post, "http://www.a.b.c/v2beta/stable-image/generate/sd3").with( + headers: { + "Accept" => "application/json", + "Authorization" => "Bearer 123", + "Content-Type" => "multipart/form-data", + "Host" => "www.a.b.c", + "User-Agent" => DiscourseAi::AiBot::USER_AGENT, + }, + ).to_return(status: 200, body: sd3_response, headers: {}) + + json = + DiscourseAi::Inference::StabilityGenerator.perform!( + "a cow", + aspect_ratio: "16:9", + image_count: 2, + ) + + expect(json).to eq(artifacts: [{ base64: "BASE64", seed: 1 }, { base64: "BASE64", seed: 1 }]) + end + it "sets dimensions to 512x512 for non XL model" do SiteSetting.ai_stability_engine = "stable-diffusion-v1-5" SiteSetting.ai_stability_api_url = "http://www.a.b.c"