mirror of https://github.com/dapr/dapr-agents.git
				
				
				
			Refactor LLM Workflows and Orchestrators for Unified Response Handling and Iteration (#163)
* Refactor ChatClientBase: drop Pydantic inheritance and add typed generate() overloads Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * Align all LLM chat clients with refactored base and unified response models Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * Unify LLM utils across providers and delegate streaming/response to provider‑specific handlers Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * Refactor LLM pipeline: add HuggingFace tool calls, unify chat client/response types, and switch DurableAgent to loop‑based workflow Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * Refactor orchestrators with loops and unify LLM response handling using LLMChatResponse Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * test remaining quickstarts after all changes Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * run pytest after all changes Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * Run linting and formatting checks to ensure code quality Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> * Update logging, Orchestrator Name and OTel module name Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> --------- Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
		
							parent
							
								
									3e767e03fb
								
							
						
					
					
						commit
						29edfc419b
					
				| 
						 | 
				
			
			@ -1,30 +1,20 @@
 | 
			
		|||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Any, Dict, List, Optional, Union
 | 
			
		||||
 | 
			
		||||
from dapr_agents.agents.base import AgentBase
 | 
			
		||||
from dapr_agents.types import (
 | 
			
		||||
    AgentError,
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    ChatCompletion,
 | 
			
		||||
    ToolCall,
 | 
			
		||||
    ToolExecutionRecord,
 | 
			
		||||
    ToolMessage,
 | 
			
		||||
    UserMessage,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FinishReason(str, Enum):
 | 
			
		||||
    STOP = "stop"
 | 
			
		||||
    LENGTH = "length"
 | 
			
		||||
    CONTENT_FILTER = "content_filter"
 | 
			
		||||
    TOOL_CALLS = "tool_calls"
 | 
			
		||||
    FUNCTION_CALL = "function_call"  # deprecated
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Agent(AgentBase):
 | 
			
		||||
    """
 | 
			
		||||
    Agent that manages tool calls and conversations using a language model.
 | 
			
		||||
| 
						 | 
				
			
			@ -168,8 +158,10 @@ class Agent(AgentBase):
 | 
			
		|||
                        name=function_name,
 | 
			
		||||
                        content=result_str,
 | 
			
		||||
                    )
 | 
			
		||||
                    # Printing the tool message for visibility
 | 
			
		||||
                    # Print the tool message for visibility
 | 
			
		||||
                    self.text_formatter.print_message(tool_message)
 | 
			
		||||
                    # Add tool message to memory
 | 
			
		||||
                    self.memory.add_message(tool_message)
 | 
			
		||||
                    # Append tool message to the persistent audit log
 | 
			
		||||
                    tool_execution_record = ToolExecutionRecord(
 | 
			
		||||
                        tool_call_id=tool_id,
 | 
			
		||||
| 
						 | 
				
			
			@ -201,69 +193,54 @@ class Agent(AgentBase):
 | 
			
		|||
        Raises:
 | 
			
		||||
            AgentError: On chat failure or tool issues.
 | 
			
		||||
        """
 | 
			
		||||
        for iteration in range(self.max_iterations):
 | 
			
		||||
            logger.info(f"Iteration {iteration + 1}/{self.max_iterations} started.")
 | 
			
		||||
 | 
			
		||||
        final_reply = None
 | 
			
		||||
        for turn in range(1, self.max_iterations + 1):
 | 
			
		||||
            logger.info(f"Iteration {turn}/{self.max_iterations} started.")
 | 
			
		||||
            try:
 | 
			
		||||
                # Generate response using the LLM
 | 
			
		||||
                response = self.llm.generate(
 | 
			
		||||
                response: LLMChatResponse = self.llm.generate(
 | 
			
		||||
                    messages=messages,
 | 
			
		||||
                    tools=self.get_llm_tools(),
 | 
			
		||||
                    tool_choice=self.tool_choice,
 | 
			
		||||
                )
 | 
			
		||||
                # If response is a dict, convert to ChatCompletion
 | 
			
		||||
                if isinstance(response, dict):
 | 
			
		||||
                    response = ChatCompletion(**response)
 | 
			
		||||
                elif not isinstance(response, ChatCompletion):
 | 
			
		||||
                    # If response is an iterator (stream), raise TypeError
 | 
			
		||||
                    raise TypeError(f"Expected ChatCompletion, got {type(response)}")
 | 
			
		||||
                # Get the response message and print it
 | 
			
		||||
                # Get the first candidate from the response
 | 
			
		||||
                response_message = response.get_message()
 | 
			
		||||
                if response_message is not None:
 | 
			
		||||
                    self.text_formatter.print_message(response_message)
 | 
			
		||||
 | 
			
		||||
                # Get Reason for the response
 | 
			
		||||
                reason = FinishReason(response.get_reason())
 | 
			
		||||
                # Check if the response contains an assistant message
 | 
			
		||||
                if response_message is None:
 | 
			
		||||
                    raise AgentError("LLM returned no assistant message")
 | 
			
		||||
                else:
 | 
			
		||||
                    assistant = response_message
 | 
			
		||||
                    self.text_formatter.print_message(assistant)
 | 
			
		||||
                    self.memory.add_message(assistant)
 | 
			
		||||
 | 
			
		||||
                # Handle tool calls response
 | 
			
		||||
                if reason == FinishReason.TOOL_CALLS:
 | 
			
		||||
                    tool_calls = response.get_tool_calls()
 | 
			
		||||
                if assistant is not None and assistant.has_tool_calls():
 | 
			
		||||
                    tool_calls = assistant.get_tool_calls()
 | 
			
		||||
                    if tool_calls:
 | 
			
		||||
                        # Add the assistant message with tool calls to the conversation
 | 
			
		||||
                        if response_message is not None:
 | 
			
		||||
                            messages.append(response_message)
 | 
			
		||||
                        # Execute tools and collect results for this iteration only
 | 
			
		||||
                        tool_messages = await self.execute_tools(tool_calls)
 | 
			
		||||
                        # Add tool results to messages for the next iteration
 | 
			
		||||
                        messages.extend([tm.model_dump() for tm in tool_messages])
 | 
			
		||||
                        # Continue to next iteration to let LLM process tool results
 | 
			
		||||
                        messages.append(assistant.model_dump())
 | 
			
		||||
                        tool_msgs = await self.execute_tools(tool_calls)
 | 
			
		||||
                        messages.extend([tm.model_dump() for tm in tool_msgs])
 | 
			
		||||
                        if turn == self.max_iterations:
 | 
			
		||||
                            final_reply = assistant
 | 
			
		||||
                            logger.info("Reached max turns after tool calls; stopping.")
 | 
			
		||||
                            break
 | 
			
		||||
                        continue
 | 
			
		||||
                # Handle stop response
 | 
			
		||||
                elif reason == FinishReason.STOP:
 | 
			
		||||
                    # Append AssistantMessage to memory
 | 
			
		||||
                    msg = AssistantMessage(content=response.get_content() or "")
 | 
			
		||||
                    self.memory.add_message(msg)
 | 
			
		||||
                    return msg.content
 | 
			
		||||
                # Handle Function call response
 | 
			
		||||
                elif reason == FinishReason.FUNCTION_CALL:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        "LLM returned a deprecated function_call. Function calls are not processed by this agent."
 | 
			
		||||
                    )
 | 
			
		||||
                    msg = AssistantMessage(
 | 
			
		||||
                        content="Function calls are not supported or processed by this agent."
 | 
			
		||||
                    )
 | 
			
		||||
                    self.memory.add_message(msg)
 | 
			
		||||
                    return msg.content
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.error(f"Unknown finish reason: {reason}")
 | 
			
		||||
                    raise AgentError(f"Unknown finish reason: {reason}")
 | 
			
		||||
 | 
			
		||||
                # No tool calls => done
 | 
			
		||||
                final_reply = assistant
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"Error during chat generation: {e}")
 | 
			
		||||
                logger.error(f"Error on turn {turn}: {e}")
 | 
			
		||||
                raise AgentError(f"Failed during chat generation: {e}") from e
 | 
			
		||||
 | 
			
		||||
        logger.info("Max iterations reached. Agent has stopped.")
 | 
			
		||||
        return None
 | 
			
		||||
        # Post-loop
 | 
			
		||||
        if final_reply is None:
 | 
			
		||||
            logger.warning("No reply generated; hitting max iterations.")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Agent conversation completed after {turn} turns.")
 | 
			
		||||
        return final_reply
 | 
			
		||||
 | 
			
		||||
    async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,18 +26,11 @@ from typing import (
 | 
			
		|||
    ClassVar,
 | 
			
		||||
)
 | 
			
		||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict
 | 
			
		||||
from dapr_agents.llm.chat import ChatClientBase
 | 
			
		||||
from dapr_agents.llm.openai import OpenAIChatClient
 | 
			
		||||
from dapr_agents.llm.huggingface import HFHubChatClient
 | 
			
		||||
from dapr_agents.llm.nvidia import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.llm.dapr import DaprChatClient
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Type alias for all concrete chat client implementations
 | 
			
		||||
ChatClientType = Union[
 | 
			
		||||
    OpenAIChatClient, HFHubChatClient, NVIDIAChatClient, DaprChatClient
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AgentBase(BaseModel, ABC):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -73,7 +66,7 @@ class AgentBase(BaseModel, ABC):
 | 
			
		|||
        default=None,
 | 
			
		||||
        description="A custom system prompt, overriding name, role, goal, and instructions.",
 | 
			
		||||
    )
 | 
			
		||||
    llm: ChatClientType = Field(
 | 
			
		||||
    llm: ChatClientBase = Field(
 | 
			
		||||
        default_factory=OpenAIChatClient,
 | 
			
		||||
        description="Language model client for generating responses.",
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,16 +1,16 @@
 | 
			
		|||
import json
 | 
			
		||||
import logging
 | 
			
		||||
from datetime import datetime, timezone
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Any, Dict, List, Optional, Union
 | 
			
		||||
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext  # type: ignore
 | 
			
		||||
from pydantic import BaseModel, Field, model_validator
 | 
			
		||||
from pydantic import Field, model_validator
 | 
			
		||||
 | 
			
		||||
from dapr_agents.agents.base import AgentBase
 | 
			
		||||
from dapr_agents.types import (
 | 
			
		||||
    AgentError,
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
    ToolExecutionRecord,
 | 
			
		||||
    ToolMessage,
 | 
			
		||||
    UserMessage,
 | 
			
		||||
| 
						 | 
				
			
			@ -32,15 +32,6 @@ from .state import (
 | 
			
		|||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FinishReason(str, Enum):
 | 
			
		||||
    UNKNOWN = "unknown"
 | 
			
		||||
    STOP = "stop"
 | 
			
		||||
    LENGTH = "length"
 | 
			
		||||
    CONTENT_FILTER = "content_filter"
 | 
			
		||||
    TOOL_CALLS = "tool_calls"
 | 
			
		||||
    FUNCTION_CALL = "function_call"  # deprecated
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(@Sicoyle): Clear up the lines between DurableAgent and AgentWorkflow
 | 
			
		||||
class DurableAgent(AgenticWorkflow, AgentBase):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -137,156 +128,150 @@ class DurableAgent(AgenticWorkflow, AgentBase):
 | 
			
		|||
        Returns:
 | 
			
		||||
            Dict[str, Any]: The final response message when the workflow completes, or None if continuing to the next iteration.
 | 
			
		||||
        """
 | 
			
		||||
        # Step 0: Retrieve task, iteration, and sourceworkflow instance ID from the message
 | 
			
		||||
        # Step 1: pull out task + metadata
 | 
			
		||||
        if isinstance(message, dict):
 | 
			
		||||
            task = message.get("task", None)
 | 
			
		||||
            iteration = message.get("iteration", 0)
 | 
			
		||||
            source_workflow_instance_id = message.get("workflow_instance_id")
 | 
			
		||||
            metadata = message.get("_message_metadata", {}) or {}
 | 
			
		||||
        else:
 | 
			
		||||
            task = getattr(message, "task", None)
 | 
			
		||||
            iteration = getattr(message, "iteration", 0)
 | 
			
		||||
            source_workflow_instance_id = getattr(message, "workflow_instance_id", None)
 | 
			
		||||
        # This is the instance ID of the current workflow execution
 | 
			
		||||
            metadata = getattr(message, "_message_metadata", {}) or {}
 | 
			
		||||
 | 
			
		||||
        instance_id = ctx.instance_id
 | 
			
		||||
        source = metadata.get("source")
 | 
			
		||||
        final_message: Optional[Dict[str, Any]] = None
 | 
			
		||||
 | 
			
		||||
        if not ctx.is_replaying:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Workflow iteration {iteration + 1} started (Instance ID: {instance_id})."
 | 
			
		||||
            )
 | 
			
		||||
            logger.debug(f"Initial message from {source} -> {self.name}")
 | 
			
		||||
 | 
			
		||||
        # Step 1: Initialize workflow entry and state if this is the first iteration
 | 
			
		||||
        if iteration == 0:
 | 
			
		||||
            # Get metadata from the message, if available
 | 
			
		||||
            if isinstance(message, dict):
 | 
			
		||||
                metadata = message.get("_message_metadata", {})
 | 
			
		||||
            else:
 | 
			
		||||
                metadata = getattr(message, "_message_metadata", {})
 | 
			
		||||
            source = metadata.get("source", None)
 | 
			
		||||
            # Use activity to record initial entry for replay safety
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.record_initial_entry,
 | 
			
		||||
                input={
 | 
			
		||||
                    "instance_id": instance_id,
 | 
			
		||||
                    "input": task or "Triggered without input.",
 | 
			
		||||
                    "source": source,
 | 
			
		||||
                    "source_workflow_instance_id": source_workflow_instance_id,
 | 
			
		||||
                    "output": "",
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(f"Initial message from {source} -> {self.name}")
 | 
			
		||||
 | 
			
		||||
        # Step 2: Retrieve workflow entry info for this instance
 | 
			
		||||
        entry_info = yield ctx.call_activity(
 | 
			
		||||
            self.get_workflow_entry_info, input={"instance_id": instance_id}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        source = entry_info.get("source")
 | 
			
		||||
        source_workflow_instance_id = entry_info.get("source_workflow_instance_id")
 | 
			
		||||
 | 
			
		||||
        # Step 3: Generate Response via LLM
 | 
			
		||||
        response = yield ctx.call_activity(
 | 
			
		||||
            self.generate_response, input={"task": task, "instance_id": instance_id}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Step 4: Extract Response Message from LLM Response
 | 
			
		||||
        response_message = yield ctx.call_activity(
 | 
			
		||||
            self.get_response_message, input={"response": response}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Step 5: Extract Finish Reason from LLM Response
 | 
			
		||||
        finish_reason = yield ctx.call_activity(
 | 
			
		||||
            self.get_finish_reason, input={"response": response}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Step 6:Add the assistant's response message to the chat history
 | 
			
		||||
        yield ctx.call_activity(
 | 
			
		||||
            self.append_assistant_message,
 | 
			
		||||
            input={"instance_id": instance_id, "message": response_message},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Step 7: Handle tool calls response
 | 
			
		||||
        if finish_reason == FinishReason.TOOL_CALLS:
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info("Tool calls detected in LLM response.")
 | 
			
		||||
            # Retrieve the list of tool calls extracted from the LLM response
 | 
			
		||||
            tool_calls = yield ctx.call_activity(
 | 
			
		||||
                self.get_tool_calls, input={"response": response}
 | 
			
		||||
            )
 | 
			
		||||
            if tool_calls:
 | 
			
		||||
        try:
 | 
			
		||||
            # Loop up to max_iterations
 | 
			
		||||
            for turn in range(1, self.max_iterations + 1):
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.debug(f"Executing {len(tool_calls)} tool call(s)..")
 | 
			
		||||
                # Run the tool calls in parallel
 | 
			
		||||
                parallel = [
 | 
			
		||||
                    ctx.call_activity(self.run_tool, input={"tool_call": tc})
 | 
			
		||||
                    for tc in tool_calls
 | 
			
		||||
                ]
 | 
			
		||||
                tool_results = yield self.when_all(parallel)
 | 
			
		||||
                # Add tool results for the next iteration
 | 
			
		||||
                for tr in tool_results:
 | 
			
		||||
                    logger.info(
 | 
			
		||||
                        f"Workflow turn {turn}/{self.max_iterations} (Instance ID: {instance_id})"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                # Step 2: On turn 1, record the initial entry
 | 
			
		||||
                if turn == 1:
 | 
			
		||||
                    yield ctx.call_activity(
 | 
			
		||||
                        self.append_tool_message,
 | 
			
		||||
                        input={"instance_id": instance_id, "tool_result": tr},
 | 
			
		||||
                        self.record_initial_entry,
 | 
			
		||||
                        input={
 | 
			
		||||
                            "instance_id": instance_id,
 | 
			
		||||
                            "input": task or "Triggered without input.",
 | 
			
		||||
                            "source": source,
 | 
			
		||||
                            "source_workflow_instance_id": source_workflow_instance_id,
 | 
			
		||||
                            "output": "",
 | 
			
		||||
                        },
 | 
			
		||||
                    )
 | 
			
		||||
        # Step 8: Process iteration count and finish reason
 | 
			
		||||
        next_iteration_count = iteration + 1
 | 
			
		||||
        max_iterations_reached = next_iteration_count > self.max_iterations
 | 
			
		||||
        if finish_reason == FinishReason.STOP or max_iterations_reached:
 | 
			
		||||
            # Process max iterations reached
 | 
			
		||||
            if max_iterations_reached:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        f"Workflow {instance_id} reached the max iteration limit ({self.max_iterations}) before finishing naturally."
 | 
			
		||||
                    )
 | 
			
		||||
                # Modify the response message to indicate forced stop
 | 
			
		||||
                response_message[
 | 
			
		||||
                    "content"
 | 
			
		||||
                ] += "\n\nThe workflow was terminated because it reached the maximum iteration limit. The task may not be fully complete."
 | 
			
		||||
 | 
			
		||||
            # Broadcast the final response if a broadcast topic is set
 | 
			
		||||
            if self.broadcast_topic_name:
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.broadcast_message_to_agents,
 | 
			
		||||
                    input={"message": response_message},
 | 
			
		||||
                # Step 3: Retrieve workflow entry info for this instance
 | 
			
		||||
                entry_info: dict = yield ctx.call_activity(
 | 
			
		||||
                    self.get_workflow_entry_info, input={"instance_id": instance_id}
 | 
			
		||||
                )
 | 
			
		||||
                source = entry_info.get("source")
 | 
			
		||||
                source_workflow_instance_id = entry_info.get(
 | 
			
		||||
                    "source_workflow_instance_id"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Respond to source agent if available
 | 
			
		||||
            if source and source_workflow_instance_id:
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.send_response_back,
 | 
			
		||||
                    input={
 | 
			
		||||
                        "response": response_message,
 | 
			
		||||
                        "target_agent": source,
 | 
			
		||||
                        "target_instance_id": source_workflow_instance_id,
 | 
			
		||||
                    },
 | 
			
		||||
                # Step 4: Generate Response with LLM
 | 
			
		||||
                response_message: dict = yield ctx.call_activity(
 | 
			
		||||
                    self.generate_response,
 | 
			
		||||
                    input={"task": task, "instance_id": instance_id},
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Share Final Message
 | 
			
		||||
                # Step 5: Add the assistant's response message to the chat history
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.append_assistant_message,
 | 
			
		||||
                    input={"instance_id": instance_id, "message": response_message},
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # Step 6: Handle tool calls response
 | 
			
		||||
                tool_calls = response_message.get("tool_calls") or []
 | 
			
		||||
                if tool_calls:
 | 
			
		||||
                    if not ctx.is_replaying:
 | 
			
		||||
                        logger.info(
 | 
			
		||||
                            f"Turn {turn}: executing {len(tool_calls)} tool call(s)"
 | 
			
		||||
                        )
 | 
			
		||||
                    # fan‑out parallel tool executions
 | 
			
		||||
                    parallel = [
 | 
			
		||||
                        ctx.call_activity(self.run_tool, input={"tool_call": tc})
 | 
			
		||||
                        for tc in tool_calls
 | 
			
		||||
                    ]
 | 
			
		||||
                    tool_results: List[Dict[str, Any]] = yield self.when_all(parallel)
 | 
			
		||||
                    # Add tool results for the next iteration
 | 
			
		||||
                    for tr in tool_results:
 | 
			
		||||
                        yield ctx.call_activity(
 | 
			
		||||
                            self.append_tool_message,
 | 
			
		||||
                            input={"instance_id": instance_id, "tool_result": tr},
 | 
			
		||||
                        )
 | 
			
		||||
                    # 🔴 If this was the last turn, stop here—even though there were tool calls
 | 
			
		||||
                    if turn == self.max_iterations:
 | 
			
		||||
                        final_message = response_message
 | 
			
		||||
                        final_message[
 | 
			
		||||
                            "content"
 | 
			
		||||
                        ] += "\n\n⚠️ Stopped: reached max iterations."
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                    # Otherwise, prepare for next turn: clear task so that generate_response() uses memory/history
 | 
			
		||||
                    task = None
 | 
			
		||||
                    continue  # bump to next turn
 | 
			
		||||
 | 
			
		||||
                # No tool calls → this is your final answer
 | 
			
		||||
                final_message = response_message
 | 
			
		||||
 | 
			
		||||
                # 🔴 If it happened to be the last turn, banner it
 | 
			
		||||
                if turn == self.max_iterations:
 | 
			
		||||
                    final_message["content"] += "\n\n⚠️ Stopped: reached max iterations."
 | 
			
		||||
 | 
			
		||||
                break  # exit loop with final_message
 | 
			
		||||
            else:
 | 
			
		||||
                raise AgentError("Workflow ended without producing a final response")
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.exception("Workflow error", exc_info=e)
 | 
			
		||||
            final_message = {
 | 
			
		||||
                "role": "assistant",
 | 
			
		||||
                "content": f"⚠️ Unexpected error: {e}",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        # Step 7: Broadcast the final response if a broadcast topic is set
 | 
			
		||||
        if self.broadcast_topic_name:
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.finalize_workflow,
 | 
			
		||||
                self.broadcast_message_to_agents,
 | 
			
		||||
                input={"message": final_message},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Respond to source agent if available
 | 
			
		||||
        if source and source_workflow_instance_id:
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.send_response_back,
 | 
			
		||||
                input={
 | 
			
		||||
                    "instance_id": instance_id,
 | 
			
		||||
                    "final_output": response_message["content"],
 | 
			
		||||
                    "response": final_message,
 | 
			
		||||
                    "target_agent": source,
 | 
			
		||||
                    "target_instance_id": source_workflow_instance_id,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
            # Log the finalization of the workflow
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                verdict = "max_iterations_reached" if max_iterations_reached else "stop"
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"Workflow {instance_id} has been finalized with verdict: {verdict}"
 | 
			
		||||
                )
 | 
			
		||||
            return response_message
 | 
			
		||||
 | 
			
		||||
        # Step 9: Continue Workflow Execution
 | 
			
		||||
        if isinstance(message, dict):
 | 
			
		||||
            message.update({"task": None, "iteration": next_iteration_count})
 | 
			
		||||
            next_message = message
 | 
			
		||||
        else:
 | 
			
		||||
            # For Pydantic model, create a new dict with updated fields
 | 
			
		||||
            next_message = message.model_dump()
 | 
			
		||||
            next_message.update({"task": None, "iteration": next_iteration_count})
 | 
			
		||||
        ctx.continue_as_new(next_message)
 | 
			
		||||
        # Save final output to workflow state
 | 
			
		||||
        yield ctx.call_activity(
 | 
			
		||||
            self.finalize_workflow,
 | 
			
		||||
            input={
 | 
			
		||||
                "instance_id": instance_id,
 | 
			
		||||
                "final_output": final_message["content"],
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Set verdict for the workflow instance
 | 
			
		||||
        if not ctx.is_replaying:
 | 
			
		||||
            verdict = (
 | 
			
		||||
                "max_iterations_reached" if turn == self.max_iterations else "completed"
 | 
			
		||||
            )
 | 
			
		||||
            logger.info(f"Workflow {instance_id} finalized: {verdict}")
 | 
			
		||||
 | 
			
		||||
        # Return the final response message
 | 
			
		||||
        return final_message
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    def record_initial_entry(
 | 
			
		||||
| 
						 | 
				
			
			@ -391,96 +376,23 @@ class DurableAgent(AgenticWorkflow, AgentBase):
 | 
			
		|||
 | 
			
		||||
        # Generate response using the LLM
 | 
			
		||||
        try:
 | 
			
		||||
            response = self.llm.generate(
 | 
			
		||||
            response: LLMChatResponse = self.llm.generate(
 | 
			
		||||
                messages=messages,
 | 
			
		||||
                tools=self.get_llm_tools(),
 | 
			
		||||
                tool_choice=self.tool_choice,
 | 
			
		||||
            )
 | 
			
		||||
            if isinstance(response, BaseModel):
 | 
			
		||||
                return response.model_dump()
 | 
			
		||||
            elif isinstance(response, dict):
 | 
			
		||||
                return response
 | 
			
		||||
            else:
 | 
			
		||||
                # Defensive: raise error for unexpected type
 | 
			
		||||
                raise AgentError(f"Unexpected response type: {type(response)}")
 | 
			
		||||
            # Get the first candidate from the response
 | 
			
		||||
            response_message = response.get_message()
 | 
			
		||||
            # Check if the response contains an assistant message
 | 
			
		||||
            if response_message is None:
 | 
			
		||||
                raise AgentError("LLM returned no assistant message")
 | 
			
		||||
            # Convert the response message to a dict to work with JSON serialization
 | 
			
		||||
            assistant_message = response_message.model_dump()
 | 
			
		||||
            return assistant_message
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error during chat generation: {e}")
 | 
			
		||||
            raise AgentError(f"Failed during chat generation: {e}") from e
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    def get_response_message(self, response: Dict[str, Any]) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Extracts the response message from the first choice in the LLM response.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            response (Dict[str, Any]): The response dictionary from the LLM, expected to contain a "choices" key.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Dict[str, Any]: The extracted response message with the agent's name added.
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            AgentError: If no response message is found.
 | 
			
		||||
        """
 | 
			
		||||
        choices = response.get("choices", [])
 | 
			
		||||
        if choices:
 | 
			
		||||
            response_message = choices[0].get("message", {})
 | 
			
		||||
            if response_message:
 | 
			
		||||
                return response_message
 | 
			
		||||
        raise AgentError("No response message found in LLM response.")
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    def get_finish_reason(self, response: Dict[str, Any]) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Extracts the finish reason from the LLM response, indicating why generation stopped.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            response (Dict[str, Any]): The response dictionary from the LLM, expected to contain a "choices" key.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            FinishReason: The reason the model stopped generating tokens as an enum value.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            choices = response.get("choices", [])
 | 
			
		||||
            if choices and len(choices) > 0:
 | 
			
		||||
                reason_str = choices[0].get("finish_reason", FinishReason.UNKNOWN.value)
 | 
			
		||||
                try:
 | 
			
		||||
                    return FinishReason(reason_str)
 | 
			
		||||
                except ValueError:
 | 
			
		||||
                    logger.warning(f"Unrecognized finish reason: {reason_str}")
 | 
			
		||||
                    return FinishReason.UNKNOWN
 | 
			
		||||
            # If choices is empty, return UNKNOWN
 | 
			
		||||
            return FinishReason.UNKNOWN
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error extracting finish reason: {e}")
 | 
			
		||||
            return FinishReason.UNKNOWN
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    def get_tool_calls(
 | 
			
		||||
        self, response: Dict[str, Any]
 | 
			
		||||
    ) -> Optional[List[Dict[str, Any]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Extracts tool calls from the first choice in the LLM response, if available.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            response (Dict[str, Any]): The response dictionary from the LLM, expected to contain "choices"
 | 
			
		||||
                                    and potentially tool call information.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Optional[List[Dict[str, Any]]]: A list of tool calls if present, otherwise None.
 | 
			
		||||
        """
 | 
			
		||||
        choices = response.get("choices", [])
 | 
			
		||||
        if not choices:
 | 
			
		||||
            logger.warning("No choices found in LLM response.")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        tool_calls = choices[0].get("message", {}).get("tool_calls")
 | 
			
		||||
        if tool_calls:
 | 
			
		||||
            return tool_calls
 | 
			
		||||
 | 
			
		||||
        # Only log if choices exist but no tool_calls
 | 
			
		||||
        logger.info("No tool calls found in the first LLM response choice.")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,6 @@ class TriggerAction(BaseModel):
 | 
			
		|||
        None,
 | 
			
		||||
        description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
 | 
			
		||||
    )
 | 
			
		||||
    iteration: Optional[int] = Field(0, description="")
 | 
			
		||||
    workflow_instance_id: Optional[str] = Field(
 | 
			
		||||
        default=None, description="Dapr workflow instance id from source if available"
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,3 +1,3 @@
 | 
			
		|||
from .otel import DaprAgentsOTel
 | 
			
		||||
from .otel import DaprAgentsOtel
 | 
			
		||||
 | 
			
		||||
__all__ = ["DaprAgentsOTel"]
 | 
			
		||||
__all__ = ["DaprAgentsOtel"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@ from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExp
 | 
			
		|||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DaprAgentsOTel:
 | 
			
		||||
class DaprAgentsOtel:
 | 
			
		||||
    """
 | 
			
		||||
    OpenTelemetry configuration for Dapr agents.
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,27 +1,41 @@
 | 
			
		|||
from abc import ABC, abstractmethod
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Type, Union
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    List,
 | 
			
		||||
    Literal,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Type,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
    overload,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents.prompt.base import PromptTemplateBase
 | 
			
		||||
from dapr_agents.prompt.prompty import Prompty
 | 
			
		||||
from dapr_agents.tool.base import AgentTool
 | 
			
		||||
from dapr_agents.types.message import ChatCompletion
 | 
			
		||||
from dapr_agents.types.message import LLMChatCandidateChunk, LLMChatResponse
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T", bound=BaseModel)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatClientBase(BaseModel, ABC):
 | 
			
		||||
class ChatClientBase(ABC):
 | 
			
		||||
    """
 | 
			
		||||
    Base class for chat-specific functionality.
 | 
			
		||||
    Handles Prompty integration and provides abstract methods for chat client configuration.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        prompty: Optional Prompty spec used to render `input_data` into messages.
 | 
			
		||||
        prompt_template: Optional prompt template object for rendering.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    prompty: Optional[Prompty] = Field(
 | 
			
		||||
        default=None, description="Instance of the Prompty object (optional)."
 | 
			
		||||
    )
 | 
			
		||||
    prompt_template: Optional[PromptTemplateBase] = Field(
 | 
			
		||||
        default=None, description="Prompt template for rendering (optional)."
 | 
			
		||||
    )
 | 
			
		||||
    prompty: Optional[Prompty]
 | 
			
		||||
    prompt_template: Optional[PromptTemplateBase]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -31,43 +45,97 @@ class ChatClientBase(BaseModel, ABC):
 | 
			
		|||
        timeout: Union[int, float, Dict[str, Any]] = 1500,
 | 
			
		||||
    ) -> "ChatClientBase":
 | 
			
		||||
        """
 | 
			
		||||
        Abstract method to load a Prompty source and configure the chat client.
 | 
			
		||||
        Load a Prompty spec (path or inline), extract its model config and
 | 
			
		||||
        prompt template, and return a configured chat client.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            prompty_source (Union[str, Path]): Source of the Prompty, either a file path or inline Prompty content.
 | 
			
		||||
            timeout (Union[int, float, Dict[str, Any]]): Timeout for requests.
 | 
			
		||||
            prompty_source: Path or inline YAML/JSON for a Prompty spec.
 | 
			
		||||
            timeout: HTTP timeout (seconds or HTTPX-style dict).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            ChatClientBase: Configured chat client instance.
 | 
			
		||||
            A ready-to-use ChatClientBase subclass instance.
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def generate(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Union[
 | 
			
		||||
            str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: None = None,
 | 
			
		||||
        structured_mode: Optional[str] = None,
 | 
			
		||||
        stream: Literal[False] = False,
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> LLMChatResponse:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    """If `stream=False` and no `response_format`, returns raw LLMChatResponse."""
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def generate(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Union[
 | 
			
		||||
            str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: Type[T],
 | 
			
		||||
        structured_mode: Optional[str] = None,
 | 
			
		||||
        stream: Literal[False] = False,
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> Union[T, List[T]]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    """If `stream=False` and `response_format=SomeModel`, returns that model or a list thereof."""
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def generate(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Union[
 | 
			
		||||
            str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: Optional[Type[T]] = None,
 | 
			
		||||
        structured_mode: Optional[str] = None,
 | 
			
		||||
        stream: Literal[True],
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> Iterator[LLMChatCandidateChunk]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    """If `stream=True`, returns a streaming iterator of chunks."""
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def generate(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Union[
 | 
			
		||||
            str, Dict[str, Any], BaseModel, Iterable[Union[Dict[str, Any], BaseModel]]
 | 
			
		||||
            str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: Optional[Type[BaseModel]] = None,
 | 
			
		||||
        response_format: Optional[Type[T]] = None,
 | 
			
		||||
        structured_mode: Optional[str] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
 | 
			
		||||
        stream: bool = False,
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> Union[
 | 
			
		||||
        Iterator[LLMChatCandidateChunk],
 | 
			
		||||
        LLMChatResponse,
 | 
			
		||||
        T,
 | 
			
		||||
        List[T],
 | 
			
		||||
    ]:
 | 
			
		||||
        """
 | 
			
		||||
        Abstract method to generate chat completions.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            messages (Optional): Either pre-set messages or None if using input_data.
 | 
			
		||||
            input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
 | 
			
		||||
            model (Optional[str]): Specific model to use for the request, overriding the default.
 | 
			
		||||
            tools (Optional[List[Union[AgentTool, Dict[str, Any]]]]): List of tools for the request.
 | 
			
		||||
            response_format (Optional[Type[BaseModel]]): Optional Pydantic model for structured response parsing.
 | 
			
		||||
            structured_mode (Optional[str]): Mode for structured output.
 | 
			
		||||
            **kwargs: Additional parameters for the chat completion API.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
 | 
			
		||||
        The implementation must accept the full set of kwargs and return
 | 
			
		||||
        the union of all possible overload returns.
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
        ...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,48 +1,62 @@
 | 
			
		|||
from dapr_agents.llm.dapr.client import DaprInferenceClientBase
 | 
			
		||||
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
 | 
			
		||||
from dapr_agents.prompt.prompty import Prompty
 | 
			
		||||
from dapr_agents.types.message import BaseMessage
 | 
			
		||||
from dapr_agents.llm.chat import ChatClientBase
 | 
			
		||||
from dapr_agents.tool import AgentTool
 | 
			
		||||
from dapr.clients.grpc._request import ConversationInput
 | 
			
		||||
from typing import (
 | 
			
		||||
    Union,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Any,
 | 
			
		||||
    List,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    Type,
 | 
			
		||||
    Literal,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
)
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    List,
 | 
			
		||||
    Literal,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Type,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from dapr.clients.grpc._request import ConversationInput
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm.chat import ChatClientBase
 | 
			
		||||
from dapr_agents.llm.dapr.client import DaprInferenceClientBase
 | 
			
		||||
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
 | 
			
		||||
from dapr_agents.prompt.base import PromptTemplateBase
 | 
			
		||||
from dapr_agents.prompt.prompty import Prompty
 | 
			
		||||
from dapr_agents.tool import AgentTool
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    BaseMessage,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
 | 
			
		||||
    """
 | 
			
		||||
    Concrete class for Dapr's chat completion API using the Inference API.
 | 
			
		||||
    This class extends the ChatClientBase.
 | 
			
		||||
    Chat client for Dapr's Inference API.
 | 
			
		||||
 | 
			
		||||
    Integrates Prompty-driven prompt templates, tool injection,
 | 
			
		||||
    PII scrubbing, and normalizes the Dapr output into our unified
 | 
			
		||||
    LLMChatResponse schema.  **Streaming is not supported.**
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"function_call"}
 | 
			
		||||
    prompty: Optional[Prompty] = Field(
 | 
			
		||||
        default=None, description="Optional Prompty instance for templating."
 | 
			
		||||
    )
 | 
			
		||||
    prompt_template: Optional[PromptTemplateBase] = Field(
 | 
			
		||||
        default=None, description="Optional prompt-template to format inputs."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Only function_call–style structured output is supported
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call"}
 | 
			
		||||
 | 
			
		||||
    def model_post_init(self, __context: Any) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes private attributes for provider, api, config, and client after validation.
 | 
			
		||||
        After Pydantic init, set up API/type and default LLM component from env.
 | 
			
		||||
        """
 | 
			
		||||
        # Set the private provider and api attributes
 | 
			
		||||
        self._api = "chat"
 | 
			
		||||
        self._llm_component = os.environ["DAPR_LLM_COMPONENT_DEFAULT"]
 | 
			
		||||
 | 
			
		||||
        return super().model_post_init(__context)
 | 
			
		||||
        super().model_post_init(__context)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_prompty(
 | 
			
		||||
| 
						 | 
				
			
			@ -51,23 +65,17 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
 | 
			
		|||
        timeout: Union[int, float, Dict[str, Any]] = 1500,
 | 
			
		||||
    ) -> "DaprChatClient":
 | 
			
		||||
        """
 | 
			
		||||
        Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content.
 | 
			
		||||
        Build a DaprChatClient from a Prompty spec.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
 | 
			
		||||
                or inline Prompty content as a string.
 | 
			
		||||
            timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
 | 
			
		||||
            prompty_source: Path or inline Prompty YAML/JSON.
 | 
			
		||||
            timeout:       Request timeout in seconds or HTTPX-style dict.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            DaprChatClient: An instance of DaprChatClient configured with the model settings from the Prompty source.
 | 
			
		||||
            Configured DaprChatClient.
 | 
			
		||||
        """
 | 
			
		||||
        # Load the Prompty instance from the provided source
 | 
			
		||||
        prompty_instance = Prompty.load(prompty_source)
 | 
			
		||||
 | 
			
		||||
        # Generate the prompt template from the Prompty instance
 | 
			
		||||
        prompt_template = Prompty.to_prompt_template(prompty_instance)
 | 
			
		||||
 | 
			
		||||
        # Initialize the DaprChatClient based on the Prompty model configuration
 | 
			
		||||
        return cls.model_validate(
 | 
			
		||||
            {
 | 
			
		||||
                "timeout": timeout,
 | 
			
		||||
| 
						 | 
				
			
			@ -77,17 +85,18 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    def translate_response(self, response: dict, model: str) -> dict:
 | 
			
		||||
        """Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion."""
 | 
			
		||||
        """
 | 
			
		||||
        Convert Dapr response into OpenAI-style ChatCompletion dict.
 | 
			
		||||
        """
 | 
			
		||||
        choices = [
 | 
			
		||||
            {
 | 
			
		||||
                "finish_reason": "stop",
 | 
			
		||||
                "index": i,
 | 
			
		||||
                "message": {"content": output["result"], "role": "assistant"},
 | 
			
		||||
                "index": idx,
 | 
			
		||||
                "message": {"role": "assistant", "content": out["result"]},
 | 
			
		||||
                "logprobs": None,
 | 
			
		||||
            }
 | 
			
		||||
            for i, output in enumerate(response.get("outputs", []))
 | 
			
		||||
            for idx, out in enumerate(response.get("outputs", []))
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "choices": choices,
 | 
			
		||||
            "created": int(time.time()),
 | 
			
		||||
| 
						 | 
				
			
			@ -99,11 +108,14 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
 | 
			
		|||
    def convert_to_conversation_inputs(
 | 
			
		||||
        self, inputs: List[Dict[str, Any]]
 | 
			
		||||
    ) -> List[ConversationInput]:
 | 
			
		||||
        """
 | 
			
		||||
        Map normalized messages into Dapr ConversationInput objects.
 | 
			
		||||
        """
 | 
			
		||||
        return [
 | 
			
		||||
            ConversationInput(
 | 
			
		||||
                content=item["content"],
 | 
			
		||||
                role=item.get("role"),
 | 
			
		||||
                scrub_pii=item.get("scrubPII") == "true",
 | 
			
		||||
                scrub_pii=bool(item.get("scrubPII")),
 | 
			
		||||
            )
 | 
			
		||||
            for item in inputs
 | 
			
		||||
        ]
 | 
			
		||||
| 
						 | 
				
			
			@ -116,59 +128,75 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
 | 
			
		|||
            BaseMessage,
 | 
			
		||||
            Iterable[Union[Dict[str, Any], BaseMessage]],
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        llm_component: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: Optional[Type[BaseModel]] = None,
 | 
			
		||||
        structured_mode: Literal["function_call"] = "function_call",
 | 
			
		||||
        scrubPII: Optional[bool] = False,
 | 
			
		||||
        scrubPII: bool = False,
 | 
			
		||||
        temperature: Optional[float] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> Union[
 | 
			
		||||
        LLMChatResponse,
 | 
			
		||||
        BaseModel,
 | 
			
		||||
        List[BaseModel],
 | 
			
		||||
    ]:
 | 
			
		||||
        """
 | 
			
		||||
        Generate chat completions based on provided messages or input_data for prompt templates.
 | 
			
		||||
        Issue a non-streaming chat completion via Dapr.
 | 
			
		||||
 | 
			
		||||
        - **Streaming is not supported** and setting `stream=True` will raise.
 | 
			
		||||
        - Returns a unified `LLMChatResponse` (if no `response_format`), or
 | 
			
		||||
          validated Pydantic model(s) when `response_format` is provided.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            messages (Optional): Either pre-set messages or None if using input_data.
 | 
			
		||||
            input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
 | 
			
		||||
            llm_component (str): Name of the LLM component to use for the request.
 | 
			
		||||
            tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
 | 
			
		||||
            response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
 | 
			
		||||
            structured_mode (Literal["function_call"]): Mode for structured output: "function_call" (Limited Support).
 | 
			
		||||
            scrubPII (Type[bool]): Optional flag to obfuscate any sensitive information coming back from the LLM.
 | 
			
		||||
            **kwargs: Additional parameters for the language model.
 | 
			
		||||
            messages:        Prebuilt messages or None to use `input_data`.
 | 
			
		||||
            input_data:      Variables for Prompty template rendering.
 | 
			
		||||
            llm_component:   Dapr component name (defaults from env).
 | 
			
		||||
            tools:           AgentTool or dict specifications.
 | 
			
		||||
            response_format: Pydantic model for structured output.
 | 
			
		||||
            structured_mode: Must be "function_call".
 | 
			
		||||
            scrubPII:        Obfuscate sensitive output if True.
 | 
			
		||||
            temperature:     Sampling temperature.
 | 
			
		||||
            **kwargs:        Other Dapr API parameters.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
 | 
			
		||||
            • `LLMChatResponse` if no `response_format`
 | 
			
		||||
            • Pydantic model (or `List[...]`) when `response_format` is set
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: on invalid `structured_mode`, missing inputs, or if `stream=True`.
 | 
			
		||||
        """
 | 
			
		||||
        # 1) Validate structured_mode
 | 
			
		||||
        if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
 | 
			
		||||
                f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
 | 
			
		||||
            )
 | 
			
		||||
        # 2) Disallow response_format + streaming
 | 
			
		||||
        if response_format is not None:
 | 
			
		||||
            raise ValueError("`response_format` is not supported by DaprChatClient.")
 | 
			
		||||
        if kwargs.get("stream"):
 | 
			
		||||
            raise ValueError("Streaming is not supported by DaprChatClient.")
 | 
			
		||||
 | 
			
		||||
        # If input_data is provided, check for a prompt_template
 | 
			
		||||
        # 3) Build messages via Prompty
 | 
			
		||||
        if input_data:
 | 
			
		||||
            if not self.prompt_template:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            logger.info("Using prompt template to generate messages.")
 | 
			
		||||
                raise ValueError("input_data provided but no prompt_template is set.")
 | 
			
		||||
            messages = self.prompt_template.format_prompt(**input_data)
 | 
			
		||||
 | 
			
		||||
        # Ensure we have messages at this point
 | 
			
		||||
        if not messages:
 | 
			
		||||
            raise ValueError("Either 'messages' or 'input_data' must be provided.")
 | 
			
		||||
 | 
			
		||||
        # Process and normalize the messages
 | 
			
		||||
        params = {"inputs": RequestHandler.normalize_chat_messages(messages)}
 | 
			
		||||
        # Merge Prompty parameters if available, then override with any explicit kwargs
 | 
			
		||||
        # 4) Normalize + merge defaults
 | 
			
		||||
        params: Dict[str, Any] = {
 | 
			
		||||
            "inputs": RequestHandler.normalize_chat_messages(messages)
 | 
			
		||||
        }
 | 
			
		||||
        if self.prompty:
 | 
			
		||||
            params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
 | 
			
		||||
        else:
 | 
			
		||||
            params.update(kwargs)
 | 
			
		||||
 | 
			
		||||
        # Prepare request parameters
 | 
			
		||||
        # 5) Inject tools + structured directives
 | 
			
		||||
        params = RequestHandler.process_params(
 | 
			
		||||
            params,
 | 
			
		||||
            llm_provider=self.provider,
 | 
			
		||||
| 
						 | 
				
			
			@ -176,28 +204,32 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
 | 
			
		|||
            response_format=response_format,
 | 
			
		||||
            structured_mode=structured_mode,
 | 
			
		||||
        )
 | 
			
		||||
        inputs = self.convert_to_conversation_inputs(params["inputs"])
 | 
			
		||||
 | 
			
		||||
        # 6) Convert to Dapr inputs & call
 | 
			
		||||
        conv_inputs = self.convert_to_conversation_inputs(params["inputs"])
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info("Invoking the Dapr Conversation API.")
 | 
			
		||||
            response = self.client.chat_completion(
 | 
			
		||||
            raw = self.client.chat_completion(
 | 
			
		||||
                llm=llm_component or self._llm_component,
 | 
			
		||||
                conversation_inputs=inputs,
 | 
			
		||||
                conversation_inputs=conv_inputs,
 | 
			
		||||
                scrub_pii=scrubPII,
 | 
			
		||||
                temperature=temperature,
 | 
			
		||||
            )
 | 
			
		||||
            transposed_response = self.translate_response(response, self._llm_component)
 | 
			
		||||
            logger.info("Chat completion retrieved successfully.")
 | 
			
		||||
 | 
			
		||||
            return ResponseHandler.process_response(
 | 
			
		||||
                transposed_response,
 | 
			
		||||
                llm_provider=self.provider,
 | 
			
		||||
                response_format=response_format,
 | 
			
		||||
                structured_mode=structured_mode,
 | 
			
		||||
                stream=params.get("stream", False),
 | 
			
		||||
            normalized = self.translate_response(
 | 
			
		||||
                raw, llm_component or self._llm_component
 | 
			
		||||
            )
 | 
			
		||||
            logger.info("Chat completion retrieved successfully.")
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(
 | 
			
		||||
                f"An error occurred during the Dapr Conversation API call: {e}"
 | 
			
		||||
            )
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
        # 7) Hand off to our unified handler (always non‐stream)
 | 
			
		||||
        return ResponseHandler.process_response(
 | 
			
		||||
            response=normalized,
 | 
			
		||||
            llm_provider=self.provider,
 | 
			
		||||
            response_format=response_format,
 | 
			
		||||
            structured_mode=structured_mode,
 | 
			
		||||
            stream=False,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,51 @@
 | 
			
		|||
import logging
 | 
			
		||||
import time
 | 
			
		||||
from typing import Any, Dict
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    LLMChatCandidate,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def process_dapr_chat_response(response: Dict[str, Any]) -> LLMChatResponse:
 | 
			
		||||
    """
 | 
			
		||||
    Convert a Dapr-normalized chat dict (with OpenAI-style 'choices') into a unified LLMChatResponse.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        response: The dict returned by `DaprChatClient.translate_response`.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        LLMChatResponse: Contains a list of candidates and metadata.
 | 
			
		||||
    """
 | 
			
		||||
    # 1) Extract each choice → build AssistantMessage + LLMChatCandidate
 | 
			
		||||
    candidates = []
 | 
			
		||||
    for choice in response.get("choices", []):
 | 
			
		||||
        msg = choice.get("message", {})
 | 
			
		||||
        assistant_message = AssistantMessage(
 | 
			
		||||
            content=msg.get("content"),
 | 
			
		||||
            # Dapr currently never returns refusals, tool_calls or function_call here
 | 
			
		||||
        )
 | 
			
		||||
        candidate = LLMChatCandidate(
 | 
			
		||||
            message=assistant_message,
 | 
			
		||||
            finish_reason=choice.get("finish_reason"),
 | 
			
		||||
            # Dapr translate_response includes index & no logprobs
 | 
			
		||||
            index=choice.get("index"),
 | 
			
		||||
            logprobs=choice.get("logprobs"),
 | 
			
		||||
        )
 | 
			
		||||
        candidates.append(candidate)
 | 
			
		||||
 | 
			
		||||
    # 2) Build metadata from the top‐level fields
 | 
			
		||||
    metadata: Dict[str, Any] = {
 | 
			
		||||
        "provider": "dapr",
 | 
			
		||||
        "id": response.get("id", None),
 | 
			
		||||
        "model": response.get("model", None),
 | 
			
		||||
        "object": response.get("object", None),
 | 
			
		||||
        "usage": response.get("usage", {}),
 | 
			
		||||
        "created": response.get("created", int(time.time())),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return LLMChatResponse(results=candidates, metadata=metadata)
 | 
			
		||||
| 
						 | 
				
			
			@ -13,34 +13,49 @@ from typing import (
 | 
			
		|||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import ChatCompletionOutput
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm.chat import ChatClientBase
 | 
			
		||||
from dapr_agents.llm.huggingface.client import HFHubInferenceClientBase
 | 
			
		||||
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
 | 
			
		||||
from dapr_agents.prompt.base import PromptTemplateBase
 | 
			
		||||
from dapr_agents.prompt.prompty import Prompty
 | 
			
		||||
from dapr_agents.tool import AgentTool
 | 
			
		||||
from dapr_agents.types.message import BaseMessage, ChatCompletion
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    BaseMessage,
 | 
			
		||||
    LLMChatCandidateChunk,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
 | 
			
		||||
    """
 | 
			
		||||
    Concrete class for the Hugging Face Hub's chat completion API using the Inference API.
 | 
			
		||||
    This class extends the ChatClientBase and provides the necessary configurations for Hugging Face models.
 | 
			
		||||
    Chat client for Hugging Face Hub's Inference API.
 | 
			
		||||
 | 
			
		||||
    Extends:
 | 
			
		||||
      - HFHubInferenceClientBase: manages HF-specific auth, endpoints, retries.
 | 
			
		||||
      - ChatClientBase: provides the `.from_prompty()` and `.generate()` contract.
 | 
			
		||||
 | 
			
		||||
    Supports only function_call-style structured responses.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"function_call"}
 | 
			
		||||
    prompty: Optional[Prompty] = Field(
 | 
			
		||||
        default=None, description="Optional Prompty instance for templating."
 | 
			
		||||
    )
 | 
			
		||||
    prompt_template: Optional[PromptTemplateBase] = Field(
 | 
			
		||||
        default=None, description="Optional prompt-template to format inputs."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call"}
 | 
			
		||||
 | 
			
		||||
    def model_post_init(self, __context: Any) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes private attributes for provider, api, config, and client after validation.
 | 
			
		||||
        After Pydantic __init__, set the internal API type to "chat".
 | 
			
		||||
        """
 | 
			
		||||
        # Set the private provider and api attributes
 | 
			
		||||
        self._api = "chat"
 | 
			
		||||
        return super().model_post_init(__context)
 | 
			
		||||
        super().model_post_init(__context)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_prompty(
 | 
			
		||||
| 
						 | 
				
			
			@ -49,34 +64,28 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
 | 
			
		|||
        timeout: Union[int, float, Dict[str, Any]] = 1500,
 | 
			
		||||
    ) -> "HFHubChatClient":
 | 
			
		||||
        """
 | 
			
		||||
        Initializes an HFHubChatClient client using a Prompty source, which can be a file path or inline content.
 | 
			
		||||
        Load a Prompty spec (file or inline), extract model config & prompt template,
 | 
			
		||||
        and return a configured HFHubChatClient.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
 | 
			
		||||
                or inline Prompty content as a string.
 | 
			
		||||
            timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
 | 
			
		||||
            prompty_source: Path or inline text of a Prompty YAML/JSON.
 | 
			
		||||
            timeout:        Request timeout (seconds or HTTPX timeout dict).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            HFHubChatClient: An instance of HFHubChatClient configured with the model settings from the Prompty source.
 | 
			
		||||
            HFHubChatClient: client ready for .generate() calls.
 | 
			
		||||
        """
 | 
			
		||||
        # Load the Prompty instance from the provided source
 | 
			
		||||
        prompty_instance = Prompty.load(prompty_source)
 | 
			
		||||
 | 
			
		||||
        # Generate the prompt template from the Prompty instance
 | 
			
		||||
        prompt_template = Prompty.to_prompt_template(prompty_instance)
 | 
			
		||||
        cfg = prompty_instance.model.configuration
 | 
			
		||||
 | 
			
		||||
        # Extract the model configuration from Prompty
 | 
			
		||||
        model_config = prompty_instance.model
 | 
			
		||||
 | 
			
		||||
        # Initialize the HFHubChatClient based on the Prompty model configuration
 | 
			
		||||
        return cls.model_validate(
 | 
			
		||||
            {
 | 
			
		||||
                "model": model_config.configuration.name,
 | 
			
		||||
                "api_key": model_config.configuration.api_key,
 | 
			
		||||
                "base_url": model_config.configuration.base_url,
 | 
			
		||||
                "headers": model_config.configuration.headers,
 | 
			
		||||
                "cookies": model_config.configuration.cookies,
 | 
			
		||||
                "proxies": model_config.configuration.proxies,
 | 
			
		||||
                "model": cfg.name,
 | 
			
		||||
                "api_key": cfg.api_key,
 | 
			
		||||
                "base_url": cfg.base_url,
 | 
			
		||||
                "headers": cfg.headers,
 | 
			
		||||
                "cookies": cfg.cookies,
 | 
			
		||||
                "proxies": cfg.proxies,
 | 
			
		||||
                "timeout": timeout,
 | 
			
		||||
                "prompty": prompty_instance,
 | 
			
		||||
                "prompt_template": prompt_template,
 | 
			
		||||
| 
						 | 
				
			
			@ -91,61 +100,72 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
 | 
			
		|||
            BaseMessage,
 | 
			
		||||
            Iterable[Union[Dict[str, Any], BaseMessage]],
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: Optional[Type[BaseModel]] = None,
 | 
			
		||||
        structured_mode: Literal["function_call"] = "function_call",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
 | 
			
		||||
        stream: bool = False,
 | 
			
		||||
        **kwargs: Any,  # accept any extra params, even if unused
 | 
			
		||||
    ) -> Union[
 | 
			
		||||
        Iterator[LLMChatCandidateChunk],
 | 
			
		||||
        LLMChatResponse,
 | 
			
		||||
        BaseModel,
 | 
			
		||||
        List[BaseModel],
 | 
			
		||||
    ]:
 | 
			
		||||
        """
 | 
			
		||||
        Generate chat completions based on provided messages or input_data for prompt templates.
 | 
			
		||||
        Issue a chat completion via Hugging Face's Inference API.
 | 
			
		||||
 | 
			
		||||
        - If `stream=True` in **kwargs**, returns an iterator of `LLMChatCandidateChunk`.
 | 
			
		||||
        - Otherwise returns either:
 | 
			
		||||
            • raw `AssistantMessage` wrapped in `LLMChatResponse`, or
 | 
			
		||||
            • validated Pydantic model(s) per `response_format`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            messages (Optional): Either pre-set messages or None if using input_data.
 | 
			
		||||
            input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
 | 
			
		||||
            model (str): Specific model to use for the request, overriding the default.
 | 
			
		||||
            tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
 | 
			
		||||
            response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
 | 
			
		||||
            structured_mode (Literal["function_call"]): Mode for structured output: "function_call" (Limited Support).
 | 
			
		||||
            **kwargs: Additional parameters for the language model.
 | 
			
		||||
            messages:        Pre-built messages or None to use `input_data`.
 | 
			
		||||
            input_data:      Variables for the Prompty template.
 | 
			
		||||
            model:           Override the client's default model name.
 | 
			
		||||
            tools:           List of AgentTool or dict specs.
 | 
			
		||||
            response_format: Pydantic model (or List[Model]) for structured output.
 | 
			
		||||
            structured_mode: Must be `"function_call"` (only supported mode here).
 | 
			
		||||
            stream:          If True, return an iterator of `LLMChatCandidateChunk`.
 | 
			
		||||
            **kwargs:        Any other LLM params (temperature, top_p, stream, etc.).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
 | 
			
		||||
        """
 | 
			
		||||
            • `Iterator[LLMChatCandidateChunk]` if streaming
 | 
			
		||||
            • `LLMChatResponse` or Pydantic instance(s) if non-streaming
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: on invalid `structured_mode`, missing prompts, or API errors.
 | 
			
		||||
        """
 | 
			
		||||
        # 1) Validate structured_mode
 | 
			
		||||
        if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
 | 
			
		||||
                f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # If input_data is provided, check for a prompt_template
 | 
			
		||||
        # 2) If using a prompt template, build messages
 | 
			
		||||
        if input_data:
 | 
			
		||||
            if not self.prompt_template:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            logger.info("Using prompt template to generate messages.")
 | 
			
		||||
                raise ValueError("No prompt_template set for input_data usage.")
 | 
			
		||||
            logger.info("Formatting messages via prompt_template.")
 | 
			
		||||
            messages = self.prompt_template.format_prompt(**input_data)
 | 
			
		||||
 | 
			
		||||
        # Ensure we have messages at this point
 | 
			
		||||
        if not messages:
 | 
			
		||||
            raise ValueError("Either 'messages' or 'input_data' must be provided.")
 | 
			
		||||
            raise ValueError("Either messages or input_data must be provided.")
 | 
			
		||||
 | 
			
		||||
        # Process and normalize the messages
 | 
			
		||||
        # 3) Normalize messages + merge client/prompty defaults
 | 
			
		||||
        params = {"messages": RequestHandler.normalize_chat_messages(messages)}
 | 
			
		||||
 | 
			
		||||
        # Merge Prompty parameters if available, then override with any explicit kwargs
 | 
			
		||||
        if self.prompty:
 | 
			
		||||
            params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
 | 
			
		||||
        else:
 | 
			
		||||
            params.update(kwargs)
 | 
			
		||||
 | 
			
		||||
        # If a model is provided, override the default model
 | 
			
		||||
        # 4) Override model if given
 | 
			
		||||
        params["model"] = model or self.model
 | 
			
		||||
 | 
			
		||||
        # Prepare request parameters
 | 
			
		||||
        # 5) Inject tools / response_format via RequestHandler
 | 
			
		||||
        params = RequestHandler.process_params(
 | 
			
		||||
            params,
 | 
			
		||||
            llm_provider=self.provider,
 | 
			
		||||
| 
						 | 
				
			
			@ -154,31 +174,28 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
 | 
			
		|||
            structured_mode=structured_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 6) Call HF API + delegate parsing to ResponseHandler
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info("Invoking Hugging Face ChatCompletion API.")
 | 
			
		||||
            response: ChatCompletionOutput = self.client.chat.completions.create(
 | 
			
		||||
                **params
 | 
			
		||||
            )
 | 
			
		||||
            logger.info("Chat completion retrieved successfully.")
 | 
			
		||||
            logger.info("Calling HF ChatCompletion Inference API...")
 | 
			
		||||
            logger.debug(f"HF params: {params}")
 | 
			
		||||
            response = self.client.chat.completions.create(**params, stream=stream)
 | 
			
		||||
            logger.info("HF ChatCompletion response received.")
 | 
			
		||||
 | 
			
		||||
            # Hugging Face error handling
 | 
			
		||||
            status = getattr(response, "statuscode", 200)
 | 
			
		||||
            # HF-specific error‐code handling
 | 
			
		||||
            code = getattr(response, "code", 200)
 | 
			
		||||
            if code != 200:
 | 
			
		||||
                logger.error(
 | 
			
		||||
                    f"❌ Status Code:{status} - Code:{code} Error: {getattr(response, 'message', response)}"
 | 
			
		||||
                )
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    f"{status}/{code} Error: {getattr(response, 'message', response)}"
 | 
			
		||||
                )
 | 
			
		||||
                msg = getattr(response, "message", response)
 | 
			
		||||
                logger.error(f"❌ HF error {code}: {msg}")
 | 
			
		||||
                raise RuntimeError(f"HuggingFace error {code}: {msg}")
 | 
			
		||||
 | 
			
		||||
            return ResponseHandler.process_response(
 | 
			
		||||
                response,
 | 
			
		||||
                response=response,
 | 
			
		||||
                llm_provider=self.provider,
 | 
			
		||||
                response_format=response_format,
 | 
			
		||||
                structured_mode=structured_mode,
 | 
			
		||||
                stream=params.get("stream", False),
 | 
			
		||||
                stream=stream,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"An error occurred during the ChatCompletion API call: {e}")
 | 
			
		||||
            raise
 | 
			
		||||
            logger.error("Hugging Face ChatCompletion API error", exc_info=True)
 | 
			
		||||
            raise ValueError("Failed to process HF chat completion") from e
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,260 @@
 | 
			
		|||
import dataclasses
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Any, Callable, Dict, Iterator, Optional
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    FunctionCall,
 | 
			
		||||
    LLMChatCandidate,
 | 
			
		||||
    LLMChatCandidateChunk,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
    LLMChatResponseChunk,
 | 
			
		||||
    ToolCall,
 | 
			
		||||
    ToolCallChunk,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Helper function to handle metadata extraction
 | 
			
		||||
def _get_packet_metadata(
 | 
			
		||||
    pkt: Dict[str, Any], enrich_metadata: Optional[Dict[str, Any]]
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Extract metadata from HuggingFace packet and merge with enrich_metadata.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        pkt (Dict[str, Any]): The HuggingFace packet from which to extract metadata.
 | 
			
		||||
        enrich_metadata (Optional[Dict[str, Any]]): Additional metadata to merge with the extracted metadata.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dict[str, Any]: The merged metadata dictionary.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        return {
 | 
			
		||||
            "id": pkt.get("id"),
 | 
			
		||||
            "created": pkt.get("created"),
 | 
			
		||||
            "model": pkt.get("model"),
 | 
			
		||||
            "object": pkt.get("object"),
 | 
			
		||||
            "service_tier": pkt.get("service_tier"),
 | 
			
		||||
            "system_fingerprint": pkt.get("system_fingerprint"),
 | 
			
		||||
            **(enrich_metadata or {}),
 | 
			
		||||
        }
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.error(f"Failed to parse packet: {e}", exc_info=True)
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Helper function to process each choice delta (content, function call, tool call, finish reason)
 | 
			
		||||
def _process_choice_delta(
 | 
			
		||||
    choice: Dict[str, Any],
 | 
			
		||||
    overall_meta: Dict[str, Any],
 | 
			
		||||
    on_chunk: Optional[Callable],
 | 
			
		||||
    first_chunk_flag: bool,
 | 
			
		||||
) -> Iterator[LLMChatResponseChunk]:
 | 
			
		||||
    """
 | 
			
		||||
    Process each choice delta and yield corresponding chunks.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        choice (Dict[str, Any]): The choice delta from HuggingFace response.
 | 
			
		||||
        overall_meta (Dict[str, Any]): Overall metadata to include in chunks.
 | 
			
		||||
        on_chunk (Optional[Callable]): Callback for each chunk.
 | 
			
		||||
        first_chunk_flag (bool): Flag indicating if this is the first chunk.
 | 
			
		||||
 | 
			
		||||
    Yields:
 | 
			
		||||
        LLMChatResponseChunk: The processed chunk with content, function call, tool calls,
 | 
			
		||||
    """
 | 
			
		||||
    # Make an immutable snapshot for this single chunk
 | 
			
		||||
    meta = {**overall_meta}
 | 
			
		||||
 | 
			
		||||
    # mark first_chunk exactly once
 | 
			
		||||
    if first_chunk_flag and "first_chunk" not in meta:
 | 
			
		||||
        meta["first_chunk"] = True
 | 
			
		||||
 | 
			
		||||
    # Extract initial properties from choice
 | 
			
		||||
    delta: dict = choice.get("delta", {})
 | 
			
		||||
    idx = choice.get("index")
 | 
			
		||||
    finish_reason = choice.get("finish_reason", None)
 | 
			
		||||
    logprobs = choice.get("logprobs", None)
 | 
			
		||||
 | 
			
		||||
    # Set additional metadata
 | 
			
		||||
    if finish_reason in ("stop", "tool_calls"):
 | 
			
		||||
        meta["last_chunk"] = True
 | 
			
		||||
 | 
			
		||||
    # Process content delta
 | 
			
		||||
    content = delta.get("content", None)
 | 
			
		||||
    function_call = delta.get("function_call", None)
 | 
			
		||||
    refusal = delta.get("refusal", None)
 | 
			
		||||
    role = delta.get("role", None)
 | 
			
		||||
 | 
			
		||||
    # Process tool calls
 | 
			
		||||
    chunk_tool_calls = [ToolCallChunk(**tc) for tc in (delta.get("tool_calls") or [])]
 | 
			
		||||
 | 
			
		||||
    # Initialize LLMChatResponseChunk
 | 
			
		||||
    response_chunk = LLMChatResponseChunk(
 | 
			
		||||
        result=LLMChatCandidateChunk(
 | 
			
		||||
            content=content,
 | 
			
		||||
            function_call=function_call,
 | 
			
		||||
            refusal=refusal,
 | 
			
		||||
            role=role,
 | 
			
		||||
            tool_calls=chunk_tool_calls,
 | 
			
		||||
            finish_reason=finish_reason,
 | 
			
		||||
            index=idx,
 | 
			
		||||
            logprobs=logprobs,
 | 
			
		||||
        ),
 | 
			
		||||
        metadata=meta,
 | 
			
		||||
    )
 | 
			
		||||
    # Process chunk with on_chunk callback
 | 
			
		||||
    if on_chunk:
 | 
			
		||||
        on_chunk(response_chunk)
 | 
			
		||||
    # Yield LLMChatResponseChunk
 | 
			
		||||
    yield response_chunk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Main function to process HuggingFace streaming response
 | 
			
		||||
def process_hf_stream(
 | 
			
		||||
    raw_stream: Iterator[ChatCompletionStreamOutput],
 | 
			
		||||
    *,
 | 
			
		||||
    enrich_metadata: Optional[Dict[str, Any]] = None,
 | 
			
		||||
    on_chunk: Optional[Callable],
 | 
			
		||||
) -> Iterator[LLMChatCandidateChunk]:
 | 
			
		||||
    """
 | 
			
		||||
    Normalize HuggingFace streaming chat into LLMChatCandidateChunk objects,
 | 
			
		||||
    accumulating buffers per choice and yielding both partial and final chunks.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        raw_stream: Iterator from client.chat.completions.create(..., stream=True)
 | 
			
		||||
        enrich_metadata: Extra key/value pairs to merge into each chunk.metadata
 | 
			
		||||
        on_chunk:   Callback fired on every partial delta (token, function, tool)
 | 
			
		||||
 | 
			
		||||
    Yields:
 | 
			
		||||
        LLMChatCandidateChunk for every partial and final piece, in stream order
 | 
			
		||||
    """
 | 
			
		||||
    enrich_metadata = enrich_metadata or {}
 | 
			
		||||
    overall_meta: Dict[str, Any] = {}
 | 
			
		||||
 | 
			
		||||
    # Track if we are in the first chunk
 | 
			
		||||
    first_chunk_flag = True
 | 
			
		||||
 | 
			
		||||
    for packet in raw_stream:
 | 
			
		||||
        # Convert Pydantic / HuggingFaceObject → plain dict
 | 
			
		||||
        if hasattr(packet, "model_dump"):
 | 
			
		||||
            pkt = packet.model_dump()
 | 
			
		||||
        elif hasattr(packet, "to_dict"):
 | 
			
		||||
            pkt = packet.to_dict()
 | 
			
		||||
        elif dataclasses.is_dataclass(packet):
 | 
			
		||||
            pkt = dataclasses.asdict(packet)
 | 
			
		||||
        else:
 | 
			
		||||
            raise TypeError(f"Cannot serialize packet of type {type(packet)}")
 | 
			
		||||
 | 
			
		||||
        # Capture overall metadata from the packet
 | 
			
		||||
        overall_meta = _get_packet_metadata(pkt, enrich_metadata)
 | 
			
		||||
 | 
			
		||||
        # Process each choice in this packet
 | 
			
		||||
        if choices := pkt.get("choices"):
 | 
			
		||||
            if len(choices) == 0:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    "Received empty 'choices' in HuggingFace packet, skipping."
 | 
			
		||||
                )
 | 
			
		||||
                continue
 | 
			
		||||
            # Process the first choice in the packet
 | 
			
		||||
            choice = choices[0]
 | 
			
		||||
            yield from _process_choice_delta(
 | 
			
		||||
                choice, overall_meta, on_chunk, first_chunk_flag
 | 
			
		||||
            )
 | 
			
		||||
            # Set first_chunk_flag to False after processing the first choice
 | 
			
		||||
            first_chunk_flag = False
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning(f" Yielding packet without 'choices': {pkt}")
 | 
			
		||||
            # Initialize default LLMChatResponseChunk
 | 
			
		||||
            final_response_chunk = LLMChatResponseChunk(metadata=overall_meta)
 | 
			
		||||
            yield final_response_chunk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def process_hf_chat_response(response: ChatCompletionOutput) -> LLMChatResponse:
 | 
			
		||||
    """
 | 
			
		||||
    Convert a non-streaming Hugging Face ChatCompletionOutput into our unified LLMChatResponse.
 | 
			
		||||
 | 
			
		||||
    This will:
 | 
			
		||||
      1. Turn the HF dataclass into a plain dict via .model_dump() or .dict().
 | 
			
		||||
      2. Extract each `choice`, build an AssistantMessage (including any tool_calls or
 | 
			
		||||
         function_call shortcuts), wrap in LLMChatCandidate.
 | 
			
		||||
      3. Collect top-level metadata (id, model, usage, etc.) into an LLMChatResponse.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        response: The HFHubInferenceClientBase.chat.completions.create(...) output.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        An LLMChatResponse containing all chat candidates and metadata.
 | 
			
		||||
    """
 | 
			
		||||
    # 1) serialise the HF object to a primitive dict
 | 
			
		||||
    try:
 | 
			
		||||
        if hasattr(response, "model_dump"):
 | 
			
		||||
            resp: Dict[str, Any] = response.model_dump()
 | 
			
		||||
        elif hasattr(response, "dict"):
 | 
			
		||||
            resp: Dict[str, Any] = response.dict()
 | 
			
		||||
        elif dataclasses.is_dataclass(response):
 | 
			
		||||
            resp = dataclasses.asdict(response)
 | 
			
		||||
        elif hasattr(response, "to_dict"):
 | 
			
		||||
            resp = response.to_dict()
 | 
			
		||||
        else:
 | 
			
		||||
            raise TypeError(f"Cannot serialize object of type {type(response)}")
 | 
			
		||||
    except Exception:
 | 
			
		||||
        logger.exception("Failed to serialize HF chat response")
 | 
			
		||||
        resp = {}
 | 
			
		||||
 | 
			
		||||
    candidates = []
 | 
			
		||||
    for choice in resp.get("choices", []):
 | 
			
		||||
        msg = choice.get("message") or {}
 | 
			
		||||
 | 
			
		||||
        # 2) build tool_calls list if present
 | 
			
		||||
        tool_calls: Optional[list[ToolCall]] = None
 | 
			
		||||
        if msg.get("tool_calls"):
 | 
			
		||||
            tool_calls = []
 | 
			
		||||
            for tc in msg["tool_calls"]:
 | 
			
		||||
                try:
 | 
			
		||||
                    tool_calls.append(ToolCall(**tc))
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    logger.exception(f"Invalid HF tool_call entry: {tc}")
 | 
			
		||||
 | 
			
		||||
        # 2b) handle the single‑ID shortcut
 | 
			
		||||
        if msg.get("tool_call_id") and not tool_calls:
 | 
			
		||||
            # HF only sent you an ID; we turn that into a zero‑arg function_call
 | 
			
		||||
            fc = FunctionCall(name=msg["tool_call_id"], arguments="")
 | 
			
		||||
            tool_calls = [
 | 
			
		||||
                ToolCall(id=msg["tool_call_id"], type="function", function=fc)
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        # 3) promote first tool_call into function_call if desired
 | 
			
		||||
        function_call = tool_calls[0].function if tool_calls else None
 | 
			
		||||
 | 
			
		||||
        assistant = AssistantMessage(
 | 
			
		||||
            content=msg.get("content"),
 | 
			
		||||
            refusal=None,
 | 
			
		||||
            tool_calls=tool_calls,
 | 
			
		||||
            function_call=function_call,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        candidates.append(
 | 
			
		||||
            LLMChatCandidate(
 | 
			
		||||
                message=assistant,
 | 
			
		||||
                finish_reason=choice.get("finish_reason"),
 | 
			
		||||
                index=choice.get("index"),
 | 
			
		||||
                logprobs=choice.get("logprobs"),
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # 4) collect overall metadata
 | 
			
		||||
    metadata = {
 | 
			
		||||
        "provider": "huggingface",
 | 
			
		||||
        "id": resp.get("id"),
 | 
			
		||||
        "model": resp.get("model"),
 | 
			
		||||
        "created": resp.get("created"),
 | 
			
		||||
        "system_fingerprint": resp.get("system_fingerprint"),
 | 
			
		||||
        "usage": resp.get("usage"),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return LLMChatResponse(results=candidates, metadata=metadata)
 | 
			
		||||
| 
						 | 
				
			
			@ -13,44 +13,61 @@ from typing import (
 | 
			
		|||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from openai.types.chat import ChatCompletionMessage
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm.chat import ChatClientBase
 | 
			
		||||
from dapr_agents.llm.nvidia.client import NVIDIAClientBase
 | 
			
		||||
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
 | 
			
		||||
from dapr_agents.prompt.base import PromptTemplateBase
 | 
			
		||||
from dapr_agents.prompt.prompty import Prompty
 | 
			
		||||
from dapr_agents.tool import AgentTool
 | 
			
		||||
from dapr_agents.types.message import BaseMessage, ChatCompletion
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    BaseMessage,
 | 
			
		||||
    LLMChatCandidateChunk,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
 | 
			
		||||
    """
 | 
			
		||||
    Chat client for NVIDIA chat models.
 | 
			
		||||
    Combines NVIDIA client management with Prompty-specific functionality for handling chat completions.
 | 
			
		||||
    Chat client for NVIDIA chat models, combining NVIDIA client management
 | 
			
		||||
    with Prompty-specific prompt templates and unified request/response handling.
 | 
			
		||||
 | 
			
		||||
    Inherits:
 | 
			
		||||
      - NVIDIAClientBase: manages API key, base_url, retries, etc. for NVIDIA endpoints.
 | 
			
		||||
      - ChatClientBase: provides chat-specific abstractions.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        model: the model name to use (e.g. "meta/llama-3.1-8b-instruct").
 | 
			
		||||
        max_tokens: maximum number of tokens to generate per call.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    model: str = Field(
 | 
			
		||||
        default="meta/llama3-8b-instruct",
 | 
			
		||||
        description="Model name to use. Defaults to 'meta/llama3-8b-instruct'.",
 | 
			
		||||
        default="meta/llama-3.1-8b-instruct",
 | 
			
		||||
        description="Model name to use. Defaults to 'meta/llama-3.1-8b-instruct'.",
 | 
			
		||||
    )
 | 
			
		||||
    max_tokens: Optional[int] = Field(
 | 
			
		||||
        default=1024,
 | 
			
		||||
        description=(
 | 
			
		||||
            "The maximum number of tokens to generate in any given call. Must be an integer ≥ 1. Defaults to 1024."
 | 
			
		||||
            "Maximum number of tokens to generate in a single call. "
 | 
			
		||||
            "Must be ≥1; defaults to 1024."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    prompty: Optional[Prompty] = Field(
 | 
			
		||||
        default=None, description="Optional Prompty instance for templating."
 | 
			
		||||
    )
 | 
			
		||||
    prompt_template: Optional[PromptTemplateBase] = Field(
 | 
			
		||||
        default=None, description="Optional prompt-template to format inputs."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"function_call"}
 | 
			
		||||
    # NVIDIA currently only supports function_call structured output
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call"}
 | 
			
		||||
 | 
			
		||||
    def model_post_init(self, __context: Any) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes chat-specific attributes after validation.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            __context (Any): Additional context for post-initialization (not used here).
 | 
			
		||||
        After Pydantic init, configure the client for 'chat' API.
 | 
			
		||||
        """
 | 
			
		||||
        self._api = "chat"
 | 
			
		||||
        super().model_post_init(__context)
 | 
			
		||||
| 
						 | 
				
			
			@ -58,30 +75,23 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
 | 
			
		|||
    @classmethod
 | 
			
		||||
    def from_prompty(cls, prompty_source: Union[str, Path]) -> "NVIDIAChatClient":
 | 
			
		||||
        """
 | 
			
		||||
        Initializes an NVIDIAChatClient client using a Prompty source, which can be a file path or inline content.
 | 
			
		||||
        Build an NVIDIAChatClient from a Prompty spec (file path or inline YAML/JSON).
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
 | 
			
		||||
                or inline Prompty content as a string.
 | 
			
		||||
            prompty_source: Path or inline content of a Prompty specification.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            NVIDIAChatClient: An instance of NVIDIAChatClient configured with the model settings from the Prompty source.
 | 
			
		||||
            Configured NVIDIAChatClient instance.
 | 
			
		||||
        """
 | 
			
		||||
        # Load the Prompty instance from the provided source
 | 
			
		||||
        prompty_instance = Prompty.load(prompty_source)
 | 
			
		||||
 | 
			
		||||
        # Generate the prompt template from the Prompty instance
 | 
			
		||||
        prompt_template = Prompty.to_prompt_template(prompty_instance)
 | 
			
		||||
        cfg = prompty_instance.model.configuration
 | 
			
		||||
 | 
			
		||||
        # Extract the model configuration from Prompty
 | 
			
		||||
        model_config = prompty_instance.model
 | 
			
		||||
 | 
			
		||||
        # Initialize the NVIDIAChatClient instance using model_validate
 | 
			
		||||
        return cls.model_validate(
 | 
			
		||||
            {
 | 
			
		||||
                "model": model_config.configuration.name,
 | 
			
		||||
                "api_key": model_config.configuration.api_key,
 | 
			
		||||
                "base_url": model_config.configuration.base_url,
 | 
			
		||||
                "model": cfg.name,
 | 
			
		||||
                "api_key": cfg.api_key,
 | 
			
		||||
                "base_url": cfg.base_url,
 | 
			
		||||
                "prompty": prompty_instance,
 | 
			
		||||
                "prompt_template": prompt_template,
 | 
			
		||||
            }
 | 
			
		||||
| 
						 | 
				
			
			@ -95,66 +105,74 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
 | 
			
		|||
            BaseMessage,
 | 
			
		||||
            Iterable[Union[Dict[str, Any], BaseMessage]],
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: Optional[Type[BaseModel]] = None,
 | 
			
		||||
        max_tokens: Optional[int] = None,
 | 
			
		||||
        structured_mode: Literal["function_call"] = "function_call",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
 | 
			
		||||
        stream: bool = False,
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> Union[
 | 
			
		||||
        Iterator[LLMChatCandidateChunk],  # streaming
 | 
			
		||||
        LLMChatResponse,  # non‑stream + no format
 | 
			
		||||
        BaseModel,  # non‑stream + single structured format
 | 
			
		||||
        List[BaseModel],  # non‑stream + list structured format
 | 
			
		||||
    ]:
 | 
			
		||||
        """
 | 
			
		||||
        Generate chat completions based on provided messages or input_data for prompt templates.
 | 
			
		||||
        Issue a chat completion to NVIDIA.
 | 
			
		||||
 | 
			
		||||
        - If `stream=True` in kwargs, returns an iterator of
 | 
			
		||||
          `LLMChatCandidateChunk` via ResponseHandler.
 | 
			
		||||
        - Otherwise returns either:
 | 
			
		||||
            • a raw `LLMChatResponse`, or
 | 
			
		||||
            • validated Pydantic model(s) per `response_format`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            messages (Optional): Either pre-set messages or None if using input_data.
 | 
			
		||||
            input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
 | 
			
		||||
            model (str): Specific model to use for the request, overriding the default.
 | 
			
		||||
            tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
 | 
			
		||||
            response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
 | 
			
		||||
            max_tokens (Optional[int]): The maximum number of tokens to generate. Defaults to the instance setting.
 | 
			
		||||
            structured_mode (Literal["function_call"]): Mode for structured output: "function_call" (Limited Support).
 | 
			
		||||
            **kwargs: Additional parameters for the language model.
 | 
			
		||||
            messages:        Pre-built messages or None to use `input_data`.
 | 
			
		||||
            input_data:      Variables for the Prompty template.
 | 
			
		||||
            model:           Override default model name.
 | 
			
		||||
            tools:           List of AgentTool or dict specs.
 | 
			
		||||
            response_format: Pydantic model (or list thereof) for structured output.
 | 
			
		||||
            max_tokens:      Override default max_tokens.
 | 
			
		||||
            structured_mode: Must be "function_call" (only supported mode).
 | 
			
		||||
            stream:          If True, return an iterator of `LLMChatCandidateChunk`.
 | 
			
		||||
            **kwargs:        Other LLM params (temperature, stream, etc.).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: for invalid structured_mode or missing inputs.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # 1) Validate structured_mode
 | 
			
		||||
        if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
 | 
			
		||||
                f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # If input_data is provided, check for a prompt_template
 | 
			
		||||
        # 2) If input_data is provided, format messages via Prompty
 | 
			
		||||
        if input_data:
 | 
			
		||||
            if not self.prompt_template:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            logger.info("Using prompt template to generate messages.")
 | 
			
		||||
                raise ValueError("input_data provided but no prompt_template is set.")
 | 
			
		||||
            logger.info("Formatting messages via prompt_template.")
 | 
			
		||||
            messages = self.prompt_template.format_prompt(**input_data)
 | 
			
		||||
 | 
			
		||||
        # Ensure we have messages at this point
 | 
			
		||||
        if not messages:
 | 
			
		||||
            raise ValueError("Either 'messages' or 'input_data' must be provided.")
 | 
			
		||||
 | 
			
		||||
        # Process and normalize the messages
 | 
			
		||||
        params = {"messages": RequestHandler.normalize_chat_messages(messages)}
 | 
			
		||||
 | 
			
		||||
        # Merge prompty parameters if available, then override with any explicit kwargs
 | 
			
		||||
        # 3) Normalize messages + merge client/prompty defaults
 | 
			
		||||
        params: Dict[str, Any] = {
 | 
			
		||||
            "messages": RequestHandler.normalize_chat_messages(messages)
 | 
			
		||||
        }
 | 
			
		||||
        if self.prompty:
 | 
			
		||||
            params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
 | 
			
		||||
        else:
 | 
			
		||||
            params.update(kwargs)
 | 
			
		||||
 | 
			
		||||
        # If a model is provided, override the default model
 | 
			
		||||
        # 4) Override model & max_tokens if provided
 | 
			
		||||
        params["model"] = model or self.model
 | 
			
		||||
 | 
			
		||||
        # Apply max_tokens if provided
 | 
			
		||||
        params["max_tokens"] = max_tokens or self.max_tokens
 | 
			
		||||
 | 
			
		||||
        # Prepare request parameters
 | 
			
		||||
        # 5) Inject tools / response_format / structured_mode
 | 
			
		||||
        params = RequestHandler.process_params(
 | 
			
		||||
            params,
 | 
			
		||||
            llm_provider=self.provider,
 | 
			
		||||
| 
						 | 
				
			
			@ -163,21 +181,18 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
 | 
			
		|||
            structured_mode=structured_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 6) Call NVIDIA API + dispatch to ResponseHandler
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info("Invoking ChatCompletion API.")
 | 
			
		||||
            logger.debug(f"ChatCompletion API Parameters:{params}")
 | 
			
		||||
            response: ChatCompletionMessage = self.client.chat.completions.create(
 | 
			
		||||
                **params
 | 
			
		||||
            )
 | 
			
		||||
            logger.info("Chat completion retrieved successfully.")
 | 
			
		||||
 | 
			
		||||
            logger.info("Calling NVIDIA ChatCompletion API.")
 | 
			
		||||
            logger.debug(f"Parameters: {params}")
 | 
			
		||||
            resp = self.client.chat.completions.create(**params, stream=stream)
 | 
			
		||||
            return ResponseHandler.process_response(
 | 
			
		||||
                response,
 | 
			
		||||
                response=resp,
 | 
			
		||||
                llm_provider=self.provider,
 | 
			
		||||
                response_format=response_format,
 | 
			
		||||
                structured_mode=structured_mode,
 | 
			
		||||
                stream=params.get("stream", False),
 | 
			
		||||
                stream=stream,
 | 
			
		||||
            )
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"An error occurred during the ChatCompletion API call: {e}")
 | 
			
		||||
        except Exception:
 | 
			
		||||
            logger.error("NVIDIA ChatCompletion error", exc_info=True)
 | 
			
		||||
            raise
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,44 +13,58 @@ from typing import (
 | 
			
		|||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from openai.types.chat import ChatCompletionMessage
 | 
			
		||||
from pydantic import BaseModel, Field, model_validator
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm.chat import ChatClientBase
 | 
			
		||||
from dapr_agents.llm.openai.client.base import OpenAIClientBase
 | 
			
		||||
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
 | 
			
		||||
from dapr_agents.prompt.base import PromptTemplateBase
 | 
			
		||||
from dapr_agents.prompt.prompty import Prompty
 | 
			
		||||
from dapr_agents.tool import AgentTool
 | 
			
		||||
from dapr_agents.types.llm import AzureOpenAIModelConfig, OpenAIModelConfig
 | 
			
		||||
from dapr_agents.types.message import BaseMessage, ChatCompletion
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    BaseMessage,
 | 
			
		||||
    LLMChatCandidateChunk,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
 | 
			
		||||
    """
 | 
			
		||||
    Chat client for OpenAI models.
 | 
			
		||||
    Combines OpenAI client management with Prompty-specific functionality.
 | 
			
		||||
    Chat client for OpenAI models, layering in Prompty-driven prompt templates
 | 
			
		||||
    and unified request/response handling.
 | 
			
		||||
 | 
			
		||||
    Inherits:
 | 
			
		||||
      - OpenAIClientBase: manages API key, base_url, retries, etc.
 | 
			
		||||
      - ChatClientBase: provides chat-specific abstractions.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    model: str = Field(default=None, description="Model name to use, e.g., 'gpt-4'.")
 | 
			
		||||
    model: Optional[str] = Field(
 | 
			
		||||
        default=None, description="Model name or Azure deployment ID."
 | 
			
		||||
    )
 | 
			
		||||
    prompty: Optional[Prompty] = Field(
 | 
			
		||||
        default=None, description="Optional Prompty instance for templating."
 | 
			
		||||
    )
 | 
			
		||||
    prompt_template: Optional[PromptTemplateBase] = Field(
 | 
			
		||||
        default=None, description="Optional prompt-template to format inputs."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"json", "function_call"}
 | 
			
		||||
    SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"json", "function_call"}
 | 
			
		||||
 | 
			
		||||
    @model_validator(mode="before")
 | 
			
		||||
    def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Ensures the 'model' is set during validation.
 | 
			
		||||
        Uses 'azure_deployment' if no model is specified, defaults to 'gpt-4o'.
 | 
			
		||||
        Ensure `.model` is always set.  If unset, fall back to `azure_deployment`
 | 
			
		||||
        or default to `"gpt-4o"`.
 | 
			
		||||
        """
 | 
			
		||||
        if "model" not in values or values["model"] is None:
 | 
			
		||||
        if not values.get("model"):
 | 
			
		||||
            values["model"] = values.get("azure_deployment", "gpt-4o")
 | 
			
		||||
        return values
 | 
			
		||||
 | 
			
		||||
    def model_post_init(self, __context: Any) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes chat-specific attributes after validation.
 | 
			
		||||
        """
 | 
			
		||||
        """After Pydantic init, ensure we're in the “chat” API mode."""
 | 
			
		||||
        self._api = "chat"
 | 
			
		||||
        super().model_post_init(__context)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -61,60 +75,54 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
 | 
			
		|||
        timeout: Union[int, float, Dict[str, Any]] = 1500,
 | 
			
		||||
    ) -> "OpenAIChatClient":
 | 
			
		||||
        """
 | 
			
		||||
        Initializes an OpenAIChatClient client using a Prompty source, which can be a file path or inline content.
 | 
			
		||||
        Load a Prompty file (or inline YAML/JSON string), extract its
 | 
			
		||||
        model configuration and prompt template, and return a fully-wired client.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
 | 
			
		||||
                or inline Prompty content as a string.
 | 
			
		||||
            timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
 | 
			
		||||
            prompty_source: path or inline text for a Prompty spec.
 | 
			
		||||
            timeout:        seconds or HTTPX-style timeout, defaults to 1500.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            OpenAIChatClient: An instance of OpenAIChatClient configured with the model settings from the Prompty source.
 | 
			
		||||
            Configured OpenAIChatClient.
 | 
			
		||||
        """
 | 
			
		||||
        # Load the Prompty instance from the provided source
 | 
			
		||||
        prompty_instance = Prompty.load(prompty_source)
 | 
			
		||||
 | 
			
		||||
        # Generate the prompt template from the Prompty instance
 | 
			
		||||
        prompt_template = Prompty.to_prompt_template(prompty_instance)
 | 
			
		||||
        cfg = prompty_instance.model.configuration
 | 
			
		||||
 | 
			
		||||
        # Extract the model configuration from Prompty
 | 
			
		||||
        model_config = prompty_instance.model
 | 
			
		||||
        common = {
 | 
			
		||||
            "timeout": timeout,
 | 
			
		||||
            "prompty": prompty_instance,
 | 
			
		||||
            "prompt_template": prompt_template,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        # Initialize the OpenAIChatClient instance using model_validate
 | 
			
		||||
        if isinstance(model_config.configuration, OpenAIModelConfig):
 | 
			
		||||
        if isinstance(cfg, OpenAIModelConfig):
 | 
			
		||||
            return cls.model_validate(
 | 
			
		||||
                {
 | 
			
		||||
                    "model": model_config.configuration.name,
 | 
			
		||||
                    "api_key": model_config.configuration.api_key,
 | 
			
		||||
                    "base_url": model_config.configuration.base_url,
 | 
			
		||||
                    "organization": model_config.configuration.organization,
 | 
			
		||||
                    "project": model_config.configuration.project,
 | 
			
		||||
                    "timeout": timeout,
 | 
			
		||||
                    "prompty": prompty_instance,
 | 
			
		||||
                    "prompt_template": prompt_template,
 | 
			
		||||
                    **common,
 | 
			
		||||
                    "model": cfg.name,
 | 
			
		||||
                    "api_key": cfg.api_key,
 | 
			
		||||
                    "base_url": cfg.base_url,
 | 
			
		||||
                    "organization": cfg.organization,
 | 
			
		||||
                    "project": cfg.project,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        elif isinstance(model_config.configuration, AzureOpenAIModelConfig):
 | 
			
		||||
        elif isinstance(cfg, AzureOpenAIModelConfig):
 | 
			
		||||
            return cls.model_validate(
 | 
			
		||||
                {
 | 
			
		||||
                    "model": model_config.configuration.azure_deployment,
 | 
			
		||||
                    "api_key": model_config.configuration.api_key,
 | 
			
		||||
                    "azure_endpoint": model_config.configuration.azure_endpoint,
 | 
			
		||||
                    "azure_deployment": model_config.configuration.azure_deployment,
 | 
			
		||||
                    "api_version": model_config.configuration.api_version,
 | 
			
		||||
                    "organization": model_config.configuration.organization,
 | 
			
		||||
                    "project": model_config.configuration.project,
 | 
			
		||||
                    "azure_ad_token": model_config.configuration.azure_ad_token,
 | 
			
		||||
                    "azure_client_id": model_config.configuration.azure_client_id,
 | 
			
		||||
                    "timeout": timeout,
 | 
			
		||||
                    "prompty": prompty_instance,
 | 
			
		||||
                    "prompt_template": prompt_template,
 | 
			
		||||
                    **common,
 | 
			
		||||
                    "model": cfg.azure_deployment,
 | 
			
		||||
                    "api_key": cfg.api_key,
 | 
			
		||||
                    "azure_endpoint": cfg.azure_endpoint,
 | 
			
		||||
                    "azure_deployment": cfg.azure_deployment,
 | 
			
		||||
                    "api_version": cfg.api_version,
 | 
			
		||||
                    "organization": cfg.organization,
 | 
			
		||||
                    "project": cfg.project,
 | 
			
		||||
                    "azure_ad_token": cfg.azure_ad_token,
 | 
			
		||||
                    "azure_client_id": cfg.azure_client_id,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Unsupported model configuration type: {type(model_config.configuration)}"
 | 
			
		||||
            )
 | 
			
		||||
            raise ValueError(f"Unsupported model config: {type(cfg)}")
 | 
			
		||||
 | 
			
		||||
    def generate(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -124,61 +132,72 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
 | 
			
		|||
            BaseMessage,
 | 
			
		||||
            Iterable[Union[Dict[str, Any], BaseMessage]],
 | 
			
		||||
        ] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
 | 
			
		||||
        response_format: Optional[Type[BaseModel]] = None,
 | 
			
		||||
        structured_mode: Literal["json", "function_call"] = "json",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
 | 
			
		||||
        stream: bool = False,
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> Union[
 | 
			
		||||
        Iterator[LLMChatCandidateChunk],
 | 
			
		||||
        LLMChatResponse,
 | 
			
		||||
        BaseModel,
 | 
			
		||||
        List[BaseModel],
 | 
			
		||||
    ]:
 | 
			
		||||
        """
 | 
			
		||||
        Generate chat completions based on provided messages or input_data for prompt templates.
 | 
			
		||||
        Issue a chat completion.
 | 
			
		||||
 | 
			
		||||
        - If `stream=True` in params, returns an iterator of `LLMChatCandidateChunk`.
 | 
			
		||||
        - Otherwise returns either:
 | 
			
		||||
            • raw `AssistantMessage` wrapped in `LLMChatResponse`, or
 | 
			
		||||
            • validated Pydantic model(s) per `response_format`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            messages (Optional): Either pre-set messages or None if using input_data.
 | 
			
		||||
            input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
 | 
			
		||||
            model (str): Specific model to use for the request, overriding the default.
 | 
			
		||||
            tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
 | 
			
		||||
            response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
 | 
			
		||||
            structured_mode (Literal["json", "function_call"]): Mode for structured output: "json" or "function_call".
 | 
			
		||||
            **kwargs: Additional parameters for the language model.
 | 
			
		||||
            messages:        pre-built messages or None to use `input_data`.
 | 
			
		||||
            input_data:      variables for the Prompty template.
 | 
			
		||||
            model:           override client's default model.
 | 
			
		||||
            tools:           list of AgentTool or dict specs.
 | 
			
		||||
            response_format: Pydantic model (or list thereof) for structured output.
 | 
			
		||||
            structured_mode: “json” or “function_call” (non-stream only).
 | 
			
		||||
            stream:          if True, return an iterator of `LLMChatCandidateChunk`.
 | 
			
		||||
            **kwargs:        any other LLM params (temperature, top_p, stream, etc.).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
 | 
			
		||||
        """
 | 
			
		||||
            • `Iterator[LLMChatCandidateChunk]` if streaming
 | 
			
		||||
            • `LLMChatResponse` or Pydantic instance(s) if non-streaming
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: on invalid `structured_mode`, missing prompts, etc.
 | 
			
		||||
        """
 | 
			
		||||
        # 1) Validate structured_mode
 | 
			
		||||
        if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
 | 
			
		||||
                f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # If input_data is provided, check for a prompt_template
 | 
			
		||||
        # 2) If using a prompt template, build messages
 | 
			
		||||
        if input_data:
 | 
			
		||||
            if not self.prompt_template:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            logger.info("Using prompt template to generate messages.")
 | 
			
		||||
                raise ValueError("No prompt_template set for input_data usage.")
 | 
			
		||||
            logger.info("Formatting messages via prompt_template.")
 | 
			
		||||
            messages = self.prompt_template.format_prompt(**input_data)
 | 
			
		||||
 | 
			
		||||
        # Ensure we have messages at this point
 | 
			
		||||
        if not messages:
 | 
			
		||||
            raise ValueError("Either 'messages' or 'input_data' must be provided.")
 | 
			
		||||
            raise ValueError("Either messages or input_data must be provided.")
 | 
			
		||||
 | 
			
		||||
        # Process and normalize the messages
 | 
			
		||||
        # 3) Normalize messages + merge client/prompty defaults
 | 
			
		||||
        params = {"messages": RequestHandler.normalize_chat_messages(messages)}
 | 
			
		||||
 | 
			
		||||
        # Merge prompty parameters if available, then override with any explicit kwargs
 | 
			
		||||
        if self.prompty:
 | 
			
		||||
            params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
 | 
			
		||||
        else:
 | 
			
		||||
            params.update(kwargs)
 | 
			
		||||
 | 
			
		||||
        # If a model is provided, override the default model
 | 
			
		||||
        # 4) Override model if given
 | 
			
		||||
        params["model"] = model or self.model
 | 
			
		||||
 | 
			
		||||
        # Prepare request parameters
 | 
			
		||||
        # 5) Let RequestHandler inject tools / response_format / structured_mode
 | 
			
		||||
        params = RequestHandler.process_params(
 | 
			
		||||
            params,
 | 
			
		||||
            llm_provider=self.provider,
 | 
			
		||||
| 
						 | 
				
			
			@ -187,21 +206,21 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
 | 
			
		|||
            structured_mode=structured_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 6) Call API + hand off to ResponseHandler
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info("Invoking ChatCompletion API.")
 | 
			
		||||
            logger.debug(f"ChatCompletion API Parameters: {params}")
 | 
			
		||||
            response: ChatCompletionMessage = self.client.chat.completions.create(
 | 
			
		||||
                **params, timeout=self.timeout
 | 
			
		||||
            logger.info("Calling OpenAI ChatCompletion...")
 | 
			
		||||
            logger.debug(f"ChatCompletion params: {params}")
 | 
			
		||||
            resp = self.client.chat.completions.create(
 | 
			
		||||
                **params, stream=stream, timeout=self.timeout
 | 
			
		||||
            )
 | 
			
		||||
            logger.info("Chat completion retrieved successfully.")
 | 
			
		||||
 | 
			
		||||
            logger.info("ChatCompletion response received.")
 | 
			
		||||
            return ResponseHandler.process_response(
 | 
			
		||||
                response,
 | 
			
		||||
                response=resp,
 | 
			
		||||
                llm_provider=self.provider,
 | 
			
		||||
                response_format=response_format,
 | 
			
		||||
                structured_mode=structured_mode,
 | 
			
		||||
                stream=params.get("stream", False),
 | 
			
		||||
                stream=stream,
 | 
			
		||||
            )
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"An error occurred during the ChatCompletion API call: {e}")
 | 
			
		||||
            raise
 | 
			
		||||
            logger.error("ChatCompletion API error", exc_info=True)
 | 
			
		||||
            raise ValueError("Failed to process chat completion") from e
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,266 @@
 | 
			
		|||
import dataclasses
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Any, Callable, Dict, Iterator, Optional
 | 
			
		||||
 | 
			
		||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    FunctionCall,
 | 
			
		||||
    LLMChatCandidate,
 | 
			
		||||
    LLMChatCandidateChunk,
 | 
			
		||||
    LLMChatResponseChunk,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
    ToolCall,
 | 
			
		||||
    ToolCallChunk,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Helper function to handle metadata extraction
 | 
			
		||||
def _get_packet_metadata(
 | 
			
		||||
    pkt: Dict[str, Any], enrich_metadata: Optional[Dict[str, Any]]
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Extract metadata from OpenAI packet and merge with enrich_metadata.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        pkt (Dict[str, Any]): The OpenAI packet from which to extract metadata.
 | 
			
		||||
        enrich_metadata (Optional[Dict[str, Any]]): Additional metadata to merge with the extracted metadata.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dict[str, Any]: The merged metadata dictionary.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        return {
 | 
			
		||||
            "id": pkt.get("id"),
 | 
			
		||||
            "created": pkt.get("created"),
 | 
			
		||||
            "model": pkt.get("model"),
 | 
			
		||||
            "object": pkt.get("object"),
 | 
			
		||||
            "service_tier": pkt.get("service_tier"),
 | 
			
		||||
            "system_fingerprint": pkt.get("system_fingerprint"),
 | 
			
		||||
            **(enrich_metadata or {}),
 | 
			
		||||
        }
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.error(f"Failed to parse packet: {e}", exc_info=True)
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Helper function to process each choice delta (content, function call, tool call, finish reason)
 | 
			
		||||
def _process_choice_delta(
 | 
			
		||||
    choice: Dict[str, Any],
 | 
			
		||||
    overall_meta: Dict[str, Any],
 | 
			
		||||
    on_chunk: Optional[Callable],
 | 
			
		||||
    first_chunk_flag: bool,
 | 
			
		||||
) -> Iterator[LLMChatResponseChunk]:
 | 
			
		||||
    """
 | 
			
		||||
    Process each choice delta and yield corresponding chunks.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        choice (Dict[str, Any]): The choice delta from OpenAI response.
 | 
			
		||||
        overall_meta (Dict[str, Any]): Overall metadata to include in chunks.
 | 
			
		||||
        on_chunk (Optional[Callable]): Callback for each chunk.
 | 
			
		||||
        first_chunk_flag (bool): Flag indicating if this is the first chunk.
 | 
			
		||||
 | 
			
		||||
    Yields:
 | 
			
		||||
        LLMChatResponseChunk: The processed chunk with content, function call, tool calls,
 | 
			
		||||
    """
 | 
			
		||||
    # Make an immutable snapshot for this single chunk
 | 
			
		||||
    meta = {**overall_meta}
 | 
			
		||||
 | 
			
		||||
    # mark first_chunk exactly once
 | 
			
		||||
    if first_chunk_flag and "first_chunk" not in meta:
 | 
			
		||||
        meta["first_chunk"] = True
 | 
			
		||||
 | 
			
		||||
    # Extract initial properties from choice
 | 
			
		||||
    delta: dict = choice.get("delta", {})
 | 
			
		||||
    idx = choice.get("index")
 | 
			
		||||
    finish_reason = choice.get("finish_reason", None)
 | 
			
		||||
    logprobs = choice.get("logprobs", None)
 | 
			
		||||
 | 
			
		||||
    # Set additional metadata
 | 
			
		||||
    if finish_reason in ("stop", "tool_calls"):
 | 
			
		||||
        meta["last_chunk"] = True
 | 
			
		||||
 | 
			
		||||
    # Process content delta
 | 
			
		||||
    content = delta.get("content", None)
 | 
			
		||||
    function_call = delta.get("function_call", None)
 | 
			
		||||
    refusal = delta.get("refusal", None)
 | 
			
		||||
    role = delta.get("role", None)
 | 
			
		||||
 | 
			
		||||
    # Process tool calls
 | 
			
		||||
    chunk_tool_calls = [ToolCallChunk(**tc) for tc in (delta.get("tool_calls") or [])]
 | 
			
		||||
 | 
			
		||||
    # Initialize LLMChatResponseChunk
 | 
			
		||||
    response_chunk = LLMChatResponseChunk(
 | 
			
		||||
        result=LLMChatCandidateChunk(
 | 
			
		||||
            content=content,
 | 
			
		||||
            function_call=function_call,
 | 
			
		||||
            refusal=refusal,
 | 
			
		||||
            role=role,
 | 
			
		||||
            tool_calls=chunk_tool_calls,
 | 
			
		||||
            finish_reason=finish_reason,
 | 
			
		||||
            index=idx,
 | 
			
		||||
            logprobs=logprobs,
 | 
			
		||||
        ),
 | 
			
		||||
        metadata=meta,
 | 
			
		||||
    )
 | 
			
		||||
    # Process chunk with on_chunk callback
 | 
			
		||||
    if on_chunk:
 | 
			
		||||
        on_chunk(response_chunk)
 | 
			
		||||
    # Yield LLMChatResponseChunk
 | 
			
		||||
    yield response_chunk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Main function to process OpenAI streaming response
 | 
			
		||||
def process_openai_stream(
 | 
			
		||||
    raw_stream: Iterator[ChatCompletionChunk],
 | 
			
		||||
    *,
 | 
			
		||||
    enrich_metadata: Optional[Dict[str, Any]] = None,
 | 
			
		||||
    on_chunk: Optional[Callable],
 | 
			
		||||
) -> Iterator[LLMChatCandidateChunk]:
 | 
			
		||||
    """
 | 
			
		||||
    Normalize OpenAI streaming chat into LLMChatCandidateChunk objects,
 | 
			
		||||
    accumulating buffers per choice and yielding both partial and final chunks.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        raw_stream: Iterator from client.chat.completions.create(..., stream=True)
 | 
			
		||||
        enrich_metadata: Extra key/value pairs to merge into each chunk.metadata
 | 
			
		||||
        on_chunk:   Callback fired on every partial delta (token, function, tool)
 | 
			
		||||
 | 
			
		||||
    Yields:
 | 
			
		||||
        LLMChatCandidateChunk for every partial and final piece, in stream order
 | 
			
		||||
    """
 | 
			
		||||
    enrich_metadata = enrich_metadata or {}
 | 
			
		||||
    overall_meta: Dict[str, Any] = {}
 | 
			
		||||
 | 
			
		||||
    # Track if we are in the first chunk
 | 
			
		||||
    first_chunk_flag = True
 | 
			
		||||
 | 
			
		||||
    for packet in raw_stream:
 | 
			
		||||
        # Convert Pydantic / OpenAIObject → plain dict
 | 
			
		||||
        if hasattr(packet, "model_dump"):
 | 
			
		||||
            pkt = packet.model_dump()
 | 
			
		||||
        elif hasattr(packet, "to_dict"):
 | 
			
		||||
            pkt = packet.to_dict()
 | 
			
		||||
        elif dataclasses.is_dataclass(packet):
 | 
			
		||||
            pkt = dataclasses.asdict(packet)
 | 
			
		||||
        else:
 | 
			
		||||
            raise TypeError(f"Cannot serialize packet of type {type(packet)}")
 | 
			
		||||
 | 
			
		||||
        # Capture overall metadata from the packet
 | 
			
		||||
        overall_meta = _get_packet_metadata(pkt, enrich_metadata)
 | 
			
		||||
 | 
			
		||||
        # Process each choice in this packet
 | 
			
		||||
        if choices := pkt.get("choices"):
 | 
			
		||||
            if len(choices) == 0:
 | 
			
		||||
                logger.warning("Received empty 'choices' in OpenAI packet, skipping.")
 | 
			
		||||
                continue
 | 
			
		||||
            # Process the first choice in the packet
 | 
			
		||||
            choice = choices[0]
 | 
			
		||||
            yield from _process_choice_delta(
 | 
			
		||||
                choice, overall_meta, on_chunk, first_chunk_flag
 | 
			
		||||
            )
 | 
			
		||||
            # Set first_chunk_flag to False after processing the first choice
 | 
			
		||||
            first_chunk_flag = False
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning(f" Yielding packet without 'choices': {pkt}")
 | 
			
		||||
            # Initialize default LLMChatResponseChunk
 | 
			
		||||
            final_response_chunk = LLMChatResponseChunk(metadata=overall_meta)
 | 
			
		||||
            yield final_response_chunk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def process_openai_chat_response(openai_response: ChatCompletion) -> LLMChatResponse:
 | 
			
		||||
    """
 | 
			
		||||
    Convert an OpenAI ChatCompletion into our unified LLMChatResponse.
 | 
			
		||||
 | 
			
		||||
    This function:
 | 
			
		||||
      - Safely extracts each choice (skipping malformed ones)
 | 
			
		||||
      - Builds an AssistantMessage with content/refusal/tool_calls/function_call
 | 
			
		||||
      - Wraps into LLMChatCandidate (including index & logprobs)
 | 
			
		||||
      - Collects provider metadata
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        openai_response: A Pydantic ChatCompletion from the OpenAI SDK.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        LLMChatResponse: Contains a list of candidates and a metadata dict.
 | 
			
		||||
    """
 | 
			
		||||
    # 1) Turn into plain dict
 | 
			
		||||
    try:
 | 
			
		||||
        if hasattr(openai_response, "model_dump"):
 | 
			
		||||
            resp = openai_response.model_dump()
 | 
			
		||||
        elif hasattr(openai_response, "to_dict"):
 | 
			
		||||
            resp = openai_response.to_dict()
 | 
			
		||||
        elif dataclasses.is_dataclass(openai_response):
 | 
			
		||||
            resp = dataclasses.asdict(openai_response)
 | 
			
		||||
        else:
 | 
			
		||||
            resp = dict(openai_response)
 | 
			
		||||
    except Exception:
 | 
			
		||||
        logger.exception("Failed to serialize OpenAI chat response")
 | 
			
		||||
        resp = {}
 | 
			
		||||
 | 
			
		||||
    candidates = []
 | 
			
		||||
    for choice in resp.get("choices", []):
 | 
			
		||||
        if "message" not in choice:
 | 
			
		||||
            logger.warning(f"Skipping choice missing 'message': {choice}")
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        msg = choice["message"]
 | 
			
		||||
        # 2) Build tool_calls list if present
 | 
			
		||||
        tool_calls = None
 | 
			
		||||
        if msg.get("tool_calls"):
 | 
			
		||||
            tool_calls = []
 | 
			
		||||
            for tc in msg["tool_calls"]:
 | 
			
		||||
                try:
 | 
			
		||||
                    tool_calls.append(
 | 
			
		||||
                        ToolCall(
 | 
			
		||||
                            id=tc["id"],
 | 
			
		||||
                            type=tc["type"],
 | 
			
		||||
                            function=FunctionCall(
 | 
			
		||||
                                name=tc["function"]["name"],
 | 
			
		||||
                                arguments=tc["function"]["arguments"],
 | 
			
		||||
                            ),
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    logger.warning(f"Invalid tool_call entry {tc}: {e}")
 | 
			
		||||
 | 
			
		||||
        # 3) Build function_call if present
 | 
			
		||||
        function_call = None
 | 
			
		||||
        if fc := msg.get("function_call"):
 | 
			
		||||
            function_call = FunctionCall(
 | 
			
		||||
                name=fc.get("name", ""),
 | 
			
		||||
                arguments=fc.get("arguments", ""),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # 4) Assemble AssistantMessage
 | 
			
		||||
        assistant_message = AssistantMessage(
 | 
			
		||||
            content=msg.get("content"),
 | 
			
		||||
            refusal=msg.get("refusal"),
 | 
			
		||||
            tool_calls=tool_calls,
 | 
			
		||||
            function_call=function_call,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 5) Build candidate, including index & logprobs
 | 
			
		||||
        candidate = LLMChatCandidate(
 | 
			
		||||
            message=assistant_message,
 | 
			
		||||
            finish_reason=choice.get("finish_reason"),
 | 
			
		||||
            index=choice.get("index"),
 | 
			
		||||
            logprobs=choice.get("logprobs"),
 | 
			
		||||
        )
 | 
			
		||||
        candidates.append(candidate)
 | 
			
		||||
 | 
			
		||||
    # 6) Metadata: include provider tag
 | 
			
		||||
    metadata: Dict[str, Any] = {
 | 
			
		||||
        "provider": "openai",
 | 
			
		||||
        "id": resp.get("id"),
 | 
			
		||||
        "model": resp.get("model"),
 | 
			
		||||
        "object": resp.get("object"),
 | 
			
		||||
        "usage": resp.get("usage"),
 | 
			
		||||
        "created": resp.get("created"),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return LLMChatResponse(results=candidates, metadata=metadata)
 | 
			
		||||
| 
						 | 
				
			
			@ -1,89 +1,128 @@
 | 
			
		|||
import logging
 | 
			
		||||
from dataclasses import asdict, is_dataclass
 | 
			
		||||
from typing import Any, Dict, Iterator, Literal, Optional, Type, Union
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    Literal,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Type,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm.utils.stream import StreamHandler
 | 
			
		||||
from dapr_agents.llm.utils.structure import StructureHandler
 | 
			
		||||
from dapr_agents.types import ChatCompletion
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    LLMChatCandidateChunk,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
T = TypeVar("T", bound=BaseModel)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ResponseHandler:
 | 
			
		||||
    """
 | 
			
		||||
    Handles the processing of responses from language models.
 | 
			
		||||
    Handles both streaming and non-streaming chat completions from various LLM providers.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def process_response(
 | 
			
		||||
        response: Any,
 | 
			
		||||
        llm_provider: str,
 | 
			
		||||
        response_format: Optional[Type[BaseModel]] = None,
 | 
			
		||||
        response_format: Optional[Type[T]] = None,
 | 
			
		||||
        structured_mode: Literal["json", "function_call"] = "json",
 | 
			
		||||
        stream: bool = False,
 | 
			
		||||
    ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
 | 
			
		||||
        on_chunk: Optional[Callable[[LLMChatCandidateChunk], None]] = None,
 | 
			
		||||
    ) -> Union[
 | 
			
		||||
        Iterator[LLMChatCandidateChunk],  # when streaming
 | 
			
		||||
        LLMChatResponse,  # non‑stream + no format
 | 
			
		||||
        T,  # non‑stream + single structured format
 | 
			
		||||
        list[T],  # non‑stream + list structured format
 | 
			
		||||
    ]:
 | 
			
		||||
        """
 | 
			
		||||
        Process the response from the language model.
 | 
			
		||||
        Process a chat completion.
 | 
			
		||||
 | 
			
		||||
        - **Streaming** (`stream=True`):
 | 
			
		||||
          Yields `LLMChatCandidateChunk` via `StreamHandler`, honoring `on_chunk` / `on_final`.
 | 
			
		||||
        - **Non-streaming** (`stream=False`):
 | 
			
		||||
          1. Normalize provider envelope → `LLMChatResponse`.
 | 
			
		||||
          2. If no `response_format` requested, return that `LLMChatResponse`.
 | 
			
		||||
          3. Otherwise, extract the first assistant message, parse & validate it
 | 
			
		||||
             against your Pydantic `response_format`, and return the model (or list).
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            response: The response object from the language model.
 | 
			
		||||
            llm_provider: The LLM provider (e.g., 'openai').
 | 
			
		||||
            response_format: A pydantic model to parse and validate the structured response.
 | 
			
		||||
            structured_mode: The mode of the structured response: 'json' or 'function_call'.
 | 
			
		||||
            stream: Whether the response is a stream.
 | 
			
		||||
            response:         Raw API return (stream iterator or full response object).
 | 
			
		||||
            llm_provider:     e.g. `"openai"`.
 | 
			
		||||
            response_format:  Optional Pydantic model (or `List[Model]`) for structured output.
 | 
			
		||||
            structured_mode:  `"json"` or `"function_call"` (only non-stream).
 | 
			
		||||
            stream:           Whether this is a streaming call.
 | 
			
		||||
            on_chunk:         Callback on every partial `LLMChatCandidateChunk`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The processed response.
 | 
			
		||||
            • **streaming**: `Iterator[LLMChatCandidateChunk]`
 | 
			
		||||
            • **non-stream + no format**: full `LLMChatResponse`
 | 
			
		||||
            • **non-stream + format**: validated Pydantic model instance or `List[...]`
 | 
			
		||||
        """
 | 
			
		||||
        provider = llm_provider.lower()
 | 
			
		||||
 | 
			
		||||
        # ─── Streaming ─────────────────────────────────────────────────────────
 | 
			
		||||
        if stream:
 | 
			
		||||
            return StreamHandler.process_stream(
 | 
			
		||||
                stream=response,
 | 
			
		||||
                llm_provider=llm_provider,
 | 
			
		||||
                response_format=response_format,
 | 
			
		||||
                llm_provider=provider,
 | 
			
		||||
                on_chunk=on_chunk,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            if response_format:
 | 
			
		||||
                structured_response_json = StructureHandler.extract_structured_response(
 | 
			
		||||
                    response=response,
 | 
			
		||||
                    llm_provider=llm_provider,
 | 
			
		||||
                    structured_mode=structured_mode,
 | 
			
		||||
                )
 | 
			
		||||
            # ─── Non‑streaming ─────────────────────────────────────────────────────
 | 
			
		||||
            # 1) Normalize full response → LLMChatResponse
 | 
			
		||||
            if provider in ("openai", "nvidia"):
 | 
			
		||||
                from dapr_agents.llm.openai.utils import process_openai_chat_response
 | 
			
		||||
 | 
			
		||||
                # Normalize format and resolve actual model class
 | 
			
		||||
                normalized_format = StructureHandler.normalize_iterable_format(
 | 
			
		||||
                    response_format
 | 
			
		||||
                )
 | 
			
		||||
                model_cls = StructureHandler.resolve_response_model(normalized_format)
 | 
			
		||||
                llm_resp: LLMChatResponse = process_openai_chat_response(response)
 | 
			
		||||
            elif provider == "huggingface":
 | 
			
		||||
                from dapr_agents.llm.huggingface.utils import process_hf_chat_response
 | 
			
		||||
 | 
			
		||||
                if not model_cls:
 | 
			
		||||
                    raise TypeError(
 | 
			
		||||
                        f"Could not resolve a valid Pydantic model from response_format: {response_format}"
 | 
			
		||||
                    )
 | 
			
		||||
                llm_resp = process_hf_chat_response(response)
 | 
			
		||||
            elif provider == "dapr":
 | 
			
		||||
                from dapr_agents.llm.dapr.utils import process_dapr_chat_response
 | 
			
		||||
 | 
			
		||||
                structured_response_instance = StructureHandler.validate_response(
 | 
			
		||||
                    structured_response_json, normalized_format
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                logger.info("Structured output was successfully validated.")
 | 
			
		||||
                if hasattr(structured_response_instance, "objects"):
 | 
			
		||||
                    return structured_response_instance.objects
 | 
			
		||||
                return structured_response_instance
 | 
			
		||||
 | 
			
		||||
            # Convert response to dictionary
 | 
			
		||||
            if isinstance(response, dict):
 | 
			
		||||
                # Already a dictionary
 | 
			
		||||
                response_dict = response
 | 
			
		||||
            elif is_dataclass(response):
 | 
			
		||||
                # Dataclass instance
 | 
			
		||||
                response_dict = asdict(response)
 | 
			
		||||
            elif isinstance(response, BaseModel):
 | 
			
		||||
                # Pydantic object
 | 
			
		||||
                response_dict = response.model_dump()
 | 
			
		||||
                llm_resp = process_dapr_chat_response(response)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Unsupported response type: {type(response)}")
 | 
			
		||||
                # if you add more providers, handle them here
 | 
			
		||||
                llm_resp = response  # type: ignore
 | 
			
		||||
 | 
			
		||||
            completion = ChatCompletion(**response_dict)
 | 
			
		||||
            logger.debug(f"Chat completion response: {completion}")
 | 
			
		||||
            return completion
 | 
			
		||||
            # 2) If no structured format requested, return the full response
 | 
			
		||||
            if response_format is None:
 | 
			
		||||
                return llm_resp
 | 
			
		||||
 | 
			
		||||
            # 3) They did request a Pydantic model → extract first assistant message
 | 
			
		||||
            first_candidate = next(iter(llm_resp.results), None)
 | 
			
		||||
            if not first_candidate:
 | 
			
		||||
                raise ValueError("No candidates in LLMChatResponse")
 | 
			
		||||
            assistant = first_candidate.message
 | 
			
		||||
 | 
			
		||||
            # 3a) Get the raw JSON or function‐call payload
 | 
			
		||||
            raw = StructureHandler.extract_structured_response(
 | 
			
		||||
                message=assistant,
 | 
			
		||||
                llm_provider=llm_provider,
 | 
			
		||||
                structured_mode=structured_mode,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # 3b) Wrap List[Model] → IterableModel if needed
 | 
			
		||||
            fmt = StructureHandler.normalize_iterable_format(response_format)
 | 
			
		||||
            # 3c) Ensure exactly one Pydantic model inside
 | 
			
		||||
            model_cls = StructureHandler.resolve_response_model(fmt)
 | 
			
		||||
            if model_cls is None:
 | 
			
		||||
                raise TypeError(
 | 
			
		||||
                    f"Cannot resolve a Pydantic model from {response_format!r}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # 3d) Validate JSON/dict → Pydantic
 | 
			
		||||
            validated = StructureHandler.validate_response(raw, fmt)
 | 
			
		||||
            logger.info("Structured output successfully validated.")
 | 
			
		||||
 | 
			
		||||
            # 3e) If it’s our auto‑wrapped iterable model, return its `.objects` list
 | 
			
		||||
            return getattr(validated, "objects", validated)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,241 +1,58 @@
 | 
			
		|||
from typing import (
 | 
			
		||||
    Dict,
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    Type,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    get_args,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.llm.utils.structure import StructureHandler
 | 
			
		||||
from dapr_agents.types import ToolCall
 | 
			
		||||
from openai.types.chat import ChatCompletionChunk
 | 
			
		||||
from pydantic import BaseModel, ValidationError
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
from openai.types.chat import ChatCompletionChunk
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types.message import LLMChatCandidateChunk
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T", bound=BaseModel)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StreamHandler:
 | 
			
		||||
    """
 | 
			
		||||
    Handles streaming of chat completion responses, processing tool calls and content responses.
 | 
			
		||||
    Handles streaming of chat completion responses, delegating to the
 | 
			
		||||
    provider-specific stream processor and optionally validating output
 | 
			
		||||
    against Pydantic models.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def process_stream(
 | 
			
		||||
        stream: Iterator[Dict[str, Any]],
 | 
			
		||||
        stream: Iterator[ChatCompletionChunk],
 | 
			
		||||
        llm_provider: str,
 | 
			
		||||
        response_format: Optional[Union[Type[T], Type[Iterable[T]]]] = None,
 | 
			
		||||
    ) -> Iterator[Dict[str, Any]]:
 | 
			
		||||
        on_chunk: Optional[Callable],
 | 
			
		||||
    ) -> Iterator[LLMChatCandidateChunk]:
 | 
			
		||||
        """
 | 
			
		||||
        Stream chat completion responses.
 | 
			
		||||
        Process a streaming chat completion.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            stream: The response stream from the API.
 | 
			
		||||
            llm_provider: The LLM provider to use (e.g., 'openai').
 | 
			
		||||
            response_format: The optional Pydantic model or iterable model for validating the response.
 | 
			
		||||
            stream:           Iterator of ChatCompletionChunk from OpenAI SDK.
 | 
			
		||||
            llm_provider:     Name of the LLM provider (e.g., "openai").
 | 
			
		||||
            on_chunk:         Callback fired on every partial LLMChatCandidateChunk.
 | 
			
		||||
 | 
			
		||||
        Yields:
 | 
			
		||||
            dict: Each processed and validated chunk from the chat completion response.
 | 
			
		||||
            LLMChatCandidateChunk: fully-typed chunks, partial and final.
 | 
			
		||||
        """
 | 
			
		||||
        logger.info("Streaming response enabled.")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if llm_provider == "openai":
 | 
			
		||||
                yield from StreamHandler._process_openai_stream(stream, response_format)
 | 
			
		||||
            else:
 | 
			
		||||
                yield from stream
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"An error occurred during streaming: {e}")
 | 
			
		||||
            raise
 | 
			
		||||
        if llm_provider in ("openai", "nvidia"):
 | 
			
		||||
            from dapr_agents.llm.openai.utils import process_openai_stream
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _process_openai_stream(
 | 
			
		||||
        stream: Iterator[Dict[str, Any]],
 | 
			
		||||
        response_format: Optional[Union[Type[T], Type[Iterable[T]]]] = None,
 | 
			
		||||
    ) -> Iterator[Dict[str, Any]]:
 | 
			
		||||
        """
 | 
			
		||||
        Process OpenAI stream for chat completion.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            stream: The response stream from the OpenAI API.
 | 
			
		||||
            response_format: The optional Pydantic model or iterable model for validating the response.
 | 
			
		||||
 | 
			
		||||
        Yields:
 | 
			
		||||
            dict: Each processed and validated chunk from the chat completion response.
 | 
			
		||||
        """
 | 
			
		||||
        content_accumulator = ""
 | 
			
		||||
        json_extraction_active = False
 | 
			
		||||
        json_brace_level = 0
 | 
			
		||||
        json_string_buffer = ""
 | 
			
		||||
        tool_calls = {}
 | 
			
		||||
 | 
			
		||||
        for chunk in stream:
 | 
			
		||||
            processed_chunk = StreamHandler._process_openai_chunk(chunk)
 | 
			
		||||
            chunk_type = processed_chunk["type"]
 | 
			
		||||
            chunk_data = processed_chunk["data"]
 | 
			
		||||
 | 
			
		||||
            if chunk_type == "content":
 | 
			
		||||
                content_accumulator += chunk_data
 | 
			
		||||
                yield processed_chunk
 | 
			
		||||
            elif chunk_type in ["tool_calls", "function_call"]:
 | 
			
		||||
                for tool_chunk in chunk_data:
 | 
			
		||||
                    tool_call_index = tool_chunk["index"]
 | 
			
		||||
                    tool_call_id = tool_chunk["id"]
 | 
			
		||||
                    tool_call_function = tool_chunk["function"]
 | 
			
		||||
                    tool_call_arguments = tool_call_function["arguments"]
 | 
			
		||||
 | 
			
		||||
                    if tool_call_id is not None:
 | 
			
		||||
                        tool_calls.setdefault(
 | 
			
		||||
                            tool_call_index,
 | 
			
		||||
                            {
 | 
			
		||||
                                "id": tool_call_id,
 | 
			
		||||
                                "type": tool_chunk["type"],
 | 
			
		||||
                                "function": {
 | 
			
		||||
                                    "name": tool_call_function["name"],
 | 
			
		||||
                                    "arguments": tool_call_arguments,
 | 
			
		||||
                                },
 | 
			
		||||
                            },
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    # Add tool call arguments to current tool calls
 | 
			
		||||
                    tool_calls[tool_call_index]["function"][
 | 
			
		||||
                        "arguments"
 | 
			
		||||
                    ] += tool_call_arguments
 | 
			
		||||
 | 
			
		||||
                    # Process Iterable model if provided
 | 
			
		||||
                    if (
 | 
			
		||||
                        response_format
 | 
			
		||||
                        and isinstance(response_format, Iterable) is True
 | 
			
		||||
                    ):
 | 
			
		||||
                        trimmed_character = tool_call_arguments.strip()
 | 
			
		||||
                        # Check beginning of List
 | 
			
		||||
                        if trimmed_character == "[" and json_extraction_active is False:
 | 
			
		||||
                            json_extraction_active = True
 | 
			
		||||
                        # Check beginning of a JSON object
 | 
			
		||||
                        elif (
 | 
			
		||||
                            trimmed_character == "{" and json_extraction_active is True
 | 
			
		||||
                        ):
 | 
			
		||||
                            json_brace_level += 1
 | 
			
		||||
                            json_string_buffer += trimmed_character
 | 
			
		||||
                        # Check the end of a JSON object
 | 
			
		||||
                        elif (
 | 
			
		||||
                            "}" in trimmed_character and json_extraction_active is True
 | 
			
		||||
                        ):
 | 
			
		||||
                            json_brace_level -= 1
 | 
			
		||||
                            json_string_buffer += trimmed_character.rstrip(",")
 | 
			
		||||
                            if json_brace_level == 0:
 | 
			
		||||
                                yield from StreamHandler._validate_json_object(
 | 
			
		||||
                                    response_format, json_string_buffer
 | 
			
		||||
                                )
 | 
			
		||||
                                # Reset buffers and counts
 | 
			
		||||
                                json_string_buffer = ""
 | 
			
		||||
                        elif json_extraction_active is True:
 | 
			
		||||
                            json_string_buffer += tool_call_arguments
 | 
			
		||||
 | 
			
		||||
        if content_accumulator:
 | 
			
		||||
            yield {"type": "final_content", "data": content_accumulator}
 | 
			
		||||
 | 
			
		||||
        if tool_calls:
 | 
			
		||||
            yield from StreamHandler._get_final_tool_calls(tool_calls, response_format)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _process_openai_chunk(chunk: ChatCompletionChunk) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Process OpenAI chat completion chunk.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            chunk: The chunk from the OpenAI API.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dict: Processed chunk.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            chunk_dict = chunk.model_dump()
 | 
			
		||||
 | 
			
		||||
            if chunk_dict.get("choices") and len(chunk_dict["choices"]) > 0:
 | 
			
		||||
                choice: Dict = chunk_dict["choices"][0]
 | 
			
		||||
                delta: Dict = choice.get("delta", {})
 | 
			
		||||
 | 
			
		||||
                # Process content
 | 
			
		||||
                if delta.get("content") is not None:
 | 
			
		||||
                    return {"type": "content", "data": delta["content"], "chunk": chunk}
 | 
			
		||||
 | 
			
		||||
                # Process tool calls
 | 
			
		||||
                if delta.get("tool_calls"):
 | 
			
		||||
                    return {
 | 
			
		||||
                        "type": "tool_calls",
 | 
			
		||||
                        "data": delta["tool_calls"],
 | 
			
		||||
                        "chunk": chunk,
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                # Process function calls
 | 
			
		||||
                if delta.get("function_call"):
 | 
			
		||||
                    return {
 | 
			
		||||
                        "type": "function_call",
 | 
			
		||||
                        "data": delta["function_call"],
 | 
			
		||||
                        "chunk": chunk,
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                # Process finish reason
 | 
			
		||||
                if choice.get("finish_reason"):
 | 
			
		||||
                    return {
 | 
			
		||||
                        "type": "finish",
 | 
			
		||||
                        "data": choice["finish_reason"],
 | 
			
		||||
                        "chunk": chunk,
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
            return {}
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error handling OpenAI chat completion chunk: {e}")
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _validate_json_object(
 | 
			
		||||
        response_format: Optional[Union[Type[T], Type[Iterable[T]]]],
 | 
			
		||||
        json_string_buffer: str,
 | 
			
		||||
    ):
 | 
			
		||||
        try:
 | 
			
		||||
            model_class = get_args(response_format)[0]
 | 
			
		||||
            # Return current tool call
 | 
			
		||||
            structured_output = StructureHandler.validate_response(
 | 
			
		||||
                json_string_buffer, model_class
 | 
			
		||||
            yield from process_openai_stream(
 | 
			
		||||
                raw_stream=stream,
 | 
			
		||||
                enrich_metadata={"provider": llm_provider},
 | 
			
		||||
                on_chunk=on_chunk,
 | 
			
		||||
            )
 | 
			
		||||
            if isinstance(structured_output, model_class):
 | 
			
		||||
                logger.info("Structured output was successfully validated.")
 | 
			
		||||
                yield {"type": "structured_output", "data": structured_output}
 | 
			
		||||
        except ValidationError as validation_error:
 | 
			
		||||
            logger.error(
 | 
			
		||||
                f"Validation error: {validation_error} with JSON: {json_string_buffer}"
 | 
			
		||||
        elif llm_provider == "huggingface":
 | 
			
		||||
            from dapr_agents.llm.huggingface.utils import process_hf_stream
 | 
			
		||||
 | 
			
		||||
            yield from process_hf_stream(
 | 
			
		||||
                raw_stream=stream,
 | 
			
		||||
                enrich_metadata={"provider": llm_provider},
 | 
			
		||||
                on_chunk=on_chunk,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_final_tool_calls(
 | 
			
		||||
        tool_calls: Dict[int, Any],
 | 
			
		||||
        response_format: Optional[Union[Type[T], Type[Iterable[T]]]],
 | 
			
		||||
    ) -> Iterator[Dict[str, Any]]:
 | 
			
		||||
        """
 | 
			
		||||
        Yield final tool calls after processing.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            tool_calls: The dictionary of accumulated tool calls.
 | 
			
		||||
            response_format: The response model for validation.
 | 
			
		||||
 | 
			
		||||
        Yields:
 | 
			
		||||
            dict: Each processed and validated tool call.
 | 
			
		||||
        """
 | 
			
		||||
        for tool in tool_calls.values():
 | 
			
		||||
            if response_format and isinstance(response_format, Iterable) is False:
 | 
			
		||||
                structured_output = StructureHandler.validate_response(
 | 
			
		||||
                    tool["function"]["arguments"], response_format
 | 
			
		||||
                )
 | 
			
		||||
                if isinstance(structured_output, response_format):
 | 
			
		||||
                    logger.info("Structured output was successfully validated.")
 | 
			
		||||
                    yield {"type": "structured_output", "data": structured_output}
 | 
			
		||||
            else:
 | 
			
		||||
                tool_call = ToolCall(**tool)
 | 
			
		||||
                yield {"type": "final_tool_call", "data": tool_call}
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Streaming not supported for provider: {llm_provider}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,12 @@ from typing import (
 | 
			
		|||
from pydantic import BaseModel, Field, TypeAdapter, ValidationError, create_model
 | 
			
		||||
 | 
			
		||||
from dapr_agents.tool.utils.function_calling import to_function_call_definition
 | 
			
		||||
from dapr_agents.types import OAIJSONSchema, OAIResponseFormatSchema, StructureError
 | 
			
		||||
from dapr_agents.types import (
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    OAIJSONSchema,
 | 
			
		||||
    OAIResponseFormatSchema,
 | 
			
		||||
    StructureError,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -208,7 +213,7 @@ class StructureHandler:
 | 
			
		|||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def extract_structured_response(
 | 
			
		||||
        response: Any,
 | 
			
		||||
        message: AssistantMessage,
 | 
			
		||||
        llm_provider: str,
 | 
			
		||||
        structured_mode: Literal["json", "function_call"] = "json",
 | 
			
		||||
    ) -> Union[str, Dict[str, Any]]:
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +221,7 @@ class StructureHandler:
 | 
			
		|||
        Extracts the structured JSON string or content from the response.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            response (Any): The API response data to extract.
 | 
			
		||||
            message (AssistantMessage): The API response data to extract.
 | 
			
		||||
            llm_provider (str): The LLM provider (e.g., 'openai').
 | 
			
		||||
            structured_mode (Literal["json", "function_call"]): The structured response mode.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -228,17 +233,7 @@ class StructureHandler:
 | 
			
		|||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            logger.debug(f"Processing structured response for mode: {structured_mode}")
 | 
			
		||||
            if llm_provider in ("openai", "nvidia"):
 | 
			
		||||
                # Extract the `choices` list from the response
 | 
			
		||||
                choices = getattr(response, "choices", None)
 | 
			
		||||
                if not choices or not isinstance(choices, list):
 | 
			
		||||
                    raise StructureError("Response does not contain valid 'choices'.")
 | 
			
		||||
 | 
			
		||||
                # Extract the message object
 | 
			
		||||
                message = getattr(choices[0], "message", None)
 | 
			
		||||
                if not message:
 | 
			
		||||
                    raise StructureError("Response message is missing.")
 | 
			
		||||
 | 
			
		||||
            if llm_provider in ("openai", "nvidia", "huggingface"):
 | 
			
		||||
                if structured_mode == "function_call":
 | 
			
		||||
                    tool_calls = getattr(message, "tool_calls", None)
 | 
			
		||||
                    if tool_calls:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -55,9 +55,9 @@ class DaprHTTPClient(BaseModel):
 | 
			
		|||
            otel_enabled = False
 | 
			
		||||
 | 
			
		||||
        if otel_enabled:
 | 
			
		||||
            from dapr_agents.agents.telemetry.otel import DaprAgentsOTel  # type: ignore[import-not-found]
 | 
			
		||||
            from dapr_agents.agents.telemetry.otel import DaprAgentsOtel  # type: ignore[import-not-found]
 | 
			
		||||
 | 
			
		||||
            otel_client = DaprAgentsOTel(
 | 
			
		||||
            otel_client = DaprAgentsOtel(
 | 
			
		||||
                service_name=os.getenv("OTEL_SERVICE_NAME", "dapr-http-client"),
 | 
			
		||||
                otlp_endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,13 +1,14 @@
 | 
			
		|||
import logging
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel, ValidationError
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types import (
 | 
			
		||||
    ClaudeToolDefinition,
 | 
			
		||||
    OAIFunctionDefinition,
 | 
			
		||||
    OAIToolDefinition,
 | 
			
		||||
    ClaudeToolDefinition,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.types.exceptions import FunCallBuilderError
 | 
			
		||||
from pydantic import BaseModel, ValidationError
 | 
			
		||||
from typing import Dict, Any, Optional
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -130,72 +131,97 @@ def to_function_call_definition(
 | 
			
		|||
    args_schema: BaseModel,
 | 
			
		||||
    format_type: str = "openai",
 | 
			
		||||
    use_deprecated: bool = False,
 | 
			
		||||
) -> Dict:
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Generates a dictionary representing a function call specification, supporting various API formats.
 | 
			
		||||
    The 'use_deprecated' flag is applicable only for the 'openai' format and is ignored for others.
 | 
			
		||||
 | 
			
		||||
    - For format_type in ("openai", "nvidia", "huggingface"), produces an OpenAI-style
 | 
			
		||||
      tool definition (type="function", function={…}).
 | 
			
		||||
    - For "claude", produces a Claude-style {name, description, input_schema}.
 | 
			
		||||
    - (Gemini omitted here—call to_gemini_function_call_definition if you need it.)
 | 
			
		||||
 | 
			
		||||
    The 'use_deprecated' flag is only applicable for OpenAI-style definitions.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        name (str): The name of the function.
 | 
			
		||||
        description (str): A brief description of what the function does.
 | 
			
		||||
        args_schema (BaseModel): The Pydantic schema representing the function's parameters.
 | 
			
		||||
        format_type (str, optional): The API format to convert to ('openai', 'claude', or 'gemini'). Defaults to 'openai'.
 | 
			
		||||
        use_deprecated (bool): Flag to use the deprecated function format, only effective for 'openai'.
 | 
			
		||||
        args_schema (BaseModel): The Pydantic model describing the function's parameters.
 | 
			
		||||
        format_type (str, optional): Which API flavor to target:
 | 
			
		||||
            - "openai", "nvidia", or "huggingface" all share the same OpenAI-style schema.
 | 
			
		||||
            - "claude" uses Anthropic's format.
 | 
			
		||||
          Defaults to "openai".
 | 
			
		||||
        use_deprecated (bool): If True and format_type is OpenAI,
 | 
			
		||||
            returns the old function-only schema rather than a tool wrapper.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dict: A dictionary containing the function definition in the specified format.
 | 
			
		||||
        Dict[str, Any]: The serialized function/tool definition.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        FunCallBuilderError: If an unsupported format type is specified.
 | 
			
		||||
        FunCallBuilderError: If an unsupported format_type is provided.
 | 
			
		||||
    """
 | 
			
		||||
    if format_type.lower() in ("openai", "nvidia"):
 | 
			
		||||
    fmt = format_type.lower()
 | 
			
		||||
 | 
			
		||||
    # OpenAI‑style wrapper schema:
 | 
			
		||||
    if fmt in ("openai", "nvidia", "huggingface"):
 | 
			
		||||
        return to_openai_function_call_definition(
 | 
			
		||||
            name, description, args_schema, use_deprecated
 | 
			
		||||
        )
 | 
			
		||||
    elif format_type.lower() == "claude":
 | 
			
		||||
 | 
			
		||||
    # Anthropic Claude needs its own input_schema property
 | 
			
		||||
    if fmt == "claude":
 | 
			
		||||
        if use_deprecated:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"'use_deprecated' flag is ignored for the '{format_type}' format."
 | 
			
		||||
            )
 | 
			
		||||
        return to_claude_function_call_definition(name, description, args_schema)
 | 
			
		||||
    else:
 | 
			
		||||
        logger.error(f"Unsupported format type: {format_type}")
 | 
			
		||||
        raise FunCallBuilderError(f"Unsupported format type: {format_type}")
 | 
			
		||||
 | 
			
		||||
    # Unsupported provider
 | 
			
		||||
    logger.error(f"Unsupported format type: {format_type}")
 | 
			
		||||
    raise FunCallBuilderError(f"Unsupported format type: {format_type}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_and_format_tool(
 | 
			
		||||
    tool: Dict[str, Any], tool_format: str = "openai", use_deprecated: bool = False
 | 
			
		||||
) -> dict:
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Validates and formats a tool (provided as a dictionary) based on the specified API request format.
 | 
			
		||||
    Validates and formats a tool definition dict for the specified API style.
 | 
			
		||||
 | 
			
		||||
    - For tool_format in ("openai", "azure_openai", "nvidia", "huggingface"),
 | 
			
		||||
      uses OAIToolDefinition (or OAIFunctionDefinition if use_deprecated=True).
 | 
			
		||||
    - For "claude", uses ClaudeToolDefinition.
 | 
			
		||||
    - For "llama", treats as an OAIFunctionDefinition.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        tool: The tool to validate and format.
 | 
			
		||||
        tool_format: The API format to convert to ('openai', 'azure_openai', 'claude', 'llama').
 | 
			
		||||
        use_deprecated: Whether to use deprecated functions format for OpenAI. Defaults to False.
 | 
			
		||||
        tool (Dict[str, Any]): The raw tool definition.
 | 
			
		||||
        tool_format (str): Which API schema to validate against:
 | 
			
		||||
            "openai", "azure_openai", "nvidia", "huggingface", "claude", "llama".
 | 
			
		||||
        use_deprecated (bool): If True and using OpenAI-style, expects an OAIFunctionDefinition.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        dict: The formatted tool dictionary.
 | 
			
		||||
        Dict[str, Any]: The validated, serialized tool definition.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        ValueError: If the tool definition format is invalid.
 | 
			
		||||
        ValidationError: If the tool doesn't pass validation.
 | 
			
		||||
        ValueError: If the format is unsupported or validation fails.
 | 
			
		||||
    """
 | 
			
		||||
    fmt = tool_format.lower()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        if tool_format in ["openai", "azure_openai", "nvidia"]:
 | 
			
		||||
            validated_tool = (
 | 
			
		||||
        if fmt in ("openai", "azure_openai", "nvidia", "huggingface"):
 | 
			
		||||
            validated = (
 | 
			
		||||
                OAIFunctionDefinition(**tool)
 | 
			
		||||
                if use_deprecated
 | 
			
		||||
                else OAIToolDefinition(**tool)
 | 
			
		||||
            )
 | 
			
		||||
        elif tool_format == "claude":
 | 
			
		||||
            validated_tool = ClaudeToolDefinition(**tool)
 | 
			
		||||
        elif tool_format == "llama":
 | 
			
		||||
            validated_tool = OAIFunctionDefinition(**tool)
 | 
			
		||||
        elif fmt == "claude":
 | 
			
		||||
            validated = ClaudeToolDefinition(**tool)
 | 
			
		||||
        elif fmt == "llama":
 | 
			
		||||
            validated = OAIFunctionDefinition(**tool)
 | 
			
		||||
        else:
 | 
			
		||||
            logger.error(f"Unsupported tool format: {tool_format}")
 | 
			
		||||
            raise ValueError(f"Unsupported tool format: {tool_format}")
 | 
			
		||||
        return validated_tool.model_dump()
 | 
			
		||||
 | 
			
		||||
        return validated.model_dump()
 | 
			
		||||
 | 
			
		||||
    except ValidationError as e:
 | 
			
		||||
        logger.error(f"Validation error for {tool_format} tool definition: {e}")
 | 
			
		||||
        raise ValueError(f"Invalid tool definition format: {tool}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,8 +12,6 @@ from .message import (
 | 
			
		|||
    AssistantFinalMessage,
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    BaseMessage,
 | 
			
		||||
    ChatCompletion,
 | 
			
		||||
    Choice,
 | 
			
		||||
    EventMessageMetadata,
 | 
			
		||||
    FunctionCall,
 | 
			
		||||
    MessageContent,
 | 
			
		||||
| 
						 | 
				
			
			@ -22,6 +20,8 @@ from .message import (
 | 
			
		|||
    ToolCall,
 | 
			
		||||
    ToolMessage,
 | 
			
		||||
    UserMessage,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
    LLMChatCandidate,
 | 
			
		||||
)
 | 
			
		||||
from .schemas import OAIJSONSchema, OAIResponseFormatSchema
 | 
			
		||||
from .tools import (
 | 
			
		||||
| 
						 | 
				
			
			@ -47,8 +47,8 @@ __all__ = [
 | 
			
		|||
    "AssistantFinalMessage",
 | 
			
		||||
    "AssistantMessage",
 | 
			
		||||
    "BaseMessage",
 | 
			
		||||
    "ChatCompletion",
 | 
			
		||||
    "Choice",
 | 
			
		||||
    "LLMChatResponse",
 | 
			
		||||
    "LLMChatCandidate",
 | 
			
		||||
    "EventMessageMetadata",
 | 
			
		||||
    "FunctionCall",
 | 
			
		||||
    "MessageContent",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,7 @@ from pydantic import (
 | 
			
		|||
    model_validator,
 | 
			
		||||
    ConfigDict,
 | 
			
		||||
)
 | 
			
		||||
from typing import List, Optional, Dict
 | 
			
		||||
from typing import List, Optional, Dict, Any
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -118,6 +118,36 @@ class ToolCall(BaseModel):
 | 
			
		|||
    function: FunctionCall
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FunctionCallChunk(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a function call chunk in a streaming response, containing the function name and arguments.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        name (str): The name of the function being called.
 | 
			
		||||
        arguments (str): The JSON string representation of the function's arguments.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    name: Optional[str] = None
 | 
			
		||||
    arguments: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ToolCallChunk(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a tool call chunk in a streaming response, containing the index, ID, type, and function call details.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        index (int): The index of the tool call in the response.
 | 
			
		||||
        id (str): Unique identifier for the tool call.
 | 
			
		||||
        type (str): The type of the tool call.
 | 
			
		||||
        function (FunctionCallChunk): The function call details associated with the tool call.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    index: int
 | 
			
		||||
    id: Optional[str] = None
 | 
			
		||||
    type: Optional[str] = None
 | 
			
		||||
    function: FunctionCallChunk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MessageContent(BaseMessage):
 | 
			
		||||
    """
 | 
			
		||||
    Extends BaseMessage to include dynamic optional fields for tool calls, function calls, and tool call IDs.
 | 
			
		||||
| 
						 | 
				
			
			@ -148,73 +178,6 @@ class MessageContent(BaseMessage):
 | 
			
		|||
        return self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Choice(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a choice made by the model, detailing the reason for completion, its index, and the message content.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        finish_reason (str): Reason why the model stopped generating text.
 | 
			
		||||
        index (int): Index of the choice in a list of potential choices.
 | 
			
		||||
        message (MessageContent): Content of the message chosen by the model.
 | 
			
		||||
        logprobs (Optional[dict]): Log probabilities associated with the choice.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    finish_reason: str
 | 
			
		||||
    index: int
 | 
			
		||||
    message: MessageContent
 | 
			
		||||
    logprobs: Optional[dict]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletion(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents the full response from the chat API, including all choices, metadata, and usage information.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        choices (List[Choice]): List of choices provided by the model.
 | 
			
		||||
        created (int): Timestamp when the response was created.
 | 
			
		||||
        id (str): Unique identifier for the response.
 | 
			
		||||
        model (str): Model used for generating the response.
 | 
			
		||||
        object (str): Type of object returned.
 | 
			
		||||
        usage (dict): Information about API usage for this request.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    choices: List[Choice]
 | 
			
		||||
    created: int
 | 
			
		||||
    id: Optional[str] = None
 | 
			
		||||
    model: str
 | 
			
		||||
    object: Optional[str] = None
 | 
			
		||||
    usage: dict
 | 
			
		||||
 | 
			
		||||
    def get_message(self) -> Optional[dict]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve the main message content from the first choice.
 | 
			
		||||
        """
 | 
			
		||||
        return self.choices[0].message.model_dump() if self.choices else None
 | 
			
		||||
 | 
			
		||||
    def get_reason(self) -> Optional[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve the reason for completion from the first choice.
 | 
			
		||||
        """
 | 
			
		||||
        return self.choices[0].finish_reason if self.choices else None
 | 
			
		||||
 | 
			
		||||
    def get_tool_calls(self) -> Optional[List[ToolCall]]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve tool calls from the first choice, if available.
 | 
			
		||||
        """
 | 
			
		||||
        return (
 | 
			
		||||
            self.choices[0].message.tool_calls
 | 
			
		||||
            if self.choices and self.choices[0].message.tool_calls
 | 
			
		||||
            else None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_content(self) -> Optional[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve the content from the first choice's message.
 | 
			
		||||
        """
 | 
			
		||||
        message = self.get_message()
 | 
			
		||||
        return message.get("content") if message else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SystemMessage(BaseMessage):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a system message, automatically assigning the role to 'system'.
 | 
			
		||||
| 
						 | 
				
			
			@ -249,6 +212,7 @@ class AssistantMessage(BaseMessage):
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
    role: str = "assistant"
 | 
			
		||||
    refusal: Optional[str] = None
 | 
			
		||||
    tool_calls: Optional[List[ToolCall]] = None
 | 
			
		||||
    function_call: Optional[FunctionCall] = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -256,7 +220,7 @@ class AssistantMessage(BaseMessage):
 | 
			
		|||
    def remove_empty_calls(self):
 | 
			
		||||
        attrList = []
 | 
			
		||||
        for attribute in self.__dict__:
 | 
			
		||||
            if attribute in ("tool_calls", "function_call"):
 | 
			
		||||
            if attribute in ("tool_calls", "function_call", "refusal"):
 | 
			
		||||
                if self.__dict__[attribute] is None:
 | 
			
		||||
                    attrList.append(attribute)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -265,6 +229,28 @@ class AssistantMessage(BaseMessage):
 | 
			
		|||
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def get_tool_calls(self) -> Optional[List[ToolCall]]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve tool calls from the message if available.
 | 
			
		||||
        """
 | 
			
		||||
        if getattr(self, "tool_calls", None) is None:
 | 
			
		||||
            return None
 | 
			
		||||
        if isinstance(self.tool_calls, list):
 | 
			
		||||
            return self.tool_calls
 | 
			
		||||
        if isinstance(self.tool_calls, ToolCall):
 | 
			
		||||
            return [self.tool_calls]
 | 
			
		||||
 | 
			
		||||
    def has_tool_calls(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Check if the message has tool calls.
 | 
			
		||||
        """
 | 
			
		||||
        if not hasattr(self, "tool_calls"):
 | 
			
		||||
            return False
 | 
			
		||||
        if self.tool_calls is not None:
 | 
			
		||||
            return True
 | 
			
		||||
        if isinstance(self.tool_calls, ToolCall):
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ToolMessage(BaseMessage):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -279,6 +265,67 @@ class ToolMessage(BaseMessage):
 | 
			
		|||
    tool_call_id: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMChatCandidate(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a single candidate (output) from an LLM chat response.
 | 
			
		||||
    Allows provider-specific extra fields (e.g., index, logprobs, etc.).
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        message (AssistantMessage): The assistant's message for this candidate.
 | 
			
		||||
        finish_reason (Optional[str]): Why the model stopped generating text.
 | 
			
		||||
        [Any other provider-specific fields, e.g., index, logprobs, etc.]
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    message: AssistantMessage
 | 
			
		||||
    finish_reason: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    class Config:
 | 
			
		||||
        extra = "allow"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMChatResponse(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Unified response for LLM chat completions, supporting multiple providers.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        results (List[LLMChatCandidate]): List of candidate outputs.
 | 
			
		||||
        metadata (dict): Provider/model metadata (id, model, usage, etc.).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    results: List[LLMChatCandidate]
 | 
			
		||||
    metadata: dict = {}
 | 
			
		||||
 | 
			
		||||
    def get_message(self) -> Optional[AssistantMessage]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieves the first message from the results if available.
 | 
			
		||||
        """
 | 
			
		||||
        return self.results[0].message if self.results else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMChatCandidateChunk(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a partial (streamed) candidate from an LLM provider, for real-time streaming.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    content: Optional[str] = None
 | 
			
		||||
    function_call: Optional[Dict[str, Any]] = None
 | 
			
		||||
    refusal: Optional[str] = None
 | 
			
		||||
    role: Optional[str] = None
 | 
			
		||||
    tool_calls: Optional[List["ToolCallChunk"]] = None
 | 
			
		||||
    finish_reason: Optional[str] = None
 | 
			
		||||
    index: Optional[int] = None
 | 
			
		||||
    logprobs: Optional[dict] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMChatResponseChunk(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a partial (streamed) response from an LLM provider, for real-time streaming.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    result: LLMChatCandidateChunk
 | 
			
		||||
    metadata: Optional[dict] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AssistantFinalMessage(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents a custom final message from the assistant, encapsulating a conclusive response to the user.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@ from dapr.ext.workflow.workflow_state import WorkflowState
 | 
			
		|||
from durabletask import task as dtask
 | 
			
		||||
from pydantic import BaseModel, ConfigDict, Field
 | 
			
		||||
 | 
			
		||||
from dapr_agents.agents.base import ChatClientType
 | 
			
		||||
from dapr_agents.agents.base import ChatClientBase
 | 
			
		||||
from dapr_agents.llm.openai import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.workflow import DaprWorkflowStatus
 | 
			
		||||
from dapr_agents.workflow.task import WorkflowTask
 | 
			
		||||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ class WorkflowApp(BaseModel):
 | 
			
		|||
    A Pydantic-based class to encapsulate a Dapr Workflow runtime and manage workflows and tasks.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    llm: ChatClientType = Field(
 | 
			
		||||
    llm: ChatClientBase = Field(
 | 
			
		||||
        default_factory=OpenAIChatClient,
 | 
			
		||||
        description="The default LLM client for all LLM-based tasks.",
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +85,7 @@ class WorkflowApp(BaseModel):
 | 
			
		|||
 | 
			
		||||
        super().model_post_init(__context)
 | 
			
		||||
 | 
			
		||||
    def _choose_llm_for(self, method: Callable) -> Optional[ChatClientType]:
 | 
			
		||||
    def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
 | 
			
		||||
        """
 | 
			
		||||
        Encapsulate LLM selection logic.
 | 
			
		||||
          1. Use per-task override if provided on decorator.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,34 +3,35 @@ from datetime import datetime, timedelta
 | 
			
		|||
from typing import Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext
 | 
			
		||||
from dapr_agents.workflow.decorators import task, workflow, message_router
 | 
			
		||||
 | 
			
		||||
from dapr_agents.workflow.decorators import message_router, task, workflow
 | 
			
		||||
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
 | 
			
		||||
from dapr_agents.workflow.orchestrators.llm.schemas import (
 | 
			
		||||
    BroadcastMessage,
 | 
			
		||||
    TriggerAction,
 | 
			
		||||
    NextStep,
 | 
			
		||||
    AgentTaskResponse,
 | 
			
		||||
    ProgressCheckOutput,
 | 
			
		||||
    schemas,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.workflow.orchestrators.llm.prompts import (
 | 
			
		||||
    TASK_INITIAL_PROMPT,
 | 
			
		||||
    TASK_PLANNING_PROMPT,
 | 
			
		||||
    NEXT_STEP_PROMPT,
 | 
			
		||||
    PROGRESS_CHECK_PROMPT,
 | 
			
		||||
    SUMMARY_GENERATION_PROMPT,
 | 
			
		||||
    TASK_INITIAL_PROMPT,
 | 
			
		||||
    TASK_PLANNING_PROMPT,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.workflow.orchestrators.llm.schemas import (
 | 
			
		||||
    AgentTaskResponse,
 | 
			
		||||
    BroadcastMessage,
 | 
			
		||||
    NextStep,
 | 
			
		||||
    ProgressCheckOutput,
 | 
			
		||||
    TriggerAction,
 | 
			
		||||
    schemas,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.workflow.orchestrators.llm.state import (
 | 
			
		||||
    LLMWorkflowState,
 | 
			
		||||
    LLMWorkflowEntry,
 | 
			
		||||
    LLMWorkflowMessage,
 | 
			
		||||
    LLMWorkflowState,
 | 
			
		||||
    PlanStep,
 | 
			
		||||
    TaskResult,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.workflow.orchestrators.llm.utils import (
 | 
			
		||||
    update_step_statuses,
 | 
			
		||||
    restructure_plan,
 | 
			
		||||
    find_step_in_plan,
 | 
			
		||||
    restructure_plan,
 | 
			
		||||
    update_step_statuses,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
| 
						 | 
				
			
			@ -65,244 +66,234 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
    def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction):
 | 
			
		||||
        """
 | 
			
		||||
        Executes an LLM-driven agentic workflow where the next agent is dynamically selected
 | 
			
		||||
        based on task progress. The workflow iterates through execution cycles, updating state,
 | 
			
		||||
        handling agent responses, and determining task completion.
 | 
			
		||||
        based on task progress. Runs for up to `self.max_iterations` turns, then summarizes.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            ctx (DaprWorkflowContext): The workflow execution context.
 | 
			
		||||
            message (TriggerAction): The current workflow state containing `message`, `iteration`, and `verdict`.
 | 
			
		||||
            message (TriggerAction): Contains the current `task`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: The final processed message when the workflow terminates.
 | 
			
		||||
            str: The final summary when the workflow terminates.
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            RuntimeError: If the LLM determines the task is `failed`.
 | 
			
		||||
            RuntimeError: If the workflow ends unexpectedly without a final summary.
 | 
			
		||||
        """
 | 
			
		||||
        # Step 0: Retrieve iteration messages
 | 
			
		||||
        # Step 1: Retrieve initial task and ensure state entry exists
 | 
			
		||||
        task = message.get("task")
 | 
			
		||||
        iteration = message.get("iteration", 0)
 | 
			
		||||
 | 
			
		||||
        # Step 1:
 | 
			
		||||
        # Ensure 'instances' and the instance_id entry exist
 | 
			
		||||
        instance_id = ctx.instance_id
 | 
			
		||||
        self.state.setdefault("instances", {}).setdefault(
 | 
			
		||||
            instance_id, LLMWorkflowEntry(input=task).model_dump(mode="json")
 | 
			
		||||
        )
 | 
			
		||||
        # Retrieve the plan (will always exist after initialization)
 | 
			
		||||
        plan = self.state["instances"][instance_id].get("plan", [])
 | 
			
		||||
        final_summary: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
        if not ctx.is_replaying:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Workflow iteration {iteration + 1} started (Instance ID: {instance_id})."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Step 2: Retrieve available agents
 | 
			
		||||
        agents = yield ctx.call_activity(self.get_agents_metadata_as_string)
 | 
			
		||||
 | 
			
		||||
        # Step 3: First iteration setup
 | 
			
		||||
        if iteration == 0:
 | 
			
		||||
        # Single loop from turn 1 to max_iterations
 | 
			
		||||
        for turn in range(1, self.max_iterations + 1):
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(f"Initial message from User -> {self.name}")
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"Workflow turn {turn}/{self.max_iterations} (Instance ID: {instance_id})"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Generate the plan using a language model
 | 
			
		||||
            plan = yield ctx.call_activity(
 | 
			
		||||
                self.generate_plan,
 | 
			
		||||
                input={"task": task, "agents": agents, "plan_schema": schemas.plan},
 | 
			
		||||
            )
 | 
			
		||||
            # Step 2: Get available agents
 | 
			
		||||
            agents = yield ctx.call_activity(self.get_agents_metadata_as_string)
 | 
			
		||||
 | 
			
		||||
            # Prepare initial message with task, agents and plan context
 | 
			
		||||
            initial_message = yield ctx.call_activity(
 | 
			
		||||
                self.prepare_initial_message,
 | 
			
		||||
            # Step 3: On turn 1, generate plan and broadcast task
 | 
			
		||||
            if turn == 1:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(f"Initial message from User -> {self.name}")
 | 
			
		||||
 | 
			
		||||
                plan = yield ctx.call_activity(
 | 
			
		||||
                    self.generate_plan,
 | 
			
		||||
                    input={"task": task, "agents": agents, "plan_schema": schemas.plan},
 | 
			
		||||
                )
 | 
			
		||||
                initial_message = yield ctx.call_activity(
 | 
			
		||||
                    self.prepare_initial_message,
 | 
			
		||||
                    input={
 | 
			
		||||
                        "instance_id": instance_id,
 | 
			
		||||
                        "task": task,
 | 
			
		||||
                        "agents": agents,
 | 
			
		||||
                        "plan": plan,
 | 
			
		||||
                    },
 | 
			
		||||
                )
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.broadcast_message_to_agents,
 | 
			
		||||
                    input={"instance_id": instance_id, "task": initial_message},
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Step 4: Determine next step and dispatch
 | 
			
		||||
            next_step = yield ctx.call_activity(
 | 
			
		||||
                self.generate_next_step,
 | 
			
		||||
                input={
 | 
			
		||||
                    "instance_id": instance_id,
 | 
			
		||||
                    "task": task,
 | 
			
		||||
                    "agents": agents,
 | 
			
		||||
                    "plan": plan,
 | 
			
		||||
                    "next_step_schema": schemas.next_step,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
            # Additional Properties from NextStep
 | 
			
		||||
            next_agent = next_step["next_agent"]
 | 
			
		||||
            instruction = next_step["instruction"]
 | 
			
		||||
            step_id = next_step.get("step", None)
 | 
			
		||||
            substep_id = next_step.get("substep", None)
 | 
			
		||||
 | 
			
		||||
            # broadcast initial message to all agents
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.broadcast_message_to_agents,
 | 
			
		||||
                input={"instance_id": instance_id, "task": initial_message},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Step 4: Identify agent and instruction for the next step
 | 
			
		||||
        next_step = yield ctx.call_activity(
 | 
			
		||||
            self.generate_next_step,
 | 
			
		||||
            input={
 | 
			
		||||
                "task": task,
 | 
			
		||||
                "agents": agents,
 | 
			
		||||
                "plan": plan,
 | 
			
		||||
                "next_step_schema": schemas.next_step,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Extract Additional Properties from NextStep
 | 
			
		||||
        next_agent = next_step["next_agent"]
 | 
			
		||||
        instruction = next_step["instruction"]
 | 
			
		||||
        step_id = next_step.get("step", None)
 | 
			
		||||
        substep_id = next_step.get("substep", None)
 | 
			
		||||
 | 
			
		||||
        # Step 5: Validate Step Before Proceeding
 | 
			
		||||
        valid_step = yield ctx.call_activity(
 | 
			
		||||
            self.validate_next_step,
 | 
			
		||||
            input={
 | 
			
		||||
                "instance_id": instance_id,
 | 
			
		||||
                "plan": plan,
 | 
			
		||||
                "step": step_id,
 | 
			
		||||
                "substep": substep_id,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if valid_step:
 | 
			
		||||
            # Step 6: Broadcast Task to all Agents
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.broadcast_message_to_agents,
 | 
			
		||||
                input={"instance_id": instance_id, "task": instruction},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Step 7: Trigger next agent
 | 
			
		||||
            plan = yield ctx.call_activity(
 | 
			
		||||
                self.trigger_agent,
 | 
			
		||||
            # Step 5: Validate Step Before Proceeding
 | 
			
		||||
            valid_step = yield ctx.call_activity(
 | 
			
		||||
                self.validate_next_step,
 | 
			
		||||
                input={
 | 
			
		||||
                    "instance_id": instance_id,
 | 
			
		||||
                    "name": next_agent,
 | 
			
		||||
                    "step": step_id,
 | 
			
		||||
                    "substep": substep_id,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Step 8: Wait for agent response or timeout
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(f"Waiting for {next_agent}'s response...")
 | 
			
		||||
 | 
			
		||||
            event_data = ctx.wait_for_external_event("AgentTaskResponse")
 | 
			
		||||
            timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
 | 
			
		||||
            any_results = yield self.when_any([event_data, timeout_task])
 | 
			
		||||
 | 
			
		||||
            if any_results == timeout_task:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
 | 
			
		||||
                )
 | 
			
		||||
                task_results = {
 | 
			
		||||
                    "name": self.name,
 | 
			
		||||
                    "role": "user",
 | 
			
		||||
                    "content": f"Timeout occurred. {next_agent} did not respond on time. We need to try again...",
 | 
			
		||||
                }
 | 
			
		||||
            else:
 | 
			
		||||
                task_results = yield event_data
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(f"{task_results['name']} sent a response.")
 | 
			
		||||
 | 
			
		||||
            # Step 9: Save the task execution results to chat and task history
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.update_task_history,
 | 
			
		||||
                input={
 | 
			
		||||
                    "instance_id": instance_id,
 | 
			
		||||
                    "agent": next_agent,
 | 
			
		||||
                    "step": step_id,
 | 
			
		||||
                    "substep": substep_id,
 | 
			
		||||
                    "results": task_results,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Step 10: Check progress
 | 
			
		||||
            progress = yield ctx.call_activity(
 | 
			
		||||
                self.check_progress,
 | 
			
		||||
                input={
 | 
			
		||||
                    "task": task,
 | 
			
		||||
                    "plan": plan,
 | 
			
		||||
                    "step": step_id,
 | 
			
		||||
                    "substep": substep_id,
 | 
			
		||||
                    "results": task_results["content"],
 | 
			
		||||
                    "progress_check_schema": schemas.progress_check,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(f"Tracking Progress: {progress}")
 | 
			
		||||
 | 
			
		||||
            verdict = progress["verdict"]
 | 
			
		||||
            status_updates = progress.get("plan_status_update", [])
 | 
			
		||||
            plan_updates = progress.get("plan_restructure", [])
 | 
			
		||||
 | 
			
		||||
            # Step 11: Handle verdict and updates
 | 
			
		||||
            if status_updates or plan_updates:
 | 
			
		||||
            if valid_step:
 | 
			
		||||
                # Step 6: Broadcast Task to all Agents
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.update_plan,
 | 
			
		||||
                    self.broadcast_message_to_agents,
 | 
			
		||||
                    input={"instance_id": instance_id, "task": instruction},
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # Step 7: Trigger next agent
 | 
			
		||||
                plan = yield ctx.call_activity(
 | 
			
		||||
                    self.trigger_agent,
 | 
			
		||||
                    input={
 | 
			
		||||
                        "instance_id": instance_id,
 | 
			
		||||
                        "plan": plan,
 | 
			
		||||
                        "status_updates": status_updates,
 | 
			
		||||
                        "plan_updates": plan_updates,
 | 
			
		||||
                        "name": next_agent,
 | 
			
		||||
                        "step": step_id,
 | 
			
		||||
                        "substep": substep_id,
 | 
			
		||||
                    },
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"Step {step_id}, Substep {substep_id} not found in plan for instance {instance_id}. Recovering..."
 | 
			
		||||
            )
 | 
			
		||||
                # Step 8: Wait for agent response or timeout
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.debug(f"Waiting for {next_agent}'s response...")
 | 
			
		||||
 | 
			
		||||
            # Recovery Task: No updates, just iterate again
 | 
			
		||||
            verdict = "continue"
 | 
			
		||||
            status_updates = []
 | 
			
		||||
            plan_updates = []
 | 
			
		||||
            task_results = {
 | 
			
		||||
                "name": "orchestrator",
 | 
			
		||||
                "role": "user",
 | 
			
		||||
                "content": f"Step {step_id}, Substep {substep_id} does not exist in the plan. Adjusting workflow...",
 | 
			
		||||
            }
 | 
			
		||||
                event_data = ctx.wait_for_external_event("AgentTaskResponse")
 | 
			
		||||
                timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
 | 
			
		||||
                any_results = yield self.when_any([event_data, timeout_task])
 | 
			
		||||
 | 
			
		||||
        # Step 12: Process progress suggestions and next iteration count
 | 
			
		||||
        next_iteration_count = iteration + 1
 | 
			
		||||
        if verdict != "continue" or next_iteration_count > self.max_iterations:
 | 
			
		||||
            if next_iteration_count >= self.max_iterations:
 | 
			
		||||
                verdict = "max_iterations_reached"
 | 
			
		||||
                # Step 9: Handle Agent Response or Timeout
 | 
			
		||||
                if any_results == timeout_task:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        f"Agent response timed out (Iteration: {turn}, Instance ID: {instance_id})."
 | 
			
		||||
                    )
 | 
			
		||||
                    task_results = {
 | 
			
		||||
                        "name": self.name,
 | 
			
		||||
                        "role": "user",
 | 
			
		||||
                        "content": f"Timeout occurred. {next_agent} did not respond on time. We need to try again...",
 | 
			
		||||
                    }
 | 
			
		||||
                else:
 | 
			
		||||
                    task_results = yield event_data
 | 
			
		||||
                    if not ctx.is_replaying:
 | 
			
		||||
                        logger.info(f"{task_results['name']} sent a response.")
 | 
			
		||||
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(f"Workflow ending with verdict: {verdict}")
 | 
			
		||||
 | 
			
		||||
            # Generate final summary based on execution
 | 
			
		||||
            summary = yield ctx.call_activity(
 | 
			
		||||
                self.generate_summary,
 | 
			
		||||
                input={
 | 
			
		||||
                    "task": task,
 | 
			
		||||
                    "verdict": verdict,
 | 
			
		||||
                    "plan": plan,
 | 
			
		||||
                    "step": step_id,
 | 
			
		||||
                    "substep": substep_id,
 | 
			
		||||
                    "agent": next_agent,
 | 
			
		||||
                    "result": task_results["content"],
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Finalize the workflow properly
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.finish_workflow,
 | 
			
		||||
                input={
 | 
			
		||||
                    "instance_id": instance_id,
 | 
			
		||||
                    "plan": plan,
 | 
			
		||||
                    "step": step_id,
 | 
			
		||||
                    "substep": substep_id,
 | 
			
		||||
                    "verdict": verdict,
 | 
			
		||||
                    "summary": summary,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"Workflow {instance_id} has been finalized with verdict: {verdict}"
 | 
			
		||||
                # Step 10: Save the task execution results to chat and task history
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.update_task_history,
 | 
			
		||||
                    input={
 | 
			
		||||
                        "instance_id": instance_id,
 | 
			
		||||
                        "agent": next_agent,
 | 
			
		||||
                        "step": step_id,
 | 
			
		||||
                        "substep": substep_id,
 | 
			
		||||
                        "results": task_results,
 | 
			
		||||
                    },
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            return summary
 | 
			
		||||
                # Step 11: Check progress
 | 
			
		||||
                progress = yield ctx.call_activity(
 | 
			
		||||
                    self.check_progress,
 | 
			
		||||
                    input={
 | 
			
		||||
                        "task": task,
 | 
			
		||||
                        "plan": plan,
 | 
			
		||||
                        "step": step_id,
 | 
			
		||||
                        "substep": substep_id,
 | 
			
		||||
                        "results": task_results["content"],
 | 
			
		||||
                        "progress_check_schema": schemas.progress_check,
 | 
			
		||||
                    },
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # Step 13: Update TriggerAction state and continue workflow
 | 
			
		||||
        message["task"] = task_results["content"]
 | 
			
		||||
        message["iteration"] = next_iteration_count
 | 
			
		||||
                # Update verdict and plan based on progress
 | 
			
		||||
                verdict = progress["verdict"]
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.debug(f"Progress verdict: {verdict}")
 | 
			
		||||
                    logger.debug(f"Tracking Progress: {progress}")
 | 
			
		||||
                status_updates = progress.get("plan_status_update", [])
 | 
			
		||||
                plan_updates = progress.get("plan_restructure", [])
 | 
			
		||||
 | 
			
		||||
        # Restart workflow with updated TriggerAction state
 | 
			
		||||
        ctx.continue_as_new(message)
 | 
			
		||||
                # Step 12: Handle verdict and updates
 | 
			
		||||
                if status_updates or plan_updates:
 | 
			
		||||
                    yield ctx.call_activity(
 | 
			
		||||
                        self.update_plan,
 | 
			
		||||
                        input={
 | 
			
		||||
                            "instance_id": instance_id,
 | 
			
		||||
                            "plan": plan,
 | 
			
		||||
                            "status_updates": status_updates,
 | 
			
		||||
                            "plan_updates": plan_updates,
 | 
			
		||||
                        },
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        f"Invalid step {step_id}/{substep_id} in plan for instance {instance_id}. Retrying..."
 | 
			
		||||
                    )
 | 
			
		||||
                # Recovery Task: No updates, just iterate again
 | 
			
		||||
                verdict = "continue"
 | 
			
		||||
                status_updates = []
 | 
			
		||||
                plan_updates = []
 | 
			
		||||
                task_results = {
 | 
			
		||||
                    "name": self.name,
 | 
			
		||||
                    "role": "user",
 | 
			
		||||
                    "content": f"Step {step_id}, Substep {substep_id} does not exist in the plan. Adjusting workflow...",
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            # Step 12: Process progress suggestions and next iteration count
 | 
			
		||||
            if verdict != "continue" or turn == self.max_iterations:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    finale = (
 | 
			
		||||
                        "max_iterations_reached"
 | 
			
		||||
                        if turn == self.max_iterations
 | 
			
		||||
                        else verdict
 | 
			
		||||
                    )
 | 
			
		||||
                    logger.info(f"Ending workflow with verdict: {finale}")
 | 
			
		||||
 | 
			
		||||
                # Generate summary & finish
 | 
			
		||||
                final_summary = yield ctx.call_activity(
 | 
			
		||||
                    self.generate_summary,
 | 
			
		||||
                    input={
 | 
			
		||||
                        "task": task,
 | 
			
		||||
                        "verdict": verdict,
 | 
			
		||||
                        "plan": plan,
 | 
			
		||||
                        "step": step_id,
 | 
			
		||||
                        "substep": substep_id,
 | 
			
		||||
                        "agent": next_agent,
 | 
			
		||||
                        "result": task_results["content"],
 | 
			
		||||
                    },
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # Finalize the workflow properly
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.finish_workflow,
 | 
			
		||||
                    input={
 | 
			
		||||
                        "instance_id": instance_id,
 | 
			
		||||
                        "plan": plan,
 | 
			
		||||
                        "step": step_id,
 | 
			
		||||
                        "substep": substep_id,
 | 
			
		||||
                        "verdict": verdict,
 | 
			
		||||
                        "summary": final_summary,
 | 
			
		||||
                    },
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # Return the final summary
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(f"Workflow {instance_id} finalized.")
 | 
			
		||||
                return final_summary
 | 
			
		||||
 | 
			
		||||
            # --- PREPARE NEXT TURN ---
 | 
			
		||||
            task = task_results["content"]
 | 
			
		||||
 | 
			
		||||
        # Should never reach here
 | 
			
		||||
        raise RuntimeError(f"LLMWorkflow {instance_id} exited without summary")
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    def get_agents_metadata_as_string(self) -> str:
 | 
			
		||||
| 
						 | 
				
			
			@ -778,11 +769,11 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
                    f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
 | 
			
		||||
                )
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
 | 
			
		||||
            # Log the received response
 | 
			
		||||
            logger.debug(
 | 
			
		||||
                f"{self.name} received response for workflow {workflow_instance_id}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            logger.debug(f"Full response: {message}")
 | 
			
		||||
            # Raise a workflow event with the Agent's Task Response
 | 
			
		||||
            self.raise_workflow_event(
 | 
			
		||||
                instance_id=workflow_instance_id,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,9 +33,6 @@ class TriggerAction(BaseModel):
 | 
			
		|||
        None,
 | 
			
		||||
        description="The specific task to execute. If not provided, the agent can act based on its memory or predefined behavior.",
 | 
			
		||||
    )
 | 
			
		||||
    iteration: Optional[int] = Field(
 | 
			
		||||
        default=0, description="The current iteration of the workflow loop."
 | 
			
		||||
    )
 | 
			
		||||
    workflow_instance_id: Optional[str] = Field(
 | 
			
		||||
        default=None, description="Dapr workflow instance id from source if available"
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,43 +4,40 @@ from typing import List, Dict, Any, Optional
 | 
			
		|||
def update_step_statuses(plan: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 | 
			
		||||
    """
 | 
			
		||||
    Ensures step and sub-step statuses follow logical progression:
 | 
			
		||||
    - A step is marked "completed" if all substeps are "completed".
 | 
			
		||||
    - If any sub-step is "in_progress", the parent step must also be "in_progress".
 | 
			
		||||
    - If a sub-step is "completed" but the parent step is "not_started", update it to "in_progress".
 | 
			
		||||
    - If a parent step is "completed" but a substep is still "in_progress", downgrade it to "in_progress".
 | 
			
		||||
    - Steps without substeps should still progress logically.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        plan (List[Dict[str, Any]]): The current execution plan.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        List[Dict[str, Any]]: The updated execution plan with correct statuses.
 | 
			
		||||
      • Parent completes if all substeps complete.
 | 
			
		||||
      • Parent goes in_progress if any substep is in_progress.
 | 
			
		||||
      • If substeps start completing, parent moves in_progress.
 | 
			
		||||
      • If parent was completed but a substep reverts to in_progress, parent downgrades.
 | 
			
		||||
      • Standalone steps (no substeps) are only updated via explicit status_updates.
 | 
			
		||||
    """
 | 
			
		||||
    for step in plan:
 | 
			
		||||
        # Case 0: Handle steps that have NO substeps
 | 
			
		||||
        if "substeps" not in step or not step["substeps"]:
 | 
			
		||||
            if step["status"] == "not_started":
 | 
			
		||||
                step[
 | 
			
		||||
                    "status"
 | 
			
		||||
                ] = "in_progress"  # Independent steps should start when execution begins
 | 
			
		||||
            continue  # Skip further processing if no substeps exist
 | 
			
		||||
        subs = step.get("substeps", None)
 | 
			
		||||
 | 
			
		||||
        substep_statuses = {ss["status"] for ss in step["substeps"]}
 | 
			
		||||
        # --- NO substeps: do nothing here (explicit updates only) ---
 | 
			
		||||
        if subs is None:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # Case 1: If ALL substeps are "completed", parent step must be "completed".
 | 
			
		||||
        if all(status == "completed" for status in substep_statuses):
 | 
			
		||||
        # If substeps is not a list or is an empty list, treat as no‐substeps too:
 | 
			
		||||
        if not isinstance(subs, list) or len(subs) == 0:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # Collect child statuses
 | 
			
		||||
        statuses = {ss["status"] for ss in subs}
 | 
			
		||||
 | 
			
		||||
        # 1. All done → parent done
 | 
			
		||||
        if statuses == {"completed"}:
 | 
			
		||||
            step["status"] = "completed"
 | 
			
		||||
 | 
			
		||||
        # Case 2: If ANY substep is "in_progress", parent step must also be "in_progress".
 | 
			
		||||
        elif "in_progress" in substep_statuses:
 | 
			
		||||
        # 2. Any in_progress → parent in_progress
 | 
			
		||||
        elif "in_progress" in statuses:
 | 
			
		||||
            step["status"] = "in_progress"
 | 
			
		||||
 | 
			
		||||
        # Case 3: If a sub-step was completed but the step is still "not_started", update it.
 | 
			
		||||
        elif "completed" in substep_statuses and step["status"] == "not_started":
 | 
			
		||||
        # 3. Some done, parent not yet started → bump to in_progress
 | 
			
		||||
        elif "completed" in statuses and step["status"] == "not_started":
 | 
			
		||||
            step["status"] = "in_progress"
 | 
			
		||||
 | 
			
		||||
        # Case 4: If the step is already marked as "completed" but a substep is still "in_progress", downgrade it.
 | 
			
		||||
        elif step["status"] == "completed" and "in_progress" in substep_statuses:
 | 
			
		||||
        # 4. If parent was completed but a child is in_progress, downgrade
 | 
			
		||||
        elif step["status"] == "completed" and any(s != "completed" for s in statuses):
 | 
			
		||||
            step["status"] = "in_progress"
 | 
			
		||||
 | 
			
		||||
    return plan
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,13 +1,14 @@
 | 
			
		|||
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext
 | 
			
		||||
from dapr_agents.types import BaseMessage
 | 
			
		||||
from dapr_agents.workflow.decorators import task, workflow, message_router
 | 
			
		||||
from typing import Any, Optional, Dict
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
import random
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types import BaseMessage
 | 
			
		||||
from dapr_agents.workflow.decorators import message_router, task, workflow
 | 
			
		||||
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -35,9 +36,9 @@ class TriggerAction(BaseModel):
 | 
			
		|||
 | 
			
		||||
    task: Optional[str] = Field(
 | 
			
		||||
        None,
 | 
			
		||||
        description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
 | 
			
		||||
        description="The specific task to execute. If not provided, the agent will act "
 | 
			
		||||
        "based on its memory or predefined behavior.",
 | 
			
		||||
    )
 | 
			
		||||
    iteration: Optional[int] = Field(0, description="")
 | 
			
		||||
    workflow_instance_id: Optional[str] = Field(
 | 
			
		||||
        default=None, description="Dapr workflow instance id from source if available"
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -48,177 +49,153 @@ class RandomOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
    Implements a random workflow where agents are selected randomly to perform tasks.
 | 
			
		||||
    The workflow iterates through conversations, selecting a random agent at each step.
 | 
			
		||||
 | 
			
		||||
    Uses `continue_as_new` to persist iteration state.
 | 
			
		||||
    Runs in a single for-loop, breaking when max_iterations is reached.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    current_speaker: Optional[str] = Field(
 | 
			
		||||
        default=None, init=False, description="Current speaker in the conversation."
 | 
			
		||||
        default=None,
 | 
			
		||||
        init=False,
 | 
			
		||||
        description="Current speaker in the conversation, to avoid immediate repeats when possible.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def model_post_init(self, __context: Any) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes and configures the random workflow service.
 | 
			
		||||
        Registers tasks and workflows, then starts the workflow runtime.
 | 
			
		||||
        """
 | 
			
		||||
        self._workflow_name = "RandomWorkflow"
 | 
			
		||||
 | 
			
		||||
        super().model_post_init(__context)
 | 
			
		||||
 | 
			
		||||
    @workflow(name="RandomWorkflow")
 | 
			
		||||
    # TODO: add retry policies on activities.
 | 
			
		||||
    def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction):
 | 
			
		||||
        """
 | 
			
		||||
        Executes a random workflow where agents are selected randomly for interactions.
 | 
			
		||||
        Uses `continue_as_new` to persist iteration state.
 | 
			
		||||
        Executes the random workflow in up to `self.max_iterations` turns, selecting
 | 
			
		||||
        a different (or same) agent at random each turn.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            ctx (DaprWorkflowContext): The workflow execution context.
 | 
			
		||||
            input (TriggerAction): The current workflow state containing `message` and `iteration`.
 | 
			
		||||
            ctx (DaprWorkflowContext): Workflow context.
 | 
			
		||||
            input (TriggerAction): Contains `task`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: The last processed message when the workflow terminates.
 | 
			
		||||
            str: The final message content when the workflow terminates.
 | 
			
		||||
        """
 | 
			
		||||
        # Step 0: Retrieving Loop Context
 | 
			
		||||
        # Step 1: Gather initial task and instance ID
 | 
			
		||||
        task = input.get("task")
 | 
			
		||||
        iteration = input.get("iteration", 0)
 | 
			
		||||
        instance_id = ctx.instance_id
 | 
			
		||||
        final_output: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
        if not ctx.is_replaying:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Random workflow iteration {iteration + 1} started (Instance ID: {instance_id})."
 | 
			
		||||
            )
 | 
			
		||||
        # Single loop from turn 1 to max_iterations inclusive
 | 
			
		||||
        for turn in range(1, self.max_iterations + 1):
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"Random workflow turn {turn}/{self.max_iterations} "
 | 
			
		||||
                    f"(Instance ID: {instance_id})"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # First iteration: Process input and broadcast
 | 
			
		||||
        if iteration == 0:
 | 
			
		||||
            message = yield ctx.call_activity(self.process_input, input={"task": task})
 | 
			
		||||
            logger.info(f"Initial message from {message['role']} -> {self.name}")
 | 
			
		||||
            # Step 2: On turn 1, process initial task and broadcast
 | 
			
		||||
            if turn == 1:
 | 
			
		||||
                message = yield ctx.call_activity(
 | 
			
		||||
                    self.process_input, input={"task": task}
 | 
			
		||||
                )
 | 
			
		||||
                logger.info(f"Initial message from {message['role']} -> {self.name}")
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.broadcast_message_to_agents, input={"message": message}
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Step 1: Broadcast initial message
 | 
			
		||||
            # Step 3: Select a random speaker
 | 
			
		||||
            random_speaker = yield ctx.call_activity(self.select_random_speaker)
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(f"{self.name} selected {random_speaker} (Turn {turn}).")
 | 
			
		||||
 | 
			
		||||
            # Step 4: Trigger the agent
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.broadcast_message_to_agents, input={"message": message}
 | 
			
		||||
                self.trigger_agent,
 | 
			
		||||
                input={"name": random_speaker, "instance_id": instance_id},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Step 2: Select a random speaker
 | 
			
		||||
        random_speaker = yield ctx.call_activity(
 | 
			
		||||
            self.select_random_speaker, input={"iteration": iteration}
 | 
			
		||||
        )
 | 
			
		||||
            # Step 5: Await for agent response or timeout
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.debug("Waiting for agent response...")
 | 
			
		||||
            event_data = ctx.wait_for_external_event("AgentTaskResponse")
 | 
			
		||||
            timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
 | 
			
		||||
            any_results = yield self.when_any([event_data, timeout_task])
 | 
			
		||||
 | 
			
		||||
        # Step 3: Trigger agent
 | 
			
		||||
        yield ctx.call_activity(
 | 
			
		||||
            self.trigger_agent,
 | 
			
		||||
            input={"name": random_speaker, "instance_id": instance_id},
 | 
			
		||||
        )
 | 
			
		||||
            # Step 6: Handle response or timeout
 | 
			
		||||
            if any_results == timeout_task:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        f"Turn {turn}: agent response timed out (Instance ID: {instance_id})."
 | 
			
		||||
                    )
 | 
			
		||||
                result = {
 | 
			
		||||
                    "name": "timeout",
 | 
			
		||||
                    "content": "⏰ Timeout occurred. Continuing...",
 | 
			
		||||
                }
 | 
			
		||||
            else:
 | 
			
		||||
                result = yield event_data
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(f"{result['name']} -> {self.name}")
 | 
			
		||||
 | 
			
		||||
        # Step 4: Wait for response or timeout
 | 
			
		||||
        logger.info("Waiting for agent response...")
 | 
			
		||||
        event_data = ctx.wait_for_external_event("AgentTaskResponse")
 | 
			
		||||
        timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
 | 
			
		||||
        any_results = yield self.when_any([event_data, timeout_task])
 | 
			
		||||
            # Step 7: If this is the last allowed turn, mark final_output and break
 | 
			
		||||
            if turn == self.max_iterations:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(
 | 
			
		||||
                        f"Turn {turn}: max iterations reached (Instance ID: {instance_id})."
 | 
			
		||||
                    )
 | 
			
		||||
                final_output = result["content"]
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        if any_results == timeout_task:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
 | 
			
		||||
            # Otherwise, feed into next turn
 | 
			
		||||
            task = result["content"]
 | 
			
		||||
 | 
			
		||||
        # Sanity check (should never happen)
 | 
			
		||||
        if final_output is None:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                "RandomWorkflow completed without producing a final_output"
 | 
			
		||||
            )
 | 
			
		||||
            task_results = {
 | 
			
		||||
                "name": "timeout",
 | 
			
		||||
                "content": "Timeout occurred. Continuing...",
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            task_results = yield event_data
 | 
			
		||||
            logger.info(f"{task_results['name']} -> {self.name}")
 | 
			
		||||
 | 
			
		||||
        # Step 5: Check Iteration
 | 
			
		||||
        next_iteration_count = iteration + 1
 | 
			
		||||
        if next_iteration_count > self.max_iterations:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Max iterations reached. Ending random workflow (Instance ID: {instance_id})."
 | 
			
		||||
            )
 | 
			
		||||
            return task_results["content"]
 | 
			
		||||
 | 
			
		||||
        # Update ChatLoop for next iteration
 | 
			
		||||
        input["task"] = task_results["content"]
 | 
			
		||||
        input["iteration"] = next_iteration_count
 | 
			
		||||
 | 
			
		||||
        # Restart workflow with updated state
 | 
			
		||||
        # TODO: would we want this updated to preserve agent state between iterations?
 | 
			
		||||
        ctx.continue_as_new(input)
 | 
			
		||||
        # Return the final message content
 | 
			
		||||
        return final_output
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    async def process_input(self, task: str):
 | 
			
		||||
    async def process_input(self, task: str) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Processes the input message for the workflow.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            task (str): The user-provided input task.
 | 
			
		||||
        Returns:
 | 
			
		||||
            dict: Serialized UserMessage with the content.
 | 
			
		||||
        Wraps the raw task into a UserMessage dict.
 | 
			
		||||
        """
 | 
			
		||||
        return {"role": "user", "name": self.name, "content": task}
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    async def broadcast_message_to_agents(self, message: Dict[str, Any]):
 | 
			
		||||
        """
 | 
			
		||||
        Broadcasts a message to all agents.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            message (Dict[str, Any]): The message content and additional metadata.
 | 
			
		||||
        Broadcasts a message to all agents (excluding orchestrator).
 | 
			
		||||
        """
 | 
			
		||||
        # Format message for broadcasting
 | 
			
		||||
        task_message = BroadcastMessage(**message)
 | 
			
		||||
 | 
			
		||||
        # Send broadcast message
 | 
			
		||||
        await self.broadcast_message(message=task_message, exclude_orchestrator=True)
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    def select_random_speaker(self, iteration: int) -> str:
 | 
			
		||||
    def select_random_speaker(self) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Selects a random speaker, ensuring that a different agent is chosen if possible.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            iteration (int): The current iteration number.
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: The name of the randomly selected agent.
 | 
			
		||||
        Selects a random speaker, avoiding repeats when possible.
 | 
			
		||||
        """
 | 
			
		||||
        agents_metadata = self.get_agents_metadata(exclude_orchestrator=True)
 | 
			
		||||
        if not agents_metadata:
 | 
			
		||||
            logger.warning("No agents available for selection.")
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Agents metadata is empty. Cannot select a random speaker."
 | 
			
		||||
            )
 | 
			
		||||
        agents = self.get_agents_metadata(exclude_orchestrator=True)
 | 
			
		||||
        if not agents:
 | 
			
		||||
            logger.error("No agents available for selection.")
 | 
			
		||||
            raise ValueError("Agents list is empty.")
 | 
			
		||||
 | 
			
		||||
        agent_names = list(agents_metadata.keys())
 | 
			
		||||
        names = list(agents.keys())
 | 
			
		||||
        # Avoid repeating previous speaker if more than one agent
 | 
			
		||||
        if len(names) > 1 and self.current_speaker in names:
 | 
			
		||||
            names.remove(self.current_speaker)
 | 
			
		||||
 | 
			
		||||
        # Handle single-agent scenarios
 | 
			
		||||
        if len(agent_names) == 1:
 | 
			
		||||
            random_speaker = agent_names[0]
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Only one agent available: {random_speaker}. Using the same agent."
 | 
			
		||||
            )
 | 
			
		||||
            return random_speaker
 | 
			
		||||
 | 
			
		||||
        # Select a random speaker, avoiding repeating the previous speaker when possible
 | 
			
		||||
        previous_speaker = getattr(self, "current_speaker", None)
 | 
			
		||||
        if previous_speaker in agent_names and len(agent_names) > 1:
 | 
			
		||||
            agent_names.remove(previous_speaker)
 | 
			
		||||
 | 
			
		||||
        random_speaker = random.choice(agent_names)
 | 
			
		||||
        self.current_speaker = random_speaker
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"{self.name} randomly selected agent {random_speaker} (Iteration: {iteration})."
 | 
			
		||||
        )
 | 
			
		||||
        return random_speaker
 | 
			
		||||
        choice = random.choice(names)
 | 
			
		||||
        self.current_speaker = choice
 | 
			
		||||
        return choice
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    async def trigger_agent(self, name: str, instance_id: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Triggers the specified agent to perform its activity.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            name (str): Name of the agent to trigger.
 | 
			
		||||
            instance_id (str): Workflow instance ID for context.
 | 
			
		||||
        Sends a TriggerAction to the selected agent.
 | 
			
		||||
        """
 | 
			
		||||
        logger.info(f"Triggering agent {name} (Instance ID: {instance_id})")
 | 
			
		||||
 | 
			
		||||
        await self.send_message_to_agent(
 | 
			
		||||
            name=name,
 | 
			
		||||
            message=TriggerAction(workflow_instance_id=instance_id),
 | 
			
		||||
| 
						 | 
				
			
			@ -227,33 +204,20 @@ class RandomOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
    @message_router
 | 
			
		||||
    async def process_agent_response(self, message: AgentTaskResponse):
 | 
			
		||||
        """
 | 
			
		||||
        Processes agent response messages sent directly to the agent's topic.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            message (AgentTaskResponse): The agent's response containing task results.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            None: The function raises a workflow event with the agent's response.
 | 
			
		||||
        Handles incoming AgentTaskResponse events and re-raises them into the workflow.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            workflow_instance_id = getattr(message, "workflow_instance_id", None)
 | 
			
		||||
 | 
			
		||||
            if not workflow_instance_id:
 | 
			
		||||
                logger.error(
 | 
			
		||||
                    f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
 | 
			
		||||
                )
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Raise a workflow event with the Agent's Task Response
 | 
			
		||||
            self.raise_workflow_event(
 | 
			
		||||
                instance_id=workflow_instance_id,
 | 
			
		||||
                event_name="AgentTaskResponse",
 | 
			
		||||
                data=message,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error processing agent response: {e}", exc_info=True)
 | 
			
		||||
        workflow_instance_id = getattr(message, "workflow_instance_id", None)
 | 
			
		||||
        if not workflow_instance_id:
 | 
			
		||||
            logger.error("Missing workflow_instance_id on AgentTaskResponse; ignoring.")
 | 
			
		||||
            return
 | 
			
		||||
        # Log the received response
 | 
			
		||||
        logger.debug(
 | 
			
		||||
            f"{self.name} received response for workflow {workflow_instance_id}"
 | 
			
		||||
        )
 | 
			
		||||
        logger.debug(f"Full response: {message}")
 | 
			
		||||
        # Raise a workflow event with the Agent's Task Response
 | 
			
		||||
        self.raise_workflow_event(
 | 
			
		||||
            instance_id=workflow_instance_id,
 | 
			
		||||
            event_name="AgentTaskResponse",
 | 
			
		||||
            data=message,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,11 +1,13 @@
 | 
			
		|||
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext
 | 
			
		||||
from dapr_agents.types import BaseMessage
 | 
			
		||||
from dapr_agents.workflow.decorators import task, workflow, message_router
 | 
			
		||||
from typing import Any, Optional, Dict
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
import logging
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types import BaseMessage
 | 
			
		||||
from dapr_agents.workflow.decorators import message_router, task, workflow
 | 
			
		||||
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -33,9 +35,9 @@ class TriggerAction(BaseModel):
 | 
			
		|||
 | 
			
		||||
    task: Optional[str] = Field(
 | 
			
		||||
        None,
 | 
			
		||||
        description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
 | 
			
		||||
        description="The specific task to execute. If not provided, the agent will act "
 | 
			
		||||
        "based on its memory or predefined behavior.",
 | 
			
		||||
    )
 | 
			
		||||
    iteration: Optional[int] = Field(0, description="")
 | 
			
		||||
    workflow_instance_id: Optional[str] = Field(
 | 
			
		||||
        default=None, description="Dapr workflow instance id from source if available"
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -44,110 +46,110 @@ class TriggerAction(BaseModel):
 | 
			
		|||
class RoundRobinOrchestrator(OrchestratorWorkflowBase):
 | 
			
		||||
    """
 | 
			
		||||
    Implements a round-robin workflow where agents take turns performing tasks.
 | 
			
		||||
    The workflow iterates through conversations by selecting agents in a circular order.
 | 
			
		||||
 | 
			
		||||
    Uses `continue_as_new` to persist iteration state.
 | 
			
		||||
    Iterates for up to `self.max_iterations` turns, then returns the last reply.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def model_post_init(self, __context: Any) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes and configures the round-robin workflow service.
 | 
			
		||||
        Registers tasks and workflows, then starts the workflow runtime.
 | 
			
		||||
        Initializes and configures the round-robin workflow.
 | 
			
		||||
        """
 | 
			
		||||
        self._workflow_name = "RoundRobinWorkflow"
 | 
			
		||||
        super().model_post_init(__context)
 | 
			
		||||
 | 
			
		||||
    @workflow(name="RoundRobinWorkflow")
 | 
			
		||||
    # TODO: add retry policies on activities.
 | 
			
		||||
    def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction):
 | 
			
		||||
    def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Executes a round-robin workflow where agents interact iteratively.
 | 
			
		||||
 | 
			
		||||
        Steps:
 | 
			
		||||
        1. Processes input and broadcasts the initial message.
 | 
			
		||||
        2. Iterates through agents, selecting a speaker each round.
 | 
			
		||||
        3. Waits for agent responses or handles timeouts.
 | 
			
		||||
        4. Updates the workflow state and continues the loop.
 | 
			
		||||
        5. Terminates when max iterations are reached.
 | 
			
		||||
 | 
			
		||||
        Uses `continue_as_new` to persist iteration state.
 | 
			
		||||
        Drives the round-robin loop in up to `max_iterations` turns.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            ctx (DaprWorkflowContext): The workflow execution context.
 | 
			
		||||
            input (TriggerAction): The current workflow state containing task and iteration.
 | 
			
		||||
            ctx (DaprWorkflowContext): Workflow context.
 | 
			
		||||
            input (TriggerAction): Contains the initial `task`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: The last processed message when the workflow terminates.
 | 
			
		||||
            str: The final message content when the workflow terminates.
 | 
			
		||||
        """
 | 
			
		||||
        # Step 1: Extract task and instance ID from input
 | 
			
		||||
        task = input.get("task")
 | 
			
		||||
        iteration = input.get("iteration", 0)
 | 
			
		||||
        instance_id = ctx.instance_id
 | 
			
		||||
        final_output: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
        if not ctx.is_replaying:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Round-robin iteration {iteration + 1} started (Instance ID: {instance_id})."
 | 
			
		||||
        # Loop from 1..max_iterations
 | 
			
		||||
        for turn in range(1, self.max_iterations + 1):
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"Round-robin turn {turn}/{self.max_iterations} "
 | 
			
		||||
                    f"(Instance ID: {instance_id})"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Step 2: On turn 1, process input and broadcast message
 | 
			
		||||
            if turn == 1:
 | 
			
		||||
                message = yield ctx.call_activity(
 | 
			
		||||
                    self.process_input, input={"task": task}
 | 
			
		||||
                )
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(
 | 
			
		||||
                        f"Initial message from {message['role']} -> {self.name}"
 | 
			
		||||
                    )
 | 
			
		||||
                yield ctx.call_activity(
 | 
			
		||||
                    self.broadcast_message_to_agents, input={"message": message}
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Step 3: Select next speaker in round-robin order
 | 
			
		||||
            speaker = yield ctx.call_activity(
 | 
			
		||||
                self.select_next_speaker, input={"turn": turn}
 | 
			
		||||
            )
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.info(f"Selected agent {speaker} for turn {turn}")
 | 
			
		||||
 | 
			
		||||
        # Check Termination Condition
 | 
			
		||||
        if iteration >= self.max_iterations:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Max iterations reached. Ending round-robin workflow (Instance ID: {instance_id})."
 | 
			
		||||
            )
 | 
			
		||||
            return task
 | 
			
		||||
 | 
			
		||||
        # First iteration: Process input and broadcast
 | 
			
		||||
        if iteration == 0:
 | 
			
		||||
            message = yield ctx.call_activity(self.process_input, input={"task": task})
 | 
			
		||||
            logger.info(f"Initial message from {message['role']} -> {self.name}")
 | 
			
		||||
 | 
			
		||||
            # Broadcast initial message
 | 
			
		||||
            # Step 4: Trigger that agent
 | 
			
		||||
            yield ctx.call_activity(
 | 
			
		||||
                self.broadcast_message_to_agents, input={"message": message}
 | 
			
		||||
                self.trigger_agent,
 | 
			
		||||
                input={"name": speaker, "instance_id": instance_id},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Select next speaker
 | 
			
		||||
        next_speaker = yield ctx.call_activity(
 | 
			
		||||
            self.select_next_speaker, input={"iteration": iteration}
 | 
			
		||||
        )
 | 
			
		||||
            # Step 5: Wait for agent response or timeout
 | 
			
		||||
            if not ctx.is_replaying:
 | 
			
		||||
                logger.debug("Waiting for agent response...")
 | 
			
		||||
            event_data = ctx.wait_for_external_event("AgentTaskResponse")
 | 
			
		||||
            timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
 | 
			
		||||
            any_results = yield self.when_any([event_data, timeout_task])
 | 
			
		||||
 | 
			
		||||
        # Trigger agent
 | 
			
		||||
        yield ctx.call_activity(
 | 
			
		||||
            self.trigger_agent, input={"name": next_speaker, "instance_id": instance_id}
 | 
			
		||||
        )
 | 
			
		||||
            # Step 6: Handle result or timeout
 | 
			
		||||
            if any_results == timeout_task:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        f"Turn {turn}: response timed out "
 | 
			
		||||
                        f"(Instance ID: {instance_id})"
 | 
			
		||||
                    )
 | 
			
		||||
                result = {
 | 
			
		||||
                    "name": "timeout",
 | 
			
		||||
                    "content": "Timeout occurred. Continuing...",
 | 
			
		||||
                }
 | 
			
		||||
            else:
 | 
			
		||||
                result = yield event_data
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(f"{result['name']} -> {self.name}")
 | 
			
		||||
 | 
			
		||||
        # Wait for response or timeout
 | 
			
		||||
        logger.info("Waiting for agent response...")
 | 
			
		||||
        event_data = ctx.wait_for_external_event("AgentTaskResponse")
 | 
			
		||||
        timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
 | 
			
		||||
        any_results = yield self.when_any([event_data, timeout_task])
 | 
			
		||||
            # Step 7: If this is the last allowed turn, capture and break
 | 
			
		||||
            if turn == self.max_iterations:
 | 
			
		||||
                if not ctx.is_replaying:
 | 
			
		||||
                    logger.info(
 | 
			
		||||
                        f"Turn {turn}: max iterations reached (Instance ID: {instance_id})."
 | 
			
		||||
                    )
 | 
			
		||||
                final_output = result["content"]
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        if any_results == timeout_task:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
 | 
			
		||||
            # Otherwise, feed into next iteration
 | 
			
		||||
            task = result["content"]
 | 
			
		||||
 | 
			
		||||
        # Sanity check: final_output must be set
 | 
			
		||||
        if final_output is None:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                "RoundRobinWorkflow completed without producing final_output"
 | 
			
		||||
            )
 | 
			
		||||
            task_results = {
 | 
			
		||||
                "name": "timeout",
 | 
			
		||||
                "content": "Timeout occurred. Continuing...",
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            task_results = yield event_data
 | 
			
		||||
            logger.info(f"{task_results['name']} -> {self.name}")
 | 
			
		||||
 | 
			
		||||
        # Check Iteration
 | 
			
		||||
        next_iteration_count = iteration + 1
 | 
			
		||||
        if next_iteration_count > self.max_iterations:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Max iterations reached. Ending round-robin workflow (Instance ID: {instance_id})."
 | 
			
		||||
            )
 | 
			
		||||
            return task_results["content"]
 | 
			
		||||
 | 
			
		||||
        # Update for next iteration
 | 
			
		||||
        input["task"] = task_results["content"]
 | 
			
		||||
        input["iteration"] = next_iteration_count
 | 
			
		||||
 | 
			
		||||
        # Restart workflow with updated state
 | 
			
		||||
        # TODO: would we want this updated to preserve agent state between iterations?
 | 
			
		||||
        ctx.continue_as_new(input)
 | 
			
		||||
        return final_output
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    async def process_input(self, task: str) -> Dict[str, Any]:
 | 
			
		||||
| 
						 | 
				
			
			@ -171,17 +173,16 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
        """
 | 
			
		||||
        # Format message for broadcasting
 | 
			
		||||
        task_message = BroadcastMessage(**message)
 | 
			
		||||
 | 
			
		||||
        # Send broadcast message
 | 
			
		||||
        await self.broadcast_message(message=task_message, exclude_orchestrator=True)
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
    async def select_next_speaker(self, iteration: int) -> str:
 | 
			
		||||
    async def select_next_speaker(self, turn: int) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Selects the next speaker in round-robin order.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            iteration (int): The current iteration number.
 | 
			
		||||
            turn (int): The current turn number (1-based).
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: The name of the selected agent.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -191,12 +192,7 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
            raise ValueError("Agents metadata is empty. Cannot select next speaker.")
 | 
			
		||||
 | 
			
		||||
        agent_names = list(agents_metadata.keys())
 | 
			
		||||
 | 
			
		||||
        # Determine the next agent in the round-robin order
 | 
			
		||||
        next_speaker = agent_names[iteration % len(agent_names)]
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"{self.name} selected agent {next_speaker} for iteration {iteration}."
 | 
			
		||||
        )
 | 
			
		||||
        next_speaker = agent_names[(turn - 1) % len(agent_names)]
 | 
			
		||||
        return next_speaker
 | 
			
		||||
 | 
			
		||||
    @task
 | 
			
		||||
| 
						 | 
				
			
			@ -208,6 +204,7 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
            name (str): Name of the agent to trigger.
 | 
			
		||||
            instance_id (str): Workflow instance ID for context.
 | 
			
		||||
        """
 | 
			
		||||
        logger.info(f"Triggering agent {name} (Instance ID: {instance_id})")
 | 
			
		||||
        await self.send_message_to_agent(
 | 
			
		||||
            name=name,
 | 
			
		||||
            message=TriggerAction(workflow_instance_id=instance_id),
 | 
			
		||||
| 
						 | 
				
			
			@ -224,25 +221,21 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
 | 
			
		|||
        Returns:
 | 
			
		||||
            None: The function raises a workflow event with the agent's response.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            workflow_instance_id = getattr(message, "workflow_instance_id", None)
 | 
			
		||||
        workflow_instance_id = getattr(message, "workflow_instance_id", None)
 | 
			
		||||
 | 
			
		||||
            if not workflow_instance_id:
 | 
			
		||||
                logger.error(
 | 
			
		||||
                    f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
 | 
			
		||||
                )
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
 | 
			
		||||
        if not workflow_instance_id:
 | 
			
		||||
            logger.error(
 | 
			
		||||
                f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Raise a workflow event with the Agent's Task Response
 | 
			
		||||
            self.raise_workflow_event(
 | 
			
		||||
                instance_id=workflow_instance_id,
 | 
			
		||||
                event_name="AgentTaskResponse",
 | 
			
		||||
                data=message,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error processing agent response: {e}", exc_info=True)
 | 
			
		||||
            return
 | 
			
		||||
        # Log the received response
 | 
			
		||||
        logger.debug(
 | 
			
		||||
            f"{self.name} received response for workflow {workflow_instance_id}"
 | 
			
		||||
        )
 | 
			
		||||
        logger.debug(f"Full response: {message}")
 | 
			
		||||
        # Raise a workflow event with the Agent's Task Response
 | 
			
		||||
        self.raise_workflow_event(
 | 
			
		||||
            instance_id=workflow_instance_id,
 | 
			
		||||
            event_name="AgentTaskResponse",
 | 
			
		||||
            data=message,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@ from dapr_agents.llm.chat import ChatClientBase
 | 
			
		|||
from dapr_agents.llm.openai import OpenAIChatClient
 | 
			
		||||
from dapr_agents.llm.utils import StructureHandler
 | 
			
		||||
from dapr_agents.prompt.utils.chat import ChatPromptHelper
 | 
			
		||||
from dapr_agents.types import BaseMessage, ChatCompletion, UserMessage
 | 
			
		||||
from dapr_agents.types import BaseMessage, UserMessage, LLMChatResponse
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -261,23 +261,35 @@ class WorkflowTask(BaseModel):
 | 
			
		|||
        Unwrap AI return types into plain Python.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            result: ChatCompletion, BaseModel, or list of BaseModel.
 | 
			
		||||
            result: One of:
 | 
			
		||||
                - LLMChatResponse
 | 
			
		||||
                - BaseModel (Pydantic)
 | 
			
		||||
                - List[BaseModel]
 | 
			
		||||
                - primitive (str/int/etc) or dict
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A primitive, dict, or list of dicts.
 | 
			
		||||
            • str (assistant content) when `LLMChatResponse`
 | 
			
		||||
            • dict when a single BaseModel
 | 
			
		||||
            • List[dict] when a list of BaseModels
 | 
			
		||||
            • otherwise, the raw `result`
 | 
			
		||||
        """
 | 
			
		||||
        # Unwrap ChatCompletion
 | 
			
		||||
        if isinstance(result, ChatCompletion):
 | 
			
		||||
            logger.debug("Extracted message content from ChatCompletion.")
 | 
			
		||||
            return result.get_content()
 | 
			
		||||
        # Pydantic → dict
 | 
			
		||||
        # 1) Unwrap our unified LLMChatResponse → return the assistant's text
 | 
			
		||||
        if isinstance(result, LLMChatResponse):
 | 
			
		||||
            logger.debug("Extracted message content from LLMChatResponse.")
 | 
			
		||||
            msg = result.get_message()
 | 
			
		||||
            return getattr(msg, "content", None)
 | 
			
		||||
 | 
			
		||||
        # 2) Single Pydantic model → dict
 | 
			
		||||
        if isinstance(result, BaseModel):
 | 
			
		||||
            logger.debug("Converting Pydantic model to dictionary.")
 | 
			
		||||
            return result.model_dump()
 | 
			
		||||
 | 
			
		||||
        # 3) List of Pydantic models → list of dicts
 | 
			
		||||
        if isinstance(result, list) and all(isinstance(x, BaseModel) for x in result):
 | 
			
		||||
            logger.debug("Converting list of Pydantic models to list of dictionaries.")
 | 
			
		||||
            return [x.model_dump() for x in result]
 | 
			
		||||
        # If no specific conversion is necessary, return as-is
 | 
			
		||||
 | 
			
		||||
        # 4) Fallback: primitive, dict, etc.
 | 
			
		||||
        logger.info("Returning final task result.")
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -137,8 +137,7 @@ durable_agent = DurableAgent(
 | 
			
		|||
    state_key="workflow_state",
 | 
			
		||||
    agents_registry_store_name="agentstatestore",
 | 
			
		||||
    agents_registry_key="agents_registry",
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
),
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Agent Patterns
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,8 +1,18 @@
 | 
			
		|||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponse
 | 
			
		||||
 | 
			
		||||
# load environment variables from .env file
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Initialize the OpenAI chat client
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
response = llm.generate("Tell me a joke")
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Got response:", response.get_content())
 | 
			
		||||
 | 
			
		||||
# Generate a response from the LLM
 | 
			
		||||
response: LLMChatResponse = llm.generate("Tell me a joke")
 | 
			
		||||
 | 
			
		||||
# Print the Message content if it exists
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    content = response.get_message().content
 | 
			
		||||
    print("Got response:", content)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,11 +1,15 @@
 | 
			
		|||
import logging
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.document.embedder.sentence import SentenceTransformerEmbedder
 | 
			
		||||
from dapr_agents.storage.vectorstores import ChromaVectorStore
 | 
			
		||||
from dapr_agents.tool import tool
 | 
			
		||||
from dapr_agents.types.document import Document
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
embedding_function = SentenceTransformerEmbedder(model="all-MiniLM-L6-v2")
 | 
			
		||||
| 
						 | 
				
			
			@ -75,8 +79,6 @@ def add_machine_learning_doc() -> str:
 | 
			
		|||
 | 
			
		||||
async def main():
 | 
			
		||||
    # Seed the vector store with initial documents using Document class
 | 
			
		||||
    from dapr_agents.types.document import Document
 | 
			
		||||
 | 
			
		||||
    documents = [
 | 
			
		||||
        Document(
 | 
			
		||||
            text="Gandalf: A wizard is never late, Frodo Baggins. Nor is he early; he arrives precisely when he means to.",
 | 
			
		||||
| 
						 | 
				
			
			@ -127,5 +129,3 @@ if __name__ == "__main__":
 | 
			
		|||
        print("\nInterrupted by user. Exiting gracefully...")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"\nError occurred: {e}")
 | 
			
		||||
        print("Make sure you have the required dependencies installed:")
 | 
			
		||||
        print(" pip install sentence-transformers chromadb")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -69,13 +69,24 @@ python 01_ask_llm.py
 | 
			
		|||
This example demonstrates the simplest way to use Dapr Agents' OpenAIChatClient:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponse
 | 
			
		||||
 | 
			
		||||
# load environment variables from .env file
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Initialize the OpenAI chat client
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
response = llm.generate("Tell me a joke")
 | 
			
		||||
print("Got response:", response.get_content())
 | 
			
		||||
 | 
			
		||||
# Generate a response from the LLM
 | 
			
		||||
response: LLMChatResponse = llm.generate("Tell me a joke")
 | 
			
		||||
 | 
			
		||||
# Print the Message content if it exists
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    content = response.get_message().content
 | 
			
		||||
    print("Got response:", content)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Expected output:** The LLM will respond with a joke.
 | 
			
		||||
| 
						 | 
				
			
			@ -356,6 +367,9 @@ This example demonstrates how to create an agent with vector store capabilities,
 | 
			
		|||
 | 
			
		||||
```python
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.document.embedder.sentence import SentenceTransformerEmbedder
 | 
			
		||||
from dapr_agents.storage.vectorstores import ChromaVectorStore
 | 
			
		||||
| 
						 | 
				
			
			@ -480,8 +494,6 @@ if __name__ == "__main__":
 | 
			
		|||
        print("\nInterrupted by user. Exiting gracefully...")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"\nError occurred: {e}")
 | 
			
		||||
        print("Make sure you have the required dependencies installed:")
 | 
			
		||||
        print(" pip install sentence-transformers chromadb")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Key Concepts
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -69,27 +69,34 @@ dapr run --app-id dapr-llm --resources-path ./components -- python text_completi
 | 
			
		|||
The script uses the `DaprChatClient` which connects to Dapr's `echo` LLM component:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dapr_agents.llm import DaprChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm import DaprChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = DaprChatClient()
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
print("Response: ", response.get_content())
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = DaprChatClient.from_prompty('basic.prompty')
 | 
			
		||||
response = llm.generate(input_data={"question":"What is your name?"})
 | 
			
		||||
print("Response with prompty: ", response.get_content())
 | 
			
		||||
llm = DaprChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = DaprChatClient()
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
print("Response with user input: ", response.get_content())
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
if response.get_message() is not None and "hello" in response.get_message().content.lower():
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Expected output:** The echo component will simply return the prompts that were sent to it.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,7 @@ model:
 | 
			
		|||
    api: chat
 | 
			
		||||
    configuration:
 | 
			
		||||
        type: nvidia
 | 
			
		||||
        name: meta/llama3-8b-instruct
 | 
			
		||||
        name: HuggingFaceTB/SmolLM2-1.7B-Instruct
 | 
			
		||||
    parameters:
 | 
			
		||||
        max_tokens: 128
 | 
			
		||||
        temperature: 0.2
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,28 +1,31 @@
 | 
			
		|||
from dapr_agents.llm import DaprChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm import DaprChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = DaprChatClient()
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = DaprChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response with prompty: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = DaprChatClient()
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
 | 
			
		||||
    print("Response with user input: ", response.get_content())
 | 
			
		||||
if (
 | 
			
		||||
    response.get_message() is not None
 | 
			
		||||
    and "hello" in response.get_message().content.lower()
 | 
			
		||||
):
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,32 +46,184 @@ python text_completion.py
 | 
			
		|||
The script demonstrates basic usage of the DaprChatClient for text generation:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dapr_agents.llm import HFHubChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm import HFHubChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = HFHubChatClient(
 | 
			
		||||
    model="microsoft/Phi-3-mini-4k-instruct"
 | 
			
		||||
)
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = HFHubChatClient.from_prompty('basic.prompty')
 | 
			
		||||
response = llm.generate(input_data={"question":"What is your name?"})
 | 
			
		||||
llm = HFHubChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response with prompty: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = HFHubChatClient(model="microsoft/Phi-3-mini-4k-instruct")
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
print("Response with user input: ", response.get_content())
 | 
			
		||||
```
 | 
			
		||||
if response.get_message() is not None and "hello" in response.get_message().content.lower():
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**2. Expected output:** The LLM will respond with the name of a famous dog (e.g., "Lassie", "Hachiko", etc.).
 | 
			
		||||
 | 
			
		||||
**Run the structured text completion example:**
 | 
			
		||||
 | 
			
		||||
<!-- STEP
 | 
			
		||||
name: Run text completion example
 | 
			
		||||
expected_stdout_lines:
 | 
			
		||||
  - '"name":'
 | 
			
		||||
  - '"breed":'
 | 
			
		||||
  - '"reason":'
 | 
			
		||||
timeout_seconds: 30
 | 
			
		||||
output_match_mode: substring
 | 
			
		||||
-->
 | 
			
		||||
```bash
 | 
			
		||||
python structured_completion.py
 | 
			
		||||
```
 | 
			
		||||
<!-- END_STEP -->
 | 
			
		||||
 | 
			
		||||
This example shows how to use Pydantic models to get structured data from LLMs:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents import HFHubChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define our data model
 | 
			
		||||
class Dog(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    breed: str
 | 
			
		||||
    reason: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Initialize the chat client
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
 | 
			
		||||
# Get structured response
 | 
			
		||||
response: Dog = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")], response_format=Dog
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
print(json.dumps(response.model_dump(), indent=2))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Expected output:** A JSON object with name, breed, and reason fields
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
{
 | 
			
		||||
  "name": "Dog",
 | 
			
		||||
  "breed": "Siberian Husky",
 | 
			
		||||
  "reason": "Known for its endurance, intelligence, and loyalty, Siberian Huskies have played crucial roles in dog sledding and have been beloved companions for many."
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Streaming
 | 
			
		||||
 | 
			
		||||
Our Hugging Face chat client also support streaming responses, where you can process partial results as they arrive. Below are two examples:
 | 
			
		||||
 | 
			
		||||
**1. Basic Streaming Example**
 | 
			
		||||
 | 
			
		||||
Run the `text_completion_stream.py` script to see token‐by‐token output:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python text_completion_stream.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The scripts:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from dapr_agents import HFHubChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate("Name a famous dog!", stream=True)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    if chunk.result.content:
 | 
			
		||||
        print(chunk.result.content, end="", flush=True)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This will print each partial chunk as it arrives, so you can build up the full answer in real time.
 | 
			
		||||
 | 
			
		||||
**2. Streaming with Tool Calls:**
 | 
			
		||||
 | 
			
		||||
Use `text_completion_with_tools.py` to combine streaming with function‐call “tools”:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python text_completion_with_tools.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from dapr_agents import HFHubChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Initialize client
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B", hf_provider="auto")
 | 
			
		||||
 | 
			
		||||
# Define a simple addition tool
 | 
			
		||||
def add_numbers(a: int, b: int) -> int:
 | 
			
		||||
    return a + b
 | 
			
		||||
 | 
			
		||||
add_tool = {
 | 
			
		||||
    "type": "function",
 | 
			
		||||
    "function": {
 | 
			
		||||
        "name": "add_numbers",
 | 
			
		||||
        "description": "Add two numbers together.",
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "type": "object",
 | 
			
		||||
            "properties": {
 | 
			
		||||
                "a": {"type": "integer", "description": "The first number."},
 | 
			
		||||
                "b": {"type": "integer", "description": "The second number."}
 | 
			
		||||
            },
 | 
			
		||||
            "required": ["a", "b"]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
messages = [
 | 
			
		||||
    {"role": "system", "content": "You are a helpful assistant."},
 | 
			
		||||
    {"role": "user", "content": "Add 5 and 7 and 2 and 2."}
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(
 | 
			
		||||
    messages=messages,
 | 
			
		||||
    tools=[add_tool],
 | 
			
		||||
    stream=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    print(chunk.result)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Here, the model can decide to call your add_numbers function mid‐stream, and you’ll see those calls (and their results) as they come in.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
import json
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents import HFHubChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define our data model
 | 
			
		||||
class Dog(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    breed: str
 | 
			
		||||
    reason: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Initialize the chat client
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
 | 
			
		||||
# Get structured response
 | 
			
		||||
response: Dog = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")], response_format=Dog
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
print(json.dumps(response.model_dump(), indent=2))
 | 
			
		||||
| 
						 | 
				
			
			@ -1,28 +1,30 @@
 | 
			
		|||
from dapr_agents.llm import HFHubChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents.llm import HFHubChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = HFHubChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response with prompty: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
 | 
			
		||||
    print("Response with user input: ", response.get_content())
 | 
			
		||||
if (
 | 
			
		||||
    response.get_message() is not None
 | 
			
		||||
    and "hello" in response.get_message().content.lower()
 | 
			
		||||
):
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import HFHubChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(
 | 
			
		||||
    "Name a famous dog!", stream=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    if chunk.result.content:
 | 
			
		||||
        print(chunk.result.content, end="", flush=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,50 @@
 | 
			
		|||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import HFHubChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B", hf_provider="auto")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define a tool for addition
 | 
			
		||||
def add_numbers(a: int, b: int) -> int:
 | 
			
		||||
    return a + b
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define the tool function call schema
 | 
			
		||||
add_tool = {
 | 
			
		||||
    "type": "function",
 | 
			
		||||
    "function": {
 | 
			
		||||
        "name": "add_numbers",
 | 
			
		||||
        "description": "Add two numbers together.",
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "type": "object",
 | 
			
		||||
            "properties": {
 | 
			
		||||
                "a": {"type": "integer", "description": "The first number."},
 | 
			
		||||
                "b": {"type": "integer", "description": "The second number."},
 | 
			
		||||
            },
 | 
			
		||||
            "required": ["a", "b"],
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Define messages for the chat
 | 
			
		||||
messages = [
 | 
			
		||||
    {"role": "system", "content": "You are a helpful assistant."},
 | 
			
		||||
    {"role": "user", "content": "Add 5 and 7 and 2 and 2."},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(
 | 
			
		||||
    messages=messages, tools=[add_tool], stream=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    print(chunk.result)
 | 
			
		||||
| 
						 | 
				
			
			@ -57,34 +57,34 @@ python text_completion.py
 | 
			
		|||
The script demonstrates basic usage of Dapr Agents' NVIDIAChatClient for text generation:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = NVIDIAChatClient.from_prompty('basic.prompty')
 | 
			
		||||
response = llm.generate(input_data={"question":"What is your name?"})
 | 
			
		||||
llm = NVIDIAChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response with prompty: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
 | 
			
		||||
    print("Response with user input: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None and "hello" in response.get_message().content.lower():
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**2. Expected output:** The LLM will respond with the name of a famous dog (e.g., "Lassie", "Hachiko", etc.).
 | 
			
		||||
| 
						 | 
				
			
			@ -110,32 +110,30 @@ This example shows how to use Pydantic models to get structured data from LLMs:
 | 
			
		|||
```python
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define our data model
 | 
			
		||||
class Dog(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    breed: str
 | 
			
		||||
    reason: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Initialize the chat client
 | 
			
		||||
llm = NVIDIAChatClient(
 | 
			
		||||
    model="meta/llama-3.1-8b-instruct"
 | 
			
		||||
)
 | 
			
		||||
llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
 | 
			
		||||
 | 
			
		||||
# Get structured response
 | 
			
		||||
response = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")],
 | 
			
		||||
    response_format=Dog
 | 
			
		||||
response: Dog = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")], response_format=Dog
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
print(json.dumps(response.model_dump(), indent=2))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Expected output:** A structured Dog object with name, breed, and reason fields (e.g., `Dog(name='Hachiko', breed='Akita', reason='Known for his remarkable loyalty...')`)
 | 
			
		||||
| 
						 | 
				
			
			@ -155,4 +153,102 @@ output_match_mode: substring
 | 
			
		|||
```bash
 | 
			
		||||
python embeddings.py
 | 
			
		||||
```
 | 
			
		||||
<!-- END_STEP -->
 | 
			
		||||
<!-- END_STEP -->
 | 
			
		||||
 | 
			
		||||
### Streaming
 | 
			
		||||
 | 
			
		||||
Our NVIDIA chat client also support streaming responses, where you can process partial results as they arrive. Below are two examples:
 | 
			
		||||
 | 
			
		||||
**1. Basic Streaming Example**
 | 
			
		||||
 | 
			
		||||
Run the `text_completion_stream.py` script to see token‐by‐token output:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python text_completion_stream.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The scripts:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate("Name a famous dog!", stream=True)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    if chunk.result.content:
 | 
			
		||||
        print(chunk.result.content, end="", flush=True)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This will print each partial chunk as it arrives, so you can build up the full answer in real time.
 | 
			
		||||
 | 
			
		||||
**2. Streaming with Tool Calls:**
 | 
			
		||||
 | 
			
		||||
Use `text_completion_with_tools.py` to combine streaming with function‐call “tools”:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python text_completion_with_tools.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
 | 
			
		||||
# Define a tool for addition
 | 
			
		||||
def add_numbers(a: int, b: int) -> int:
 | 
			
		||||
    return a + b
 | 
			
		||||
 | 
			
		||||
# Define the tool function call schema
 | 
			
		||||
add_tool = {
 | 
			
		||||
    "type": "function",
 | 
			
		||||
    "function": {
 | 
			
		||||
        "name": "add_numbers",
 | 
			
		||||
        "description": "Add two numbers together.",
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "type": "object",
 | 
			
		||||
            "properties": {
 | 
			
		||||
                "a": {"type": "integer", "description": "The first number."},
 | 
			
		||||
                "b": {"type": "integer", "description": "The second number."}
 | 
			
		||||
            },
 | 
			
		||||
            "required": ["a", "b"]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Define messages for the chat
 | 
			
		||||
messages = [
 | 
			
		||||
    {"role": "system", "content": "You are a helpful assistant."},
 | 
			
		||||
    {"role": "user", "content": "Add 5 and 7 and 2 and 2."}
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(messages=messages, tools=[add_tool], stream=True)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    print(chunk.result)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Here, the model can decide to call your `add_numbers` function mid‐stream, and you’ll see those calls (and their results) as they come in.
 | 
			
		||||
| 
						 | 
				
			
			@ -1,9 +1,10 @@
 | 
			
		|||
import json
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
| 
						 | 
				
			
			@ -20,7 +21,7 @@ class Dog(BaseModel):
 | 
			
		|||
llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
 | 
			
		||||
 | 
			
		||||
# Get structured response
 | 
			
		||||
response = llm.generate(
 | 
			
		||||
response: Dog = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")], response_format=Dog
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,28 +1,31 @@
 | 
			
		|||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = NVIDIAChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response with prompty: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
 | 
			
		||||
    print("Response with user input: ", response.get_content())
 | 
			
		||||
if (
 | 
			
		||||
    response.get_message() is not None
 | 
			
		||||
    and "hello" in response.get_message().content.lower()
 | 
			
		||||
):
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(
 | 
			
		||||
    "Name a famous dog!", stream=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    if chunk.result.content:
 | 
			
		||||
        print(chunk.result.content, end="", flush=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,50 @@
 | 
			
		|||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import NVIDIAChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = NVIDIAChatClient()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define a tool for addition
 | 
			
		||||
def add_numbers(a: int, b: int) -> int:
 | 
			
		||||
    return a + b
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define the tool function call schema
 | 
			
		||||
add_tool = {
 | 
			
		||||
    "type": "function",
 | 
			
		||||
    "function": {
 | 
			
		||||
        "name": "add_numbers",
 | 
			
		||||
        "description": "Add two numbers together.",
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "type": "object",
 | 
			
		||||
            "properties": {
 | 
			
		||||
                "a": {"type": "integer", "description": "The first number."},
 | 
			
		||||
                "b": {"type": "integer", "description": "The second number."},
 | 
			
		||||
            },
 | 
			
		||||
            "required": ["a", "b"],
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Define messages for the chat
 | 
			
		||||
messages = [
 | 
			
		||||
    {"role": "system", "content": "You are a helpful assistant."},
 | 
			
		||||
    {"role": "user", "content": "Add 5 and 7 and 2 and 2."},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(
 | 
			
		||||
    messages=messages, tools=[add_tool], stream=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    print(chunk.result)
 | 
			
		||||
| 
						 | 
				
			
			@ -57,34 +57,34 @@ python text_completion.py
 | 
			
		|||
The script demonstrates basic usage of Dapr Agents' OpenAIChatClient for text generation:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = OpenAIChatClient.from_prompty('basic.prompty')
 | 
			
		||||
response = llm.generate(input_data={"question":"What is your name?"})
 | 
			
		||||
llm = OpenAIChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response with prompty: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
 | 
			
		||||
    print("Response with user input: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None and "hello" in response.get_message().content.lower():
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**2. Expected output:** The LLM will respond with the name of a famous dog (e.g., "Lassie", "Hachiko", etc.).
 | 
			
		||||
| 
						 | 
				
			
			@ -110,27 +110,29 @@ This example shows how to use Pydantic models to get structured data from LLMs:
 | 
			
		|||
```python
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define our data model
 | 
			
		||||
class Dog(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    breed: str
 | 
			
		||||
    reason: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Initialize the chat client
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
 | 
			
		||||
# Get structured response
 | 
			
		||||
response = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")],
 | 
			
		||||
    response_format=Dog
 | 
			
		||||
response: Dog = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")], response_format=Dog
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
print(json.dumps(response.model_dump(), indent=2))
 | 
			
		||||
| 
						 | 
				
			
			@ -138,6 +140,104 @@ print(json.dumps(response.model_dump(), indent=2))
 | 
			
		|||
 | 
			
		||||
**Expected output:** A structured Dog object with name, breed, and reason fields (e.g., `Dog(name='Hachiko', breed='Akita', reason='Known for his remarkable loyalty...')`)
 | 
			
		||||
 | 
			
		||||
### Streaming
 | 
			
		||||
 | 
			
		||||
Our OpenAI chat client also support streaming responses, where you can process partial results as they arrive. Below are two examples:
 | 
			
		||||
 | 
			
		||||
**1. Basic Streaming Example**
 | 
			
		||||
 | 
			
		||||
Run the `text_completion_stream.py` script to see token‐by‐token output:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python text_completion_stream.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The scripts:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate("Name a famous dog!", stream=True)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    if chunk.result.content:
 | 
			
		||||
        print(chunk.result.content, end="", flush=True)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This will print each partial chunk as it arrives, so you can build up the full answer in real time.
 | 
			
		||||
 | 
			
		||||
**2. Streaming with Tool Calls:**
 | 
			
		||||
 | 
			
		||||
Use `text_completion_with_tools.py` to combine streaming with function‐call “tools”:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python text_completion_with_tools.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
 | 
			
		||||
# Define a tool for addition
 | 
			
		||||
def add_numbers(a: int, b: int) -> int:
 | 
			
		||||
    return a + b
 | 
			
		||||
 | 
			
		||||
# Define the tool function call schema
 | 
			
		||||
add_tool = {
 | 
			
		||||
    "type": "function",
 | 
			
		||||
    "function": {
 | 
			
		||||
        "name": "add_numbers",
 | 
			
		||||
        "description": "Add two numbers together.",
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "type": "object",
 | 
			
		||||
            "properties": {
 | 
			
		||||
                "a": {"type": "integer", "description": "The first number."},
 | 
			
		||||
                "b": {"type": "integer", "description": "The second number."}
 | 
			
		||||
            },
 | 
			
		||||
            "required": ["a", "b"]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Define messages for the chat
 | 
			
		||||
messages = [
 | 
			
		||||
    {"role": "system", "content": "You are a helpful assistant."},
 | 
			
		||||
    {"role": "user", "content": "Add 5 and 7 and 2 and 2."}
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(messages=messages, tools=[add_tool], stream=True)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    print(chunk.result)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Here, the model can decide to call your `add_numbers` function mid‐stream, and you’ll see those calls (and their results) as they come in.
 | 
			
		||||
 | 
			
		||||
### Audio
 | 
			
		||||
You can use the OpenAIAudioClient in `dapr-agents` for basic tasks with the OpenAI Audio API. We will explore:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -151,8 +251,8 @@ You can use the OpenAIAudioClient in `dapr-agents` for basic tasks with the Open
 | 
			
		|||
<!-- STEP
 | 
			
		||||
name: Run audio generation example
 | 
			
		||||
expected_stdout_lines:
 | 
			
		||||
  - "Audio saved to output_speech.mp3"
 | 
			
		||||
  - "File output_speech.mp3 has been deleted."
 | 
			
		||||
  - "Audio saved to speech.mp3"
 | 
			
		||||
  - "File speech.mp3 has been deleted."
 | 
			
		||||
-->
 | 
			
		||||
```bash
 | 
			
		||||
python text_to_speech.py
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,10 @@
 | 
			
		|||
import json
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
| 
						 | 
				
			
			@ -20,7 +21,7 @@ class Dog(BaseModel):
 | 
			
		|||
llm = OpenAIChatClient()
 | 
			
		||||
 | 
			
		||||
# Get structured response
 | 
			
		||||
response = llm.generate(
 | 
			
		||||
response: Dog = llm.generate(
 | 
			
		||||
    messages=[UserMessage("One famous dog in history.")], response_format=Dog
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,28 +1,31 @@
 | 
			
		|||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types import UserMessage
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types import LLMChatResponse, UserMessage
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
response = llm.generate("Name a famous dog!")
 | 
			
		||||
response: LLMChatResponse = llm.generate("Name a famous dog!")
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion using a prompty file for context
 | 
			
		||||
llm = OpenAIChatClient.from_prompty("basic.prompty")
 | 
			
		||||
response = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0:
 | 
			
		||||
    print("Response with prompty: ", response.get_content())
 | 
			
		||||
if response.get_message() is not None:
 | 
			
		||||
    print("Response with prompty: ", response.get_message().content)
 | 
			
		||||
 | 
			
		||||
# Chat completion with user input
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
response = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
 | 
			
		||||
    print("Response with user input: ", response.get_content())
 | 
			
		||||
if (
 | 
			
		||||
    response.get_message() is not None
 | 
			
		||||
    and "hello" in response.get_message().content.lower()
 | 
			
		||||
):
 | 
			
		||||
    print("Response with user input: ", response.get_message().content)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(
 | 
			
		||||
    "Name a famous dog!", stream=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    if chunk.result.content:
 | 
			
		||||
        print(chunk.result.content, end="", flush=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,50 @@
 | 
			
		|||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from dapr_agents import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.message import LLMChatResponseChunk
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
# Load environment variables from .env
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Basic chat completion
 | 
			
		||||
llm = OpenAIChatClient()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define a tool for addition
 | 
			
		||||
def add_numbers(a: int, b: int) -> int:
 | 
			
		||||
    return a + b
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Define the tool function call schema
 | 
			
		||||
add_tool = {
 | 
			
		||||
    "type": "function",
 | 
			
		||||
    "function": {
 | 
			
		||||
        "name": "add_numbers",
 | 
			
		||||
        "description": "Add two numbers together.",
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "type": "object",
 | 
			
		||||
            "properties": {
 | 
			
		||||
                "a": {"type": "integer", "description": "The first number."},
 | 
			
		||||
                "b": {"type": "integer", "description": "The second number."},
 | 
			
		||||
            },
 | 
			
		||||
            "required": ["a", "b"],
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Define messages for the chat
 | 
			
		||||
messages = [
 | 
			
		||||
    {"role": "system", "content": "You are a helpful assistant."},
 | 
			
		||||
    {"role": "user", "content": "Add 5 and 7 and 2 and 2."},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
response: Iterator[LLMChatResponseChunk] = llm.generate(
 | 
			
		||||
    messages=messages, tools=[add_tool], stream=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for chunk in response:
 | 
			
		||||
    print(chunk.result)
 | 
			
		||||
| 
						 | 
				
			
			@ -1,5 +1,3 @@
 | 
			
		|||
import os
 | 
			
		||||
 | 
			
		||||
from dapr_agents.types.llm import AudioSpeechRequest
 | 
			
		||||
from dapr_agents import OpenAIAudioClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
| 
						 | 
				
			
			@ -24,11 +22,8 @@ audio_bytes = client.create_speech(request=tts_request)
 | 
			
		|||
# client.create_speech(request=tts_request, file_name=output_path)
 | 
			
		||||
 | 
			
		||||
# Save the audio to an MP3 file
 | 
			
		||||
output_path = "output_speech.mp3"
 | 
			
		||||
output_path = "speech.mp3"
 | 
			
		||||
with open(output_path, "wb") as audio_file:
 | 
			
		||||
    audio_file.write(audio_bytes)
 | 
			
		||||
 | 
			
		||||
print(f"Audio saved to {output_path}")
 | 
			
		||||
 | 
			
		||||
os.remove(output_path)
 | 
			
		||||
print(f"File {output_path} has been deleted.")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -129,6 +129,10 @@ python weather_agent.py
 | 
			
		|||
 | 
			
		||||
**Expected output:** The agent will identify the locations and use the get_weather tool to fetch weather information for each city.
 | 
			
		||||
 | 
			
		||||
**Other Examples:** You can also try the following agents with the same tools using `HuggingFace hub` and `NVIDIA` LLM chat clients. Make sure you add the `HUGGINGFACE_API_KEY` and `NVIDIA_API_KEY` to the `.env` file.
 | 
			
		||||
- [HuggingFace Agent](./weather_agent_hf.py)
 | 
			
		||||
- [NVIDIA Agent](./weather_agent_nv.py)
 | 
			
		||||
 | 
			
		||||
## Key Concepts
 | 
			
		||||
 | 
			
		||||
### Tool Definition
 | 
			
		||||
| 
						 | 
				
			
			@ -244,24 +248,7 @@ Dapr Agents provides two agent implementations, each designed for different use
 | 
			
		|||
The default agent type, designed for tool execution and straightforward interactions. It receives your input, determines which tools to use, executes them directly, and provides the final answer. The reasoning process is mostly hidden from you, focusing instead on delivering concise responses.
 | 
			
		||||
 | 
			
		||||
### 2. DurableAgent
 | 
			
		||||
The DurableAgent class is a workflow-based agent that extends the standard Agent with Dapr Workflows for long-running, fault-tolerant, and durable execution. It provides persistent state management, automatic retry mechanisms, and deterministic execution across failures. We will see this agent in the next example.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.tool.utils.openapi import OpenAPISpecParser
 | 
			
		||||
from dapr_agents.storage import VectorStore
 | 
			
		||||
 | 
			
		||||
# This agent type requires additional components
 | 
			
		||||
openapi_agent = Agent(
 | 
			
		||||
    name="APIExpert",
 | 
			
		||||
    role="API Expert",
 | 
			
		||||
    pattern="openapireact",  # Specify OpenAPIReAct pattern
 | 
			
		||||
    spec_parser=OpenAPISpecParser(),
 | 
			
		||||
    api_vector_store=VectorStore(),
 | 
			
		||||
    auth_header={"Authorization": "Bearer token"}
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
The DurableAgent class is a workflow-based agent that extends the standard Agent with Dapr Workflows for long-running, fault-tolerant, and durable execution. It provides persistent state management, automatic retry mechanisms, and deterministic execution across failures. We will see this agent in the next example: [Durable Agent](../03-durable-agent-tool-call/).
 | 
			
		||||
 | 
			
		||||
## Troubleshooting
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
import asyncio
 | 
			
		||||
from weather_tools import tools
 | 
			
		||||
from dapr_agents import Agent, HFHubChatClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
 | 
			
		||||
AIAgent = Agent(
 | 
			
		||||
    name="Stevie",
 | 
			
		||||
    role="Weather Assistant",
 | 
			
		||||
    goal="Assist Humans with weather related tasks.",
 | 
			
		||||
    instructions=[
 | 
			
		||||
        "Always answer the user's main weather question directly and clearly.",
 | 
			
		||||
        "If you perform any additional actions (like jumping), summarize those actions and their results.",
 | 
			
		||||
        "At the end, provide a concise summary that combines the weather information for all requested locations and any other actions you performed.",
 | 
			
		||||
    ],
 | 
			
		||||
    llm=llm,
 | 
			
		||||
    tools=tools,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Wrap your async call
 | 
			
		||||
async def main():
 | 
			
		||||
    await AIAgent.run("What is the weather in Virginia, New York and Washington DC?")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(main())
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
import asyncio
 | 
			
		||||
from weather_tools import tools
 | 
			
		||||
from dapr_agents import Agent, NVIDIAChatClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
 | 
			
		||||
 | 
			
		||||
AIAgent = Agent(
 | 
			
		||||
    name="Stevie",
 | 
			
		||||
    role="Weather Assistant",
 | 
			
		||||
    goal="Assist Humans with weather related tasks.",
 | 
			
		||||
    instructions=[
 | 
			
		||||
        "Always answer the user's main weather question directly and clearly.",
 | 
			
		||||
        "If you perform any additional actions (like jumping), summarize those actions and their results.",
 | 
			
		||||
        "At the end, provide a concise summary that combines the weather information for all requested locations and any other actions you performed.",
 | 
			
		||||
    ],
 | 
			
		||||
    llm=llm,
 | 
			
		||||
    tools=tools,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Wrap your async call
 | 
			
		||||
async def main():
 | 
			
		||||
    await AIAgent.run("What is the weather in Virginia, New York and Washington DC?")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(main())
 | 
			
		||||
| 
						 | 
				
			
			@ -87,6 +87,11 @@ start the agent with Dapr:
 | 
			
		|||
dapr run --app-id durableweatherapp --resources-path ./components -- python durable_weather_agent.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Other Durable Agent
 | 
			
		||||
You can also try the following Durable agents with the same tools using `HuggingFace hub` and `NVIDIA` LLM chat clients. Make sure you add the `HUGGINGFACE_API_KEY` and `NVIDIA_API_KEY` to the `.env` file.
 | 
			
		||||
- [HuggingFace Durable Agent](./durable_weather_agent_hf.py)
 | 
			
		||||
- [NVIDIA Durable Agent](./durable_weather_agent_nv.py)
 | 
			
		||||
 | 
			
		||||
## About Durable Agents
 | 
			
		||||
 | 
			
		||||
Durable agents maintain state across runs, enabling workflows that require persistence, recovery, and coordination. This is useful for long-running tasks, multi-step workflows, and agent collaboration.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,37 @@
 | 
			
		|||
from dapr_agents import DurableAgent, HFHubChatClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from weather_tools import tools
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def main():
 | 
			
		||||
    load_dotenv()
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
    # Initialize the HuggingFaceChatClient with the desired model
 | 
			
		||||
    llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
 | 
			
		||||
 | 
			
		||||
    # Instantiate your agent (no .as_service())
 | 
			
		||||
    weather_agent = DurableAgent(
 | 
			
		||||
        role="Weather Assistant",
 | 
			
		||||
        name="Stevie",
 | 
			
		||||
        goal="Help humans get weather and location info using smart tools.",
 | 
			
		||||
        instructions=[
 | 
			
		||||
            "Respond clearly and helpfully to weather-related questions.",
 | 
			
		||||
            "Use tools when appropriate to fetch weather data.",
 | 
			
		||||
        ],
 | 
			
		||||
        llm=llm,
 | 
			
		||||
        message_bus_name="messagepubsub",
 | 
			
		||||
        state_store_name="workflowstatestore",
 | 
			
		||||
        state_key="workflow_state",
 | 
			
		||||
        agents_registry_store_name="agentstatestore",
 | 
			
		||||
        agents_registry_key="agents_registry",
 | 
			
		||||
        tools=tools,
 | 
			
		||||
    )
 | 
			
		||||
    # Start the agent service
 | 
			
		||||
    await weather_agent.run("What's the weather in Boston?")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(main())
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,37 @@
 | 
			
		|||
from dapr_agents import DurableAgent, NVIDIAChatClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from weather_tools import tools
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def main():
 | 
			
		||||
    load_dotenv()
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
    # Initialize the NVIDIAChatClient with the desired model
 | 
			
		||||
    llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
 | 
			
		||||
 | 
			
		||||
    # Instantiate your agent (no .as_service())
 | 
			
		||||
    weather_agent = DurableAgent(
 | 
			
		||||
        role="Weather Assistant",
 | 
			
		||||
        name="Stevie",
 | 
			
		||||
        goal="Help humans get weather and location info using smart tools.",
 | 
			
		||||
        instructions=[
 | 
			
		||||
            "Respond clearly and helpfully to weather-related questions.",
 | 
			
		||||
            "Use tools when appropriate to fetch weather data.",
 | 
			
		||||
        ],
 | 
			
		||||
        llm=llm,
 | 
			
		||||
        message_bus_name="messagepubsub",
 | 
			
		||||
        state_store_name="workflowstatestore",
 | 
			
		||||
        state_key="workflow_state",
 | 
			
		||||
        agents_registry_store_name="agentstatestore",
 | 
			
		||||
        agents_registry_key="agents_registry",
 | 
			
		||||
        tools=tools,
 | 
			
		||||
    )
 | 
			
		||||
    # Start the agent service
 | 
			
		||||
    await weather_agent.run("What's the weather in Boston?")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(main())
 | 
			
		||||
| 
						 | 
				
			
			@ -1,9 +1,11 @@
 | 
			
		|||
import chainlit as cl
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.memory import ConversationDaprStateMemory
 | 
			
		||||
from dapr.clients import DaprClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from unstructured.partition.pdf import partition_pdf
 | 
			
		||||
from dapr.clients import DaprClient
 | 
			
		||||
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.memory import ConversationDaprStateMemory
 | 
			
		||||
from dapr_agents.types import AssistantMessage
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -55,20 +57,22 @@ async def start():
 | 
			
		|||
        upload(file_bytes, text_file.name, "upload")
 | 
			
		||||
 | 
			
		||||
    # give the model the document to learn
 | 
			
		||||
    response = await agent.run("This is a document element to learn: " + document_text)
 | 
			
		||||
    response: AssistantMessage = await agent.run(
 | 
			
		||||
        "This is a document element to learn: " + document_text
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    await cl.Message(content=f"`{text_file.name}` uploaded.").send()
 | 
			
		||||
 | 
			
		||||
    await cl.Message(content=response).send()
 | 
			
		||||
    await cl.Message(content=response.content).send()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@cl.on_message
 | 
			
		||||
async def main(message: cl.Message):
 | 
			
		||||
    # chat to the model about the document
 | 
			
		||||
    result = await agent.run(message.content)
 | 
			
		||||
    result: AssistantMessage = await agent.run(message.content)
 | 
			
		||||
 | 
			
		||||
    await cl.Message(
 | 
			
		||||
        content=result,
 | 
			
		||||
        content=result.content,
 | 
			
		||||
    ).send()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,16 +1,16 @@
 | 
			
		|||
{
 | 
			
		||||
    "instances": {
 | 
			
		||||
        "e419ad7d06f8480498543e6244b8b3a0": {
 | 
			
		||||
        "047a27d46af2482eac9e9dda8382281b": {
 | 
			
		||||
            "input": "What is the weather in New York?",
 | 
			
		||||
            "output": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
 | 
			
		||||
            "start_time": "2025-07-23T02:35:58.877055",
 | 
			
		||||
            "end_time": "2025-07-23T06:36:00.786049+00:00",
 | 
			
		||||
            "output": "The current weather in New York is 66\u00b0F.",
 | 
			
		||||
            "start_time": "2025-07-27T02:07:07.814621",
 | 
			
		||||
            "end_time": "2025-07-27T06:07:09.682802+00:00",
 | 
			
		||||
            "messages": [
 | 
			
		||||
                {
 | 
			
		||||
                    "content": "What is the weather in New York?",
 | 
			
		||||
                    "role": "user",
 | 
			
		||||
                    "id": "b57d8dcf-a84e-4ed9-b6eb-b27360a6a466",
 | 
			
		||||
                    "timestamp": "2025-07-23T02:35:58.890010"
 | 
			
		||||
                    "id": "0133e81f-a5c6-4d02-ada6-80527292172e",
 | 
			
		||||
                    "timestamp": "2025-07-27T02:07:07.826629"
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "content": null,
 | 
			
		||||
| 
						 | 
				
			
			@ -18,7 +18,7 @@
 | 
			
		|||
                    "name": "Stevie",
 | 
			
		||||
                    "tool_calls": [
 | 
			
		||||
                        {
 | 
			
		||||
                            "id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
 | 
			
		||||
                            "id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
 | 
			
		||||
                            "type": "function",
 | 
			
		||||
                            "function": {
 | 
			
		||||
                                "name": "LocalGetWeather",
 | 
			
		||||
| 
						 | 
				
			
			@ -26,42 +26,42 @@
 | 
			
		|||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    ],
 | 
			
		||||
                    "id": "7b508dae-d86a-4cb5-9b84-4bf13f5ef3a3",
 | 
			
		||||
                    "timestamp": "2025-07-23T02:35:59.660877"
 | 
			
		||||
                    "id": "ed40c479-55e8-4026-8e73-8c7cba47e291",
 | 
			
		||||
                    "timestamp": "2025-07-27T02:07:08.651564"
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "content": "New York: 79F.",
 | 
			
		||||
                    "content": "New York: 66F.",
 | 
			
		||||
                    "role": "tool",
 | 
			
		||||
                    "tool_call_id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
 | 
			
		||||
                    "id": "4691eabc-1310-4ad7-bac6-7ecade2857a4",
 | 
			
		||||
                    "function_name": "LocalGetWeather",
 | 
			
		||||
                    "function_args": "{\"location\":\"New York\"}",
 | 
			
		||||
                    "timestamp": "2025-07-23T02:35:59.718943"
 | 
			
		||||
                    "name": "LocalGetWeather",
 | 
			
		||||
                    "tool_call_id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
 | 
			
		||||
                    "id": "f89659f6-b908-435b-9017-ee6013e2f519",
 | 
			
		||||
                    "timestamp": "2025-07-27T02:07:08.696208"
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "content": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
 | 
			
		||||
                    "content": "The current weather in New York is 66\u00b0F.",
 | 
			
		||||
                    "role": "assistant",
 | 
			
		||||
                    "name": "Stevie",
 | 
			
		||||
                    "id": "54886dd0-108e-4e07-969c-4d04327768ee",
 | 
			
		||||
                    "timestamp": "2025-07-23T02:36:00.776750"
 | 
			
		||||
                    "id": "a8492536-32b1-433c-93bb-f5987091f5a2",
 | 
			
		||||
                    "timestamp": "2025-07-27T02:07:09.668079"
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            "last_message": {
 | 
			
		||||
                "content": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
 | 
			
		||||
                "content": "The current weather in New York is 66\u00b0F.",
 | 
			
		||||
                "role": "assistant",
 | 
			
		||||
                "name": "Stevie",
 | 
			
		||||
                "id": "54886dd0-108e-4e07-969c-4d04327768ee",
 | 
			
		||||
                "timestamp": "2025-07-23T02:36:00.776750"
 | 
			
		||||
                "id": "a8492536-32b1-433c-93bb-f5987091f5a2",
 | 
			
		||||
                "timestamp": "2025-07-27T02:07:09.668079"
 | 
			
		||||
            },
 | 
			
		||||
            "tool_history": [
 | 
			
		||||
                {
 | 
			
		||||
                    "content": "New York: 79F.",
 | 
			
		||||
                    "role": "tool",
 | 
			
		||||
                    "tool_call_id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
 | 
			
		||||
                    "id": "4691eabc-1310-4ad7-bac6-7ecade2857a4",
 | 
			
		||||
                    "function_name": "LocalGetWeather",
 | 
			
		||||
                    "function_args": "{\"location\":\"New York\"}",
 | 
			
		||||
                    "timestamp": "2025-07-23T02:35:59.718943"
 | 
			
		||||
                    "id": "7f8c100a-1d11-47ee-b517-a6397f62bcba",
 | 
			
		||||
                    "timestamp": "2025-07-27T02:07:08.696235",
 | 
			
		||||
                    "tool_call_id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
 | 
			
		||||
                    "tool_name": "LocalGetWeather",
 | 
			
		||||
                    "tool_args": {
 | 
			
		||||
                        "location": "New York"
 | 
			
		||||
                    },
 | 
			
		||||
                    "execution_result": "New York: 66F."
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            "source": null,
 | 
			
		||||
| 
						 | 
				
			
			@ -72,8 +72,8 @@
 | 
			
		|||
        {
 | 
			
		||||
            "content": "What is the weather in New York?",
 | 
			
		||||
            "role": "user",
 | 
			
		||||
            "id": "b57d8dcf-a84e-4ed9-b6eb-b27360a6a466",
 | 
			
		||||
            "timestamp": "2025-07-23T02:35:58.890010"
 | 
			
		||||
            "id": "0133e81f-a5c6-4d02-ada6-80527292172e",
 | 
			
		||||
            "timestamp": "2025-07-27T02:07:07.826629"
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "content": null,
 | 
			
		||||
| 
						 | 
				
			
			@ -81,7 +81,7 @@
 | 
			
		|||
            "name": "Stevie",
 | 
			
		||||
            "tool_calls": [
 | 
			
		||||
                {
 | 
			
		||||
                    "id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
 | 
			
		||||
                    "id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
 | 
			
		||||
                    "type": "function",
 | 
			
		||||
                    "function": {
 | 
			
		||||
                        "name": "LocalGetWeather",
 | 
			
		||||
| 
						 | 
				
			
			@ -89,24 +89,23 @@
 | 
			
		|||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            "id": "7b508dae-d86a-4cb5-9b84-4bf13f5ef3a3",
 | 
			
		||||
            "timestamp": "2025-07-23T02:35:59.660877"
 | 
			
		||||
            "id": "ed40c479-55e8-4026-8e73-8c7cba47e291",
 | 
			
		||||
            "timestamp": "2025-07-27T02:07:08.651564"
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "content": "New York: 79F.",
 | 
			
		||||
            "content": "New York: 66F.",
 | 
			
		||||
            "role": "tool",
 | 
			
		||||
            "tool_call_id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
 | 
			
		||||
            "id": "4691eabc-1310-4ad7-bac6-7ecade2857a4",
 | 
			
		||||
            "function_name": "LocalGetWeather",
 | 
			
		||||
            "function_args": "{\"location\":\"New York\"}",
 | 
			
		||||
            "timestamp": "2025-07-23T02:35:59.718943"
 | 
			
		||||
            "name": "LocalGetWeather",
 | 
			
		||||
            "tool_call_id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
 | 
			
		||||
            "id": "f89659f6-b908-435b-9017-ee6013e2f519",
 | 
			
		||||
            "timestamp": "2025-07-27T02:07:08.696208"
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "content": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
 | 
			
		||||
            "content": "The current weather in New York is 66\u00b0F.",
 | 
			
		||||
            "role": "assistant",
 | 
			
		||||
            "name": "Stevie",
 | 
			
		||||
            "id": "54886dd0-108e-4e07-969c-4d04327768ee",
 | 
			
		||||
            "timestamp": "2025-07-23T02:36:00.776750"
 | 
			
		||||
            "id": "a8492536-32b1-433c-93bb-f5987091f5a2",
 | 
			
		||||
            "timestamp": "2025-07-27T02:07:09.668079"
 | 
			
		||||
        }
 | 
			
		||||
    ]
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -74,7 +74,11 @@ weather_agent = Agent(
 | 
			
		|||
    goal="Help humans get weather and location info using MCP tools.",
 | 
			
		||||
    instructions=["Instrictions go here"],
 | 
			
		||||
    tools=tools,
 | 
			
		||||
)    
 | 
			
		||||
) 
 | 
			
		||||
 | 
			
		||||
# Run a sample query
 | 
			
		||||
result: AssistantMessage = await weather_agent.run("What is the weather in New York?")
 | 
			
		||||
print(result.content)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Running the Example
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,6 +4,7 @@ from dotenv import load_dotenv
 | 
			
		|||
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.tool.mcp import MCPClient
 | 
			
		||||
from dapr_agents.types import AssistantMessage
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -37,8 +38,10 @@ async def main():
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        # Run a sample query
 | 
			
		||||
        result = await weather_agent.run("What is the weather in New York?")
 | 
			
		||||
        print(result)
 | 
			
		||||
        result: AssistantMessage = await weather_agent.run(
 | 
			
		||||
            "What is the weather in New York?"
 | 
			
		||||
        )
 | 
			
		||||
        print(result.content)
 | 
			
		||||
 | 
			
		||||
    finally:
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,7 +74,7 @@ mkdir docker-entrypoint-initdb.d
 | 
			
		|||
cp schema.sql users.sql ./docker-entrypoint-initdb.d
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Run the database container:
 | 
			
		||||
Run the database container (Make sure you are in the quickstart directory):
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
docker run --rm --name sampledb \
 | 
			
		||||
| 
						 | 
				
			
			@ -130,7 +130,7 @@ Change the settings below based on your Postgres configuration:
 | 
			
		|||
*Note: If you're running Postgres in a Docker container, change `<HOST>` to `localhost`.*
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
docker run -p 8000:8000 -e DATABASE_URI=postgresql://<USERNAME>:<PASSWORD>@<HOST>:5432/userdb crystaldba/postgres-mcp --access-mode=unrestricted --transport=sse
 | 
			
		||||
docker run --rm -ti -p 8000:8000 -e DATABASE_URI=postgresql://<USERNAME>:<PASSWORD>@<HOST>:5432/userdb crystaldba/postgres-mcp --access-mode=unrestricted --transport=sse
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Examples
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,11 @@
 | 
			
		|||
import chainlit as cl
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.tool.mcp.client import MCPClient
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from get_schema import get_table_schema_as_dict
 | 
			
		||||
 | 
			
		||||
from dapr_agents import Agent
 | 
			
		||||
from dapr_agents.tool.mcp.client import MCPClient
 | 
			
		||||
from dapr_agents.types import AssistantMessage
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
instructions = [
 | 
			
		||||
| 
						 | 
				
			
			@ -52,18 +54,18 @@ async def start():
 | 
			
		|||
async def main(message: cl.Message):
 | 
			
		||||
    # generate the result set and pass back to the user
 | 
			
		||||
    prompt = create_prompt_for_llm(table_info, message.content)
 | 
			
		||||
    result = await agent.run(prompt)
 | 
			
		||||
    result: AssistantMessage = await agent.run(prompt)
 | 
			
		||||
 | 
			
		||||
    await cl.Message(
 | 
			
		||||
        content=result,
 | 
			
		||||
        content=result.content,
 | 
			
		||||
    ).send()
 | 
			
		||||
 | 
			
		||||
    result_set = await agent.run(
 | 
			
		||||
    result_set: AssistantMessage = await agent.run(
 | 
			
		||||
        "Execute the following sql query and always return a table format unless instructed otherwise. If the user asks a question regarding the data, return the result and formalize an answer based on inspecting the data: "
 | 
			
		||||
        + result
 | 
			
		||||
        + result.content
 | 
			
		||||
    )
 | 
			
		||||
    await cl.Message(
 | 
			
		||||
        content=result_set,
 | 
			
		||||
        content=result_set.content,
 | 
			
		||||
    ).send()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,7 +6,7 @@ from dapr_agents.agents.agent.agent import Agent
 | 
			
		|||
from dapr_agents.types import (
 | 
			
		||||
    AgentError,
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
    ChatCompletion,
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
    ToolExecutionRecord,
 | 
			
		||||
    UserMessage,
 | 
			
		||||
    ToolCall,
 | 
			
		||||
| 
						 | 
				
			
			@ -135,15 +135,15 @@ class TestAgent:
 | 
			
		|||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_run_agent_basic(self, basic_agent):
 | 
			
		||||
        """Test basic agent run functionality."""
 | 
			
		||||
        mock_response = Mock(spec=ChatCompletion)
 | 
			
		||||
        mock_response.get_message.return_value = AssistantMessage(content="Hello!")
 | 
			
		||||
        mock_response.get_reason.return_value = "stop"
 | 
			
		||||
        mock_response.get_content.return_value = "Hello!"
 | 
			
		||||
        mock_response = Mock(spec=LLMChatResponse)
 | 
			
		||||
        assistant_msg = AssistantMessage(content="Hello!")
 | 
			
		||||
        mock_response.get_message.return_value = assistant_msg
 | 
			
		||||
        basic_agent.llm.generate.return_value = mock_response
 | 
			
		||||
 | 
			
		||||
        result = await basic_agent._run_agent("Hello")
 | 
			
		||||
 | 
			
		||||
        assert result == "Hello!"
 | 
			
		||||
        assert isinstance(result, AssistantMessage)
 | 
			
		||||
        assert result.content == "Hello!"
 | 
			
		||||
        basic_agent.llm.generate.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
| 
						 | 
				
			
			@ -155,26 +155,24 @@ class TestAgent:
 | 
			
		|||
        tool_call = Mock(spec=ToolCall)
 | 
			
		||||
        tool_call.id = "call_123"
 | 
			
		||||
        tool_call.function = mock_function
 | 
			
		||||
        mock_response = Mock(spec=ChatCompletion)
 | 
			
		||||
        mock_response.get_message.return_value = AssistantMessage(content="Using tool")
 | 
			
		||||
        mock_response.get_reason.return_value = "tool_calls"
 | 
			
		||||
        mock_response.get_tool_calls.return_value = [tool_call]
 | 
			
		||||
        agent_with_tools.llm.generate.return_value = mock_response
 | 
			
		||||
        final_response = Mock(spec=ChatCompletion)
 | 
			
		||||
        final_response.get_message.return_value = AssistantMessage(
 | 
			
		||||
            content="Final answer"
 | 
			
		||||
        )
 | 
			
		||||
        final_response.get_reason.return_value = "stop"
 | 
			
		||||
        final_response.get_content.return_value = "Final answer"
 | 
			
		||||
        agent_with_tools.llm.generate.side_effect = [mock_response, final_response]
 | 
			
		||||
        # Ensure the tool is present in the agent and executor
 | 
			
		||||
 | 
			
		||||
        first_response = Mock(spec=LLMChatResponse)
 | 
			
		||||
        first_assistant = AssistantMessage(content="Using tool", tool_calls=[tool_call])
 | 
			
		||||
        first_response.get_message.return_value = first_assistant
 | 
			
		||||
 | 
			
		||||
        second_response = Mock(spec=LLMChatResponse)
 | 
			
		||||
        second_assistant = AssistantMessage(content="Final answer")
 | 
			
		||||
        second_response.get_message.return_value = second_assistant
 | 
			
		||||
 | 
			
		||||
        agent_with_tools.llm.generate.side_effect = [first_response, second_response]
 | 
			
		||||
        agent_with_tools.tools = [echo_tool]
 | 
			
		||||
        agent_with_tools._tool_executor = agent_with_tools._tool_executor.__class__(
 | 
			
		||||
            tools=[echo_tool]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        result = await agent_with_tools._run_agent("Use the tool")
 | 
			
		||||
        assert result == "Final answer"
 | 
			
		||||
        # tool_history is cleared after a successful run, so we do not assert on its length here -> we should probably fix this later on.
 | 
			
		||||
        assert isinstance(result, AssistantMessage)
 | 
			
		||||
        assert result.content == "Final answer"
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_process_response_success(self, agent_with_tools):
 | 
			
		||||
| 
						 | 
				
			
			@ -217,20 +215,18 @@ class TestAgent:
 | 
			
		|||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_process_iterations_max_reached(self, basic_agent):
 | 
			
		||||
        """Test that agent stops after max iterations."""
 | 
			
		||||
        # Mock LLM to always return tool calls,
 | 
			
		||||
        # and so this means the agent thinks it has a tool to call, but never actually calls a tool,
 | 
			
		||||
        # so agent keeps looping until max iterations is reached.
 | 
			
		||||
        mock_response = Mock(spec=ChatCompletion)
 | 
			
		||||
        mock_response.get_message.return_value = AssistantMessage(content="Using tool")
 | 
			
		||||
        mock_response.get_reason.return_value = "tool_calls"
 | 
			
		||||
        mock_response.get_tool_calls.return_value = []
 | 
			
		||||
        """Test that agent stops immediately when there are no tool calls."""
 | 
			
		||||
        mock_response = Mock(spec=LLMChatResponse)
 | 
			
		||||
        assistant_msg = AssistantMessage(content="Using tool", tool_calls=[])
 | 
			
		||||
        mock_response.get_message.return_value = assistant_msg
 | 
			
		||||
        basic_agent.llm.generate.return_value = mock_response
 | 
			
		||||
 | 
			
		||||
        result = await basic_agent.process_iterations([])
 | 
			
		||||
 | 
			
		||||
        assert result is None
 | 
			
		||||
        assert basic_agent.llm.generate.call_count == basic_agent.max_iterations
 | 
			
		||||
        # current logic sees no tools ===> returns on first iteration
 | 
			
		||||
        assert isinstance(result, AssistantMessage)
 | 
			
		||||
        assert result.content == "Using tool"
 | 
			
		||||
        assert basic_agent.llm.generate.call_count == 1
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_process_iterations_with_llm_error(self, basic_agent):
 | 
			
		||||
| 
						 | 
				
			
			@ -274,15 +270,15 @@ class TestAgent:
 | 
			
		|||
        """Test agent using memory context when no input is provided."""
 | 
			
		||||
        basic_agent.memory.add_message(UserMessage(content="Previous message"))
 | 
			
		||||
 | 
			
		||||
        mock_response = Mock(spec=ChatCompletion)
 | 
			
		||||
        mock_response.get_message.return_value = AssistantMessage(content="Response")
 | 
			
		||||
        mock_response.get_reason.return_value = "stop"
 | 
			
		||||
        mock_response.get_content.return_value = "Response"
 | 
			
		||||
        mock_response = Mock(spec=LLMChatResponse)
 | 
			
		||||
        assistant_msg = AssistantMessage(content="Response")
 | 
			
		||||
        mock_response.get_message.return_value = assistant_msg
 | 
			
		||||
        basic_agent.llm.generate.return_value = mock_response
 | 
			
		||||
 | 
			
		||||
        result = await basic_agent._run_agent(None)
 | 
			
		||||
 | 
			
		||||
        assert result == "Response"
 | 
			
		||||
        assert isinstance(result, AssistantMessage)
 | 
			
		||||
        assert result.content == "Response"
 | 
			
		||||
        basic_agent.llm.generate.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    def test_agent_tool_history_management(self, basic_agent):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,11 +2,13 @@
 | 
			
		|||
# Right now we have to do a bunch of patching at the class-level instead of patching at the instance-level.
 | 
			
		||||
# In future, we should do dependency injection instead of patching at the class-level to make it easier to test.
 | 
			
		||||
# This applies to all areas in this file where we have with patch.object()...
 | 
			
		||||
import pytest
 | 
			
		||||
import asyncio
 | 
			
		||||
import os
 | 
			
		||||
from unittest.mock import Mock, AsyncMock, patch
 | 
			
		||||
from typing import Any
 | 
			
		||||
from unittest.mock import AsyncMock, Mock, patch
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext
 | 
			
		||||
 | 
			
		||||
from dapr_agents.agents.durableagent.agent import DurableAgent
 | 
			
		||||
from dapr_agents.agents.durableagent.schemas import (
 | 
			
		||||
| 
						 | 
				
			
			@ -14,22 +16,21 @@ from dapr_agents.agents.durableagent.schemas import (
 | 
			
		|||
    BroadcastMessage,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.agents.durableagent.state import (
 | 
			
		||||
    DurableAgentWorkflowEntry,
 | 
			
		||||
    DurableAgentWorkflowState,
 | 
			
		||||
)
 | 
			
		||||
from dapr_agents.memory import ConversationListMemory
 | 
			
		||||
from dapr_agents.llm import OpenAIChatClient
 | 
			
		||||
from dapr_agents.memory import ConversationListMemory
 | 
			
		||||
from dapr_agents.tool.base import AgentTool
 | 
			
		||||
from dapr.ext.workflow import DaprWorkflowContext
 | 
			
		||||
from dapr_agents.tool.executor import AgentToolExecutor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We need this otherwise these tests all fail since they require Dapr to be available.
 | 
			
		||||
@pytest.fixture(autouse=True)
 | 
			
		||||
def patch_dapr_check(monkeypatch):
 | 
			
		||||
    from dapr_agents.workflow import agentic
 | 
			
		||||
    from dapr_agents.workflow import base
 | 
			
		||||
    from unittest.mock import Mock
 | 
			
		||||
 | 
			
		||||
    from dapr_agents.workflow import agentic, base
 | 
			
		||||
 | 
			
		||||
    # Mock the Dapr availability check to always return True
 | 
			
		||||
    monkeypatch.setattr(
 | 
			
		||||
        agentic.AgenticWorkflow, "_is_dapr_available", lambda self: True
 | 
			
		||||
| 
						 | 
				
			
			@ -272,9 +273,6 @@ class TestDurableAgent:
 | 
			
		|||
            "stop",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        # Manually insert a mock instance to ensure the state is populated for the assertion
 | 
			
		||||
        from dapr_agents.agents.durableagent.state import DurableAgentWorkflowEntry
 | 
			
		||||
 | 
			
		||||
        basic_durable_agent.state["instances"][
 | 
			
		||||
            "test-instance-123"
 | 
			
		||||
        ] = DurableAgentWorkflowEntry(
 | 
			
		||||
| 
						 | 
				
			
			@ -299,158 +297,46 @@ class TestDurableAgent:
 | 
			
		|||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_generate_response_activity(self, basic_durable_agent):
 | 
			
		||||
        """Test the generate_response activity."""
 | 
			
		||||
        mock_response = {
 | 
			
		||||
            "choices": [
 | 
			
		||||
                {
 | 
			
		||||
                    "message": {
 | 
			
		||||
                        "content": "Test response",
 | 
			
		||||
                        "tool_calls": [],
 | 
			
		||||
                        "finish_reason": "stop",
 | 
			
		||||
                    },
 | 
			
		||||
                    "finish_reason": "stop",
 | 
			
		||||
                }
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
        basic_durable_agent.llm.generate = Mock(return_value=mock_response)
 | 
			
		||||
        """Test that generate_response unwraps an LLMChatResponse properly."""
 | 
			
		||||
        from dapr_agents.types import (
 | 
			
		||||
            AssistantMessage,
 | 
			
		||||
            LLMChatCandidate,
 | 
			
		||||
            LLMChatResponse,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # create a fake LLMChatResponse with one choice
 | 
			
		||||
        fake_response = LLMChatResponse(
 | 
			
		||||
            results=[
 | 
			
		||||
                LLMChatCandidate(
 | 
			
		||||
                    message=AssistantMessage(content="Test response", tool_calls=[]),
 | 
			
		||||
                    finish_reason="stop",
 | 
			
		||||
                )
 | 
			
		||||
            ],
 | 
			
		||||
            metadata={},
 | 
			
		||||
        )
 | 
			
		||||
        basic_durable_agent.llm.generate = Mock(return_value=fake_response)
 | 
			
		||||
 | 
			
		||||
        instance_id = "test-instance-123"
 | 
			
		||||
        workflow_entry = {
 | 
			
		||||
            "input": "Test task",
 | 
			
		||||
            "source": "test_source",
 | 
			
		||||
            "source_workflow_instance_id": None,
 | 
			
		||||
            "messages": [],
 | 
			
		||||
            "tool_history": [],
 | 
			
		||||
            "output": None,
 | 
			
		||||
        # set up a minimal instance record
 | 
			
		||||
        basic_durable_agent.state["instances"] = {
 | 
			
		||||
            instance_id: {
 | 
			
		||||
                "input": "Test task",
 | 
			
		||||
                "source": "test_source",
 | 
			
		||||
                "source_workflow_instance_id": None,
 | 
			
		||||
                "messages": [],
 | 
			
		||||
                "tool_history": [],
 | 
			
		||||
                "output": None,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        basic_durable_agent.state["instances"] = {instance_id: workflow_entry}
 | 
			
		||||
 | 
			
		||||
        result = await basic_durable_agent.generate_response(instance_id, "Test task")
 | 
			
		||||
 | 
			
		||||
        assert result == mock_response
 | 
			
		||||
        assistant_dict = await basic_durable_agent.generate_response(
 | 
			
		||||
            instance_id, "Test task"
 | 
			
		||||
        )
 | 
			
		||||
        # The dict dumped from AssistantMessage should have our content
 | 
			
		||||
        assert assistant_dict["content"] == "Test response"
 | 
			
		||||
        assert assistant_dict["tool_calls"] == []
 | 
			
		||||
        basic_durable_agent.llm.generate.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_get_response_message_activity(self, basic_durable_agent):
 | 
			
		||||
        """Test the get_response_message activity."""
 | 
			
		||||
        response = {
 | 
			
		||||
            "choices": [
 | 
			
		||||
                {
 | 
			
		||||
                    "message": {
 | 
			
		||||
                        "content": "Test response",
 | 
			
		||||
                        "tool_calls": [],
 | 
			
		||||
                        "finish_reason": "stop",
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        result = basic_durable_agent.get_response_message(response)
 | 
			
		||||
 | 
			
		||||
        assert result["content"] == "Test response"
 | 
			
		||||
        assert result["tool_calls"] == []
 | 
			
		||||
        assert result["finish_reason"] == "stop"
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_get_finish_reason_activity(self, basic_durable_agent):
 | 
			
		||||
        """Test the get_finish_reason activity."""
 | 
			
		||||
        response = {"choices": [{"finish_reason": "stop"}]}
 | 
			
		||||
 | 
			
		||||
        result = basic_durable_agent.get_finish_reason(response)
 | 
			
		||||
 | 
			
		||||
        assert result == "stop"
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_get_tool_calls_activity_with_tools(self, durable_agent_with_tools):
 | 
			
		||||
        """Test the get_tool_calls activity when tools are present."""
 | 
			
		||||
        response = {
 | 
			
		||||
            "choices": [
 | 
			
		||||
                {
 | 
			
		||||
                    "message": {
 | 
			
		||||
                        "tool_calls": [
 | 
			
		||||
                            {
 | 
			
		||||
                                "id": "call_123",
 | 
			
		||||
                                "function": {
 | 
			
		||||
                                    "name": "test_tool",
 | 
			
		||||
                                    "arguments": '{"arg1": "value1"}',
 | 
			
		||||
                                },
 | 
			
		||||
                            }
 | 
			
		||||
                        ]
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        result = durable_agent_with_tools.get_tool_calls(response)
 | 
			
		||||
 | 
			
		||||
        assert result is not None
 | 
			
		||||
        assert len(result) == 1
 | 
			
		||||
        assert result[0]["id"] == "call_123"
 | 
			
		||||
        assert result[0]["function"]["name"] == "test_tool"
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_get_tool_calls_activity_no_tools(self, basic_durable_agent):
 | 
			
		||||
        """Test the get_tool_calls activity when no tools are present."""
 | 
			
		||||
        response = {"tool_calls": []}
 | 
			
		||||
 | 
			
		||||
        result = basic_durable_agent.get_tool_calls(response)
 | 
			
		||||
 | 
			
		||||
        assert result is None
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_execute_tool_activity_success(self, durable_agent_with_tools):
 | 
			
		||||
        """Test successful tool execution in activity."""
 | 
			
		||||
        tool_call = {
 | 
			
		||||
            "id": "call_123",
 | 
			
		||||
            "function": {"name": "test_tool", "arguments": '{"arg1": "value1"}'},
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        instance_id = "test-instance-123"
 | 
			
		||||
        workflow_entry = {
 | 
			
		||||
            "input": "Test task",
 | 
			
		||||
            "source": "test_source",
 | 
			
		||||
            "source_workflow_instance_id": None,
 | 
			
		||||
            "messages": [],
 | 
			
		||||
            "tool_history": [],
 | 
			
		||||
            "output": None,
 | 
			
		||||
        }
 | 
			
		||||
        durable_agent_with_tools.state["instances"] = {instance_id: workflow_entry}
 | 
			
		||||
 | 
			
		||||
        with patch.object(
 | 
			
		||||
            AgentToolExecutor,
 | 
			
		||||
            "run_tool",
 | 
			
		||||
            new_callable=AsyncMock,
 | 
			
		||||
            return_value="test_result",
 | 
			
		||||
        ) as mock_run_tool:
 | 
			
		||||
            result = await durable_agent_with_tools.run_tool(tool_call)
 | 
			
		||||
            # Simulate appending to tool_history as the workflow would do
 | 
			
		||||
            durable_agent_with_tools.state["instances"][instance_id].setdefault(
 | 
			
		||||
                "tool_history", []
 | 
			
		||||
            ).append(result)
 | 
			
		||||
            instance_data = durable_agent_with_tools.state["instances"][instance_id]
 | 
			
		||||
            assert len(instance_data["tool_history"]) == 1
 | 
			
		||||
            tool_entry = instance_data["tool_history"][0]
 | 
			
		||||
            assert tool_entry["tool_call_id"] == "call_123"
 | 
			
		||||
            assert tool_entry["tool_name"] == "test_tool"
 | 
			
		||||
            assert tool_entry["execution_result"] == "test_result"
 | 
			
		||||
            mock_run_tool.assert_called_once_with("test_tool", arg1="value1")
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_execute_tool_activity_failure(self, durable_agent_with_tools):
 | 
			
		||||
        """Test tool execution failure in activity."""
 | 
			
		||||
        tool_call = {
 | 
			
		||||
            "id": "call_123",
 | 
			
		||||
            "function": {"name": "test_tool", "arguments": '{"arg1": "value1"}'},
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        with patch.object(
 | 
			
		||||
            type(durable_agent_with_tools.tool_executor),
 | 
			
		||||
            "run_tool",
 | 
			
		||||
            side_effect=Exception("Tool failed"),
 | 
			
		||||
        ):
 | 
			
		||||
            with pytest.raises(Exception, match="Tool failed"):
 | 
			
		||||
                await durable_agent_with_tools.run_tool(tool_call)
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.asyncio
 | 
			
		||||
    async def test_broadcast_message_to_agents_activity(self, basic_durable_agent):
 | 
			
		||||
        """Test broadcasting message to agents activity."""
 | 
			
		||||
| 
						 | 
				
			
			@ -486,15 +372,16 @@ class TestDurableAgent:
 | 
			
		|||
        """Test finishing workflow activity."""
 | 
			
		||||
        instance_id = "test-instance-123"
 | 
			
		||||
        final_output = "Final response"
 | 
			
		||||
        workflow_entry = {
 | 
			
		||||
            "input": "Test task",
 | 
			
		||||
            "source": "test_source",
 | 
			
		||||
            "source_workflow_instance_id": None,
 | 
			
		||||
            "messages": [],
 | 
			
		||||
            "tool_history": [],
 | 
			
		||||
            "output": None,
 | 
			
		||||
        basic_durable_agent.state["instances"] = {
 | 
			
		||||
            instance_id: {
 | 
			
		||||
                "input": "Test task",
 | 
			
		||||
                "source": "test_source",
 | 
			
		||||
                "source_workflow_instance_id": None,
 | 
			
		||||
                "messages": [],
 | 
			
		||||
                "tool_history": [],
 | 
			
		||||
                "output": None,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        basic_durable_agent.state["instances"] = {instance_id: workflow_entry}
 | 
			
		||||
 | 
			
		||||
        basic_durable_agent.finalize_workflow(instance_id, final_output)
 | 
			
		||||
        instance_data = basic_durable_agent.state["instances"][instance_id]
 | 
			
		||||
| 
						 | 
				
			
			@ -513,15 +400,16 @@ class TestDurableAgent:
 | 
			
		|||
        }
 | 
			
		||||
        final_output = "Final output"
 | 
			
		||||
 | 
			
		||||
        workflow_entry = {
 | 
			
		||||
            "input": "Test task",
 | 
			
		||||
            "source": "test_source",
 | 
			
		||||
            "source_workflow_instance_id": None,
 | 
			
		||||
            "messages": [],
 | 
			
		||||
            "tool_history": [],
 | 
			
		||||
            "output": None,
 | 
			
		||||
        basic_durable_agent.state["instances"] = {
 | 
			
		||||
            instance_id: {
 | 
			
		||||
                "input": "Test task",
 | 
			
		||||
                "source": "test_source",
 | 
			
		||||
                "source_workflow_instance_id": None,
 | 
			
		||||
                "messages": [],
 | 
			
		||||
                "tool_history": [],
 | 
			
		||||
                "output": None,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        basic_durable_agent.state["instances"] = {instance_id: workflow_entry}
 | 
			
		||||
 | 
			
		||||
        basic_durable_agent.append_assistant_message(instance_id, message)
 | 
			
		||||
        basic_durable_agent.append_tool_message(instance_id, tool_execution_record)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,81 +1,37 @@
 | 
			
		|||
from pydantic import BaseModel
 | 
			
		||||
from typing import Optional, Dict, Any, Union, Iterator, Type, Iterable
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from collections import UserDict
 | 
			
		||||
from dapr_agents.llm import OpenAIChatClient
 | 
			
		||||
from dapr_agents.types.message import (
 | 
			
		||||
    LLMChatResponse,
 | 
			
		||||
    LLMChatCandidate,
 | 
			
		||||
    AssistantMessage,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockLLMClient(UserDict):
 | 
			
		||||
    """Mock LLM client for testing."""
 | 
			
		||||
class MockLLMClient(OpenAIChatClient):
 | 
			
		||||
    """Mock LLM client for testing that passes type validation."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **kwargs):
 | 
			
		||||
        # Initialize UserDict properly
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Set default values
 | 
			
		||||
        self.data["model"] = kwargs.get("model", "gpt-4o")
 | 
			
		||||
        self.data["azure_deployment"] = kwargs.get("azure_deployment", None)
 | 
			
		||||
        self.data["prompt_template"] = kwargs.get("prompt_template", None)
 | 
			
		||||
        self.data["api_key"] = kwargs.get("api_key", "mock-api-key")
 | 
			
		||||
        self.data["base_url"] = kwargs.get("base_url", "https://api.openai.com/v1")
 | 
			
		||||
        self.data["timeout"] = kwargs.get("timeout", 1500)
 | 
			
		||||
 | 
			
		||||
        # Store additional attributes that might be accessed
 | 
			
		||||
        object.__setattr__(
 | 
			
		||||
            self, "_prompt_template", kwargs.get("prompt_template", None)
 | 
			
		||||
        )
 | 
			
		||||
        object.__setattr__(self, "_model", self.data["model"])
 | 
			
		||||
        object.__setattr__(self, "_azure_deployment", self.data["azure_deployment"])
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        if name == "data":
 | 
			
		||||
            return object.__getattribute__(self, "data")
 | 
			
		||||
        if name in self.data:
 | 
			
		||||
            return self.data[name]
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, name, value):
 | 
			
		||||
        if name == "data" or name.startswith("_"):
 | 
			
		||||
            object.__setattr__(self, name, value)
 | 
			
		||||
        else:
 | 
			
		||||
            self.data[name] = value
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_prompty(
 | 
			
		||||
        cls,
 | 
			
		||||
        prompty_source: Union[str, Path],
 | 
			
		||||
        timeout: Union[int, float, Dict[str, Any]] = 1500,
 | 
			
		||||
    ) -> "MockLLMClient":
 | 
			
		||||
        """Mock implementation of from_prompty method."""
 | 
			
		||||
        return cls(timeout=timeout)
 | 
			
		||||
 | 
			
		||||
    def get_client(self):
 | 
			
		||||
        """Mock implementation of get_client."""
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def get_config(self) -> Dict[str, Any]:
 | 
			
		||||
        """Mock implementation of get_config."""
 | 
			
		||||
        return {
 | 
			
		||||
            "model": self.data["model"],
 | 
			
		||||
            "azure_deployment": self.data["azure_deployment"],
 | 
			
		||||
            "api_key": self.data["api_key"],
 | 
			
		||||
            "base_url": self.data["base_url"],
 | 
			
		||||
            "timeout": self.data["timeout"],
 | 
			
		||||
        }
 | 
			
		||||
        super().__init__(model=kwargs.get("model", "gpt-4o"), api_key="mock-api-key")
 | 
			
		||||
        self.prompt_template = kwargs.get("prompt_template", None)
 | 
			
		||||
 | 
			
		||||
    def generate(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Union[
 | 
			
		||||
            str,
 | 
			
		||||
            Dict[str, Any],
 | 
			
		||||
            Iterable[Union[Dict[str, Any]]],
 | 
			
		||||
        ] = None,
 | 
			
		||||
        input_data: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        tools: Optional[list] = None,
 | 
			
		||||
        response_format: Optional[Type[BaseModel]] = None,
 | 
			
		||||
        structured_mode: str = "json",
 | 
			
		||||
        messages=None,
 | 
			
		||||
        *,
 | 
			
		||||
        input_data=None,
 | 
			
		||||
        model=None,
 | 
			
		||||
        tools=None,
 | 
			
		||||
        response_format=None,
 | 
			
		||||
        structured_mode="json",
 | 
			
		||||
        stream=False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
 | 
			
		||||
        """Mock implementation of generate method."""
 | 
			
		||||
        return {
 | 
			
		||||
            "choices": [{"message": {"content": "Mock response", "role": "assistant"}}]
 | 
			
		||||
        }
 | 
			
		||||
    ):
 | 
			
		||||
        return LLMChatResponse(
 | 
			
		||||
            results=[
 | 
			
		||||
                LLMChatCandidate(
 | 
			
		||||
                    message=AssistantMessage(
 | 
			
		||||
                        content="This is a mock response from the LLM client."
 | 
			
		||||
                    ),
 | 
			
		||||
                    finish_reason="stop",
 | 
			
		||||
                )
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -83,6 +83,6 @@ async def test_select_random_speaker(orchestrator_config):
 | 
			
		|||
        mockclient.return_value = MagicMock()
 | 
			
		||||
        orchestrator = RandomOrchestrator(**orchestrator_config)
 | 
			
		||||
 | 
			
		||||
        speaker = orchestrator.select_random_speaker(iteration=1)
 | 
			
		||||
        speaker = orchestrator.select_random_speaker()
 | 
			
		||||
        assert speaker in ["agent1", "agent2"]
 | 
			
		||||
        assert orchestrator.current_speaker == speaker
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue