mirror of https://github.com/dapr/dapr-agents.git
Merge pull request #2 from dapr-sandbox/feature/sentencetransformer-cachedir
Add SentenceTransformer Cache Directory
This commit is contained in:
commit
1a323f525a
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "floki-ai"
|
||||
version = "0.10.2"
|
||||
version = "0.11.2"
|
||||
description = "Agentic Workflows Made Simple"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Roberto Rodriguez" }]
|
||||
|
|
@ -49,7 +49,7 @@ homepage = "https://github.com/Cyb3rWard0g/floki"
|
|||
|
||||
[tool.poetry]
|
||||
name = "floki"
|
||||
version = "0.10.2"
|
||||
version = "0.11.2"
|
||||
description = "Agentic Workflows Made Simple"
|
||||
authors = ["Roberto Rodriguez"]
|
||||
license = "MIT"
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -5,7 +5,7 @@ with open("README.md", encoding="utf-8") as f:
|
|||
|
||||
setup(
|
||||
name="floki-ai",
|
||||
version="0.10.2",
|
||||
version="0.11.2",
|
||||
author="Roberto Rodriguez",
|
||||
description="Agentic Workflows Made Simple",
|
||||
long_description=long_description,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from floki.document.embedder.base import EmbedderBase
|
|||
from typing import List, Any, Optional, Union, Literal
|
||||
from pydantic import Field
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ class SentenceTransformerEmbedder(EmbedderBase):
|
|||
device: Literal["cpu", "cuda", "mps", "npu"] = Field(default="cpu", description="Device for computation.")
|
||||
normalize_embeddings: bool = Field(default=False, description="Whether to normalize embeddings.")
|
||||
multi_process: bool = Field(default=False, description="Whether to use multi-process encoding.")
|
||||
cache_dir: Optional[str] = Field(default=None, description="Directory to cache or load the model.")
|
||||
|
||||
client: Optional[Any] = Field(default=None, init=False, description="Loaded SentenceTransformer model.")
|
||||
|
||||
|
|
@ -32,9 +34,25 @@ class SentenceTransformerEmbedder(EmbedderBase):
|
|||
"Install it using `pip install sentence-transformers`."
|
||||
)
|
||||
|
||||
logger.info(f"Loading SentenceTransformer model: {self.model}")
|
||||
self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=self.model, device=self.device)
|
||||
logger.info("Model loaded successfully.")
|
||||
# Determine whether to load from cache or download
|
||||
model_path = self.cache_dir if self.cache_dir and os.path.exists(self.cache_dir) else self.model
|
||||
# Attempt to load the model
|
||||
try:
|
||||
if os.path.exists(model_path):
|
||||
logger.info(f"Loading SentenceTransformer model from local path: {model_path}")
|
||||
else:
|
||||
logger.info(f"Downloading SentenceTransformer model: {self.model}")
|
||||
if self.cache_dir:
|
||||
logger.info(f"Model will be cached to: {self.cache_dir}")
|
||||
self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=model_path, device=self.device)
|
||||
logger.info("Model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SentenceTransformer model: {e}")
|
||||
raise
|
||||
# Save to cache directory if downloaded
|
||||
if model_path == self.model and self.cache_dir and not os.path.exists(self.cache_dir):
|
||||
logger.info(f"Saving the downloaded model to: {self.cache_dir}")
|
||||
self.client.save(self.cache_dir)
|
||||
|
||||
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue