# frozen_string_literal: true module DiscourseAi module Embeddings class SemanticSearch def self.clear_cache_for(query) digest = OpenSSL::Digest::SHA1.hexdigest(query) hyde_key = "semantic-search-#{digest}-#{SiteSetting.ai_embeddings_semantic_search_hyde_model}" Discourse.cache.delete(hyde_key) Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_selected_model}") Discourse.cache.delete("-#{SiteSetting.ai_embeddings_selected_model}") end def initialize(guardian) @guardian = guardian end def cached_query?(query) digest = OpenSSL::Digest::SHA1.hexdigest(query) embedding_key = build_embedding_key( digest, SiteSetting.ai_embeddings_semantic_search_hyde_model, SiteSetting.ai_embeddings_selected_model, ) Discourse.cache.read(embedding_key).present? end def vector @vector ||= DiscourseAi::Embeddings::Vector.instance end def hyde_embedding(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term) hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model) embedding_key = build_embedding_key( digest, SiteSetting.ai_embeddings_semantic_search_hyde_model, SiteSetting.ai_embeddings_selected_model, ) hypothetical_post = Discourse .cache .fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) } Discourse .cache .fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) } end def embedding(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term) embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_selected_model) Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) } end # this ensures the candidate topics are over selected # that way we have a much better chance of finding topics # if the user filtered the results or index is a bit out of date OVER_SELECTION_FACTOR = 4 def search_for_topics(query, page = 1, hyde: true) max_results_per_page = 100 limit = [Search.per_filter, max_results_per_page].min + 1 offset = (page - 1) * limit search = Search.new(query, { guardian: guardian }) search_term = search.term if search_term.blank? || search_term.length < SiteSetting.min_search_term_length return Post.none end search_embedding = nil search_embedding = hyde_embedding(search_term) if hyde search_embedding = embedding(search_term) if search_embedding.blank? over_selection_limit = limit * OVER_SELECTION_FACTOR schema = DiscourseAi::Embeddings::Schema.for(Topic) candidate_topic_ids = schema.asymmetric_similarity_search( search_embedding, limit: over_selection_limit, offset: offset, ).map(&:topic_id) semantic_results = ::Post .where(post_type: ::Topic.visible_post_types(guardian.user)) .public_posts .where("topics.visible") .where(topic_id: candidate_topic_ids, post_number: 1) .order("array_position(ARRAY#{candidate_topic_ids}, posts.topic_id)") .limit(limit) query_filter_results = search.apply_filters(semantic_results) guardian.filter_allowed_categories(query_filter_results) end def quick_search(query) max_semantic_results_per_page = 100 search = Search.new(query, { guardian: guardian }) search_term = search.term return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length vector = DiscourseAi::Embeddings::Vector.instance digest = OpenSSL::Digest::SHA1.hexdigest(search_term) embedding_key = build_embedding_key( digest, SiteSetting.ai_embeddings_semantic_search_hyde_model, SiteSetting.ai_embeddings_selected_model, ) search_term_embedding = Discourse .cache .fetch(embedding_key, expires_in: 1.week) do vector.vector_from(search_term, asymetric: true) end candidate_post_ids = DiscourseAi::Embeddings::Schema .for(Post) .asymmetric_similarity_search( search_term_embedding, limit: max_semantic_results_per_page, offset: 0, ) .map(&:post_id) semantic_results = ::Post .where(post_type: ::Topic.visible_post_types(guardian.user)) .public_posts .where("topics.visible") .where(id: candidate_post_ids) .order("array_position(ARRAY#{candidate_post_ids}, posts.id)") filtered_results = search.apply_filters(semantic_results) rerank_posts_payload = filtered_results .map(&:cooked) .map { Nokogiri::HTML5.fragment(_1).text } .map { _1.truncate(2000, omission: "") } reranked_results = DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank( search_term, rerank_posts_payload, ) reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5) reranked_semantic_results = ::Post .where(post_type: ::Topic.visible_post_types(guardian.user)) .public_posts .where("topics.visible") .where(id: reordered_ids) .order("array_position(ARRAY#{reordered_ids}, posts.id)") guardian.filter_allowed_categories(reranked_semantic_results) end def hypothetical_post_from(search_term) context = DiscourseAi::Personas::BotContext.new( user: @guardian.user, skip_tool_details: true, feature_name: "semantic_search_hyde", messages: [{ type: :user, content: search_term }], ) bot = build_bot(@guardian.user) return nil if bot.nil? structured_output = nil raw_response = +"" hyde_schema_key = bot.persona.response_format&.first.to_h buffer_blk = Proc.new do |partial, _, type| if type == :structured_output structured_output = partial elsif type.blank? # Assume response is a regular completion. raw_response << partial end end bot.reply(context, &buffer_blk) structured_output&.read_buffered_property(hyde_schema_key["key"]&.to_sym) || raw_response end # Priorities are: # 1. Persona's default LLM # 2. `ai_embeddings_semantic_search_hyde_model` setting. def find_ai_hyde_model(persona_klass) model_id = persona_klass.default_llm_id || SiteSetting.ai_embeddings_semantic_search_hyde_model&.split(":")&.last return if model_id.blank? LlmModel.find_by(id: model_id) end private attr_reader :guardian def build_hyde_key(digest, hyde_model) "semantic-search-#{digest}-#{hyde_model}" end def build_embedding_key(digest, hyde_model, embedding_model) "#{build_hyde_key(digest, hyde_model)}-#{embedding_model}" end def build_bot(user) persona_id = SiteSetting.ai_embeddings_semantic_search_hyde_persona persona_klass = AiPersona.find_by(id: persona_id)&.class_instance return if persona_klass.nil? llm_model = find_ai_hyde_model(persona_klass) return if llm_model.nil? DiscourseAi::Personas::Bot.as(user, persona: persona_klass.new, model: llm_model) end end end end