Merge pull request #2 from dapr-sandbox/feature/sentencetransformer-cachedir

Add SentenceTransformer Cache Directory
This commit is contained in:
Roberto Rodriguez 2025-01-25 14:10:25 -08:00 committed by GitHub
commit 1a323f525a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 6 deletions

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "floki-ai"
version = "0.10.2"
version = "0.11.2"
description = "Agentic Workflows Made Simple"
readme = "README.md"
authors = [{ name = "Roberto Rodriguez" }]
@ -49,7 +49,7 @@ homepage = "https://github.com/Cyb3rWard0g/floki"
[tool.poetry]
name = "floki"
version = "0.10.2"
version = "0.11.2"
description = "Agentic Workflows Made Simple"
authors = ["Roberto Rodriguez"]
license = "MIT"

View File

@ -5,7 +5,7 @@ with open("README.md", encoding="utf-8") as f:
setup(
name="floki-ai",
version="0.10.2",
version="0.11.2",
author="Roberto Rodriguez",
description="Agentic Workflows Made Simple",
long_description=long_description,

View File

@ -2,6 +2,7 @@ from floki.document.embedder.base import EmbedderBase
from typing import List, Any, Optional, Union, Literal
from pydantic import Field
import logging
import os
logger = logging.getLogger(__name__)
@ -15,6 +16,7 @@ class SentenceTransformerEmbedder(EmbedderBase):
device: Literal["cpu", "cuda", "mps", "npu"] = Field(default="cpu", description="Device for computation.")
normalize_embeddings: bool = Field(default=False, description="Whether to normalize embeddings.")
multi_process: bool = Field(default=False, description="Whether to use multi-process encoding.")
cache_dir: Optional[str] = Field(default=None, description="Directory to cache or load the model.")
client: Optional[Any] = Field(default=None, init=False, description="Loaded SentenceTransformer model.")
@ -32,9 +34,25 @@ class SentenceTransformerEmbedder(EmbedderBase):
"Install it using `pip install sentence-transformers`."
)
logger.info(f"Loading SentenceTransformer model: {self.model}")
self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=self.model, device=self.device)
logger.info("Model loaded successfully.")
# Determine whether to load from cache or download
model_path = self.cache_dir if self.cache_dir and os.path.exists(self.cache_dir) else self.model
# Attempt to load the model
try:
if os.path.exists(model_path):
logger.info(f"Loading SentenceTransformer model from local path: {model_path}")
else:
logger.info(f"Downloading SentenceTransformer model: {self.model}")
if self.cache_dir:
logger.info(f"Model will be cached to: {self.cache_dir}")
self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=model_path, device=self.device)
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load SentenceTransformer model: {e}")
raise
# Save to cache directory if downloaded
if model_path == self.model and self.cache_dir and not os.path.exists(self.cache_dir):
logger.info(f"Saving the downloaded model to: {self.cache_dir}")
self.client.save(self.cache_dir)
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
"""