diff --git a/app/models/classification_result.rb b/app/models/classification_result.rb new file mode 100644 index 00000000..1eeb35b7 --- /dev/null +++ b/app/models/classification_result.rb @@ -0,0 +1,23 @@ +# frozen_string_literal: true + +class ClassificationResult < ActiveRecord::Base + belongs_to :target, polymorphic: true +end + +# == Schema Information +# +# Table name: classification_results +# +# id :bigint not null, primary key +# model_used :string +# classification_type :string +# target_id :integer +# target_type :string +# classification :jsonb +# created_at :datetime not null +# updated_at :datetime not null +# +# Indexes +# +# unique_classification_target_per_type (target_id,target_type,model_used) UNIQUE +# diff --git a/db/.gitkeep b/db/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/db/migrate/20230224165056_create_classification_results_table.rb b/db/migrate/20230224165056_create_classification_results_table.rb new file mode 100644 index 00000000..4b6e8a57 --- /dev/null +++ b/db/migrate/20230224165056_create_classification_results_table.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true +class CreateClassificationResultsTable < ActiveRecord::Migration[7.0] + def change + create_table :classification_results do |t| + t.string :model_used, null: true + t.string :classification_type, null: true + t.integer :target_id, null: true + t.string :target_type, null: true + + t.jsonb :classification, null: true + t.timestamps + end + + add_index :classification_results, + %i[target_id target_type model_used], + unique: true, + name: "unique_classification_target_per_type" + end +end diff --git a/lib/modules/nsfw/nsfw_classification.rb b/lib/modules/nsfw/nsfw_classification.rb index 90d688fd..97f59e41 100644 --- a/lib/modules/nsfw/nsfw_classification.rb +++ b/lib/modules/nsfw/nsfw_classification.rb @@ -14,20 +14,23 @@ module DiscourseAI def should_flag_based_on?(classification_data) return false if !SiteSetting.ai_nsfw_flag_automatically - # Flat representation of each model classification of each upload. - # Each element looks like [model_name, data] - all_classifications = classification_data.values.flatten.map { |x| x.to_a.flatten } - - all_classifications.any? { |(model_name, data)| send("#{model_name}_verdict?", data) } + classification_data.any? do |model_name, classifications| + classifications.values.any? do |data| + send("#{model_name}_verdict?", data.except(:neutral, :target_classified_type)) + end + end end def request(target_to_classify) uploads_to_classify = content_of(target_to_classify) - uploads_to_classify.reduce({}) do |memo, upload| - memo[upload.id] = available_models.reduce({}) do |per_model, model| - per_model[model] = evaluate_with_model(model, upload) - per_model + available_models.reduce({}) do |memo, model| + memo[model] = uploads_to_classify.reduce({}) do |upl_memo, upload| + upl_memo[upload.id] = evaluate_with_model(model, upload).merge( + target_classified_type: upload.class.name, + ) + + upl_memo end memo @@ -61,11 +64,9 @@ module DiscourseAI end def nsfw_detector_verdict?(classification) - classification.each do |key, value| - next if key == :neutral - return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}") + classification.any? do |key, value| + value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}") end - false end end end diff --git a/lib/modules/toxicity/toxicity_classification.rb b/lib/modules/toxicity/toxicity_classification.rb index 5767d51c..66702bd6 100644 --- a/lib/modules/toxicity/toxicity_classification.rb +++ b/lib/modules/toxicity/toxicity_classification.rb @@ -42,11 +42,15 @@ module DiscourseAI SiteSetting.ai_toxicity_inference_service_api_key, ) - { SiteSetting.ai_toxicity_inference_service_api_model => data } + { available_model => data } end private + def available_model + SiteSetting.ai_toxicity_inference_service_api_model + end + def content_of(target_to_classify) return target_to_classify.message if target_to_classify.is_a?(ChatMessage) diff --git a/lib/shared/chat_message_classification.rb b/lib/shared/chat_message_classification.rb index adad5df2..2b031faa 100644 --- a/lib/shared/chat_message_classification.rb +++ b/lib/shared/chat_message_classification.rb @@ -4,14 +4,6 @@ module ::DiscourseAI class ChatMessageClassification < Classification private - def store_classification(chat_message, type, classification_data) - PluginStore.set( - type, - "chat_message_#{chat_message.id}", - classification_data.merge(date: Time.now.utc), - ) - end - def flag!(chat_message, _toxic_labels) Chat::ChatReviewQueue.new.flag_message( chat_message, diff --git a/lib/shared/classification.rb b/lib/shared/classification.rb index 1b916866..35e0ee70 100644 --- a/lib/shared/classification.rb +++ b/lib/shared/classification.rb @@ -12,7 +12,7 @@ module ::DiscourseAI classification_model .request(target) .tap do |classification| - store_classification(target, classification_model.type, classification) + store_classification(target, classification) if classification_model.should_flag_based_on?(classification) flag!(target, classification) @@ -28,8 +28,25 @@ module ::DiscourseAI raise NotImplemented end - def store_classification(_target, _classification) - raise NotImplemented + def store_classification(target, classification) + attrs = + classification.map do |model_name, classifications| + { + model_used: model_name, + target_id: target.id, + target_type: target.class.name, + classification_type: classification_model.type, + classification: classifications, + updated_at: DateTime.now, + created_at: DateTime.now, + } + end + + ClassificationResult.upsert_all( + attrs, + unique_by: %i[target_id target_type model_used], + update_only: %i[classification], + ) end def flagger diff --git a/lib/shared/post_classification.rb b/lib/shared/post_classification.rb index 807eb139..62cfd83e 100644 --- a/lib/shared/post_classification.rb +++ b/lib/shared/post_classification.rb @@ -4,10 +4,6 @@ module ::DiscourseAI class PostClassification < Classification private - def store_classification(post, type, classification_data) - PostCustomField.create!(post_id: post.id, name: type, value: classification_data.to_json) - end - def flag!(post, classification_type) PostActionCreator.new( flagger, diff --git a/plugin.rb b/plugin.rb index 7334f7e5..12023bd6 100644 --- a/plugin.rb +++ b/plugin.rb @@ -14,6 +14,8 @@ after_initialize do PLUGIN_NAME = "discourse-ai" end + require_relative "app/models/classification_result" + require_relative "lib/shared/inference_manager" require_relative "lib/shared/classification" require_relative "lib/shared/post_classification" diff --git a/spec/lib/modules/nsfw/nsfw_classification_spec.rb b/spec/lib/modules/nsfw/nsfw_classification_spec.rb index 4e5176e1..727dbc22 100644 --- a/spec/lib/modules/nsfw/nsfw_classification_spec.rb +++ b/spec/lib/modules/nsfw/nsfw_classification_spec.rb @@ -8,19 +8,15 @@ describe DiscourseAI::NSFW::NSFWClassification do let(:available_models) { SiteSetting.ai_nsfw_models.split("|") } + fab!(:upload_1) { Fabricate(:s3_image_upload) } + fab!(:post) { Fabricate(:post, uploads: [upload_1]) } + describe "#request" do - fab!(:upload_1) { Fabricate(:s3_image_upload) } - fab!(:post) { Fabricate(:post, uploads: [upload_1]) } - - def assert_correctly_classified(upload, results, expected) - available_models.each do |model| - model_result = results.dig(upload.id, model) - - expect(model_result).to eq(expected[model]) - end + def assert_correctly_classified(results, expected) + available_models.each { |model| expect(results[model]).to eq(expected[model]) } end - def build_expected_classification(positive: true) + def build_expected_classification(target, positive: true) available_models.reduce({}) do |memo, model| model_expected = if positive @@ -29,7 +25,9 @@ describe DiscourseAI::NSFW::NSFWClassification do NSFWInferenceStubs.negative_result(model) end - memo[model] = model_expected + memo[model] = { + target.id => model_expected.merge(target_classified_type: target.class.name), + } memo end end @@ -37,11 +35,11 @@ describe DiscourseAI::NSFW::NSFWClassification do context "when the target has one upload" do it "returns the classification and the model used for it" do NSFWInferenceStubs.positive(upload_1) - expected = build_expected_classification + expected = build_expected_classification(upload_1) classification = subject.request(post) - assert_correctly_classified(upload_1, classification, expected) + assert_correctly_classified(classification, expected) end context "when the target has multiple uploads" do @@ -52,13 +50,14 @@ describe DiscourseAI::NSFW::NSFWClassification do it "returns a classification for each one" do NSFWInferenceStubs.positive(upload_1) NSFWInferenceStubs.negative(upload_2) - expected_upload_1 = build_expected_classification - expected_upload_2 = build_expected_classification(positive: false) + expected_classification = build_expected_classification(upload_1) + expected_classification.deep_merge!( + build_expected_classification(upload_2, positive: false), + ) classification = subject.request(post) - assert_correctly_classified(upload_1, classification, expected_upload_1) - assert_correctly_classified(upload_2, classification, expected_upload_2) + assert_correctly_classified(classification, expected_classification) end end end @@ -69,15 +68,23 @@ describe DiscourseAI::NSFW::NSFWClassification do let(:positive_classification) do { - 1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, - 2 => available_models.map { |m| { m => NSFWInferenceStubs.positive_result(m) } }, + "opennsfw2" => { + 1 => NSFWInferenceStubs.negative_result("opennsfw2"), + 2 => NSFWInferenceStubs.positive_result("opennsfw2"), + }, + "nsfw_detector" => { + 1 => NSFWInferenceStubs.negative_result("nsfw_detector"), + 2 => NSFWInferenceStubs.positive_result("nsfw_detector"), + }, } end let(:negative_classification) do { - 1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, - 2 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, + "opennsfw2" => { + 1 => NSFWInferenceStubs.negative_result("opennsfw2"), + 2 => NSFWInferenceStubs.negative_result("opennsfw2"), + }, } end diff --git a/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb b/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb index e4237ba1..cf108e0d 100644 --- a/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb +++ b/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb @@ -18,19 +18,19 @@ describe Jobs::PostSentimentAnalysis do subject.execute({ post_id: post.id }) - expect(PostCustomField.where(post: post).count).to be_zero + expect(ClassificationResult.where(target: post).count).to be_zero end it "does nothing if there's no arg called post_id" do subject.execute({}) - expect(PostCustomField.where(post: post).count).to be_zero + expect(ClassificationResult.where(target: post).count).to be_zero end it "does nothing if no post match the given id" do subject.execute({ post_id: nil }) - expect(PostCustomField.where(post: post).count).to be_zero + expect(ClassificationResult.where(target: post).count).to be_zero end it "does nothing if the post content is blank" do @@ -38,7 +38,7 @@ describe Jobs::PostSentimentAnalysis do subject.execute({ post_id: post.id }) - expect(PostCustomField.where(post: post).count).to be_zero + expect(ClassificationResult.where(target: post).count).to be_zero end end @@ -48,7 +48,7 @@ describe Jobs::PostSentimentAnalysis do subject.execute({ post_id: post.id }) - expect(PostCustomField.where(post: post).count).to eq(expected_analysis) + expect(ClassificationResult.where(target: post).count).to eq(expected_analysis) end end end diff --git a/spec/lib/modules/sentiment/sentiment_classification_spec.rb b/spec/lib/modules/sentiment/sentiment_classification_spec.rb index 8fe2a3b8..a1c827fb 100644 --- a/spec/lib/modules/sentiment/sentiment_classification_spec.rb +++ b/spec/lib/modules/sentiment/sentiment_classification_spec.rb @@ -4,9 +4,9 @@ require "rails_helper" require_relative "../../../support/sentiment_inference_stubs" describe DiscourseAI::Sentiment::SentimentClassification do - describe "#request" do - fab!(:target) { Fabricate(:post) } + fab!(:target) { Fabricate(:post) } + describe "#request" do before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" } it "returns the classification and the model used for it" do diff --git a/spec/lib/modules/toxicity/toxicity_classification_spec.rb b/spec/lib/modules/toxicity/toxicity_classification_spec.rb index 3edd2b91..5ad0fa36 100644 --- a/spec/lib/modules/toxicity/toxicity_classification_spec.rb +++ b/spec/lib/modules/toxicity/toxicity_classification_spec.rb @@ -4,9 +4,9 @@ require "rails_helper" require_relative "../../../support/toxicity_inference_stubs" describe DiscourseAI::Toxicity::ToxicityClassification do - describe "#request" do - fab!(:target) { Fabricate(:post) } + fab!(:target) { Fabricate(:post) } + describe "#request" do it "returns the classification and the model used for it" do ToxicityInferenceStubs.stub_post_classification(target, toxic: false) diff --git a/spec/shared/chat_message_classification_spec.rb b/spec/shared/chat_message_classification_spec.rb index 1b7d49e0..735055c9 100644 --- a/spec/shared/chat_message_classification_spec.rb +++ b/spec/shared/chat_message_classification_spec.rb @@ -12,15 +12,14 @@ describe DiscourseAI::ChatMessageClassification do describe "#classify!" do before { ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) } - it "stores the model classification data in a custom field" do + it "stores the model classification data" do classification.classify!(chat_message) - store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}") - classified_data = - store_row[SiteSetting.ai_toxicity_inference_service_api_model].symbolize_keys + result = ClassificationResult.find_by(target: chat_message, classification_type: model.type) - expect(classified_data).to eq(ToxicityInferenceStubs.toxic_response) - expect(store_row[:date]).to be_present + classification = result.classification.symbolize_keys + + expect(classification).to eq(ToxicityInferenceStubs.toxic_response) end it "flags the message when the model decides we should" do diff --git a/spec/shared/classification_spec.rb b/spec/shared/classification_spec.rb new file mode 100644 index 00000000..da26cd3b --- /dev/null +++ b/spec/shared/classification_spec.rb @@ -0,0 +1,80 @@ +# frozen_string_literal: true + +require "rails_helper" +require_relative "../support/sentiment_inference_stubs" + +describe DiscourseAI::Classification do + describe "#classify!" do + describe "saving the classification result" do + let(:classification_raw_result) do + model + .available_models + .reduce({}) do |memo, model_name| + memo[model_name] = SentimentInferenceStubs.model_response(model_name) + memo + end + end + + let(:model) { DiscourseAI::Sentiment::SentimentClassification.new } + let(:classification) { DiscourseAI::PostClassification.new(model) } + fab!(:target) { Fabricate(:post) } + + before do + SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" + SentimentInferenceStubs.stub_classification(target) + end + + it "stores one result per model used" do + classification.classify!(target) + + stored_results = ClassificationResult.where(target: target) + expect(stored_results.length).to eq(model.available_models.length) + + model.available_models.each do |model_name| + result = stored_results.detect { |c| c.model_used == model_name } + + expect(result.classification_type).to eq(model.type.to_s) + expect(result.created_at).to be_present + expect(result.updated_at).to be_present + + expected_classification = SentimentInferenceStubs.model_response(model) + + expect(result.classification.deep_symbolize_keys).to eq(expected_classification) + end + end + + it "updates an existing classification result" do + original_creation = 3.days.ago + + model.available_models.each do |model_name| + ClassificationResult.create!( + target: target, + model_used: model_name, + classification_type: model.type, + created_at: original_creation, + updated_at: original_creation, + classification: { + }, + ) + end + + classification.classify!(target) + + stored_results = ClassificationResult.where(target: target) + expect(stored_results.length).to eq(model.available_models.length) + + model.available_models.each do |model_name| + result = stored_results.detect { |c| c.model_used == model_name } + + expect(result.classification_type).to eq(model.type.to_s) + expect(result.updated_at).to be > original_creation + expect(result.created_at).to eq_time(original_creation) + + expect(result.classification.deep_symbolize_keys).to eq( + classification_raw_result[model_name], + ) + end + end + end + end +end diff --git a/spec/shared/post_classification_spec.rb b/spec/shared/post_classification_spec.rb index ca5714fd..6c0da133 100644 --- a/spec/shared/post_classification_spec.rb +++ b/spec/shared/post_classification_spec.rb @@ -12,16 +12,13 @@ describe DiscourseAI::PostClassification do describe "#classify!" do before { ToxicityInferenceStubs.stub_post_classification(post, toxic: true) } - it "stores the model classification data in a custom field" do + it "stores the model classification data" do classification.classify!(post) - custom_field = PostCustomField.find_by(post: post, name: model.type) + result = ClassificationResult.find_by(target: post, classification_type: model.type) - expect(custom_field.value).to eq( - { - SiteSetting.ai_toxicity_inference_service_api_model => - ToxicityInferenceStubs.toxic_response, - }.to_json, - ) + classification = result.classification.symbolize_keys + + expect(classification).to eq(ToxicityInferenceStubs.toxic_response) end it "flags the message and hides the post when the model decides we should" do