mirror of https://github.com/dapr/dapr-agents.git
230 lines
9.3 KiB
Python
230 lines
9.3 KiB
Python
import asyncio
|
|
import inspect
|
|
import logging
|
|
import threading
|
|
import functools
|
|
from typing import Callable
|
|
|
|
from dapr.aio.clients.grpc.subscription import Subscription
|
|
from dapr.clients.grpc._response import TopicEventResponse
|
|
from dapr.clients.grpc.subscription import StreamInactiveError
|
|
from dapr.common.pubsub.subscription import StreamCancelledError, SubscriptionMessage
|
|
from dapr_agents.workflow.messaging.parser import (
|
|
extract_cloudevent_data,
|
|
validate_message_model,
|
|
)
|
|
from dapr_agents.workflow.messaging.utils import is_valid_routable_model
|
|
from dapr_agents.workflow.utils import get_decorated_methods
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MessageRoutingMixin:
|
|
"""
|
|
Mixin class providing dynamic message routing capabilities for agentic services using Dapr pub/sub.
|
|
|
|
This mixin enables:
|
|
- Auto-registration of message handlers via the `@message_router` decorator.
|
|
- CloudEvent-based dispatch to appropriate handlers based on `type`.
|
|
- Topic subscription management and graceful shutdown.
|
|
- Support for both synchronous and asynchronous handler methods.
|
|
- Workflow-aware message handling for registered workflow entrypoints.
|
|
|
|
Expected attributes provided by the consuming service:
|
|
- `self._dapr_client`: A configured Dapr client instance.
|
|
- `self.name`: The agent's name (used for default topic routing).
|
|
- `self.message_bus_name`: Pub/Sub component name in Dapr.
|
|
- `self.broadcast_topic_name`: Optional default topic name for broadcasts.
|
|
- `self._topic_handlers`: Dict storing routing info by (pubsub, topic).
|
|
- `self._subscriptions`: Dict storing unsubscribe functions for active subscriptions.
|
|
"""
|
|
|
|
def register_message_routes(self) -> None:
|
|
"""
|
|
Registers message handlers dynamically by subscribing once per topic.
|
|
Incoming messages are dispatched by CloudEvent `type` to the appropriate handler.
|
|
|
|
This function:
|
|
- Scans all class methods for the `@message_router` decorator.
|
|
- Extracts routing metadata and message model schemas.
|
|
- Wraps each handler and maps it by `(pubsub_name, topic_name)` and schema name.
|
|
- Ensures only one handler per schema per topic is allowed.
|
|
"""
|
|
message_handlers = get_decorated_methods(self, "_is_message_handler")
|
|
|
|
for method_name, method in message_handlers.items():
|
|
try:
|
|
router_data = method._message_router_data.copy()
|
|
pubsub_name = router_data.get("pubsub") or self.message_bus_name
|
|
is_broadcast = router_data.get("is_broadcast", False)
|
|
topic_name = router_data.get("topic") or (
|
|
self.broadcast_topic_name if is_broadcast else self.name
|
|
)
|
|
message_schemas = router_data.get("message_schemas", [])
|
|
|
|
if not message_schemas:
|
|
raise ValueError(
|
|
f"No message models found for handler '{method_name}'."
|
|
)
|
|
|
|
wrapped_method = self._create_wrapped_method(method)
|
|
topic_key = (pubsub_name, topic_name)
|
|
|
|
self._topic_handlers.setdefault(topic_key, {})
|
|
|
|
for schema in message_schemas:
|
|
if not is_valid_routable_model(schema):
|
|
raise ValueError(
|
|
f"Unsupported message model for handler '{method_name}': {schema}"
|
|
)
|
|
|
|
schema_name = schema.__name__
|
|
logger.debug(
|
|
f"Registering handler '{method_name}' for topic '{topic_name}' with model '{schema_name}'"
|
|
)
|
|
|
|
# Prevent multiple handlers for the same schema
|
|
if schema_name in self._topic_handlers[topic_key]:
|
|
raise ValueError(
|
|
f"Duplicate handler for model '{schema_name}' on topic '{topic_name}'. "
|
|
f"Each model can only be handled by one function per topic."
|
|
)
|
|
|
|
self._topic_handlers[topic_key][schema_name] = {
|
|
"schema": schema,
|
|
"handler": wrapped_method,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to register handler '{method_name}': {e}", exc_info=True
|
|
)
|
|
|
|
# Subscribe once per topic
|
|
for pubsub_name, topic_name in self._topic_handlers.keys():
|
|
self._subscribe_with_router(pubsub_name, topic_name)
|
|
|
|
logger.info("All message routes registered.")
|
|
|
|
def _create_wrapped_method(self, method: Callable) -> Callable:
|
|
"""
|
|
Wraps a message handler method to ensure it runs asynchronously,
|
|
with special handling for workflows.
|
|
"""
|
|
|
|
@functools.wraps(method)
|
|
async def wrapped_method(message: dict):
|
|
try:
|
|
if getattr(method, "_is_workflow", False):
|
|
workflow_name = getattr(method, "_workflow_name", method.__name__)
|
|
instance_id = self.run_workflow(workflow_name, input=message)
|
|
asyncio.create_task(self.monitor_workflow_completion(instance_id))
|
|
return None
|
|
|
|
if inspect.iscoroutinefunction(method):
|
|
return await method(message=message)
|
|
else:
|
|
return method(message=message)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error invoking handler '{method.__name__}': {e}", exc_info=True
|
|
)
|
|
return None
|
|
|
|
return wrapped_method
|
|
|
|
def _subscribe_with_router(self, pubsub_name: str, topic_name: str):
|
|
subscription: Subscription = self._dapr_client.subscribe(
|
|
pubsub_name, topic_name
|
|
)
|
|
loop = asyncio.get_running_loop()
|
|
|
|
def stream_messages(sub: Subscription):
|
|
while True:
|
|
try:
|
|
for message in sub:
|
|
if message:
|
|
try:
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self._route_message(
|
|
pubsub_name, topic_name, message
|
|
),
|
|
loop,
|
|
)
|
|
response = future.result()
|
|
sub.respond(message, response.status)
|
|
except Exception as e:
|
|
print(f"Error handling message: {e}")
|
|
else:
|
|
continue
|
|
except (StreamInactiveError, StreamCancelledError):
|
|
break
|
|
|
|
def close_subscription():
|
|
subscription.close()
|
|
|
|
self._subscriptions[(pubsub_name, topic_name)] = close_subscription
|
|
threading.Thread(
|
|
target=stream_messages, args=(subscription,), daemon=True
|
|
).start()
|
|
|
|
async def _route_message(
|
|
self, pubsub_name: str, topic_name: str, message: SubscriptionMessage
|
|
) -> TopicEventResponse:
|
|
"""
|
|
Routes an incoming message to the correct handler based on CloudEvent `type`.
|
|
|
|
Args:
|
|
pubsub_name (str): The name of the pubsub component.
|
|
topic_name (str): The topic from which the message was received.
|
|
message (SubscriptionMessage): The incoming Dapr message.
|
|
|
|
Returns:
|
|
TopicEventResponse: The response status for the message (success, drop, retry).
|
|
"""
|
|
try:
|
|
handler_map = self._topic_handlers.get((pubsub_name, topic_name), {})
|
|
if not handler_map:
|
|
logger.warning(
|
|
f"No handlers for topic '{topic_name}' on pubsub '{pubsub_name}'. Dropping message."
|
|
)
|
|
return TopicEventResponse("drop")
|
|
|
|
# Step 1: Extract CloudEvent metadata and data
|
|
event_data, metadata = extract_cloudevent_data(message)
|
|
event_type = metadata.get("type")
|
|
|
|
route_entry = handler_map.get(event_type)
|
|
if not route_entry:
|
|
logger.warning(
|
|
f"No handler matched CloudEvent type '{event_type}' on topic '{topic_name}'"
|
|
)
|
|
return TopicEventResponse("drop")
|
|
|
|
schema = route_entry["schema"]
|
|
handler = route_entry["handler"]
|
|
|
|
try:
|
|
parsed_message = validate_message_model(schema, event_data)
|
|
parsed_message["_message_metadata"] = metadata
|
|
|
|
logger.info(
|
|
f"Dispatched to handler '{handler.__name__}' for event type '{event_type}'"
|
|
)
|
|
result = await handler(parsed_message)
|
|
if result is not None:
|
|
return TopicEventResponse("success"), result
|
|
|
|
return TopicEventResponse("success")
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to validate message against schema '{schema.__name__}': {e}"
|
|
)
|
|
return TopicEventResponse("retry")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during message routing: {e}", exc_info=True)
|
|
return TopicEventResponse("retry")
|