mirror of https://github.com/dapr/dapr-agents.git
153 lines
7.0 KiB
Python
153 lines
7.0 KiB
Python
from dapr_agents.document.embedder.base import EmbedderBase
|
|
from dapr_agents.llm.openai.embeddings import OpenAIEmbeddingClient
|
|
from typing import List, Any, Union, Optional
|
|
from pydantic import Field, ConfigDict
|
|
import numpy as np
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
|
|
"""
|
|
OpenAI-based embedder for generating text embeddings with handling for long inputs.
|
|
Inherits functionality from OpenAIEmbeddingClient for API interactions.
|
|
"""
|
|
|
|
max_tokens: int = Field(default=8191, description="Maximum tokens allowed per input.")
|
|
chunk_size: int = Field(default=1000, description="Batch size for embedding requests.")
|
|
normalize: bool = Field(default=True, description="Whether to normalize embeddings.")
|
|
encoding_name: Optional[str] = Field(default=None, description="Token encoding name (if provided).")
|
|
encoder: Optional[Any] = Field(default=None, init=False, description="TikToken Encoder")
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Initialize attributes after model validation.
|
|
Automatically determines the appropriate encoding for the model.
|
|
"""
|
|
super().model_post_init(__context)
|
|
|
|
try:
|
|
import tiktoken
|
|
from tiktoken.core import Encoding
|
|
except ImportError:
|
|
raise ImportError(
|
|
"The `tiktoken` library is required for tokenizing inputs. "
|
|
"Install it using `pip install tiktoken`."
|
|
)
|
|
|
|
if self.encoding_name:
|
|
# Use the explicitly provided encoding
|
|
self.encoder: Encoding = tiktoken.get_encoding(self.encoding_name)
|
|
else:
|
|
# Automatically determine encoding based on model
|
|
try:
|
|
self.encoder: Encoding = tiktoken.encoding_for_model(self.model)
|
|
except KeyError:
|
|
# Fallback to default encoding and model
|
|
logger.warning(
|
|
f"Model '{self.model}' not recognized. "
|
|
"Defaulting to 'cl100k_base' encoding and 'text-embedding-ada-002' model."
|
|
)
|
|
self.encoder = tiktoken.get_encoding("cl100k_base")
|
|
self.model = "text-embedding-ada-002"
|
|
|
|
def _tokenize_text(self, text: str) -> List[int]:
|
|
"""Tokenizes the input text using the specified encoding."""
|
|
return self.encoder.encode(text)
|
|
|
|
def _chunk_tokens(self, tokens: List[int], chunk_length: int) -> List[List[int]]:
|
|
"""Splits tokens into chunks of the specified length."""
|
|
return [tokens[i:i + chunk_length] for i in range(0, len(tokens), chunk_length)]
|
|
|
|
def _process_embeddings(self, embeddings: List[List[float]], weights: List[int]) -> List[float]:
|
|
"""Combines embeddings using weighted averaging."""
|
|
weighted_avg = np.average(embeddings, axis=0, weights=weights)
|
|
if self.normalize:
|
|
norm = np.linalg.norm(weighted_avg)
|
|
return (weighted_avg / norm).tolist()
|
|
return weighted_avg.tolist()
|
|
|
|
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
|
"""
|
|
Embeds input text(s) with support for both single and multiple inputs, handling long texts via chunking and batching.
|
|
|
|
Args:
|
|
input (Union[str, List[str]]): The input text(s) to embed. Can be a single string or a list of strings.
|
|
|
|
Returns:
|
|
Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s).
|
|
- Returns a single list of floats for a single string input.
|
|
- Returns a list of lists of floats for a list of string inputs.
|
|
|
|
Notes:
|
|
- Handles long inputs by chunking them into smaller parts based on `max_tokens` and reassembling embeddings.
|
|
- Batches API calls for efficiency using `chunk_size`.
|
|
- Automatically combines chunk embeddings using weighted averaging for long inputs.
|
|
"""
|
|
# Validate input
|
|
if not input or (isinstance(input, list) and all(not q for q in input)):
|
|
raise ValueError("Input must contain valid text.")
|
|
|
|
# Check if the input is a single string or a list of strings
|
|
single_input = isinstance(input, str)
|
|
input_strings = [input] if single_input else input
|
|
|
|
# Tokenize the input strings to check for long texts requiring chunking
|
|
tokenized_inputs = [self._tokenize_text(q) for q in input_strings]
|
|
chunks = [] # Holds text chunks for API calls
|
|
chunk_indices = [] # Maps each chunk to its original input index
|
|
|
|
# Handle tokenized inputs: Chunk long inputs and map chunks to their respective inputs
|
|
for idx, tokens in enumerate(tokenized_inputs):
|
|
if len(tokens) <= self.max_tokens:
|
|
# Directly use the text if it's within max token limits
|
|
chunks.append(self.encoder.decode(tokens))
|
|
chunk_indices.append(idx)
|
|
else:
|
|
# Split long inputs into smaller chunks
|
|
token_chunks = self._chunk_tokens(tokens, self.max_tokens)
|
|
chunks.extend([self.encoder.decode(chunk) for chunk in token_chunks])
|
|
chunk_indices.extend([idx] * len(token_chunks))
|
|
|
|
# Process the chunks in batches for efficiency
|
|
batch_size = self.chunk_size
|
|
chunk_embeddings = [] # Holds embeddings for all chunks
|
|
|
|
for i in range(0, len(chunks), batch_size):
|
|
batch = chunks[i:i + batch_size]
|
|
response = self.create_embedding(input=batch) # Batch API call
|
|
chunk_embeddings.extend(r.embedding for r in response.data)
|
|
|
|
# Group chunk embeddings by their original query indices
|
|
grouped_embeddings = [[] for _ in range(len(input_strings))]
|
|
for idx, embedding in zip(chunk_indices, chunk_embeddings):
|
|
grouped_embeddings[idx].append(embedding)
|
|
|
|
# Combine chunk embeddings for each query
|
|
results = []
|
|
for embeddings, tokens in zip(grouped_embeddings, tokenized_inputs):
|
|
if len(embeddings) == 1:
|
|
# If only one chunk, use its embedding directly
|
|
results.append(embeddings[0])
|
|
else:
|
|
# Combine chunk embeddings using weighted averaging
|
|
weights = [len(chunk) for chunk in self._chunk_tokens(tokens, self.max_tokens)]
|
|
results.append(self._process_embeddings(embeddings, weights))
|
|
|
|
# Return a single embedding if the input was a single string; otherwise, return a list
|
|
return results[0] if single_input else results
|
|
|
|
def __call__(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
|
"""
|
|
Allows the instance to be called directly to embed text(s).
|
|
|
|
Args:
|
|
input (Union[str, List[str]]): The input text(s) to embed.
|
|
|
|
Returns:
|
|
Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s).
|
|
"""
|
|
return self.embed(input)
|