Hierarchical LLM config on Tasks + workflow/ decorator refactor (#92)

* updated dependencies versions

* Updated model validator for HuggingFace-Hub client to catch model and hub url early

* split task registration into discover + register phases, improved LLM client init in tasks and workflow wrappers

* improve decorator to access method attributes

* wrapping workflow decotrators to log, validate, etc., without losing signature/docs

* Improved LLM-based task client and cleaned execution of LLM, agent and python function

* Added an example of multiple models being defined per workflow task after updates

* Updated quickstarts basic agent runs to async

* Added model attribute to huggingface-hub client class

* Fixed random and roundrobin orchestrators TriggerAction schema, trigger action and task to process agent response

* Updated quickstart multi-agent workflows and actor-based agents docs

* added .dapr to gitignore
This commit is contained in:
Roberto Rodriguez 2025-04-22 09:17:36 -04:00 committed by GitHub
parent b939d7d2f5
commit 099dc5d2fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 575 additions and 459 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
.DS_Store
secrets.json
test
.dapr
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -0,0 +1,60 @@
from dapr_agents import OpenAIChatClient, NVIDIAChatClient
from dapr_agents.types import DaprWorkflowContext
from dapr_agents. workflow import WorkflowApp, task, workflow
from dotenv import load_dotenv
import os
import logging
load_dotenv()
nvidia_llm = NVIDIAChatClient(
model="meta/llama-3.1-8b-instruct",
api_key=os.getenv("NVIDIA_API_KEY")
)
oai_llm = OpenAIChatClient(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o",
base_url=os.getenv("OPENAI_API_BASE_URL"),
)
azoai_llm = OpenAIChatClient(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
azure_deployment="gpt-4o-mini",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
azure_api_version="2024-12-01-preview"
)
@workflow
def test_workflow(ctx: DaprWorkflowContext):
"""
A simple workflow that uses a multi-modal task chain.
"""
oai_results = yield ctx.call_activity(invoke_oai, input="Peru")
azoai_results = yield ctx.call_activity(invoke_azoai, input=oai_results)
nvidia_results = yield ctx.call_activity(invoke_nvidia, input=azoai_results)
return nvidia_results
@task(description="What is the name of the capital of {country}?. Reply with just the name.", llm=oai_llm)
def invoke_oai(country: str) -> str:
pass
@task(description="What is a famous thing about {capital}?", llm=azoai_llm)
def invoke_azoai(capital: str) -> str:
pass
@task(description="Context: {context}. From the previous context. Pick one thing to do.", llm=nvidia_llm)
def invoke_nvidia(context: str) -> str:
pass
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(workflow=test_workflow)
logging.info("Workflow results: %s", results)
logging.info("Workflow completed successfully.")

View File

@ -52,9 +52,21 @@ class HFHubInferenceClientBase(LLMClientBase):
values['api_key'] = api_key
# Ensure mutual exclusivity of `model` and `base_url`
# mutualexclusivity
if model is not None and base_url is not None:
raise ValueError("Cannot provide both 'model' and 'base_url'. They are mutually exclusive.")
raise ValueError("Cannot provide both 'model' and 'base_url'.")
# require at least one
if model is None and base_url is None:
raise ValueError(
"HF Inference needs either `model` or `base_url`. "
"E.g. model='gpt2' or base_url='https://…/models/gpt2'."
)
# autoderive model from base_url
if model is None:
derived = base_url.rstrip("/").split("/")[-1]
values["model"] = derived
return values

View File

@ -22,7 +22,7 @@ from dapr.ext.workflow.workflow_state import WorkflowState
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.types.workflow import DaprWorkflowStatus
from dapr_agents.workflow.task import WorkflowTask
from dapr_agents.workflow.utils import get_callable_decorated_methods
from dapr_agents.workflow.utils import get_decorated_methods
logger = logging.getLogger(__name__)
@ -48,27 +48,148 @@ class WorkflowApp(BaseModel):
def model_post_init(self, __context: Any) -> None:
"""
Post-initialization configuration for the WorkflowApp.
Initializes the Dapr Workflow runtime, client, and state store, and ensures
that workflows and tasks are registered.
Initialize the Dapr workflow runtime and register tasks & workflows.
"""
# Initialize WorkflowRuntime and DaprWorkflowClient
# initialize clients and runtime
self.wf_runtime = WorkflowRuntime()
self.wf_runtime_is_running = False
self.wf_client = DaprWorkflowClient()
self.client = DaprClient()
logger.info("WorkflowApp initialized; discovering tasks and workflows.")
logger.info(f"Initialized WorkflowApp.")
# Discover and register
discovered_tasks = self._discover_tasks()
self._register_tasks(discovered_tasks)
discovered_wfs = self._discover_workflows()
self._register_workflows(discovered_wfs)
# Register workflows and tasks after the instance is created
self.register_all_workflows()
self.register_all_tasks()
# Proceed with base model setup
super().model_post_init(__context)
def get_chat_history(self) -> List[Any]:
"""
Stub for fetching past conversation history. Override in subclasses.
"""
logger.debug("Fetching chat history (default stub)")
return []
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
"""
Encapsulate LLM selection logic.
1. Use per-task override if provided on decorator.
2. Else if description-based, fall back to default app LLM.
3. Otherwise, None.
"""
per_task = getattr(method, '_task_llm', None)
if per_task:
return per_task
if getattr(method, '_explicit_llm', False):
return self.llm
return None
def _discover_tasks(self) -> Dict[str, Callable]:
"""Gather all @task-decorated functions and methods."""
module = sys.modules['__main__']
tasks: Dict[str, Callable] = {}
# free functions
for name, fn in inspect.getmembers(module, inspect.isfunction):
if getattr(fn, '_is_task', False) and fn.__module__ == module.__name__:
tasks[getattr(fn, '_task_name', name)] = fn
# bound methods
for name, method in get_decorated_methods(self, '_is_task').items():
tasks[getattr(method, '_task_name', name)] = method
logger.debug(f"Discovered tasks: {list(tasks)}")
return tasks
def _register_tasks(self, tasks: Dict[str, Callable]) -> None:
"""Register each discovered task with the Dapr runtime."""
for task_name, method in tasks.items():
llm = self._choose_llm_for(method)
logger.debug(f"Registering task '{task_name}' with llm={getattr(llm, '__class__', None)}")
kwargs = getattr(method, '_task_kwargs', {})
task_instance = WorkflowTask(
func=method,
description=getattr(method, '_task_description', None),
agent=getattr(method, '_task_agent', None),
llm=llm,
include_chat_history=getattr(method, '_task_include_chat_history', False),
workflow_app=self,
**kwargs
)
# wrap for Dapr
wrapped = self._make_task_wrapper(task_name, method, task_instance)
activity_decorator = self.wf_runtime.activity(name=task_name)
self.tasks[task_name] = activity_decorator(wrapped)
def _make_task_wrapper(
self,
task_name: str,
method: Callable,
task_instance: WorkflowTask
) -> Callable:
"""Produce the function that Dapr will invoke for each activity."""
def run_sync(coro):
try:
loop = asyncio.get_running_loop()
return loop.run_until_complete(coro)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
@functools.wraps(method)
def wrapper(ctx: WorkflowActivityContext, *args, **kwargs):
wf_ctx = WorkflowActivityContext(ctx)
try:
call = task_instance(wf_ctx, *args, **kwargs)
if asyncio.iscoroutine(call):
return run_sync(call)
return call
except Exception as e:
logger.exception(f"Task '{task_name}' failed")
raise
return wrapper
def _discover_workflows(self) -> Dict[str, Callable]:
"""Gather all @workflow-decorated functions and methods."""
module = sys.modules['__main__']
wfs: Dict[str, Callable] = {}
for name, fn in inspect.getmembers(module, inspect.isfunction):
if getattr(fn, '_is_workflow', False) and fn.__module__ == module.__name__:
wfs[getattr(fn, '_workflow_name', name)] = fn
for name, method in get_decorated_methods(self, '_is_workflow').items():
wfs[getattr(method, '_workflow_name', name)] = method
logger.info(f"Discovered workflows: {list(wfs)}")
return wfs
def _register_workflows(self, wfs: Dict[str, Callable]) -> None:
"""Register each discovered workflow with the Dapr runtime."""
for wf_name, method in wfs.items():
@functools.wraps(method)
def wrapped(*args, **kwargs):
return method(*args, **kwargs)
decorator = self.wf_runtime.workflow(name=wf_name)
self.workflows[wf_name] = decorator(wrapped)
def start_runtime(self):
"""Idempotently start the Dapr workflow runtime."""
if not self.wf_runtime_is_running:
logger.info("Starting workflow runtime.")
self.wf_runtime.start()
self.wf_runtime_is_running = True
else:
logger.debug("Workflow runtime already running; skipping.")
def stop_runtime(self):
"""Idempotently stop the Dapr workflow runtime."""
if self.wf_runtime_is_running:
logger.info("Stopping workflow runtime.")
self.wf_runtime.shutdown()
self.wf_runtime_is_running = False
else:
logger.debug("Workflow runtime already stopped; skipping.")
def register_agent(self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict) -> None:
"""
Merges the existing data with the new data and updates the store.
@ -142,132 +263,6 @@ class WorkflowApp(BaseModel):
except Exception as e:
logger.warning(f"Error retrieving data for key '{key}' from store '{store_name}'")
return None
def register_all_tasks(self):
"""
Registers all collected tasks with Dapr while preserving execution logic.
"""
current_module = sys.modules["__main__"]
all_functions = {}
for name, func in inspect.getmembers(current_module, inspect.isfunction):
if hasattr(func, "_is_task") and func.__module__ == current_module.__name__:
task_name = getattr(func, "_task_name", None) or name
all_functions[task_name] = func
# Load instance methods that are tasks
task_methods = get_callable_decorated_methods(self, "_is_task")
for method_name, method in task_methods.items():
task_name = getattr(method, "_task_name", method_name)
all_functions[task_name] = method
logger.debug(f"Discovered tasks: {list(all_functions.keys())}")
def make_task_wrapper(method):
"""Creates a unique task wrapper bound to a specific method reference."""
# Extract stored metadata from the function
task_name = getattr(method, "_task_name", method.__name__)
explicit_llm = getattr(method, "_explicit_llm", False)
# Always initialize `llm` as `None` explicitly first
llm = None
# If task is explicitly LLM-based, but has no LLM, use `self.llm`
if explicit_llm and self.llm is not None:
llm = self.llm
task_kwargs = getattr(method, "_task_kwargs", {})
task_instance = WorkflowTask(
func=method,
description=getattr(method, "_task_description", None),
agent=getattr(method, "_task_agent", None),
llm=llm,
include_chat_history=getattr(method, "_task_include_chat_history", False),
workflow_app=self,
**task_kwargs
)
def run_in_event_loop(coroutine):
"""Ensures that an async function runs synchronously if needed."""
try:
loop = asyncio.get_running_loop()
return loop.run_until_complete(coroutine)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coroutine)
@functools.wraps(method)
def task_wrapper(ctx: WorkflowActivityContext, *args, **kwargs):
"""Wrapper function for executing tasks in a Dapr workflow, handling both sync and async tasks."""
wf_ctx = WorkflowActivityContext(ctx)
try:
if inspect.iscoroutinefunction(method) or asyncio.iscoroutinefunction(task_instance.__call__):
return run_in_event_loop(task_instance(wf_ctx, *args, **kwargs))
else:
return task_instance(wf_ctx, *args, **kwargs)
except Exception as e:
raise RuntimeError(f"Task '{task_name}' execution failed: {e}")
return task_name, task_wrapper # Return both name and wrapper
for method in all_functions.values():
# Ensure function reference is properly preserved inside a function scope
task_name, task_wrapper = make_task_wrapper(method)
# Register the task with Dapr Workflow using the correct task name
activity_decorator = self.wf_runtime.activity(name=task_name)
registered_activity = activity_decorator(task_wrapper)
# Store task reference
self.tasks[task_name] = registered_activity
def register_all_workflows(self):
"""
Registers all workflow functions dynamically with Dapr.
"""
current_module = sys.modules["__main__"]
all_workflows = {}
# Load global-level workflow functions
for name, func in inspect.getmembers(current_module, inspect.isfunction):
if hasattr(func, "_is_workflow") and func.__module__ == current_module.__name__:
workflow_name = getattr(func, "_workflow_name", None) or name
all_workflows[workflow_name] = func
# Load instance methods that are workflows
workflow_methods = get_callable_decorated_methods(self, "_is_workflow")
for method_name, method in workflow_methods.items():
workflow_name = getattr(method, "_workflow_name", method_name)
all_workflows[workflow_name] = method
logger.info(f"Discovered workflows: {list(all_workflows.keys())}")
def make_workflow_wrapper(method):
"""Creates a wrapper to prevent pointer overwrites during workflow registration."""
workflow_name = getattr(method, "_workflow_name", method.__name__)
@functools.wraps(method)
def workflow_wrapper(*args, **kwargs):
"""Directly calls the method without modifying ctx injection (already handled)."""
try:
return method(*args, **kwargs)
except Exception as e:
raise RuntimeError(f"Workflow '{workflow_name}' execution failed: {e}")
return workflow_name, workflow_wrapper
for method in all_workflows.values():
workflow_name, workflow_wrapper = make_workflow_wrapper(method)
# Register the workflow with Dapr using the correct name
workflow_decorator = self.wf_runtime.workflow(name=workflow_name)
registered_workflow = workflow_decorator(workflow_wrapper)
# Store workflow reference
self.workflows[workflow_name] = registered_workflow
def resolve_task(self, task: Union[str, Callable]) -> Callable:
"""
@ -416,7 +411,7 @@ class WorkflowApp(BaseModel):
if state.serialized_output:
logger.debug(f"Output: {json.dumps(state.serialized_output, indent=2)}")
elif workflow_status == "FAILED":
elif workflow_status in ("FAILED", "ABORTED"):
# Ensure `failure_details` exists before accessing attributes
error_type = getattr(failure_details, "error_type", "Unknown")
message = getattr(failure_details, "message", "No message provided")
@ -430,6 +425,8 @@ class WorkflowApp(BaseModel):
f"Input: {json.dumps(state.serialized_input, indent=2)}"
)
self.terminate_workflow(instance_id)
else:
logger.warning(
f"Workflow '{instance_id}' ended with status '{workflow_status}'.\n"
@ -633,19 +630,4 @@ class WorkflowApp(BaseModel):
Returns:
dtask.WhenAnyTask: A task that completes when the first task finishes.
"""
return dtask.when_any(tasks)
def start_runtime(self):
"""
Starts the Dapr workflow runtime
"""
logger.info("Starting workflow runtime.")
self.wf_runtime.start()
def stop_runtime(self):
"""
Stops the Dapr workflow runtime.
"""
logger.info("Stopping workflow runtime.")
self.wf_runtime.shutdown()
return dtask.when_any(tasks)

View File

@ -1,6 +1,7 @@
import functools
import inspect
from typing import Any, Callable, Optional
import logging
from pydantic import BaseModel, ValidationError
@ -65,18 +66,31 @@ def task(
def decorator(f: Callable) -> Callable:
if not callable(f):
raise ValueError(f"@task must be applied to a function, got {type(f)}.")
# Attach task metadata for later consumption
f._is_task = True
f._task_name = name or f.__name__
f._task_description = description
f._task_agent = agent
f._task_llm = llm
# Attach task metadata
f._is_task = True
f._task_name = name or f.__name__
f._task_description = description
f._task_agent = agent
f._task_llm = llm
f._task_include_chat_history = include_chat_history
f._explicit_llm = llm is not None or bool(description)
f._task_kwargs = task_kwargs # Forward anything else (e.g., structured_mode)
return f
f._explicit_llm = llm is not None or bool(description)
f._task_kwargs = task_kwargs
# wrap it so we can log, validate, etc., without losing signature/docs
@functools.wraps(f)
def wrapper(*args, **kwargs):
logging.getLogger(__name__).debug(f"Calling task '{f._task_name}'")
return f(*args, **kwargs)
# copy our metadata onto the wrapper so discovery still sees it
for attr in (
"_is_task", "_task_name", "_task_description", "_task_agent",
"_task_llm", "_task_include_chat_history", "_explicit_llm", "_task_kwargs"
):
setattr(wrapper, attr, getattr(f, attr))
return wrapper
return decorator(func) if func else decorator # Supports both @task and @task(name="custom")
@ -146,6 +160,8 @@ def workflow(func: Optional[Callable] = None, *, name: Optional[str] = None) ->
def wrapper(*args, **kwargs):
"""Wrapper for handling input validation and execution."""
logging.getLogger(__name__).info(f"Starting workflow '{f._workflow_name}'")
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()
@ -179,6 +195,8 @@ def workflow(func: Optional[Callable] = None, *, name: Optional[str] = None) ->
return f(*bound_args.args, **bound_args.kwargs)
wrapper._is_workflow = True
wrapper._workflow_name = f._workflow_name
return wrapper
return decorator(func) if func else decorator # Supports both `@workflow` and `@workflow(name="custom")`

View File

@ -11,7 +11,7 @@ from dapr.clients.grpc.subscription import StreamInactiveError
from dapr.common.pubsub.subscription import StreamCancelledError, SubscriptionMessage
from dapr_agents.workflow.messaging.parser import extract_cloudevent_data, validate_message_model
from dapr_agents.workflow.messaging.utils import is_valid_routable_model
from dapr_agents.workflow.utils import get_callable_decorated_methods
from dapr_agents.workflow.utils import get_decorated_methods
logger = logging.getLogger(__name__)
@ -46,7 +46,7 @@ class MessageRoutingMixin:
- Wraps each handler and maps it by `(pubsub_name, topic_name)` and schema name.
- Ensures only one handler per schema per topic is allowed.
"""
message_handlers = get_callable_decorated_methods(self, "_is_message_handler")
message_handlers = get_decorated_methods(self, "_is_message_handler")
for method_name, method in message_handlers.items():
try:

View File

@ -1,9 +1,7 @@
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
from dapr_agents.types import DaprWorkflowContext, BaseMessage, EventMessageMetadata
from dapr_agents.types import DaprWorkflowContext, BaseMessage
from dapr_agents.workflow.decorators import workflow, task
from dapr_agents.workflow.messaging.decorator import message_router
from fastapi.responses import JSONResponse
from fastapi import Response, status
from typing import Any, Optional, Dict, Any
from datetime import timedelta
from pydantic import BaseModel, Field
@ -23,8 +21,9 @@ class TriggerAction(BaseModel):
"""
Represents a message used to trigger an agent's activity within the workflow.
"""
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent can act based on its memory or predefined behavior.")
iteration: Optional[int] = Field(default=0, description="The current iteration of the workflow loop.")
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.")
iteration: Optional[int] = Field(0, description="")
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available")
class RandomOrchestrator(OrchestratorWorkflowBase):
"""
@ -170,29 +169,35 @@ class RandomOrchestrator(OrchestratorWorkflowBase):
name (str): Name of the agent to trigger.
instance_id (str): Workflow instance ID for context.
"""
logger.info(f"Triggering agent {name} (Instance ID: {instance_id})")
await self.send_message_to_agent(
name=name,
message=TriggerAction(task=None),
workflow_instance_id=instance_id,
message=TriggerAction(workflow_instance_id=instance_id),
)
@message_router
async def process_agent_response(self, message: AgentTaskResponse, metadata: EventMessageMetadata) -> Response:
async def process_agent_response(self, message: AgentTaskResponse):
"""
Processes agent response messages sent directly to the agent's topic.
Args:
message (AgentTaskResponse): The agent's response containing task results.
metadata (EventMessageMetadata): Metadata associated with the message, including headers.
Returns:
Response: A JSON response confirming the workflow event was successfully triggered.
None: The function raises a workflow event with the agent's response.
"""
agent_response = (message).model_dump()
workflow_instance_id = metadata.headers.get("workflow_instance_id")
event_name = metadata.headers.get("event_name", "AgentTaskResponse")
try:
workflow_instance_id = message.get("workflow_instance_id")
# Raise a workflow event with the Agent's Task Response!
self.raise_workflow_event(instance_id=workflow_instance_id, event_name=event_name, data=agent_response)
if not workflow_instance_id:
logger.error(f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring.")
return
return JSONResponse(content={"message": "Workflow event triggered successfully."}, status_code=status.HTTP_200_OK)
logger.info(f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'.")
# Raise a workflow event with the Agent's Task Response
self.raise_workflow_event(instance_id=workflow_instance_id, event_name="AgentTaskResponse", data=message)
except Exception as e:
logger.error(f"Error processing agent response: {e}", exc_info=True)

View File

@ -1,9 +1,7 @@
from dapr_agents.workflow.messaging.decorator import message_router
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
from dapr_agents.types import DaprWorkflowContext, BaseMessage, EventMessageMetadata
from dapr_agents.types import DaprWorkflowContext, BaseMessage
from dapr_agents.workflow.decorators import workflow, task
from fastapi.responses import JSONResponse
from fastapi import Response, status
from typing import Any, Optional, Dict
from pydantic import BaseModel, Field
from datetime import timedelta
@ -21,8 +19,9 @@ class TriggerAction(BaseModel):
"""
Represents a message used to trigger an agent's activity within the workflow.
"""
task: Optional[str] = None
iteration: Optional[int] = 0
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.")
iteration: Optional[int] = Field(0, description="")
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available")
class RoundRobinOrchestrator(OrchestratorWorkflowBase):
"""
@ -161,30 +160,31 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
"""
await self.send_message_to_agent(
name=name,
message=TriggerAction(task=None),
workflow_instance_id=instance_id,
message=TriggerAction(workflow_instance_id=instance_id),
)
@message_router
async def process_agent_response(self, message: AgentTaskResponse,
metadata: EventMessageMetadata) -> Response:
async def process_agent_response(self, message: AgentTaskResponse):
"""
Processes agent response messages sent directly to the agent's topic.
Args:
message (AgentTaskResponse): The agent's response containing task results.
metadata (EventMessageMetadata): Metadata associated with the message, including headers.
Returns:
Response: A JSON response confirming the workflow event was successfully triggered.
None: The function raises a workflow event with the agent's response.
"""
agent_response = (message).model_dump()
workflow_instance_id = metadata.headers.get("workflow_instance_id")
event_name = metadata.headers.get("event_name", "AgentTaskResponse")
try:
workflow_instance_id = message.get("workflow_instance_id")
# Raise a workflow event with the Agent's Task Response!
self.raise_workflow_event(instance_id=workflow_instance_id, event_name=event_name,
data=agent_response)
if not workflow_instance_id:
logger.error(f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring.")
return
return JSONResponse(content={"message": "Workflow event triggered successfully."},
status_code=status.HTTP_200_OK)
logger.info(f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'.")
# Raise a workflow event with the Agent's Task Response
self.raise_workflow_event(instance_id=workflow_instance_id, event_name="AgentTaskResponse", data=message)
except Exception as e:
logger.error(f"Error processing agent response: {e}", exc_info=True)

View File

@ -4,7 +4,7 @@ import logging
from dataclasses import is_dataclass
from functools import update_wrapper
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field
@ -44,249 +44,232 @@ class WorkflowTask(BaseModel):
"""
Post-initialization to set up function signatures and default LLM clients.
"""
# Default to OpenAIChatClient if promptbased but no llm provided
if self.description and not self.llm:
self.llm = OpenAIChatClient()
if self.func:
# Preserve name / docs for stack traces
update_wrapper(self, self.func)
# Capture signature for input / output handling
self.signature = inspect.signature(self.func) if self.func else None
# Honor any structured_mode override
if not self.structured_mode and "structured_mode" in self.task_kwargs:
self.structured_mode = self.task_kwargs["structured_mode"]
# Proceed with base model setup
super().model_post_init(__context)
async def __call__(self, ctx: WorkflowActivityContext, input: Any = None) -> Any:
async def __call__(self, ctx: WorkflowActivityContext, payload: Any = None) -> Any:
"""
Executes the task and validates its output.
Ensures all coroutines are awaited before returning.
Executes the task, routing to agent, LLM, or pure-Python logic.
Dispatches to Python, Agent, or LLM paths and validates output.
Args:
ctx (WorkflowActivityContext): The workflow execution context.
input (Any): The task input.
payload (Any): The task input.
Returns:
Any: The result of the task.
"""
input = self._normalize_input(input) if input is not None else {}
# Prepare input dict
data = self._normalize_input(payload) if payload is not None else {}
logger.info(f"Executing task '{self.func.__name__}'")
logger.debug(f"Executing task '{self.func.__name__}' with input {data!r}")
try:
if self.agent or self.llm:
executor = self._choose_executor()
if executor in ("agent", "llm"):
if not self.description:
raise ValueError(f"Task {self.func.__name__} is LLM-based but has no description!")
result = await self._run_task(self.format_description(self.description, input))
result = await self._validate_output(result)
elif self.func:
# Task is a Python function
logger.info(f"Invoking Regular Task")
if asyncio.iscoroutinefunction(self.func):
# Await async function
result = await self.func(**input)
else:
# Call sync function
result = self._execute_function(input)
result = await self._validate_output(result)
raise ValueError("LLM/agent tasks require a description template")
prompt = self.format_description(self.description, data)
raw = await self._run_via_ai(prompt, executor)
else:
raise ValueError("Task must have an LLM, agent, or regular function for execution.")
return result
except Exception as e:
logger.error(f"Task execution error: {e}")
raw = await self._run_python(data)
validated = await self._validate_output(raw)
return validated
except Exception:
logger.exception(f"Error in task '{self.func.__name__}'")
raise
def _normalize_input(self, input: Any) -> dict:
def _choose_executor(self) -> Literal["agent", "llm", "python"]:
"""
Converts input into a normalized dictionary.
Args:
input (Any): Input to normalize (e.g., dictionary, dataclass, or object).
Pick execution path.
Returns:
dict: Normalized dictionary representation of the input.
"""
if is_dataclass(input):
return input.__dict__
elif isinstance(input, SimpleNamespace):
return vars(input)
elif not isinstance(input, dict):
return self._single_value_to_dict(input)
return input
def _single_value_to_dict(self, value: Any) -> dict:
"""
Wraps a single input value in a dictionary.
Args:
value (Any): Single input value.
Returns:
dict: Dictionary with parameter name as the key.
One of "agent", "llm", or "python".
Raises:
ValueError: If no function signature is available.
ValueError: If no valid executor is configured.
"""
if not self.signature:
raise ValueError("Cannot convert single input to dict: function signature is missing.")
param_name = list(self.signature.parameters.keys())[0]
return {param_name: value}
def format_description(self, description: str, input: dict) -> str:
"""
Formats a description string with input parameters.
Args:
description (str): Description template.
input (dict): Input parameters for formatting.
Returns:
str: Formatted description string.
"""
if self.signature:
bound_args = self.signature.bind(**input)
bound_args.apply_defaults()
return description.format(**bound_args.arguments)
return description.format(**input)
async def _run_task(self, formatted_description: str) -> Any:
"""
Determine whether to run the task using an agent or an LLM.
Args:
formatted_description (str): The formatted description to pass to the agent or LLM.
Returns:
Any: The result of the agent or LLM execution.
Raises:
ValueError: If neither an agent nor an LLM is provided.
"""
logger.debug(f"Task Description: {formatted_description}")
if self.agent:
return await self._run_agent(formatted_description)
elif self.llm:
return await self._run_llm(formatted_description)
return "agent"
if self.llm:
return "llm"
if self.func:
return "python"
raise ValueError("No execution path found for this task")
async def _run_python(self, data: dict) -> Any:
"""
Invoke the Python function directly.
Args:
data: Keyword arguments for the function.
Returns:
The function's return value.
"""
logger.info("Invoking regular Python function")
if asyncio.iscoroutinefunction(self.func):
return await self.func(**data)
else:
raise ValueError("No agent or LLM provided.")
async def _run_agent(self, description: str) -> Any:
return self.func(**data)
async def _run_via_ai(self, prompt: str, executor: Literal["agent", "llm"]) -> Any:
"""
Execute the task using the provided agent.
Run the prompt through an Agent or LLM.
Args:
description (str): The formatted description to pass to the agent.
prompt: The fully formatted prompt string.
kind: "agent" or "llm".
Returns:
Any: The result of the agent execution.
Raw result from the AI path.
"""
logger.info("Invoking Task with AI Agent...")
result = await self.agent.run(description)
logger.debug(f"Agent result type: {type(result)}, value: {result}")
logger.info(f"Invoking task via {executor.upper()}")
logger.debug(f"Invoking task with prompt: {prompt!r}")
if executor == "agent":
result = await self.agent.run(prompt)
else:
result = await self._invoke_llm(prompt)
return self._convert_result(result)
async def _run_llm(self, description: Union[str, List[BaseMessage]]) -> Any:
logger.info("Invoking Task with LLM...")
async def _invoke_llm(self, prompt: str) -> Any:
"""
Build messages and call the LLM client.
# 1. Get chat history if enabled
conversation_history = []
Args:
prompt: The formatted prompt string.
Returns:
LLM-generated result.
"""
# Gather history if needed
history: List[BaseMessage] = []
if self.include_chat_history and self.workflow_app:
logger.info("Retrieving conversation history...")
conversation_history = self.workflow_app.get_chat_history()
logger.debug(f"Conversation history retrieved: {conversation_history}")
logger.debug("Retrieving chat history")
history = self.workflow_app.get_chat_history()
# 2. Convert string input to structured messages
if isinstance(description, str):
description = [UserMessage(description)]
llm_messages = conversation_history + description
messages: List[BaseMessage] = history + [UserMessage(prompt)]
params: Dict[str, Any] = {"messages": messages}
# 3. Base LLM parameters
llm_params = {"messages": llm_messages}
# 4. Add structured response config if a valid Pydantic model is the return type
# Add structured formatting if return type is a Pydantic model
if self.signature and self.signature.return_annotation is not inspect.Signature.empty:
return_type = self.signature.return_annotation
model_cls = StructureHandler.resolve_response_model(return_type)
# Only proceed if we resolved a Pydantic model
model_cls = StructureHandler.resolve_response_model(
self.signature.return_annotation
)
if model_cls:
if not hasattr(self.llm, "provider"):
raise AttributeError(
f"{type(self.llm).__name__} is missing the `.provider` attribute — required for structured response generation."
)
params["response_format"] = self.signature.return_annotation
params["structured_mode"] = self.structured_mode
logger.debug(f"Using LLM provider: {self.llm.provider}")
llm_params["response_format"] = return_type
llm_params["structured_mode"] = self.structured_mode or "json"
# 5. Call the LLM client
result = self.llm.generate(**llm_params)
logger.debug(f"LLM result type: {type(result)}, value: {result}")
return self._convert_result(result)
def _convert_result(self, result: Any) -> Any:
"""
Convert the task result to a dictionary if necessary.
Args:
result (Any): The raw task result.
Returns:
Any: The converted result.
"""
if isinstance(result, ChatCompletion):
logger.debug("Extracted message content from ChatCompletion.")
return result.get_content()
if isinstance(result, BaseModel):
logger.debug("Converting Pydantic model to dictionary.")
return result.model_dump()
if isinstance(result, list) and all(isinstance(item, BaseModel) for item in result):
logger.debug("Converting list of Pydantic models to list of dictionaries.")
return [item.model_dump() for item in result]
# If no specific conversion is necessary, return as-is
logger.info("Returning final task result.")
return result
logger.debug(f"LLM call params: {params}")
return self.llm.generate(**params)
def _execute_function(self, input: dict) -> Any:
def _normalize_input(self, raw_input: Any) -> dict:
"""
Execute the wrapped function with the provided input.
Normalize various input types into a dict.
Args:
input (dict): The input data to pass to the function.
raw_input: Dataclass, SimpleNamespace, single value, or dict.
Returns:
Any: The result of the function execution.
A dict suitable for function invocation.
Raises:
ValueError: If signature is missing when wrapping a single value.
"""
return self.func(**input)
if is_dataclass(raw_input):
return raw_input.__dict__
if isinstance(raw_input, SimpleNamespace):
return vars(raw_input)
if not isinstance(raw_input, dict):
# wrap single argument
if not self.signature:
raise ValueError("Cannot infer param name without signature")
name = next(iter(self.signature.parameters))
return {name: raw_input}
return raw_input
async def _validate_output(self, result: Any) -> Any:
"""
Validates the output of the task against the expected return type.
Await and validate the result against return-type model.
Supports coroutine outputs and structured type validation.
Args:
result: Raw result from executor.
Returns:
Any: The validated and potentially transformed result.
Validated/transformed result.
"""
if asyncio.iscoroutine(result):
logger.warning("Result is a coroutine; awaiting.")
result = await result
if not self.signature or self.signature.return_annotation is inspect.Signature.empty:
if (
not self.signature
or self.signature.return_annotation is inspect.Signature.empty
):
return result
expected_type = self.signature.return_annotation
return StructureHandler.validate_against_signature(result, expected_type)
return StructureHandler.validate_against_signature(
result, self.signature.return_annotation
)
def _convert_result(self, result: Any) -> Any:
"""
Unwrap AI return types into plain Python.
Args:
result: ChatCompletion, BaseModel, or list of BaseModel.
Returns:
A primitive, dict, or list of dicts.
"""
# Unwrap ChatCompletion
if isinstance(result, ChatCompletion):
logger.debug("Extracted message content from ChatCompletion.")
return result.get_content()
# Pydantic → dict
if isinstance(result, BaseModel):
logger.debug("Converting Pydantic model to dictionary.")
return result.model_dump()
if isinstance(result, list) and all(isinstance(x, BaseModel) for x in result):
logger.debug("Converting list of Pydantic models to list of dictionaries.")
return [x.model_dump() for x in result]
# If no specific conversion is necessary, return as-is
logger.info("Returning final task result.")
return result
def format_description(self, template: str, data: dict) -> str:
"""
Interpolate inputs into the prompt template.
Args:
template: The `{}`-style template string.
data: Mapping of variable names to values.
Returns:
The fully formatted prompt.
"""
if self.signature:
bound = self.signature.bind(**data)
bound.apply_defaults()
return template.format(**bound.arguments)
return template.format(**data)
class TaskWrapper:
"""

View File

@ -1,35 +1,53 @@
import inspect
import logging
from typing import Any, Callable, Dict
logger = logging.getLogger(__name__)
def get_callable_decorated_methods(instance, decorator_attr: str) -> dict:
def get_decorated_methods(instance: Any, attribute_name: str) -> Dict[str, Callable]:
"""
Safely retrieves all instance methods decorated with a specific attribute (e.g. `_is_task`, `_is_workflow`).
Find all **public** bound methods on `instance` that carry a given decorator attribute.
This will:
1. Inspect the class for functions or methods.
2. Bind them to the instance (so `self` is applied).
3. Filter in only those where `hasattr(method, attribute_name) is True`.
Args:
instance: The class instance to inspect.
decorator_attr (str): The attribute name set by a decorator (e.g. "_is_task").
instance: Any object whose methods you want to inspect.
attribute_name:
The name of the attribute set by your decorator
(e.g. "_is_task" or "_is_workflow").
Returns:
dict: Mapping of method names to bound method callables.
A dict mapping `method_name` `bound method`.
Example:
>>> class A:
... @task
... def foo(self): ...
...
>>> get_decorated_methods(A(), "_is_task")
{"foo": <bound method A.foo of <A object ...>>}
"""
discovered = {}
for method_name in dir(instance):
if method_name.startswith("_"):
continue # Skip private/protected
discovered: Dict[str, Callable] = {}
raw_attr = getattr(type(instance), method_name, None)
if not (inspect.isfunction(raw_attr) or inspect.ismethod(raw_attr)):
continue # Skip non-methods (e.g., @property)
try:
method = getattr(instance, method_name)
except Exception as e:
logger.warning(f"Skipping method '{method_name}' due to error: {e}")
cls = type(instance)
for name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
# skip private/protected
if name.startswith("_"):
continue
if hasattr(method, decorator_attr):
discovered[method_name] = method
# bind to instance so that signature(self, ...) works
try:
bound = getattr(instance, name)
except Exception as e:
logger.warning(f"Could not bind method '{name}': {e}")
continue
# pick up only those with our decorator flag
if hasattr(bound, attribute_name):
discovered[name] = bound
logger.debug(f"Discovered decorated method: {name}")
return discovered

View File

@ -27,13 +27,13 @@ classifiers = [
dependencies = [
"durabletask-dapr >= 0.2.0a7",
"pydantic == 2.10.5",
"openai == 1.59.6",
"pydantic == 2.11.3",
"openai == 1.75.0",
"openapi-pydantic == 0.5.1",
"openapi-schema-pydantic==1.2.4",
"regex >= 2023.12.25",
"Jinja2 >= 3.1.6",
"azure-identity == 1.19.0",
"azure-identity == 1.21.0",
"dapr >= 1.15.0",
"dapr-ext-fastapi == 1.15.0",
"dapr-ext-workflow == 1.15.0",
@ -41,7 +41,7 @@ dependencies = [
"cloudevents == 1.11.0",
"pyyaml == 6.0.2",
"rich == 13.9.4",
"huggingface_hub == 0.27.1",
"huggingface_hub == 0.30.2",
"numpy == 2.2.2",
]

View File

@ -1,18 +1,24 @@
import asyncio
from dapr_agents import tool, Agent
from dotenv import load_dotenv
load_dotenv()
@tool
def my_weather_func() -> str:
"""Get current weather."""
return "It's 72°F and sunny"
weather_agent = Agent(
name="WeatherAgent",
role="Weather Assistant",
instructions=["Help users with weather information"],
tools=[my_weather_func]
)
async def main():
weather_agent = Agent(
name="WeatherAgent",
role="Weather Assistant",
instructions=["Help users with weather information"],
tools=[my_weather_func]
)
response = weather_agent.run("What's the weather?")
print(response)
response = await weather_agent.run("What's the weather?")
print(response)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,7 +1,9 @@
import asyncio
from dapr_agents import tool, ReActAgent
from dotenv import load_dotenv
load_dotenv()
@tool
def search_weather(city: str) -> str:
"""Get weather information for a city."""
@ -14,14 +16,17 @@ def get_activities(weather: str) -> str:
activities = {"rainy": "Visit museums", "sunny": "Go hiking"}
return activities.get(weather.lower(), "Stay comfortable")
react_agent = ReActAgent(
name="TravelAgent",
role="Travel Assistant",
instructions=["Check weather, then suggest activities"],
tools=[search_weather, get_activities]
)
async def main():
react_agent = ReActAgent(
name="TravelAgent",
role="Travel Assistant",
instructions=["Check weather, then suggest activities"],
tools=[search_weather, get_activities]
)
result = react_agent.run("What should I do in London today?")
result = await react_agent.run("What should I do in London today?")
if result:
print("Result:", result)
if len(result) > 0:
print ("Result:", result)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -92,24 +92,30 @@ python 02_build_agent.py
This example shows how to create a basic agent with a custom tool:
```python
import asyncio
from dapr_agents import tool, Agent
from dotenv import load_dotenv
load_dotenv()
@tool
def my_weather_func() -> str:
"""Get current weather."""
return "It's 72°F and sunny"
weather_agent = Agent(
name="WeatherAgent",
role="Weather Assistant",
instructions=["Help users with weather information"],
tools=[my_weather_func]
)
async def main():
weather_agent = Agent(
name="WeatherAgent",
role="Weather Assistant",
instructions=["Help users with weather information"],
tools=[my_weather_func]
)
response = weather_agent.run("What's the weather?")
print(response)
response = await weather_agent.run("What's the weather?")
print(response)
if __name__ == "__main__":
asyncio.run(main())
```
**Expected output:** The agent will use the weather tool to provide the current weather.
@ -141,10 +147,12 @@ python 03_reason_act.py
<!-- END_STEP -->
```python
import asyncio
from dapr_agents import tool, ReActAgent
from dotenv import load_dotenv
load_dotenv()
@tool
def search_weather(city: str) -> str:
"""Get weather information for a city."""
@ -154,17 +162,23 @@ def search_weather(city: str) -> str:
@tool
def get_activities(weather: str) -> str:
"""Get activity recommendations."""
activities = {"rainy": "Visit museums", "Sunny": "Go hiking"}
activities = {"rainy": "Visit museums", "sunny": "Go hiking"}
return activities.get(weather.lower(), "Stay comfortable")
react_agent = ReActAgent(
name="TravelAgent",
role="Travel Assistant",
instructions=["Check weather, then suggest activities"],
tools=[search_weather, get_activities]
)
async def main():
react_agent = ReActAgent(
name="TravelAgent",
role="Travel Assistant",
instructions=["Check weather, then suggest activities"],
tools=[search_weather, get_activities]
)
react_agent.run("What should I do in London today?")
result = await react_agent.run("What should I do in London today?")
if result:
print("Result:", result)
if __name__ == "__main__":
asyncio.run(main())
```
**Expected output:** The agent will first check the weather in London, find it's rainy, and then recommend visiting museums.

View File

@ -70,7 +70,7 @@ if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
# Chat completion with user input
llm = HFHubChatClient()
llm = HFHubChatClient(model="microsoft/Phi-3-mini-4k-instruct")
response = llm.generate(messages=[UserMessage("hello")])
print("Response with user input: ", response.get_content())

View File

@ -22,7 +22,7 @@ if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
# Chat completion with user input
llm = HFHubChatClient()
llm = HFHubChatClient(model="microsoft/Phi-3-mini-4k-instruct")
response = llm.generate(messages=[UserMessage("hello")])

View File

@ -80,31 +80,38 @@ from dotenv import load_dotenv
import asyncio
import logging
async def main():
try:
# Define Agent
hobbit_agent = Agent(
role="Hobbit",
name="Frodo",
goal="Take the ring to Mordor",
instructions=["Speak like Frodo"]
)
# Expose Agent as an Actor Service
hobbit_service = AgentActor(
hobbit_agent = Agent(role="Hobbit", name="Frodo",
goal="Carry the One Ring to Mount Doom, resisting its corruptive power while navigating danger and uncertainty.",
instructions=[
"Speak like Frodo, with humility, determination, and a growing sense of resolve.",
"Endure hardships and temptations, staying true to the mission even when faced with doubt.",
"Seek guidance and trust allies, but bear the ultimate burden alone when necessary.",
"Move carefully through enemy-infested lands, avoiding unnecessary risks.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."])
# Expose Agent as an Actor over a Service
hobbit_actor = AgentActor(
agent=hobbit_agent,
message_bus_name="messagepubsub",
agents_state_store_name="agentstatestore",
service_port=8001,
agents_registry_store_name="agentstatestore",
agents_registry_key="agents_registry",
service_port=8001
)
await hobbit_service.start()
await hobbit_actor.start()
except Exception as e:
print(f"Error starting service: {e}")
print(f"Error starting actor: {e}")
if __name__ == "__main__":
load_dotenv()
logging.basicConfig(level=logging.INFO)
asyncio.run(main())
```
@ -120,24 +127,29 @@ from dotenv import load_dotenv
import asyncio
import logging
async def main():
try:
random_workflow = RandomOrchestrator(
name="RandomOrchestrator",
message_bus_name="messagepubsub",
state_store_name="agenticworkflowstate",
state_store_name="workflowstatestore",
state_key="workflow_state",
agents_registry_store_name="agentstatestore",
agents_registry_key="agents_registry",
max_iterations=3
).as_service(port=8004)
await random_workflow.start()
except Exception as e:
print(f"Error starting workflow: {e}")
if __name__ == "__main__":
load_dotenv()
logging.basicConfig(level=logging.INFO)
asyncio.run(main())
```

View File

@ -1,10 +1,10 @@
pydantic==2.10.5
openai==1.59.6
pydantic==2.11.3
openai==1.75.0
openapi-pydantic==0.5.1
openapi-schema-pydantic==1.2.4
regex>=2023.12.25
Jinja2>=3.1.6
azure-identity==1.19.0
azure-identity==1.21.0
dapr>=1.15.0
dapr-ext-fastapi==1.15.0
dapr-ext-workflow==1.15.0
@ -12,6 +12,6 @@ colorama==0.4.6
cloudevents==1.11.0
pyyaml==6.0.2
rich==13.9.4
huggingface_hub==0.27.1
huggingface_hub==0.30.2
numpy==2.2.2
mcp==1.6.0