mirror of https://github.com/dapr/dapr-agents.git
815 lines
30 KiB
Python
815 lines
30 KiB
Python
import asyncio
|
|
import functools
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import sys
|
|
import uuid
|
|
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
|
|
|
from dapr.ext.workflow import (
|
|
DaprWorkflowClient,
|
|
WorkflowActivityContext,
|
|
WorkflowRuntime,
|
|
)
|
|
from dapr.ext.workflow.workflow_state import WorkflowState
|
|
from durabletask import task as dtask
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
from typing import ClassVar
|
|
|
|
from dapr_agents.agents.base import ChatClientBase
|
|
from dapr_agents.types.workflow import DaprWorkflowStatus
|
|
from dapr_agents.workflow.task import WorkflowTask
|
|
from dapr_agents.workflow.utils.core import get_decorated_methods, is_pydantic_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class WorkflowApp(BaseModel):
|
|
"""
|
|
A Pydantic-based class to encapsulate a Dapr Workflow runtime and manage workflows and tasks.
|
|
"""
|
|
|
|
# Class-level constant for instrumentation callback
|
|
# This is used to register the workflow instrumentation callback with the instrumentor
|
|
# This is needed because the instrumentor is not available immediately when the WorkflowApp is initialized
|
|
INSTRUMENTATION_CALLBACK_ATTR: ClassVar[str] = "_workflow_instrumentation_callback"
|
|
|
|
llm: Optional[ChatClientBase] = Field(
|
|
default=None,
|
|
description="The default LLM client for tasks that explicitly require an LLM but don't specify one (optional).",
|
|
)
|
|
# TODO: I think this should be within the wf client or wf runtime...?
|
|
timeout: int = Field(
|
|
default=300,
|
|
description="Default timeout duration in seconds for workflow tasks.",
|
|
)
|
|
|
|
# Initialized in model_post_init
|
|
wf_runtime: Optional[WorkflowRuntime] = Field(
|
|
default=None, init=False, description="Workflow runtime instance."
|
|
)
|
|
wf_runtime_is_running: Optional[bool] = Field(
|
|
default=None, init=False, description="Is the Workflow runtime running?"
|
|
)
|
|
wf_client: Optional[DaprWorkflowClient] = Field(
|
|
default=None, init=False, description="Workflow client instance."
|
|
)
|
|
tasks: Dict[str, Callable] = Field(
|
|
default_factory=dict, init=False, description="Dictionary of registered tasks."
|
|
)
|
|
workflows: Dict[str, Callable] = Field(
|
|
default_factory=dict,
|
|
init=False,
|
|
description="Dictionary of registered workflows.",
|
|
)
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Initialize the Dapr workflow runtime and register tasks & workflows.
|
|
"""
|
|
# Initialize clients and runtime
|
|
self.wf_runtime = WorkflowRuntime()
|
|
self.wf_runtime_is_running = False
|
|
self.wf_client = DaprWorkflowClient()
|
|
logger.info("WorkflowApp initialized; discovering tasks and workflows.")
|
|
|
|
# Discover and register tasks and workflows
|
|
discovered_tasks = self._discover_tasks()
|
|
self._register_tasks(discovered_tasks)
|
|
discovered_wfs = self._discover_workflows()
|
|
self._register_workflows(discovered_wfs)
|
|
|
|
super().model_post_init(__context)
|
|
|
|
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
|
|
"""
|
|
Encapsulate LLM selection logic.
|
|
1. Use per-task override if provided on decorator.
|
|
2. Else if marked as explicitly requiring an LLM, fall back to default app LLM (if available).
|
|
3. Otherwise, returns None.
|
|
"""
|
|
per_task = getattr(method, "_task_llm", None)
|
|
if per_task:
|
|
return per_task
|
|
if getattr(method, "_explicit_llm", False):
|
|
if self.llm is None:
|
|
logger.warning(
|
|
f"Task '{getattr(method, '_task_name', getattr(method, '__name__', str(method)))}' requires an LLM "
|
|
"but no default LLM is configured in WorkflowApp and no explicit LLM was provided."
|
|
)
|
|
return self.llm
|
|
return None
|
|
|
|
def _discover_tasks(self) -> Dict[str, Callable]:
|
|
"""Gather all @task-decorated functions and methods."""
|
|
module = sys.modules["__main__"]
|
|
tasks: Dict[str, Callable] = {}
|
|
# Free functions in __main__
|
|
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
|
if getattr(fn, "_is_task", False) and fn.__module__ == module.__name__:
|
|
tasks[getattr(fn, "_task_name", name)] = fn
|
|
# Bound methods (if any) discovered via helper
|
|
for name, method in get_decorated_methods(self, "_is_task").items():
|
|
tasks[getattr(method, "_task_name", name)] = method
|
|
logger.debug(f"Discovered tasks: {list(tasks)}")
|
|
return tasks
|
|
|
|
def register_task(self, task_func: Callable, name: Optional[str] = None) -> None:
|
|
"""
|
|
Manually register a @task-decorated function, similar to native Dapr pattern.
|
|
|
|
This allows explicit registration of tasks from other modules or packages.
|
|
The task will be registered immediately with the Dapr runtime.
|
|
|
|
Args:
|
|
task_func: The @task-decorated function to register
|
|
name: Optional custom name (defaults to function name or _task_name)
|
|
|
|
Example:
|
|
from tasks.my_tasks import generate_queries
|
|
wfapp.register_task(generate_queries)
|
|
wfapp.register_task(generate_queries, name="custom_name")
|
|
"""
|
|
# Validate input
|
|
if not callable(task_func):
|
|
raise ValueError(
|
|
f"task_func must be callable, got {type(task_func)}: {task_func}"
|
|
)
|
|
|
|
if not getattr(task_func, "_is_task", False):
|
|
raise ValueError(
|
|
f"Function {getattr(task_func, '__name__', str(task_func))} is not decorated with @task"
|
|
)
|
|
|
|
task_name = name or getattr(
|
|
task_func, "_task_name", getattr(task_func, "__name__", "unknown_task")
|
|
)
|
|
|
|
# Check if already registered
|
|
if task_name in self.tasks:
|
|
logger.warning(f"Task '{task_name}' is already registered, skipping")
|
|
return
|
|
|
|
# Register immediately with Dapr runtime using existing registration logic
|
|
llm = self._choose_llm_for(task_func)
|
|
logger.debug(
|
|
f"Manually registering task '{task_name}' with llm={getattr(llm, '__class__', None)}"
|
|
)
|
|
|
|
kwargs = getattr(task_func, "_task_kwargs", {})
|
|
task_instance = WorkflowTask(
|
|
func=task_func,
|
|
description=getattr(task_func, "_task_description", None),
|
|
agent=getattr(task_func, "_task_agent", None),
|
|
llm=llm,
|
|
include_chat_history=getattr(
|
|
task_func, "_task_include_chat_history", False
|
|
),
|
|
workflow_app=self,
|
|
**kwargs,
|
|
)
|
|
|
|
# Wrap for Dapr invocation
|
|
wrapped = self._make_task_wrapper(task_name, task_func, task_instance)
|
|
self.wf_runtime.register_activity(wrapped)
|
|
self.tasks[task_name] = wrapped
|
|
|
|
def register_tasks_from_module(
|
|
self, module_name_or_object: Union[str, Any]
|
|
) -> None:
|
|
"""
|
|
Register all @task-decorated functions from a specific module.
|
|
|
|
Args:
|
|
module_name_or_object: Module name string (e.g., "tasks.queries") or imported module object
|
|
|
|
Example:
|
|
# Using string name
|
|
wfapp.register_tasks_from_module("tasks.queries")
|
|
|
|
# Using imported module
|
|
import tasks.queries
|
|
wfapp.register_tasks_from_module(tasks.queries)
|
|
|
|
# Using from import
|
|
from tasks import queries
|
|
wfapp.register_tasks_from_module(queries)
|
|
"""
|
|
try:
|
|
# Handle both string names and module objects
|
|
if isinstance(module_name_or_object, str):
|
|
import importlib
|
|
|
|
module = importlib.import_module(module_name_or_object)
|
|
module_name = module_name_or_object
|
|
else:
|
|
# Assume it's a module object
|
|
module = module_name_or_object
|
|
module_name = getattr(module, "__name__", str(module))
|
|
|
|
registered_count = 0
|
|
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
|
if getattr(fn, "_is_task", False):
|
|
task_name = getattr(fn, "_task_name", name)
|
|
if task_name not in self.tasks: # Skip if already registered
|
|
self.register_task(fn)
|
|
registered_count += 1
|
|
|
|
logger.info(
|
|
f"Registered {registered_count} tasks from module '{module_name}'"
|
|
)
|
|
|
|
except ImportError as e:
|
|
raise ImportError(f"Could not import module '{module_name_or_object}': {e}")
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Error registering tasks from module '{module_name_or_object}': {e}"
|
|
)
|
|
|
|
def register_tasks_from_package(self, package_name: str) -> None:
|
|
"""
|
|
Register all @task-decorated functions from all modules in a package.
|
|
|
|
Args:
|
|
package_name: Name of package to scan (e.g., "tasks")
|
|
|
|
Example:
|
|
wfapp.register_tasks_from_package("tasks") # Scans tasks/*.py
|
|
"""
|
|
try:
|
|
import importlib
|
|
import pkgutil
|
|
|
|
package = importlib.import_module(package_name)
|
|
|
|
# Collect all tasks first, then register using the original _register_tasks method
|
|
discovered_tasks: Dict[str, Callable] = {}
|
|
|
|
total_tasks = 0
|
|
for importer, modname, ispkg in pkgutil.iter_modules(package.__path__):
|
|
if not ispkg: # Only scan modules, not sub-packages
|
|
full_module_name = f"{package_name}.{modname}"
|
|
try:
|
|
module = importlib.import_module(full_module_name)
|
|
|
|
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
|
if getattr(fn, "_is_task", False):
|
|
task_name = getattr(fn, "_task_name", name)
|
|
if (
|
|
task_name not in self.tasks
|
|
): # Skip if already registered
|
|
discovered_tasks[task_name] = fn
|
|
total_tasks += 1
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to scan module {full_module_name}: {e}")
|
|
|
|
# Now register all discovered tasks using the original _register_tasks method
|
|
if discovered_tasks:
|
|
self._register_tasks(discovered_tasks)
|
|
|
|
logger.info(f"Registered {total_tasks} tasks from package '{package_name}'")
|
|
except ImportError as e:
|
|
raise ImportError(f"Could not import package '{package_name}': {e}")
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Error registering tasks from package '{package_name}': {e}"
|
|
)
|
|
|
|
def _register_tasks(self, tasks: Dict[str, Callable]) -> None:
|
|
"""Register each discovered task with the Dapr runtime using direct registration."""
|
|
for task_name, method in tasks.items():
|
|
llm = self._choose_llm_for(method)
|
|
logger.debug(
|
|
f"Registering task '{task_name}' with llm={getattr(llm, '__class__', None)}"
|
|
)
|
|
kwargs = getattr(method, "_task_kwargs", {})
|
|
task_instance = WorkflowTask(
|
|
func=method,
|
|
description=getattr(method, "_task_description", None),
|
|
agent=getattr(method, "_task_agent", None),
|
|
llm=llm,
|
|
include_chat_history=getattr(
|
|
method, "_task_include_chat_history", False
|
|
),
|
|
workflow_app=self,
|
|
**kwargs,
|
|
)
|
|
# Wrap for Dapr invocation
|
|
wrapped = self._make_task_wrapper(task_name, method, task_instance)
|
|
|
|
# Use direct registration like official Dapr examples
|
|
self.wf_runtime.register_activity(wrapped)
|
|
self.tasks[task_name] = wrapped
|
|
|
|
def _make_task_wrapper(
|
|
self, task_name: str, method: Callable, task_instance: WorkflowTask
|
|
) -> Callable:
|
|
"""Produce the function that Dapr will invoke for each activity."""
|
|
|
|
def run_sync(coro):
|
|
# Try to get the running event loop and run until complete
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
return loop.run_until_complete(coro)
|
|
except RuntimeError:
|
|
# If no running loop, create one
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
return loop.run_until_complete(coro)
|
|
finally:
|
|
loop.close()
|
|
|
|
@functools.wraps(method)
|
|
def wrapper(ctx: WorkflowActivityContext, *args, **kwargs):
|
|
wf_ctx = WorkflowActivityContext(ctx)
|
|
try:
|
|
call = task_instance(wf_ctx, *args, **kwargs)
|
|
if asyncio.iscoroutine(call):
|
|
return run_sync(call)
|
|
return call
|
|
except Exception:
|
|
logger.exception(f"Task '{task_name}' failed")
|
|
raise
|
|
|
|
return wrapper
|
|
|
|
# TODO: workflow discovery can also come from dapr runtime
|
|
# Python workflows can be registered in a variety of ways, and we need to support all of them.
|
|
# This supports decorator-based registration;
|
|
# however, there is also manual registration approach.
|
|
# See example below:
|
|
# def setup_workflow_runtime():
|
|
# wf_runtime = WorkflowRuntime()
|
|
# wf_runtime.register_workflow(order_processing_workflow)
|
|
# wf_runtime.register_workflow(fulfillment_workflow)
|
|
# wf_runtime.register_activity(process_payment)
|
|
# wf_runtime.register_activity(send_notification)
|
|
# return wf_runtime
|
|
|
|
# runtime = setup_workflow_runtime()
|
|
# runtime.start()
|
|
def _discover_workflows(self) -> Dict[str, Callable]:
|
|
"""Gather all @workflow-decorated functions and methods."""
|
|
module = sys.modules["__main__"]
|
|
wfs: Dict[str, Callable] = {}
|
|
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
|
if getattr(fn, "_is_workflow", False) and fn.__module__ == module.__name__:
|
|
wfs[getattr(fn, "_workflow_name", name)] = fn
|
|
for name, method in get_decorated_methods(self, "_is_workflow").items():
|
|
wfs[getattr(method, "_workflow_name", name)] = method
|
|
logger.info(f"Discovered workflows: {list(wfs)}")
|
|
return wfs
|
|
|
|
def _register_workflows(self, wfs: Dict[str, Callable]) -> None:
|
|
"""Register each discovered workflow with the Dapr runtime."""
|
|
for wf_name, method in wfs.items():
|
|
# Use a closure helper to avoid late-binding capture issues.
|
|
def make_wrapped(meth: Callable) -> Callable:
|
|
@functools.wraps(meth)
|
|
def wrapped(*args, **kwargs):
|
|
return meth(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
decorator = self.wf_runtime.workflow(name=wf_name)
|
|
self.workflows[wf_name] = decorator(make_wrapped(method))
|
|
|
|
def resolve_task(self, task: Union[str, Callable]) -> Callable:
|
|
"""
|
|
Resolves a registered task function by its name or decorated function.
|
|
|
|
Args:
|
|
task (Union[str, Callable]): The task name or callable function.
|
|
|
|
Returns:
|
|
Callable: The resolved task function.
|
|
|
|
Raises:
|
|
AttributeError: If the task is not found.
|
|
"""
|
|
if isinstance(task, str):
|
|
task_name = task
|
|
elif callable(task):
|
|
task_name = getattr(
|
|
task, "_task_name", getattr(task, "__name__", "unknown_task")
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid task reference: {task}")
|
|
|
|
task_func = self.tasks.get(task_name)
|
|
if not task_func:
|
|
raise AttributeError(f"Task '{task_name}' not found.")
|
|
|
|
return task_func
|
|
|
|
def resolve_workflow(self, workflow: Union[str, Callable]) -> Callable:
|
|
"""
|
|
Resolves a registered workflow function by its name or decorated function.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable function.
|
|
|
|
Returns:
|
|
Callable: The resolved workflow function.
|
|
|
|
Raises:
|
|
AttributeError: If the workflow is not found.
|
|
"""
|
|
if isinstance(workflow, str):
|
|
workflow_name = workflow # Direct lookup by string name
|
|
elif callable(workflow):
|
|
workflow_name = getattr(
|
|
workflow,
|
|
"_workflow_name",
|
|
getattr(workflow, "__name__", "unknown_workflow"),
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid workflow reference: {workflow}")
|
|
|
|
workflow_func = self.workflows.get(workflow_name)
|
|
if not workflow_func:
|
|
raise AttributeError(f"Workflow '{workflow_name}' not found.")
|
|
|
|
return workflow_func
|
|
|
|
def start_runtime(self):
|
|
"""Idempotently start the Dapr workflow runtime."""
|
|
if not self.wf_runtime_is_running:
|
|
logger.info("Starting workflow runtime.")
|
|
self.wf_runtime.start()
|
|
self.wf_runtime_is_running = True
|
|
|
|
# Apply workflow instrumentation if callback was registered
|
|
if hasattr(self.__class__, self.INSTRUMENTATION_CALLBACK_ATTR):
|
|
getattr(self.__class__, self.INSTRUMENTATION_CALLBACK_ATTR)()
|
|
else:
|
|
logger.debug("Workflow runtime already running; skipping.")
|
|
|
|
def stop_runtime(self):
|
|
"""Idempotently stop the Dapr workflow runtime."""
|
|
if self.wf_runtime_is_running:
|
|
logger.info("Stopping workflow runtime.")
|
|
self.wf_runtime.shutdown()
|
|
self.wf_runtime_is_running = False
|
|
else:
|
|
logger.debug("Workflow runtime already stopped; skipping.")
|
|
|
|
def run_workflow(
|
|
self, workflow: Union[str, Callable], input: Union[str, Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Starts a workflow execution.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable.
|
|
input (Union[str, Dict[str, Any]], optional): Input data for the workflow.
|
|
|
|
Returns:
|
|
str: The instance ID of the started workflow.
|
|
|
|
Raises:
|
|
Exception: If workflow execution fails.
|
|
"""
|
|
try:
|
|
# Generate unique instance ID
|
|
instance_id = uuid.uuid4().hex
|
|
|
|
# Resolve the workflow function
|
|
workflow_func = self.resolve_workflow(workflow)
|
|
|
|
# Schedule workflow execution
|
|
instance_id = self.wf_client.schedule_new_workflow(
|
|
workflow=workflow_func, input=input, instance_id=instance_id
|
|
)
|
|
|
|
logger.info(f"Started workflow with instance ID {instance_id}.")
|
|
return instance_id
|
|
except Exception as e:
|
|
logger.error(f"Failed to start workflow {workflow}: {e}")
|
|
raise
|
|
|
|
async def monitor_workflow_state(self, instance_id: str) -> Optional[WorkflowState]:
|
|
"""
|
|
Monitors and retrieves the final state of a workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
|
|
Returns:
|
|
Optional[WorkflowState]: The final state of the workflow or None if not found.
|
|
"""
|
|
try:
|
|
state: WorkflowState = await asyncio.to_thread(
|
|
self.wait_for_workflow_completion,
|
|
instance_id,
|
|
fetch_payloads=True,
|
|
timeout_in_seconds=self.timeout,
|
|
)
|
|
|
|
if not state:
|
|
logger.error(f"Workflow '{instance_id}' not found.")
|
|
return None
|
|
|
|
return state
|
|
except TimeoutError:
|
|
logger.error(f"Workflow '{instance_id}' monitoring timed out.")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving workflow state for '{instance_id}': {e}")
|
|
return None
|
|
|
|
async def monitor_workflow_completion(self, instance_id: str) -> None:
|
|
"""
|
|
Monitors the execution of a workflow and logs its final state.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
"""
|
|
try:
|
|
logger.info(f"Monitoring workflow '{instance_id}'...")
|
|
|
|
# Retrieve workflow state
|
|
state: WorkflowState = await self.monitor_workflow_state(instance_id)
|
|
if not state:
|
|
return # Error already logged in monitor_workflow_state
|
|
|
|
# Extract relevant details
|
|
workflow_status = state.runtime_status.name
|
|
failure_details = (
|
|
state.failure_details
|
|
) # This is an object, not a dictionary
|
|
|
|
if workflow_status.upper() == DaprWorkflowStatus.COMPLETED.value.upper():
|
|
logger.info(
|
|
f"Workflow '{instance_id}' completed successfully. Status: {workflow_status}."
|
|
)
|
|
|
|
if state.serialized_output:
|
|
logger.debug(
|
|
f"Output: {json.dumps(state.serialized_output, indent=2)}"
|
|
)
|
|
|
|
elif workflow_status.upper() in (
|
|
DaprWorkflowStatus.FAILED.value.upper(),
|
|
"ABORTED",
|
|
):
|
|
# Ensure `failure_details` exists before accessing attributes
|
|
error_type = getattr(failure_details, "error_type", "Unknown")
|
|
message = getattr(failure_details, "message", "No message provided")
|
|
stack_trace = getattr(
|
|
failure_details, "stack_trace", "No stack trace available"
|
|
)
|
|
|
|
logger.error(
|
|
f"Workflow '{instance_id}' failed.\n"
|
|
f"Error Type: {error_type}\n"
|
|
f"Message: {message}\n"
|
|
f"Stack Trace:\n{stack_trace}\n"
|
|
f"Input: {json.dumps(state.serialized_input, indent=2)}"
|
|
)
|
|
|
|
self.terminate_workflow(instance_id)
|
|
|
|
else:
|
|
logger.warning(
|
|
f"Workflow '{instance_id}' ended with status '{workflow_status}'.\n"
|
|
f"Input: {json.dumps(state.serialized_input, indent=2)}"
|
|
)
|
|
|
|
logger.debug(
|
|
f"Workflow Details: Instance ID={state.instance_id}, Name={state.name}, "
|
|
f"Created At={state.created_at}, Last Updated At={state.last_updated_at}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error monitoring workflow '{instance_id}': {e}", exc_info=True
|
|
)
|
|
finally:
|
|
logger.info(f"Finished monitoring workflow '{instance_id}'.")
|
|
|
|
async def run_and_monitor_workflow_async(
|
|
self,
|
|
workflow: Union[str, Callable],
|
|
input: Optional[Union[str, Dict[str, Any]]] = None,
|
|
) -> Optional[str]:
|
|
"""
|
|
Runs a workflow asynchronously and monitors its completion.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable.
|
|
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
|
|
|
|
Returns:
|
|
Optional[str]: The serialized output of the workflow.
|
|
"""
|
|
try:
|
|
# Off-load the potentially blocking run_workflow call to a thread.
|
|
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
|
|
|
|
logger.debug(
|
|
f"Workflow '{workflow}' started with instance ID: {instance_id}"
|
|
)
|
|
|
|
# Await the asynchronous monitoring of the workflow state.
|
|
state = await self.monitor_workflow_state(instance_id)
|
|
|
|
if not state:
|
|
raise RuntimeError(f"Workflow '{instance_id}' not found.")
|
|
|
|
workflow_status = (
|
|
DaprWorkflowStatus[state.runtime_status.name]
|
|
if state.runtime_status.name in DaprWorkflowStatus.__members__
|
|
else DaprWorkflowStatus.UNKNOWN
|
|
)
|
|
|
|
if workflow_status == DaprWorkflowStatus.COMPLETED:
|
|
logger.info(f"Workflow '{instance_id}' completed successfully!")
|
|
logger.debug(f"Output: {state.serialized_output}")
|
|
else:
|
|
logger.error(
|
|
f"Workflow '{instance_id}' ended with status '{workflow_status.value}'."
|
|
)
|
|
|
|
# Return the final state output
|
|
return state.serialized_output
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during workflow '{instance_id}': {e}")
|
|
raise
|
|
finally:
|
|
logger.info(f"Finished workflow with Instance ID: {instance_id}.")
|
|
|
|
def run_and_monitor_workflow_sync(
|
|
self,
|
|
workflow: Union[str, Callable],
|
|
input: Optional[Union[str, Dict[str, Any]]] = None,
|
|
) -> Optional[str]:
|
|
"""
|
|
Synchronous wrapper for running and monitoring a workflow.
|
|
This allows calling code that is not async to still run the workflow.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable.
|
|
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
|
|
|
|
Returns:
|
|
Optional[str]: The serialized output of the workflow.
|
|
"""
|
|
return asyncio.run(self.run_and_monitor_workflow_async(workflow, input))
|
|
|
|
def terminate_workflow(
|
|
self, instance_id: str, *, output: Optional[Any] = None
|
|
) -> None:
|
|
"""
|
|
Terminates a running workflow.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
output (Optional[Any]): Optional output to set for the terminated workflow.
|
|
|
|
Raises:
|
|
Exception: If the termination fails.
|
|
"""
|
|
try:
|
|
self.wf_client.terminate_workflow(instance_id=instance_id, output=output)
|
|
logger.info(
|
|
f"Successfully terminated workflow '{instance_id}' with output: {output}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to terminate workflow '{instance_id}'. Error: {e}")
|
|
raise Exception(f"Error terminating workflow '{instance_id}': {e}")
|
|
|
|
def get_workflow_state(self, instance_id: str) -> Optional[Any]:
|
|
"""
|
|
Retrieves the state of a workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
|
|
Returns:
|
|
Optional[Any]: The workflow state if found.
|
|
|
|
Raises:
|
|
RuntimeError: If retrieving the state fails.
|
|
"""
|
|
try:
|
|
state = self.wf_client.get_workflow_state(instance_id)
|
|
logger.info(
|
|
f"Retrieved state for workflow {instance_id}: {state.runtime_status}."
|
|
)
|
|
return state
|
|
except Exception as e:
|
|
logger.error(f"Failed to retrieve workflow state for {instance_id}: {e}")
|
|
return None
|
|
|
|
def wait_for_workflow_completion(
|
|
self,
|
|
instance_id: str,
|
|
fetch_payloads: bool = True,
|
|
timeout_in_seconds: int = 120,
|
|
) -> Optional[WorkflowState]:
|
|
"""
|
|
Waits for a workflow to complete and retrieves its state.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
fetch_payloads (bool): Whether to fetch input/output payloads.
|
|
timeout_in_seconds (int): Maximum wait time in seconds.
|
|
|
|
Returns:
|
|
Optional[WorkflowState]: The final state or None if it times out.
|
|
|
|
Raises:
|
|
RuntimeError: If waiting for completion fails.
|
|
"""
|
|
try:
|
|
state = self.wf_client.wait_for_workflow_completion(
|
|
instance_id,
|
|
fetch_payloads=fetch_payloads,
|
|
timeout_in_seconds=timeout_in_seconds,
|
|
)
|
|
if state:
|
|
logger.info(
|
|
f"Workflow {instance_id} completed with status: {state.runtime_status}."
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Workflow {instance_id} did not complete within the timeout period."
|
|
)
|
|
return state
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error while waiting for workflow {instance_id} completion: {e}"
|
|
)
|
|
return None
|
|
|
|
def raise_workflow_event(
|
|
self, instance_id: str, event_name: str, *, data: Any | None = None
|
|
) -> None:
|
|
"""
|
|
Raises an event for a running workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
event_name (str): The name of the event to raise.
|
|
data (Any | None): Optional event data.
|
|
|
|
Raises:
|
|
Exception: If raising the event fails.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
f"Raising workflow event '{event_name}' for instance '{instance_id}'"
|
|
)
|
|
# Ensure data is in a serializable format
|
|
if is_pydantic_model(type(data)):
|
|
# Convert Pydantic model to dict
|
|
data = data.model_dump()
|
|
# Raise the event using the Dapr workflow client with serialized data
|
|
self.wf_client.raise_workflow_event(
|
|
instance_id=instance_id, event_name=event_name, data=data
|
|
)
|
|
logger.info(
|
|
f"Successfully raised workflow event '{event_name}' for instance '{instance_id}'!"
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error raising workflow event '{event_name}' for instance '{instance_id}'. "
|
|
f"Data: {data}, Error: {e}"
|
|
)
|
|
raise Exception(
|
|
f"Failed to raise workflow event '{event_name}' for instance '{instance_id}': {str(e)}"
|
|
)
|
|
|
|
def when_all(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAllTask[T]:
|
|
"""
|
|
Waits for all given tasks to complete.
|
|
|
|
Args:
|
|
tasks (List[dtask.Task[T]]): The tasks to wait for.
|
|
|
|
Returns:
|
|
dtask.WhenAllTask[T]: A task that completes when all tasks finish.
|
|
"""
|
|
return dtask.when_all(tasks)
|
|
|
|
def when_any(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAnyTask:
|
|
"""
|
|
Waits for any one of the given tasks to complete.
|
|
|
|
Args:
|
|
tasks (List[dtask.Task[T]]): The tasks to monitor.
|
|
|
|
Returns:
|
|
dtask.WhenAnyTask: A task that completes when the first task finishes.
|
|
"""
|
|
return dtask.when_any(tasks)
|