diff --git a/lib/modules/embeddings/topic.rb b/lib/modules/embeddings/topic.rb index 23e4827e..b513cae7 100644 --- a/lib/modules/embeddings/topic.rb +++ b/lib/modules/embeddings/topic.rb @@ -33,22 +33,27 @@ module DiscourseAi def asymmetric_semantic_search(model, query, limit, offset) embedding = model.generate_embedding(query) - candidate_ids = - DiscourseAi::Database::Connection - .db - .query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset) - SELECT - topic_id - FROM - topic_embeddings_#{model.name.underscore} - ORDER BY - embedding #{model.pg_function} '[:query_embedding]' - LIMIT :limit - OFFSET :offset - SQL - .map(&:topic_id) - - raise StandardError, "No embeddings found for topic #{topic.id}" if candidate_ids.empty? + begin + candidate_ids = + DiscourseAi::Database::Connection + .db + .query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset) + SELECT + topic_id + FROM + topic_embeddings_#{model.name.underscore} + ORDER BY + embedding #{model.pg_function} '[:query_embedding]' + LIMIT :limit + OFFSET :offset + SQL + .map(&:topic_id) + rescue PG::Error => e + Rails.logger.error( + "Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}", + ) + raise MissingEmbeddingError + end candidate_ids end @@ -56,32 +61,49 @@ module DiscourseAi private def query_symmetric_embeddings(model, topic) - DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id) - SELECT - topic_id - FROM - topic_embeddings_#{model.name.underscore} - ORDER BY - embedding #{model.pg_function} ( - SELECT - embedding - FROM - topic_embeddings_#{model.name.underscore} - WHERE - topic_id = :topic_id - LIMIT 1 - ) - LIMIT 100 - SQL + begin + DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id) + SELECT + topic_id + FROM + topic_embeddings_#{model.name.underscore} + ORDER BY + embedding #{model.pg_function} ( + SELECT + embedding + FROM + topic_embeddings_#{model.name.underscore} + WHERE + topic_id = :topic_id + LIMIT 1 + ) + LIMIT 100 + SQL + rescue PG::Error => e + Rails.logger.error( + "Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}", + ) + raise MissingEmbeddingError + end end def persist_embedding(topic, model, embedding) - DiscourseAi::Database::Connection.db.exec(<<~SQL, topic_id: topic.id, embedding: embedding) - INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding) - VALUES (:topic_id, '[:embedding]') - ON CONFLICT (topic_id) - DO UPDATE SET embedding = '[:embedding]' - SQL + begin + DiscourseAi::Database::Connection.db.exec( + <<~SQL, + INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding) + VALUES (:topic_id, '[:embedding]') + ON CONFLICT (topic_id) + DO UPDATE SET embedding = '[:embedding]' + SQL + topic_id: topic.id, + embedding: embedding, + ) + rescue PG::Error => e + Rails.logger.error( + "Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}", + ) + end end end end diff --git a/spec/requests/topic_spec.rb b/spec/requests/topic_spec.rb index fd7b0e99..41eb0000 100644 --- a/spec/requests/topic_spec.rb +++ b/spec/requests/topic_spec.rb @@ -25,6 +25,7 @@ describe ::TopicsController do .returns([topic1.id, topic2.id, topic3.id]) get("#{topic.relative_url}.json") + expect(response.status).to eq(200) json = response.parsed_body expect(json["suggested_topics"].length).to eq(0) @@ -38,5 +39,16 @@ describe ::TopicsController do expect(json["suggested_topics"].length).to eq(0) expect(json["related_topics"].length).to eq(2) end + + it "excludes embeddings when the database is offline" do + DiscourseAi::Database::Connection.stubs(:db).raises(PG::ConnectionBad) + + get "#{topic.relative_url}.json" + expect(response.status).to eq(200) + json = response.parsed_body + + expect(json["suggested_topics"].length).not_to eq(0) + expect(json["related_topics"].length).to eq(0) + end end end