dapr-agents/dapr_agents/llm/openai/chat.py

206 lines
8.3 KiB
Python

from dapr_agents.types.llm import AzureOpenAIModelConfig, OpenAIModelConfig
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.llm.openai.client.base import OpenAIClientBase
from dapr_agents.types.message import BaseMessage
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.tool import AgentTool
from typing import (
Union,
Optional,
Iterable,
Dict,
Any,
List,
Iterator,
Type,
Literal,
ClassVar,
)
from openai.types.chat import ChatCompletionMessage
from pydantic import BaseModel, Field, model_validator
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
"""
Chat client for OpenAI models.
Combines OpenAI client management with Prompty-specific functionality.
"""
model: str = Field(default=None, description="Model name to use, e.g., 'gpt-4'.")
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"json", "function_call"}
@model_validator(mode="before")
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Ensures the 'model' is set during validation.
Uses 'azure_deployment' if no model is specified, defaults to 'gpt-4o'.
"""
if "model" not in values or values["model"] is None:
values["model"] = values.get("azure_deployment", "gpt-4o")
return values
def model_post_init(self, __context: Any) -> None:
"""
Initializes chat-specific attributes after validation.
"""
self._api = "chat"
super().model_post_init(__context)
@classmethod
def from_prompty(
cls,
prompty_source: Union[str, Path],
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "OpenAIChatClient":
"""
Initializes an OpenAIChatClient client using a Prompty source, which can be a file path or inline content.
Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string.
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
Returns:
OpenAIChatClient: An instance of OpenAIChatClient configured with the model settings from the Prompty source.
"""
# Load the Prompty instance from the provided source
prompty_instance = Prompty.load(prompty_source)
# Generate the prompt template from the Prompty instance
prompt_template = Prompty.to_prompt_template(prompty_instance)
# Extract the model configuration from Prompty
model_config = prompty_instance.model
# Initialize the OpenAIChatClient instance using model_validate
if isinstance(model_config.configuration, OpenAIModelConfig):
return cls.model_validate(
{
"model": model_config.configuration.name,
"api_key": model_config.configuration.api_key,
"base_url": model_config.configuration.base_url,
"organization": model_config.configuration.organization,
"project": model_config.configuration.project,
"timeout": timeout,
"prompty": prompty_instance,
"prompt_template": prompt_template,
}
)
elif isinstance(model_config.configuration, AzureOpenAIModelConfig):
return cls.model_validate(
{
"model": model_config.configuration.azure_deployment,
"api_key": model_config.configuration.api_key,
"azure_endpoint": model_config.configuration.azure_endpoint,
"azure_deployment": model_config.configuration.azure_deployment,
"api_version": model_config.configuration.api_version,
"organization": model_config.configuration.organization,
"project": model_config.configuration.project,
"azure_ad_token": model_config.configuration.azure_ad_token,
"azure_client_id": model_config.configuration.azure_client_id,
"timeout": timeout,
"prompty": prompty_instance,
"prompt_template": prompt_template,
}
)
else:
raise ValueError(
f"Unsupported model configuration type: {type(model_config.configuration)}"
)
def generate(
self,
messages: Union[
str,
Dict[str, Any],
BaseMessage,
Iterable[Union[Dict[str, Any], BaseMessage]],
] = None,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Optional[Type[BaseModel]] = None,
structured_mode: Literal["json", "function_call"] = "json",
**kwargs,
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
"""
Generate chat completions based on provided messages or input_data for prompt templates.
Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
model (str): Specific model to use for the request, overriding the default.
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
structured_mode (Literal["json", "function_call"]): Mode for structured output: "json" or "function_call".
**kwargs: Additional parameters for the language model.
Returns:
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
"""
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
raise ValueError(
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
)
# If input_data is provided, check for a prompt_template
if input_data:
if not self.prompt_template:
raise ValueError(
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
)
logger.info("Using prompt template to generate messages.")
messages = self.prompt_template.format_prompt(**input_data)
# Ensure we have messages at this point
if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
# Process and normalize the messages
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
# Merge prompty parameters if available, then override with any explicit kwargs
if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else:
params.update(kwargs)
# If a model is provided, override the default model
params["model"] = model or self.model
# Prepare request parameters
params = RequestHandler.process_params(
params,
llm_provider=self.provider,
tools=tools,
response_format=response_format,
structured_mode=structured_mode,
)
try:
logger.info("Invoking ChatCompletion API.")
logger.debug(f"ChatCompletion API Parameters: {params}")
response: ChatCompletionMessage = self.client.chat.completions.create(
**params, timeout=self.timeout
)
logger.info("Chat completion retrieved successfully.")
return ResponseHandler.process_response(
response,
llm_provider=self.provider,
response_format=response_format,
structured_mode=structured_mode,
stream=params.get("stream", False),
)
except Exception as e:
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
raise