mirror of https://github.com/dapr/dapr-agents.git
				
				
				
			Merge pull request #2 from dapr-sandbox/feature/sentencetransformer-cachedir
Add SentenceTransformer Cache Directory
This commit is contained in:
		
						commit
						1a323f525a
					
				| 
						 | 
				
			
			@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
 | 
			
		|||
 | 
			
		||||
[project]
 | 
			
		||||
name = "floki-ai"
 | 
			
		||||
version = "0.10.2"
 | 
			
		||||
version = "0.11.2"
 | 
			
		||||
description = "Agentic Workflows Made Simple"
 | 
			
		||||
readme = "README.md"
 | 
			
		||||
authors = [{ name = "Roberto Rodriguez" }]
 | 
			
		||||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ homepage = "https://github.com/Cyb3rWard0g/floki"
 | 
			
		|||
 | 
			
		||||
[tool.poetry]
 | 
			
		||||
name = "floki"
 | 
			
		||||
version = "0.10.2"
 | 
			
		||||
version = "0.11.2"
 | 
			
		||||
description = "Agentic Workflows Made Simple"
 | 
			
		||||
authors = ["Roberto Rodriguez"]
 | 
			
		||||
license = "MIT"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										2
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										2
									
								
								setup.py
								
								
								
								
							| 
						 | 
				
			
			@ -5,7 +5,7 @@ with open("README.md", encoding="utf-8") as f:
 | 
			
		|||
 | 
			
		||||
setup(
 | 
			
		||||
    name="floki-ai",
 | 
			
		||||
    version="0.10.2",
 | 
			
		||||
    version="0.11.2",
 | 
			
		||||
    author="Roberto Rodriguez",
 | 
			
		||||
    description="Agentic Workflows Made Simple",
 | 
			
		||||
    long_description=long_description,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@ from floki.document.embedder.base import EmbedderBase
 | 
			
		|||
from typing import List, Any, Optional, Union, Literal
 | 
			
		||||
from pydantic import Field
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -15,6 +16,7 @@ class SentenceTransformerEmbedder(EmbedderBase):
 | 
			
		|||
    device: Literal["cpu", "cuda", "mps", "npu"] = Field(default="cpu", description="Device for computation.")
 | 
			
		||||
    normalize_embeddings: bool = Field(default=False, description="Whether to normalize embeddings.")
 | 
			
		||||
    multi_process: bool = Field(default=False, description="Whether to use multi-process encoding.")
 | 
			
		||||
    cache_dir: Optional[str] = Field(default=None, description="Directory to cache or load the model.")
 | 
			
		||||
    
 | 
			
		||||
    client: Optional[Any] = Field(default=None, init=False, description="Loaded SentenceTransformer model.")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -32,9 +34,25 @@ class SentenceTransformerEmbedder(EmbedderBase):
 | 
			
		|||
                "Install it using `pip install sentence-transformers`."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Loading SentenceTransformer model: {self.model}")
 | 
			
		||||
        self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=self.model, device=self.device)
 | 
			
		||||
        logger.info("Model loaded successfully.")
 | 
			
		||||
        # Determine whether to load from cache or download
 | 
			
		||||
        model_path = self.cache_dir if self.cache_dir and os.path.exists(self.cache_dir) else self.model
 | 
			
		||||
        # Attempt to load the model
 | 
			
		||||
        try:
 | 
			
		||||
            if os.path.exists(model_path):
 | 
			
		||||
                logger.info(f"Loading SentenceTransformer model from local path: {model_path}")
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info(f"Downloading SentenceTransformer model: {self.model}")
 | 
			
		||||
                if self.cache_dir:
 | 
			
		||||
                    logger.info(f"Model will be cached to: {self.cache_dir}")
 | 
			
		||||
            self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=model_path, device=self.device)
 | 
			
		||||
            logger.info("Model loaded successfully.")
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to load SentenceTransformer model: {e}")
 | 
			
		||||
            raise
 | 
			
		||||
        # Save to cache directory if downloaded
 | 
			
		||||
        if model_path == self.model and self.cache_dir and not os.path.exists(self.cache_dir):
 | 
			
		||||
            logger.info(f"Saving the downloaded model to: {self.cache_dir}")
 | 
			
		||||
            self.client.save(self.cache_dir)
 | 
			
		||||
 | 
			
		||||
    def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue