mirror of https://github.com/dapr/dapr-agents.git
786 lines
29 KiB
Python
786 lines
29 KiB
Python
import asyncio
|
|
import functools
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
|
|
from durabletask import task as dtask
|
|
|
|
from dapr.clients import DaprClient
|
|
from dapr.clients.grpc._request import (
|
|
TransactionOperationType,
|
|
TransactionalStateOperation,
|
|
)
|
|
from dapr.clients.grpc._response import StateResponse
|
|
from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
|
|
from dapr.ext.workflow import (
|
|
DaprWorkflowClient,
|
|
WorkflowActivityContext,
|
|
WorkflowRuntime,
|
|
)
|
|
from dapr.ext.workflow.workflow_state import WorkflowState
|
|
|
|
from dapr_agents.llm.chat import ChatClientBase
|
|
from dapr_agents.types.workflow import DaprWorkflowStatus
|
|
from dapr_agents.workflow.task import WorkflowTask
|
|
from dapr_agents.workflow.utils import get_decorated_methods
|
|
|
|
from pydantic import PrivateAttr
|
|
from dapr_agents.agent.telemetry import (
|
|
DaprAgentsOTel,
|
|
span_decorator,
|
|
)
|
|
|
|
from opentelemetry.trace import Tracer, set_tracer_provider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class WorkflowApp(BaseModel):
|
|
"""
|
|
A Pydantic-based class to encapsulate a Dapr Workflow runtime and manage workflows and tasks.
|
|
"""
|
|
|
|
llm: Optional[ChatClientBase] = Field(
|
|
default=None, description="The default LLM client for all LLM-based tasks."
|
|
)
|
|
timeout: int = Field(
|
|
default=300,
|
|
description="Default timeout duration in seconds for workflow tasks.",
|
|
)
|
|
|
|
# Initialized in model_post_init
|
|
wf_runtime: Optional[WorkflowRuntime] = Field(
|
|
default=None, init=False, description="Workflow runtime instance."
|
|
)
|
|
wf_runtime_is_running: Optional[bool] = Field(
|
|
default=None, init=False, description="Is the Workflow runtime running?"
|
|
)
|
|
wf_client: Optional[DaprWorkflowClient] = Field(
|
|
default=None, init=False, description="Workflow client instance."
|
|
)
|
|
client: Optional[DaprClient] = Field(
|
|
default=None, init=False, description="Dapr client instance."
|
|
)
|
|
tasks: Dict[str, Callable] = Field(
|
|
default_factory=dict, init=False, description="Dictionary of registered tasks."
|
|
)
|
|
workflows: Dict[str, Callable] = Field(
|
|
default_factory=dict,
|
|
init=False,
|
|
description="Dictionary of registered workflows.",
|
|
)
|
|
|
|
_tracer: Optional[Tracer] = PrivateAttr(default=None)
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Initialize the Dapr workflow runtime and register tasks & workflows.
|
|
"""
|
|
# Initialize clients and runtime
|
|
self.wf_runtime = WorkflowRuntime()
|
|
self.wf_runtime_is_running = False
|
|
self.wf_client = DaprWorkflowClient()
|
|
self.client = DaprClient()
|
|
logger.info("WorkflowApp initialized; discovering tasks and workflows.")
|
|
|
|
# Discover and register tasks and workflows
|
|
discovered_tasks = self._discover_tasks()
|
|
self._register_tasks(discovered_tasks)
|
|
discovered_wfs = self._discover_workflows()
|
|
self._register_workflows(discovered_wfs)
|
|
|
|
try:
|
|
otel_client = DaprAgentsOTel(
|
|
service_name=self.name,
|
|
otlp_endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
|
|
)
|
|
provider = otel_client.create_and_instrument_tracer_provider()
|
|
set_tracer_provider(provider)
|
|
|
|
self._tracer = provider.get_tracer("wf_tracer")
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"OpenTelemetry initialization failed: {e}. Continuing without telemetry."
|
|
)
|
|
self._tracer = None
|
|
|
|
super().model_post_init(__context)
|
|
|
|
def get_chat_history(self) -> List[Any]:
|
|
"""
|
|
Stub for fetching past conversation history. Override in subclasses.
|
|
"""
|
|
logger.debug("Fetching chat history (default stub)")
|
|
return []
|
|
|
|
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
|
|
"""
|
|
Encapsulate LLM selection logic.
|
|
1. Use per-task override if provided on decorator.
|
|
2. Else if marked as explicitly requiring an LLM, fall back to default app LLM.
|
|
3. Otherwise, returns None.
|
|
"""
|
|
per_task = getattr(method, "_task_llm", None)
|
|
if per_task:
|
|
return per_task
|
|
if getattr(method, "_explicit_llm", False):
|
|
return self.llm
|
|
return None
|
|
|
|
def _discover_tasks(self) -> Dict[str, Callable]:
|
|
"""Gather all @task-decorated functions and methods."""
|
|
module = sys.modules["__main__"]
|
|
tasks: Dict[str, Callable] = {}
|
|
# Free functions in __main__
|
|
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
|
if getattr(fn, "_is_task", False) and fn.__module__ == module.__name__:
|
|
tasks[getattr(fn, "_task_name", name)] = fn
|
|
# Bound methods (if any) discovered via helper
|
|
for name, method in get_decorated_methods(self, "_is_task").items():
|
|
tasks[getattr(method, "_task_name", name)] = method
|
|
logger.debug(f"Discovered tasks: {list(tasks)}")
|
|
return tasks
|
|
|
|
def _register_tasks(self, tasks: Dict[str, Callable]) -> None:
|
|
"""Register each discovered task with the Dapr runtime."""
|
|
for task_name, method in tasks.items():
|
|
llm = self._choose_llm_for(method)
|
|
logger.debug(
|
|
f"Registering task '{task_name}' with llm={getattr(llm, '__class__', None)}"
|
|
)
|
|
kwargs = getattr(method, "_task_kwargs", {})
|
|
task_instance = WorkflowTask(
|
|
func=method,
|
|
description=getattr(method, "_task_description", None),
|
|
agent=getattr(method, "_task_agent", None),
|
|
llm=llm,
|
|
include_chat_history=getattr(
|
|
method, "_task_include_chat_history", False
|
|
),
|
|
workflow_app=self,
|
|
**kwargs,
|
|
)
|
|
# Wrap for Dapr invocation
|
|
wrapped = self._make_task_wrapper(task_name, method, task_instance)
|
|
activity_decorator = self.wf_runtime.activity(name=task_name)
|
|
self.tasks[task_name] = activity_decorator(wrapped)
|
|
|
|
def _make_task_wrapper(
|
|
self, task_name: str, method: Callable, task_instance: WorkflowTask
|
|
) -> Callable:
|
|
"""Produce the function that Dapr will invoke for each activity."""
|
|
|
|
def run_sync(coro):
|
|
# Try to get the running event loop and run until complete
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
return loop.run_until_complete(coro)
|
|
except RuntimeError:
|
|
# If no running loop, create one
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
return loop.run_until_complete(coro)
|
|
|
|
@functools.wraps(method)
|
|
def wrapper(ctx: WorkflowActivityContext, *args, **kwargs):
|
|
wf_ctx = WorkflowActivityContext(ctx)
|
|
try:
|
|
call = task_instance(wf_ctx, *args, **kwargs)
|
|
if asyncio.iscoroutine(call):
|
|
return run_sync(call)
|
|
return call
|
|
except Exception:
|
|
logger.exception(f"Task '{task_name}' failed")
|
|
raise
|
|
|
|
return wrapper
|
|
|
|
def _discover_workflows(self) -> Dict[str, Callable]:
|
|
"""Gather all @workflow-decorated functions and methods."""
|
|
module = sys.modules["__main__"]
|
|
wfs: Dict[str, Callable] = {}
|
|
for name, fn in inspect.getmembers(module, inspect.isfunction):
|
|
if getattr(fn, "_is_workflow", False) and fn.__module__ == module.__name__:
|
|
wfs[getattr(fn, "_workflow_name", name)] = fn
|
|
for name, method in get_decorated_methods(self, "_is_workflow").items():
|
|
wfs[getattr(method, "_workflow_name", name)] = method
|
|
logger.info(f"Discovered workflows: {list(wfs)}")
|
|
return wfs
|
|
|
|
def _register_workflows(self, wfs: Dict[str, Callable]) -> None:
|
|
"""Register each discovered workflow with the Dapr runtime."""
|
|
for wf_name, method in wfs.items():
|
|
# Use a closure helper to avoid late-binding capture issues.
|
|
def make_wrapped(meth: Callable) -> Callable:
|
|
@functools.wraps(meth)
|
|
def wrapped(*args, **kwargs):
|
|
return meth(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
decorator = self.wf_runtime.workflow(name=wf_name)
|
|
self.workflows[wf_name] = decorator(make_wrapped(method))
|
|
|
|
def start_runtime(self):
|
|
"""Idempotently start the Dapr workflow runtime."""
|
|
if not self.wf_runtime_is_running:
|
|
logger.info("Starting workflow runtime.")
|
|
self.wf_runtime.start()
|
|
self.wf_runtime_is_running = True
|
|
else:
|
|
logger.debug("Workflow runtime already running; skipping.")
|
|
|
|
def stop_runtime(self):
|
|
"""Idempotently stop the Dapr workflow runtime."""
|
|
if self.wf_runtime_is_running:
|
|
logger.info("Stopping workflow runtime.")
|
|
self.wf_runtime.shutdown()
|
|
self.wf_runtime_is_running = False
|
|
else:
|
|
logger.debug("Workflow runtime already stopped; skipping.")
|
|
|
|
@span_decorator("register_agent")
|
|
def register_agent(
|
|
self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict
|
|
) -> None:
|
|
"""
|
|
Merges the existing data with the new data and updates the store.
|
|
|
|
Args:
|
|
store_name (str): The name of the Dapr state store component.
|
|
key (str): The key to update.
|
|
data (dict): The data to update the store with.
|
|
"""
|
|
# retry the entire operation up to ten times sleeping 1 second between each attempt
|
|
for attempt in range(1, 11):
|
|
try:
|
|
response: StateResponse = self.client.get_state(
|
|
store_name=store_name, key=store_key
|
|
)
|
|
if not response.etag:
|
|
# if there is no etag the following transaction won't work as expected
|
|
# so we need to save an empty object with a strong consistency to force the etag to be created
|
|
self.client.save_state(
|
|
store_name=store_name,
|
|
key=store_key,
|
|
value=json.dumps({}),
|
|
state_metadata={"contentType": "application/json"},
|
|
options=StateOptions(
|
|
concurrency=Concurrency.first_write,
|
|
consistency=Consistency.strong,
|
|
),
|
|
)
|
|
# raise an exception to retry the entire operation
|
|
raise Exception(f"No etag found for key: {store_key}")
|
|
existing_data = json.loads(response.data) if response.data else {}
|
|
if (agent_name, agent_metadata) in existing_data.items():
|
|
logger.debug(f"agent {agent_name} already registered.")
|
|
return None
|
|
agent_data = {agent_name: agent_metadata}
|
|
merged_data = {**existing_data, **agent_data}
|
|
logger.debug(f"merged data: {merged_data} etag: {response.etag}")
|
|
try:
|
|
# using the transactional API to be able to later support the Dapr outbox pattern
|
|
self.client.execute_state_transaction(
|
|
store_name=store_name,
|
|
operations=[
|
|
TransactionalStateOperation(
|
|
key=store_key,
|
|
data=json.dumps(merged_data),
|
|
etag=response.etag,
|
|
operation_type=TransactionOperationType.upsert,
|
|
)
|
|
],
|
|
transactional_metadata={"contentType": "application/json"},
|
|
)
|
|
except Exception as e:
|
|
raise e
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error on transaction attempt: {attempt}: {e}")
|
|
logger.info("Sleeping for 1 second before retrying transaction...")
|
|
time.sleep(1)
|
|
raise Exception(
|
|
f"Failed to update state store key: {store_key} after 10 attempts."
|
|
)
|
|
|
|
def get_data_from_store(self, store_name: str, key: str) -> Optional[dict]:
|
|
"""
|
|
Retrieves data from the Dapr state store using the given key.
|
|
|
|
Args:
|
|
store_name (str): The name of the Dapr state store component.
|
|
key (str): The key to fetch data from.
|
|
|
|
Returns:
|
|
Optional[dict]: the retrieved dictionary or None if not found.
|
|
"""
|
|
try:
|
|
response: StateResponse = self.client.get_state(
|
|
store_name=store_name, key=key
|
|
)
|
|
data = response.data
|
|
return json.loads(data) if data else None
|
|
except Exception:
|
|
logger.warning(
|
|
f"Error retrieving data for key '{key}' from store '{store_name}'"
|
|
)
|
|
return None
|
|
|
|
def resolve_task(self, task: Union[str, Callable]) -> Callable:
|
|
"""
|
|
Resolves a registered task function by its name or decorated function.
|
|
|
|
Args:
|
|
task (Union[str, Callable]): The task name or callable function.
|
|
|
|
Returns:
|
|
Callable: The resolved task function.
|
|
|
|
Raises:
|
|
AttributeError: If the task is not found.
|
|
"""
|
|
if isinstance(task, str):
|
|
task_name = task
|
|
elif callable(task):
|
|
task_name = getattr(task, "_task_name", task.__name__)
|
|
else:
|
|
raise ValueError(f"Invalid task reference: {task}")
|
|
|
|
task_func = self.tasks.get(task_name)
|
|
if not task_func:
|
|
raise AttributeError(f"Task '{task_name}' not found.")
|
|
|
|
return task_func
|
|
|
|
def resolve_workflow(self, workflow: Union[str, Callable]) -> Callable:
|
|
"""
|
|
Resolves a registered workflow function by its name or decorated function.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable function.
|
|
|
|
Returns:
|
|
Callable: The resolved workflow function.
|
|
|
|
Raises:
|
|
AttributeError: If the workflow is not found.
|
|
"""
|
|
if isinstance(workflow, str):
|
|
workflow_name = workflow # Direct lookup by string name
|
|
elif callable(workflow):
|
|
workflow_name = getattr(workflow, "_workflow_name", workflow.__name__)
|
|
else:
|
|
raise ValueError(f"Invalid workflow reference: {workflow}")
|
|
|
|
workflow_func = self.workflows.get(workflow_name)
|
|
if not workflow_func:
|
|
raise AttributeError(f"Workflow '{workflow_name}' not found.")
|
|
|
|
return workflow_func
|
|
|
|
@span_decorator("run_workflow")
|
|
def run_workflow(
|
|
self, workflow: Union[str, Callable], input: Union[str, Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Starts a workflow execution.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable.
|
|
input (Union[str, Dict[str, Any]], optional): Input data for the workflow.
|
|
|
|
Returns:
|
|
str: The instance ID of the started workflow.
|
|
|
|
Raises:
|
|
Exception: If workflow execution fails.
|
|
"""
|
|
try:
|
|
# Start Workflow Runtime
|
|
if not self.wf_runtime_is_running:
|
|
self.start_runtime()
|
|
|
|
# Generate unique instance ID
|
|
instance_id = uuid.uuid4().hex
|
|
|
|
# Resolve the workflow function
|
|
workflow_func = self.resolve_workflow(workflow)
|
|
|
|
# Schedule workflow execution
|
|
instance_id = self.wf_client.schedule_new_workflow(
|
|
workflow=workflow_func, input=input, instance_id=instance_id
|
|
)
|
|
|
|
logger.info(f"Started workflow with instance ID {instance_id}.")
|
|
return instance_id
|
|
except Exception as e:
|
|
logger.error(f"Failed to start workflow {workflow}: {e}")
|
|
raise
|
|
|
|
async def monitor_workflow_state(self, instance_id: str) -> Optional[WorkflowState]:
|
|
"""
|
|
Monitors and retrieves the final state of a workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
|
|
Returns:
|
|
Optional[WorkflowState]: The final state of the workflow or None if not found.
|
|
"""
|
|
try:
|
|
state: WorkflowState = await asyncio.to_thread(
|
|
self.wait_for_workflow_completion,
|
|
instance_id,
|
|
fetch_payloads=True,
|
|
timeout_in_seconds=self.timeout,
|
|
)
|
|
|
|
if not state:
|
|
logger.error(f"Workflow '{instance_id}' not found.")
|
|
return None
|
|
|
|
return state
|
|
except TimeoutError:
|
|
logger.error(f"Workflow '{instance_id}' monitoring timed out.")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving workflow state for '{instance_id}': {e}")
|
|
return None
|
|
|
|
async def monitor_workflow_completion(self, instance_id: str) -> None:
|
|
"""
|
|
Monitors the execution of a workflow and logs its final state.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
"""
|
|
try:
|
|
logger.info(f"Monitoring workflow '{instance_id}'...")
|
|
|
|
# Retrieve workflow state
|
|
state = await self.monitor_workflow_state(instance_id)
|
|
if not state:
|
|
return # Error already logged in monitor_workflow_state
|
|
|
|
# Extract relevant details
|
|
workflow_status = state.runtime_status.name
|
|
failure_details = (
|
|
state.failure_details
|
|
) # This is an object, not a dictionary
|
|
|
|
if workflow_status == "COMPLETED":
|
|
logger.info(
|
|
f"Workflow '{instance_id}' completed successfully. Status: {workflow_status}."
|
|
)
|
|
|
|
if state.serialized_output:
|
|
logger.debug(
|
|
f"Output: {json.dumps(state.serialized_output, indent=2)}"
|
|
)
|
|
|
|
elif workflow_status in ("FAILED", "ABORTED"):
|
|
# Ensure `failure_details` exists before accessing attributes
|
|
error_type = getattr(failure_details, "error_type", "Unknown")
|
|
message = getattr(failure_details, "message", "No message provided")
|
|
stack_trace = getattr(
|
|
failure_details, "stack_trace", "No stack trace available"
|
|
)
|
|
|
|
logger.error(
|
|
f"Workflow '{instance_id}' failed.\n"
|
|
f"Error Type: {error_type}\n"
|
|
f"Message: {message}\n"
|
|
f"Stack Trace:\n{stack_trace}\n"
|
|
f"Input: {json.dumps(state.serialized_input, indent=2)}"
|
|
)
|
|
|
|
self.terminate_workflow(instance_id)
|
|
|
|
else:
|
|
logger.warning(
|
|
f"Workflow '{instance_id}' ended with status '{workflow_status}'.\n"
|
|
f"Input: {json.dumps(state.serialized_input, indent=2)}"
|
|
)
|
|
|
|
logger.debug(
|
|
f"Workflow Details: Instance ID={state.instance_id}, Name={state.name}, "
|
|
f"Created At={state.created_at}, Last Updated At={state.last_updated_at}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error monitoring workflow '{instance_id}': {e}", exc_info=True
|
|
)
|
|
finally:
|
|
logger.info(f"Finished monitoring workflow '{instance_id}'.")
|
|
|
|
async def run_and_monitor_workflow_async(
|
|
self,
|
|
workflow: Union[str, Callable],
|
|
input: Optional[Union[str, Dict[str, Any]]] = None,
|
|
) -> Optional[str]:
|
|
"""
|
|
Runs a workflow asynchronously and monitors its completion.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable.
|
|
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
|
|
|
|
Returns:
|
|
Optional[str]: The serialized output of the workflow.
|
|
"""
|
|
instance_id = None
|
|
try:
|
|
# Off-load the potentially blocking run_workflow call to a thread.
|
|
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
|
|
|
|
# Await the asynchronous monitoring of the workflow state.
|
|
state = await self.monitor_workflow_state(instance_id)
|
|
|
|
if not state:
|
|
raise RuntimeError(f"Workflow '{instance_id}' not found.")
|
|
|
|
workflow_status = (
|
|
DaprWorkflowStatus[state.runtime_status.name]
|
|
if state.runtime_status.name in DaprWorkflowStatus.__members__
|
|
else DaprWorkflowStatus.UNKNOWN
|
|
)
|
|
|
|
if workflow_status == DaprWorkflowStatus.COMPLETED:
|
|
logger.info(f"Workflow '{instance_id}' completed successfully!")
|
|
logger.debug(f"Output: {state.serialized_output}")
|
|
else:
|
|
logger.error(
|
|
f"Workflow '{instance_id}' ended with status '{workflow_status.value}'."
|
|
)
|
|
|
|
# Return the final state output
|
|
return state.serialized_output
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during workflow '{instance_id}': {e}")
|
|
raise
|
|
finally:
|
|
logger.info(f"Finished workflow with Instance ID: {instance_id}.")
|
|
# Off-load the stop_runtime call as it may block.
|
|
await asyncio.to_thread(self.stop_runtime)
|
|
|
|
def run_and_monitor_workflow_sync(
|
|
self,
|
|
workflow: Union[str, Callable],
|
|
input: Optional[Union[str, Dict[str, Any]]] = None,
|
|
) -> Optional[str]:
|
|
"""
|
|
Synchronous wrapper for running and monitoring a workflow.
|
|
This allows calling code that is not async to still run the workflow.
|
|
|
|
Args:
|
|
workflow (Union[str, Callable]): The workflow name or callable.
|
|
input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload.
|
|
|
|
Returns:
|
|
Optional[str]: The serialized output of the workflow.
|
|
"""
|
|
return asyncio.run(self.run_and_monitor_workflow_async(workflow, input))
|
|
|
|
def terminate_workflow(
|
|
self, instance_id: str, *, output: Optional[Any] = None
|
|
) -> None:
|
|
"""
|
|
Terminates a running workflow.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
output (Optional[Any]): Optional output to set for the terminated workflow.
|
|
|
|
Raises:
|
|
Exception: If the termination fails.
|
|
"""
|
|
try:
|
|
self.wf_client.terminate_workflow(instance_id=instance_id, output=output)
|
|
logger.info(
|
|
f"Successfully terminated workflow '{instance_id}' with output: {output}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to terminate workflow '{instance_id}'. Error: {e}")
|
|
raise Exception(f"Error terminating workflow '{instance_id}': {e}")
|
|
|
|
def get_workflow_state(self, instance_id: str) -> Optional[Any]:
|
|
"""
|
|
Retrieves the state of a workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
|
|
Returns:
|
|
Optional[Any]: The workflow state if found.
|
|
|
|
Raises:
|
|
RuntimeError: If retrieving the state fails.
|
|
"""
|
|
try:
|
|
state = self.wf_client.get_workflow_state(instance_id)
|
|
logger.info(
|
|
f"Retrieved state for workflow {instance_id}: {state.runtime_status}."
|
|
)
|
|
return state
|
|
except Exception as e:
|
|
logger.error(f"Failed to retrieve workflow state for {instance_id}: {e}")
|
|
return None
|
|
|
|
def wait_for_workflow_completion(
|
|
self,
|
|
instance_id: str,
|
|
fetch_payloads: bool = True,
|
|
timeout_in_seconds: int = 120,
|
|
) -> Optional[WorkflowState]:
|
|
"""
|
|
Waits for a workflow to complete and retrieves its state.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
fetch_payloads (bool): Whether to fetch input/output payloads.
|
|
timeout_in_seconds (int): Maximum wait time in seconds.
|
|
|
|
Returns:
|
|
Optional[WorkflowState]: The final state or None if it times out.
|
|
|
|
Raises:
|
|
RuntimeError: If waiting for completion fails.
|
|
"""
|
|
try:
|
|
state = self.wf_client.wait_for_workflow_completion(
|
|
instance_id,
|
|
fetch_payloads=fetch_payloads,
|
|
timeout_in_seconds=timeout_in_seconds,
|
|
)
|
|
if state:
|
|
logger.info(
|
|
f"Workflow {instance_id} completed with status: {state.runtime_status}."
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Workflow {instance_id} did not complete within the timeout period."
|
|
)
|
|
return state
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error while waiting for workflow {instance_id} completion: {e}"
|
|
)
|
|
return None
|
|
|
|
def raise_workflow_event(
|
|
self, instance_id: str, event_name: str, *, data: Any | None = None
|
|
) -> None:
|
|
"""
|
|
Raises an event for a running workflow instance.
|
|
|
|
Args:
|
|
instance_id (str): The workflow instance ID.
|
|
event_name (str): The name of the event to raise.
|
|
data (Any | None): Optional event data.
|
|
|
|
Raises:
|
|
Exception: If raising the event fails.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
f"Raising workflow event '{event_name}' for instance '{instance_id}'"
|
|
)
|
|
self.wf_client.raise_workflow_event(
|
|
instance_id=instance_id, event_name=event_name, data=data
|
|
)
|
|
logger.info(
|
|
f"Successfully raised workflow event '{event_name}' for instance '{instance_id}'!"
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error raising workflow event '{event_name}' for instance '{instance_id}'. "
|
|
f"Data: {data}, Error: {e}"
|
|
)
|
|
raise Exception(
|
|
f"Failed to raise workflow event '{event_name}' for instance '{instance_id}': {str(e)}"
|
|
)
|
|
|
|
def invoke_service(
|
|
self,
|
|
service: str,
|
|
method: str,
|
|
http_method: str = "POST",
|
|
input: Optional[Dict[str, Any]] = None,
|
|
timeout: Optional[int] = None,
|
|
) -> Any:
|
|
"""
|
|
Invokes an external service via Dapr.
|
|
|
|
Args:
|
|
service (str): The service name.
|
|
method (str): The method to call.
|
|
http_method (str, optional): The HTTP method (default: "POST").
|
|
input (Optional[Dict[str, Any]], optional): The request payload.
|
|
timeout (Optional[int], optional): Timeout in seconds.
|
|
|
|
Returns:
|
|
Any: The response from the service.
|
|
|
|
Raises:
|
|
Exception: If the invocation fails.
|
|
"""
|
|
try:
|
|
resp = self.client.invoke_method(
|
|
app_id=service,
|
|
method_name=method,
|
|
http_verb=http_method,
|
|
data=json.dumps(input) if input else None,
|
|
timeout=timeout,
|
|
)
|
|
if resp.status_code != 200:
|
|
raise Exception(
|
|
f"Error calling {service}.{method}: {resp.status_code}: {resp.text}"
|
|
)
|
|
|
|
agent_response = json.loads(resp.data.decode("utf-8"))
|
|
logger.info(f"Agent's Response: {agent_response}")
|
|
return agent_response
|
|
except Exception as e:
|
|
logger.error(f"Failed to invoke {service}.{method}: {e}")
|
|
raise e
|
|
|
|
def when_all(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAllTask[T]:
|
|
"""
|
|
Waits for all given tasks to complete.
|
|
|
|
Args:
|
|
tasks (List[dtask.Task[T]]): The tasks to wait for.
|
|
|
|
Returns:
|
|
dtask.WhenAllTask[T]: A task that completes when all tasks finish.
|
|
"""
|
|
return dtask.when_all(tasks)
|
|
|
|
def when_any(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAnyTask:
|
|
"""
|
|
Waits for any one of the given tasks to complete.
|
|
|
|
Args:
|
|
tasks (List[dtask.Task[T]]): The tasks to monitor.
|
|
|
|
Returns:
|
|
dtask.WhenAnyTask: A task that completes when the first task finishes.
|
|
"""
|
|
return dtask.when_any(tasks)
|