From 5f9597474c1cf78cd67bcd4d0ec40b0d9149861a Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Fri, 24 Feb 2023 13:25:02 -0300 Subject: [PATCH] REFACTOR: Streamline flag and classification process --- lib/modules/nsfw/entry_point.rb | 5 +- lib/modules/nsfw/evaluation.rb | 50 --------- .../jobs/regular/evaluate_post_uploads.rb | 8 +- lib/modules/nsfw/nsfw_classification.rb | 72 ++++++++++++ lib/modules/sentiment/entry_point.rb | 2 +- .../jobs/regular/post_sentiment_analysis.rb | 4 +- lib/modules/sentiment/post_classifier.rb | 42 ------- .../sentiment/sentiment_classification.rb | 52 +++++++++ .../toxicity/chat_message_classifier.rb | 33 ------ lib/modules/toxicity/classifier.rb | 66 ----------- lib/modules/toxicity/entry_point.rb | 4 +- .../regular/toxicity_classify_chat_message.rb | 4 +- .../jobs/regular/toxicity_classify_post.rb | 4 +- lib/modules/toxicity/post_classifier.rb | 28 ----- .../toxicity/toxicity_classification.rb | 61 ++++++++++ lib/shared/chat_message_classification.rb | 24 ++++ lib/shared/classification.rb | 39 +++++++ lib/shared/flag_manager.rb | 27 ----- lib/shared/inference_manager.rb | 2 +- lib/shared/post_classification.rb | 23 ++++ plugin.rb | 4 +- spec/lib/modules/nsfw/evaluation_spec.rb | 49 --------- .../regular/evaluate_post_uploads_spec.rb | 20 ---- .../modules/nsfw/nsfw_classification_spec.rb | 104 ++++++++++++++++++ .../modules/sentiment/post_classifier_spec.rb | 26 ----- .../sentiment_classification_spec.rb | 22 ++++ .../toxicity/chat_message_classifier_spec.rb | 48 -------- .../modules/toxicity/post_classifier_spec.rb | 51 --------- .../toxicity/toxicity_classification_spec.rb | 56 ++++++++++ .../chat_message_classification_spec.rb | 42 +++++++ spec/shared/post_classification_spec.rb | 44 ++++++++ spec/support/sentiment_inference_stubs.rb | 2 +- 32 files changed, 560 insertions(+), 458 deletions(-) delete mode 100644 lib/modules/nsfw/evaluation.rb create mode 100644 lib/modules/nsfw/nsfw_classification.rb delete mode 100644 lib/modules/sentiment/post_classifier.rb create mode 100644 lib/modules/sentiment/sentiment_classification.rb delete mode 100644 lib/modules/toxicity/chat_message_classifier.rb delete mode 100644 lib/modules/toxicity/classifier.rb delete mode 100644 lib/modules/toxicity/post_classifier.rb create mode 100644 lib/modules/toxicity/toxicity_classification.rb create mode 100644 lib/shared/chat_message_classification.rb create mode 100644 lib/shared/classification.rb delete mode 100644 lib/shared/flag_manager.rb create mode 100644 lib/shared/post_classification.rb delete mode 100644 spec/lib/modules/nsfw/evaluation_spec.rb create mode 100644 spec/lib/modules/nsfw/nsfw_classification_spec.rb delete mode 100644 spec/lib/modules/sentiment/post_classifier_spec.rb create mode 100644 spec/lib/modules/sentiment/sentiment_classification_spec.rb delete mode 100644 spec/lib/modules/toxicity/chat_message_classifier_spec.rb delete mode 100644 spec/lib/modules/toxicity/post_classifier_spec.rb create mode 100644 spec/lib/modules/toxicity/toxicity_classification_spec.rb create mode 100644 spec/shared/chat_message_classification_spec.rb create mode 100644 spec/shared/post_classification_spec.rb diff --git a/lib/modules/nsfw/entry_point.rb b/lib/modules/nsfw/entry_point.rb index 07dd7a44..2b3445d0 100644 --- a/lib/modules/nsfw/entry_point.rb +++ b/lib/modules/nsfw/entry_point.rb @@ -4,14 +4,15 @@ module DiscourseAI module NSFW class EntryPoint def load_files - require_relative "evaluation" + require_relative "nsfw_classification" require_relative "jobs/regular/evaluate_post_uploads" end def inject_into(plugin) nsfw_detection_cb = Proc.new do |post| - if SiteSetting.ai_nsfw_detection_enabled && post.uploads.present? + if SiteSetting.ai_nsfw_detection_enabled && + DiscourseAI::NSFW::NSFWClassification.new.can_classify?(post) Jobs.enqueue(:evaluate_post_uploads, post_id: post.id) end end diff --git a/lib/modules/nsfw/evaluation.rb b/lib/modules/nsfw/evaluation.rb deleted file mode 100644 index 7f060cd9..00000000 --- a/lib/modules/nsfw/evaluation.rb +++ /dev/null @@ -1,50 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAI - module NSFW - class Evaluation - def perform(upload) - result = { verdict: false, evaluation: {} } - - SiteSetting - .ai_nsfw_models - .split("|") - .each do |model| - model_result = evaluate_with_model(model, upload).symbolize_keys! - - result[:evaluation][model.to_sym] = model_result - - result[:verdict] = send("#{model}_verdict?", model_result) - end - - result - end - - private - - def evaluate_with_model(model, upload) - upload_url = Discourse.store.cdn_url(upload.url) - upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/") - - DiscourseAI::InferenceManager.perform!( - "#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify", - model, - upload_url, - SiteSetting.ai_nsfw_inference_service_api_key, - ) - end - - def opennsfw2_verdict?(clasification) - clasification.values.first.to_i >= SiteSetting.ai_nsfw_flag_threshold_general - end - - def nsfw_detector_verdict?(classification) - classification.each do |key, value| - next if key == :neutral - return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}") - end - false - end - end - end -end diff --git a/lib/modules/nsfw/jobs/regular/evaluate_post_uploads.rb b/lib/modules/nsfw/jobs/regular/evaluate_post_uploads.rb index 0077d662..da34a630 100644 --- a/lib/modules/nsfw/jobs/regular/evaluate_post_uploads.rb +++ b/lib/modules/nsfw/jobs/regular/evaluate_post_uploads.rb @@ -9,13 +9,9 @@ module Jobs post = Post.includes(:uploads).find_by_id(post_id) return if post.nil? || post.uploads.empty? - nsfw_evaluation = DiscourseAI::NSFW::Evaluation.new + return if post.uploads.none? { |u| FileHelper.is_supported_image?(u.url) } - image_uploads = post.uploads.select { |upload| FileHelper.is_supported_image?(upload.url) } - - results = image_uploads.map { |upload| nsfw_evaluation.perform(upload) } - - DiscourseAI::FlagManager.new(post).flag! if results.any? { |r| r[:verdict] } + DiscourseAI::PostClassification.new(DiscourseAI::NSFW::NSFWClassification.new).classify!(post) end end end diff --git a/lib/modules/nsfw/nsfw_classification.rb b/lib/modules/nsfw/nsfw_classification.rb new file mode 100644 index 00000000..90d688fd --- /dev/null +++ b/lib/modules/nsfw/nsfw_classification.rb @@ -0,0 +1,72 @@ +# frozen_string_literal: true + +module DiscourseAI + module NSFW + class NSFWClassification + def type + :nsfw + end + + def can_classify?(target) + content_of(target).present? + end + + def should_flag_based_on?(classification_data) + return false if !SiteSetting.ai_nsfw_flag_automatically + + # Flat representation of each model classification of each upload. + # Each element looks like [model_name, data] + all_classifications = classification_data.values.flatten.map { |x| x.to_a.flatten } + + all_classifications.any? { |(model_name, data)| send("#{model_name}_verdict?", data) } + end + + def request(target_to_classify) + uploads_to_classify = content_of(target_to_classify) + + uploads_to_classify.reduce({}) do |memo, upload| + memo[upload.id] = available_models.reduce({}) do |per_model, model| + per_model[model] = evaluate_with_model(model, upload) + per_model + end + + memo + end + end + + private + + def evaluate_with_model(model, upload) + upload_url = Discourse.store.cdn_url(upload.url) + upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/") + + DiscourseAI::InferenceManager.perform!( + "#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify", + model, + upload_url, + SiteSetting.ai_nsfw_inference_service_api_key, + ) + end + + def available_models + SiteSetting.ai_nsfw_models.split("|") + end + + def content_of(target_to_classify) + target_to_classify.uploads.to_a.select { |u| FileHelper.is_supported_image?(u.url) } + end + + def opennsfw2_verdict?(clasification) + clasification.values.first.to_i >= SiteSetting.ai_nsfw_flag_threshold_general + end + + def nsfw_detector_verdict?(classification) + classification.each do |key, value| + next if key == :neutral + return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}") + end + false + end + end + end +end diff --git a/lib/modules/sentiment/entry_point.rb b/lib/modules/sentiment/entry_point.rb index 286b9dd6..d813b5c4 100644 --- a/lib/modules/sentiment/entry_point.rb +++ b/lib/modules/sentiment/entry_point.rb @@ -3,7 +3,7 @@ module DiscourseAI module Sentiment class EntryPoint def load_files - require_relative "post_classifier" + require_relative "sentiment_classification" require_relative "jobs/regular/post_sentiment_analysis" end diff --git a/lib/modules/sentiment/jobs/regular/post_sentiment_analysis.rb b/lib/modules/sentiment/jobs/regular/post_sentiment_analysis.rb index b6bbd063..a29a0162 100644 --- a/lib/modules/sentiment/jobs/regular/post_sentiment_analysis.rb +++ b/lib/modules/sentiment/jobs/regular/post_sentiment_analysis.rb @@ -9,7 +9,9 @@ module ::Jobs post = Post.find_by(id: post_id, post_type: Post.types[:regular]) return if post&.raw.blank? - ::DiscourseAI::Sentiment::PostClassifier.new.classify!(post) + DiscourseAI::PostClassification.new( + DiscourseAI::Sentiment::SentimentClassification.new, + ).classify!(post) end end end diff --git a/lib/modules/sentiment/post_classifier.rb b/lib/modules/sentiment/post_classifier.rb deleted file mode 100644 index 301da2a2..00000000 --- a/lib/modules/sentiment/post_classifier.rb +++ /dev/null @@ -1,42 +0,0 @@ -# frozen_string_literal: true - -module ::DiscourseAI - module Sentiment - class PostClassifier - def classify!(post) - available_models.each do |model| - classification = request_classification(post, model) - - store_classification(post, model, classification) - end - end - - def available_models - SiteSetting.ai_sentiment_models.split("|") - end - - private - - def request_classification(post, model) - ::DiscourseAI::InferenceManager.perform!( - "#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify", - model, - content(post), - SiteSetting.ai_sentiment_inference_service_api_key, - ) - end - - def content(post) - post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw - end - - def store_classification(post, model, classification) - PostCustomField.create!( - post_id: post.id, - name: "ai-sentiment-#{model}", - value: { classification: classification }.to_json, - ) - end - end - end -end diff --git a/lib/modules/sentiment/sentiment_classification.rb b/lib/modules/sentiment/sentiment_classification.rb new file mode 100644 index 00000000..ebd40b73 --- /dev/null +++ b/lib/modules/sentiment/sentiment_classification.rb @@ -0,0 +1,52 @@ +# frozen_string_literal: true + +module DiscourseAI + module Sentiment + class SentimentClassification + def type + :sentiment + end + + def available_models + SiteSetting.ai_sentiment_models.split("|") + end + + def can_classify?(target) + content_of(target).present? + end + + def should_flag_based_on?(classification_data) + # We don't flag based on sentiment classification. + false + end + + def request(target_to_classify) + target_content = content_of(target_to_classify) + + available_models.reduce({}) do |memo, model| + memo[model] = request_with(model, target_content) + memo + end + end + + private + + def request_with(model, content) + ::DiscourseAI::InferenceManager.perform!( + "#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify", + model, + content, + SiteSetting.ai_sentiment_inference_service_api_key, + ) + end + + def content_of(target_to_classify) + if target_to_classify.post_number == 1 + "#{target_to_classify.topic.title}\n#{target_to_classify.raw}" + else + target_to_classify.raw + end + end + end + end +end diff --git a/lib/modules/toxicity/chat_message_classifier.rb b/lib/modules/toxicity/chat_message_classifier.rb deleted file mode 100644 index f36f5b9f..00000000 --- a/lib/modules/toxicity/chat_message_classifier.rb +++ /dev/null @@ -1,33 +0,0 @@ -# frozen_string_literal: true - -module ::DiscourseAI - module Toxicity - class ChatMessageClassifier < Classifier - private - - def content(chat_message) - chat_message.message - end - - def store_classification(chat_message, classification) - PluginStore.set( - "toxicity", - "chat_message_#{chat_message.id}", - { - classification: classification, - model: SiteSetting.ai_toxicity_inference_service_api_model, - date: Time.now.utc, - }, - ) - end - - def flag!(chat_message, _toxic_labels) - Chat::ChatReviewQueue.new.flag_message( - chat_message, - Guardian.new(flagger), - ReviewableScore.types[:inappropriate], - ) - end - end - end -end diff --git a/lib/modules/toxicity/classifier.rb b/lib/modules/toxicity/classifier.rb deleted file mode 100644 index 5a5e6fdf..00000000 --- a/lib/modules/toxicity/classifier.rb +++ /dev/null @@ -1,66 +0,0 @@ -# frozen_string_literal: true - -module ::DiscourseAI - module Toxicity - class Classifier - CLASSIFICATION_LABELS = %w[ - toxicity - severe_toxicity - obscene - identity_attack - insult - threat - sexual_explicit - ] - - def classify!(target) - classification = request_classification(target) - - store_classification(target, classification) - - toxic_labels = filter_toxic_labels(classification) - - flag!(target, toxic_labels) if should_flag_based_on?(toxic_labels) - end - - protected - - def flag!(_target, _toxic_labels) - raise NotImplemented - end - - def store_classification(_target, _classification) - raise NotImplemented - end - - def content(_target) - raise NotImplemented - end - - def flagger - Discourse.system_user - end - - private - - def request_classification(target) - ::DiscourseAI::InferenceManager.perform!( - "#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify", - SiteSetting.ai_toxicity_inference_service_api_model, - content(target), - SiteSetting.ai_toxicity_inference_service_api_key, - ) - end - - def filter_toxic_labels(classification) - CLASSIFICATION_LABELS.filter do |label| - classification[label] >= SiteSetting.send("ai_toxicity_flag_threshold_#{label}") - end - end - - def should_flag_based_on?(toxic_labels) - SiteSetting.ai_toxicity_flag_automatically && toxic_labels.present? - end - end - end -end diff --git a/lib/modules/toxicity/entry_point.rb b/lib/modules/toxicity/entry_point.rb index b628ac2c..352f3fd8 100644 --- a/lib/modules/toxicity/entry_point.rb +++ b/lib/modules/toxicity/entry_point.rb @@ -4,9 +4,7 @@ module DiscourseAI class EntryPoint def load_files require_relative "scan_queue" - require_relative "classifier" - require_relative "post_classifier" - require_relative "chat_message_classifier" + require_relative "toxicity_classification" require_relative "jobs/regular/toxicity_classify_post" require_relative "jobs/regular/toxicity_classify_chat_message" diff --git a/lib/modules/toxicity/jobs/regular/toxicity_classify_chat_message.rb b/lib/modules/toxicity/jobs/regular/toxicity_classify_chat_message.rb index 5c316456..c0167a5f 100644 --- a/lib/modules/toxicity/jobs/regular/toxicity_classify_chat_message.rb +++ b/lib/modules/toxicity/jobs/regular/toxicity_classify_chat_message.rb @@ -10,7 +10,9 @@ module ::Jobs chat_message = ChatMessage.find_by(id: chat_message_id) return if chat_message&.message.blank? - ::DiscourseAI::Toxicity::ChatMessageClassifier.new.classify!(chat_message) + DiscourseAI::ChatMessageClassification.new( + DiscourseAI::Toxicity::ToxicityClassification.new, + ).classify!(chat_message) end end end diff --git a/lib/modules/toxicity/jobs/regular/toxicity_classify_post.rb b/lib/modules/toxicity/jobs/regular/toxicity_classify_post.rb index 3db16428..f59b16d2 100644 --- a/lib/modules/toxicity/jobs/regular/toxicity_classify_post.rb +++ b/lib/modules/toxicity/jobs/regular/toxicity_classify_post.rb @@ -11,7 +11,9 @@ module ::Jobs post = Post.find_by(id: post_id, post_type: Post.types[:regular]) return if post&.raw.blank? - ::DiscourseAI::Toxicity::PostClassifier.new.classify!(post) + DiscourseAI::PostClassification.new( + DiscourseAI::Toxicity::ToxicityClassification.new, + ).classify!(post) end end end diff --git a/lib/modules/toxicity/post_classifier.rb b/lib/modules/toxicity/post_classifier.rb deleted file mode 100644 index 65413a4c..00000000 --- a/lib/modules/toxicity/post_classifier.rb +++ /dev/null @@ -1,28 +0,0 @@ -# frozen_string_literal: true - -module ::DiscourseAI - module Toxicity - class PostClassifier < Classifier - private - - def content(post) - post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw - end - - def store_classification(post, classification) - PostCustomField.create!( - post_id: post.id, - name: "toxicity", - value: { - classification: classification, - model: SiteSetting.ai_toxicity_inference_service_api_model, - }.to_json, - ) - end - - def flag!(target, toxic_labels) - ::DiscourseAI::FlagManager.new(target, reasons: toxic_labels).flag! - end - end - end -end diff --git a/lib/modules/toxicity/toxicity_classification.rb b/lib/modules/toxicity/toxicity_classification.rb new file mode 100644 index 00000000..5767d51c --- /dev/null +++ b/lib/modules/toxicity/toxicity_classification.rb @@ -0,0 +1,61 @@ +# frozen_string_literal: true + +module DiscourseAI + module Toxicity + class ToxicityClassification + CLASSIFICATION_LABELS = %i[ + toxicity + severe_toxicity + obscene + identity_attack + insult + threat + sexual_explicit + ] + + def type + :toxicity + end + + def can_classify?(target) + content_of(target).present? + end + + def should_flag_based_on?(classification_data) + return false if !SiteSetting.ai_toxicity_flag_automatically + + # We only use one model for this classification. + # Classification_data looks like { model_name => classification } + _model_used, data = classification_data.to_a.first + + CLASSIFICATION_LABELS.any? do |label| + data[label] >= SiteSetting.send("ai_toxicity_flag_threshold_#{label}") + end + end + + def request(target_to_classify) + data = + ::DiscourseAI::InferenceManager.perform!( + "#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify", + SiteSetting.ai_toxicity_inference_service_api_model, + content_of(target_to_classify), + SiteSetting.ai_toxicity_inference_service_api_key, + ) + + { SiteSetting.ai_toxicity_inference_service_api_model => data } + end + + private + + def content_of(target_to_classify) + return target_to_classify.message if target_to_classify.is_a?(ChatMessage) + + if target_to_classify.post_number == 1 + "#{target_to_classify.topic.title}\n#{target_to_classify.raw}" + else + target_to_classify.raw + end + end + end + end +end diff --git a/lib/shared/chat_message_classification.rb b/lib/shared/chat_message_classification.rb new file mode 100644 index 00000000..adad5df2 --- /dev/null +++ b/lib/shared/chat_message_classification.rb @@ -0,0 +1,24 @@ +# frozen_string_literal: true + +module ::DiscourseAI + class ChatMessageClassification < Classification + private + + def store_classification(chat_message, type, classification_data) + PluginStore.set( + type, + "chat_message_#{chat_message.id}", + classification_data.merge(date: Time.now.utc), + ) + end + + def flag!(chat_message, _toxic_labels) + Chat::ChatReviewQueue.new.flag_message( + chat_message, + Guardian.new(flagger), + ReviewableScore.types[:inappropriate], + queue_for_review: true, + ) + end + end +end diff --git a/lib/shared/classification.rb b/lib/shared/classification.rb new file mode 100644 index 00000000..1b916866 --- /dev/null +++ b/lib/shared/classification.rb @@ -0,0 +1,39 @@ +# frozen_string_literal: true + +module ::DiscourseAI + class Classification + def initialize(classification_model) + @classification_model = classification_model + end + + def classify!(target) + return :cannot_classify unless classification_model.can_classify?(target) + + classification_model + .request(target) + .tap do |classification| + store_classification(target, classification_model.type, classification) + + if classification_model.should_flag_based_on?(classification) + flag!(target, classification) + end + end + end + + protected + + attr_reader :classification_model + + def flag!(_target, _classification) + raise NotImplemented + end + + def store_classification(_target, _classification) + raise NotImplemented + end + + def flagger + Discourse.system_user + end + end +end diff --git a/lib/shared/flag_manager.rb b/lib/shared/flag_manager.rb deleted file mode 100644 index 405964f0..00000000 --- a/lib/shared/flag_manager.rb +++ /dev/null @@ -1,27 +0,0 @@ -# frozen_string_literal: true - -module ::DiscourseAI - class FlagManager - DEFAULT_FLAGGER = Discourse.system_user - DEFAULT_REASON = "discourse-ai" - - def initialize(object, flagger: DEFAULT_FLAGGER, type: :inappropriate, reasons: DEFAULT_REASON) - @flagger = flagger - @object = object - @type = type - @reasons = reasons - end - - def flag! - PostActionCreator.new( - @flagger, - @object, - PostActionType.types[:inappropriate], - reason: @reasons, - queue_for_review: true, - ).perform - - @object.publish_change_to_clients! :acted - end - end -end diff --git a/lib/shared/inference_manager.rb b/lib/shared/inference_manager.rb index 554cd941..6f7b12ee 100644 --- a/lib/shared/inference_manager.rb +++ b/lib/shared/inference_manager.rb @@ -11,7 +11,7 @@ module ::DiscourseAI raise Net::HTTPBadResponse unless response.status == 200 - JSON.parse(response.body) + JSON.parse(response.body, symbolize_names: true) end end end diff --git a/lib/shared/post_classification.rb b/lib/shared/post_classification.rb new file mode 100644 index 00000000..807eb139 --- /dev/null +++ b/lib/shared/post_classification.rb @@ -0,0 +1,23 @@ +# frozen_string_literal: true + +module ::DiscourseAI + class PostClassification < Classification + private + + def store_classification(post, type, classification_data) + PostCustomField.create!(post_id: post.id, name: type, value: classification_data.to_json) + end + + def flag!(post, classification_type) + PostActionCreator.new( + flagger, + post, + PostActionType.types[:inappropriate], + reason: classification_type, + queue_for_review: true, + ).perform + + post.publish_change_to_clients! :acted + end + end +end diff --git a/plugin.rb b/plugin.rb index 0330d805..7334f7e5 100644 --- a/plugin.rb +++ b/plugin.rb @@ -15,7 +15,9 @@ after_initialize do end require_relative "lib/shared/inference_manager" - require_relative "lib/shared/flag_manager" + require_relative "lib/shared/classification" + require_relative "lib/shared/post_classification" + require_relative "lib/shared/chat_message_classification" require_relative "lib/modules/nsfw/entry_point" require_relative "lib/modules/toxicity/entry_point" diff --git a/spec/lib/modules/nsfw/evaluation_spec.rb b/spec/lib/modules/nsfw/evaluation_spec.rb deleted file mode 100644 index c5d36b9c..00000000 --- a/spec/lib/modules/nsfw/evaluation_spec.rb +++ /dev/null @@ -1,49 +0,0 @@ -# frozen_string_literal: true - -require "rails_helper" -require_relative "../../../support/nsfw_inference_stubs" - -describe DiscourseAI::NSFW::Evaluation do - before do - SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com" - SiteSetting.ai_nsfw_detection_enabled = true - end - - fab!(:image) { Fabricate(:s3_image_upload) } - - let(:available_models) { SiteSetting.ai_nsfw_models.split("|") } - - describe "perform" do - context "when we determine content is NSFW" do - before { NSFWInferenceStubs.positive(image) } - - it "returns true alongside the evaluation" do - result = subject.perform(image) - - expect(result[:verdict]).to eq(true) - - available_models.each do |model| - expect(result.dig(:evaluation, model.to_sym)).to eq( - NSFWInferenceStubs.positive_result(model), - ) - end - end - end - - context "when we determine content is safe" do - before { NSFWInferenceStubs.negative(image) } - - it "returns false alongside the evaluation" do - result = subject.perform(image) - - expect(result[:verdict]).to eq(false) - - available_models.each do |model| - expect(result.dig(:evaluation, model.to_sym)).to eq( - NSFWInferenceStubs.negative_result(model), - ) - end - end - end - end -end diff --git a/spec/lib/modules/nsfw/jobs/regular/evaluate_post_uploads_spec.rb b/spec/lib/modules/nsfw/jobs/regular/evaluate_post_uploads_spec.rb index 38b51859..718da1fe 100644 --- a/spec/lib/modules/nsfw/jobs/regular/evaluate_post_uploads_spec.rb +++ b/spec/lib/modules/nsfw/jobs/regular/evaluate_post_uploads_spec.rb @@ -76,25 +76,5 @@ describe Jobs::EvaluatePostUploads do end end end - - context "when the post has multiple uploads" do - fab!(:upload_2) { Fabricate(:upload) } - - before { post.uploads << upload_2 } - - context "when we conclude content is NSFW" do - before do - NSFWInferenceStubs.negative(upload_1) - NSFWInferenceStubs.positive(upload_2) - end - - it "flags and hides the post if at least one upload is considered NSFW" do - subject.execute({ post_id: post.id }) - - expect(ReviewableFlaggedPost.where(target: post).count).to eq(1) - expect(post.reload.hidden?).to eq(true) - end - end - end end end diff --git a/spec/lib/modules/nsfw/nsfw_classification_spec.rb b/spec/lib/modules/nsfw/nsfw_classification_spec.rb new file mode 100644 index 00000000..4e5176e1 --- /dev/null +++ b/spec/lib/modules/nsfw/nsfw_classification_spec.rb @@ -0,0 +1,104 @@ +# frozen_string_literal: true + +require "rails_helper" +require_relative "../../../support/nsfw_inference_stubs" + +describe DiscourseAI::NSFW::NSFWClassification do + before { SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com" } + + let(:available_models) { SiteSetting.ai_nsfw_models.split("|") } + + describe "#request" do + fab!(:upload_1) { Fabricate(:s3_image_upload) } + fab!(:post) { Fabricate(:post, uploads: [upload_1]) } + + def assert_correctly_classified(upload, results, expected) + available_models.each do |model| + model_result = results.dig(upload.id, model) + + expect(model_result).to eq(expected[model]) + end + end + + def build_expected_classification(positive: true) + available_models.reduce({}) do |memo, model| + model_expected = + if positive + NSFWInferenceStubs.positive_result(model) + else + NSFWInferenceStubs.negative_result(model) + end + + memo[model] = model_expected + memo + end + end + + context "when the target has one upload" do + it "returns the classification and the model used for it" do + NSFWInferenceStubs.positive(upload_1) + expected = build_expected_classification + + classification = subject.request(post) + + assert_correctly_classified(upload_1, classification, expected) + end + + context "when the target has multiple uploads" do + fab!(:upload_2) { Fabricate(:upload) } + + before { post.uploads << upload_2 } + + it "returns a classification for each one" do + NSFWInferenceStubs.positive(upload_1) + NSFWInferenceStubs.negative(upload_2) + expected_upload_1 = build_expected_classification + expected_upload_2 = build_expected_classification(positive: false) + + classification = subject.request(post) + + assert_correctly_classified(upload_1, classification, expected_upload_1) + assert_correctly_classified(upload_2, classification, expected_upload_2) + end + end + end + end + + describe "#should_flag_based_on?" do + before { SiteSetting.ai_nsfw_flag_automatically = true } + + let(:positive_classification) do + { + 1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, + 2 => available_models.map { |m| { m => NSFWInferenceStubs.positive_result(m) } }, + } + end + + let(:negative_classification) do + { + 1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, + 2 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, + } + end + + it "returns false when NSFW flaggin is disabled" do + SiteSetting.ai_nsfw_flag_automatically = false + + should_flag = subject.should_flag_based_on?(positive_classification) + + expect(should_flag).to eq(false) + end + + it "returns true if the response is NSFW based on our thresholds" do + should_flag = subject.should_flag_based_on?(positive_classification) + + expect(should_flag).to eq(true) + end + + it "returns false if the response is safe based on our thresholds" do + should_flag = subject.should_flag_based_on?(negative_classification) + + expect(should_flag).to eq(false) + end + end +end diff --git a/spec/lib/modules/sentiment/post_classifier_spec.rb b/spec/lib/modules/sentiment/post_classifier_spec.rb deleted file mode 100644 index 35f14abf..00000000 --- a/spec/lib/modules/sentiment/post_classifier_spec.rb +++ /dev/null @@ -1,26 +0,0 @@ -# frozen_string_literal: true - -require "rails_helper" -require_relative "../../../support/sentiment_inference_stubs" - -describe DiscourseAI::Sentiment::PostClassifier do - fab!(:post) { Fabricate(:post) } - - before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" } - - describe "#classify!" do - it "stores each model classification in a post custom field" do - SentimentInferenceStubs.stub_classification(post) - - subject.classify!(post) - - subject.available_models.each do |model| - stored_classification = PostCustomField.find_by(post: post, name: "ai-sentiment-#{model}") - expect(stored_classification).to be_present - expect(stored_classification.value).to eq( - { classification: SentimentInferenceStubs.model_response(model) }.to_json, - ) - end - end - end -end diff --git a/spec/lib/modules/sentiment/sentiment_classification_spec.rb b/spec/lib/modules/sentiment/sentiment_classification_spec.rb new file mode 100644 index 00000000..8fe2a3b8 --- /dev/null +++ b/spec/lib/modules/sentiment/sentiment_classification_spec.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +require "rails_helper" +require_relative "../../../support/sentiment_inference_stubs" + +describe DiscourseAI::Sentiment::SentimentClassification do + describe "#request" do + fab!(:target) { Fabricate(:post) } + + before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" } + + it "returns the classification and the model used for it" do + SentimentInferenceStubs.stub_classification(target) + + result = subject.request(target) + + subject.available_models.each do |model| + expect(result[model]).to eq(SentimentInferenceStubs.model_response(model)) + end + end + end +end diff --git a/spec/lib/modules/toxicity/chat_message_classifier_spec.rb b/spec/lib/modules/toxicity/chat_message_classifier_spec.rb deleted file mode 100644 index a40f9842..00000000 --- a/spec/lib/modules/toxicity/chat_message_classifier_spec.rb +++ /dev/null @@ -1,48 +0,0 @@ -# frozen_string_literal: true - -require "rails_helper" -require_relative "../../../support/toxicity_inference_stubs" - -describe DiscourseAI::Toxicity::ChatMessageClassifier do - before { SiteSetting.ai_toxicity_flag_automatically = true } - - fab!(:chat_message) { Fabricate(:chat_message) } - - describe "#classify!" do - it "creates a reviewable when the post is classified as toxic" do - ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) - - subject.classify!(chat_message) - - expect(ReviewableChatMessage.where(target: chat_message).count).to eq(1) - end - - it "doesn't create a reviewable if the post is not classified as toxic" do - ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: false) - - subject.classify!(chat_message) - - expect(ReviewableChatMessage.where(target: chat_message).count).to be_zero - end - - it "doesn't create a reviewable if flagging is disabled" do - SiteSetting.ai_toxicity_flag_automatically = false - ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) - - subject.classify!(chat_message) - - expect(ReviewableChatMessage.where(target: chat_message).count).to be_zero - end - - it "stores the classification in a custom field" do - ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: false) - - subject.classify!(chat_message) - store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}").deep_symbolize_keys - - expect(store_row[:classification]).to eq(ToxicityInferenceStubs.civilized_response) - expect(store_row[:model]).to eq(SiteSetting.ai_toxicity_inference_service_api_model) - expect(store_row[:date]).to be_present - end - end -end diff --git a/spec/lib/modules/toxicity/post_classifier_spec.rb b/spec/lib/modules/toxicity/post_classifier_spec.rb deleted file mode 100644 index bc9832ce..00000000 --- a/spec/lib/modules/toxicity/post_classifier_spec.rb +++ /dev/null @@ -1,51 +0,0 @@ -# frozen_string_literal: true - -require "rails_helper" -require_relative "../../../support/toxicity_inference_stubs" - -describe DiscourseAI::Toxicity::PostClassifier do - before { SiteSetting.ai_toxicity_flag_automatically = true } - - fab!(:post) { Fabricate(:post) } - - describe "#classify!" do - it "creates a reviewable when the post is classified as toxic" do - ToxicityInferenceStubs.stub_post_classification(post, toxic: true) - - subject.classify!(post) - - expect(ReviewableFlaggedPost.where(target: post).count).to eq(1) - end - - it "doesn't create a reviewable if the post is not classified as toxic" do - ToxicityInferenceStubs.stub_post_classification(post, toxic: false) - - subject.classify!(post) - - expect(ReviewableFlaggedPost.where(target: post).count).to be_zero - end - - it "doesn't create a reviewable if flagging is disabled" do - SiteSetting.ai_toxicity_flag_automatically = false - ToxicityInferenceStubs.stub_post_classification(post, toxic: true) - - subject.classify!(post) - - expect(ReviewableFlaggedPost.where(target: post).count).to be_zero - end - - it "stores the classification in a custom field" do - ToxicityInferenceStubs.stub_post_classification(post, toxic: false) - - subject.classify!(post) - custom_field = PostCustomField.find_by(post: post, name: "toxicity") - - expect(custom_field.value).to eq( - { - classification: ToxicityInferenceStubs.civilized_response, - model: SiteSetting.ai_toxicity_inference_service_api_model, - }.to_json, - ) - end - end -end diff --git a/spec/lib/modules/toxicity/toxicity_classification_spec.rb b/spec/lib/modules/toxicity/toxicity_classification_spec.rb new file mode 100644 index 00000000..3edd2b91 --- /dev/null +++ b/spec/lib/modules/toxicity/toxicity_classification_spec.rb @@ -0,0 +1,56 @@ +# frozen_string_literal: true + +require "rails_helper" +require_relative "../../../support/toxicity_inference_stubs" + +describe DiscourseAI::Toxicity::ToxicityClassification do + describe "#request" do + fab!(:target) { Fabricate(:post) } + + it "returns the classification and the model used for it" do + ToxicityInferenceStubs.stub_post_classification(target, toxic: false) + + result = subject.request(target) + + expect(result[SiteSetting.ai_toxicity_inference_service_api_model]).to eq( + ToxicityInferenceStubs.civilized_response, + ) + end + end + + describe "#should_flag_based_on?" do + before { SiteSetting.ai_toxicity_flag_automatically = true } + + let(:toxic_response) do + { + SiteSetting.ai_toxicity_inference_service_api_model => + ToxicityInferenceStubs.toxic_response, + } + end + + it "returns false when toxicity flaggin is disabled" do + SiteSetting.ai_toxicity_flag_automatically = false + + should_flag = subject.should_flag_based_on?(toxic_response) + + expect(should_flag).to eq(false) + end + + it "returns true if the response is toxic based on our thresholds" do + should_flag = subject.should_flag_based_on?(toxic_response) + + expect(should_flag).to eq(true) + end + + it "returns false if the response is civilized based on our thresholds" do + civilized_response = { + SiteSetting.ai_toxicity_inference_service_api_model => + ToxicityInferenceStubs.civilized_response, + } + + should_flag = subject.should_flag_based_on?(civilized_response) + + expect(should_flag).to eq(false) + end + end +end diff --git a/spec/shared/chat_message_classification_spec.rb b/spec/shared/chat_message_classification_spec.rb new file mode 100644 index 00000000..1b7d49e0 --- /dev/null +++ b/spec/shared/chat_message_classification_spec.rb @@ -0,0 +1,42 @@ +# frozen_string_literal: true + +require "rails_helper" +require_relative "../support/toxicity_inference_stubs" + +describe DiscourseAI::ChatMessageClassification do + fab!(:chat_message) { Fabricate(:chat_message) } + + let(:model) { DiscourseAI::Toxicity::ToxicityClassification.new } + let(:classification) { described_class.new(model) } + + describe "#classify!" do + before { ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) } + + it "stores the model classification data in a custom field" do + classification.classify!(chat_message) + store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}") + + classified_data = + store_row[SiteSetting.ai_toxicity_inference_service_api_model].symbolize_keys + + expect(classified_data).to eq(ToxicityInferenceStubs.toxic_response) + expect(store_row[:date]).to be_present + end + + it "flags the message when the model decides we should" do + SiteSetting.ai_toxicity_flag_automatically = true + + classification.classify!(chat_message) + + expect(ReviewableChatMessage.where(target: chat_message).count).to eq(1) + end + + it "doesn't flags the message if the model decides we shouldn't" do + SiteSetting.ai_toxicity_flag_automatically = false + + classification.classify!(chat_message) + + expect(ReviewableChatMessage.where(target: chat_message).count).to be_zero + end + end +end diff --git a/spec/shared/post_classification_spec.rb b/spec/shared/post_classification_spec.rb new file mode 100644 index 00000000..ca5714fd --- /dev/null +++ b/spec/shared/post_classification_spec.rb @@ -0,0 +1,44 @@ +# frozen_string_literal: true + +require "rails_helper" +require_relative "../support/toxicity_inference_stubs" + +describe DiscourseAI::PostClassification do + fab!(:post) { Fabricate(:post) } + + let(:model) { DiscourseAI::Toxicity::ToxicityClassification.new } + let(:classification) { described_class.new(model) } + + describe "#classify!" do + before { ToxicityInferenceStubs.stub_post_classification(post, toxic: true) } + + it "stores the model classification data in a custom field" do + classification.classify!(post) + custom_field = PostCustomField.find_by(post: post, name: model.type) + + expect(custom_field.value).to eq( + { + SiteSetting.ai_toxicity_inference_service_api_model => + ToxicityInferenceStubs.toxic_response, + }.to_json, + ) + end + + it "flags the message and hides the post when the model decides we should" do + SiteSetting.ai_toxicity_flag_automatically = true + + classification.classify!(post) + + expect(ReviewableFlaggedPost.where(target: post).count).to eq(1) + expect(post.reload.hidden?).to eq(true) + end + + it "doesn't flags the message if the model decides we shouldn't" do + SiteSetting.ai_toxicity_flag_automatically = false + + classification.classify!(post) + + expect(ReviewableFlaggedPost.where(target: post).count).to be_zero + end + end +end diff --git a/spec/support/sentiment_inference_stubs.rb b/spec/support/sentiment_inference_stubs.rb index 67351988..49ec2c9d 100644 --- a/spec/support/sentiment_inference_stubs.rb +++ b/spec/support/sentiment_inference_stubs.rb @@ -15,7 +15,7 @@ class SentimentInferenceStubs def stub_classification(post) content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw - DiscourseAI::Sentiment::PostClassifier.new.available_models.each do |model| + DiscourseAI::Sentiment::SentimentClassification.new.available_models.each do |model| WebMock .stub_request(:post, endpoint) .with(body: JSON.dump(model: model, content: content))