dapr-agents/dapr_agents/document/embedder/openai.py

153 lines
7.0 KiB
Python

from dapr_agents.document.embedder.base import EmbedderBase
from dapr_agents.llm.openai.embeddings import OpenAIEmbeddingClient
from typing import List, Any, Union, Optional
from pydantic import Field, ConfigDict
import numpy as np
import logging
logger = logging.getLogger(__name__)
class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
"""
OpenAI-based embedder for generating text embeddings with handling for long inputs.
Inherits functionality from OpenAIEmbeddingClient for API interactions.
"""
max_tokens: int = Field(default=8191, description="Maximum tokens allowed per input.")
chunk_size: int = Field(default=1000, description="Batch size for embedding requests.")
normalize: bool = Field(default=True, description="Whether to normalize embeddings.")
encoding_name: Optional[str] = Field(default=None, description="Token encoding name (if provided).")
encoder: Optional[Any] = Field(default=None, init=False, description="TikToken Encoder")
model_config = ConfigDict(arbitrary_types_allowed=True)
def model_post_init(self, __context: Any) -> None:
"""
Initialize attributes after model validation.
Automatically determines the appropriate encoding for the model.
"""
super().model_post_init(__context)
try:
import tiktoken
from tiktoken.core import Encoding
except ImportError:
raise ImportError(
"The `tiktoken` library is required for tokenizing inputs. "
"Install it using `pip install tiktoken`."
)
if self.encoding_name:
# Use the explicitly provided encoding
self.encoder: Encoding = tiktoken.get_encoding(self.encoding_name)
else:
# Automatically determine encoding based on model
try:
self.encoder: Encoding = tiktoken.encoding_for_model(self.model)
except KeyError:
# Fallback to default encoding and model
logger.warning(
f"Model '{self.model}' not recognized. "
"Defaulting to 'cl100k_base' encoding and 'text-embedding-ada-002' model."
)
self.encoder = tiktoken.get_encoding("cl100k_base")
self.model = "text-embedding-ada-002"
def _tokenize_text(self, text: str) -> List[int]:
"""Tokenizes the input text using the specified encoding."""
return self.encoder.encode(text)
def _chunk_tokens(self, tokens: List[int], chunk_length: int) -> List[List[int]]:
"""Splits tokens into chunks of the specified length."""
return [tokens[i:i + chunk_length] for i in range(0, len(tokens), chunk_length)]
def _process_embeddings(self, embeddings: List[List[float]], weights: List[int]) -> List[float]:
"""Combines embeddings using weighted averaging."""
weighted_avg = np.average(embeddings, axis=0, weights=weights)
if self.normalize:
norm = np.linalg.norm(weighted_avg)
return (weighted_avg / norm).tolist()
return weighted_avg.tolist()
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
"""
Embeds input text(s) with support for both single and multiple inputs, handling long texts via chunking and batching.
Args:
input (Union[str, List[str]]): The input text(s) to embed. Can be a single string or a list of strings.
Returns:
Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s).
- Returns a single list of floats for a single string input.
- Returns a list of lists of floats for a list of string inputs.
Notes:
- Handles long inputs by chunking them into smaller parts based on `max_tokens` and reassembling embeddings.
- Batches API calls for efficiency using `chunk_size`.
- Automatically combines chunk embeddings using weighted averaging for long inputs.
"""
# Validate input
if not input or (isinstance(input, list) and all(not q for q in input)):
raise ValueError("Input must contain valid text.")
# Check if the input is a single string or a list of strings
single_input = isinstance(input, str)
input_strings = [input] if single_input else input
# Tokenize the input strings to check for long texts requiring chunking
tokenized_inputs = [self._tokenize_text(q) for q in input_strings]
chunks = [] # Holds text chunks for API calls
chunk_indices = [] # Maps each chunk to its original input index
# Handle tokenized inputs: Chunk long inputs and map chunks to their respective inputs
for idx, tokens in enumerate(tokenized_inputs):
if len(tokens) <= self.max_tokens:
# Directly use the text if it's within max token limits
chunks.append(self.encoder.decode(tokens))
chunk_indices.append(idx)
else:
# Split long inputs into smaller chunks
token_chunks = self._chunk_tokens(tokens, self.max_tokens)
chunks.extend([self.encoder.decode(chunk) for chunk in token_chunks])
chunk_indices.extend([idx] * len(token_chunks))
# Process the chunks in batches for efficiency
batch_size = self.chunk_size
chunk_embeddings = [] # Holds embeddings for all chunks
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
response = self.create_embedding(input=batch) # Batch API call
chunk_embeddings.extend(r.embedding for r in response.data)
# Group chunk embeddings by their original query indices
grouped_embeddings = [[] for _ in range(len(input_strings))]
for idx, embedding in zip(chunk_indices, chunk_embeddings):
grouped_embeddings[idx].append(embedding)
# Combine chunk embeddings for each query
results = []
for embeddings, tokens in zip(grouped_embeddings, tokenized_inputs):
if len(embeddings) == 1:
# If only one chunk, use its embedding directly
results.append(embeddings[0])
else:
# Combine chunk embeddings using weighted averaging
weights = [len(chunk) for chunk in self._chunk_tokens(tokens, self.max_tokens)]
results.append(self._process_embeddings(embeddings, weights))
# Return a single embedding if the input was a single string; otherwise, return a list
return results[0] if single_input else results
def __call__(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
"""
Allows the instance to be called directly to embed text(s).
Args:
input (Union[str, List[str]]): The input text(s) to embed.
Returns:
Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s).
"""
return self.embed(input)