dapr-agents/dapr_agents/workflow/messaging/routing.py

230 lines
9.3 KiB
Python

import asyncio
import inspect
import logging
import threading
import functools
from typing import Callable
from dapr.aio.clients.grpc.subscription import Subscription
from dapr.clients.grpc._response import TopicEventResponse
from dapr.clients.grpc.subscription import StreamInactiveError
from dapr.common.pubsub.subscription import StreamCancelledError, SubscriptionMessage
from dapr_agents.workflow.messaging.parser import (
extract_cloudevent_data,
validate_message_model,
)
from dapr_agents.workflow.messaging.utils import is_valid_routable_model
from dapr_agents.workflow.utils import get_decorated_methods
logger = logging.getLogger(__name__)
class MessageRoutingMixin:
"""
Mixin class providing dynamic message routing capabilities for agentic services using Dapr pub/sub.
This mixin enables:
- Auto-registration of message handlers via the `@message_router` decorator.
- CloudEvent-based dispatch to appropriate handlers based on `type`.
- Topic subscription management and graceful shutdown.
- Support for both synchronous and asynchronous handler methods.
- Workflow-aware message handling for registered workflow entrypoints.
Expected attributes provided by the consuming service:
- `self._dapr_client`: A configured Dapr client instance.
- `self.name`: The agent's name (used for default topic routing).
- `self.message_bus_name`: Pub/Sub component name in Dapr.
- `self.broadcast_topic_name`: Optional default topic name for broadcasts.
- `self._topic_handlers`: Dict storing routing info by (pubsub, topic).
- `self._subscriptions`: Dict storing unsubscribe functions for active subscriptions.
"""
def register_message_routes(self) -> None:
"""
Registers message handlers dynamically by subscribing once per topic.
Incoming messages are dispatched by CloudEvent `type` to the appropriate handler.
This function:
- Scans all class methods for the `@message_router` decorator.
- Extracts routing metadata and message model schemas.
- Wraps each handler and maps it by `(pubsub_name, topic_name)` and schema name.
- Ensures only one handler per schema per topic is allowed.
"""
message_handlers = get_decorated_methods(self, "_is_message_handler")
for method_name, method in message_handlers.items():
try:
router_data = method._message_router_data.copy()
pubsub_name = router_data.get("pubsub") or self.message_bus_name
is_broadcast = router_data.get("is_broadcast", False)
topic_name = router_data.get("topic") or (
self.broadcast_topic_name if is_broadcast else self.name
)
message_schemas = router_data.get("message_schemas", [])
if not message_schemas:
raise ValueError(
f"No message models found for handler '{method_name}'."
)
wrapped_method = self._create_wrapped_method(method)
topic_key = (pubsub_name, topic_name)
self._topic_handlers.setdefault(topic_key, {})
for schema in message_schemas:
if not is_valid_routable_model(schema):
raise ValueError(
f"Unsupported message model for handler '{method_name}': {schema}"
)
schema_name = schema.__name__
logger.debug(
f"Registering handler '{method_name}' for topic '{topic_name}' with model '{schema_name}'"
)
# Prevent multiple handlers for the same schema
if schema_name in self._topic_handlers[topic_key]:
raise ValueError(
f"Duplicate handler for model '{schema_name}' on topic '{topic_name}'. "
f"Each model can only be handled by one function per topic."
)
self._topic_handlers[topic_key][schema_name] = {
"schema": schema,
"handler": wrapped_method,
}
except Exception as e:
logger.error(
f"Failed to register handler '{method_name}': {e}", exc_info=True
)
# Subscribe once per topic
for pubsub_name, topic_name in self._topic_handlers.keys():
self._subscribe_with_router(pubsub_name, topic_name)
logger.info("All message routes registered.")
def _create_wrapped_method(self, method: Callable) -> Callable:
"""
Wraps a message handler method to ensure it runs asynchronously,
with special handling for workflows.
"""
@functools.wraps(method)
async def wrapped_method(message: dict):
try:
if getattr(method, "_is_workflow", False):
workflow_name = getattr(method, "_workflow_name", method.__name__)
instance_id = self.run_workflow(workflow_name, input=message)
asyncio.create_task(self.monitor_workflow_completion(instance_id))
return None
if inspect.iscoroutinefunction(method):
return await method(message=message)
else:
return method(message=message)
except Exception as e:
logger.error(
f"Error invoking handler '{method.__name__}': {e}", exc_info=True
)
return None
return wrapped_method
def _subscribe_with_router(self, pubsub_name: str, topic_name: str):
subscription: Subscription = self._dapr_client.subscribe(
pubsub_name, topic_name
)
loop = asyncio.get_running_loop()
def stream_messages(sub: Subscription):
while True:
try:
for message in sub:
if message:
try:
future = asyncio.run_coroutine_threadsafe(
self._route_message(
pubsub_name, topic_name, message
),
loop,
)
response = future.result()
sub.respond(message, response.status)
except Exception as e:
print(f"Error handling message: {e}")
else:
continue
except (StreamInactiveError, StreamCancelledError):
break
def close_subscription():
subscription.close()
self._subscriptions[(pubsub_name, topic_name)] = close_subscription
threading.Thread(
target=stream_messages, args=(subscription,), daemon=True
).start()
async def _route_message(
self, pubsub_name: str, topic_name: str, message: SubscriptionMessage
) -> TopicEventResponse:
"""
Routes an incoming message to the correct handler based on CloudEvent `type`.
Args:
pubsub_name (str): The name of the pubsub component.
topic_name (str): The topic from which the message was received.
message (SubscriptionMessage): The incoming Dapr message.
Returns:
TopicEventResponse: The response status for the message (success, drop, retry).
"""
try:
handler_map = self._topic_handlers.get((pubsub_name, topic_name), {})
if not handler_map:
logger.warning(
f"No handlers for topic '{topic_name}' on pubsub '{pubsub_name}'. Dropping message."
)
return TopicEventResponse("drop")
# Step 1: Extract CloudEvent metadata and data
event_data, metadata = extract_cloudevent_data(message)
event_type = metadata.get("type")
route_entry = handler_map.get(event_type)
if not route_entry:
logger.warning(
f"No handler matched CloudEvent type '{event_type}' on topic '{topic_name}'"
)
return TopicEventResponse("drop")
schema = route_entry["schema"]
handler = route_entry["handler"]
try:
parsed_message = validate_message_model(schema, event_data)
parsed_message["_message_metadata"] = metadata
logger.info(
f"Dispatched to handler '{handler.__name__}' for event type '{event_type}'"
)
result = await handler(parsed_message)
if result is not None:
return TopicEventResponse("success"), result
return TopicEventResponse("success")
except Exception as e:
logger.warning(
f"Failed to validate message against schema '{schema.__name__}': {e}"
)
return TopicEventResponse("retry")
except Exception as e:
logger.error(f"Unexpected error during message routing: {e}", exc_info=True)
return TopicEventResponse("retry")