mirror of https://github.com/dapr/dapr-agents.git
709 lines
31 KiB
Python
709 lines
31 KiB
Python
from dapr.ext.workflow import WorkflowRuntime, WorkflowActivityContext, DaprWorkflowContext, DaprWorkflowClient
|
|
from dapr.ext.workflow.workflow_state import WorkflowState
|
|
from dapr_agents.types.workflow import WorkflowStatus, WorkflowStateMap, WorkflowMessage, WorkflowEntry
|
|
from typing import Any, Callable, Generator, Optional, Dict, TypeVar, Union, List
|
|
from dapr_agents.storage.daprstores.statestore import DaprStateStore
|
|
from dapr_agents.workflow.task import Task, TaskWrapper
|
|
from pydantic import BaseModel, Field, ValidationError, ConfigDict
|
|
from dapr.conf import settings as dapr_settings
|
|
from dapr.clients import DaprClient
|
|
from durabletask import task as dtask
|
|
from datetime import datetime
|
|
import asyncio
|
|
import functools
|
|
import logging
|
|
import uuid
|
|
import json
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar('T')
|
|
TInput = TypeVar('TInput')
|
|
TOutput = TypeVar('TOutput')
|
|
Workflow = Callable[..., Union[Generator[dtask.Task, Any, Any], TOutput]]
|
|
|
|
class WorkflowApp(BaseModel):
|
|
"""
|
|
A Pydantic-based class to encapsulate a Dapr Workflow runtime and manage workflows and tasks.
|
|
|
|
Attributes:
|
|
daprGrpcHost (Optional[str]): Host address for the Dapr gRPC endpoint.
|
|
daprGrpcPort (Optional[int]): Port number for the Dapr gRPC endpoint.
|
|
wf_runtime (WorkflowRuntime): The Dapr Workflow runtime instance.
|
|
wf_client (DaprWorkflowClient): The Dapr Workflow client instance for invoking and interacting with workflows.
|
|
tasks (Dict[str, Callable]): A dictionary storing registered task functions by name.
|
|
workflows (Dict[str, Callable]): A dictionary storing registered workflows by name.
|
|
timeout (Optional[int]): Timeout for workflow completion in seconds. Defaults to 300.
|
|
"""
|
|
|
|
daprGrpcHost: Optional[str] = Field(None, description="Host address for the Dapr gRPC endpoint.")
|
|
daprGrpcPort: Optional[int] = Field(None, description="Port number for the Dapr gRPC endpoint.")
|
|
workflow_state_store_name: str = Field(default="workflowstatestore", description="The name of the Dapr state store component used to store workflow metadata.")
|
|
workflow_timeout: int = Field(default=300, description="Default timeout duration in seconds for workflow tasks.")
|
|
|
|
# Initialized in model_post_init
|
|
wf_runtime: Optional[WorkflowRuntime] = Field(default=None, init=False, description="Workflow runtime instance.")
|
|
wf_client: Optional[DaprWorkflowClient] = Field(default=None, init=False, description="Workflow client instance.")
|
|
wf_state_store: Optional[DaprStateStore] = Field(default=None, init=False, description="Dapr state store instance for accessing and managing workflow state.")
|
|
wf_state_key: str = Field(default="workflow_state", init=False, description="Dapr state store key for the workflow state.")
|
|
state: WorkflowStateMap = Field(default=None, init=False, description="Workflow Dapr state.")
|
|
tasks: Dict[str, Callable] = Field(default_factory=dict, description="Dictionary of registered tasks.")
|
|
workflows: Dict[str, Callable] = Field(default_factory=dict, description="Dictionary of registered workflows.")
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Post-initialization to configure Dapr Workflow runtime and client.
|
|
"""
|
|
# Configure Dapr gRPC settings, using environment variables if provided
|
|
env_daprGrpcHost = os.getenv('DAPR_RUNTIME_HOST')
|
|
env_daprGrpcPort = os.getenv('DAPR_GRPC_PORT')
|
|
|
|
# Resolve final values for Dapr settings
|
|
self.daprGrpcHost = self.daprGrpcHost or env_daprGrpcHost or dapr_settings.DAPR_RUNTIME_HOST
|
|
self.daprGrpcPort = int(self.daprGrpcPort or env_daprGrpcPort or dapr_settings.DAPR_GRPC_PORT)
|
|
|
|
# Initialize WorkflowRuntime and DaprWorkflowClient
|
|
self.wf_runtime = WorkflowRuntime(host=self.daprGrpcHost, port=self.daprGrpcPort)
|
|
self.wf_client = DaprWorkflowClient(host=self.daprGrpcHost, port=self.daprGrpcPort)
|
|
|
|
# Initialize Workflow state store
|
|
self.wf_state_store = DaprStateStore(store_name=self.workflow_state_store_name, address=f"{self.daprGrpcHost}:{self.daprGrpcPort}")
|
|
|
|
# Register workflow
|
|
self.register_workflow_metadata()
|
|
|
|
logger.info(f"Initialized WorkflowApp with Dapr gRPC host '{self.daprGrpcHost}' and port '{self.daprGrpcPort}'.")
|
|
|
|
# Proceed with base model setup
|
|
super().model_post_init(__context)
|
|
|
|
def task(self,func: Optional[Callable] = None, *, name: Optional[str] = None, description: Optional[str] = None, agent: Optional[Any] = None, agent_method: Optional[Union[str, Callable]] = "run", llm: Optional[Any] = None, llm_method: Optional[Union[str, Callable]] = "generate") -> Callable:
|
|
"""
|
|
Custom decorator to create and register a workflow task, supporting async and extended capabilities.
|
|
|
|
This decorator allows for the creation and registration of tasks that can be executed
|
|
as part of a Dapr workflow. The task can optionally integrate with an agent or LLM
|
|
for enhanced functionality. It supports both synchronous and asynchronous functions.
|
|
|
|
Args:
|
|
func (Callable, optional): The function to be decorated as a workflow task. Defaults to None.
|
|
name (Optional[str]): The name to register the task with. Defaults to the function's name.
|
|
description (Optional[str]): A textual description of the task. Defaults to None.
|
|
agent (Optional[Any]): The agent to use for executing the task if a description is provided. Defaults to None.
|
|
agent_method (Optional[Union[str, Callable]]): The method or callable to invoke the agent. Defaults to "run".
|
|
llm (Optional[Any]): The LLM client to use for executing the task if a description is provided. Defaults to None.
|
|
llm_method (Optional[Union[str, Callable]]): The method or callable to invoke the LLM client. Defaults to "generate".
|
|
|
|
Returns:
|
|
Callable: The decorated function wrapped with task logic and registered as an activity.
|
|
"""
|
|
# Check if the first argument is a string, implying it's the description
|
|
if isinstance(func, str) is True:
|
|
description = func
|
|
func = None
|
|
|
|
def decorator(f: Callable):
|
|
"""
|
|
Decorator to wrap a function as a Dapr workflow activity task.
|
|
|
|
Args:
|
|
f (Callable): The function to be wrapped and registered as a Dapr task.
|
|
|
|
Returns:
|
|
Callable: A decorated function wrapped with the task execution logic.
|
|
"""
|
|
# Wrap the original function with Task logic
|
|
task_instance = Task(
|
|
func=f,
|
|
description=description,
|
|
agent=agent,
|
|
agent_method=agent_method,
|
|
llm=llm,
|
|
llm_method=llm_method,
|
|
)
|
|
|
|
@functools.wraps(f)
|
|
def task_wrapper(ctx: WorkflowActivityContext, input: Any = None):
|
|
"""
|
|
Wrapper function for executing tasks in a Dapr workflow.
|
|
Handles both sync and async tasks.
|
|
"""
|
|
async def async_execution():
|
|
"""
|
|
Handles the actual asynchronous execution of the task.
|
|
"""
|
|
try:
|
|
result = await task_instance(ctx, input)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Async task execution failed: {e}")
|
|
raise
|
|
|
|
def run_in_event_loop(coroutine):
|
|
"""
|
|
Helper function to run a coroutine in the current or a new event loop.
|
|
"""
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
except RuntimeError:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
return loop.run_until_complete(coroutine)
|
|
|
|
try:
|
|
if asyncio.iscoroutinefunction(f) or asyncio.iscoroutinefunction(task_instance.__call__):
|
|
# Handle async tasks
|
|
return run_in_event_loop(async_execution())
|
|
else:
|
|
# Handle sync tasks
|
|
result = task_instance(ctx, input)
|
|
if asyncio.iscoroutine(result):
|
|
logger.warning("Sync task returned a coroutine. Running it in the event loop.")
|
|
return run_in_event_loop(result)
|
|
|
|
logger.info(f"Sync task completed.")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Task execution failed: {e}")
|
|
raise
|
|
|
|
# Register the task with Dapr Workflow
|
|
activity_decorator = self.wf_runtime.activity(name=name or f.__name__)
|
|
registered_activity = activity_decorator(task_wrapper)
|
|
|
|
# Optionally, store the task in the registry for easier access
|
|
task_name = name or f.__name__
|
|
self.tasks[task_name] = registered_activity
|
|
|
|
return registered_activity
|
|
|
|
return decorator(func) if func else decorator
|
|
|
|
def workflow(self, func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable:
|
|
"""
|
|
Custom decorator to register a function as a workflow, building on top of Dapr's workflow decorator.
|
|
|
|
This decorator allows you to register a function as a Dapr workflow while injecting additional
|
|
context and custom logic. It leverages the existing Dapr `workflow` decorator to ensure compatibility
|
|
with the Dapr workflow runtime and adds the workflow to an internal registry for easy management.
|
|
|
|
Args:
|
|
func (Callable, optional): The function to be decorated as a workflow. Defaults to None.
|
|
name (Optional[str]): The name to register the workflow with. Defaults to the function's name.
|
|
|
|
Returns:
|
|
Callable: The decorated function with context injection and registered as a workflow.
|
|
"""
|
|
def decorator(f: Callable):
|
|
@functools.wraps(f)
|
|
def workflow_wrapper(ctx: DaprWorkflowContext, *args, **kwargs):
|
|
# Inject the context into the function's closure
|
|
return f(ctx, *args, **kwargs)
|
|
|
|
# Use the original workflow decorator to register the task_wrapper
|
|
workflow_decorator = self.wf_runtime.workflow(name=name)
|
|
registered_workflow = workflow_decorator(workflow_wrapper)
|
|
|
|
# Optionally, store the task in your task registry
|
|
workflow_name = name or f.__name__
|
|
self.workflows[workflow_name] = registered_workflow
|
|
return workflow_wrapper
|
|
|
|
return decorator(func) if func else decorator
|
|
|
|
def create_task(
|
|
self,
|
|
*,
|
|
name: Optional[str],
|
|
description: Optional[str],
|
|
agent: Optional[Any] = None,
|
|
agent_method: Optional[Union[str, Callable]] = "run",
|
|
llm: Optional[Any] = None,
|
|
llm_method: Optional[Union[str, Callable]] = "generate"
|
|
) -> Callable:
|
|
"""
|
|
Method to create and register a task directly, without using it as a decorator.
|
|
|
|
Args:
|
|
name (Optional[str]): The name to register the task with.
|
|
description (Optional[str]): A textual description of the task, which can be used by an agent or LLM.
|
|
agent (Optional[Any]): The agent to use for executing the task if a description is provided. Defaults to None.
|
|
agent_method (Optional[Union[str, Callable]]): The method or callable to invoke the agent. Defaults to "run".
|
|
llm (Optional[Any]): The LLM client to use for executing the task if a description is provided. Defaults to None.
|
|
llm_method (Optional[Union[str, Callable]]): The method or callable to invoke the LLM client. Defaults to "generate".
|
|
|
|
Returns:
|
|
Callable: The wrapped Task object, ready to be used in a workflow.
|
|
"""
|
|
# Create the Task instance directly
|
|
task_instance = Task(None, description, agent, agent_method, llm, llm_method)
|
|
|
|
# Wrap the Task instance with a TaskWrapper that provides a __name__
|
|
wrapped_task = TaskWrapper(task_instance, name)
|
|
|
|
# Register the wrapped Task instance with the provided name
|
|
self.wf_runtime.register_activity(wrapped_task, name=name)
|
|
|
|
# Store the wrapped task in your task registry
|
|
self.tasks[name] = wrapped_task
|
|
return wrapped_task
|
|
|
|
def register_workflow_metadata(self) -> None:
|
|
"""
|
|
Initializes or loads the workflow metadata from the Dapr state store.
|
|
"""
|
|
logger.info("Registering Workflow metadata.")
|
|
|
|
# Attempt to retrieve existing state
|
|
has_state, state_data = self.wf_state_store.try_get_state(self.wf_state_key)
|
|
|
|
if not has_state:
|
|
# No existing state, initialize with default values
|
|
logger.info("Initializing state for workflow.")
|
|
self.state = WorkflowStateMap()
|
|
# Save newly initialized state
|
|
self.save_state(self.state)
|
|
else:
|
|
# Load the existing state
|
|
logger.info("Loading existing workflow state.")
|
|
logger.debug(f"Existing state data: {state_data}")
|
|
try:
|
|
self.state = WorkflowStateMap(**state_data)
|
|
except ValidationError as e:
|
|
# Handle invalid existing state
|
|
logger.error(f"Failed to validate existing state: {e}")
|
|
# Reinitialize with default values and save
|
|
self.state = WorkflowStateMap()
|
|
self.save_state(self.state)
|
|
|
|
def save_state(self, value: Optional[WorkflowStateMap] = None) -> None:
|
|
"""
|
|
Saves the workflow state to the Dapr state store using the predefined workflow state key.
|
|
|
|
Args:
|
|
value (Optional[WorkflowStateMap]): The state data to save. If not provided, uses `self.state`.
|
|
"""
|
|
try:
|
|
# Use the provided state or fallback to the local state
|
|
state_to_save = value or self.state
|
|
if not state_to_save:
|
|
raise ValueError("No state to save. Both `value` and `self.state` are None.")
|
|
|
|
self.wf_state_store.save_state(self.wf_state_key, state_to_save.model_dump_json())
|
|
logger.info(f"Successfully saved state for key '{self.wf_state_key}'.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save state for key '{self.wf_state_key}': {e}")
|
|
raise
|
|
|
|
def add_message(self, instance_id: str, message: Dict[str, Any]) -> None:
|
|
"""
|
|
Adds a message to the workflow entry for the given instance ID.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
message (Dict[str, Any]): The message to add.
|
|
"""
|
|
workflow_entry = self.state.instances.get(instance_id)
|
|
if not workflow_entry:
|
|
logger.error(f"Workflow with instance ID {instance_id} not found.")
|
|
return
|
|
|
|
workflow_message = WorkflowMessage(**message)
|
|
workflow_entry.messages.append(workflow_message)
|
|
self.save_state()
|
|
|
|
def resolve_task(self, task_name: str) -> Callable:
|
|
"""
|
|
Resolve the task function by its registered name.
|
|
|
|
Args:
|
|
task_name (str): The name of the task to resolve.
|
|
|
|
Returns:
|
|
Callable: The resolved task function.
|
|
|
|
Raises:
|
|
AttributeError: If the task function is not found.
|
|
"""
|
|
task_func = self.tasks.get(task_name)
|
|
if not task_func:
|
|
raise AttributeError(f"Task '{task_name}' not found.")
|
|
return task_func
|
|
|
|
def resolve_workflow(self, workflow_name: str) -> Callable:
|
|
"""
|
|
Resolve the workflow function by its registered name.
|
|
|
|
Args:
|
|
workflow_name (str): The name of the workflow to resolve.
|
|
|
|
Returns:
|
|
Callable: The resolved workflow function.
|
|
|
|
Raises:
|
|
AttributeError: If the workflow function is not found.
|
|
"""
|
|
workflow_func = self.workflows.get(workflow_name)
|
|
if not workflow_func:
|
|
raise AttributeError(f"Workflow '{workflow_name}' not found.")
|
|
return workflow_func
|
|
|
|
def run_workflow(self, workflow: Union[str, Callable], input: Union[str, Dict[str, Any]] = None) -> str:
|
|
"""
|
|
Start a workflow and manage its lifecycle.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): Workflow name or callable instance.
|
|
input (Union[str, Dict[str, Any]], optional): Input for the workflow. Defaults to None.
|
|
|
|
Returns:
|
|
str: Workflow instance ID.
|
|
"""
|
|
try:
|
|
# Start Workflow Runtime
|
|
self.start_runtime()
|
|
|
|
# Generate unique instance ID
|
|
instance_id = str(uuid.uuid4()).replace("-", "")
|
|
|
|
# Check for existing workflows
|
|
if instance_id in self.state.instances:
|
|
logger.warning(f"Workflow instance {instance_id} already exists.")
|
|
return
|
|
|
|
# Prepare workflow input
|
|
entry_input = input if isinstance(input, str) else json.dumps(input) if input else ""
|
|
|
|
# Initialize workflow entry
|
|
workflow_entry = WorkflowEntry(input=entry_input, status=WorkflowStatus.RUNNING)
|
|
self.state.instances[instance_id] = workflow_entry
|
|
self.save_state()
|
|
|
|
# Resolve the workflow function
|
|
workflow_func = self.resolve_workflow(workflow) if isinstance(workflow, str) else workflow
|
|
|
|
# Schedule workflow execution
|
|
instance_id = self.wf_client.schedule_new_workflow(
|
|
workflow=workflow_func,
|
|
input=input,
|
|
instance_id=instance_id
|
|
)
|
|
|
|
logger.info(f"Started workflow with instance ID {instance_id}.")
|
|
return instance_id
|
|
except Exception as e:
|
|
logger.error(f"Failed to start workflow {workflow}: {e}")
|
|
raise
|
|
|
|
async def monitor_workflow_completion(self, instance_id: str):
|
|
"""
|
|
Monitor workflow instance in the background and handle its final state.
|
|
"""
|
|
try:
|
|
logger.info(f"Starting to monitor workflow '{instance_id}'...")
|
|
|
|
state: WorkflowState = await asyncio.to_thread(
|
|
self.wait_for_workflow_completion,
|
|
instance_id,
|
|
fetch_payloads=True,
|
|
timeout_in_seconds=self.workflow_timeout,
|
|
)
|
|
|
|
if not state:
|
|
logger.error(f"Workflow '{instance_id}' not found.")
|
|
self.handle_workflow_output(instance_id, "Workflow not found.", WorkflowStatus.FAILED)
|
|
return
|
|
|
|
# Directly map runtime status to WorkflowStatus
|
|
workflow_status = WorkflowStatus[state.runtime_status.name]
|
|
|
|
if workflow_status == WorkflowStatus.COMPLETED:
|
|
logger.info(f"Workflow '{instance_id}' completed successfully!")
|
|
logger.debug(f"Output: {state.serialized_output}")
|
|
self.handle_workflow_output(instance_id, state.serialized_output, WorkflowStatus.COMPLETED)
|
|
else:
|
|
logger.error(f"Workflow '{instance_id}' ended with status '{workflow_status.value}'.")
|
|
self.handle_workflow_output(
|
|
instance_id,
|
|
f"Workflow ended with status: {workflow_status.value}.",
|
|
workflow_status,
|
|
)
|
|
|
|
except TimeoutError:
|
|
logger.error(f"Workflow '{instance_id}' monitoring timed out.")
|
|
self.handle_workflow_output(instance_id, "Workflow monitoring timed out.", WorkflowStatus.FAILED)
|
|
except Exception as e:
|
|
logger.error(f"Error monitoring workflow '{instance_id}': {e}")
|
|
self.handle_workflow_output(instance_id, f"Error monitoring workflow: {e}", WorkflowStatus.FAILED)
|
|
finally:
|
|
logger.info(f"Finished monitoring workflow '{instance_id}'.")
|
|
self.stop_runtime()
|
|
|
|
def run_and_monitor_workflow(self, workflow: Union[str, Callable], input: Optional[Union[str, Dict[str, Any]]] = None) -> WorkflowState:
|
|
"""
|
|
Run a workflow synchronously and handle its completion.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): Workflow name or callable instance.
|
|
input (Optional[Union[str, Dict[str, Any]]]): The input for the workflow.
|
|
|
|
Returns:
|
|
WorkflowState: The final state of the workflow after completion.
|
|
"""
|
|
try:
|
|
|
|
# Schedule the workflow
|
|
instance_id = self.run_workflow(workflow, input=input)
|
|
|
|
# Wait for workflow completion
|
|
state: WorkflowState = self.wait_for_workflow_completion(
|
|
instance_id,
|
|
fetch_payloads=True,
|
|
timeout_in_seconds=self.workflow_timeout,
|
|
)
|
|
|
|
if not state:
|
|
logger.error(f"Workflow '{instance_id}' not found.")
|
|
self.handle_workflow_output(instance_id, "Workflow not found.", WorkflowStatus.FAILED)
|
|
raise RuntimeError(f"Workflow '{instance_id}' not found.")
|
|
|
|
# Determine workflow status
|
|
try:
|
|
workflow_status = WorkflowStatus[state.runtime_status.name]
|
|
except KeyError:
|
|
workflow_status = WorkflowStatus.UNKNOWN
|
|
logger.warning(f"Unrecognized workflow status '{state.runtime_status.name}'. Defaulting to UNKNOWN.")
|
|
|
|
if workflow_status == WorkflowStatus.COMPLETED:
|
|
logger.info(f"Workflow '{instance_id}' completed successfully!")
|
|
logger.debug(f"Output: {state.serialized_output}")
|
|
self.handle_workflow_output(instance_id, state.serialized_output, WorkflowStatus.COMPLETED)
|
|
else:
|
|
logger.error(f"Workflow '{instance_id}' ended with status '{workflow_status.value}'.")
|
|
self.handle_workflow_output(
|
|
instance_id,
|
|
f"Workflow ended with status: {workflow_status.value}.",
|
|
workflow_status,
|
|
)
|
|
|
|
# Return the final state
|
|
logger.info(f"Returning final output for workflow '{instance_id}'")
|
|
logger.debug(f"Serialized Output: {state.serialized_output}")
|
|
return state.serialized_output
|
|
|
|
except TimeoutError:
|
|
logger.error(f"Workflow '{instance_id}' monitoring timed out.")
|
|
self.handle_workflow_output(instance_id, "Workflow monitoring timed out.", WorkflowStatus.FAILED)
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error during workflow '{instance_id}': {e}")
|
|
self.handle_workflow_output(instance_id, f"Error during workflow: {e}", WorkflowStatus.FAILED)
|
|
raise
|
|
finally:
|
|
logger.info(f"Finished workflow with Instance ID: {instance_id}.")
|
|
self.stop_runtime()
|
|
|
|
def handle_workflow_output(self, instance_id: str, output: Any, status: WorkflowStatus):
|
|
"""
|
|
Handle the output of a completed workflow.
|
|
"""
|
|
try:
|
|
# Check if workflow exists
|
|
workflow_entry = self.state.instances.get(instance_id)
|
|
if not workflow_entry:
|
|
logger.error(f"Workflow with instance ID {instance_id} not found.")
|
|
return
|
|
|
|
# Update workflow entry
|
|
workflow_entry.output = output
|
|
workflow_entry.status = status
|
|
workflow_entry.end_time = datetime.now()
|
|
|
|
# Persist the updated state
|
|
self.save_state()
|
|
|
|
logger.info(f"Workflow '{instance_id}' output persisted successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to persist workflow output for '{instance_id}': {e}")
|
|
|
|
def terminate_workflow(self, instance_id: str, *, output: Optional[Any] = None) -> None:
|
|
"""
|
|
Terminates a running workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The ID of the workflow instance to terminate.
|
|
output (Optional[Any]): The optional output to set for the terminated workflow instance.
|
|
|
|
Raises:
|
|
Exception: If the termination request fails.
|
|
"""
|
|
try:
|
|
self.wf_client.terminate_workflow(instance_id=instance_id, output=output)
|
|
logger.info(f"Successfully terminated workflow '{instance_id}' with output: {output}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to terminate workflow '{instance_id}'. Error: {e}")
|
|
raise Exception(f"Error terminating workflow '{instance_id}': {e}")
|
|
|
|
def get_workflow_state(self, instance_id: str) -> Optional[Any]:
|
|
"""
|
|
Retrieve the state of the workflow instance with the given ID.
|
|
|
|
Args:
|
|
instance_id (str): The ID of the workflow instance to retrieve the state for.
|
|
|
|
Returns:
|
|
Optional[Any]: The state of the workflow instance if found, otherwise None.
|
|
|
|
Raises:
|
|
RuntimeError: If there is an issue retrieving the workflow state.
|
|
"""
|
|
try:
|
|
state = self.wf_client.get_workflow_state(instance_id)
|
|
logger.info(f"Retrieved state for workflow {instance_id}: {state.runtime_status}.")
|
|
return state
|
|
except Exception as e:
|
|
logger.error(f"Failed to retrieve workflow state for {instance_id}: {e}")
|
|
return None
|
|
|
|
def wait_for_workflow_completion(self, instance_id: str, fetch_payloads: bool = True, timeout_in_seconds: int = 120) -> Optional[WorkflowState]:
|
|
"""
|
|
Wait for the workflow instance to complete and retrieve its state.
|
|
|
|
Args:
|
|
instance_id (str): The unique ID of the workflow instance to wait for.
|
|
fetch_payloads (bool): Whether to fetch the input, output payloads,
|
|
and custom status for the workflow instance. Defaults to True.
|
|
timeout_in_seconds (int): The maximum time in seconds to wait for the workflow instance
|
|
to complete. Defaults to 120 seconds.
|
|
|
|
Returns:
|
|
Optional[WorkflowState]: The state of the workflow instance if it completes within the timeout, otherwise None.
|
|
|
|
Raises:
|
|
RuntimeError: If there is an issue waiting for the workflow completion.
|
|
"""
|
|
try:
|
|
state = self.wf_client.wait_for_workflow_completion(
|
|
instance_id, fetch_payloads=fetch_payloads, timeout_in_seconds=timeout_in_seconds
|
|
)
|
|
if state:
|
|
logger.info(f"Workflow {instance_id} completed with status: {state.runtime_status}.")
|
|
else:
|
|
logger.warning(f"Workflow {instance_id} did not complete within the timeout period.")
|
|
return state
|
|
except Exception as e:
|
|
logger.error(f"Error while waiting for workflow {instance_id} completion: {e}")
|
|
return None
|
|
|
|
def raise_workflow_event(self, instance_id: str, event_name: str, *, data: Any | None = None) -> None:
|
|
"""
|
|
Raises an event for a running Dapr workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The unique identifier of the workflow instance.
|
|
event_name (str): The name of the event to raise in the workflow.
|
|
data (Any | None): The optional data payload for the event.
|
|
|
|
Raises:
|
|
Exception: If raising the event fails.
|
|
"""
|
|
try:
|
|
logger.info(f"Raising workflow event '{event_name}' for instance '{instance_id}'")
|
|
self.wf_client.raise_workflow_event(instance_id=instance_id, event_name=event_name, data=data)
|
|
logger.info(f"Successfully raised workflow event '{event_name}' for instance '{instance_id}'!")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error raising workflow event '{event_name}' for instance '{instance_id}'. "
|
|
f"Data: {data}, Error: {e}"
|
|
)
|
|
raise Exception(f"Failed to raise workflow event '{event_name}' for instance '{instance_id}': {str(e)}")
|
|
|
|
def call_service(
|
|
self,
|
|
service: str,
|
|
http_method: str = "POST",
|
|
input: Optional[Dict[str, Any]] = None,
|
|
timeout: Optional[int] = None
|
|
) -> Any:
|
|
"""
|
|
Call an external agent service via Dapr.
|
|
|
|
Args:
|
|
service (str): The name of the agent service to call.
|
|
http_method (str, optional): The HTTP method to use (e.g., "GET", "POST"). Defaults to "POST".
|
|
input (Optional[Dict[str, Any]], optional): The input data to pass to the agent service. Defaults to None.
|
|
timeout (Optional[int], optional): Timeout for the service call in seconds. Defaults to None.
|
|
|
|
Returns:
|
|
Any: The response from the agent service.
|
|
|
|
Raises:
|
|
Exception: If there is an error invoking the agent service.
|
|
"""
|
|
try:
|
|
with DaprClient() as d:
|
|
resp = d.invoke_method(
|
|
service,
|
|
"generate",
|
|
http_verb=http_method,
|
|
data=json.dumps(input) if input else None,
|
|
timeout=timeout
|
|
)
|
|
if resp.status_code != 200:
|
|
raise Exception(f"Error calling {service} service: {resp.status_code}: {resp.text}")
|
|
agent_response = json.loads(resp.data.decode("utf-8"))
|
|
logger.info(f"Agent's Result: {agent_response}")
|
|
return agent_response
|
|
except Exception as e:
|
|
logger.error(f"Failed to call agent service: {e}")
|
|
raise e
|
|
|
|
def when_all(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAllTask[T]:
|
|
"""
|
|
Returns a task that completes when all of the provided tasks complete or when one of the tasks fails.
|
|
|
|
This is useful in orchestrating multiple tasks in a workflow where you want to wait for all tasks
|
|
to either complete or for the first one to fail.
|
|
|
|
Args:
|
|
tasks (List[dtask.Task[T]]): A list of task instances that should all complete.
|
|
|
|
Returns:
|
|
dtask.WhenAllTask[T]: A task that represents the combined completion of all the provided tasks.
|
|
"""
|
|
return dtask.when_all(tasks)
|
|
|
|
def when_any(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAnyTask:
|
|
"""
|
|
Returns a task that completes when any one of the provided tasks completes or fails.
|
|
|
|
This is useful in scenarios where you want to proceed as soon as one of the tasks finishes, without
|
|
waiting for the others to complete.
|
|
|
|
Args:
|
|
tasks (List[dtask.Task[T]]): A list of task instances, any of which can complete or fail to trigger completion.
|
|
|
|
Returns:
|
|
dtask.WhenAnyTask: A task that represents the completion of the first of the provided tasks to finish.
|
|
"""
|
|
return dtask.when_any(tasks)
|
|
|
|
def start_runtime(self):
|
|
"""
|
|
Start the workflow runtime
|
|
"""
|
|
|
|
logger.info("Starting workflow runtime.")
|
|
self.wf_runtime.start()
|
|
|
|
def stop_runtime(self):
|
|
"""
|
|
Stop the workflow runtime.
|
|
"""
|
|
logger.info("Stopping workflow runtime.")
|
|
self.wf_runtime.shutdown() |