mirror of https://github.com/dapr/dapr-agents.git
Async‑first workflow runner, sync wrapper & registration closure fix (#93)
* Async workflow runner, sync wrapper & registration closure fix * updated cookbook notebooks to show sync and async with workflow monitoring * updated quickstarts to use the updated sync workflow monitoring
This commit is contained in:
parent
099dc5d2fb
commit
bd0859d181
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
import logging
|
||||
|
||||
@workflow(name='random_workflow')
|
||||
def task_chain_workflow(ctx:DaprWorkflowContext, input: int):
|
||||
|
|
@ -32,6 +33,6 @@ if __name__ == '__main__':
|
|||
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(task_chain_workflow, input=10)
|
||||
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow, input=10)
|
||||
|
||||
print(f"Results: {results}")
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
|
||||
@workflow(name="random_workflow")
|
||||
def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
|
||||
result1 = yield ctx.call_activity(step1, input=input)
|
||||
result2 = yield ctx.call_activity(step2, input=result1)
|
||||
result3 = yield ctx.call_activity(step3, input=result2)
|
||||
return [result1, result2, result3]
|
||||
|
||||
@task
|
||||
def step1(activity_input: int) -> int:
|
||||
print(f"Step 1: Received input: {activity_input}.")
|
||||
return activity_input + 1
|
||||
|
||||
@task
|
||||
def step2(activity_input: int) -> int:
|
||||
print(f"Step 2: Received input: {activity_input}.")
|
||||
return activity_input * 2
|
||||
|
||||
@task
|
||||
def step3(activity_input: int) -> int:
|
||||
print(f"Step 3: Received input: {activity_input}.")
|
||||
return activity_input ^ 2
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
result = await wfapp.run_and_monitor_workflow_async(
|
||||
task_chain_workflow,
|
||||
input=10
|
||||
)
|
||||
print(f"Results: {result}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -31,5 +31,5 @@ if __name__ == '__main__':
|
|||
wfapp = WorkflowApp()
|
||||
|
||||
# Run workflow
|
||||
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
|
||||
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
|
||||
print(results)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Define Workflow logic
|
||||
@workflow(name='lotr_workflow')
|
||||
def task_chain_workflow(ctx: DaprWorkflowContext):
|
||||
result1 = yield ctx.call_activity(get_character)
|
||||
result2 = yield ctx.call_activity(get_line, input={"character": result1})
|
||||
return result2
|
||||
|
||||
@task(description="""
|
||||
Pick a random character from The Lord of the Rings\n
|
||||
and respond with the character's name ONLY
|
||||
""")
|
||||
def get_character() -> str:
|
||||
pass
|
||||
|
||||
@task(description="What is a famous line by {character}",)
|
||||
def get_line(character: str) -> str:
|
||||
pass
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the WorkflowApp
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
# Run workflow
|
||||
result = await wfapp.run_and_monitor_workflow_async(task_chain_workflow)
|
||||
print(f"Results: {result}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
|
||||
@workflow
|
||||
def question(ctx:DaprWorkflowContext, input:int):
|
||||
|
|
@ -25,5 +25,9 @@ if __name__ == '__main__':
|
|||
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(workflow=question, input="Scooby Doo")
|
||||
results = wfapp.run_and_monitor_workflow_sync(
|
||||
workflow=question,
|
||||
input="Scooby Doo"
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@workflow
|
||||
def question(ctx:DaprWorkflowContext, input:int):
|
||||
step1 = yield ctx.call_activity(ask, input=input)
|
||||
return step1
|
||||
|
||||
class Dog(BaseModel):
|
||||
name: str
|
||||
bio: str
|
||||
breed: str
|
||||
|
||||
@task("Who was {name}?")
|
||||
def ask(name:str) -> Dog:
|
||||
pass
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the WorkflowApp
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
# Run workflow
|
||||
result = await wfapp.run_and_monitor_workflow_async(
|
||||
workflow=question,
|
||||
input="Scooby Doo"
|
||||
)
|
||||
print(f"Results: {result}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -358,4 +358,4 @@ if __name__ == '__main__':
|
|||
raise ValueError("PDF URL must be provided via CLI or config file.")
|
||||
|
||||
# Run the workflow
|
||||
wfapp.run_and_monitor_workflow(workflow=doc2podcast, input=user_input)
|
||||
wfapp.run_and_monitor_workflow_sync(workflow=doc2podcast, input=user_input)
|
||||
|
|
@ -54,7 +54,7 @@ if __name__ == '__main__':
|
|||
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(workflow=test_workflow)
|
||||
results = wfapp.run_and_monitor_workflow_sync(workflow=test_workflow)
|
||||
|
||||
logging.info("Workflow results: %s", results)
|
||||
logging.info("Workflow completed successfully.")
|
||||
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ class WorkflowApp(BaseModel):
|
|||
|
||||
# Initialized in model_post_init
|
||||
wf_runtime: Optional[WorkflowRuntime] = Field(default=None, init=False, description="Workflow runtime instance.")
|
||||
wf_runtime_is_running: Optional[bool] = Field(default=None, init=False, description="Is the Workflow runtime running?.")
|
||||
wf_runtime_is_running: Optional[bool] = Field(default=None, init=False, description="Is the Workflow runtime running?")
|
||||
wf_client: Optional[DaprWorkflowClient] = Field(default=None, init=False, description="Workflow client instance.")
|
||||
client: Optional[DaprClient] = Field(default=None, init=False, description="Dapr client instance.")
|
||||
tasks: Dict[str, Callable] = Field(default_factory=dict, init=False, description="Dictionary of registered tasks.")
|
||||
|
|
@ -50,14 +50,14 @@ class WorkflowApp(BaseModel):
|
|||
"""
|
||||
Initialize the Dapr workflow runtime and register tasks & workflows.
|
||||
"""
|
||||
# initialize clients and runtime
|
||||
# Initialize clients and runtime
|
||||
self.wf_runtime = WorkflowRuntime()
|
||||
self.wf_runtime_is_running = False
|
||||
self.wf_client = DaprWorkflowClient()
|
||||
self.client = DaprClient()
|
||||
logger.info("WorkflowApp initialized; discovering tasks and workflows.")
|
||||
|
||||
# Discover and register
|
||||
# Discover and register tasks and workflows
|
||||
discovered_tasks = self._discover_tasks()
|
||||
self._register_tasks(discovered_tasks)
|
||||
discovered_wfs = self._discover_workflows()
|
||||
|
|
@ -75,9 +75,9 @@ class WorkflowApp(BaseModel):
|
|||
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
|
||||
"""
|
||||
Encapsulate LLM selection logic.
|
||||
1. Use per-task override if provided on decorator.
|
||||
2. Else if description-based, fall back to default app LLM.
|
||||
3. Otherwise, None.
|
||||
1. Use per-task override if provided on decorator.
|
||||
2. Else if marked as explicitly requiring an LLM, fall back to default app LLM.
|
||||
3. Otherwise, returns None.
|
||||
"""
|
||||
per_task = getattr(method, '_task_llm', None)
|
||||
if per_task:
|
||||
|
|
@ -90,11 +90,11 @@ class WorkflowApp(BaseModel):
|
|||
"""Gather all @task-decorated functions and methods."""
|
||||
module = sys.modules['__main__']
|
||||
tasks: Dict[str, Callable] = {}
|
||||
# free functions
|
||||
# Free functions in __main__
|
||||
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
||||
if getattr(fn, '_is_task', False) and fn.__module__ == module.__name__:
|
||||
tasks[getattr(fn, '_task_name', name)] = fn
|
||||
# bound methods
|
||||
# Bound methods (if any) discovered via helper
|
||||
for name, method in get_decorated_methods(self, '_is_task').items():
|
||||
tasks[getattr(method, '_task_name', name)] = method
|
||||
logger.debug(f"Discovered tasks: {list(tasks)}")
|
||||
|
|
@ -115,7 +115,7 @@ class WorkflowApp(BaseModel):
|
|||
workflow_app=self,
|
||||
**kwargs
|
||||
)
|
||||
# wrap for Dapr
|
||||
# Wrap for Dapr invocation
|
||||
wrapped = self._make_task_wrapper(task_name, method, task_instance)
|
||||
activity_decorator = self.wf_runtime.activity(name=task_name)
|
||||
self.tasks[task_name] = activity_decorator(wrapped)
|
||||
|
|
@ -128,10 +128,12 @@ class WorkflowApp(BaseModel):
|
|||
) -> Callable:
|
||||
"""Produce the function that Dapr will invoke for each activity."""
|
||||
def run_sync(coro):
|
||||
# Try to get the running event loop and run until complete
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
except RuntimeError:
|
||||
# If no running loop, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(coro)
|
||||
|
|
@ -164,13 +166,15 @@ class WorkflowApp(BaseModel):
|
|||
def _register_workflows(self, wfs: Dict[str, Callable]) -> None:
|
||||
"""Register each discovered workflow with the Dapr runtime."""
|
||||
for wf_name, method in wfs.items():
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapped(*args, **kwargs):
|
||||
return method(*args, **kwargs)
|
||||
# Use a closure helper to avoid late-binding capture issues.
|
||||
def make_wrapped(meth: Callable) -> Callable:
|
||||
@functools.wraps(meth)
|
||||
def wrapped(*args, **kwargs):
|
||||
return meth(*args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
decorator = self.wf_runtime.workflow(name=wf_name)
|
||||
self.workflows[wf_name] = decorator(wrapped)
|
||||
self.workflows[wf_name] = decorator(make_wrapped(method))
|
||||
|
||||
def start_runtime(self):
|
||||
"""Idempotently start the Dapr workflow runtime."""
|
||||
|
|
@ -244,8 +248,8 @@ class WorkflowApp(BaseModel):
|
|||
logger.debug(f"Sleeping for 1 second before retrying transaction...")
|
||||
time.sleep(1)
|
||||
raise Exception(f"Failed to update state store key: {store_key} after 10 attempts.")
|
||||
|
||||
def get_data_from_store(self, store_name: str, key: str) -> Tuple[bool, dict]:
|
||||
|
||||
def get_data_from_store(self, store_name: str, key: str) -> Optional[dict]:
|
||||
"""
|
||||
Retrieves data from the Dapr state store using the given key.
|
||||
|
||||
|
|
@ -254,7 +258,7 @@ class WorkflowApp(BaseModel):
|
|||
key (str): The key to fetch data from.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, dict]: A tuple indicating if data was found (bool) and the retrieved data (dict).
|
||||
Optional[dict]: the retrieved dictionary or None if not found.
|
||||
"""
|
||||
try:
|
||||
response: StateResponse = self.client.get_state(store_name=store_name, key=key)
|
||||
|
|
@ -335,10 +339,9 @@ class WorkflowApp(BaseModel):
|
|||
# Start Workflow Runtime
|
||||
if not self.wf_runtime_is_running:
|
||||
self.start_runtime()
|
||||
self.wf_runtime_is_running = True
|
||||
|
||||
# Generate unique instance ID
|
||||
instance_id = str(uuid.uuid4()).replace("-", "")
|
||||
instance_id = uuid.uuid4().hex
|
||||
|
||||
# Resolve the workflow function
|
||||
workflow_func = self.resolve_workflow(workflow)
|
||||
|
|
@ -443,29 +446,37 @@ class WorkflowApp(BaseModel):
|
|||
finally:
|
||||
logger.info(f"Finished monitoring workflow '{instance_id}'.")
|
||||
|
||||
def run_and_monitor_workflow(self, workflow: Union[str, Callable], input: Optional[Union[str, Dict[str, Any]]] = None) -> Optional[str]:
|
||||
async def run_and_monitor_workflow_async(
|
||||
self,
|
||||
workflow: Union[str, Callable],
|
||||
input: Optional[Union[str, Dict[str, Any]]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Runs a workflow synchronously and monitors its completion.
|
||||
|
||||
Runs a workflow asynchronously and monitors its completion.
|
||||
|
||||
Args:
|
||||
workflow (Union[str, Callable]): The workflow name or callable.
|
||||
input (Optional[Union[str, Dict[str, Any]]]): The workflow input.
|
||||
|
||||
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The serialized output of the workflow.
|
||||
"""
|
||||
instance_id = None
|
||||
try:
|
||||
# Schedule the workflow
|
||||
instance_id = self.run_workflow(workflow, input=input)
|
||||
|
||||
# Ensure we run within a new asyncio event loop
|
||||
state = asyncio.run(self.monitor_workflow_state(instance_id))
|
||||
|
||||
# Off-load the potentially blocking run_workflow call to a thread.
|
||||
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
|
||||
|
||||
# Await the asynchronous monitoring of the workflow state.
|
||||
state = await self.monitor_workflow_state(instance_id)
|
||||
|
||||
if not state:
|
||||
raise RuntimeError(f"Workflow '{instance_id}' not found.")
|
||||
|
||||
workflow_status = DaprWorkflowStatus[state.runtime_status.name] if state.runtime_status.name in DaprWorkflowStatus.__members__ else DaprWorkflowStatus.UNKNOWN
|
||||
|
||||
workflow_status = (
|
||||
DaprWorkflowStatus[state.runtime_status.name]
|
||||
if state.runtime_status.name in DaprWorkflowStatus.__members__
|
||||
else DaprWorkflowStatus.UNKNOWN
|
||||
)
|
||||
|
||||
if workflow_status == DaprWorkflowStatus.COMPLETED:
|
||||
logger.info(f"Workflow '{instance_id}' completed successfully!")
|
||||
|
|
@ -475,13 +486,32 @@ class WorkflowApp(BaseModel):
|
|||
|
||||
# Return the final state output
|
||||
return state.serialized_output
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during workflow '{instance_id}': {e}")
|
||||
raise
|
||||
finally:
|
||||
logger.info(f"Finished workflow with Instance ID: {instance_id}.")
|
||||
self.stop_runtime()
|
||||
# Off-load the stop_runtime call as it may block.
|
||||
await asyncio.to_thread(self.stop_runtime)
|
||||
|
||||
def run_and_monitor_workflow_sync(
|
||||
self,
|
||||
workflow: Union[str, Callable],
|
||||
input: Optional[Union[str, Dict[str, Any]]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Synchronous wrapper for running and monitoring a workflow.
|
||||
This allows calling code that is not async to still run the workflow.
|
||||
|
||||
Args:
|
||||
workflow (Union[str, Callable]): The workflow name or callable.
|
||||
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The serialized output of the workflow.
|
||||
"""
|
||||
return asyncio.run(self.run_and_monitor_workflow_async(workflow, input))
|
||||
|
||||
def terminate_workflow(self, instance_id: str, *, output: Optional[Any] = None) -> None:
|
||||
"""
|
||||
|
|
@ -539,7 +569,9 @@ class WorkflowApp(BaseModel):
|
|||
"""
|
||||
try:
|
||||
state = self.wf_client.wait_for_workflow_completion(
|
||||
instance_id, fetch_payloads=fetch_payloads, timeout_in_seconds=timeout_in_seconds
|
||||
instance_id,
|
||||
fetch_payloads=fetch_payloads,
|
||||
timeout_in_seconds=timeout_in_seconds
|
||||
)
|
||||
if state:
|
||||
logger.info(f"Workflow {instance_id} completed with status: {state.runtime_status}.")
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def write_blog(outline: str) -> str:
|
|||
if __name__ == '__main__':
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(
|
||||
results = wfapp.run_and_monitor_workflow_sync(
|
||||
analyze_topic,
|
||||
input="AI Agents"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ def write_blog(outline: str) -> str:
|
|||
if __name__ == '__main__':
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(
|
||||
results = wfapp.run_and_monitor_workflow_sync(
|
||||
analyze_topic,
|
||||
input="AI Agents"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def get_line(character: str) -> str:
|
|||
if __name__ == '__main__':
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
|
||||
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
|
||||
print(f"Famous Line: {results}")
|
||||
```
|
||||
|
||||
|
|
@ -209,7 +209,7 @@ if __name__ == "__main__":
|
|||
research_topic = "The environmental impact of quantum computing"
|
||||
|
||||
logging.info(f"Starting research workflow on: {research_topic}")
|
||||
results = wfapp.run_and_monitor_workflow(research_workflow, input=research_topic)
|
||||
results = wfapp.run_and_monitor_workflow_sync(research_workflow, input=research_topic)
|
||||
logging.info(f"\nResearch Report:\n{results}")
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -73,6 +73,6 @@ if __name__ == "__main__":
|
|||
research_topic = "The environmental impact of quantum computing"
|
||||
|
||||
logging.info(f"Starting research workflow on: {research_topic}")
|
||||
results = wfapp.run_and_monitor_workflow(research_workflow, input=research_topic)
|
||||
results = wfapp.run_and_monitor_workflow_sync(research_workflow, input=research_topic)
|
||||
if len(results) > 0:
|
||||
logging.info(f"\nResearch Report:\n{results}")
|
||||
|
|
@ -26,5 +26,5 @@ def get_line(character: str) -> str:
|
|||
if __name__ == '__main__':
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
|
||||
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
|
||||
print(f"Famous Line: {results}")
|
||||
|
|
@ -5,8 +5,6 @@ from dotenv import load_dotenv
|
|||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the WorkflowApp
|
||||
|
||||
# Define Workflow logic
|
||||
@workflow(name='task_chain_workflow')
|
||||
def task_chain_workflow(ctx: DaprWorkflowContext):
|
||||
|
|
@ -31,5 +29,5 @@ def get_line(character: str) -> str:
|
|||
if __name__ == '__main__':
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow(task_chain_workflow)
|
||||
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
|
||||
print(f"Results: {results}")
|
||||
Loading…
Reference in New Issue