mirror of https://github.com/dapr/dapr-agents.git
333 lines
11 KiB
Python
333 lines
11 KiB
Python
import asyncio
|
||
import inspect
|
||
import logging
|
||
from dataclasses import is_dataclass
|
||
from functools import update_wrapper
|
||
from types import SimpleNamespace
|
||
from typing import Any, Callable, Dict, List, Literal, Optional
|
||
|
||
from pydantic import BaseModel, ConfigDict, Field
|
||
|
||
from dapr.ext.workflow import WorkflowActivityContext
|
||
|
||
from dapr_agents.agent.base import AgentBase
|
||
from dapr_agents.llm.chat import ChatClientBase
|
||
from dapr_agents.llm.openai import OpenAIChatClient
|
||
from dapr_agents.llm.utils import StructureHandler
|
||
from dapr_agents.types import BaseMessage, ChatCompletion, UserMessage
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class WorkflowTask(BaseModel):
|
||
"""
|
||
Encapsulates task logic for execution by an LLM, agent, or Python function.
|
||
|
||
Supports both synchronous and asynchronous tasks, with optional output validation
|
||
using Pydantic models or specified return types.
|
||
"""
|
||
|
||
func: Optional[Callable] = Field(
|
||
None, description="The original function to be executed, if provided."
|
||
)
|
||
description: Optional[str] = Field(
|
||
None, description="A description template for the task, used with LLM or agent."
|
||
)
|
||
agent: Optional[AgentBase] = Field(
|
||
None, description="The agent used for task execution, if applicable."
|
||
)
|
||
llm: Optional[ChatClientBase] = Field(
|
||
None, description="The LLM client for executing the task, if applicable."
|
||
)
|
||
include_chat_history: Optional[bool] = Field(
|
||
False,
|
||
description="Whether to include past conversation history in the LLM call.",
|
||
)
|
||
workflow_app: Optional[Any] = Field(
|
||
None, description="Reference to the WorkflowApp instance."
|
||
)
|
||
structured_mode: Literal["json", "function_call"] = Field(
|
||
default="json",
|
||
description="Structured response mode for LLM output. Valid values: 'json', 'function_call'.",
|
||
)
|
||
task_kwargs: Dict[str, Any] = Field(
|
||
default_factory=dict,
|
||
exclude=True,
|
||
description="Additional keyword arguments passed via the @task decorator.",
|
||
)
|
||
|
||
# Initialized during setup
|
||
signature: Optional[inspect.Signature] = Field(
|
||
None, init=False, description="The signature of the provided function."
|
||
)
|
||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||
|
||
def model_post_init(self, __context: Any) -> None:
|
||
"""
|
||
Post-initialization to set up function signatures and default LLM clients.
|
||
"""
|
||
# Default to OpenAIChatClient if prompt‐based but no llm provided
|
||
if self.description and not self.llm:
|
||
self.llm = OpenAIChatClient()
|
||
|
||
if self.func:
|
||
# Preserve name / docs for stack traces
|
||
update_wrapper(self, self.func)
|
||
|
||
# Capture signature for input / output handling
|
||
self.signature = inspect.signature(self.func) if self.func else None
|
||
|
||
# Honor any structured_mode override
|
||
if not self.structured_mode and "structured_mode" in self.task_kwargs:
|
||
self.structured_mode = self.task_kwargs["structured_mode"]
|
||
|
||
# Proceed with base model setup
|
||
super().model_post_init(__context)
|
||
|
||
async def __call__(self, ctx: WorkflowActivityContext, payload: Any = None) -> Any:
|
||
"""
|
||
Executes the task, routing to agent, LLM, or pure-Python logic.
|
||
|
||
Dispatches to Python, Agent, or LLM paths and validates output.
|
||
|
||
Args:
|
||
ctx (WorkflowActivityContext): The workflow execution context.
|
||
payload (Any): The task input.
|
||
|
||
Returns:
|
||
Any: The result of the task.
|
||
"""
|
||
# Prepare input dict
|
||
data = self._normalize_input(payload) if payload is not None else {}
|
||
logger.info(f"Executing task '{self.func.__name__}'")
|
||
logger.debug(f"Executing task '{self.func.__name__}' with input {data!r}")
|
||
|
||
try:
|
||
executor = self._choose_executor()
|
||
if executor in ("agent", "llm"):
|
||
if not self.description:
|
||
raise ValueError("LLM/agent tasks require a description template")
|
||
prompt = self.format_description(self.description, data)
|
||
raw = await self._run_via_ai(prompt, executor)
|
||
else:
|
||
raw = await self._run_python(data)
|
||
|
||
validated = await self._validate_output(raw)
|
||
return validated
|
||
|
||
except Exception:
|
||
logger.exception(f"Error in task '{self.func.__name__}'")
|
||
raise
|
||
|
||
def _choose_executor(self) -> Literal["agent", "llm", "python"]:
|
||
"""
|
||
Pick execution path.
|
||
|
||
Returns:
|
||
One of "agent", "llm", or "python".
|
||
|
||
Raises:
|
||
ValueError: If no valid executor is configured.
|
||
"""
|
||
if self.agent:
|
||
return "agent"
|
||
if self.llm:
|
||
return "llm"
|
||
if self.func:
|
||
return "python"
|
||
raise ValueError("No execution path found for this task")
|
||
|
||
async def _run_python(self, data: dict) -> Any:
|
||
"""
|
||
Invoke the Python function directly.
|
||
|
||
Args:
|
||
data: Keyword arguments for the function.
|
||
|
||
Returns:
|
||
The function's return value.
|
||
"""
|
||
logger.info("Invoking regular Python function")
|
||
if asyncio.iscoroutinefunction(self.func):
|
||
return await self.func(**data)
|
||
else:
|
||
return self.func(**data)
|
||
|
||
async def _run_via_ai(self, prompt: str, executor: Literal["agent", "llm"]) -> Any:
|
||
"""
|
||
Run the prompt through an Agent or LLM.
|
||
|
||
Args:
|
||
prompt: The fully formatted prompt string.
|
||
kind: "agent" or "llm".
|
||
|
||
Returns:
|
||
Raw result from the AI path.
|
||
"""
|
||
logger.info(f"Invoking task via {executor.upper()}")
|
||
logger.debug(f"Invoking task with prompt: {prompt!r}")
|
||
if executor == "agent":
|
||
result = await self.agent.run(prompt)
|
||
else:
|
||
result = await self._invoke_llm(prompt)
|
||
return self._convert_result(result)
|
||
|
||
async def _invoke_llm(self, prompt: str) -> Any:
|
||
"""
|
||
Build messages and call the LLM client.
|
||
|
||
Args:
|
||
prompt: The formatted prompt string.
|
||
|
||
Returns:
|
||
LLM-generated result.
|
||
"""
|
||
# Gather history if needed
|
||
history: List[BaseMessage] = []
|
||
if self.include_chat_history and self.workflow_app:
|
||
logger.debug("Retrieving chat history")
|
||
history = self.workflow_app.get_chat_history()
|
||
|
||
messages: List[BaseMessage] = history + [UserMessage(prompt)]
|
||
params: Dict[str, Any] = {"messages": messages}
|
||
|
||
# Add structured formatting if return type is a Pydantic model
|
||
if (
|
||
self.signature
|
||
and self.signature.return_annotation is not inspect.Signature.empty
|
||
):
|
||
model_cls = StructureHandler.resolve_response_model(
|
||
self.signature.return_annotation
|
||
)
|
||
if model_cls:
|
||
params["response_format"] = self.signature.return_annotation
|
||
params["structured_mode"] = self.structured_mode
|
||
|
||
logger.debug(f"LLM call params: {params}")
|
||
return self.llm.generate(**params)
|
||
|
||
def _normalize_input(self, raw_input: Any) -> dict:
|
||
"""
|
||
Normalize various input types into a dict.
|
||
|
||
Args:
|
||
raw_input: Dataclass, SimpleNamespace, single value, or dict.
|
||
|
||
Returns:
|
||
A dict suitable for function invocation.
|
||
|
||
Raises:
|
||
ValueError: If signature is missing when wrapping a single value.
|
||
"""
|
||
if is_dataclass(raw_input):
|
||
return raw_input.__dict__
|
||
if isinstance(raw_input, SimpleNamespace):
|
||
return vars(raw_input)
|
||
if not isinstance(raw_input, dict):
|
||
# wrap single argument
|
||
if not self.signature:
|
||
raise ValueError("Cannot infer param name without signature")
|
||
name = next(iter(self.signature.parameters))
|
||
return {name: raw_input}
|
||
return raw_input
|
||
|
||
async def _validate_output(self, result: Any) -> Any:
|
||
"""
|
||
Await and validate the result against return-type model.
|
||
|
||
Args:
|
||
result: Raw result from executor.
|
||
|
||
Returns:
|
||
Validated/transformed result.
|
||
"""
|
||
if asyncio.iscoroutine(result):
|
||
result = await result
|
||
|
||
if (
|
||
not self.signature
|
||
or self.signature.return_annotation is inspect.Signature.empty
|
||
):
|
||
return result
|
||
|
||
return StructureHandler.validate_against_signature(
|
||
result, self.signature.return_annotation
|
||
)
|
||
|
||
def _convert_result(self, result: Any) -> Any:
|
||
"""
|
||
Unwrap AI return types into plain Python.
|
||
|
||
Args:
|
||
result: ChatCompletion, BaseModel, or list of BaseModel.
|
||
|
||
Returns:
|
||
A primitive, dict, or list of dicts.
|
||
"""
|
||
# Unwrap ChatCompletion
|
||
if isinstance(result, ChatCompletion):
|
||
logger.debug("Extracted message content from ChatCompletion.")
|
||
return result.get_content()
|
||
# Pydantic → dict
|
||
if isinstance(result, BaseModel):
|
||
logger.debug("Converting Pydantic model to dictionary.")
|
||
return result.model_dump()
|
||
if isinstance(result, list) and all(isinstance(x, BaseModel) for x in result):
|
||
logger.debug("Converting list of Pydantic models to list of dictionaries.")
|
||
return [x.model_dump() for x in result]
|
||
# If no specific conversion is necessary, return as-is
|
||
logger.info("Returning final task result.")
|
||
return result
|
||
|
||
def format_description(self, template: str, data: dict) -> str:
|
||
"""
|
||
Interpolate inputs into the prompt template.
|
||
|
||
Args:
|
||
template: The `{}`-style template string.
|
||
data: Mapping of variable names to values.
|
||
|
||
Returns:
|
||
The fully formatted prompt.
|
||
"""
|
||
if self.signature:
|
||
bound = self.signature.bind(**data)
|
||
bound.apply_defaults()
|
||
return template.format(**bound.arguments)
|
||
return template.format(**data)
|
||
|
||
|
||
class TaskWrapper:
|
||
"""
|
||
A wrapper for WorkflowTask that preserves callable behavior and attributes like __name__.
|
||
"""
|
||
|
||
def __init__(self, task_instance: WorkflowTask, name: str):
|
||
"""
|
||
Initialize the TaskWrapper.
|
||
|
||
Args:
|
||
task_instance (WorkflowTask): The task instance to wrap.
|
||
name (str): The task name.
|
||
"""
|
||
self.task_instance = task_instance
|
||
self.__name__ = name
|
||
self.__doc__ = getattr(task_instance.func, "__doc__", None)
|
||
self.__module__ = getattr(task_instance.func, "__module__", None)
|
||
|
||
def __call__(self, *args, **kwargs):
|
||
"""
|
||
Delegate the call to the wrapped WorkflowTask instance.
|
||
"""
|
||
return self.task_instance(*args, **kwargs)
|
||
|
||
def __getattr__(self, item):
|
||
"""
|
||
Delegate attribute access to the wrapped task.
|
||
"""
|
||
return getattr(self.task_instance, item)
|
||
|
||
def __repr__(self):
|
||
return f"<TaskWrapper name={self.__name__}>"
|