Async‑first workflow runner, sync wrapper & registration closure fix (#93)

* Async workflow runner, sync wrapper & registration closure fix

* updated cookbook notebooks to show sync and async with workflow monitoring

* updated quickstarts to use the updated sync workflow monitoring
This commit is contained in:
Roberto Rodriguez 2025-04-22 12:20:04 -04:00 committed by GitHub
parent 099dc5d2fb
commit bd0859d181
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 207 additions and 52 deletions

View File

@ -1,6 +1,7 @@
import logging
from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext
import logging
@workflow(name='random_workflow')
def task_chain_workflow(ctx:DaprWorkflowContext, input: int):
@ -32,6 +33,6 @@ if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(task_chain_workflow, input=10)
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow, input=10)
print(f"Results: {results}")

View File

@ -0,0 +1,40 @@
import asyncio
import logging
from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext
@workflow(name="random_workflow")
def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
result1 = yield ctx.call_activity(step1, input=input)
result2 = yield ctx.call_activity(step2, input=result1)
result3 = yield ctx.call_activity(step3, input=result2)
return [result1, result2, result3]
@task
def step1(activity_input: int) -> int:
print(f"Step 1: Received input: {activity_input}.")
return activity_input + 1
@task
def step2(activity_input: int) -> int:
print(f"Step 2: Received input: {activity_input}.")
return activity_input * 2
@task
def step3(activity_input: int) -> int:
print(f"Step 3: Received input: {activity_input}.")
return activity_input ^ 2
async def main():
logging.basicConfig(level=logging.INFO)
wfapp = WorkflowApp()
result = await wfapp.run_and_monitor_workflow_async(
task_chain_workflow,
input=10
)
print(f"Results: {result}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -31,5 +31,5 @@ if __name__ == '__main__':
wfapp = WorkflowApp()
# Run workflow
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
print(results)

View File

@ -0,0 +1,40 @@
import asyncio
import logging
from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext
from dotenv import load_dotenv
# Define Workflow logic
@workflow(name='lotr_workflow')
def task_chain_workflow(ctx: DaprWorkflowContext):
result1 = yield ctx.call_activity(get_character)
result2 = yield ctx.call_activity(get_line, input={"character": result1})
return result2
@task(description="""
Pick a random character from The Lord of the Rings\n
and respond with the character's name ONLY
""")
def get_character() -> str:
pass
@task(description="What is a famous line by {character}",)
def get_line(character: str) -> str:
pass
async def main():
logging.basicConfig(level=logging.INFO)
# Load environment variables
load_dotenv()
# Initialize the WorkflowApp
wfapp = WorkflowApp()
# Run workflow
result = await wfapp.run_and_monitor_workflow_async(task_chain_workflow)
print(f"Results: {result}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,8 +1,8 @@
import logging
from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext
from pydantic import BaseModel
from dotenv import load_dotenv
import logging
@workflow
def question(ctx:DaprWorkflowContext, input:int):
@ -25,5 +25,9 @@ if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(workflow=question, input="Scooby Doo")
results = wfapp.run_and_monitor_workflow_sync(
workflow=question,
input="Scooby Doo"
)
print(results)

View File

@ -0,0 +1,40 @@
import asyncio
import logging
from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext
from pydantic import BaseModel
from dotenv import load_dotenv
@workflow
def question(ctx:DaprWorkflowContext, input:int):
step1 = yield ctx.call_activity(ask, input=input)
return step1
class Dog(BaseModel):
name: str
bio: str
breed: str
@task("Who was {name}?")
def ask(name:str) -> Dog:
pass
async def main():
logging.basicConfig(level=logging.INFO)
# Load environment variables
load_dotenv()
# Initialize the WorkflowApp
wfapp = WorkflowApp()
# Run workflow
result = await wfapp.run_and_monitor_workflow_async(
workflow=question,
input="Scooby Doo"
)
print(f"Results: {result}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -358,4 +358,4 @@ if __name__ == '__main__':
raise ValueError("PDF URL must be provided via CLI or config file.")
# Run the workflow
wfapp.run_and_monitor_workflow(workflow=doc2podcast, input=user_input)
wfapp.run_and_monitor_workflow_sync(workflow=doc2podcast, input=user_input)

View File

@ -54,7 +54,7 @@ if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(workflow=test_workflow)
results = wfapp.run_and_monitor_workflow_sync(workflow=test_workflow)
logging.info("Workflow results: %s", results)
logging.info("Workflow completed successfully.")

View File

@ -6,7 +6,7 @@ import logging
import sys
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field
@ -38,7 +38,7 @@ class WorkflowApp(BaseModel):
# Initialized in model_post_init
wf_runtime: Optional[WorkflowRuntime] = Field(default=None, init=False, description="Workflow runtime instance.")
wf_runtime_is_running: Optional[bool] = Field(default=None, init=False, description="Is the Workflow runtime running?.")
wf_runtime_is_running: Optional[bool] = Field(default=None, init=False, description="Is the Workflow runtime running?")
wf_client: Optional[DaprWorkflowClient] = Field(default=None, init=False, description="Workflow client instance.")
client: Optional[DaprClient] = Field(default=None, init=False, description="Dapr client instance.")
tasks: Dict[str, Callable] = Field(default_factory=dict, init=False, description="Dictionary of registered tasks.")
@ -50,14 +50,14 @@ class WorkflowApp(BaseModel):
"""
Initialize the Dapr workflow runtime and register tasks & workflows.
"""
# initialize clients and runtime
# Initialize clients and runtime
self.wf_runtime = WorkflowRuntime()
self.wf_runtime_is_running = False
self.wf_client = DaprWorkflowClient()
self.client = DaprClient()
logger.info("WorkflowApp initialized; discovering tasks and workflows.")
# Discover and register
# Discover and register tasks and workflows
discovered_tasks = self._discover_tasks()
self._register_tasks(discovered_tasks)
discovered_wfs = self._discover_workflows()
@ -75,9 +75,9 @@ class WorkflowApp(BaseModel):
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
"""
Encapsulate LLM selection logic.
1. Use per-task override if provided on decorator.
2. Else if description-based, fall back to default app LLM.
3. Otherwise, None.
1. Use per-task override if provided on decorator.
2. Else if marked as explicitly requiring an LLM, fall back to default app LLM.
3. Otherwise, returns None.
"""
per_task = getattr(method, '_task_llm', None)
if per_task:
@ -90,11 +90,11 @@ class WorkflowApp(BaseModel):
"""Gather all @task-decorated functions and methods."""
module = sys.modules['__main__']
tasks: Dict[str, Callable] = {}
# free functions
# Free functions in __main__
for name, fn in inspect.getmembers(module, inspect.isfunction):
if getattr(fn, '_is_task', False) and fn.__module__ == module.__name__:
tasks[getattr(fn, '_task_name', name)] = fn
# bound methods
# Bound methods (if any) discovered via helper
for name, method in get_decorated_methods(self, '_is_task').items():
tasks[getattr(method, '_task_name', name)] = method
logger.debug(f"Discovered tasks: {list(tasks)}")
@ -115,7 +115,7 @@ class WorkflowApp(BaseModel):
workflow_app=self,
**kwargs
)
# wrap for Dapr
# Wrap for Dapr invocation
wrapped = self._make_task_wrapper(task_name, method, task_instance)
activity_decorator = self.wf_runtime.activity(name=task_name)
self.tasks[task_name] = activity_decorator(wrapped)
@ -128,10 +128,12 @@ class WorkflowApp(BaseModel):
) -> Callable:
"""Produce the function that Dapr will invoke for each activity."""
def run_sync(coro):
# Try to get the running event loop and run until complete
try:
loop = asyncio.get_running_loop()
return loop.run_until_complete(coro)
except RuntimeError:
# If no running loop, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
@ -164,13 +166,15 @@ class WorkflowApp(BaseModel):
def _register_workflows(self, wfs: Dict[str, Callable]) -> None:
"""Register each discovered workflow with the Dapr runtime."""
for wf_name, method in wfs.items():
@functools.wraps(method)
def wrapped(*args, **kwargs):
return method(*args, **kwargs)
# Use a closure helper to avoid late-binding capture issues.
def make_wrapped(meth: Callable) -> Callable:
@functools.wraps(meth)
def wrapped(*args, **kwargs):
return meth(*args, **kwargs)
return wrapped
decorator = self.wf_runtime.workflow(name=wf_name)
self.workflows[wf_name] = decorator(wrapped)
self.workflows[wf_name] = decorator(make_wrapped(method))
def start_runtime(self):
"""Idempotently start the Dapr workflow runtime."""
@ -244,8 +248,8 @@ class WorkflowApp(BaseModel):
logger.debug(f"Sleeping for 1 second before retrying transaction...")
time.sleep(1)
raise Exception(f"Failed to update state store key: {store_key} after 10 attempts.")
def get_data_from_store(self, store_name: str, key: str) -> Tuple[bool, dict]:
def get_data_from_store(self, store_name: str, key: str) -> Optional[dict]:
"""
Retrieves data from the Dapr state store using the given key.
@ -254,7 +258,7 @@ class WorkflowApp(BaseModel):
key (str): The key to fetch data from.
Returns:
Tuple[bool, dict]: A tuple indicating if data was found (bool) and the retrieved data (dict).
Optional[dict]: the retrieved dictionary or None if not found.
"""
try:
response: StateResponse = self.client.get_state(store_name=store_name, key=key)
@ -335,10 +339,9 @@ class WorkflowApp(BaseModel):
# Start Workflow Runtime
if not self.wf_runtime_is_running:
self.start_runtime()
self.wf_runtime_is_running = True
# Generate unique instance ID
instance_id = str(uuid.uuid4()).replace("-", "")
instance_id = uuid.uuid4().hex
# Resolve the workflow function
workflow_func = self.resolve_workflow(workflow)
@ -443,29 +446,37 @@ class WorkflowApp(BaseModel):
finally:
logger.info(f"Finished monitoring workflow '{instance_id}'.")
def run_and_monitor_workflow(self, workflow: Union[str, Callable], input: Optional[Union[str, Dict[str, Any]]] = None) -> Optional[str]:
async def run_and_monitor_workflow_async(
self,
workflow: Union[str, Callable],
input: Optional[Union[str, Dict[str, Any]]] = None
) -> Optional[str]:
"""
Runs a workflow synchronously and monitors its completion.
Runs a workflow asynchronously and monitors its completion.
Args:
workflow (Union[str, Callable]): The workflow name or callable.
input (Optional[Union[str, Dict[str, Any]]]): The workflow input.
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
Returns:
Optional[str]: The serialized output of the workflow.
"""
instance_id = None
try:
# Schedule the workflow
instance_id = self.run_workflow(workflow, input=input)
# Ensure we run within a new asyncio event loop
state = asyncio.run(self.monitor_workflow_state(instance_id))
# Off-load the potentially blocking run_workflow call to a thread.
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
# Await the asynchronous monitoring of the workflow state.
state = await self.monitor_workflow_state(instance_id)
if not state:
raise RuntimeError(f"Workflow '{instance_id}' not found.")
workflow_status = DaprWorkflowStatus[state.runtime_status.name] if state.runtime_status.name in DaprWorkflowStatus.__members__ else DaprWorkflowStatus.UNKNOWN
workflow_status = (
DaprWorkflowStatus[state.runtime_status.name]
if state.runtime_status.name in DaprWorkflowStatus.__members__
else DaprWorkflowStatus.UNKNOWN
)
if workflow_status == DaprWorkflowStatus.COMPLETED:
logger.info(f"Workflow '{instance_id}' completed successfully!")
@ -475,13 +486,32 @@ class WorkflowApp(BaseModel):
# Return the final state output
return state.serialized_output
except Exception as e:
logger.error(f"Error during workflow '{instance_id}': {e}")
raise
finally:
logger.info(f"Finished workflow with Instance ID: {instance_id}.")
self.stop_runtime()
# Off-load the stop_runtime call as it may block.
await asyncio.to_thread(self.stop_runtime)
def run_and_monitor_workflow_sync(
self,
workflow: Union[str, Callable],
input: Optional[Union[str, Dict[str, Any]]] = None
) -> Optional[str]:
"""
Synchronous wrapper for running and monitoring a workflow.
This allows calling code that is not async to still run the workflow.
Args:
workflow (Union[str, Callable]): The workflow name or callable.
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
Returns:
Optional[str]: The serialized output of the workflow.
"""
return asyncio.run(self.run_and_monitor_workflow_async(workflow, input))
def terminate_workflow(self, instance_id: str, *, output: Optional[Any] = None) -> None:
"""
@ -539,7 +569,9 @@ class WorkflowApp(BaseModel):
"""
try:
state = self.wf_client.wait_for_workflow_completion(
instance_id, fetch_payloads=fetch_payloads, timeout_in_seconds=timeout_in_seconds
instance_id,
fetch_payloads=fetch_payloads,
timeout_in_seconds=timeout_in_seconds
)
if state:
logger.info(f"Workflow {instance_id} completed with status: {state.runtime_status}.")

View File

@ -27,7 +27,7 @@ def write_blog(outline: str) -> str:
if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(
results = wfapp.run_and_monitor_workflow_sync(
analyze_topic,
input="AI Agents"
)

View File

@ -234,7 +234,7 @@ def write_blog(outline: str) -> str:
if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(
results = wfapp.run_and_monitor_workflow_sync(
analyze_topic,
input="AI Agents"
)

View File

@ -103,7 +103,7 @@ def get_line(character: str) -> str:
if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
print(f"Famous Line: {results}")
```
@ -209,7 +209,7 @@ if __name__ == "__main__":
research_topic = "The environmental impact of quantum computing"
logging.info(f"Starting research workflow on: {research_topic}")
results = wfapp.run_and_monitor_workflow(research_workflow, input=research_topic)
results = wfapp.run_and_monitor_workflow_sync(research_workflow, input=research_topic)
logging.info(f"\nResearch Report:\n{results}")
```

View File

@ -73,6 +73,6 @@ if __name__ == "__main__":
research_topic = "The environmental impact of quantum computing"
logging.info(f"Starting research workflow on: {research_topic}")
results = wfapp.run_and_monitor_workflow(research_workflow, input=research_topic)
results = wfapp.run_and_monitor_workflow_sync(research_workflow, input=research_topic)
if len(results) > 0:
logging.info(f"\nResearch Report:\n{results}")

View File

@ -26,5 +26,5 @@ def get_line(character: str) -> str:
if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
print(f"Famous Line: {results}")

View File

@ -5,8 +5,6 @@ from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Initialize the WorkflowApp
# Define Workflow logic
@workflow(name='task_chain_workflow')
def task_chain_workflow(ctx: DaprWorkflowContext):
@ -31,5 +29,5 @@ def get_line(character: str) -> str:
if __name__ == '__main__':
wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
print(f"Results: {results}")