mirror of https://github.com/dapr/dapr-agents.git
232 lines
8.0 KiB
Python
232 lines
8.0 KiB
Python
from dapr_agents.workflow.messaging.decorator import message_router
|
|
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
|
|
from dapr.ext.workflow import DaprWorkflowContext
|
|
from dapr_agents.types import BaseMessage
|
|
from dapr_agents.workflow.decorators import workflow, task
|
|
from typing import Any, Optional, Dict
|
|
from pydantic import BaseModel, Field
|
|
from datetime import timedelta
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentTaskResponse(BaseMessage):
|
|
"""
|
|
Represents a response message from an agent after completing a task.
|
|
"""
|
|
|
|
workflow_instance_id: Optional[str] = Field(
|
|
default=None, description="Dapr workflow instance id from source if available"
|
|
)
|
|
|
|
|
|
class TriggerAction(BaseModel):
|
|
"""
|
|
Represents a message used to trigger an agent's activity within the workflow.
|
|
"""
|
|
|
|
task: Optional[str] = Field(
|
|
None,
|
|
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
|
|
)
|
|
iteration: Optional[int] = Field(0, description="")
|
|
workflow_instance_id: Optional[str] = Field(
|
|
default=None, description="Dapr workflow instance id from source if available"
|
|
)
|
|
|
|
|
|
class RoundRobinOrchestrator(OrchestratorWorkflowBase):
|
|
"""
|
|
Implements a round-robin workflow where agents take turns performing tasks.
|
|
The workflow iterates through conversations by selecting agents in a circular order.
|
|
|
|
Uses `continue_as_new` to persist iteration state.
|
|
"""
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Initializes and configures the round-robin workflow service.
|
|
Registers tasks and workflows, then starts the workflow runtime.
|
|
"""
|
|
self._workflow_name = "RoundRobinWorkflow"
|
|
super().model_post_init(__context)
|
|
|
|
@workflow(name="RoundRobinWorkflow")
|
|
def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction):
|
|
"""
|
|
Executes a round-robin workflow where agents interact iteratively.
|
|
|
|
Steps:
|
|
1. Processes input and broadcasts the initial message.
|
|
2. Iterates through agents, selecting a speaker each round.
|
|
3. Waits for agent responses or handles timeouts.
|
|
4. Updates the workflow state and continues the loop.
|
|
5. Terminates when max iterations are reached.
|
|
|
|
Uses `continue_as_new` to persist iteration state.
|
|
|
|
Args:
|
|
ctx (DaprWorkflowContext): The workflow execution context.
|
|
input (TriggerAction): The current workflow state containing task and iteration.
|
|
|
|
Returns:
|
|
str: The last processed message when the workflow terminates.
|
|
"""
|
|
task = input.get("task")
|
|
iteration = input.get("iteration", 0)
|
|
instance_id = ctx.instance_id
|
|
|
|
if not ctx.is_replaying:
|
|
logger.info(
|
|
f"Round-robin iteration {iteration + 1} started (Instance ID: {instance_id})."
|
|
)
|
|
|
|
# Check Termination Condition
|
|
if iteration >= self.max_iterations:
|
|
logger.info(
|
|
f"Max iterations reached. Ending round-robin workflow (Instance ID: {instance_id})."
|
|
)
|
|
return task
|
|
|
|
# First iteration: Process input and broadcast
|
|
if iteration == 0:
|
|
message = yield ctx.call_activity(self.process_input, input={"task": task})
|
|
logger.info(f"Initial message from {message['role']} -> {self.name}")
|
|
|
|
# Broadcast initial message
|
|
yield ctx.call_activity(
|
|
self.broadcast_message_to_agents, input={"message": message}
|
|
)
|
|
|
|
# Select next speaker
|
|
next_speaker = yield ctx.call_activity(
|
|
self.select_next_speaker, input={"iteration": iteration}
|
|
)
|
|
|
|
# Trigger agent
|
|
yield ctx.call_activity(
|
|
self.trigger_agent, input={"name": next_speaker, "instance_id": instance_id}
|
|
)
|
|
|
|
# Wait for response or timeout
|
|
logger.info("Waiting for agent response...")
|
|
event_data = ctx.wait_for_external_event("AgentTaskResponse")
|
|
timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
|
|
any_results = yield self.when_any([event_data, timeout_task])
|
|
|
|
if any_results == timeout_task:
|
|
logger.warning(
|
|
f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
|
|
)
|
|
task_results = {
|
|
"name": "timeout",
|
|
"content": "Timeout occurred. Continuing...",
|
|
}
|
|
else:
|
|
task_results = yield event_data
|
|
logger.info(f"{task_results['name']} -> {self.name}")
|
|
|
|
# Update for next iteration
|
|
input["task"] = task_results["content"]
|
|
input["iteration"] = iteration + 1
|
|
|
|
# Restart workflow with updated state
|
|
ctx.continue_as_new(input)
|
|
|
|
@task
|
|
async def process_input(self, task: str) -> Dict[str, Any]:
|
|
"""
|
|
Processes the input message for the workflow.
|
|
|
|
Args:
|
|
task (str): The user-provided input task.
|
|
Returns:
|
|
dict: Serialized UserMessage with the content.
|
|
"""
|
|
return {"role": "user", "name": self.name, "content": task}
|
|
|
|
@task
|
|
async def broadcast_message_to_agents(self, message: Dict[str, Any]):
|
|
"""
|
|
Broadcasts a message to all agents.
|
|
|
|
Args:
|
|
message (Dict[str, Any]): The message content and additional metadata.
|
|
"""
|
|
await self.broadcast_message(
|
|
message=BaseMessage(**message), exclude_orchestrator=True
|
|
)
|
|
|
|
@task
|
|
async def select_next_speaker(self, iteration: int) -> str:
|
|
"""
|
|
Selects the next speaker in round-robin order.
|
|
|
|
Args:
|
|
iteration (int): The current iteration number.
|
|
Returns:
|
|
str: The name of the selected agent.
|
|
"""
|
|
agents_metadata = self.get_agents_metadata(exclude_orchestrator=True)
|
|
if not agents_metadata:
|
|
logger.warning("No agents available for selection.")
|
|
raise ValueError("Agents metadata is empty. Cannot select next speaker.")
|
|
|
|
agent_names = list(agents_metadata.keys())
|
|
|
|
# Determine the next agent in the round-robin order
|
|
next_speaker = agent_names[iteration % len(agent_names)]
|
|
logger.info(
|
|
f"{self.name} selected agent {next_speaker} for iteration {iteration}."
|
|
)
|
|
return next_speaker
|
|
|
|
@task
|
|
async def trigger_agent(self, name: str, instance_id: str) -> None:
|
|
"""
|
|
Triggers the specified agent to perform its activity.
|
|
|
|
Args:
|
|
name (str): Name of the agent to trigger.
|
|
instance_id (str): Workflow instance ID for context.
|
|
"""
|
|
await self.send_message_to_agent(
|
|
name=name,
|
|
message=TriggerAction(workflow_instance_id=instance_id),
|
|
)
|
|
|
|
@message_router
|
|
async def process_agent_response(self, message: AgentTaskResponse):
|
|
"""
|
|
Processes agent response messages sent directly to the agent's topic.
|
|
|
|
Args:
|
|
message (AgentTaskResponse): The agent's response containing task results.
|
|
|
|
Returns:
|
|
None: The function raises a workflow event with the agent's response.
|
|
"""
|
|
try:
|
|
workflow_instance_id = message.get("workflow_instance_id")
|
|
|
|
if not workflow_instance_id:
|
|
logger.error(
|
|
f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
|
|
)
|
|
return
|
|
|
|
logger.info(
|
|
f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
|
|
)
|
|
|
|
# Raise a workflow event with the Agent's Task Response
|
|
self.raise_workflow_event(
|
|
instance_id=workflow_instance_id,
|
|
event_name="AgentTaskResponse",
|
|
data=message,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing agent response: {e}", exc_info=True)
|