dapr-agents/dapr_agents/workflow/service.py

143 lines
5.9 KiB
Python

from fastapi import HTTPException, Request, Response, status
from fastapi import HTTPException, status, Request
from cloudevents.http.conversion import from_http
from cloudevents.http.event import CloudEvent
from fastapi.responses import JSONResponse
from dapr_agents.service.fastapi import DaprEnabledService
from dapr_agents.workflow import WorkflowApp
from fastapi import Request
from pydantic import Field
from typing import Any
import asyncio
import logging
logger = logging.getLogger(__name__)
class WorkflowAppService(DaprEnabledService, WorkflowApp):
"""
Abstract base class for agentic workflows, providing a template for common workflow operations.
"""
# Fields initialized later
workflow_name: str = Field(default=None, init=False, description="The main workflow name for this service.")
def model_post_init(self, __context: Any) -> None:
"""
Configure workflows and initialize AgentService and WorkflowApp.
"""
super().model_post_init(__context)
# Register API Routes
self.app.add_api_route("/RunWorkflow", self.run_workflow_from_request, methods=["POST"])
self.app.add_api_route("/RaiseWorkflowEvent", self.raise_workflow_event_from_request, methods=["POST"])
async def run_workflow_from_request(self, request: Request) -> JSONResponse:
"""
Run a workflow instance triggered by an incoming HTTP request.
Handles both CloudEvents and plain JSON input, with background monitoring.
"""
try:
# Extract headers and body
headers = request.headers
body = await request.body()
# Attempt to parse as CloudEvent
try:
event: CloudEvent = from_http(dict(headers), body)
workflow_name = event.get("subject") or headers.get("workflow_name", self.workflow_name)
input_data = event.data
except Exception:
# Fallback to plain JSON
data = await request.json()
workflow_name = headers.get("workflow_name", self.workflow_name)
input_data = data
if not workflow_name:
raise ValueError("Workflow name must be provided in headers or as CloudEvent subject.")
logger.info(f"Starting '{workflow_name}' from request with input: {input_data}")
# Start the workflow
instance_id = self.run_workflow(workflow=workflow_name, input=input_data)
# Schedule background monitoring
asyncio.create_task(self.monitor_workflow_completion(instance_id))
# Respond with the workflow instance ID immediately
return JSONResponse(
content={"message": "Workflow initiated successfully.", "workflow_instance_id": instance_id},
status_code=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Error triggering workflow: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error triggering workflow: {str(e)}",
)
async def raise_workflow_event_from_request(self, request: Request) -> Response:
"""
Handles an API request or pub/sub message to trigger a workflow event.
This method processes the incoming CloudEvent or HTTP request, extracts metadata
(workflow_instance_id, event_name), and triggers the corresponding event for a running workflow instance.
Args:
request (Request): The incoming request containing event details.
Returns:
Response: Success or error response based on event processing.
"""
try:
# Parse the incoming CloudEvent
body = await request.body()
headers = request.headers
event: CloudEvent = from_http(dict(headers), body)
# Extract essential metadata from headers
workflow_instance_id = headers.get("workflow_instance_id")
event_name = headers.get("event_name")
if not workflow_instance_id:
logger.warning("Workflow event missing 'workflow_instance_id'.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Header 'workflow_instance_id' is required."
)
if not event_name:
logger.warning("Workflow event missing 'event_name'.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Header 'event_name' is required."
)
# Extract event details
source = event.get("source")
event_data = event.data
if not event_data:
logger.warning("Event data is empty or missing.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Event data is required."
)
logger.info(f"Processing event '{event_name}' for workflow '{workflow_instance_id}' from '{source}'.")
# Trigger the workflow event, passing all event data as `data`
self.raise_workflow_event(instance_id=workflow_instance_id, event_name=event_name, data=event_data)
return JSONResponse(content={"message": "Workflow event triggered successfully."}, status_code=status.HTTP_200_OK)
except HTTPException as e:
logger.error(f"Error processing workflow event: {e.detail}", exc_info=True)
raise e
except Exception as e:
logger.error(f"Unexpected error processing workflow event: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing workflow event: {str(e)}"
)