dapr-agents/dapr_agents/workflow/task.py

396 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import inspect
import logging
from dataclasses import is_dataclass
from functools import update_wrapper
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Literal, Optional
from dapr.ext.workflow import WorkflowActivityContext
from pydantic import BaseModel, ConfigDict, Field
from dapr_agents.agents.base import AgentBase
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.openai import OpenAIChatClient
from dapr_agents.llm.utils import StructureHandler
from dapr_agents.prompt.utils.chat import ChatPromptHelper
from dapr_agents.types import BaseMessage, UserMessage, LLMChatResponse
logger = logging.getLogger(__name__)
class WorkflowTask(BaseModel):
"""
Encapsulates task logic for execution by an LLM, agent, or Python function.
Supports both synchronous and asynchronous tasks, with optional output validation
using Pydantic models or specified return types.
"""
func: Optional[Callable] = Field(
None, description="The original function to be executed, if provided."
)
description: Optional[str] = Field(
None, description="A description template for the task, used with LLM or agent."
)
agent: Optional[AgentBase] = Field(
None, description="The agent used for task execution, if applicable."
)
llm: Optional[ChatClientBase] = Field(
None, description="The LLM client for executing the task, if applicable."
)
include_chat_history: Optional[bool] = Field(
False,
description="Whether to include past conversation history in the LLM call.",
)
workflow_app: Optional[Any] = Field(
None, description="Reference to the WorkflowApp instance."
)
structured_mode: Literal["json", "function_call"] = Field(
default="json",
description="Structured response mode for LLM output. Valid values: 'json', 'function_call'.",
)
task_kwargs: Dict[str, Any] = Field(
default_factory=dict,
exclude=True,
description="Additional keyword arguments passed via the @task decorator.",
)
# Initialized during setup
signature: Optional[inspect.Signature] = Field(
None, init=False, description="The signature of the provided function."
)
model_config = ConfigDict(arbitrary_types_allowed=True)
def model_post_init(self, __context: Any) -> None:
"""
Post-initialization to set up function signatures and default LLM clients.
"""
# Default to OpenAIChatClient if promptbased but no llm provided
if self.description and not self.llm:
self.llm = OpenAIChatClient()
if self.func:
# Preserve name / docs for stack traces
update_wrapper(self, self.func)
# Capture signature for input / output handling
self.signature = inspect.signature(self.func) if self.func else None
# Honor any structured_mode override
if not self.structured_mode and "structured_mode" in self.task_kwargs:
self.structured_mode = self.task_kwargs["structured_mode"]
# Proceed with base model setup
super().model_post_init(__context)
async def __call__(self, ctx: WorkflowActivityContext, payload: Any = None) -> Any:
"""
Executes the task, routing to agent, LLM, or pure-Python logic.
Dispatches to Python, Agent, or LLM paths and validates output.
Args:
ctx (WorkflowActivityContext): The workflow execution context.
payload (Any): The task input.
Returns:
Any: The result of the task.
"""
# Prepare input dict
data = self._normalize_input(payload) if payload is not None else {}
logger.info(f"Executing task '{self.func.__name__}'")
logger.debug(f"Executing task '{self.func.__name__}' with input {data!r}")
try:
executor = self._choose_executor()
if executor in ("agent", "llm"):
if executor == "llm" and not self.description:
raise ValueError("LLM tasks require a description template")
elif executor == "agent":
# For agents, prefer string input for natural conversation
if self.description:
# Use description template with parameter substitution
prompt = self.format_description(self.description, data)
else:
# Pass string input naturally for direct agent conversation
prompt = self._format_natural_agent_input(payload, data)
else:
# LLM with description
prompt = self.format_description(self.description, data)
raw = await self._run_via_ai(prompt, executor)
else:
raw = await self._run_python(data)
validated = await self._validate_output(raw)
return validated
except Exception:
logger.exception(f"Error in task '{self.func.__name__}'")
raise
def _choose_executor(self) -> Literal["agent", "llm", "python"]:
"""
Pick execution path.
Returns:
One of "agent", "llm", or "python".
Raises:
ValueError: If no valid executor is configured.
"""
if self.agent:
return "agent"
if self.llm:
return "llm"
if self.func:
return "python"
raise ValueError("No execution path found for this task")
async def _run_python(self, data: dict) -> Any:
"""
Invoke the Python function directly.
Args:
data: Keyword arguments for the function.
Returns:
The function's return value.
"""
logger.debug("Invoking regular Python function")
if asyncio.iscoroutinefunction(self.func):
return await self.func(**data)
else:
return self.func(**data)
async def _run_via_ai(self, prompt: Any, executor: Literal["agent", "llm"]) -> Any:
"""
Run the prompt through an Agent or LLM.
Args:
prompt: The prompt data - string for LLM, string/dict/Any for agent.
executor: "agent" or "llm".
Returns:
Raw result from the AI path.
"""
logger.debug(f"Invoking task via {executor.upper()}")
logger.debug(f"Invoking task with prompt: {prompt!r}")
if executor == "agent":
# Agents can handle string, dict, or other input types
result = await self.agent.run(prompt)
else:
# LLM expects a string prompt
if not isinstance(prompt, str):
raise ValueError(f"LLM executor requires string prompt, got {type(prompt)}")
result = await self._invoke_llm(prompt)
return self._convert_result(result)
async def _invoke_llm(self, prompt: str) -> Any:
"""
Build messages and call the LLM client.
Args:
prompt: The formatted prompt string.
Returns:
LLM-generated result.
"""
# Gather history if needed
history: List[BaseMessage] = []
if self.include_chat_history and self.workflow_app:
logger.debug("Retrieving chat history")
history_dicts = self.workflow_app.get_chat_history()
history = ChatPromptHelper.normalize_chat_messages(history_dicts)
messages: List[BaseMessage] = history + [UserMessage(prompt)]
params: Dict[str, Any] = {"messages": messages}
# Add structured formatting if return type is a Pydantic model
if (
self.signature
and self.signature.return_annotation is not inspect.Signature.empty
):
model_cls = StructureHandler.resolve_response_model(
self.signature.return_annotation
)
if model_cls:
params["response_format"] = self.signature.return_annotation
params["structured_mode"] = self.structured_mode
logger.debug(f"LLM call params: {params}")
return self.llm.generate(**params)
def _normalize_input(self, raw_input: Any) -> dict:
"""
Normalize various input types into a dict.
Args:
raw_input: Dataclass, SimpleNamespace, single value, or dict.
Returns:
A dict suitable for function invocation.
Raises:
ValueError: If signature is missing when wrapping a single value.
"""
if is_dataclass(raw_input):
return raw_input.__dict__
if isinstance(raw_input, SimpleNamespace):
return vars(raw_input)
if not isinstance(raw_input, dict):
# wrap single argument
if not self.signature or len(self.signature.parameters) == 0:
# No signature or no parameters - return empty dict for consistency
return {}
name = next(iter(self.signature.parameters))
return {name: raw_input}
return raw_input
async def _validate_output(self, result: Any) -> Any:
"""
Await and validate the result against return-type model.
Args:
result: Raw result from executor.
Returns:
Validated/transformed result.
"""
if asyncio.iscoroutine(result):
result = await result
if (
not self.signature
or self.signature.return_annotation is inspect.Signature.empty
):
return result
return StructureHandler.validate_against_signature(
result, self.signature.return_annotation
)
def _convert_result(self, result: Any) -> Any:
"""
Unwrap AI return types into plain Python.
Args:
result: One of:
- LLMChatResponse
- BaseModel (Pydantic)
- List[BaseModel]
- primitive (str/int/etc) or dict
Returns:
• str (assistant content) when `LLMChatResponse`
• dict when a single BaseModel
• List[dict] when a list of BaseModels
• otherwise, the raw `result`
"""
# 1) Unwrap our unified LLMChatResponse → return the assistant's text
if isinstance(result, LLMChatResponse):
logger.debug("Extracted message content from LLMChatResponse.")
msg = result.get_message()
return getattr(msg, "content", None)
# 2) Single Pydantic model → dict
if isinstance(result, BaseModel):
logger.debug("Converting Pydantic model to dictionary.")
return result.model_dump()
# 3) List of Pydantic models → list of dicts
if isinstance(result, list) and all(isinstance(x, BaseModel) for x in result):
logger.debug("Converting list of Pydantic models to list of dictionaries.")
return [x.model_dump() for x in result]
# 4) Fallback: primitive, dict, etc.
logger.info("Returning final task result.")
return result
def format_description(self, template: str, data: dict) -> str:
"""
Interpolate inputs into the prompt template.
Args:
template: The `{}`-style template string.
data: Mapping of variable names to values.
Returns:
The fully formatted prompt.
"""
if self.signature:
bound = self.signature.bind(**data)
bound.apply_defaults()
return template.format(**bound.arguments)
return template.format(**data)
def _format_natural_agent_input(self, payload: Any, data: dict) -> str:
"""
Format input for natural agent conversation.
Favors string input over dictionary for better agent interaction.
Args:
payload: The original raw payload from the workflow
data: The normalized dictionary version
Returns:
String input for natural agent conversation
"""
if payload is None:
return ""
# If payload is already a simple string/number, use it directly
if isinstance(payload, (str, int, float, bool)):
return str(payload)
# If we have function parameters, format them naturally
if data and len(data) == 1:
# Single parameter: extract the value
value = next(iter(data.values()))
return str(value) if value is not None else ""
elif data:
# Multiple parameters: format as natural text
parts = []
for key, value in data.items():
if value is not None:
parts.append(f"{key}: {value}")
return "\n".join(parts)
else:
# Fallback to string representation of payload
return str(payload)
class TaskWrapper:
"""
A wrapper for WorkflowTask that preserves callable behavior and attributes like __name__.
"""
def __init__(self, task_instance: WorkflowTask, name: str):
"""
Initialize the TaskWrapper.
Args:
task_instance (WorkflowTask): The task instance to wrap.
name (str): The task name.
"""
self.task_instance = task_instance
self.__name__ = name
self.__doc__ = getattr(task_instance.func, "__doc__", None)
self.__module__ = getattr(task_instance.func, "__module__", None)
def __call__(self, *args, **kwargs):
"""
Delegate the call to the wrapped WorkflowTask instance.
"""
return self.task_instance(*args, **kwargs)
def __getattr__(self, item):
"""
Delegate attribute access to the wrapped task.
"""
return getattr(self.task_instance, item)
def __repr__(self):
return f"<TaskWrapper name={self.__name__}>"