From 1287ef4428fbec2ec211bdf62f26444ac56576d3 Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Thu, 28 Dec 2023 10:28:01 -0300 Subject: [PATCH] FEATURE: Support for Gemini Embeddings (#382) --- config/settings.yml | 1 + ...01_create_gemini_topic_embeddings_table.rb | 16 ++++++ lib/embeddings/vector_representations/base.rb | 1 + .../vector_representations/gemini.rb | 49 +++++++++++++++++++ lib/inference/gemini_embeddings.rb | 22 +++++++++ 5 files changed, 89 insertions(+) create mode 100644 db/migrate/20231227223301_create_gemini_topic_embeddings_table.rb create mode 100644 lib/embeddings/vector_representations/gemini.rb create mode 100644 lib/inference/gemini_embeddings.rb diff --git a/config/settings.yml b/config/settings.yml index 651f1588..2cd9846b 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -217,6 +217,7 @@ discourse_ai: - text-embedding-ada-002 - multilingual-e5-large - bge-large-en + - gemini ai_embeddings_generate_for_pms: false ai_embeddings_semantic_related_topics_enabled: default: false diff --git a/db/migrate/20231227223301_create_gemini_topic_embeddings_table.rb b/db/migrate/20231227223301_create_gemini_topic_embeddings_table.rb new file mode 100644 index 00000000..04a532c1 --- /dev/null +++ b/db/migrate/20231227223301_create_gemini_topic_embeddings_table.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +class CreateGeminiTopicEmbeddingsTable < ActiveRecord::Migration[7.0] + def change + create_table :ai_topic_embeddings_5_1, id: false do |t| + t.integer :topic_id, null: false + t.integer :model_version, null: false + t.integer :strategy_version, null: false + t.text :digest, null: false + t.column :embeddings, "vector(768)", null: false + t.timestamps + + t.index :topic_id, unique: true + end + end +end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 3a3ec9d0..3a5cb469 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -10,6 +10,7 @@ module DiscourseAi [ DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2, DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn, + DiscourseAi::Embeddings::VectorRepresentations::Gemini, DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002, ].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model } diff --git a/lib/embeddings/vector_representations/gemini.rb b/lib/embeddings/vector_representations/gemini.rb new file mode 100644 index 00000000..4b75da49 --- /dev/null +++ b/lib/embeddings/vector_representations/gemini.rb @@ -0,0 +1,49 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + module VectorRepresentations + class Gemini < Base + def id + 5 + end + + def version + 1 + end + + def name + "gemini" + end + + def dimensions + 768 + end + + def max_sequence_length + 2048 + end + + def pg_function + "<=>" + end + + def pg_index_type + "vector_cosine_ops" + end + + def vector_from(text) + response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text) + response[:embedding][:values] + end + + # There is no public tokenizer for Gemini, and from the ones we already ship in the plugin + # OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe + # to use OpenAI tokenizer since it will overestimate the number of tokens. + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer + end + end + end + end +end diff --git a/lib/inference/gemini_embeddings.rb b/lib/inference/gemini_embeddings.rb new file mode 100644 index 00000000..933b0fe3 --- /dev/null +++ b/lib/inference/gemini_embeddings.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +module ::DiscourseAi + module Inference + class GeminiEmbeddings + def self.perform!(content) + headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + + url = + "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{SiteSetting.ai_gemini_api_key}" + + body = { content: { parts: [{ text: content }] } } + + response = Faraday.post(url, body.to_json, headers) + + raise Net::HTTPBadResponse if ![200].include?(response.status) + + JSON.parse(response.body, symbolize_names: true) + end + end + end +end