dapr-agents/dapr_agents/workflow/orchestrators/roundrobin.py

232 lines
8.0 KiB
Python

from dapr_agents.workflow.messaging.decorator import message_router
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents.types import BaseMessage
from dapr_agents.workflow.decorators import workflow, task
from typing import Any, Optional, Dict
from pydantic import BaseModel, Field
from datetime import timedelta
import logging
logger = logging.getLogger(__name__)
class AgentTaskResponse(BaseMessage):
"""
Represents a response message from an agent after completing a task.
"""
workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)
class TriggerAction(BaseModel):
"""
Represents a message used to trigger an agent's activity within the workflow.
"""
task: Optional[str] = Field(
None,
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
)
iteration: Optional[int] = Field(0, description="")
workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)
class RoundRobinOrchestrator(OrchestratorWorkflowBase):
"""
Implements a round-robin workflow where agents take turns performing tasks.
The workflow iterates through conversations by selecting agents in a circular order.
Uses `continue_as_new` to persist iteration state.
"""
def model_post_init(self, __context: Any) -> None:
"""
Initializes and configures the round-robin workflow service.
Registers tasks and workflows, then starts the workflow runtime.
"""
self._workflow_name = "RoundRobinWorkflow"
super().model_post_init(__context)
@workflow(name="RoundRobinWorkflow")
def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction):
"""
Executes a round-robin workflow where agents interact iteratively.
Steps:
1. Processes input and broadcasts the initial message.
2. Iterates through agents, selecting a speaker each round.
3. Waits for agent responses or handles timeouts.
4. Updates the workflow state and continues the loop.
5. Terminates when max iterations are reached.
Uses `continue_as_new` to persist iteration state.
Args:
ctx (DaprWorkflowContext): The workflow execution context.
input (TriggerAction): The current workflow state containing task and iteration.
Returns:
str: The last processed message when the workflow terminates.
"""
task = input.get("task")
iteration = input.get("iteration", 0)
instance_id = ctx.instance_id
if not ctx.is_replaying:
logger.info(
f"Round-robin iteration {iteration + 1} started (Instance ID: {instance_id})."
)
# Check Termination Condition
if iteration >= self.max_iterations:
logger.info(
f"Max iterations reached. Ending round-robin workflow (Instance ID: {instance_id})."
)
return task
# First iteration: Process input and broadcast
if iteration == 0:
message = yield ctx.call_activity(self.process_input, input={"task": task})
logger.info(f"Initial message from {message['role']} -> {self.name}")
# Broadcast initial message
yield ctx.call_activity(
self.broadcast_message_to_agents, input={"message": message}
)
# Select next speaker
next_speaker = yield ctx.call_activity(
self.select_next_speaker, input={"iteration": iteration}
)
# Trigger agent
yield ctx.call_activity(
self.trigger_agent, input={"name": next_speaker, "instance_id": instance_id}
)
# Wait for response or timeout
logger.info("Waiting for agent response...")
event_data = ctx.wait_for_external_event("AgentTaskResponse")
timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
any_results = yield self.when_any([event_data, timeout_task])
if any_results == timeout_task:
logger.warning(
f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
)
task_results = {
"name": "timeout",
"content": "Timeout occurred. Continuing...",
}
else:
task_results = yield event_data
logger.info(f"{task_results['name']} -> {self.name}")
# Update for next iteration
input["task"] = task_results["content"]
input["iteration"] = iteration + 1
# Restart workflow with updated state
ctx.continue_as_new(input)
@task
async def process_input(self, task: str) -> Dict[str, Any]:
"""
Processes the input message for the workflow.
Args:
task (str): The user-provided input task.
Returns:
dict: Serialized UserMessage with the content.
"""
return {"role": "user", "name": self.name, "content": task}
@task
async def broadcast_message_to_agents(self, message: Dict[str, Any]):
"""
Broadcasts a message to all agents.
Args:
message (Dict[str, Any]): The message content and additional metadata.
"""
await self.broadcast_message(
message=BaseMessage(**message), exclude_orchestrator=True
)
@task
async def select_next_speaker(self, iteration: int) -> str:
"""
Selects the next speaker in round-robin order.
Args:
iteration (int): The current iteration number.
Returns:
str: The name of the selected agent.
"""
agents_metadata = self.get_agents_metadata(exclude_orchestrator=True)
if not agents_metadata:
logger.warning("No agents available for selection.")
raise ValueError("Agents metadata is empty. Cannot select next speaker.")
agent_names = list(agents_metadata.keys())
# Determine the next agent in the round-robin order
next_speaker = agent_names[iteration % len(agent_names)]
logger.info(
f"{self.name} selected agent {next_speaker} for iteration {iteration}."
)
return next_speaker
@task
async def trigger_agent(self, name: str, instance_id: str) -> None:
"""
Triggers the specified agent to perform its activity.
Args:
name (str): Name of the agent to trigger.
instance_id (str): Workflow instance ID for context.
"""
await self.send_message_to_agent(
name=name,
message=TriggerAction(workflow_instance_id=instance_id),
)
@message_router
async def process_agent_response(self, message: AgentTaskResponse):
"""
Processes agent response messages sent directly to the agent's topic.
Args:
message (AgentTaskResponse): The agent's response containing task results.
Returns:
None: The function raises a workflow event with the agent's response.
"""
try:
workflow_instance_id = message.get("workflow_instance_id")
if not workflow_instance_id:
logger.error(
f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
)
return
logger.info(
f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
)
# Raise a workflow event with the Agent's Task Response
self.raise_workflow_event(
instance_id=workflow_instance_id,
event_name="AgentTaskResponse",
data=message,
)
except Exception as e:
logger.error(f"Error processing agent response: {e}", exc_info=True)