mirror of https://github.com/dapr/dapr-agents.git
Hierarchical LLM config on Tasks + workflow/ decorator refactor (#92)
* updated dependencies versions * Updated model validator for HuggingFace-Hub client to catch model and hub url early * split task registration into discover + register phases, improved LLM client init in tasks and workflow wrappers * improve decorator to access method attributes * wrapping workflow decotrators to log, validate, etc., without losing signature/docs * Improved LLM-based task client and cleaned execution of LLM, agent and python function * Added an example of multiple models being defined per workflow task after updates * Updated quickstarts basic agent runs to async * Added model attribute to huggingface-hub client class * Fixed random and roundrobin orchestrators TriggerAction schema, trigger action and task to process agent response * Updated quickstart multi-agent workflows and actor-based agents docs * added .dapr to gitignore
This commit is contained in:
parent
b939d7d2f5
commit
099dc5d2fb
|
@ -2,6 +2,7 @@
|
|||
.DS_Store
|
||||
secrets.json
|
||||
test
|
||||
.dapr
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
from dapr_agents import OpenAIChatClient, NVIDIAChatClient
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr_agents. workflow import WorkflowApp, task, workflow
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import logging
|
||||
|
||||
load_dotenv()
|
||||
|
||||
nvidia_llm = NVIDIAChatClient(
|
||||
model="meta/llama-3.1-8b-instruct",
|
||||
api_key=os.getenv("NVIDIA_API_KEY")
|
||||
)
|
||||
|
||||
oai_llm = OpenAIChatClient(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o",
|
||||
base_url=os.getenv("OPENAI_API_BASE_URL"),
|
||||
)
|
||||
|
||||
azoai_llm = OpenAIChatClient(
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
azure_deployment="gpt-4o-mini",
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
azure_api_version="2024-12-01-preview"
|
||||
)
|
||||
|
||||
|
||||
@workflow
|
||||
def test_workflow(ctx: DaprWorkflowContext):
|
||||
"""
|
||||
A simple workflow that uses a multi-modal task chain.
|
||||
"""
|
||||
oai_results = yield ctx.call_activity(invoke_oai, input="Peru")
|
||||
azoai_results = yield ctx.call_activity(invoke_azoai, input=oai_results)
|
||||
nvidia_results = yield ctx.call_activity(invoke_nvidia, input=azoai_results)
|
||||
return nvidia_results
|
||||
|
||||
@task(description="What is the name of the capital of {country}?. Reply with just the name.", llm=oai_llm)
|
||||
def invoke_oai(country: str) -> str:
|
||||
pass
|
||||
|
||||
@task(description="What is a famous thing about {capital}?", llm=azoai_llm)
|
||||
def invoke_azoai(capital: str) -> str:
|
||||
pass
|
||||
|
||||
@task(description="Context: {context}. From the previous context. Pick one thing to do.", llm=nvidia_llm)
|
||||
def invoke_nvidia(context: str) -> str:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(workflow=test_workflow)
|
||||
|
||||
logging.info("Workflow results: %s", results)
|
||||
logging.info("Workflow completed successfully.")
|
|
@ -52,9 +52,21 @@ class HFHubInferenceClientBase(LLMClientBase):
|
|||
|
||||
values['api_key'] = api_key
|
||||
|
||||
# Ensure mutual exclusivity of `model` and `base_url`
|
||||
# mutual‑exclusivity
|
||||
if model is not None and base_url is not None:
|
||||
raise ValueError("Cannot provide both 'model' and 'base_url'. They are mutually exclusive.")
|
||||
raise ValueError("Cannot provide both 'model' and 'base_url'.")
|
||||
|
||||
# require at least one
|
||||
if model is None and base_url is None:
|
||||
raise ValueError(
|
||||
"HF Inference needs either `model` or `base_url`. "
|
||||
"E.g. model='gpt2' or base_url='https://…/models/gpt2'."
|
||||
)
|
||||
|
||||
# auto‑derive model from base_url
|
||||
if model is None:
|
||||
derived = base_url.rstrip("/").split("/")[-1]
|
||||
values["model"] = derived
|
||||
|
||||
return values
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from dapr.ext.workflow.workflow_state import WorkflowState
|
|||
from dapr_agents.llm.chat import ChatClientBase
|
||||
from dapr_agents.types.workflow import DaprWorkflowStatus
|
||||
from dapr_agents.workflow.task import WorkflowTask
|
||||
from dapr_agents.workflow.utils import get_callable_decorated_methods
|
||||
from dapr_agents.workflow.utils import get_decorated_methods
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -48,27 +48,148 @@ class WorkflowApp(BaseModel):
|
|||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
Post-initialization configuration for the WorkflowApp.
|
||||
|
||||
Initializes the Dapr Workflow runtime, client, and state store, and ensures
|
||||
that workflows and tasks are registered.
|
||||
Initialize the Dapr workflow runtime and register tasks & workflows.
|
||||
"""
|
||||
|
||||
# Initialize WorkflowRuntime and DaprWorkflowClient
|
||||
# initialize clients and runtime
|
||||
self.wf_runtime = WorkflowRuntime()
|
||||
self.wf_runtime_is_running = False
|
||||
self.wf_client = DaprWorkflowClient()
|
||||
self.client = DaprClient()
|
||||
logger.info("WorkflowApp initialized; discovering tasks and workflows.")
|
||||
|
||||
logger.info(f"Initialized WorkflowApp.")
|
||||
# Discover and register
|
||||
discovered_tasks = self._discover_tasks()
|
||||
self._register_tasks(discovered_tasks)
|
||||
discovered_wfs = self._discover_workflows()
|
||||
self._register_workflows(discovered_wfs)
|
||||
|
||||
# Register workflows and tasks after the instance is created
|
||||
self.register_all_workflows()
|
||||
self.register_all_tasks()
|
||||
|
||||
# Proceed with base model setup
|
||||
super().model_post_init(__context)
|
||||
|
||||
def get_chat_history(self) -> List[Any]:
|
||||
"""
|
||||
Stub for fetching past conversation history. Override in subclasses.
|
||||
"""
|
||||
logger.debug("Fetching chat history (default stub)")
|
||||
return []
|
||||
|
||||
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
|
||||
"""
|
||||
Encapsulate LLM selection logic.
|
||||
1. Use per-task override if provided on decorator.
|
||||
2. Else if description-based, fall back to default app LLM.
|
||||
3. Otherwise, None.
|
||||
"""
|
||||
per_task = getattr(method, '_task_llm', None)
|
||||
if per_task:
|
||||
return per_task
|
||||
if getattr(method, '_explicit_llm', False):
|
||||
return self.llm
|
||||
return None
|
||||
|
||||
def _discover_tasks(self) -> Dict[str, Callable]:
|
||||
"""Gather all @task-decorated functions and methods."""
|
||||
module = sys.modules['__main__']
|
||||
tasks: Dict[str, Callable] = {}
|
||||
# free functions
|
||||
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
||||
if getattr(fn, '_is_task', False) and fn.__module__ == module.__name__:
|
||||
tasks[getattr(fn, '_task_name', name)] = fn
|
||||
# bound methods
|
||||
for name, method in get_decorated_methods(self, '_is_task').items():
|
||||
tasks[getattr(method, '_task_name', name)] = method
|
||||
logger.debug(f"Discovered tasks: {list(tasks)}")
|
||||
return tasks
|
||||
|
||||
def _register_tasks(self, tasks: Dict[str, Callable]) -> None:
|
||||
"""Register each discovered task with the Dapr runtime."""
|
||||
for task_name, method in tasks.items():
|
||||
llm = self._choose_llm_for(method)
|
||||
logger.debug(f"Registering task '{task_name}' with llm={getattr(llm, '__class__', None)}")
|
||||
kwargs = getattr(method, '_task_kwargs', {})
|
||||
task_instance = WorkflowTask(
|
||||
func=method,
|
||||
description=getattr(method, '_task_description', None),
|
||||
agent=getattr(method, '_task_agent', None),
|
||||
llm=llm,
|
||||
include_chat_history=getattr(method, '_task_include_chat_history', False),
|
||||
workflow_app=self,
|
||||
**kwargs
|
||||
)
|
||||
# wrap for Dapr
|
||||
wrapped = self._make_task_wrapper(task_name, method, task_instance)
|
||||
activity_decorator = self.wf_runtime.activity(name=task_name)
|
||||
self.tasks[task_name] = activity_decorator(wrapped)
|
||||
|
||||
def _make_task_wrapper(
|
||||
self,
|
||||
task_name: str,
|
||||
method: Callable,
|
||||
task_instance: WorkflowTask
|
||||
) -> Callable:
|
||||
"""Produce the function that Dapr will invoke for each activity."""
|
||||
def run_sync(coro):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(ctx: WorkflowActivityContext, *args, **kwargs):
|
||||
wf_ctx = WorkflowActivityContext(ctx)
|
||||
try:
|
||||
call = task_instance(wf_ctx, *args, **kwargs)
|
||||
if asyncio.iscoroutine(call):
|
||||
return run_sync(call)
|
||||
return call
|
||||
except Exception as e:
|
||||
logger.exception(f"Task '{task_name}' failed")
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
def _discover_workflows(self) -> Dict[str, Callable]:
|
||||
"""Gather all @workflow-decorated functions and methods."""
|
||||
module = sys.modules['__main__']
|
||||
wfs: Dict[str, Callable] = {}
|
||||
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
||||
if getattr(fn, '_is_workflow', False) and fn.__module__ == module.__name__:
|
||||
wfs[getattr(fn, '_workflow_name', name)] = fn
|
||||
for name, method in get_decorated_methods(self, '_is_workflow').items():
|
||||
wfs[getattr(method, '_workflow_name', name)] = method
|
||||
logger.info(f"Discovered workflows: {list(wfs)}")
|
||||
return wfs
|
||||
|
||||
def _register_workflows(self, wfs: Dict[str, Callable]) -> None:
|
||||
"""Register each discovered workflow with the Dapr runtime."""
|
||||
for wf_name, method in wfs.items():
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapped(*args, **kwargs):
|
||||
return method(*args, **kwargs)
|
||||
|
||||
decorator = self.wf_runtime.workflow(name=wf_name)
|
||||
self.workflows[wf_name] = decorator(wrapped)
|
||||
|
||||
def start_runtime(self):
|
||||
"""Idempotently start the Dapr workflow runtime."""
|
||||
if not self.wf_runtime_is_running:
|
||||
logger.info("Starting workflow runtime.")
|
||||
self.wf_runtime.start()
|
||||
self.wf_runtime_is_running = True
|
||||
else:
|
||||
logger.debug("Workflow runtime already running; skipping.")
|
||||
|
||||
def stop_runtime(self):
|
||||
"""Idempotently stop the Dapr workflow runtime."""
|
||||
if self.wf_runtime_is_running:
|
||||
logger.info("Stopping workflow runtime.")
|
||||
self.wf_runtime.shutdown()
|
||||
self.wf_runtime_is_running = False
|
||||
else:
|
||||
logger.debug("Workflow runtime already stopped; skipping.")
|
||||
|
||||
def register_agent(self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict) -> None:
|
||||
"""
|
||||
Merges the existing data with the new data and updates the store.
|
||||
|
@ -142,132 +263,6 @@ class WorkflowApp(BaseModel):
|
|||
except Exception as e:
|
||||
logger.warning(f"Error retrieving data for key '{key}' from store '{store_name}'")
|
||||
return None
|
||||
|
||||
def register_all_tasks(self):
|
||||
"""
|
||||
Registers all collected tasks with Dapr while preserving execution logic.
|
||||
"""
|
||||
current_module = sys.modules["__main__"]
|
||||
|
||||
all_functions = {}
|
||||
for name, func in inspect.getmembers(current_module, inspect.isfunction):
|
||||
if hasattr(func, "_is_task") and func.__module__ == current_module.__name__:
|
||||
task_name = getattr(func, "_task_name", None) or name
|
||||
all_functions[task_name] = func
|
||||
|
||||
# Load instance methods that are tasks
|
||||
task_methods = get_callable_decorated_methods(self, "_is_task")
|
||||
for method_name, method in task_methods.items():
|
||||
task_name = getattr(method, "_task_name", method_name)
|
||||
all_functions[task_name] = method
|
||||
|
||||
logger.debug(f"Discovered tasks: {list(all_functions.keys())}")
|
||||
|
||||
def make_task_wrapper(method):
|
||||
"""Creates a unique task wrapper bound to a specific method reference."""
|
||||
# Extract stored metadata from the function
|
||||
task_name = getattr(method, "_task_name", method.__name__)
|
||||
explicit_llm = getattr(method, "_explicit_llm", False)
|
||||
|
||||
# Always initialize `llm` as `None` explicitly first
|
||||
llm = None
|
||||
|
||||
# If task is explicitly LLM-based, but has no LLM, use `self.llm`
|
||||
if explicit_llm and self.llm is not None:
|
||||
llm = self.llm
|
||||
|
||||
task_kwargs = getattr(method, "_task_kwargs", {})
|
||||
|
||||
task_instance = WorkflowTask(
|
||||
func=method,
|
||||
description=getattr(method, "_task_description", None),
|
||||
agent=getattr(method, "_task_agent", None),
|
||||
llm=llm,
|
||||
include_chat_history=getattr(method, "_task_include_chat_history", False),
|
||||
workflow_app=self,
|
||||
**task_kwargs
|
||||
)
|
||||
|
||||
def run_in_event_loop(coroutine):
|
||||
"""Ensures that an async function runs synchronously if needed."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.run_until_complete(coroutine)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(coroutine)
|
||||
|
||||
@functools.wraps(method)
|
||||
def task_wrapper(ctx: WorkflowActivityContext, *args, **kwargs):
|
||||
"""Wrapper function for executing tasks in a Dapr workflow, handling both sync and async tasks."""
|
||||
wf_ctx = WorkflowActivityContext(ctx)
|
||||
|
||||
try:
|
||||
if inspect.iscoroutinefunction(method) or asyncio.iscoroutinefunction(task_instance.__call__):
|
||||
return run_in_event_loop(task_instance(wf_ctx, *args, **kwargs))
|
||||
else:
|
||||
return task_instance(wf_ctx, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Task '{task_name}' execution failed: {e}")
|
||||
|
||||
return task_name, task_wrapper # Return both name and wrapper
|
||||
|
||||
for method in all_functions.values():
|
||||
# Ensure function reference is properly preserved inside a function scope
|
||||
task_name, task_wrapper = make_task_wrapper(method)
|
||||
|
||||
# Register the task with Dapr Workflow using the correct task name
|
||||
activity_decorator = self.wf_runtime.activity(name=task_name)
|
||||
registered_activity = activity_decorator(task_wrapper)
|
||||
|
||||
# Store task reference
|
||||
self.tasks[task_name] = registered_activity
|
||||
|
||||
def register_all_workflows(self):
|
||||
"""
|
||||
Registers all workflow functions dynamically with Dapr.
|
||||
"""
|
||||
current_module = sys.modules["__main__"]
|
||||
|
||||
all_workflows = {}
|
||||
# Load global-level workflow functions
|
||||
for name, func in inspect.getmembers(current_module, inspect.isfunction):
|
||||
if hasattr(func, "_is_workflow") and func.__module__ == current_module.__name__:
|
||||
workflow_name = getattr(func, "_workflow_name", None) or name
|
||||
all_workflows[workflow_name] = func
|
||||
|
||||
# Load instance methods that are workflows
|
||||
workflow_methods = get_callable_decorated_methods(self, "_is_workflow")
|
||||
for method_name, method in workflow_methods.items():
|
||||
workflow_name = getattr(method, "_workflow_name", method_name)
|
||||
all_workflows[workflow_name] = method
|
||||
|
||||
logger.info(f"Discovered workflows: {list(all_workflows.keys())}")
|
||||
|
||||
def make_workflow_wrapper(method):
|
||||
"""Creates a wrapper to prevent pointer overwrites during workflow registration."""
|
||||
workflow_name = getattr(method, "_workflow_name", method.__name__)
|
||||
|
||||
@functools.wraps(method)
|
||||
def workflow_wrapper(*args, **kwargs):
|
||||
"""Directly calls the method without modifying ctx injection (already handled)."""
|
||||
try:
|
||||
return method(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Workflow '{workflow_name}' execution failed: {e}")
|
||||
|
||||
return workflow_name, workflow_wrapper
|
||||
|
||||
for method in all_workflows.values():
|
||||
workflow_name, workflow_wrapper = make_workflow_wrapper(method)
|
||||
|
||||
# Register the workflow with Dapr using the correct name
|
||||
workflow_decorator = self.wf_runtime.workflow(name=workflow_name)
|
||||
registered_workflow = workflow_decorator(workflow_wrapper)
|
||||
|
||||
# Store workflow reference
|
||||
self.workflows[workflow_name] = registered_workflow
|
||||
|
||||
def resolve_task(self, task: Union[str, Callable]) -> Callable:
|
||||
"""
|
||||
|
@ -416,7 +411,7 @@ class WorkflowApp(BaseModel):
|
|||
if state.serialized_output:
|
||||
logger.debug(f"Output: {json.dumps(state.serialized_output, indent=2)}")
|
||||
|
||||
elif workflow_status == "FAILED":
|
||||
elif workflow_status in ("FAILED", "ABORTED"):
|
||||
# Ensure `failure_details` exists before accessing attributes
|
||||
error_type = getattr(failure_details, "error_type", "Unknown")
|
||||
message = getattr(failure_details, "message", "No message provided")
|
||||
|
@ -430,6 +425,8 @@ class WorkflowApp(BaseModel):
|
|||
f"Input: {json.dumps(state.serialized_input, indent=2)}"
|
||||
)
|
||||
|
||||
self.terminate_workflow(instance_id)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Workflow '{instance_id}' ended with status '{workflow_status}'.\n"
|
||||
|
@ -633,19 +630,4 @@ class WorkflowApp(BaseModel):
|
|||
Returns:
|
||||
dtask.WhenAnyTask: A task that completes when the first task finishes.
|
||||
"""
|
||||
return dtask.when_any(tasks)
|
||||
|
||||
def start_runtime(self):
|
||||
"""
|
||||
Starts the Dapr workflow runtime
|
||||
"""
|
||||
|
||||
logger.info("Starting workflow runtime.")
|
||||
self.wf_runtime.start()
|
||||
|
||||
def stop_runtime(self):
|
||||
"""
|
||||
Stops the Dapr workflow runtime.
|
||||
"""
|
||||
logger.info("Stopping workflow runtime.")
|
||||
self.wf_runtime.shutdown()
|
||||
return dtask.when_any(tasks)
|
|
@ -1,6 +1,7 @@
|
|||
import functools
|
||||
import inspect
|
||||
from typing import Any, Callable, Optional
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
@ -65,18 +66,31 @@ def task(
|
|||
def decorator(f: Callable) -> Callable:
|
||||
if not callable(f):
|
||||
raise ValueError(f"@task must be applied to a function, got {type(f)}.")
|
||||
|
||||
# Attach task metadata for later consumption
|
||||
f._is_task = True
|
||||
f._task_name = name or f.__name__
|
||||
f._task_description = description
|
||||
f._task_agent = agent
|
||||
f._task_llm = llm
|
||||
|
||||
# Attach task metadata
|
||||
f._is_task = True
|
||||
f._task_name = name or f.__name__
|
||||
f._task_description = description
|
||||
f._task_agent = agent
|
||||
f._task_llm = llm
|
||||
f._task_include_chat_history = include_chat_history
|
||||
f._explicit_llm = llm is not None or bool(description)
|
||||
f._task_kwargs = task_kwargs # Forward anything else (e.g., structured_mode)
|
||||
|
||||
return f
|
||||
f._explicit_llm = llm is not None or bool(description)
|
||||
f._task_kwargs = task_kwargs
|
||||
|
||||
# wrap it so we can log, validate, etc., without losing signature/docs
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
logging.getLogger(__name__).debug(f"Calling task '{f._task_name}'")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# copy our metadata onto the wrapper so discovery still sees it
|
||||
for attr in (
|
||||
"_is_task", "_task_name", "_task_description", "_task_agent",
|
||||
"_task_llm", "_task_include_chat_history", "_explicit_llm", "_task_kwargs"
|
||||
):
|
||||
setattr(wrapper, attr, getattr(f, attr))
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator(func) if func else decorator # Supports both @task and @task(name="custom")
|
||||
|
||||
|
@ -146,6 +160,8 @@ def workflow(func: Optional[Callable] = None, *, name: Optional[str] = None) ->
|
|||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper for handling input validation and execution."""
|
||||
|
||||
logging.getLogger(__name__).info(f"Starting workflow '{f._workflow_name}'")
|
||||
|
||||
bound_args = sig.bind_partial(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
|
@ -179,6 +195,8 @@ def workflow(func: Optional[Callable] = None, *, name: Optional[str] = None) ->
|
|||
|
||||
return f(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
wrapper._is_workflow = True
|
||||
wrapper._workflow_name = f._workflow_name
|
||||
return wrapper
|
||||
|
||||
return decorator(func) if func else decorator # Supports both `@workflow` and `@workflow(name="custom")`
|
|
@ -11,7 +11,7 @@ from dapr.clients.grpc.subscription import StreamInactiveError
|
|||
from dapr.common.pubsub.subscription import StreamCancelledError, SubscriptionMessage
|
||||
from dapr_agents.workflow.messaging.parser import extract_cloudevent_data, validate_message_model
|
||||
from dapr_agents.workflow.messaging.utils import is_valid_routable_model
|
||||
from dapr_agents.workflow.utils import get_callable_decorated_methods
|
||||
from dapr_agents.workflow.utils import get_decorated_methods
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -46,7 +46,7 @@ class MessageRoutingMixin:
|
|||
- Wraps each handler and maps it by `(pubsub_name, topic_name)` and schema name.
|
||||
- Ensures only one handler per schema per topic is allowed.
|
||||
"""
|
||||
message_handlers = get_callable_decorated_methods(self, "_is_message_handler")
|
||||
message_handlers = get_decorated_methods(self, "_is_message_handler")
|
||||
|
||||
for method_name, method in message_handlers.items():
|
||||
try:
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
|
||||
from dapr_agents.types import DaprWorkflowContext, BaseMessage, EventMessageMetadata
|
||||
from dapr_agents.types import DaprWorkflowContext, BaseMessage
|
||||
from dapr_agents.workflow.decorators import workflow, task
|
||||
from dapr_agents.workflow.messaging.decorator import message_router
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import Response, status
|
||||
from typing import Any, Optional, Dict, Any
|
||||
from datetime import timedelta
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -23,8 +21,9 @@ class TriggerAction(BaseModel):
|
|||
"""
|
||||
Represents a message used to trigger an agent's activity within the workflow.
|
||||
"""
|
||||
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent can act based on its memory or predefined behavior.")
|
||||
iteration: Optional[int] = Field(default=0, description="The current iteration of the workflow loop.")
|
||||
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.")
|
||||
iteration: Optional[int] = Field(0, description="")
|
||||
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available")
|
||||
|
||||
class RandomOrchestrator(OrchestratorWorkflowBase):
|
||||
"""
|
||||
|
@ -170,29 +169,35 @@ class RandomOrchestrator(OrchestratorWorkflowBase):
|
|||
name (str): Name of the agent to trigger.
|
||||
instance_id (str): Workflow instance ID for context.
|
||||
"""
|
||||
logger.info(f"Triggering agent {name} (Instance ID: {instance_id})")
|
||||
|
||||
await self.send_message_to_agent(
|
||||
name=name,
|
||||
message=TriggerAction(task=None),
|
||||
workflow_instance_id=instance_id,
|
||||
message=TriggerAction(workflow_instance_id=instance_id),
|
||||
)
|
||||
|
||||
@message_router
|
||||
async def process_agent_response(self, message: AgentTaskResponse, metadata: EventMessageMetadata) -> Response:
|
||||
async def process_agent_response(self, message: AgentTaskResponse):
|
||||
"""
|
||||
Processes agent response messages sent directly to the agent's topic.
|
||||
|
||||
Args:
|
||||
message (AgentTaskResponse): The agent's response containing task results.
|
||||
metadata (EventMessageMetadata): Metadata associated with the message, including headers.
|
||||
|
||||
|
||||
Returns:
|
||||
Response: A JSON response confirming the workflow event was successfully triggered.
|
||||
None: The function raises a workflow event with the agent's response.
|
||||
"""
|
||||
agent_response = (message).model_dump()
|
||||
workflow_instance_id = metadata.headers.get("workflow_instance_id")
|
||||
event_name = metadata.headers.get("event_name", "AgentTaskResponse")
|
||||
try:
|
||||
workflow_instance_id = message.get("workflow_instance_id")
|
||||
|
||||
# Raise a workflow event with the Agent's Task Response!
|
||||
self.raise_workflow_event(instance_id=workflow_instance_id, event_name=event_name, data=agent_response)
|
||||
if not workflow_instance_id:
|
||||
logger.error(f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring.")
|
||||
return
|
||||
|
||||
return JSONResponse(content={"message": "Workflow event triggered successfully."}, status_code=status.HTTP_200_OK)
|
||||
logger.info(f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'.")
|
||||
|
||||
# Raise a workflow event with the Agent's Task Response
|
||||
self.raise_workflow_event(instance_id=workflow_instance_id, event_name="AgentTaskResponse", data=message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent response: {e}", exc_info=True)
|
|
@ -1,9 +1,7 @@
|
|||
from dapr_agents.workflow.messaging.decorator import message_router
|
||||
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
|
||||
from dapr_agents.types import DaprWorkflowContext, BaseMessage, EventMessageMetadata
|
||||
from dapr_agents.types import DaprWorkflowContext, BaseMessage
|
||||
from dapr_agents.workflow.decorators import workflow, task
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import Response, status
|
||||
from typing import Any, Optional, Dict
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import timedelta
|
||||
|
@ -21,8 +19,9 @@ class TriggerAction(BaseModel):
|
|||
"""
|
||||
Represents a message used to trigger an agent's activity within the workflow.
|
||||
"""
|
||||
task: Optional[str] = None
|
||||
iteration: Optional[int] = 0
|
||||
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.")
|
||||
iteration: Optional[int] = Field(0, description="")
|
||||
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available")
|
||||
|
||||
class RoundRobinOrchestrator(OrchestratorWorkflowBase):
|
||||
"""
|
||||
|
@ -161,30 +160,31 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
|
|||
"""
|
||||
await self.send_message_to_agent(
|
||||
name=name,
|
||||
message=TriggerAction(task=None),
|
||||
workflow_instance_id=instance_id,
|
||||
message=TriggerAction(workflow_instance_id=instance_id),
|
||||
)
|
||||
|
||||
|
||||
@message_router
|
||||
async def process_agent_response(self, message: AgentTaskResponse,
|
||||
metadata: EventMessageMetadata) -> Response:
|
||||
async def process_agent_response(self, message: AgentTaskResponse):
|
||||
"""
|
||||
Processes agent response messages sent directly to the agent's topic.
|
||||
|
||||
Args:
|
||||
message (AgentTaskResponse): The agent's response containing task results.
|
||||
metadata (EventMessageMetadata): Metadata associated with the message, including headers.
|
||||
|
||||
Returns:
|
||||
Response: A JSON response confirming the workflow event was successfully triggered.
|
||||
None: The function raises a workflow event with the agent's response.
|
||||
"""
|
||||
agent_response = (message).model_dump()
|
||||
workflow_instance_id = metadata.headers.get("workflow_instance_id")
|
||||
event_name = metadata.headers.get("event_name", "AgentTaskResponse")
|
||||
try:
|
||||
workflow_instance_id = message.get("workflow_instance_id")
|
||||
|
||||
# Raise a workflow event with the Agent's Task Response!
|
||||
self.raise_workflow_event(instance_id=workflow_instance_id, event_name=event_name,
|
||||
data=agent_response)
|
||||
if not workflow_instance_id:
|
||||
logger.error(f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring.")
|
||||
return
|
||||
|
||||
return JSONResponse(content={"message": "Workflow event triggered successfully."},
|
||||
status_code=status.HTTP_200_OK)
|
||||
logger.info(f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'.")
|
||||
|
||||
# Raise a workflow event with the Agent's Task Response
|
||||
self.raise_workflow_event(instance_id=workflow_instance_id, event_name="AgentTaskResponse", data=message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent response: {e}", exc_info=True)
|
|
@ -4,7 +4,7 @@ import logging
|
|||
from dataclasses import is_dataclass
|
||||
from functools import update_wrapper
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
@ -44,249 +44,232 @@ class WorkflowTask(BaseModel):
|
|||
"""
|
||||
Post-initialization to set up function signatures and default LLM clients.
|
||||
"""
|
||||
# Default to OpenAIChatClient if prompt‐based but no llm provided
|
||||
if self.description and not self.llm:
|
||||
self.llm = OpenAIChatClient()
|
||||
|
||||
|
||||
if self.func:
|
||||
# Preserve name / docs for stack traces
|
||||
update_wrapper(self, self.func)
|
||||
|
||||
# Capture signature for input / output handling
|
||||
self.signature = inspect.signature(self.func) if self.func else None
|
||||
|
||||
# Honor any structured_mode override
|
||||
if not self.structured_mode and "structured_mode" in self.task_kwargs:
|
||||
self.structured_mode = self.task_kwargs["structured_mode"]
|
||||
|
||||
# Proceed with base model setup
|
||||
super().model_post_init(__context)
|
||||
|
||||
async def __call__(self, ctx: WorkflowActivityContext, input: Any = None) -> Any:
|
||||
async def __call__(self, ctx: WorkflowActivityContext, payload: Any = None) -> Any:
|
||||
"""
|
||||
Executes the task and validates its output.
|
||||
Ensures all coroutines are awaited before returning.
|
||||
Executes the task, routing to agent, LLM, or pure-Python logic.
|
||||
|
||||
Dispatches to Python, Agent, or LLM paths and validates output.
|
||||
|
||||
Args:
|
||||
ctx (WorkflowActivityContext): The workflow execution context.
|
||||
input (Any): The task input.
|
||||
payload (Any): The task input.
|
||||
|
||||
Returns:
|
||||
Any: The result of the task.
|
||||
"""
|
||||
input = self._normalize_input(input) if input is not None else {}
|
||||
# Prepare input dict
|
||||
data = self._normalize_input(payload) if payload is not None else {}
|
||||
logger.info(f"Executing task '{self.func.__name__}'")
|
||||
logger.debug(f"Executing task '{self.func.__name__}' with input {data!r}")
|
||||
|
||||
try:
|
||||
if self.agent or self.llm:
|
||||
executor = self._choose_executor()
|
||||
if executor in ("agent", "llm"):
|
||||
if not self.description:
|
||||
raise ValueError(f"Task {self.func.__name__} is LLM-based but has no description!")
|
||||
|
||||
result = await self._run_task(self.format_description(self.description, input))
|
||||
result = await self._validate_output(result)
|
||||
elif self.func:
|
||||
# Task is a Python function
|
||||
logger.info(f"Invoking Regular Task")
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
# Await async function
|
||||
result = await self.func(**input)
|
||||
else:
|
||||
# Call sync function
|
||||
result = self._execute_function(input)
|
||||
result = await self._validate_output(result)
|
||||
raise ValueError("LLM/agent tasks require a description template")
|
||||
prompt = self.format_description(self.description, data)
|
||||
raw = await self._run_via_ai(prompt, executor)
|
||||
else:
|
||||
raise ValueError("Task must have an LLM, agent, or regular function for execution.")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Task execution error: {e}")
|
||||
raw = await self._run_python(data)
|
||||
|
||||
validated = await self._validate_output(raw)
|
||||
return validated
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Error in task '{self.func.__name__}'")
|
||||
raise
|
||||
|
||||
def _normalize_input(self, input: Any) -> dict:
|
||||
def _choose_executor(self) -> Literal["agent", "llm", "python"]:
|
||||
"""
|
||||
Converts input into a normalized dictionary.
|
||||
|
||||
Args:
|
||||
input (Any): Input to normalize (e.g., dictionary, dataclass, or object).
|
||||
Pick execution path.
|
||||
|
||||
Returns:
|
||||
dict: Normalized dictionary representation of the input.
|
||||
"""
|
||||
if is_dataclass(input):
|
||||
return input.__dict__
|
||||
elif isinstance(input, SimpleNamespace):
|
||||
return vars(input)
|
||||
elif not isinstance(input, dict):
|
||||
return self._single_value_to_dict(input)
|
||||
return input
|
||||
|
||||
def _single_value_to_dict(self, value: Any) -> dict:
|
||||
"""
|
||||
Wraps a single input value in a dictionary.
|
||||
|
||||
Args:
|
||||
value (Any): Single input value.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with parameter name as the key.
|
||||
One of "agent", "llm", or "python".
|
||||
|
||||
Raises:
|
||||
ValueError: If no function signature is available.
|
||||
ValueError: If no valid executor is configured.
|
||||
"""
|
||||
if not self.signature:
|
||||
raise ValueError("Cannot convert single input to dict: function signature is missing.")
|
||||
|
||||
param_name = list(self.signature.parameters.keys())[0]
|
||||
return {param_name: value}
|
||||
|
||||
def format_description(self, description: str, input: dict) -> str:
|
||||
"""
|
||||
Formats a description string with input parameters.
|
||||
|
||||
Args:
|
||||
description (str): Description template.
|
||||
input (dict): Input parameters for formatting.
|
||||
|
||||
Returns:
|
||||
str: Formatted description string.
|
||||
"""
|
||||
if self.signature:
|
||||
bound_args = self.signature.bind(**input)
|
||||
bound_args.apply_defaults()
|
||||
return description.format(**bound_args.arguments)
|
||||
return description.format(**input)
|
||||
|
||||
async def _run_task(self, formatted_description: str) -> Any:
|
||||
"""
|
||||
Determine whether to run the task using an agent or an LLM.
|
||||
|
||||
Args:
|
||||
formatted_description (str): The formatted description to pass to the agent or LLM.
|
||||
|
||||
Returns:
|
||||
Any: The result of the agent or LLM execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither an agent nor an LLM is provided.
|
||||
"""
|
||||
logger.debug(f"Task Description: {formatted_description}")
|
||||
|
||||
if self.agent:
|
||||
return await self._run_agent(formatted_description)
|
||||
elif self.llm:
|
||||
return await self._run_llm(formatted_description)
|
||||
return "agent"
|
||||
if self.llm:
|
||||
return "llm"
|
||||
if self.func:
|
||||
return "python"
|
||||
raise ValueError("No execution path found for this task")
|
||||
|
||||
async def _run_python(self, data: dict) -> Any:
|
||||
"""
|
||||
Invoke the Python function directly.
|
||||
|
||||
Args:
|
||||
data: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
The function's return value.
|
||||
"""
|
||||
logger.info("Invoking regular Python function")
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
return await self.func(**data)
|
||||
else:
|
||||
raise ValueError("No agent or LLM provided.")
|
||||
|
||||
async def _run_agent(self, description: str) -> Any:
|
||||
return self.func(**data)
|
||||
|
||||
async def _run_via_ai(self, prompt: str, executor: Literal["agent", "llm"]) -> Any:
|
||||
"""
|
||||
Execute the task using the provided agent.
|
||||
Run the prompt through an Agent or LLM.
|
||||
|
||||
Args:
|
||||
description (str): The formatted description to pass to the agent.
|
||||
prompt: The fully formatted prompt string.
|
||||
kind: "agent" or "llm".
|
||||
|
||||
Returns:
|
||||
Any: The result of the agent execution.
|
||||
Raw result from the AI path.
|
||||
"""
|
||||
logger.info("Invoking Task with AI Agent...")
|
||||
|
||||
result = await self.agent.run(description)
|
||||
|
||||
logger.debug(f"Agent result type: {type(result)}, value: {result}")
|
||||
logger.info(f"Invoking task via {executor.upper()}")
|
||||
logger.debug(f"Invoking task with prompt: {prompt!r}")
|
||||
if executor == "agent":
|
||||
result = await self.agent.run(prompt)
|
||||
else:
|
||||
result = await self._invoke_llm(prompt)
|
||||
return self._convert_result(result)
|
||||
|
||||
async def _run_llm(self, description: Union[str, List[BaseMessage]]) -> Any:
|
||||
logger.info("Invoking Task with LLM...")
|
||||
async def _invoke_llm(self, prompt: str) -> Any:
|
||||
"""
|
||||
Build messages and call the LLM client.
|
||||
|
||||
# 1. Get chat history if enabled
|
||||
conversation_history = []
|
||||
Args:
|
||||
prompt: The formatted prompt string.
|
||||
|
||||
Returns:
|
||||
LLM-generated result.
|
||||
"""
|
||||
# Gather history if needed
|
||||
history: List[BaseMessage] = []
|
||||
if self.include_chat_history and self.workflow_app:
|
||||
logger.info("Retrieving conversation history...")
|
||||
conversation_history = self.workflow_app.get_chat_history()
|
||||
logger.debug(f"Conversation history retrieved: {conversation_history}")
|
||||
logger.debug("Retrieving chat history")
|
||||
history = self.workflow_app.get_chat_history()
|
||||
|
||||
# 2. Convert string input to structured messages
|
||||
if isinstance(description, str):
|
||||
description = [UserMessage(description)]
|
||||
llm_messages = conversation_history + description
|
||||
messages: List[BaseMessage] = history + [UserMessage(prompt)]
|
||||
params: Dict[str, Any] = {"messages": messages}
|
||||
|
||||
# 3. Base LLM parameters
|
||||
llm_params = {"messages": llm_messages}
|
||||
|
||||
# 4. Add structured response config if a valid Pydantic model is the return type
|
||||
# Add structured formatting if return type is a Pydantic model
|
||||
if self.signature and self.signature.return_annotation is not inspect.Signature.empty:
|
||||
return_type = self.signature.return_annotation
|
||||
model_cls = StructureHandler.resolve_response_model(return_type)
|
||||
|
||||
# Only proceed if we resolved a Pydantic model
|
||||
model_cls = StructureHandler.resolve_response_model(
|
||||
self.signature.return_annotation
|
||||
)
|
||||
if model_cls:
|
||||
if not hasattr(self.llm, "provider"):
|
||||
raise AttributeError(
|
||||
f"{type(self.llm).__name__} is missing the `.provider` attribute — required for structured response generation."
|
||||
)
|
||||
params["response_format"] = self.signature.return_annotation
|
||||
params["structured_mode"] = self.structured_mode
|
||||
|
||||
logger.debug(f"Using LLM provider: {self.llm.provider}")
|
||||
|
||||
llm_params["response_format"] = return_type
|
||||
llm_params["structured_mode"] = self.structured_mode or "json"
|
||||
|
||||
# 5. Call the LLM client
|
||||
result = self.llm.generate(**llm_params)
|
||||
|
||||
logger.debug(f"LLM result type: {type(result)}, value: {result}")
|
||||
return self._convert_result(result)
|
||||
|
||||
def _convert_result(self, result: Any) -> Any:
|
||||
"""
|
||||
Convert the task result to a dictionary if necessary.
|
||||
|
||||
Args:
|
||||
result (Any): The raw task result.
|
||||
|
||||
Returns:
|
||||
Any: The converted result.
|
||||
"""
|
||||
if isinstance(result, ChatCompletion):
|
||||
logger.debug("Extracted message content from ChatCompletion.")
|
||||
return result.get_content()
|
||||
|
||||
if isinstance(result, BaseModel):
|
||||
logger.debug("Converting Pydantic model to dictionary.")
|
||||
return result.model_dump()
|
||||
|
||||
if isinstance(result, list) and all(isinstance(item, BaseModel) for item in result):
|
||||
logger.debug("Converting list of Pydantic models to list of dictionaries.")
|
||||
return [item.model_dump() for item in result]
|
||||
|
||||
# If no specific conversion is necessary, return as-is
|
||||
logger.info("Returning final task result.")
|
||||
return result
|
||||
logger.debug(f"LLM call params: {params}")
|
||||
return self.llm.generate(**params)
|
||||
|
||||
def _execute_function(self, input: dict) -> Any:
|
||||
def _normalize_input(self, raw_input: Any) -> dict:
|
||||
"""
|
||||
Execute the wrapped function with the provided input.
|
||||
Normalize various input types into a dict.
|
||||
|
||||
Args:
|
||||
input (dict): The input data to pass to the function.
|
||||
raw_input: Dataclass, SimpleNamespace, single value, or dict.
|
||||
|
||||
Returns:
|
||||
Any: The result of the function execution.
|
||||
A dict suitable for function invocation.
|
||||
|
||||
Raises:
|
||||
ValueError: If signature is missing when wrapping a single value.
|
||||
"""
|
||||
return self.func(**input)
|
||||
if is_dataclass(raw_input):
|
||||
return raw_input.__dict__
|
||||
if isinstance(raw_input, SimpleNamespace):
|
||||
return vars(raw_input)
|
||||
if not isinstance(raw_input, dict):
|
||||
# wrap single argument
|
||||
if not self.signature:
|
||||
raise ValueError("Cannot infer param name without signature")
|
||||
name = next(iter(self.signature.parameters))
|
||||
return {name: raw_input}
|
||||
return raw_input
|
||||
|
||||
async def _validate_output(self, result: Any) -> Any:
|
||||
"""
|
||||
Validates the output of the task against the expected return type.
|
||||
Await and validate the result against return-type model.
|
||||
|
||||
Supports coroutine outputs and structured type validation.
|
||||
Args:
|
||||
result: Raw result from executor.
|
||||
|
||||
Returns:
|
||||
Any: The validated and potentially transformed result.
|
||||
Validated/transformed result.
|
||||
"""
|
||||
if asyncio.iscoroutine(result):
|
||||
logger.warning("Result is a coroutine; awaiting.")
|
||||
result = await result
|
||||
|
||||
if not self.signature or self.signature.return_annotation is inspect.Signature.empty:
|
||||
if (
|
||||
not self.signature
|
||||
or self.signature.return_annotation is inspect.Signature.empty
|
||||
):
|
||||
return result
|
||||
|
||||
expected_type = self.signature.return_annotation
|
||||
return StructureHandler.validate_against_signature(result, expected_type)
|
||||
return StructureHandler.validate_against_signature(
|
||||
result, self.signature.return_annotation
|
||||
)
|
||||
|
||||
def _convert_result(self, result: Any) -> Any:
|
||||
"""
|
||||
Unwrap AI return types into plain Python.
|
||||
|
||||
Args:
|
||||
result: ChatCompletion, BaseModel, or list of BaseModel.
|
||||
|
||||
Returns:
|
||||
A primitive, dict, or list of dicts.
|
||||
"""
|
||||
# Unwrap ChatCompletion
|
||||
if isinstance(result, ChatCompletion):
|
||||
logger.debug("Extracted message content from ChatCompletion.")
|
||||
return result.get_content()
|
||||
# Pydantic → dict
|
||||
if isinstance(result, BaseModel):
|
||||
logger.debug("Converting Pydantic model to dictionary.")
|
||||
return result.model_dump()
|
||||
if isinstance(result, list) and all(isinstance(x, BaseModel) for x in result):
|
||||
logger.debug("Converting list of Pydantic models to list of dictionaries.")
|
||||
return [x.model_dump() for x in result]
|
||||
# If no specific conversion is necessary, return as-is
|
||||
logger.info("Returning final task result.")
|
||||
return result
|
||||
|
||||
def format_description(self, template: str, data: dict) -> str:
|
||||
"""
|
||||
Interpolate inputs into the prompt template.
|
||||
|
||||
Args:
|
||||
template: The `{}`-style template string.
|
||||
data: Mapping of variable names to values.
|
||||
|
||||
Returns:
|
||||
The fully formatted prompt.
|
||||
"""
|
||||
if self.signature:
|
||||
bound = self.signature.bind(**data)
|
||||
bound.apply_defaults()
|
||||
return template.format(**bound.arguments)
|
||||
return template.format(**data)
|
||||
|
||||
class TaskWrapper:
|
||||
"""
|
||||
|
|
|
@ -1,35 +1,53 @@
|
|||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_callable_decorated_methods(instance, decorator_attr: str) -> dict:
|
||||
def get_decorated_methods(instance: Any, attribute_name: str) -> Dict[str, Callable]:
|
||||
"""
|
||||
Safely retrieves all instance methods decorated with a specific attribute (e.g. `_is_task`, `_is_workflow`).
|
||||
Find all **public** bound methods on `instance` that carry a given decorator attribute.
|
||||
|
||||
This will:
|
||||
1. Inspect the class for functions or methods.
|
||||
2. Bind them to the instance (so `self` is applied).
|
||||
3. Filter in only those where `hasattr(method, attribute_name) is True`.
|
||||
|
||||
Args:
|
||||
instance: The class instance to inspect.
|
||||
decorator_attr (str): The attribute name set by a decorator (e.g. "_is_task").
|
||||
instance: Any object whose methods you want to inspect.
|
||||
attribute_name:
|
||||
The name of the attribute set by your decorator
|
||||
(e.g. "_is_task" or "_is_workflow").
|
||||
|
||||
Returns:
|
||||
dict: Mapping of method names to bound method callables.
|
||||
A dict mapping `method_name` → `bound method`.
|
||||
|
||||
Example:
|
||||
>>> class A:
|
||||
... @task
|
||||
... def foo(self): ...
|
||||
...
|
||||
>>> get_decorated_methods(A(), "_is_task")
|
||||
{"foo": <bound method A.foo of <A object ...>>}
|
||||
"""
|
||||
discovered = {}
|
||||
for method_name in dir(instance):
|
||||
if method_name.startswith("_"):
|
||||
continue # Skip private/protected
|
||||
discovered: Dict[str, Callable] = {}
|
||||
|
||||
raw_attr = getattr(type(instance), method_name, None)
|
||||
if not (inspect.isfunction(raw_attr) or inspect.ismethod(raw_attr)):
|
||||
continue # Skip non-methods (e.g., @property)
|
||||
|
||||
try:
|
||||
method = getattr(instance, method_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping method '{method_name}' due to error: {e}")
|
||||
cls = type(instance)
|
||||
for name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||||
# skip private/protected
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
if hasattr(method, decorator_attr):
|
||||
discovered[method_name] = method
|
||||
# bind to instance so that signature(self, ...) works
|
||||
try:
|
||||
bound = getattr(instance, name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not bind method '{name}': {e}")
|
||||
continue
|
||||
|
||||
# pick up only those with our decorator flag
|
||||
if hasattr(bound, attribute_name):
|
||||
discovered[name] = bound
|
||||
logger.debug(f"Discovered decorated method: {name}")
|
||||
|
||||
return discovered
|
|
@ -27,13 +27,13 @@ classifiers = [
|
|||
|
||||
dependencies = [
|
||||
"durabletask-dapr >= 0.2.0a7",
|
||||
"pydantic == 2.10.5",
|
||||
"openai == 1.59.6",
|
||||
"pydantic == 2.11.3",
|
||||
"openai == 1.75.0",
|
||||
"openapi-pydantic == 0.5.1",
|
||||
"openapi-schema-pydantic==1.2.4",
|
||||
"regex >= 2023.12.25",
|
||||
"Jinja2 >= 3.1.6",
|
||||
"azure-identity == 1.19.0",
|
||||
"azure-identity == 1.21.0",
|
||||
"dapr >= 1.15.0",
|
||||
"dapr-ext-fastapi == 1.15.0",
|
||||
"dapr-ext-workflow == 1.15.0",
|
||||
|
@ -41,7 +41,7 @@ dependencies = [
|
|||
"cloudevents == 1.11.0",
|
||||
"pyyaml == 6.0.2",
|
||||
"rich == 13.9.4",
|
||||
"huggingface_hub == 0.27.1",
|
||||
"huggingface_hub == 0.30.2",
|
||||
"numpy == 2.2.2",
|
||||
]
|
||||
|
||||
|
|
|
@ -1,18 +1,24 @@
|
|||
import asyncio
|
||||
from dapr_agents import tool, Agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@tool
|
||||
def my_weather_func() -> str:
|
||||
"""Get current weather."""
|
||||
return "It's 72°F and sunny"
|
||||
|
||||
weather_agent = Agent(
|
||||
name="WeatherAgent",
|
||||
role="Weather Assistant",
|
||||
instructions=["Help users with weather information"],
|
||||
tools=[my_weather_func]
|
||||
)
|
||||
async def main():
|
||||
weather_agent = Agent(
|
||||
name="WeatherAgent",
|
||||
role="Weather Assistant",
|
||||
instructions=["Help users with weather information"],
|
||||
tools=[my_weather_func]
|
||||
)
|
||||
|
||||
response = weather_agent.run("What's the weather?")
|
||||
print(response)
|
||||
response = await weather_agent.run("What's the weather?")
|
||||
print(response)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -1,7 +1,9 @@
|
|||
import asyncio
|
||||
from dapr_agents import tool, ReActAgent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@tool
|
||||
def search_weather(city: str) -> str:
|
||||
"""Get weather information for a city."""
|
||||
|
@ -14,14 +16,17 @@ def get_activities(weather: str) -> str:
|
|||
activities = {"rainy": "Visit museums", "sunny": "Go hiking"}
|
||||
return activities.get(weather.lower(), "Stay comfortable")
|
||||
|
||||
react_agent = ReActAgent(
|
||||
name="TravelAgent",
|
||||
role="Travel Assistant",
|
||||
instructions=["Check weather, then suggest activities"],
|
||||
tools=[search_weather, get_activities]
|
||||
)
|
||||
async def main():
|
||||
react_agent = ReActAgent(
|
||||
name="TravelAgent",
|
||||
role="Travel Assistant",
|
||||
instructions=["Check weather, then suggest activities"],
|
||||
tools=[search_weather, get_activities]
|
||||
)
|
||||
|
||||
result = react_agent.run("What should I do in London today?")
|
||||
result = await react_agent.run("What should I do in London today?")
|
||||
if result:
|
||||
print("Result:", result)
|
||||
|
||||
if len(result) > 0:
|
||||
print ("Result:", result)
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -92,24 +92,30 @@ python 02_build_agent.py
|
|||
This example shows how to create a basic agent with a custom tool:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from dapr_agents import tool, Agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@tool
|
||||
def my_weather_func() -> str:
|
||||
"""Get current weather."""
|
||||
return "It's 72°F and sunny"
|
||||
|
||||
weather_agent = Agent(
|
||||
name="WeatherAgent",
|
||||
role="Weather Assistant",
|
||||
instructions=["Help users with weather information"],
|
||||
tools=[my_weather_func]
|
||||
)
|
||||
async def main():
|
||||
weather_agent = Agent(
|
||||
name="WeatherAgent",
|
||||
role="Weather Assistant",
|
||||
instructions=["Help users with weather information"],
|
||||
tools=[my_weather_func]
|
||||
)
|
||||
|
||||
response = weather_agent.run("What's the weather?")
|
||||
print(response)
|
||||
response = await weather_agent.run("What's the weather?")
|
||||
print(response)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
**Expected output:** The agent will use the weather tool to provide the current weather.
|
||||
|
@ -141,10 +147,12 @@ python 03_reason_act.py
|
|||
<!-- END_STEP -->
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from dapr_agents import tool, ReActAgent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@tool
|
||||
def search_weather(city: str) -> str:
|
||||
"""Get weather information for a city."""
|
||||
|
@ -154,17 +162,23 @@ def search_weather(city: str) -> str:
|
|||
@tool
|
||||
def get_activities(weather: str) -> str:
|
||||
"""Get activity recommendations."""
|
||||
activities = {"rainy": "Visit museums", "Sunny": "Go hiking"}
|
||||
activities = {"rainy": "Visit museums", "sunny": "Go hiking"}
|
||||
return activities.get(weather.lower(), "Stay comfortable")
|
||||
|
||||
react_agent = ReActAgent(
|
||||
name="TravelAgent",
|
||||
role="Travel Assistant",
|
||||
instructions=["Check weather, then suggest activities"],
|
||||
tools=[search_weather, get_activities]
|
||||
)
|
||||
async def main():
|
||||
react_agent = ReActAgent(
|
||||
name="TravelAgent",
|
||||
role="Travel Assistant",
|
||||
instructions=["Check weather, then suggest activities"],
|
||||
tools=[search_weather, get_activities]
|
||||
)
|
||||
|
||||
react_agent.run("What should I do in London today?")
|
||||
result = await react_agent.run("What should I do in London today?")
|
||||
if result:
|
||||
print("Result:", result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
**Expected output:** The agent will first check the weather in London, find it's rainy, and then recommend visiting museums.
|
||||
|
|
|
@ -70,7 +70,7 @@ if len(response.get_content()) > 0:
|
|||
print("Response with prompty: ", response.get_content())
|
||||
|
||||
# Chat completion with user input
|
||||
llm = HFHubChatClient()
|
||||
llm = HFHubChatClient(model="microsoft/Phi-3-mini-4k-instruct")
|
||||
response = llm.generate(messages=[UserMessage("hello")])
|
||||
|
||||
print("Response with user input: ", response.get_content())
|
||||
|
|
|
@ -22,7 +22,7 @@ if len(response.get_content()) > 0:
|
|||
print("Response with prompty: ", response.get_content())
|
||||
|
||||
# Chat completion with user input
|
||||
llm = HFHubChatClient()
|
||||
llm = HFHubChatClient(model="microsoft/Phi-3-mini-4k-instruct")
|
||||
response = llm.generate(messages=[UserMessage("hello")])
|
||||
|
||||
|
||||
|
|
|
@ -80,31 +80,38 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
hobbit_agent = Agent(
|
||||
role="Hobbit",
|
||||
name="Frodo",
|
||||
goal="Take the ring to Mordor",
|
||||
instructions=["Speak like Frodo"]
|
||||
)
|
||||
|
||||
# Expose Agent as an Actor Service
|
||||
hobbit_service = AgentActor(
|
||||
hobbit_agent = Agent(role="Hobbit", name="Frodo",
|
||||
goal="Carry the One Ring to Mount Doom, resisting its corruptive power while navigating danger and uncertainty.",
|
||||
instructions=[
|
||||
"Speak like Frodo, with humility, determination, and a growing sense of resolve.",
|
||||
"Endure hardships and temptations, staying true to the mission even when faced with doubt.",
|
||||
"Seek guidance and trust allies, but bear the ultimate burden alone when necessary.",
|
||||
"Move carefully through enemy-infested lands, avoiding unnecessary risks.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."])
|
||||
|
||||
# Expose Agent as an Actor over a Service
|
||||
hobbit_actor = AgentActor(
|
||||
agent=hobbit_agent,
|
||||
message_bus_name="messagepubsub",
|
||||
agents_state_store_name="agentstatestore",
|
||||
service_port=8001,
|
||||
agents_registry_store_name="agentstatestore",
|
||||
agents_registry_key="agents_registry",
|
||||
service_port=8001
|
||||
)
|
||||
|
||||
await hobbit_service.start()
|
||||
await hobbit_actor.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
print(f"Error starting actor: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
|
@ -120,24 +127,29 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
random_workflow = RandomOrchestrator(
|
||||
name="RandomOrchestrator",
|
||||
message_bus_name="messagepubsub",
|
||||
state_store_name="agenticworkflowstate",
|
||||
state_store_name="workflowstatestore",
|
||||
state_key="workflow_state",
|
||||
agents_registry_store_name="agentstatestore",
|
||||
agents_registry_key="agents_registry",
|
||||
max_iterations=3
|
||||
).as_service(port=8004)
|
||||
|
||||
await random_workflow.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting workflow: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
pydantic==2.10.5
|
||||
openai==1.59.6
|
||||
pydantic==2.11.3
|
||||
openai==1.75.0
|
||||
openapi-pydantic==0.5.1
|
||||
openapi-schema-pydantic==1.2.4
|
||||
regex>=2023.12.25
|
||||
Jinja2>=3.1.6
|
||||
azure-identity==1.19.0
|
||||
azure-identity==1.21.0
|
||||
dapr>=1.15.0
|
||||
dapr-ext-fastapi==1.15.0
|
||||
dapr-ext-workflow==1.15.0
|
||||
|
@ -12,6 +12,6 @@ colorama==0.4.6
|
|||
cloudevents==1.11.0
|
||||
pyyaml==6.0.2
|
||||
rich==13.9.4
|
||||
huggingface_hub==0.27.1
|
||||
huggingface_hub==0.30.2
|
||||
numpy==2.2.2
|
||||
mcp==1.6.0
|
Loading…
Reference in New Issue