dapr-agents/dapr_agents/llm/openai/chat.py

152 lines
7.2 KiB
Python

from dapr_agents.types.llm import AzureOpenAIModelConfig, OpenAIModelConfig
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.llm.openai.client.base import OpenAIClientBase
from dapr_agents.types.message import BaseMessage
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.tool import AgentTool
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type
from openai.types.chat import ChatCompletionMessage
from pydantic import BaseModel, Field, model_validator
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
"""
Chat client for OpenAI models.
Combines OpenAI client management with Prompty-specific functionality.
"""
model: str = Field(default=None, description="Model name to use, e.g., 'gpt-4'.")
@model_validator(mode="before")
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Ensures the 'model' is set during validation.
Uses 'azure_deployment' if no model is specified, defaults to 'gpt-4o'.
"""
if 'model' not in values or values['model'] is None:
values['model'] = values.get('azure_deployment', 'gpt-4o')
return values
def model_post_init(self, __context: Any) -> None:
"""
Initializes chat-specific attributes after validation.
"""
self._api = "chat"
super().model_post_init(__context)
@classmethod
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'OpenAIChatClient':
"""
Initializes an OpenAIChatClient client using a Prompty source, which can be a file path or inline content.
Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string.
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
Returns:
OpenAIChatClient: An instance of OpenAIChatClient configured with the model settings from the Prompty source.
"""
# Load the Prompty instance from the provided source
prompty_instance = Prompty.load(prompty_source)
# Generate the prompt template from the Prompty instance
prompt_template = Prompty.to_prompt_template(prompty_instance)
# Extract the model configuration from Prompty
model_config = prompty_instance.model
# Initialize the OpenAIChatClient instance using model_validate
if isinstance(model_config.configuration, OpenAIModelConfig):
return cls.model_validate({
'model': model_config.configuration.name,
'api_key': model_config.configuration.api_key,
'base_url': model_config.configuration.base_url,
'organization': model_config.configuration.organization,
'project': model_config.configuration.project,
'timeout': timeout,
'prompty': prompty_instance,
'prompt_template': prompt_template,
})
elif isinstance(model_config.configuration, AzureOpenAIModelConfig):
return cls.model_validate({
'model': model_config.configuration.azure_deployment,
'api_key': model_config.configuration.api_key,
'azure_endpoint': model_config.configuration.azure_endpoint,
'azure_deployment': model_config.configuration.azure_deployment,
'api_version': model_config.configuration.api_version,
'organization': model_config.configuration.organization,
'project': model_config.configuration.project,
'azure_ad_token': model_config.configuration.azure_ad_token,
'azure_client_id': model_config.configuration.azure_client_id,
'timeout': timeout,
'prompty': prompty_instance,
'prompt_template': prompt_template,
})
else:
raise ValueError(f"Unsupported model configuration type: {type(model_config.configuration)}")
def generate(
self,
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_model: Optional[Type[BaseModel]] = None,
**kwargs
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
"""
Generate chat completions based on provided messages or input_data for prompt templates.
Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
model (str): Specific model to use for the request, overriding the default.
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
response_model (Type[BaseModel]): Optional Pydantic model for structured response parsing.
**kwargs: Additional parameters for the language model.
Returns:
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
"""
# If input_data is provided, check for a prompt_template
if input_data:
if not self.prompt_template:
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.")
logger.info("Using prompt template to generate messages.")
messages = self.prompt_template.format_prompt(**input_data)
# Ensure we have messages at this point
if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
# Process and normalize the messages
params = {'messages': RequestHandler.normalize_chat_messages(messages)}
# Merge prompty parameters if available, then override with any explicit kwargs
if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else:
params.update(kwargs)
# If a model is provided, override the default model
params['model'] = model or self.model
# Prepare and send the request
params = RequestHandler.process_params(params, llm_provider=self.provider, tools=tools, response_model=response_model)
try:
logger.info("Invoking ChatCompletion API.")
logger.debug(f"ChatCompletion API Parameters:{params}")
response: ChatCompletionMessage = self.client.chat.completions.create(**params, timeout=self.timeout)
logger.info("Chat completion retrieved successfully.")
return ResponseHandler.process_response(response, llm_provider=self.provider, response_model=response_model, stream=params.get('stream', False))
except Exception as e:
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
raise