mirror of https://github.com/dapr/dapr-agents.git
206 lines
8.3 KiB
Python
206 lines
8.3 KiB
Python
from dapr_agents.types.llm import AzureOpenAIModelConfig, OpenAIModelConfig
|
|
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
|
|
from dapr_agents.llm.openai.client.base import OpenAIClientBase
|
|
from dapr_agents.types.message import BaseMessage
|
|
from dapr_agents.llm.chat import ChatClientBase
|
|
from dapr_agents.prompt.prompty import Prompty
|
|
from dapr_agents.tool import AgentTool
|
|
from typing import (
|
|
Union,
|
|
Optional,
|
|
Iterable,
|
|
Dict,
|
|
Any,
|
|
List,
|
|
Iterator,
|
|
Type,
|
|
Literal,
|
|
ClassVar,
|
|
)
|
|
from openai.types.chat import ChatCompletionMessage
|
|
from pydantic import BaseModel, Field, model_validator
|
|
from pathlib import Path
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|
"""
|
|
Chat client for OpenAI models.
|
|
Combines OpenAI client management with Prompty-specific functionality.
|
|
"""
|
|
|
|
model: str = Field(default=None, description="Model name to use, e.g., 'gpt-4'.")
|
|
|
|
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"json", "function_call"}
|
|
|
|
@model_validator(mode="before")
|
|
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Ensures the 'model' is set during validation.
|
|
Uses 'azure_deployment' if no model is specified, defaults to 'gpt-4o'.
|
|
"""
|
|
if "model" not in values or values["model"] is None:
|
|
values["model"] = values.get("azure_deployment", "gpt-4o")
|
|
return values
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Initializes chat-specific attributes after validation.
|
|
"""
|
|
self._api = "chat"
|
|
super().model_post_init(__context)
|
|
|
|
@classmethod
|
|
def from_prompty(
|
|
cls,
|
|
prompty_source: Union[str, Path],
|
|
timeout: Union[int, float, Dict[str, Any]] = 1500,
|
|
) -> "OpenAIChatClient":
|
|
"""
|
|
Initializes an OpenAIChatClient client using a Prompty source, which can be a file path or inline content.
|
|
|
|
Args:
|
|
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
|
|
or inline Prompty content as a string.
|
|
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
|
|
|
|
Returns:
|
|
OpenAIChatClient: An instance of OpenAIChatClient configured with the model settings from the Prompty source.
|
|
"""
|
|
# Load the Prompty instance from the provided source
|
|
prompty_instance = Prompty.load(prompty_source)
|
|
|
|
# Generate the prompt template from the Prompty instance
|
|
prompt_template = Prompty.to_prompt_template(prompty_instance)
|
|
|
|
# Extract the model configuration from Prompty
|
|
model_config = prompty_instance.model
|
|
|
|
# Initialize the OpenAIChatClient instance using model_validate
|
|
if isinstance(model_config.configuration, OpenAIModelConfig):
|
|
return cls.model_validate(
|
|
{
|
|
"model": model_config.configuration.name,
|
|
"api_key": model_config.configuration.api_key,
|
|
"base_url": model_config.configuration.base_url,
|
|
"organization": model_config.configuration.organization,
|
|
"project": model_config.configuration.project,
|
|
"timeout": timeout,
|
|
"prompty": prompty_instance,
|
|
"prompt_template": prompt_template,
|
|
}
|
|
)
|
|
elif isinstance(model_config.configuration, AzureOpenAIModelConfig):
|
|
return cls.model_validate(
|
|
{
|
|
"model": model_config.configuration.azure_deployment,
|
|
"api_key": model_config.configuration.api_key,
|
|
"azure_endpoint": model_config.configuration.azure_endpoint,
|
|
"azure_deployment": model_config.configuration.azure_deployment,
|
|
"api_version": model_config.configuration.api_version,
|
|
"organization": model_config.configuration.organization,
|
|
"project": model_config.configuration.project,
|
|
"azure_ad_token": model_config.configuration.azure_ad_token,
|
|
"azure_client_id": model_config.configuration.azure_client_id,
|
|
"timeout": timeout,
|
|
"prompty": prompty_instance,
|
|
"prompt_template": prompt_template,
|
|
}
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported model configuration type: {type(model_config.configuration)}"
|
|
)
|
|
|
|
def generate(
|
|
self,
|
|
messages: Union[
|
|
str,
|
|
Dict[str, Any],
|
|
BaseMessage,
|
|
Iterable[Union[Dict[str, Any], BaseMessage]],
|
|
] = None,
|
|
input_data: Optional[Dict[str, Any]] = None,
|
|
model: Optional[str] = None,
|
|
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
|
|
response_format: Optional[Type[BaseModel]] = None,
|
|
structured_mode: Literal["json", "function_call"] = "json",
|
|
**kwargs,
|
|
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
|
|
"""
|
|
Generate chat completions based on provided messages or input_data for prompt templates.
|
|
|
|
Args:
|
|
messages (Optional): Either pre-set messages or None if using input_data.
|
|
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
|
|
model (str): Specific model to use for the request, overriding the default.
|
|
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
|
|
response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
|
|
structured_mode (Literal["json", "function_call"]): Mode for structured output: "json" or "function_call".
|
|
**kwargs: Additional parameters for the language model.
|
|
|
|
Returns:
|
|
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
|
|
"""
|
|
|
|
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
|
|
raise ValueError(
|
|
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
|
|
)
|
|
|
|
# If input_data is provided, check for a prompt_template
|
|
if input_data:
|
|
if not self.prompt_template:
|
|
raise ValueError(
|
|
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
|
|
)
|
|
|
|
logger.info("Using prompt template to generate messages.")
|
|
messages = self.prompt_template.format_prompt(**input_data)
|
|
|
|
# Ensure we have messages at this point
|
|
if not messages:
|
|
raise ValueError("Either 'messages' or 'input_data' must be provided.")
|
|
|
|
# Process and normalize the messages
|
|
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
|
|
|
|
# Merge prompty parameters if available, then override with any explicit kwargs
|
|
if self.prompty:
|
|
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
|
|
else:
|
|
params.update(kwargs)
|
|
|
|
# If a model is provided, override the default model
|
|
params["model"] = model or self.model
|
|
|
|
# Prepare request parameters
|
|
params = RequestHandler.process_params(
|
|
params,
|
|
llm_provider=self.provider,
|
|
tools=tools,
|
|
response_format=response_format,
|
|
structured_mode=structured_mode,
|
|
)
|
|
|
|
try:
|
|
logger.info("Invoking ChatCompletion API.")
|
|
logger.debug(f"ChatCompletion API Parameters: {params}")
|
|
response: ChatCompletionMessage = self.client.chat.completions.create(
|
|
**params, timeout=self.timeout
|
|
)
|
|
logger.info("Chat completion retrieved successfully.")
|
|
|
|
return ResponseHandler.process_response(
|
|
response,
|
|
llm_provider=self.provider,
|
|
response_format=response_format,
|
|
structured_mode=structured_mode,
|
|
stream=params.get("stream", False),
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
|
|
raise
|