mirror of https://github.com/dapr/dapr-agents.git
96 lines
3.8 KiB
Python
96 lines
3.8 KiB
Python
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
|
from dapr_agents.llm.openai.client.base import OpenAIClientBase
|
|
from typing import Union, Dict, Any, Literal, List, Optional
|
|
from pydantic import Field, model_validator
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OpenAIEmbeddingClient(OpenAIClientBase):
|
|
"""
|
|
Client for handling OpenAI's embedding functionalities, supporting both OpenAI and Azure OpenAI configurations.
|
|
|
|
Attributes:
|
|
model (str): The ID of the model to use for embedding. Defaults to `text-embedding-ada-002` if not specified.
|
|
encoding_format (Optional[Literal["float", "base64"]]): The format of the embeddings. Defaults to 'float'.
|
|
dimensions (Optional[int]): Number of dimensions for the output embeddings. Only supported in specific models like `text-embedding-3`.
|
|
user (Optional[str]): A unique identifier representing the end-user.
|
|
"""
|
|
|
|
model: str = Field(
|
|
default=None, description="ID of the model to use for embedding."
|
|
)
|
|
encoding_format: Optional[Literal["float", "base64"]] = Field(
|
|
"float", description="Format for the embeddings. Defaults to 'float'."
|
|
)
|
|
dimensions: Optional[int] = Field(
|
|
None,
|
|
description="Number of dimensions for the output embeddings. Supported in text-embedding-3 and later models.",
|
|
)
|
|
user: Optional[str] = Field(
|
|
None, description="Unique identifier representing the end-user."
|
|
)
|
|
|
|
@model_validator(mode="before")
|
|
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Ensures that the 'model' attribute is set during validation.
|
|
If 'model' is not provided, defaults to 'text-embedding-ada-002' or uses 'azure_deployment' if available.
|
|
|
|
Args:
|
|
values (Dict[str, Any]): Dictionary of model attributes to validate and initialize.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Updated dictionary of validated attributes.
|
|
"""
|
|
if "model" not in values or values["model"] is None:
|
|
values["model"] = values.get("azure_deployment", "text-embedding-ada-002")
|
|
return values
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Post-initialization setup for private attributes.
|
|
|
|
This method configures the API endpoint for embedding operations.
|
|
|
|
Args:
|
|
__context (Any): Context provided during model initialization.
|
|
"""
|
|
self._api = "embeddings"
|
|
return super().model_post_init(__context)
|
|
|
|
def create_embedding(
|
|
self,
|
|
input: Union[str, List[Union[str, List[int]]]],
|
|
model: Optional[str] = None,
|
|
) -> CreateEmbeddingResponse:
|
|
"""
|
|
Generate embeddings for the given input text(s).
|
|
|
|
Args:
|
|
input (Union[str, List[Union[str, List[int]]]]): Input text(s) or tokenized input(s) to generate embeddings for.
|
|
- A single string for one input.
|
|
- A list of strings or tokenized lists for multiple inputs.
|
|
model (Optional[str]): Model to use for embedding. Overrides the default model if provided.
|
|
|
|
Returns:
|
|
CreateEmbeddingResponse: A response object containing the generated embeddings and associated metadata.
|
|
|
|
Raises:
|
|
ValueError: If the client fails to generate embeddings.
|
|
"""
|
|
logger.info(f"Using model '{self.model}' for embedding generation.")
|
|
|
|
# If a model is provided, override the default model
|
|
model = model or self.model
|
|
|
|
response = self.client.embeddings.create(
|
|
model=model,
|
|
input=input,
|
|
encoding_format=self.encoding_format,
|
|
dimensions=self.dimensions,
|
|
user=self.user,
|
|
)
|
|
return response
|