Refactor LLM Workflows and Orchestrators for Unified Response Handling and Iteration (#163)

* Refactor ChatClientBase: drop Pydantic inheritance and add typed generate() overloads

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* Align all LLM chat clients with refactored base and unified response models

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* Unify LLM utils across providers and delegate streaming/response to provider‑specific handlers

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* Refactor LLM pipeline: add HuggingFace tool calls, unify chat client/response types, and switch DurableAgent to loop‑based workflow

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* Refactor orchestrators with loops and unify LLM response handling using LLMChatResponse

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* test remaining quickstarts after all changes

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* run pytest after all changes

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* Run linting and formatting checks to ensure code quality

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

* Update logging, Orchestrator Name and OTel module name

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>

---------

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
Roberto Rodriguez 2025-07-28 15:39:56 -04:00 committed by yaron2
parent 3e767e03fb
commit 29edfc419b
67 changed files with 3135 additions and 2042 deletions

View File

@ -1,30 +1,20 @@
import asyncio
import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from dapr_agents.agents.base import AgentBase
from dapr_agents.types import (
AgentError,
AssistantMessage,
ChatCompletion,
ToolCall,
ToolExecutionRecord,
ToolMessage,
UserMessage,
LLMChatResponse,
)
logger = logging.getLogger(__name__)
class FinishReason(str, Enum):
STOP = "stop"
LENGTH = "length"
CONTENT_FILTER = "content_filter"
TOOL_CALLS = "tool_calls"
FUNCTION_CALL = "function_call" # deprecated
class Agent(AgentBase):
"""
Agent that manages tool calls and conversations using a language model.
@ -168,8 +158,10 @@ class Agent(AgentBase):
name=function_name,
content=result_str,
)
# Printing the tool message for visibility
# Print the tool message for visibility
self.text_formatter.print_message(tool_message)
# Add tool message to memory
self.memory.add_message(tool_message)
# Append tool message to the persistent audit log
tool_execution_record = ToolExecutionRecord(
tool_call_id=tool_id,
@ -201,70 +193,55 @@ class Agent(AgentBase):
Raises:
AgentError: On chat failure or tool issues.
"""
for iteration in range(self.max_iterations):
logger.info(f"Iteration {iteration + 1}/{self.max_iterations} started.")
final_reply = None
for turn in range(1, self.max_iterations + 1):
logger.info(f"Iteration {turn}/{self.max_iterations} started.")
try:
# Generate response using the LLM
response = self.llm.generate(
response: LLMChatResponse = self.llm.generate(
messages=messages,
tools=self.get_llm_tools(),
tool_choice=self.tool_choice,
)
# If response is a dict, convert to ChatCompletion
if isinstance(response, dict):
response = ChatCompletion(**response)
elif not isinstance(response, ChatCompletion):
# If response is an iterator (stream), raise TypeError
raise TypeError(f"Expected ChatCompletion, got {type(response)}")
# Get the response message and print it
# Get the first candidate from the response
response_message = response.get_message()
if response_message is not None:
self.text_formatter.print_message(response_message)
# Get Reason for the response
reason = FinishReason(response.get_reason())
# Check if the response contains an assistant message
if response_message is None:
raise AgentError("LLM returned no assistant message")
else:
assistant = response_message
self.text_formatter.print_message(assistant)
self.memory.add_message(assistant)
# Handle tool calls response
if reason == FinishReason.TOOL_CALLS:
tool_calls = response.get_tool_calls()
if assistant is not None and assistant.has_tool_calls():
tool_calls = assistant.get_tool_calls()
if tool_calls:
# Add the assistant message with tool calls to the conversation
if response_message is not None:
messages.append(response_message)
# Execute tools and collect results for this iteration only
tool_messages = await self.execute_tools(tool_calls)
# Add tool results to messages for the next iteration
messages.extend([tm.model_dump() for tm in tool_messages])
# Continue to next iteration to let LLM process tool results
messages.append(assistant.model_dump())
tool_msgs = await self.execute_tools(tool_calls)
messages.extend([tm.model_dump() for tm in tool_msgs])
if turn == self.max_iterations:
final_reply = assistant
logger.info("Reached max turns after tool calls; stopping.")
break
continue
# Handle stop response
elif reason == FinishReason.STOP:
# Append AssistantMessage to memory
msg = AssistantMessage(content=response.get_content() or "")
self.memory.add_message(msg)
return msg.content
# Handle Function call response
elif reason == FinishReason.FUNCTION_CALL:
logger.warning(
"LLM returned a deprecated function_call. Function calls are not processed by this agent."
)
msg = AssistantMessage(
content="Function calls are not supported or processed by this agent."
)
self.memory.add_message(msg)
return msg.content
else:
logger.error(f"Unknown finish reason: {reason}")
raise AgentError(f"Unknown finish reason: {reason}")
# No tool calls => done
final_reply = assistant
break
except Exception as e:
logger.error(f"Error during chat generation: {e}")
logger.error(f"Error on turn {turn}: {e}")
raise AgentError(f"Failed during chat generation: {e}") from e
logger.info("Max iterations reached. Agent has stopped.")
# Post-loop
if final_reply is None:
logger.warning("No reply generated; hitting max iterations.")
return None
logger.info(f"Agent conversation completed after {turn} turns.")
return final_reply
async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
"""
Executes a single registered tool by name, handling both sync and async tools.

View File

@ -26,18 +26,11 @@ from typing import (
ClassVar,
)
from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.openai import OpenAIChatClient
from dapr_agents.llm.huggingface import HFHubChatClient
from dapr_agents.llm.nvidia import NVIDIAChatClient
from dapr_agents.llm.dapr import DaprChatClient
logger = logging.getLogger(__name__)
# Type alias for all concrete chat client implementations
ChatClientType = Union[
OpenAIChatClient, HFHubChatClient, NVIDIAChatClient, DaprChatClient
]
class AgentBase(BaseModel, ABC):
"""
@ -73,7 +66,7 @@ class AgentBase(BaseModel, ABC):
default=None,
description="A custom system prompt, overriding name, role, goal, and instructions.",
)
llm: ChatClientType = Field(
llm: ChatClientBase = Field(
default_factory=OpenAIChatClient,
description="Language model client for generating responses.",
)

View File

@ -1,16 +1,16 @@
import json
import logging
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from dapr.ext.workflow import DaprWorkflowContext # type: ignore
from pydantic import BaseModel, Field, model_validator
from pydantic import Field, model_validator
from dapr_agents.agents.base import AgentBase
from dapr_agents.types import (
AgentError,
AssistantMessage,
LLMChatResponse,
ToolExecutionRecord,
ToolMessage,
UserMessage,
@ -32,15 +32,6 @@ from .state import (
logger = logging.getLogger(__name__)
class FinishReason(str, Enum):
UNKNOWN = "unknown"
STOP = "stop"
LENGTH = "length"
CONTENT_FILTER = "content_filter"
TOOL_CALLS = "tool_calls"
FUNCTION_CALL = "function_call" # deprecated
# TODO(@Sicoyle): Clear up the lines between DurableAgent and AgentWorkflow
class DurableAgent(AgenticWorkflow, AgentBase):
"""
@ -137,32 +128,33 @@ class DurableAgent(AgenticWorkflow, AgentBase):
Returns:
Dict[str, Any]: The final response message when the workflow completes, or None if continuing to the next iteration.
"""
# Step 0: Retrieve task, iteration, and sourceworkflow instance ID from the message
# Step 1: pull out task + metadata
if isinstance(message, dict):
task = message.get("task", None)
iteration = message.get("iteration", 0)
source_workflow_instance_id = message.get("workflow_instance_id")
metadata = message.get("_message_metadata", {}) or {}
else:
task = getattr(message, "task", None)
iteration = getattr(message, "iteration", 0)
source_workflow_instance_id = getattr(message, "workflow_instance_id", None)
# This is the instance ID of the current workflow execution
metadata = getattr(message, "_message_metadata", {}) or {}
instance_id = ctx.instance_id
source = metadata.get("source")
final_message: Optional[Dict[str, Any]] = None
if not ctx.is_replaying:
logger.debug(f"Initial message from {source} -> {self.name}")
try:
# Loop up to max_iterations
for turn in range(1, self.max_iterations + 1):
if not ctx.is_replaying:
logger.info(
f"Workflow iteration {iteration + 1} started (Instance ID: {instance_id})."
f"Workflow turn {turn}/{self.max_iterations} (Instance ID: {instance_id})"
)
# Step 1: Initialize workflow entry and state if this is the first iteration
if iteration == 0:
# Get metadata from the message, if available
if isinstance(message, dict):
metadata = message.get("_message_metadata", {})
else:
metadata = getattr(message, "_message_metadata", {})
source = metadata.get("source", None)
# Use activity to record initial entry for replay safety
# Step 2: On turn 1, record the initial entry
if turn == 1:
yield ctx.call_activity(
self.record_initial_entry,
input={
@ -174,81 +166,81 @@ class DurableAgent(AgenticWorkflow, AgentBase):
},
)
if not ctx.is_replaying:
logger.info(f"Initial message from {source} -> {self.name}")
# Step 2: Retrieve workflow entry info for this instance
entry_info = yield ctx.call_activity(
# Step 3: Retrieve workflow entry info for this instance
entry_info: dict = yield ctx.call_activity(
self.get_workflow_entry_info, input={"instance_id": instance_id}
)
source = entry_info.get("source")
source_workflow_instance_id = entry_info.get("source_workflow_instance_id")
# Step 3: Generate Response via LLM
response = yield ctx.call_activity(
self.generate_response, input={"task": task, "instance_id": instance_id}
source_workflow_instance_id = entry_info.get(
"source_workflow_instance_id"
)
# Step 4: Extract Response Message from LLM Response
response_message = yield ctx.call_activity(
self.get_response_message, input={"response": response}
# Step 4: Generate Response with LLM
response_message: dict = yield ctx.call_activity(
self.generate_response,
input={"task": task, "instance_id": instance_id},
)
# Step 5: Extract Finish Reason from LLM Response
finish_reason = yield ctx.call_activity(
self.get_finish_reason, input={"response": response}
)
# Step 6:Add the assistant's response message to the chat history
# Step 5: Add the assistant's response message to the chat history
yield ctx.call_activity(
self.append_assistant_message,
input={"instance_id": instance_id, "message": response_message},
)
# Step 7: Handle tool calls response
if finish_reason == FinishReason.TOOL_CALLS:
if not ctx.is_replaying:
logger.info("Tool calls detected in LLM response.")
# Retrieve the list of tool calls extracted from the LLM response
tool_calls = yield ctx.call_activity(
self.get_tool_calls, input={"response": response}
)
# Step 6: Handle tool calls response
tool_calls = response_message.get("tool_calls") or []
if tool_calls:
if not ctx.is_replaying:
logger.debug(f"Executing {len(tool_calls)} tool call(s)..")
# Run the tool calls in parallel
logger.info(
f"Turn {turn}: executing {len(tool_calls)} tool call(s)"
)
# fanout parallel tool executions
parallel = [
ctx.call_activity(self.run_tool, input={"tool_call": tc})
for tc in tool_calls
]
tool_results = yield self.when_all(parallel)
tool_results: List[Dict[str, Any]] = yield self.when_all(parallel)
# Add tool results for the next iteration
for tr in tool_results:
yield ctx.call_activity(
self.append_tool_message,
input={"instance_id": instance_id, "tool_result": tr},
)
# Step 8: Process iteration count and finish reason
next_iteration_count = iteration + 1
max_iterations_reached = next_iteration_count > self.max_iterations
if finish_reason == FinishReason.STOP or max_iterations_reached:
# Process max iterations reached
if max_iterations_reached:
if not ctx.is_replaying:
logger.warning(
f"Workflow {instance_id} reached the max iteration limit ({self.max_iterations}) before finishing naturally."
)
# Modify the response message to indicate forced stop
response_message[
# 🔴 If this was the last turn, stop here—even though there were tool calls
if turn == self.max_iterations:
final_message = response_message
final_message[
"content"
] += "\n\nThe workflow was terminated because it reached the maximum iteration limit. The task may not be fully complete."
] += "\n\n⚠️ Stopped: reached max iterations."
break
# Broadcast the final response if a broadcast topic is set
# Otherwise, prepare for next turn: clear task so that generate_response() uses memory/history
task = None
continue # bump to next turn
# No tool calls → this is your final answer
final_message = response_message
# 🔴 If it happened to be the last turn, banner it
if turn == self.max_iterations:
final_message["content"] += "\n\n⚠️ Stopped: reached max iterations."
break # exit loop with final_message
else:
raise AgentError("Workflow ended without producing a final response")
except Exception as e:
logger.exception("Workflow error", exc_info=e)
final_message = {
"role": "assistant",
"content": f"⚠️ Unexpected error: {e}",
}
# Step 7: Broadcast the final response if a broadcast topic is set
if self.broadcast_topic_name:
yield ctx.call_activity(
self.broadcast_message_to_agents,
input={"message": response_message},
input={"message": final_message},
)
# Respond to source agent if available
@ -256,37 +248,30 @@ class DurableAgent(AgenticWorkflow, AgentBase):
yield ctx.call_activity(
self.send_response_back,
input={
"response": response_message,
"response": final_message,
"target_agent": source,
"target_instance_id": source_workflow_instance_id,
},
)
# Share Final Message
# Save final output to workflow state
yield ctx.call_activity(
self.finalize_workflow,
input={
"instance_id": instance_id,
"final_output": response_message["content"],
"final_output": final_message["content"],
},
)
# Log the finalization of the workflow
if not ctx.is_replaying:
verdict = "max_iterations_reached" if max_iterations_reached else "stop"
logger.info(
f"Workflow {instance_id} has been finalized with verdict: {verdict}"
)
return response_message
# Step 9: Continue Workflow Execution
if isinstance(message, dict):
message.update({"task": None, "iteration": next_iteration_count})
next_message = message
else:
# For Pydantic model, create a new dict with updated fields
next_message = message.model_dump()
next_message.update({"task": None, "iteration": next_iteration_count})
ctx.continue_as_new(next_message)
# Set verdict for the workflow instance
if not ctx.is_replaying:
verdict = (
"max_iterations_reached" if turn == self.max_iterations else "completed"
)
logger.info(f"Workflow {instance_id} finalized: {verdict}")
# Return the final response message
return final_message
@task
def record_initial_entry(
@ -391,96 +376,23 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# Generate response using the LLM
try:
response = self.llm.generate(
response: LLMChatResponse = self.llm.generate(
messages=messages,
tools=self.get_llm_tools(),
tool_choice=self.tool_choice,
)
if isinstance(response, BaseModel):
return response.model_dump()
elif isinstance(response, dict):
return response
else:
# Defensive: raise error for unexpected type
raise AgentError(f"Unexpected response type: {type(response)}")
# Get the first candidate from the response
response_message = response.get_message()
# Check if the response contains an assistant message
if response_message is None:
raise AgentError("LLM returned no assistant message")
# Convert the response message to a dict to work with JSON serialization
assistant_message = response_message.model_dump()
return assistant_message
except Exception as e:
logger.error(f"Error during chat generation: {e}")
raise AgentError(f"Failed during chat generation: {e}") from e
@task
def get_response_message(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts the response message from the first choice in the LLM response.
Args:
response (Dict[str, Any]): The response dictionary from the LLM, expected to contain a "choices" key.
Returns:
Dict[str, Any]: The extracted response message with the agent's name added.
Raises:
AgentError: If no response message is found.
"""
choices = response.get("choices", [])
if choices:
response_message = choices[0].get("message", {})
if response_message:
return response_message
raise AgentError("No response message found in LLM response.")
@task
def get_finish_reason(self, response: Dict[str, Any]) -> str:
"""
Extracts the finish reason from the LLM response, indicating why generation stopped.
Args:
response (Dict[str, Any]): The response dictionary from the LLM, expected to contain a "choices" key.
Returns:
FinishReason: The reason the model stopped generating tokens as an enum value.
"""
try:
choices = response.get("choices", [])
if choices and len(choices) > 0:
reason_str = choices[0].get("finish_reason", FinishReason.UNKNOWN.value)
try:
return FinishReason(reason_str)
except ValueError:
logger.warning(f"Unrecognized finish reason: {reason_str}")
return FinishReason.UNKNOWN
# If choices is empty, return UNKNOWN
return FinishReason.UNKNOWN
except Exception as e:
logger.error(f"Error extracting finish reason: {e}")
return FinishReason.UNKNOWN
@task
def get_tool_calls(
self, response: Dict[str, Any]
) -> Optional[List[Dict[str, Any]]]:
"""
Extracts tool calls from the first choice in the LLM response, if available.
Args:
response (Dict[str, Any]): The response dictionary from the LLM, expected to contain "choices"
and potentially tool call information.
Returns:
Optional[List[Dict[str, Any]]]: A list of tool calls if present, otherwise None.
"""
choices = response.get("choices", [])
if not choices:
logger.warning("No choices found in LLM response.")
return None
tool_calls = choices[0].get("message", {}).get("tool_calls")
if tool_calls:
return tool_calls
# Only log if choices exist but no tool_calls
logger.info("No tool calls found in the first LLM response choice.")
return None
@task
async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
"""

View File

@ -28,7 +28,6 @@ class TriggerAction(BaseModel):
None,
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
)
iteration: Optional[int] = Field(0, description="")
workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)

View File

@ -1,3 +1,3 @@
from .otel import DaprAgentsOTel
from .otel import DaprAgentsOtel
__all__ = ["DaprAgentsOTel"]
__all__ = ["DaprAgentsOtel"]

View File

@ -16,7 +16,7 @@ from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExp
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
class DaprAgentsOTel:
class DaprAgentsOtel:
"""
OpenTelemetry configuration for Dapr agents.
"""

View File

@ -1,27 +1,41 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Optional, Type, Union
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
overload,
)
from pydantic import BaseModel, Field
from pydantic import BaseModel
from dapr_agents.prompt.base import PromptTemplateBase
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.tool.base import AgentTool
from dapr_agents.types.message import ChatCompletion
from dapr_agents.types.message import LLMChatCandidateChunk, LLMChatResponse
T = TypeVar("T", bound=BaseModel)
class ChatClientBase(BaseModel, ABC):
class ChatClientBase(ABC):
"""
Base class for chat-specific functionality.
Handles Prompty integration and provides abstract methods for chat client configuration.
Attributes:
prompty: Optional Prompty spec used to render `input_data` into messages.
prompt_template: Optional prompt template object for rendering.
"""
prompty: Optional[Prompty] = Field(
default=None, description="Instance of the Prompty object (optional)."
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="Prompt template for rendering (optional)."
)
prompty: Optional[Prompty]
prompt_template: Optional[PromptTemplateBase]
@classmethod
@abstractmethod
@ -31,43 +45,97 @@ class ChatClientBase(BaseModel, ABC):
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "ChatClientBase":
"""
Abstract method to load a Prompty source and configure the chat client.
Load a Prompty spec (path or inline), extract its model config and
prompt template, and return a configured chat client.
Args:
prompty_source (Union[str, Path]): Source of the Prompty, either a file path or inline Prompty content.
timeout (Union[int, float, Dict[str, Any]]): Timeout for requests.
prompty_source: Path or inline YAML/JSON for a Prompty spec.
timeout: HTTP timeout (seconds or HTTPX-style dict).
Returns:
ChatClientBase: Configured chat client instance.
A ready-to-use ChatClientBase subclass instance.
"""
pass
...
@overload
def generate(
self,
messages: Union[
str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: None = None,
structured_mode: Optional[str] = None,
stream: Literal[False] = False,
**kwargs: Any,
) -> LLMChatResponse:
...
"""If `stream=False` and no `response_format`, returns raw LLMChatResponse."""
@overload
def generate(
self,
messages: Union[
str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Type[T],
structured_mode: Optional[str] = None,
stream: Literal[False] = False,
**kwargs: Any,
) -> Union[T, List[T]]:
...
"""If `stream=False` and `response_format=SomeModel`, returns that model or a list thereof."""
@overload
def generate(
self,
messages: Union[
str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Optional[Type[T]] = None,
structured_mode: Optional[str] = None,
stream: Literal[True],
**kwargs: Any,
) -> Iterator[LLMChatCandidateChunk]:
...
"""If `stream=True`, returns a streaming iterator of chunks."""
@abstractmethod
def generate(
self,
messages: Union[
str, Dict[str, Any], BaseModel, Iterable[Union[Dict[str, Any], BaseModel]]
str, Dict[str, Any], Any, Iterable[Union[Dict[str, Any], Any]]
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Optional[Type[BaseModel]] = None,
response_format: Optional[Type[T]] = None,
structured_mode: Optional[str] = None,
**kwargs,
) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
stream: bool = False,
**kwargs: Any,
) -> Union[
Iterator[LLMChatCandidateChunk],
LLMChatResponse,
T,
List[T],
]:
"""
Abstract method to generate chat completions.
Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
model (Optional[str]): Specific model to use for the request, overriding the default.
tools (Optional[List[Union[AgentTool, Dict[str, Any]]]]): List of tools for the request.
response_format (Optional[Type[BaseModel]]): Optional Pydantic model for structured response parsing.
structured_mode (Optional[str]): Mode for structured output.
**kwargs: Additional parameters for the chat completion API.
Returns:
Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
The implementation must accept the full set of kwargs and return
the union of all possible overload returns.
"""
pass
...

View File

@ -1,48 +1,62 @@
from dapr_agents.llm.dapr.client import DaprInferenceClientBase
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.types.message import BaseMessage
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.tool import AgentTool
from dapr.clients.grpc._request import ConversationInput
from typing import (
Union,
Optional,
Iterable,
Dict,
Any,
List,
Iterator,
Type,
Literal,
ClassVar,
)
from pydantic import BaseModel
from pathlib import Path
import logging
import os
import time
from pathlib import Path
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Literal,
Optional,
Type,
Union,
)
from dapr.clients.grpc._request import ConversationInput
from pydantic import BaseModel, Field
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.dapr.client import DaprInferenceClientBase
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.prompt.base import PromptTemplateBase
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.tool import AgentTool
from dapr_agents.types.message import (
BaseMessage,
LLMChatResponse,
)
logger = logging.getLogger(__name__)
class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
"""
Concrete class for Dapr's chat completion API using the Inference API.
This class extends the ChatClientBase.
Chat client for Dapr's Inference API.
Integrates Prompty-driven prompt templates, tool injection,
PII scrubbing, and normalizes the Dapr output into our unified
LLMChatResponse schema. **Streaming is not supported.**
"""
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"function_call"}
prompty: Optional[Prompty] = Field(
default=None, description="Optional Prompty instance for templating."
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="Optional prompt-template to format inputs."
)
# Only function_callstyle structured output is supported
SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call"}
def model_post_init(self, __context: Any) -> None:
"""
Initializes private attributes for provider, api, config, and client after validation.
After Pydantic init, set up API/type and default LLM component from env.
"""
# Set the private provider and api attributes
self._api = "chat"
self._llm_component = os.environ["DAPR_LLM_COMPONENT_DEFAULT"]
return super().model_post_init(__context)
super().model_post_init(__context)
@classmethod
def from_prompty(
@ -51,23 +65,17 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "DaprChatClient":
"""
Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content.
Build a DaprChatClient from a Prompty spec.
Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string.
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
prompty_source: Path or inline Prompty YAML/JSON.
timeout: Request timeout in seconds or HTTPX-style dict.
Returns:
DaprChatClient: An instance of DaprChatClient configured with the model settings from the Prompty source.
Configured DaprChatClient.
"""
# Load the Prompty instance from the provided source
prompty_instance = Prompty.load(prompty_source)
# Generate the prompt template from the Prompty instance
prompt_template = Prompty.to_prompt_template(prompty_instance)
# Initialize the DaprChatClient based on the Prompty model configuration
return cls.model_validate(
{
"timeout": timeout,
@ -77,17 +85,18 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
)
def translate_response(self, response: dict, model: str) -> dict:
"""Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion."""
"""
Convert Dapr response into OpenAI-style ChatCompletion dict.
"""
choices = [
{
"finish_reason": "stop",
"index": i,
"message": {"content": output["result"], "role": "assistant"},
"index": idx,
"message": {"role": "assistant", "content": out["result"]},
"logprobs": None,
}
for i, output in enumerate(response.get("outputs", []))
for idx, out in enumerate(response.get("outputs", []))
]
return {
"choices": choices,
"created": int(time.time()),
@ -99,11 +108,14 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
def convert_to_conversation_inputs(
self, inputs: List[Dict[str, Any]]
) -> List[ConversationInput]:
"""
Map normalized messages into Dapr ConversationInput objects.
"""
return [
ConversationInput(
content=item["content"],
role=item.get("role"),
scrub_pii=item.get("scrubPII") == "true",
scrub_pii=bool(item.get("scrubPII")),
)
for item in inputs
]
@ -116,59 +128,75 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
BaseMessage,
Iterable[Union[Dict[str, Any], BaseMessage]],
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
llm_component: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Optional[Type[BaseModel]] = None,
structured_mode: Literal["function_call"] = "function_call",
scrubPII: Optional[bool] = False,
scrubPII: bool = False,
temperature: Optional[float] = None,
**kwargs,
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
**kwargs: Any,
) -> Union[
LLMChatResponse,
BaseModel,
List[BaseModel],
]:
"""
Generate chat completions based on provided messages or input_data for prompt templates.
Issue a non-streaming chat completion via Dapr.
- **Streaming is not supported** and setting `stream=True` will raise.
- Returns a unified `LLMChatResponse` (if no `response_format`), or
validated Pydantic model(s) when `response_format` is provided.
Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
llm_component (str): Name of the LLM component to use for the request.
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
structured_mode (Literal["function_call"]): Mode for structured output: "function_call" (Limited Support).
scrubPII (Type[bool]): Optional flag to obfuscate any sensitive information coming back from the LLM.
**kwargs: Additional parameters for the language model.
messages: Prebuilt messages or None to use `input_data`.
input_data: Variables for Prompty template rendering.
llm_component: Dapr component name (defaults from env).
tools: AgentTool or dict specifications.
response_format: Pydantic model for structured output.
structured_mode: Must be "function_call".
scrubPII: Obfuscate sensitive output if True.
temperature: Sampling temperature.
**kwargs: Other Dapr API parameters.
Returns:
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
`LLMChatResponse` if no `response_format`
Pydantic model (or `List[...]`) when `response_format` is set
Raises:
ValueError: on invalid `structured_mode`, missing inputs, or if `stream=True`.
"""
# 1) Validate structured_mode
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
raise ValueError(
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
)
# 2) Disallow response_format + streaming
if response_format is not None:
raise ValueError("`response_format` is not supported by DaprChatClient.")
if kwargs.get("stream"):
raise ValueError("Streaming is not supported by DaprChatClient.")
# If input_data is provided, check for a prompt_template
# 3) Build messages via Prompty
if input_data:
if not self.prompt_template:
raise ValueError(
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
)
logger.info("Using prompt template to generate messages.")
raise ValueError("input_data provided but no prompt_template is set.")
messages = self.prompt_template.format_prompt(**input_data)
# Ensure we have messages at this point
if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
# Process and normalize the messages
params = {"inputs": RequestHandler.normalize_chat_messages(messages)}
# Merge Prompty parameters if available, then override with any explicit kwargs
# 4) Normalize + merge defaults
params: Dict[str, Any] = {
"inputs": RequestHandler.normalize_chat_messages(messages)
}
if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else:
params.update(kwargs)
# Prepare request parameters
# 5) Inject tools + structured directives
params = RequestHandler.process_params(
params,
llm_provider=self.provider,
@ -176,28 +204,32 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
response_format=response_format,
structured_mode=structured_mode,
)
inputs = self.convert_to_conversation_inputs(params["inputs"])
# 6) Convert to Dapr inputs & call
conv_inputs = self.convert_to_conversation_inputs(params["inputs"])
try:
logger.info("Invoking the Dapr Conversation API.")
response = self.client.chat_completion(
raw = self.client.chat_completion(
llm=llm_component or self._llm_component,
conversation_inputs=inputs,
conversation_inputs=conv_inputs,
scrub_pii=scrubPII,
temperature=temperature,
)
transposed_response = self.translate_response(response, self._llm_component)
logger.info("Chat completion retrieved successfully.")
return ResponseHandler.process_response(
transposed_response,
llm_provider=self.provider,
response_format=response_format,
structured_mode=structured_mode,
stream=params.get("stream", False),
normalized = self.translate_response(
raw, llm_component or self._llm_component
)
logger.info("Chat completion retrieved successfully.")
except Exception as e:
logger.error(
f"An error occurred during the Dapr Conversation API call: {e}"
)
raise
# 7) Hand off to our unified handler (always nonstream)
return ResponseHandler.process_response(
response=normalized,
llm_provider=self.provider,
response_format=response_format,
structured_mode=structured_mode,
stream=False,
)

View File

@ -0,0 +1,51 @@
import logging
import time
from typing import Any, Dict
from dapr_agents.types.message import (
AssistantMessage,
LLMChatCandidate,
LLMChatResponse,
)
logger = logging.getLogger(__name__)
def process_dapr_chat_response(response: Dict[str, Any]) -> LLMChatResponse:
"""
Convert a Dapr-normalized chat dict (with OpenAI-style 'choices') into a unified LLMChatResponse.
Args:
response: The dict returned by `DaprChatClient.translate_response`.
Returns:
LLMChatResponse: Contains a list of candidates and metadata.
"""
# 1) Extract each choice → build AssistantMessage + LLMChatCandidate
candidates = []
for choice in response.get("choices", []):
msg = choice.get("message", {})
assistant_message = AssistantMessage(
content=msg.get("content"),
# Dapr currently never returns refusals, tool_calls or function_call here
)
candidate = LLMChatCandidate(
message=assistant_message,
finish_reason=choice.get("finish_reason"),
# Dapr translate_response includes index & no logprobs
index=choice.get("index"),
logprobs=choice.get("logprobs"),
)
candidates.append(candidate)
# 2) Build metadata from the toplevel fields
metadata: Dict[str, Any] = {
"provider": "dapr",
"id": response.get("id", None),
"model": response.get("model", None),
"object": response.get("object", None),
"usage": response.get("usage", {}),
"created": response.get("created", int(time.time())),
}
return LLMChatResponse(results=candidates, metadata=metadata)

View File

@ -13,34 +13,49 @@ from typing import (
Union,
)
from huggingface_hub import ChatCompletionOutput
from pydantic import BaseModel
from pydantic import BaseModel, Field
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.huggingface.client import HFHubInferenceClientBase
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.prompt.base import PromptTemplateBase
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.tool import AgentTool
from dapr_agents.types.message import BaseMessage, ChatCompletion
from dapr_agents.types.message import (
BaseMessage,
LLMChatCandidateChunk,
LLMChatResponse,
)
logger = logging.getLogger(__name__)
class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
"""
Concrete class for the Hugging Face Hub's chat completion API using the Inference API.
This class extends the ChatClientBase and provides the necessary configurations for Hugging Face models.
Chat client for Hugging Face Hub's Inference API.
Extends:
- HFHubInferenceClientBase: manages HF-specific auth, endpoints, retries.
- ChatClientBase: provides the `.from_prompty()` and `.generate()` contract.
Supports only function_call-style structured responses.
"""
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"function_call"}
prompty: Optional[Prompty] = Field(
default=None, description="Optional Prompty instance for templating."
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="Optional prompt-template to format inputs."
)
SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call"}
def model_post_init(self, __context: Any) -> None:
"""
Initializes private attributes for provider, api, config, and client after validation.
After Pydantic __init__, set the internal API type to "chat".
"""
# Set the private provider and api attributes
self._api = "chat"
return super().model_post_init(__context)
super().model_post_init(__context)
@classmethod
def from_prompty(
@ -49,34 +64,28 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "HFHubChatClient":
"""
Initializes an HFHubChatClient client using a Prompty source, which can be a file path or inline content.
Load a Prompty spec (file or inline), extract model config & prompt template,
and return a configured HFHubChatClient.
Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string.
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
prompty_source: Path or inline text of a Prompty YAML/JSON.
timeout: Request timeout (seconds or HTTPX timeout dict).
Returns:
HFHubChatClient: An instance of HFHubChatClient configured with the model settings from the Prompty source.
HFHubChatClient: client ready for .generate() calls.
"""
# Load the Prompty instance from the provided source
prompty_instance = Prompty.load(prompty_source)
# Generate the prompt template from the Prompty instance
prompt_template = Prompty.to_prompt_template(prompty_instance)
cfg = prompty_instance.model.configuration
# Extract the model configuration from Prompty
model_config = prompty_instance.model
# Initialize the HFHubChatClient based on the Prompty model configuration
return cls.model_validate(
{
"model": model_config.configuration.name,
"api_key": model_config.configuration.api_key,
"base_url": model_config.configuration.base_url,
"headers": model_config.configuration.headers,
"cookies": model_config.configuration.cookies,
"proxies": model_config.configuration.proxies,
"model": cfg.name,
"api_key": cfg.api_key,
"base_url": cfg.base_url,
"headers": cfg.headers,
"cookies": cfg.cookies,
"proxies": cfg.proxies,
"timeout": timeout,
"prompty": prompty_instance,
"prompt_template": prompt_template,
@ -91,61 +100,72 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
BaseMessage,
Iterable[Union[Dict[str, Any], BaseMessage]],
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Optional[Type[BaseModel]] = None,
structured_mode: Literal["function_call"] = "function_call",
**kwargs,
) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
stream: bool = False,
**kwargs: Any, # accept any extra params, even if unused
) -> Union[
Iterator[LLMChatCandidateChunk],
LLMChatResponse,
BaseModel,
List[BaseModel],
]:
"""
Generate chat completions based on provided messages or input_data for prompt templates.
Issue a chat completion via Hugging Face's Inference API.
- If `stream=True` in **kwargs**, returns an iterator of `LLMChatCandidateChunk`.
- Otherwise returns either:
raw `AssistantMessage` wrapped in `LLMChatResponse`, or
validated Pydantic model(s) per `response_format`.
Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
model (str): Specific model to use for the request, overriding the default.
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
structured_mode (Literal["function_call"]): Mode for structured output: "function_call" (Limited Support).
**kwargs: Additional parameters for the language model.
messages: Pre-built messages or None to use `input_data`.
input_data: Variables for the Prompty template.
model: Override the client's default model name.
tools: List of AgentTool or dict specs.
response_format: Pydantic model (or List[Model]) for structured output.
structured_mode: Must be `"function_call"` (only supported mode here).
stream: If True, return an iterator of `LLMChatCandidateChunk`.
**kwargs: Any other LLM params (temperature, top_p, stream, etc.).
Returns:
Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
"""
`Iterator[LLMChatCandidateChunk]` if streaming
`LLMChatResponse` or Pydantic instance(s) if non-streaming
Raises:
ValueError: on invalid `structured_mode`, missing prompts, or API errors.
"""
# 1) Validate structured_mode
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
raise ValueError(
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
)
# If input_data is provided, check for a prompt_template
# 2) If using a prompt template, build messages
if input_data:
if not self.prompt_template:
raise ValueError(
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
)
logger.info("Using prompt template to generate messages.")
raise ValueError("No prompt_template set for input_data usage.")
logger.info("Formatting messages via prompt_template.")
messages = self.prompt_template.format_prompt(**input_data)
# Ensure we have messages at this point
if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
raise ValueError("Either messages or input_data must be provided.")
# Process and normalize the messages
# 3) Normalize messages + merge client/prompty defaults
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
# Merge Prompty parameters if available, then override with any explicit kwargs
if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else:
params.update(kwargs)
# If a model is provided, override the default model
# 4) Override model if given
params["model"] = model or self.model
# Prepare request parameters
# 5) Inject tools / response_format via RequestHandler
params = RequestHandler.process_params(
params,
llm_provider=self.provider,
@ -154,31 +174,28 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
structured_mode=structured_mode,
)
# 6) Call HF API + delegate parsing to ResponseHandler
try:
logger.info("Invoking Hugging Face ChatCompletion API.")
response: ChatCompletionOutput = self.client.chat.completions.create(
**params
)
logger.info("Chat completion retrieved successfully.")
logger.info("Calling HF ChatCompletion Inference API...")
logger.debug(f"HF params: {params}")
response = self.client.chat.completions.create(**params, stream=stream)
logger.info("HF ChatCompletion response received.")
# Hugging Face error handling
status = getattr(response, "statuscode", 200)
# HF-specific errorcode handling
code = getattr(response, "code", 200)
if code != 200:
logger.error(
f"❌ Status Code:{status} - Code:{code} Error: {getattr(response, 'message', response)}"
)
raise RuntimeError(
f"{status}/{code} Error: {getattr(response, 'message', response)}"
)
msg = getattr(response, "message", response)
logger.error(f"❌ HF error {code}: {msg}")
raise RuntimeError(f"HuggingFace error {code}: {msg}")
return ResponseHandler.process_response(
response,
response=response,
llm_provider=self.provider,
response_format=response_format,
structured_mode=structured_mode,
stream=params.get("stream", False),
stream=stream,
)
except Exception as e:
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
raise
logger.error("Hugging Face ChatCompletion API error", exc_info=True)
raise ValueError("Failed to process HF chat completion") from e

View File

@ -0,0 +1,260 @@
import dataclasses
import logging
from typing import Any, Callable, Dict, Iterator, Optional
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput
from dapr_agents.types.message import (
AssistantMessage,
FunctionCall,
LLMChatCandidate,
LLMChatCandidateChunk,
LLMChatResponse,
LLMChatResponseChunk,
ToolCall,
ToolCallChunk,
)
logger = logging.getLogger(__name__)
# Helper function to handle metadata extraction
def _get_packet_metadata(
pkt: Dict[str, Any], enrich_metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Extract metadata from HuggingFace packet and merge with enrich_metadata.
Args:
pkt (Dict[str, Any]): The HuggingFace packet from which to extract metadata.
enrich_metadata (Optional[Dict[str, Any]]): Additional metadata to merge with the extracted metadata.
Returns:
Dict[str, Any]: The merged metadata dictionary.
"""
try:
return {
"id": pkt.get("id"),
"created": pkt.get("created"),
"model": pkt.get("model"),
"object": pkt.get("object"),
"service_tier": pkt.get("service_tier"),
"system_fingerprint": pkt.get("system_fingerprint"),
**(enrich_metadata or {}),
}
except Exception as e:
logger.error(f"Failed to parse packet: {e}", exc_info=True)
return {}
# Helper function to process each choice delta (content, function call, tool call, finish reason)
def _process_choice_delta(
choice: Dict[str, Any],
overall_meta: Dict[str, Any],
on_chunk: Optional[Callable],
first_chunk_flag: bool,
) -> Iterator[LLMChatResponseChunk]:
"""
Process each choice delta and yield corresponding chunks.
Args:
choice (Dict[str, Any]): The choice delta from HuggingFace response.
overall_meta (Dict[str, Any]): Overall metadata to include in chunks.
on_chunk (Optional[Callable]): Callback for each chunk.
first_chunk_flag (bool): Flag indicating if this is the first chunk.
Yields:
LLMChatResponseChunk: The processed chunk with content, function call, tool calls,
"""
# Make an immutable snapshot for this single chunk
meta = {**overall_meta}
# mark first_chunk exactly once
if first_chunk_flag and "first_chunk" not in meta:
meta["first_chunk"] = True
# Extract initial properties from choice
delta: dict = choice.get("delta", {})
idx = choice.get("index")
finish_reason = choice.get("finish_reason", None)
logprobs = choice.get("logprobs", None)
# Set additional metadata
if finish_reason in ("stop", "tool_calls"):
meta["last_chunk"] = True
# Process content delta
content = delta.get("content", None)
function_call = delta.get("function_call", None)
refusal = delta.get("refusal", None)
role = delta.get("role", None)
# Process tool calls
chunk_tool_calls = [ToolCallChunk(**tc) for tc in (delta.get("tool_calls") or [])]
# Initialize LLMChatResponseChunk
response_chunk = LLMChatResponseChunk(
result=LLMChatCandidateChunk(
content=content,
function_call=function_call,
refusal=refusal,
role=role,
tool_calls=chunk_tool_calls,
finish_reason=finish_reason,
index=idx,
logprobs=logprobs,
),
metadata=meta,
)
# Process chunk with on_chunk callback
if on_chunk:
on_chunk(response_chunk)
# Yield LLMChatResponseChunk
yield response_chunk
# Main function to process HuggingFace streaming response
def process_hf_stream(
raw_stream: Iterator[ChatCompletionStreamOutput],
*,
enrich_metadata: Optional[Dict[str, Any]] = None,
on_chunk: Optional[Callable],
) -> Iterator[LLMChatCandidateChunk]:
"""
Normalize HuggingFace streaming chat into LLMChatCandidateChunk objects,
accumulating buffers per choice and yielding both partial and final chunks.
Args:
raw_stream: Iterator from client.chat.completions.create(..., stream=True)
enrich_metadata: Extra key/value pairs to merge into each chunk.metadata
on_chunk: Callback fired on every partial delta (token, function, tool)
Yields:
LLMChatCandidateChunk for every partial and final piece, in stream order
"""
enrich_metadata = enrich_metadata or {}
overall_meta: Dict[str, Any] = {}
# Track if we are in the first chunk
first_chunk_flag = True
for packet in raw_stream:
# Convert Pydantic / HuggingFaceObject → plain dict
if hasattr(packet, "model_dump"):
pkt = packet.model_dump()
elif hasattr(packet, "to_dict"):
pkt = packet.to_dict()
elif dataclasses.is_dataclass(packet):
pkt = dataclasses.asdict(packet)
else:
raise TypeError(f"Cannot serialize packet of type {type(packet)}")
# Capture overall metadata from the packet
overall_meta = _get_packet_metadata(pkt, enrich_metadata)
# Process each choice in this packet
if choices := pkt.get("choices"):
if len(choices) == 0:
logger.warning(
"Received empty 'choices' in HuggingFace packet, skipping."
)
continue
# Process the first choice in the packet
choice = choices[0]
yield from _process_choice_delta(
choice, overall_meta, on_chunk, first_chunk_flag
)
# Set first_chunk_flag to False after processing the first choice
first_chunk_flag = False
else:
logger.warning(f" Yielding packet without 'choices': {pkt}")
# Initialize default LLMChatResponseChunk
final_response_chunk = LLMChatResponseChunk(metadata=overall_meta)
yield final_response_chunk
def process_hf_chat_response(response: ChatCompletionOutput) -> LLMChatResponse:
"""
Convert a non-streaming Hugging Face ChatCompletionOutput into our unified LLMChatResponse.
This will:
1. Turn the HF dataclass into a plain dict via .model_dump() or .dict().
2. Extract each `choice`, build an AssistantMessage (including any tool_calls or
function_call shortcuts), wrap in LLMChatCandidate.
3. Collect top-level metadata (id, model, usage, etc.) into an LLMChatResponse.
Args:
response: The HFHubInferenceClientBase.chat.completions.create(...) output.
Returns:
An LLMChatResponse containing all chat candidates and metadata.
"""
# 1) serialise the HF object to a primitive dict
try:
if hasattr(response, "model_dump"):
resp: Dict[str, Any] = response.model_dump()
elif hasattr(response, "dict"):
resp: Dict[str, Any] = response.dict()
elif dataclasses.is_dataclass(response):
resp = dataclasses.asdict(response)
elif hasattr(response, "to_dict"):
resp = response.to_dict()
else:
raise TypeError(f"Cannot serialize object of type {type(response)}")
except Exception:
logger.exception("Failed to serialize HF chat response")
resp = {}
candidates = []
for choice in resp.get("choices", []):
msg = choice.get("message") or {}
# 2) build tool_calls list if present
tool_calls: Optional[list[ToolCall]] = None
if msg.get("tool_calls"):
tool_calls = []
for tc in msg["tool_calls"]:
try:
tool_calls.append(ToolCall(**tc))
except Exception:
logger.exception(f"Invalid HF tool_call entry: {tc}")
# 2b) handle the singleID shortcut
if msg.get("tool_call_id") and not tool_calls:
# HF only sent you an ID; we turn that into a zeroarg function_call
fc = FunctionCall(name=msg["tool_call_id"], arguments="")
tool_calls = [
ToolCall(id=msg["tool_call_id"], type="function", function=fc)
]
# 3) promote first tool_call into function_call if desired
function_call = tool_calls[0].function if tool_calls else None
assistant = AssistantMessage(
content=msg.get("content"),
refusal=None,
tool_calls=tool_calls,
function_call=function_call,
)
candidates.append(
LLMChatCandidate(
message=assistant,
finish_reason=choice.get("finish_reason"),
index=choice.get("index"),
logprobs=choice.get("logprobs"),
)
)
# 4) collect overall metadata
metadata = {
"provider": "huggingface",
"id": resp.get("id"),
"model": resp.get("model"),
"created": resp.get("created"),
"system_fingerprint": resp.get("system_fingerprint"),
"usage": resp.get("usage"),
}
return LLMChatResponse(results=candidates, metadata=metadata)

View File

@ -13,44 +13,61 @@ from typing import (
Union,
)
from openai.types.chat import ChatCompletionMessage
from pydantic import BaseModel, Field
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.nvidia.client import NVIDIAClientBase
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.prompt.base import PromptTemplateBase
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.tool import AgentTool
from dapr_agents.types.message import BaseMessage, ChatCompletion
from dapr_agents.types.message import (
BaseMessage,
LLMChatCandidateChunk,
LLMChatResponse,
)
logger = logging.getLogger(__name__)
class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
"""
Chat client for NVIDIA chat models.
Combines NVIDIA client management with Prompty-specific functionality for handling chat completions.
Chat client for NVIDIA chat models, combining NVIDIA client management
with Prompty-specific prompt templates and unified request/response handling.
Inherits:
- NVIDIAClientBase: manages API key, base_url, retries, etc. for NVIDIA endpoints.
- ChatClientBase: provides chat-specific abstractions.
Attributes:
model: the model name to use (e.g. "meta/llama-3.1-8b-instruct").
max_tokens: maximum number of tokens to generate per call.
"""
model: str = Field(
default="meta/llama3-8b-instruct",
description="Model name to use. Defaults to 'meta/llama3-8b-instruct'.",
default="meta/llama-3.1-8b-instruct",
description="Model name to use. Defaults to 'meta/llama-3.1-8b-instruct'.",
)
max_tokens: Optional[int] = Field(
default=1024,
description=(
"The maximum number of tokens to generate in any given call. Must be an integer ≥ 1. Defaults to 1024."
"Maximum number of tokens to generate in a single call. "
"Must be ≥1; defaults to 1024."
),
)
prompty: Optional[Prompty] = Field(
default=None, description="Optional Prompty instance for templating."
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="Optional prompt-template to format inputs."
)
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"function_call"}
# NVIDIA currently only supports function_call structured output
SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call"}
def model_post_init(self, __context: Any) -> None:
"""
Initializes chat-specific attributes after validation.
Args:
__context (Any): Additional context for post-initialization (not used here).
After Pydantic init, configure the client for 'chat' API.
"""
self._api = "chat"
super().model_post_init(__context)
@ -58,30 +75,23 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
@classmethod
def from_prompty(cls, prompty_source: Union[str, Path]) -> "NVIDIAChatClient":
"""
Initializes an NVIDIAChatClient client using a Prompty source, which can be a file path or inline content.
Build an NVIDIAChatClient from a Prompty spec (file path or inline YAML/JSON).
Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string.
prompty_source: Path or inline content of a Prompty specification.
Returns:
NVIDIAChatClient: An instance of NVIDIAChatClient configured with the model settings from the Prompty source.
Configured NVIDIAChatClient instance.
"""
# Load the Prompty instance from the provided source
prompty_instance = Prompty.load(prompty_source)
# Generate the prompt template from the Prompty instance
prompt_template = Prompty.to_prompt_template(prompty_instance)
cfg = prompty_instance.model.configuration
# Extract the model configuration from Prompty
model_config = prompty_instance.model
# Initialize the NVIDIAChatClient instance using model_validate
return cls.model_validate(
{
"model": model_config.configuration.name,
"api_key": model_config.configuration.api_key,
"base_url": model_config.configuration.base_url,
"model": cfg.name,
"api_key": cfg.api_key,
"base_url": cfg.base_url,
"prompty": prompty_instance,
"prompt_template": prompt_template,
}
@ -95,66 +105,74 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
BaseMessage,
Iterable[Union[Dict[str, Any], BaseMessage]],
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Optional[Type[BaseModel]] = None,
max_tokens: Optional[int] = None,
structured_mode: Literal["function_call"] = "function_call",
**kwargs,
) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
stream: bool = False,
**kwargs: Any,
) -> Union[
Iterator[LLMChatCandidateChunk], # streaming
LLMChatResponse, # nonstream + no format
BaseModel, # nonstream + single structured format
List[BaseModel], # nonstream + list structured format
]:
"""
Generate chat completions based on provided messages or input_data for prompt templates.
Issue a chat completion to NVIDIA.
- If `stream=True` in kwargs, returns an iterator of
`LLMChatCandidateChunk` via ResponseHandler.
- Otherwise returns either:
a raw `LLMChatResponse`, or
validated Pydantic model(s) per `response_format`.
Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
model (str): Specific model to use for the request, overriding the default.
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
max_tokens (Optional[int]): The maximum number of tokens to generate. Defaults to the instance setting.
structured_mode (Literal["function_call"]): Mode for structured output: "function_call" (Limited Support).
**kwargs: Additional parameters for the language model.
messages: Pre-built messages or None to use `input_data`.
input_data: Variables for the Prompty template.
model: Override default model name.
tools: List of AgentTool or dict specs.
response_format: Pydantic model (or list thereof) for structured output.
max_tokens: Override default max_tokens.
structured_mode: Must be "function_call" (only supported mode).
stream: If True, return an iterator of `LLMChatCandidateChunk`.
**kwargs: Other LLM params (temperature, stream, etc.).
Returns:
Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
Raises:
ValueError: for invalid structured_mode or missing inputs.
"""
# 1) Validate structured_mode
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
raise ValueError(
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
)
# If input_data is provided, check for a prompt_template
# 2) If input_data is provided, format messages via Prompty
if input_data:
if not self.prompt_template:
raise ValueError(
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
)
logger.info("Using prompt template to generate messages.")
raise ValueError("input_data provided but no prompt_template is set.")
logger.info("Formatting messages via prompt_template.")
messages = self.prompt_template.format_prompt(**input_data)
# Ensure we have messages at this point
if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
# Process and normalize the messages
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
# Merge prompty parameters if available, then override with any explicit kwargs
# 3) Normalize messages + merge client/prompty defaults
params: Dict[str, Any] = {
"messages": RequestHandler.normalize_chat_messages(messages)
}
if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else:
params.update(kwargs)
# If a model is provided, override the default model
# 4) Override model & max_tokens if provided
params["model"] = model or self.model
# Apply max_tokens if provided
params["max_tokens"] = max_tokens or self.max_tokens
# Prepare request parameters
# 5) Inject tools / response_format / structured_mode
params = RequestHandler.process_params(
params,
llm_provider=self.provider,
@ -163,21 +181,18 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
structured_mode=structured_mode,
)
# 6) Call NVIDIA API + dispatch to ResponseHandler
try:
logger.info("Invoking ChatCompletion API.")
logger.debug(f"ChatCompletion API Parameters:{params}")
response: ChatCompletionMessage = self.client.chat.completions.create(
**params
)
logger.info("Chat completion retrieved successfully.")
logger.info("Calling NVIDIA ChatCompletion API.")
logger.debug(f"Parameters: {params}")
resp = self.client.chat.completions.create(**params, stream=stream)
return ResponseHandler.process_response(
response,
response=resp,
llm_provider=self.provider,
response_format=response_format,
structured_mode=structured_mode,
stream=params.get("stream", False),
stream=stream,
)
except Exception as e:
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
except Exception:
logger.error("NVIDIA ChatCompletion error", exc_info=True)
raise

View File

@ -13,44 +13,58 @@ from typing import (
Union,
)
from openai.types.chat import ChatCompletionMessage
from pydantic import BaseModel, Field, model_validator
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.openai.client.base import OpenAIClientBase
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.prompt.base import PromptTemplateBase
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.tool import AgentTool
from dapr_agents.types.llm import AzureOpenAIModelConfig, OpenAIModelConfig
from dapr_agents.types.message import BaseMessage, ChatCompletion
from dapr_agents.types.message import (
BaseMessage,
LLMChatCandidateChunk,
LLMChatResponse,
)
logger = logging.getLogger(__name__)
class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
"""
Chat client for OpenAI models.
Combines OpenAI client management with Prompty-specific functionality.
Chat client for OpenAI models, layering in Prompty-driven prompt templates
and unified request/response handling.
Inherits:
- OpenAIClientBase: manages API key, base_url, retries, etc.
- ChatClientBase: provides chat-specific abstractions.
"""
model: str = Field(default=None, description="Model name to use, e.g., 'gpt-4'.")
model: Optional[str] = Field(
default=None, description="Model name or Azure deployment ID."
)
prompty: Optional[Prompty] = Field(
default=None, description="Optional Prompty instance for templating."
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="Optional prompt-template to format inputs."
)
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"json", "function_call"}
SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"json", "function_call"}
@model_validator(mode="before")
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Ensures the 'model' is set during validation.
Uses 'azure_deployment' if no model is specified, defaults to 'gpt-4o'.
Ensure `.model` is always set. If unset, fall back to `azure_deployment`
or default to `"gpt-4o"`.
"""
if "model" not in values or values["model"] is None:
if not values.get("model"):
values["model"] = values.get("azure_deployment", "gpt-4o")
return values
def model_post_init(self, __context: Any) -> None:
"""
Initializes chat-specific attributes after validation.
"""
"""After Pydantic init, ensure we're in the “chat” API mode."""
self._api = "chat"
super().model_post_init(__context)
@ -61,60 +75,54 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "OpenAIChatClient":
"""
Initializes an OpenAIChatClient client using a Prompty source, which can be a file path or inline content.
Load a Prompty file (or inline YAML/JSON string), extract its
model configuration and prompt template, and return a fully-wired client.
Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string.
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
prompty_source: path or inline text for a Prompty spec.
timeout: seconds or HTTPX-style timeout, defaults to 1500.
Returns:
OpenAIChatClient: An instance of OpenAIChatClient configured with the model settings from the Prompty source.
Configured OpenAIChatClient.
"""
# Load the Prompty instance from the provided source
prompty_instance = Prompty.load(prompty_source)
# Generate the prompt template from the Prompty instance
prompt_template = Prompty.to_prompt_template(prompty_instance)
cfg = prompty_instance.model.configuration
# Extract the model configuration from Prompty
model_config = prompty_instance.model
# Initialize the OpenAIChatClient instance using model_validate
if isinstance(model_config.configuration, OpenAIModelConfig):
return cls.model_validate(
{
"model": model_config.configuration.name,
"api_key": model_config.configuration.api_key,
"base_url": model_config.configuration.base_url,
"organization": model_config.configuration.organization,
"project": model_config.configuration.project,
common = {
"timeout": timeout,
"prompty": prompty_instance,
"prompt_template": prompt_template,
}
)
elif isinstance(model_config.configuration, AzureOpenAIModelConfig):
if isinstance(cfg, OpenAIModelConfig):
return cls.model_validate(
{
"model": model_config.configuration.azure_deployment,
"api_key": model_config.configuration.api_key,
"azure_endpoint": model_config.configuration.azure_endpoint,
"azure_deployment": model_config.configuration.azure_deployment,
"api_version": model_config.configuration.api_version,
"organization": model_config.configuration.organization,
"project": model_config.configuration.project,
"azure_ad_token": model_config.configuration.azure_ad_token,
"azure_client_id": model_config.configuration.azure_client_id,
"timeout": timeout,
"prompty": prompty_instance,
"prompt_template": prompt_template,
**common,
"model": cfg.name,
"api_key": cfg.api_key,
"base_url": cfg.base_url,
"organization": cfg.organization,
"project": cfg.project,
}
)
elif isinstance(cfg, AzureOpenAIModelConfig):
return cls.model_validate(
{
**common,
"model": cfg.azure_deployment,
"api_key": cfg.api_key,
"azure_endpoint": cfg.azure_endpoint,
"azure_deployment": cfg.azure_deployment,
"api_version": cfg.api_version,
"organization": cfg.organization,
"project": cfg.project,
"azure_ad_token": cfg.azure_ad_token,
"azure_client_id": cfg.azure_client_id,
}
)
else:
raise ValueError(
f"Unsupported model configuration type: {type(model_config.configuration)}"
)
raise ValueError(f"Unsupported model config: {type(cfg)}")
def generate(
self,
@ -124,61 +132,72 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
BaseMessage,
Iterable[Union[Dict[str, Any], BaseMessage]],
] = None,
*,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_format: Optional[Type[BaseModel]] = None,
structured_mode: Literal["json", "function_call"] = "json",
**kwargs,
) -> Union[Iterator[Dict[str, Any]], ChatCompletion]:
stream: bool = False,
**kwargs: Any,
) -> Union[
Iterator[LLMChatCandidateChunk],
LLMChatResponse,
BaseModel,
List[BaseModel],
]:
"""
Generate chat completions based on provided messages or input_data for prompt templates.
Issue a chat completion.
- If `stream=True` in params, returns an iterator of `LLMChatCandidateChunk`.
- Otherwise returns either:
raw `AssistantMessage` wrapped in `LLMChatResponse`, or
validated Pydantic model(s) per `response_format`.
Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
model (str): Specific model to use for the request, overriding the default.
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
response_format (Type[BaseModel]): Optional Pydantic model for structured response parsing.
structured_mode (Literal["json", "function_call"]): Mode for structured output: "json" or "function_call".
**kwargs: Additional parameters for the language model.
messages: pre-built messages or None to use `input_data`.
input_data: variables for the Prompty template.
model: override client's default model.
tools: list of AgentTool or dict specs.
response_format: Pydantic model (or list thereof) for structured output.
structured_mode: json or function_call (non-stream only).
stream: if True, return an iterator of `LLMChatCandidateChunk`.
**kwargs: any other LLM params (temperature, top_p, stream, etc.).
Returns:
Union[Iterator[Dict[str, Any]], ChatCompletion]: The chat completion response(s).
"""
`Iterator[LLMChatCandidateChunk]` if streaming
`LLMChatResponse` or Pydantic instance(s) if non-streaming
Raises:
ValueError: on invalid `structured_mode`, missing prompts, etc.
"""
# 1) Validate structured_mode
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
raise ValueError(
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
)
# If input_data is provided, check for a prompt_template
# 2) If using a prompt template, build messages
if input_data:
if not self.prompt_template:
raise ValueError(
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
)
logger.info("Using prompt template to generate messages.")
raise ValueError("No prompt_template set for input_data usage.")
logger.info("Formatting messages via prompt_template.")
messages = self.prompt_template.format_prompt(**input_data)
# Ensure we have messages at this point
if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
raise ValueError("Either messages or input_data must be provided.")
# Process and normalize the messages
# 3) Normalize messages + merge client/prompty defaults
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
# Merge prompty parameters if available, then override with any explicit kwargs
if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else:
params.update(kwargs)
# If a model is provided, override the default model
# 4) Override model if given
params["model"] = model or self.model
# Prepare request parameters
# 5) Let RequestHandler inject tools / response_format / structured_mode
params = RequestHandler.process_params(
params,
llm_provider=self.provider,
@ -187,21 +206,21 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
structured_mode=structured_mode,
)
# 6) Call API + hand off to ResponseHandler
try:
logger.info("Invoking ChatCompletion API.")
logger.debug(f"ChatCompletion API Parameters: {params}")
response: ChatCompletionMessage = self.client.chat.completions.create(
**params, timeout=self.timeout
logger.info("Calling OpenAI ChatCompletion...")
logger.debug(f"ChatCompletion params: {params}")
resp = self.client.chat.completions.create(
**params, stream=stream, timeout=self.timeout
)
logger.info("Chat completion retrieved successfully.")
logger.info("ChatCompletion response received.")
return ResponseHandler.process_response(
response,
response=resp,
llm_provider=self.provider,
response_format=response_format,
structured_mode=structured_mode,
stream=params.get("stream", False),
stream=stream,
)
except Exception as e:
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
raise
logger.error("ChatCompletion API error", exc_info=True)
raise ValueError("Failed to process chat completion") from e

View File

@ -0,0 +1,266 @@
import dataclasses
import logging
from typing import Any, Callable, Dict, Iterator, Optional
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from dapr_agents.types.message import (
AssistantMessage,
FunctionCall,
LLMChatCandidate,
LLMChatCandidateChunk,
LLMChatResponseChunk,
LLMChatResponse,
ToolCall,
ToolCallChunk,
)
logger = logging.getLogger(__name__)
# Helper function to handle metadata extraction
def _get_packet_metadata(
pkt: Dict[str, Any], enrich_metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Extract metadata from OpenAI packet and merge with enrich_metadata.
Args:
pkt (Dict[str, Any]): The OpenAI packet from which to extract metadata.
enrich_metadata (Optional[Dict[str, Any]]): Additional metadata to merge with the extracted metadata.
Returns:
Dict[str, Any]: The merged metadata dictionary.
"""
try:
return {
"id": pkt.get("id"),
"created": pkt.get("created"),
"model": pkt.get("model"),
"object": pkt.get("object"),
"service_tier": pkt.get("service_tier"),
"system_fingerprint": pkt.get("system_fingerprint"),
**(enrich_metadata or {}),
}
except Exception as e:
logger.error(f"Failed to parse packet: {e}", exc_info=True)
return {}
# Helper function to process each choice delta (content, function call, tool call, finish reason)
def _process_choice_delta(
choice: Dict[str, Any],
overall_meta: Dict[str, Any],
on_chunk: Optional[Callable],
first_chunk_flag: bool,
) -> Iterator[LLMChatResponseChunk]:
"""
Process each choice delta and yield corresponding chunks.
Args:
choice (Dict[str, Any]): The choice delta from OpenAI response.
overall_meta (Dict[str, Any]): Overall metadata to include in chunks.
on_chunk (Optional[Callable]): Callback for each chunk.
first_chunk_flag (bool): Flag indicating if this is the first chunk.
Yields:
LLMChatResponseChunk: The processed chunk with content, function call, tool calls,
"""
# Make an immutable snapshot for this single chunk
meta = {**overall_meta}
# mark first_chunk exactly once
if first_chunk_flag and "first_chunk" not in meta:
meta["first_chunk"] = True
# Extract initial properties from choice
delta: dict = choice.get("delta", {})
idx = choice.get("index")
finish_reason = choice.get("finish_reason", None)
logprobs = choice.get("logprobs", None)
# Set additional metadata
if finish_reason in ("stop", "tool_calls"):
meta["last_chunk"] = True
# Process content delta
content = delta.get("content", None)
function_call = delta.get("function_call", None)
refusal = delta.get("refusal", None)
role = delta.get("role", None)
# Process tool calls
chunk_tool_calls = [ToolCallChunk(**tc) for tc in (delta.get("tool_calls") or [])]
# Initialize LLMChatResponseChunk
response_chunk = LLMChatResponseChunk(
result=LLMChatCandidateChunk(
content=content,
function_call=function_call,
refusal=refusal,
role=role,
tool_calls=chunk_tool_calls,
finish_reason=finish_reason,
index=idx,
logprobs=logprobs,
),
metadata=meta,
)
# Process chunk with on_chunk callback
if on_chunk:
on_chunk(response_chunk)
# Yield LLMChatResponseChunk
yield response_chunk
# Main function to process OpenAI streaming response
def process_openai_stream(
raw_stream: Iterator[ChatCompletionChunk],
*,
enrich_metadata: Optional[Dict[str, Any]] = None,
on_chunk: Optional[Callable],
) -> Iterator[LLMChatCandidateChunk]:
"""
Normalize OpenAI streaming chat into LLMChatCandidateChunk objects,
accumulating buffers per choice and yielding both partial and final chunks.
Args:
raw_stream: Iterator from client.chat.completions.create(..., stream=True)
enrich_metadata: Extra key/value pairs to merge into each chunk.metadata
on_chunk: Callback fired on every partial delta (token, function, tool)
Yields:
LLMChatCandidateChunk for every partial and final piece, in stream order
"""
enrich_metadata = enrich_metadata or {}
overall_meta: Dict[str, Any] = {}
# Track if we are in the first chunk
first_chunk_flag = True
for packet in raw_stream:
# Convert Pydantic / OpenAIObject → plain dict
if hasattr(packet, "model_dump"):
pkt = packet.model_dump()
elif hasattr(packet, "to_dict"):
pkt = packet.to_dict()
elif dataclasses.is_dataclass(packet):
pkt = dataclasses.asdict(packet)
else:
raise TypeError(f"Cannot serialize packet of type {type(packet)}")
# Capture overall metadata from the packet
overall_meta = _get_packet_metadata(pkt, enrich_metadata)
# Process each choice in this packet
if choices := pkt.get("choices"):
if len(choices) == 0:
logger.warning("Received empty 'choices' in OpenAI packet, skipping.")
continue
# Process the first choice in the packet
choice = choices[0]
yield from _process_choice_delta(
choice, overall_meta, on_chunk, first_chunk_flag
)
# Set first_chunk_flag to False after processing the first choice
first_chunk_flag = False
else:
logger.warning(f" Yielding packet without 'choices': {pkt}")
# Initialize default LLMChatResponseChunk
final_response_chunk = LLMChatResponseChunk(metadata=overall_meta)
yield final_response_chunk
def process_openai_chat_response(openai_response: ChatCompletion) -> LLMChatResponse:
"""
Convert an OpenAI ChatCompletion into our unified LLMChatResponse.
This function:
- Safely extracts each choice (skipping malformed ones)
- Builds an AssistantMessage with content/refusal/tool_calls/function_call
- Wraps into LLMChatCandidate (including index & logprobs)
- Collects provider metadata
Args:
openai_response: A Pydantic ChatCompletion from the OpenAI SDK.
Returns:
LLMChatResponse: Contains a list of candidates and a metadata dict.
"""
# 1) Turn into plain dict
try:
if hasattr(openai_response, "model_dump"):
resp = openai_response.model_dump()
elif hasattr(openai_response, "to_dict"):
resp = openai_response.to_dict()
elif dataclasses.is_dataclass(openai_response):
resp = dataclasses.asdict(openai_response)
else:
resp = dict(openai_response)
except Exception:
logger.exception("Failed to serialize OpenAI chat response")
resp = {}
candidates = []
for choice in resp.get("choices", []):
if "message" not in choice:
logger.warning(f"Skipping choice missing 'message': {choice}")
continue
msg = choice["message"]
# 2) Build tool_calls list if present
tool_calls = None
if msg.get("tool_calls"):
tool_calls = []
for tc in msg["tool_calls"]:
try:
tool_calls.append(
ToolCall(
id=tc["id"],
type=tc["type"],
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
)
except Exception as e:
logger.warning(f"Invalid tool_call entry {tc}: {e}")
# 3) Build function_call if present
function_call = None
if fc := msg.get("function_call"):
function_call = FunctionCall(
name=fc.get("name", ""),
arguments=fc.get("arguments", ""),
)
# 4) Assemble AssistantMessage
assistant_message = AssistantMessage(
content=msg.get("content"),
refusal=msg.get("refusal"),
tool_calls=tool_calls,
function_call=function_call,
)
# 5) Build candidate, including index & logprobs
candidate = LLMChatCandidate(
message=assistant_message,
finish_reason=choice.get("finish_reason"),
index=choice.get("index"),
logprobs=choice.get("logprobs"),
)
candidates.append(candidate)
# 6) Metadata: include provider tag
metadata: Dict[str, Any] = {
"provider": "openai",
"id": resp.get("id"),
"model": resp.get("model"),
"object": resp.get("object"),
"usage": resp.get("usage"),
"created": resp.get("created"),
}
return LLMChatResponse(results=candidates, metadata=metadata)

View File

@ -1,89 +1,128 @@
import logging
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Iterator, Literal, Optional, Type, Union
from typing import (
Any,
Callable,
Iterator,
Literal,
Optional,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel
from dapr_agents.llm.utils.stream import StreamHandler
from dapr_agents.llm.utils.structure import StructureHandler
from dapr_agents.types import ChatCompletion
from dapr_agents.types.message import (
LLMChatCandidateChunk,
LLMChatResponse,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
class ResponseHandler:
"""
Handles the processing of responses from language models.
Handles both streaming and non-streaming chat completions from various LLM providers.
"""
@staticmethod
def process_response(
response: Any,
llm_provider: str,
response_format: Optional[Type[BaseModel]] = None,
response_format: Optional[Type[T]] = None,
structured_mode: Literal["json", "function_call"] = "json",
stream: bool = False,
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
on_chunk: Optional[Callable[[LLMChatCandidateChunk], None]] = None,
) -> Union[
Iterator[LLMChatCandidateChunk], # when streaming
LLMChatResponse, # nonstream + no format
T, # nonstream + single structured format
list[T], # nonstream + list structured format
]:
"""
Process the response from the language model.
Process a chat completion.
- **Streaming** (`stream=True`):
Yields `LLMChatCandidateChunk` via `StreamHandler`, honoring `on_chunk` / `on_final`.
- **Non-streaming** (`stream=False`):
1. Normalize provider envelope `LLMChatResponse`.
2. If no `response_format` requested, return that `LLMChatResponse`.
3. Otherwise, extract the first assistant message, parse & validate it
against your Pydantic `response_format`, and return the model (or list).
Args:
response: The response object from the language model.
llm_provider: The LLM provider (e.g., 'openai').
response_format: A pydantic model to parse and validate the structured response.
structured_mode: The mode of the structured response: 'json' or 'function_call'.
stream: Whether the response is a stream.
response: Raw API return (stream iterator or full response object).
llm_provider: e.g. `"openai"`.
response_format: Optional Pydantic model (or `List[Model]`) for structured output.
structured_mode: `"json"` or `"function_call"` (only non-stream).
stream: Whether this is a streaming call.
on_chunk: Callback on every partial `LLMChatCandidateChunk`.
Returns:
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The processed response.
**streaming**: `Iterator[LLMChatCandidateChunk]`
**non-stream + no format**: full `LLMChatResponse`
**non-stream + format**: validated Pydantic model instance or `List[...]`
"""
provider = llm_provider.lower()
# ─── Streaming ─────────────────────────────────────────────────────────
if stream:
return StreamHandler.process_stream(
stream=response,
llm_provider=llm_provider,
response_format=response_format,
llm_provider=provider,
on_chunk=on_chunk,
)
else:
if response_format:
structured_response_json = StructureHandler.extract_structured_response(
response=response,
# ─── Nonstreaming ─────────────────────────────────────────────────────
# 1) Normalize full response → LLMChatResponse
if provider in ("openai", "nvidia"):
from dapr_agents.llm.openai.utils import process_openai_chat_response
llm_resp: LLMChatResponse = process_openai_chat_response(response)
elif provider == "huggingface":
from dapr_agents.llm.huggingface.utils import process_hf_chat_response
llm_resp = process_hf_chat_response(response)
elif provider == "dapr":
from dapr_agents.llm.dapr.utils import process_dapr_chat_response
llm_resp = process_dapr_chat_response(response)
else:
# if you add more providers, handle them here
llm_resp = response # type: ignore
# 2) If no structured format requested, return the full response
if response_format is None:
return llm_resp
# 3) They did request a Pydantic model → extract first assistant message
first_candidate = next(iter(llm_resp.results), None)
if not first_candidate:
raise ValueError("No candidates in LLMChatResponse")
assistant = first_candidate.message
# 3a) Get the raw JSON or functioncall payload
raw = StructureHandler.extract_structured_response(
message=assistant,
llm_provider=llm_provider,
structured_mode=structured_mode,
)
# Normalize format and resolve actual model class
normalized_format = StructureHandler.normalize_iterable_format(
response_format
)
model_cls = StructureHandler.resolve_response_model(normalized_format)
if not model_cls:
# 3b) Wrap List[Model] → IterableModel if needed
fmt = StructureHandler.normalize_iterable_format(response_format)
# 3c) Ensure exactly one Pydantic model inside
model_cls = StructureHandler.resolve_response_model(fmt)
if model_cls is None:
raise TypeError(
f"Could not resolve a valid Pydantic model from response_format: {response_format}"
f"Cannot resolve a Pydantic model from {response_format!r}"
)
structured_response_instance = StructureHandler.validate_response(
structured_response_json, normalized_format
)
# 3d) Validate JSON/dict → Pydantic
validated = StructureHandler.validate_response(raw, fmt)
logger.info("Structured output successfully validated.")
logger.info("Structured output was successfully validated.")
if hasattr(structured_response_instance, "objects"):
return structured_response_instance.objects
return structured_response_instance
# Convert response to dictionary
if isinstance(response, dict):
# Already a dictionary
response_dict = response
elif is_dataclass(response):
# Dataclass instance
response_dict = asdict(response)
elif isinstance(response, BaseModel):
# Pydantic object
response_dict = response.model_dump()
else:
raise ValueError(f"Unsupported response type: {type(response)}")
completion = ChatCompletion(**response_dict)
logger.debug(f"Chat completion response: {completion}")
return completion
# 3e) If its our autowrapped iterable model, return its `.objects` list
return getattr(validated, "objects", validated)

View File

@ -1,241 +1,58 @@
from typing import (
Dict,
Any,
Callable,
Iterator,
Type,
TypeVar,
Union,
Optional,
Iterable,
get_args,
TypeVar,
)
from dapr_agents.llm.utils.structure import StructureHandler
from dapr_agents.types import ToolCall
from openai.types.chat import ChatCompletionChunk
from pydantic import BaseModel, ValidationError
import logging
logger = logging.getLogger(__name__)
from openai.types.chat import ChatCompletionChunk
from pydantic import BaseModel
from dapr_agents.types.message import LLMChatCandidateChunk
T = TypeVar("T", bound=BaseModel)
class StreamHandler:
"""
Handles streaming of chat completion responses, processing tool calls and content responses.
Handles streaming of chat completion responses, delegating to the
provider-specific stream processor and optionally validating output
against Pydantic models.
"""
@staticmethod
def process_stream(
stream: Iterator[Dict[str, Any]],
stream: Iterator[ChatCompletionChunk],
llm_provider: str,
response_format: Optional[Union[Type[T], Type[Iterable[T]]]] = None,
) -> Iterator[Dict[str, Any]]:
on_chunk: Optional[Callable],
) -> Iterator[LLMChatCandidateChunk]:
"""
Stream chat completion responses.
Process a streaming chat completion.
Args:
stream: The response stream from the API.
llm_provider: The LLM provider to use (e.g., 'openai').
response_format: The optional Pydantic model or iterable model for validating the response.
stream: Iterator of ChatCompletionChunk from OpenAI SDK.
llm_provider: Name of the LLM provider (e.g., "openai").
on_chunk: Callback fired on every partial LLMChatCandidateChunk.
Yields:
dict: Each processed and validated chunk from the chat completion response.
LLMChatCandidateChunk: fully-typed chunks, partial and final.
"""
logger.info("Streaming response enabled.")
try:
if llm_provider == "openai":
yield from StreamHandler._process_openai_stream(stream, response_format)
if llm_provider in ("openai", "nvidia"):
from dapr_agents.llm.openai.utils import process_openai_stream
yield from process_openai_stream(
raw_stream=stream,
enrich_metadata={"provider": llm_provider},
on_chunk=on_chunk,
)
elif llm_provider == "huggingface":
from dapr_agents.llm.huggingface.utils import process_hf_stream
yield from process_hf_stream(
raw_stream=stream,
enrich_metadata={"provider": llm_provider},
on_chunk=on_chunk,
)
else:
yield from stream
except Exception as e:
logger.error(f"An error occurred during streaming: {e}")
raise
@staticmethod
def _process_openai_stream(
stream: Iterator[Dict[str, Any]],
response_format: Optional[Union[Type[T], Type[Iterable[T]]]] = None,
) -> Iterator[Dict[str, Any]]:
"""
Process OpenAI stream for chat completion.
Args:
stream: The response stream from the OpenAI API.
response_format: The optional Pydantic model or iterable model for validating the response.
Yields:
dict: Each processed and validated chunk from the chat completion response.
"""
content_accumulator = ""
json_extraction_active = False
json_brace_level = 0
json_string_buffer = ""
tool_calls = {}
for chunk in stream:
processed_chunk = StreamHandler._process_openai_chunk(chunk)
chunk_type = processed_chunk["type"]
chunk_data = processed_chunk["data"]
if chunk_type == "content":
content_accumulator += chunk_data
yield processed_chunk
elif chunk_type in ["tool_calls", "function_call"]:
for tool_chunk in chunk_data:
tool_call_index = tool_chunk["index"]
tool_call_id = tool_chunk["id"]
tool_call_function = tool_chunk["function"]
tool_call_arguments = tool_call_function["arguments"]
if tool_call_id is not None:
tool_calls.setdefault(
tool_call_index,
{
"id": tool_call_id,
"type": tool_chunk["type"],
"function": {
"name": tool_call_function["name"],
"arguments": tool_call_arguments,
},
},
)
# Add tool call arguments to current tool calls
tool_calls[tool_call_index]["function"][
"arguments"
] += tool_call_arguments
# Process Iterable model if provided
if (
response_format
and isinstance(response_format, Iterable) is True
):
trimmed_character = tool_call_arguments.strip()
# Check beginning of List
if trimmed_character == "[" and json_extraction_active is False:
json_extraction_active = True
# Check beginning of a JSON object
elif (
trimmed_character == "{" and json_extraction_active is True
):
json_brace_level += 1
json_string_buffer += trimmed_character
# Check the end of a JSON object
elif (
"}" in trimmed_character and json_extraction_active is True
):
json_brace_level -= 1
json_string_buffer += trimmed_character.rstrip(",")
if json_brace_level == 0:
yield from StreamHandler._validate_json_object(
response_format, json_string_buffer
)
# Reset buffers and counts
json_string_buffer = ""
elif json_extraction_active is True:
json_string_buffer += tool_call_arguments
if content_accumulator:
yield {"type": "final_content", "data": content_accumulator}
if tool_calls:
yield from StreamHandler._get_final_tool_calls(tool_calls, response_format)
@staticmethod
def _process_openai_chunk(chunk: ChatCompletionChunk) -> Dict[str, Any]:
"""
Process OpenAI chat completion chunk.
Args:
chunk: The chunk from the OpenAI API.
Returns:
dict: Processed chunk.
"""
try:
chunk_dict = chunk.model_dump()
if chunk_dict.get("choices") and len(chunk_dict["choices"]) > 0:
choice: Dict = chunk_dict["choices"][0]
delta: Dict = choice.get("delta", {})
# Process content
if delta.get("content") is not None:
return {"type": "content", "data": delta["content"], "chunk": chunk}
# Process tool calls
if delta.get("tool_calls"):
return {
"type": "tool_calls",
"data": delta["tool_calls"],
"chunk": chunk,
}
# Process function calls
if delta.get("function_call"):
return {
"type": "function_call",
"data": delta["function_call"],
"chunk": chunk,
}
# Process finish reason
if choice.get("finish_reason"):
return {
"type": "finish",
"data": choice["finish_reason"],
"chunk": chunk,
}
return {}
except Exception as e:
logger.error(f"Error handling OpenAI chat completion chunk: {e}")
raise
@staticmethod
def _validate_json_object(
response_format: Optional[Union[Type[T], Type[Iterable[T]]]],
json_string_buffer: str,
):
try:
model_class = get_args(response_format)[0]
# Return current tool call
structured_output = StructureHandler.validate_response(
json_string_buffer, model_class
)
if isinstance(structured_output, model_class):
logger.info("Structured output was successfully validated.")
yield {"type": "structured_output", "data": structured_output}
except ValidationError as validation_error:
logger.error(
f"Validation error: {validation_error} with JSON: {json_string_buffer}"
)
@staticmethod
def _get_final_tool_calls(
tool_calls: Dict[int, Any],
response_format: Optional[Union[Type[T], Type[Iterable[T]]]],
) -> Iterator[Dict[str, Any]]:
"""
Yield final tool calls after processing.
Args:
tool_calls: The dictionary of accumulated tool calls.
response_format: The response model for validation.
Yields:
dict: Each processed and validated tool call.
"""
for tool in tool_calls.values():
if response_format and isinstance(response_format, Iterable) is False:
structured_output = StructureHandler.validate_response(
tool["function"]["arguments"], response_format
)
if isinstance(structured_output, response_format):
logger.info("Structured output was successfully validated.")
yield {"type": "structured_output", "data": structured_output}
else:
tool_call = ToolCall(**tool)
yield {"type": "final_tool_call", "data": tool_call}
raise ValueError(f"Streaming not supported for provider: {llm_provider}")

View File

@ -18,7 +18,12 @@ from typing import (
from pydantic import BaseModel, Field, TypeAdapter, ValidationError, create_model
from dapr_agents.tool.utils.function_calling import to_function_call_definition
from dapr_agents.types import OAIJSONSchema, OAIResponseFormatSchema, StructureError
from dapr_agents.types import (
AssistantMessage,
OAIJSONSchema,
OAIResponseFormatSchema,
StructureError,
)
logger = logging.getLogger(__name__)
@ -208,7 +213,7 @@ class StructureHandler:
@staticmethod
def extract_structured_response(
response: Any,
message: AssistantMessage,
llm_provider: str,
structured_mode: Literal["json", "function_call"] = "json",
) -> Union[str, Dict[str, Any]]:
@ -216,7 +221,7 @@ class StructureHandler:
Extracts the structured JSON string or content from the response.
Args:
response (Any): The API response data to extract.
message (AssistantMessage): The API response data to extract.
llm_provider (str): The LLM provider (e.g., 'openai').
structured_mode (Literal["json", "function_call"]): The structured response mode.
@ -228,17 +233,7 @@ class StructureHandler:
"""
try:
logger.debug(f"Processing structured response for mode: {structured_mode}")
if llm_provider in ("openai", "nvidia"):
# Extract the `choices` list from the response
choices = getattr(response, "choices", None)
if not choices or not isinstance(choices, list):
raise StructureError("Response does not contain valid 'choices'.")
# Extract the message object
message = getattr(choices[0], "message", None)
if not message:
raise StructureError("Response message is missing.")
if llm_provider in ("openai", "nvidia", "huggingface"):
if structured_mode == "function_call":
tool_calls = getattr(message, "tool_calls", None)
if tool_calls:

View File

@ -55,9 +55,9 @@ class DaprHTTPClient(BaseModel):
otel_enabled = False
if otel_enabled:
from dapr_agents.agents.telemetry.otel import DaprAgentsOTel # type: ignore[import-not-found]
from dapr_agents.agents.telemetry.otel import DaprAgentsOtel # type: ignore[import-not-found]
otel_client = DaprAgentsOTel(
otel_client = DaprAgentsOtel(
service_name=os.getenv("OTEL_SERVICE_NAME", "dapr-http-client"),
otlp_endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
)

View File

@ -1,13 +1,14 @@
import logging
from typing import Any, Dict, Optional
from pydantic import BaseModel, ValidationError
from dapr_agents.types import (
ClaudeToolDefinition,
OAIFunctionDefinition,
OAIToolDefinition,
ClaudeToolDefinition,
)
from dapr_agents.types.exceptions import FunCallBuilderError
from pydantic import BaseModel, ValidationError
from typing import Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
@ -130,72 +131,97 @@ def to_function_call_definition(
args_schema: BaseModel,
format_type: str = "openai",
use_deprecated: bool = False,
) -> Dict:
) -> Dict[str, Any]:
"""
Generates a dictionary representing a function call specification, supporting various API formats.
The 'use_deprecated' flag is applicable only for the 'openai' format and is ignored for others.
- For format_type in ("openai", "nvidia", "huggingface"), produces an OpenAI-style
tool definition (type="function", function={}).
- For "claude", produces a Claude-style {name, description, input_schema}.
- (Gemini omitted herecall to_gemini_function_call_definition if you need it.)
The 'use_deprecated' flag is only applicable for OpenAI-style definitions.
Args:
name (str): The name of the function.
description (str): A brief description of what the function does.
args_schema (BaseModel): The Pydantic schema representing the function's parameters.
format_type (str, optional): The API format to convert to ('openai', 'claude', or 'gemini'). Defaults to 'openai'.
use_deprecated (bool): Flag to use the deprecated function format, only effective for 'openai'.
args_schema (BaseModel): The Pydantic model describing the function's parameters.
format_type (str, optional): Which API flavor to target:
- "openai", "nvidia", or "huggingface" all share the same OpenAI-style schema.
- "claude" uses Anthropic's format.
Defaults to "openai".
use_deprecated (bool): If True and format_type is OpenAI,
returns the old function-only schema rather than a tool wrapper.
Returns:
Dict: A dictionary containing the function definition in the specified format.
Dict[str, Any]: The serialized function/tool definition.
Raises:
FunCallBuilderError: If an unsupported format type is specified.
FunCallBuilderError: If an unsupported format_type is provided.
"""
if format_type.lower() in ("openai", "nvidia"):
fmt = format_type.lower()
# OpenAIstyle wrapper schema:
if fmt in ("openai", "nvidia", "huggingface"):
return to_openai_function_call_definition(
name, description, args_schema, use_deprecated
)
elif format_type.lower() == "claude":
# Anthropic Claude needs its own input_schema property
if fmt == "claude":
if use_deprecated:
logger.warning(
f"'use_deprecated' flag is ignored for the '{format_type}' format."
)
return to_claude_function_call_definition(name, description, args_schema)
else:
# Unsupported provider
logger.error(f"Unsupported format type: {format_type}")
raise FunCallBuilderError(f"Unsupported format type: {format_type}")
def validate_and_format_tool(
tool: Dict[str, Any], tool_format: str = "openai", use_deprecated: bool = False
) -> dict:
) -> Dict[str, Any]:
"""
Validates and formats a tool (provided as a dictionary) based on the specified API request format.
Validates and formats a tool definition dict for the specified API style.
- For tool_format in ("openai", "azure_openai", "nvidia", "huggingface"),
uses OAIToolDefinition (or OAIFunctionDefinition if use_deprecated=True).
- For "claude", uses ClaudeToolDefinition.
- For "llama", treats as an OAIFunctionDefinition.
Args:
tool: The tool to validate and format.
tool_format: The API format to convert to ('openai', 'azure_openai', 'claude', 'llama').
use_deprecated: Whether to use deprecated functions format for OpenAI. Defaults to False.
tool (Dict[str, Any]): The raw tool definition.
tool_format (str): Which API schema to validate against:
"openai", "azure_openai", "nvidia", "huggingface", "claude", "llama".
use_deprecated (bool): If True and using OpenAI-style, expects an OAIFunctionDefinition.
Returns:
dict: The formatted tool dictionary.
Dict[str, Any]: The validated, serialized tool definition.
Raises:
ValueError: If the tool definition format is invalid.
ValidationError: If the tool doesn't pass validation.
ValueError: If the format is unsupported or validation fails.
"""
fmt = tool_format.lower()
try:
if tool_format in ["openai", "azure_openai", "nvidia"]:
validated_tool = (
if fmt in ("openai", "azure_openai", "nvidia", "huggingface"):
validated = (
OAIFunctionDefinition(**tool)
if use_deprecated
else OAIToolDefinition(**tool)
)
elif tool_format == "claude":
validated_tool = ClaudeToolDefinition(**tool)
elif tool_format == "llama":
validated_tool = OAIFunctionDefinition(**tool)
elif fmt == "claude":
validated = ClaudeToolDefinition(**tool)
elif fmt == "llama":
validated = OAIFunctionDefinition(**tool)
else:
logger.error(f"Unsupported tool format: {tool_format}")
raise ValueError(f"Unsupported tool format: {tool_format}")
return validated_tool.model_dump()
return validated.model_dump()
except ValidationError as e:
logger.error(f"Validation error for {tool_format} tool definition: {e}")
raise ValueError(f"Invalid tool definition format: {tool}")

View File

@ -12,8 +12,6 @@ from .message import (
AssistantFinalMessage,
AssistantMessage,
BaseMessage,
ChatCompletion,
Choice,
EventMessageMetadata,
FunctionCall,
MessageContent,
@ -22,6 +20,8 @@ from .message import (
ToolCall,
ToolMessage,
UserMessage,
LLMChatResponse,
LLMChatCandidate,
)
from .schemas import OAIJSONSchema, OAIResponseFormatSchema
from .tools import (
@ -47,8 +47,8 @@ __all__ = [
"AssistantFinalMessage",
"AssistantMessage",
"BaseMessage",
"ChatCompletion",
"Choice",
"LLMChatResponse",
"LLMChatCandidate",
"EventMessageMetadata",
"FunctionCall",
"MessageContent",

View File

@ -5,7 +5,7 @@ from pydantic import (
model_validator,
ConfigDict,
)
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any
import json
@ -118,6 +118,36 @@ class ToolCall(BaseModel):
function: FunctionCall
class FunctionCallChunk(BaseModel):
"""
Represents a function call chunk in a streaming response, containing the function name and arguments.
Attributes:
name (str): The name of the function being called.
arguments (str): The JSON string representation of the function's arguments.
"""
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCallChunk(BaseModel):
"""
Represents a tool call chunk in a streaming response, containing the index, ID, type, and function call details.
Attributes:
index (int): The index of the tool call in the response.
id (str): Unique identifier for the tool call.
type (str): The type of the tool call.
function (FunctionCallChunk): The function call details associated with the tool call.
"""
index: int
id: Optional[str] = None
type: Optional[str] = None
function: FunctionCallChunk
class MessageContent(BaseMessage):
"""
Extends BaseMessage to include dynamic optional fields for tool calls, function calls, and tool call IDs.
@ -148,73 +178,6 @@ class MessageContent(BaseMessage):
return self
class Choice(BaseModel):
"""
Represents a choice made by the model, detailing the reason for completion, its index, and the message content.
Attributes:
finish_reason (str): Reason why the model stopped generating text.
index (int): Index of the choice in a list of potential choices.
message (MessageContent): Content of the message chosen by the model.
logprobs (Optional[dict]): Log probabilities associated with the choice.
"""
finish_reason: str
index: int
message: MessageContent
logprobs: Optional[dict]
class ChatCompletion(BaseModel):
"""
Represents the full response from the chat API, including all choices, metadata, and usage information.
Attributes:
choices (List[Choice]): List of choices provided by the model.
created (int): Timestamp when the response was created.
id (str): Unique identifier for the response.
model (str): Model used for generating the response.
object (str): Type of object returned.
usage (dict): Information about API usage for this request.
"""
choices: List[Choice]
created: int
id: Optional[str] = None
model: str
object: Optional[str] = None
usage: dict
def get_message(self) -> Optional[dict]:
"""
Retrieve the main message content from the first choice.
"""
return self.choices[0].message.model_dump() if self.choices else None
def get_reason(self) -> Optional[str]:
"""
Retrieve the reason for completion from the first choice.
"""
return self.choices[0].finish_reason if self.choices else None
def get_tool_calls(self) -> Optional[List[ToolCall]]:
"""
Retrieve tool calls from the first choice, if available.
"""
return (
self.choices[0].message.tool_calls
if self.choices and self.choices[0].message.tool_calls
else None
)
def get_content(self) -> Optional[str]:
"""
Retrieve the content from the first choice's message.
"""
message = self.get_message()
return message.get("content") if message else None
class SystemMessage(BaseMessage):
"""
Represents a system message, automatically assigning the role to 'system'.
@ -249,6 +212,7 @@ class AssistantMessage(BaseMessage):
"""
role: str = "assistant"
refusal: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
function_call: Optional[FunctionCall] = None
@ -256,7 +220,7 @@ class AssistantMessage(BaseMessage):
def remove_empty_calls(self):
attrList = []
for attribute in self.__dict__:
if attribute in ("tool_calls", "function_call"):
if attribute in ("tool_calls", "function_call", "refusal"):
if self.__dict__[attribute] is None:
attrList.append(attribute)
@ -265,6 +229,28 @@ class AssistantMessage(BaseMessage):
return self
def get_tool_calls(self) -> Optional[List[ToolCall]]:
"""
Retrieve tool calls from the message if available.
"""
if getattr(self, "tool_calls", None) is None:
return None
if isinstance(self.tool_calls, list):
return self.tool_calls
if isinstance(self.tool_calls, ToolCall):
return [self.tool_calls]
def has_tool_calls(self) -> bool:
"""
Check if the message has tool calls.
"""
if not hasattr(self, "tool_calls"):
return False
if self.tool_calls is not None:
return True
if isinstance(self.tool_calls, ToolCall):
return True
class ToolMessage(BaseMessage):
"""
@ -279,6 +265,67 @@ class ToolMessage(BaseMessage):
tool_call_id: str
class LLMChatCandidate(BaseModel):
"""
Represents a single candidate (output) from an LLM chat response.
Allows provider-specific extra fields (e.g., index, logprobs, etc.).
Attributes:
message (AssistantMessage): The assistant's message for this candidate.
finish_reason (Optional[str]): Why the model stopped generating text.
[Any other provider-specific fields, e.g., index, logprobs, etc.]
"""
message: AssistantMessage
finish_reason: Optional[str] = None
class Config:
extra = "allow"
class LLMChatResponse(BaseModel):
"""
Unified response for LLM chat completions, supporting multiple providers.
Attributes:
results (List[LLMChatCandidate]): List of candidate outputs.
metadata (dict): Provider/model metadata (id, model, usage, etc.).
"""
results: List[LLMChatCandidate]
metadata: dict = {}
def get_message(self) -> Optional[AssistantMessage]:
"""
Retrieves the first message from the results if available.
"""
return self.results[0].message if self.results else None
class LLMChatCandidateChunk(BaseModel):
"""
Represents a partial (streamed) candidate from an LLM provider, for real-time streaming.
"""
content: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None
refusal: Optional[str] = None
role: Optional[str] = None
tool_calls: Optional[List["ToolCallChunk"]] = None
finish_reason: Optional[str] = None
index: Optional[int] = None
logprobs: Optional[dict] = None
class LLMChatResponseChunk(BaseModel):
"""
Represents a partial (streamed) response from an LLM provider, for real-time streaming.
"""
result: LLMChatCandidateChunk
metadata: Optional[dict] = None
class AssistantFinalMessage(BaseModel):
"""
Represents a custom final message from the assistant, encapsulating a conclusive response to the user.

View File

@ -16,7 +16,7 @@ from dapr.ext.workflow.workflow_state import WorkflowState
from durabletask import task as dtask
from pydantic import BaseModel, ConfigDict, Field
from dapr_agents.agents.base import ChatClientType
from dapr_agents.agents.base import ChatClientBase
from dapr_agents.llm.openai import OpenAIChatClient
from dapr_agents.types.workflow import DaprWorkflowStatus
from dapr_agents.workflow.task import WorkflowTask
@ -32,7 +32,7 @@ class WorkflowApp(BaseModel):
A Pydantic-based class to encapsulate a Dapr Workflow runtime and manage workflows and tasks.
"""
llm: ChatClientType = Field(
llm: ChatClientBase = Field(
default_factory=OpenAIChatClient,
description="The default LLM client for all LLM-based tasks.",
)
@ -85,7 +85,7 @@ class WorkflowApp(BaseModel):
super().model_post_init(__context)
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientType]:
def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]:
"""
Encapsulate LLM selection logic.
1. Use per-task override if provided on decorator.

View File

@ -3,34 +3,35 @@ from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents.workflow.decorators import task, workflow, message_router
from dapr_agents.workflow.decorators import message_router, task, workflow
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
from dapr_agents.workflow.orchestrators.llm.schemas import (
BroadcastMessage,
TriggerAction,
NextStep,
AgentTaskResponse,
ProgressCheckOutput,
schemas,
)
from dapr_agents.workflow.orchestrators.llm.prompts import (
TASK_INITIAL_PROMPT,
TASK_PLANNING_PROMPT,
NEXT_STEP_PROMPT,
PROGRESS_CHECK_PROMPT,
SUMMARY_GENERATION_PROMPT,
TASK_INITIAL_PROMPT,
TASK_PLANNING_PROMPT,
)
from dapr_agents.workflow.orchestrators.llm.schemas import (
AgentTaskResponse,
BroadcastMessage,
NextStep,
ProgressCheckOutput,
TriggerAction,
schemas,
)
from dapr_agents.workflow.orchestrators.llm.state import (
LLMWorkflowState,
LLMWorkflowEntry,
LLMWorkflowMessage,
LLMWorkflowState,
PlanStep,
TaskResult,
)
from dapr_agents.workflow.orchestrators.llm.utils import (
update_step_statuses,
restructure_plan,
find_step_in_plan,
restructure_plan,
update_step_statuses,
)
logger = logging.getLogger(__name__)
@ -65,52 +66,46 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction):
"""
Executes an LLM-driven agentic workflow where the next agent is dynamically selected
based on task progress. The workflow iterates through execution cycles, updating state,
handling agent responses, and determining task completion.
based on task progress. Runs for up to `self.max_iterations` turns, then summarizes.
Args:
ctx (DaprWorkflowContext): The workflow execution context.
message (TriggerAction): The current workflow state containing `message`, `iteration`, and `verdict`.
message (TriggerAction): Contains the current `task`.
Returns:
str: The final processed message when the workflow terminates.
str: The final summary when the workflow terminates.
Raises:
RuntimeError: If the LLM determines the task is `failed`.
RuntimeError: If the workflow ends unexpectedly without a final summary.
"""
# Step 0: Retrieve iteration messages
# Step 1: Retrieve initial task and ensure state entry exists
task = message.get("task")
iteration = message.get("iteration", 0)
# Step 1:
# Ensure 'instances' and the instance_id entry exist
instance_id = ctx.instance_id
self.state.setdefault("instances", {}).setdefault(
instance_id, LLMWorkflowEntry(input=task).model_dump(mode="json")
)
# Retrieve the plan (will always exist after initialization)
plan = self.state["instances"][instance_id].get("plan", [])
final_summary: Optional[str] = None
# Single loop from turn 1 to max_iterations
for turn in range(1, self.max_iterations + 1):
if not ctx.is_replaying:
logger.info(
f"Workflow iteration {iteration + 1} started (Instance ID: {instance_id})."
f"Workflow turn {turn}/{self.max_iterations} (Instance ID: {instance_id})"
)
# Step 2: Retrieve available agents
# Step 2: Get available agents
agents = yield ctx.call_activity(self.get_agents_metadata_as_string)
# Step 3: First iteration setup
if iteration == 0:
# Step 3: On turn 1, generate plan and broadcast task
if turn == 1:
if not ctx.is_replaying:
logger.info(f"Initial message from User -> {self.name}")
# Generate the plan using a language model
plan = yield ctx.call_activity(
self.generate_plan,
input={"task": task, "agents": agents, "plan_schema": schemas.plan},
)
# Prepare initial message with task, agents and plan context
initial_message = yield ctx.call_activity(
self.prepare_initial_message,
input={
@ -120,14 +115,12 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
"plan": plan,
},
)
# broadcast initial message to all agents
yield ctx.call_activity(
self.broadcast_message_to_agents,
input={"instance_id": instance_id, "task": initial_message},
)
# Step 4: Identify agent and instruction for the next step
# Step 4: Determine next step and dispatch
next_step = yield ctx.call_activity(
self.generate_next_step,
input={
@ -137,8 +130,7 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
"next_step_schema": schemas.next_step,
},
)
# Extract Additional Properties from NextStep
# Additional Properties from NextStep
next_agent = next_step["next_agent"]
instruction = next_step["instruction"]
step_id = next_step.get("step", None)
@ -175,15 +167,16 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
# Step 8: Wait for agent response or timeout
if not ctx.is_replaying:
logger.info(f"Waiting for {next_agent}'s response...")
logger.debug(f"Waiting for {next_agent}'s response...")
event_data = ctx.wait_for_external_event("AgentTaskResponse")
timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
any_results = yield self.when_any([event_data, timeout_task])
# Step 9: Handle Agent Response or Timeout
if any_results == timeout_task:
logger.warning(
f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
f"Agent response timed out (Iteration: {turn}, Instance ID: {instance_id})."
)
task_results = {
"name": self.name,
@ -195,7 +188,7 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
if not ctx.is_replaying:
logger.info(f"{task_results['name']} sent a response.")
# Step 9: Save the task execution results to chat and task history
# Step 10: Save the task execution results to chat and task history
yield ctx.call_activity(
self.update_task_history,
input={
@ -207,7 +200,7 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
},
)
# Step 10: Check progress
# Step 11: Check progress
progress = yield ctx.call_activity(
self.check_progress,
input={
@ -220,14 +213,15 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
},
)
if not ctx.is_replaying:
logger.info(f"Tracking Progress: {progress}")
# Update verdict and plan based on progress
verdict = progress["verdict"]
if not ctx.is_replaying:
logger.debug(f"Progress verdict: {verdict}")
logger.debug(f"Tracking Progress: {progress}")
status_updates = progress.get("plan_status_update", [])
plan_updates = progress.get("plan_restructure", [])
# Step 11: Handle verdict and updates
# Step 12: Handle verdict and updates
if status_updates or plan_updates:
yield ctx.call_activity(
self.update_plan,
@ -238,33 +232,33 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
"plan_updates": plan_updates,
},
)
else:
if not ctx.is_replaying:
logger.warning(
f"Step {step_id}, Substep {substep_id} not found in plan for instance {instance_id}. Recovering..."
f"Invalid step {step_id}/{substep_id} in plan for instance {instance_id}. Retrying..."
)
# Recovery Task: No updates, just iterate again
verdict = "continue"
status_updates = []
plan_updates = []
task_results = {
"name": "orchestrator",
"name": self.name,
"role": "user",
"content": f"Step {step_id}, Substep {substep_id} does not exist in the plan. Adjusting workflow...",
}
# Step 12: Process progress suggestions and next iteration count
next_iteration_count = iteration + 1
if verdict != "continue" or next_iteration_count > self.max_iterations:
if next_iteration_count >= self.max_iterations:
verdict = "max_iterations_reached"
if verdict != "continue" or turn == self.max_iterations:
if not ctx.is_replaying:
logger.info(f"Workflow ending with verdict: {verdict}")
finale = (
"max_iterations_reached"
if turn == self.max_iterations
else verdict
)
logger.info(f"Ending workflow with verdict: {finale}")
# Generate final summary based on execution
summary = yield ctx.call_activity(
# Generate summary & finish
final_summary = yield ctx.call_activity(
self.generate_summary,
input={
"task": task,
@ -286,23 +280,20 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
"step": step_id,
"substep": substep_id,
"verdict": verdict,
"summary": summary,
"summary": final_summary,
},
)
# Return the final summary
if not ctx.is_replaying:
logger.info(
f"Workflow {instance_id} has been finalized with verdict: {verdict}"
)
logger.info(f"Workflow {instance_id} finalized.")
return final_summary
return summary
# --- PREPARE NEXT TURN ---
task = task_results["content"]
# Step 13: Update TriggerAction state and continue workflow
message["task"] = task_results["content"]
message["iteration"] = next_iteration_count
# Restart workflow with updated TriggerAction state
ctx.continue_as_new(message)
# Should never reach here
raise RuntimeError(f"LLMWorkflow {instance_id} exited without summary")
@task
def get_agents_metadata_as_string(self) -> str:
@ -778,11 +769,11 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
)
return
logger.info(
f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
# Log the received response
logger.debug(
f"{self.name} received response for workflow {workflow_instance_id}"
)
logger.debug(f"Full response: {message}")
# Raise a workflow event with the Agent's Task Response
self.raise_workflow_event(
instance_id=workflow_instance_id,

View File

@ -33,9 +33,6 @@ class TriggerAction(BaseModel):
None,
description="The specific task to execute. If not provided, the agent can act based on its memory or predefined behavior.",
)
iteration: Optional[int] = Field(
default=0, description="The current iteration of the workflow loop."
)
workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)

View File

@ -4,43 +4,40 @@ from typing import List, Dict, Any, Optional
def update_step_statuses(plan: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Ensures step and sub-step statuses follow logical progression:
- A step is marked "completed" if all substeps are "completed".
- If any sub-step is "in_progress", the parent step must also be "in_progress".
- If a sub-step is "completed" but the parent step is "not_started", update it to "in_progress".
- If a parent step is "completed" but a substep is still "in_progress", downgrade it to "in_progress".
- Steps without substeps should still progress logically.
Args:
plan (List[Dict[str, Any]]): The current execution plan.
Returns:
List[Dict[str, Any]]: The updated execution plan with correct statuses.
Parent completes if all substeps complete.
Parent goes in_progress if any substep is in_progress.
If substeps start completing, parent moves in_progress.
If parent was completed but a substep reverts to in_progress, parent downgrades.
Standalone steps (no substeps) are only updated via explicit status_updates.
"""
for step in plan:
# Case 0: Handle steps that have NO substeps
if "substeps" not in step or not step["substeps"]:
if step["status"] == "not_started":
step[
"status"
] = "in_progress" # Independent steps should start when execution begins
continue # Skip further processing if no substeps exist
subs = step.get("substeps", None)
substep_statuses = {ss["status"] for ss in step["substeps"]}
# --- NO substeps: do nothing here (explicit updates only) ---
if subs is None:
continue
# Case 1: If ALL substeps are "completed", parent step must be "completed".
if all(status == "completed" for status in substep_statuses):
# If substeps is not a list or is an empty list, treat as nosubsteps too:
if not isinstance(subs, list) or len(subs) == 0:
continue
# Collect child statuses
statuses = {ss["status"] for ss in subs}
# 1. All done → parent done
if statuses == {"completed"}:
step["status"] = "completed"
# Case 2: If ANY substep is "in_progress", parent step must also be "in_progress".
elif "in_progress" in substep_statuses:
# 2. Any in_progress → parent in_progress
elif "in_progress" in statuses:
step["status"] = "in_progress"
# Case 3: If a sub-step was completed but the step is still "not_started", update it.
elif "completed" in substep_statuses and step["status"] == "not_started":
# 3. Some done, parent not yet started → bump to in_progress
elif "completed" in statuses and step["status"] == "not_started":
step["status"] = "in_progress"
# Case 4: If the step is already marked as "completed" but a substep is still "in_progress", downgrade it.
elif step["status"] == "completed" and "in_progress" in substep_statuses:
# 4. If parent was completed but a child is in_progress, downgrade
elif step["status"] == "completed" and any(s != "completed" for s in statuses):
step["status"] = "in_progress"
return plan

View File

@ -1,13 +1,14 @@
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents.types import BaseMessage
from dapr_agents.workflow.decorators import task, workflow, message_router
from typing import Any, Optional, Dict
from datetime import timedelta
from pydantic import BaseModel, Field
import random
import logging
import random
from datetime import timedelta
from typing import Any, Dict, Optional
from dapr.ext.workflow import DaprWorkflowContext
from pydantic import BaseModel, Field
from dapr_agents.types import BaseMessage
from dapr_agents.workflow.decorators import message_router, task, workflow
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
logger = logging.getLogger(__name__)
@ -35,9 +36,9 @@ class TriggerAction(BaseModel):
task: Optional[str] = Field(
None,
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
description="The specific task to execute. If not provided, the agent will act "
"based on its memory or predefined behavior.",
)
iteration: Optional[int] = Field(0, description="")
workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)
@ -48,177 +49,153 @@ class RandomOrchestrator(OrchestratorWorkflowBase):
Implements a random workflow where agents are selected randomly to perform tasks.
The workflow iterates through conversations, selecting a random agent at each step.
Uses `continue_as_new` to persist iteration state.
Runs in a single for-loop, breaking when max_iterations is reached.
"""
current_speaker: Optional[str] = Field(
default=None, init=False, description="Current speaker in the conversation."
default=None,
init=False,
description="Current speaker in the conversation, to avoid immediate repeats when possible.",
)
def model_post_init(self, __context: Any) -> None:
"""
Initializes and configures the random workflow service.
Registers tasks and workflows, then starts the workflow runtime.
"""
self._workflow_name = "RandomWorkflow"
super().model_post_init(__context)
@workflow(name="RandomWorkflow")
# TODO: add retry policies on activities.
def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction):
"""
Executes a random workflow where agents are selected randomly for interactions.
Uses `continue_as_new` to persist iteration state.
Executes the random workflow in up to `self.max_iterations` turns, selecting
a different (or same) agent at random each turn.
Args:
ctx (DaprWorkflowContext): The workflow execution context.
input (TriggerAction): The current workflow state containing `message` and `iteration`.
ctx (DaprWorkflowContext): Workflow context.
input (TriggerAction): Contains `task`.
Returns:
str: The last processed message when the workflow terminates.
str: The final message content when the workflow terminates.
"""
# Step 0: Retrieving Loop Context
# Step 1: Gather initial task and instance ID
task = input.get("task")
iteration = input.get("iteration", 0)
instance_id = ctx.instance_id
final_output: Optional[str] = None
# Single loop from turn 1 to max_iterations inclusive
for turn in range(1, self.max_iterations + 1):
if not ctx.is_replaying:
logger.info(
f"Random workflow iteration {iteration + 1} started (Instance ID: {instance_id})."
f"Random workflow turn {turn}/{self.max_iterations} "
f"(Instance ID: {instance_id})"
)
# First iteration: Process input and broadcast
if iteration == 0:
message = yield ctx.call_activity(self.process_input, input={"task": task})
# Step 2: On turn 1, process initial task and broadcast
if turn == 1:
message = yield ctx.call_activity(
self.process_input, input={"task": task}
)
logger.info(f"Initial message from {message['role']} -> {self.name}")
# Step 1: Broadcast initial message
yield ctx.call_activity(
self.broadcast_message_to_agents, input={"message": message}
)
# Step 2: Select a random speaker
random_speaker = yield ctx.call_activity(
self.select_random_speaker, input={"iteration": iteration}
)
# Step 3: Select a random speaker
random_speaker = yield ctx.call_activity(self.select_random_speaker)
if not ctx.is_replaying:
logger.info(f"{self.name} selected {random_speaker} (Turn {turn}).")
# Step 3: Trigger agent
# Step 4: Trigger the agent
yield ctx.call_activity(
self.trigger_agent,
input={"name": random_speaker, "instance_id": instance_id},
)
# Step 4: Wait for response or timeout
logger.info("Waiting for agent response...")
# Step 5: Await for agent response or timeout
if not ctx.is_replaying:
logger.debug("Waiting for agent response...")
event_data = ctx.wait_for_external_event("AgentTaskResponse")
timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
any_results = yield self.when_any([event_data, timeout_task])
# Step 6: Handle response or timeout
if any_results == timeout_task:
if not ctx.is_replaying:
logger.warning(
f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
f"Turn {turn}: agent response timed out (Instance ID: {instance_id})."
)
task_results = {
result = {
"name": "timeout",
"content": "Timeout occurred. Continuing...",
"content": "Timeout occurred. Continuing...",
}
else:
task_results = yield event_data
logger.info(f"{task_results['name']} -> {self.name}")
result = yield event_data
if not ctx.is_replaying:
logger.info(f"{result['name']} -> {self.name}")
# Step 5: Check Iteration
next_iteration_count = iteration + 1
if next_iteration_count > self.max_iterations:
# Step 7: If this is the last allowed turn, mark final_output and break
if turn == self.max_iterations:
if not ctx.is_replaying:
logger.info(
f"Max iterations reached. Ending random workflow (Instance ID: {instance_id})."
f"Turn {turn}: max iterations reached (Instance ID: {instance_id})."
)
return task_results["content"]
final_output = result["content"]
break
# Update ChatLoop for next iteration
input["task"] = task_results["content"]
input["iteration"] = next_iteration_count
# Otherwise, feed into next turn
task = result["content"]
# Restart workflow with updated state
# TODO: would we want this updated to preserve agent state between iterations?
ctx.continue_as_new(input)
# Sanity check (should never happen)
if final_output is None:
raise RuntimeError(
"RandomWorkflow completed without producing a final_output"
)
# Return the final message content
return final_output
@task
async def process_input(self, task: str):
async def process_input(self, task: str) -> Dict[str, Any]:
"""
Processes the input message for the workflow.
Args:
task (str): The user-provided input task.
Returns:
dict: Serialized UserMessage with the content.
Wraps the raw task into a UserMessage dict.
"""
return {"role": "user", "name": self.name, "content": task}
@task
async def broadcast_message_to_agents(self, message: Dict[str, Any]):
"""
Broadcasts a message to all agents.
Args:
message (Dict[str, Any]): The message content and additional metadata.
Broadcasts a message to all agents (excluding orchestrator).
"""
# Format message for broadcasting
task_message = BroadcastMessage(**message)
# Send broadcast message
await self.broadcast_message(message=task_message, exclude_orchestrator=True)
@task
def select_random_speaker(self, iteration: int) -> str:
def select_random_speaker(self) -> str:
"""
Selects a random speaker, ensuring that a different agent is chosen if possible.
Args:
iteration (int): The current iteration number.
Returns:
str: The name of the randomly selected agent.
Selects a random speaker, avoiding repeats when possible.
"""
agents_metadata = self.get_agents_metadata(exclude_orchestrator=True)
if not agents_metadata:
logger.warning("No agents available for selection.")
raise ValueError(
"Agents metadata is empty. Cannot select a random speaker."
)
agents = self.get_agents_metadata(exclude_orchestrator=True)
if not agents:
logger.error("No agents available for selection.")
raise ValueError("Agents list is empty.")
agent_names = list(agents_metadata.keys())
names = list(agents.keys())
# Avoid repeating previous speaker if more than one agent
if len(names) > 1 and self.current_speaker in names:
names.remove(self.current_speaker)
# Handle single-agent scenarios
if len(agent_names) == 1:
random_speaker = agent_names[0]
logger.info(
f"Only one agent available: {random_speaker}. Using the same agent."
)
return random_speaker
# Select a random speaker, avoiding repeating the previous speaker when possible
previous_speaker = getattr(self, "current_speaker", None)
if previous_speaker in agent_names and len(agent_names) > 1:
agent_names.remove(previous_speaker)
random_speaker = random.choice(agent_names)
self.current_speaker = random_speaker
logger.info(
f"{self.name} randomly selected agent {random_speaker} (Iteration: {iteration})."
)
return random_speaker
choice = random.choice(names)
self.current_speaker = choice
return choice
@task
async def trigger_agent(self, name: str, instance_id: str) -> None:
"""
Triggers the specified agent to perform its activity.
Args:
name (str): Name of the agent to trigger.
instance_id (str): Workflow instance ID for context.
Sends a TriggerAction to the selected agent.
"""
logger.info(f"Triggering agent {name} (Instance ID: {instance_id})")
await self.send_message_to_agent(
name=name,
message=TriggerAction(workflow_instance_id=instance_id),
@ -227,33 +204,20 @@ class RandomOrchestrator(OrchestratorWorkflowBase):
@message_router
async def process_agent_response(self, message: AgentTaskResponse):
"""
Processes agent response messages sent directly to the agent's topic.
Args:
message (AgentTaskResponse): The agent's response containing task results.
Returns:
None: The function raises a workflow event with the agent's response.
Handles incoming AgentTaskResponse events and re-raises them into the workflow.
"""
try:
workflow_instance_id = getattr(message, "workflow_instance_id", None)
if not workflow_instance_id:
logger.error(
f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
)
logger.error("Missing workflow_instance_id on AgentTaskResponse; ignoring.")
return
logger.info(
f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
# Log the received response
logger.debug(
f"{self.name} received response for workflow {workflow_instance_id}"
)
logger.debug(f"Full response: {message}")
# Raise a workflow event with the Agent's Task Response
self.raise_workflow_event(
instance_id=workflow_instance_id,
event_name="AgentTaskResponse",
data=message,
)
except Exception as e:
logger.error(f"Error processing agent response: {e}", exc_info=True)

View File

@ -1,11 +1,13 @@
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents.types import BaseMessage
from dapr_agents.workflow.decorators import task, workflow, message_router
from typing import Any, Optional, Dict
from pydantic import BaseModel, Field
from datetime import timedelta
import logging
from datetime import timedelta
from typing import Any, Dict, Optional
from dapr.ext.workflow import DaprWorkflowContext
from pydantic import BaseModel, Field
from dapr_agents.types import BaseMessage
from dapr_agents.workflow.decorators import message_router, task, workflow
from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase
logger = logging.getLogger(__name__)
@ -33,9 +35,9 @@ class TriggerAction(BaseModel):
task: Optional[str] = Field(
None,
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
description="The specific task to execute. If not provided, the agent will act "
"based on its memory or predefined behavior.",
)
iteration: Optional[int] = Field(0, description="")
workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)
@ -44,110 +46,110 @@ class TriggerAction(BaseModel):
class RoundRobinOrchestrator(OrchestratorWorkflowBase):
"""
Implements a round-robin workflow where agents take turns performing tasks.
The workflow iterates through conversations by selecting agents in a circular order.
Uses `continue_as_new` to persist iteration state.
Iterates for up to `self.max_iterations` turns, then returns the last reply.
"""
def model_post_init(self, __context: Any) -> None:
"""
Initializes and configures the round-robin workflow service.
Registers tasks and workflows, then starts the workflow runtime.
Initializes and configures the round-robin workflow.
"""
self._workflow_name = "RoundRobinWorkflow"
super().model_post_init(__context)
@workflow(name="RoundRobinWorkflow")
# TODO: add retry policies on activities.
def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction):
def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction) -> str:
"""
Executes a round-robin workflow where agents interact iteratively.
Steps:
1. Processes input and broadcasts the initial message.
2. Iterates through agents, selecting a speaker each round.
3. Waits for agent responses or handles timeouts.
4. Updates the workflow state and continues the loop.
5. Terminates when max iterations are reached.
Uses `continue_as_new` to persist iteration state.
Drives the round-robin loop in up to `max_iterations` turns.
Args:
ctx (DaprWorkflowContext): The workflow execution context.
input (TriggerAction): The current workflow state containing task and iteration.
ctx (DaprWorkflowContext): Workflow context.
input (TriggerAction): Contains the initial `task`.
Returns:
str: The last processed message when the workflow terminates.
str: The final message content when the workflow terminates.
"""
# Step 1: Extract task and instance ID from input
task = input.get("task")
iteration = input.get("iteration", 0)
instance_id = ctx.instance_id
final_output: Optional[str] = None
# Loop from 1..max_iterations
for turn in range(1, self.max_iterations + 1):
if not ctx.is_replaying:
logger.info(
f"Round-robin iteration {iteration + 1} started (Instance ID: {instance_id})."
f"Round-robin turn {turn}/{self.max_iterations} "
f"(Instance ID: {instance_id})"
)
# Check Termination Condition
if iteration >= self.max_iterations:
# Step 2: On turn 1, process input and broadcast message
if turn == 1:
message = yield ctx.call_activity(
self.process_input, input={"task": task}
)
if not ctx.is_replaying:
logger.info(
f"Max iterations reached. Ending round-robin workflow (Instance ID: {instance_id})."
f"Initial message from {message['role']} -> {self.name}"
)
return task
# First iteration: Process input and broadcast
if iteration == 0:
message = yield ctx.call_activity(self.process_input, input={"task": task})
logger.info(f"Initial message from {message['role']} -> {self.name}")
# Broadcast initial message
yield ctx.call_activity(
self.broadcast_message_to_agents, input={"message": message}
)
# Select next speaker
next_speaker = yield ctx.call_activity(
self.select_next_speaker, input={"iteration": iteration}
# Step 3: Select next speaker in round-robin order
speaker = yield ctx.call_activity(
self.select_next_speaker, input={"turn": turn}
)
if not ctx.is_replaying:
logger.info(f"Selected agent {speaker} for turn {turn}")
# Trigger agent
# Step 4: Trigger that agent
yield ctx.call_activity(
self.trigger_agent, input={"name": next_speaker, "instance_id": instance_id}
self.trigger_agent,
input={"name": speaker, "instance_id": instance_id},
)
# Wait for response or timeout
logger.info("Waiting for agent response...")
# Step 5: Wait for agent response or timeout
if not ctx.is_replaying:
logger.debug("Waiting for agent response...")
event_data = ctx.wait_for_external_event("AgentTaskResponse")
timeout_task = ctx.create_timer(timedelta(seconds=self.timeout))
any_results = yield self.when_any([event_data, timeout_task])
# Step 6: Handle result or timeout
if any_results == timeout_task:
if not ctx.is_replaying:
logger.warning(
f"Agent response timed out (Iteration: {iteration + 1}, Instance ID: {instance_id})."
f"Turn {turn}: response timed out "
f"(Instance ID: {instance_id})"
)
task_results = {
result = {
"name": "timeout",
"content": "Timeout occurred. Continuing...",
}
else:
task_results = yield event_data
logger.info(f"{task_results['name']} -> {self.name}")
result = yield event_data
if not ctx.is_replaying:
logger.info(f"{result['name']} -> {self.name}")
# Check Iteration
next_iteration_count = iteration + 1
if next_iteration_count > self.max_iterations:
# Step 7: If this is the last allowed turn, capture and break
if turn == self.max_iterations:
if not ctx.is_replaying:
logger.info(
f"Max iterations reached. Ending round-robin workflow (Instance ID: {instance_id})."
f"Turn {turn}: max iterations reached (Instance ID: {instance_id})."
)
return task_results["content"]
final_output = result["content"]
break
# Update for next iteration
input["task"] = task_results["content"]
input["iteration"] = next_iteration_count
# Otherwise, feed into next iteration
task = result["content"]
# Restart workflow with updated state
# TODO: would we want this updated to preserve agent state between iterations?
ctx.continue_as_new(input)
# Sanity check: final_output must be set
if final_output is None:
raise RuntimeError(
"RoundRobinWorkflow completed without producing final_output"
)
return final_output
@task
async def process_input(self, task: str) -> Dict[str, Any]:
@ -171,17 +173,16 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
"""
# Format message for broadcasting
task_message = BroadcastMessage(**message)
# Send broadcast message
await self.broadcast_message(message=task_message, exclude_orchestrator=True)
@task
async def select_next_speaker(self, iteration: int) -> str:
async def select_next_speaker(self, turn: int) -> str:
"""
Selects the next speaker in round-robin order.
Args:
iteration (int): The current iteration number.
turn (int): The current turn number (1-based).
Returns:
str: The name of the selected agent.
"""
@ -191,12 +192,7 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
raise ValueError("Agents metadata is empty. Cannot select next speaker.")
agent_names = list(agents_metadata.keys())
# Determine the next agent in the round-robin order
next_speaker = agent_names[iteration % len(agent_names)]
logger.info(
f"{self.name} selected agent {next_speaker} for iteration {iteration}."
)
next_speaker = agent_names[(turn - 1) % len(agent_names)]
return next_speaker
@task
@ -208,6 +204,7 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
name (str): Name of the agent to trigger.
instance_id (str): Workflow instance ID for context.
"""
logger.info(f"Triggering agent {name} (Instance ID: {instance_id})")
await self.send_message_to_agent(
name=name,
message=TriggerAction(workflow_instance_id=instance_id),
@ -224,7 +221,6 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
Returns:
None: The function raises a workflow event with the agent's response.
"""
try:
workflow_instance_id = getattr(message, "workflow_instance_id", None)
if not workflow_instance_id:
@ -232,17 +228,14 @@ class RoundRobinOrchestrator(OrchestratorWorkflowBase):
f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring."
)
return
logger.info(
f"{self.name} processing agent response for workflow instance '{workflow_instance_id}'."
# Log the received response
logger.debug(
f"{self.name} received response for workflow {workflow_instance_id}"
)
logger.debug(f"Full response: {message}")
# Raise a workflow event with the Agent's Task Response
self.raise_workflow_event(
instance_id=workflow_instance_id,
event_name="AgentTaskResponse",
data=message,
)
except Exception as e:
logger.error(f"Error processing agent response: {e}", exc_info=True)

View File

@ -14,7 +14,7 @@ from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.openai import OpenAIChatClient
from dapr_agents.llm.utils import StructureHandler
from dapr_agents.prompt.utils.chat import ChatPromptHelper
from dapr_agents.types import BaseMessage, ChatCompletion, UserMessage
from dapr_agents.types import BaseMessage, UserMessage, LLMChatResponse
logger = logging.getLogger(__name__)
@ -261,23 +261,35 @@ class WorkflowTask(BaseModel):
Unwrap AI return types into plain Python.
Args:
result: ChatCompletion, BaseModel, or list of BaseModel.
result: One of:
- LLMChatResponse
- BaseModel (Pydantic)
- List[BaseModel]
- primitive (str/int/etc) or dict
Returns:
A primitive, dict, or list of dicts.
str (assistant content) when `LLMChatResponse`
dict when a single BaseModel
List[dict] when a list of BaseModels
otherwise, the raw `result`
"""
# Unwrap ChatCompletion
if isinstance(result, ChatCompletion):
logger.debug("Extracted message content from ChatCompletion.")
return result.get_content()
# Pydantic → dict
# 1) Unwrap our unified LLMChatResponse → return the assistant's text
if isinstance(result, LLMChatResponse):
logger.debug("Extracted message content from LLMChatResponse.")
msg = result.get_message()
return getattr(msg, "content", None)
# 2) Single Pydantic model → dict
if isinstance(result, BaseModel):
logger.debug("Converting Pydantic model to dictionary.")
return result.model_dump()
# 3) List of Pydantic models → list of dicts
if isinstance(result, list) and all(isinstance(x, BaseModel) for x in result):
logger.debug("Converting list of Pydantic models to list of dictionaries.")
return [x.model_dump() for x in result]
# If no specific conversion is necessary, return as-is
# 4) Fallback: primitive, dict, etc.
logger.info("Returning final task result.")
return result

View File

@ -137,8 +137,7 @@ durable_agent = DurableAgent(
state_key="workflow_state",
agents_registry_store_name="agentstatestore",
agents_registry_key="agents_registry",
),
)
),
```
## Agent Patterns

View File

@ -1,8 +1,18 @@
from dapr_agents import OpenAIChatClient
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types.message import LLMChatResponse
# load environment variables from .env file
load_dotenv()
# Initialize the OpenAI chat client
llm = OpenAIChatClient()
response = llm.generate("Tell me a joke")
if len(response.get_content()) > 0:
print("Got response:", response.get_content())
# Generate a response from the LLM
response: LLMChatResponse = llm.generate("Tell me a joke")
# Print the Message content if it exists
if response.get_message() is not None:
content = response.get_message().content
print("Got response:", content)

View File

@ -1,11 +1,15 @@
import logging
from dotenv import load_dotenv
from dapr_agents import Agent
from dapr_agents.document.embedder.sentence import SentenceTransformerEmbedder
from dapr_agents.storage.vectorstores import ChromaVectorStore
from dapr_agents.tool import tool
from dapr_agents.types.document import Document
load_dotenv()
logging.basicConfig(level=logging.INFO)
embedding_function = SentenceTransformerEmbedder(model="all-MiniLM-L6-v2")
@ -75,8 +79,6 @@ def add_machine_learning_doc() -> str:
async def main():
# Seed the vector store with initial documents using Document class
from dapr_agents.types.document import Document
documents = [
Document(
text="Gandalf: A wizard is never late, Frodo Baggins. Nor is he early; he arrives precisely when he means to.",
@ -127,5 +129,3 @@ if __name__ == "__main__":
print("\nInterrupted by user. Exiting gracefully...")
except Exception as e:
print(f"\nError occurred: {e}")
print("Make sure you have the required dependencies installed:")
print(" pip install sentence-transformers chromadb")

View File

@ -69,13 +69,24 @@ python 01_ask_llm.py
This example demonstrates the simplest way to use Dapr Agents' OpenAIChatClient:
```python
from dapr_agents import OpenAIChatClient
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types.message import LLMChatResponse
# load environment variables from .env file
load_dotenv()
# Initialize the OpenAI chat client
llm = OpenAIChatClient()
response = llm.generate("Tell me a joke")
print("Got response:", response.get_content())
# Generate a response from the LLM
response: LLMChatResponse = llm.generate("Tell me a joke")
# Print the Message content if it exists
if response.get_message() is not None:
content = response.get_message().content
print("Got response:", content)
```
**Expected output:** The LLM will respond with a joke.
@ -356,6 +367,9 @@ This example demonstrates how to create an agent with vector store capabilities,
```python
import logging
from dotenv import load_dotenv
from dapr_agents import Agent
from dapr_agents.document.embedder.sentence import SentenceTransformerEmbedder
from dapr_agents.storage.vectorstores import ChromaVectorStore
@ -480,8 +494,6 @@ if __name__ == "__main__":
print("\nInterrupted by user. Exiting gracefully...")
except Exception as e:
print(f"\nError occurred: {e}")
print("Make sure you have the required dependencies installed:")
print(" pip install sentence-transformers chromadb")
```
## Key Concepts

View File

@ -69,27 +69,34 @@ dapr run --app-id dapr-llm --resources-path ./components -- python text_completi
The script uses the `DaprChatClient` which connects to Dapr's `echo` LLM component:
```python
from dapr_agents.llm import DaprChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents.llm import DaprChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = DaprChatClient()
response = llm.generate("Name a famous dog!")
print("Response: ", response.get_content())
response: LLMChatResponse = llm.generate("Name a famous dog!")
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = DaprChatClient.from_prompty('basic.prompty')
response = llm.generate(input_data={"question":"What is your name?"})
print("Response with prompty: ", response.get_content())
llm = DaprChatClient.from_prompty("basic.prompty")
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = DaprChatClient()
response = llm.generate(messages=[UserMessage("hello")])
print("Response with user input: ", response.get_content())
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
if response.get_message() is not None and "hello" in response.get_message().content.lower():
print("Response with user input: ", response.get_message().content)
```
**Expected output:** The echo component will simply return the prompts that were sent to it.

View File

@ -5,7 +5,7 @@ model:
api: chat
configuration:
type: nvidia
name: meta/llama3-8b-instruct
name: HuggingFaceTB/SmolLM2-1.7B-Instruct
parameters:
max_tokens: 128
temperature: 0.2

View File

@ -1,28 +1,31 @@
from dapr_agents.llm import DaprChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents.llm import DaprChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = DaprChatClient()
response = llm.generate("Name a famous dog!")
response: LLMChatResponse = llm.generate("Name a famous dog!")
if len(response.get_content()) > 0:
print("Response: ", response.get_content())
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = DaprChatClient.from_prompty("basic.prompty")
response = llm.generate(input_data={"question": "What is your name?"})
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = DaprChatClient()
response = llm.generate(messages=[UserMessage("hello")])
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
print("Response with user input: ", response.get_content())
if (
response.get_message() is not None
and "hello" in response.get_message().content.lower()
):
print("Response with user input: ", response.get_message().content)

View File

@ -46,32 +46,184 @@ python text_completion.py
The script demonstrates basic usage of the DaprChatClient for text generation:
```python
from dapr_agents.llm import HFHubChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents.llm import HFHubChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
load_dotenv()
# Basic chat completion
llm = HFHubChatClient(
model="microsoft/Phi-3-mini-4k-instruct"
)
response = llm.generate("Name a famous dog!")
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
response: LLMChatResponse = llm.generate("Name a famous dog!")
if len(response.get_content()) > 0:
print("Response: ", response.get_content())
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = HFHubChatClient.from_prompty('basic.prompty')
response = llm.generate(input_data={"question":"What is your name?"})
llm = HFHubChatClient.from_prompty("basic.prompty")
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = HFHubChatClient(model="microsoft/Phi-3-mini-4k-instruct")
response = llm.generate(messages=[UserMessage("hello")])
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
print("Response with user input: ", response.get_content())
if response.get_message() is not None and "hello" in response.get_message().content.lower():
print("Response with user input: ", response.get_message().content)
```
**2. Expected output:** The LLM will respond with the name of a famous dog (e.g., "Lassie", "Hachiko", etc.).
**Run the structured text completion example:**
<!-- STEP
name: Run text completion example
expected_stdout_lines:
- '"name":'
- '"breed":'
- '"reason":'
timeout_seconds: 30
output_match_mode: substring
-->
```bash
python structured_completion.py
```
<!-- END_STEP -->
This example shows how to use Pydantic models to get structured data from LLMs:
```python
import json
from dotenv import load_dotenv
from pydantic import BaseModel
from dapr_agents import HFHubChatClient
from dapr_agents.types import UserMessage
# Load environment variables from .env
load_dotenv()
# Define our data model
class Dog(BaseModel):
name: str
breed: str
reason: str
# Initialize the chat client
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
# Get structured response
response: Dog = llm.generate(
messages=[UserMessage("One famous dog in history.")], response_format=Dog
)
print(json.dumps(response.model_dump(), indent=2))
```
**Expected output:** A JSON object with name, breed, and reason fields
```
{
"name": "Dog",
"breed": "Siberian Husky",
"reason": "Known for its endurance, intelligence, and loyalty, Siberian Huskies have played crucial roles in dog sledding and have been beloved companions for many."
}
```
### Streaming
Our Hugging Face chat client also support streaming responses, where you can process partial results as they arrive. Below are two examples:
**1. Basic Streaming Example**
Run the `text_completion_stream.py` script to see tokenbytoken output:
```bash
python text_completion_stream.py
```
The scripts:
```python
from dotenv import load_dotenv
from dapr_agents import HFHubChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
load_dotenv()
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
response: Iterator[LLMChatResponseChunk] = llm.generate("Name a famous dog!", stream=True)
for chunk in response:
if chunk.result.content:
print(chunk.result.content, end="", flush=True)
```
This will print each partial chunk as it arrives, so you can build up the full answer in real time.
**2. Streaming with Tool Calls:**
Use `text_completion_with_tools.py` to combine streaming with functioncall “tools”:
```bash
python text_completion_with_tools.py
```
```python
from dotenv import load_dotenv
from dapr_agents import HFHubChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
load_dotenv()
# Initialize client
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B", hf_provider="auto")
# Define a simple addition tool
def add_numbers(a: int, b: int) -> int:
return a + b
add_tool = {
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number."},
"b": {"type": "integer", "description": "The second number."}
},
"required": ["a", "b"]
}
}
}
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add 5 and 7 and 2 and 2."}
]
response: Iterator[LLMChatResponseChunk] = llm.generate(
messages=messages,
tools=[add_tool],
stream=True
)
for chunk in response:
print(chunk.result)
```
Here, the model can decide to call your add_numbers function midstream, and youll see those calls (and their results) as they come in.

View File

@ -0,0 +1,28 @@
import json
from dotenv import load_dotenv
from pydantic import BaseModel
from dapr_agents import HFHubChatClient
from dapr_agents.types import UserMessage
# Load environment variables from .env
load_dotenv()
# Define our data model
class Dog(BaseModel):
name: str
breed: str
reason: str
# Initialize the chat client
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
# Get structured response
response: Dog = llm.generate(
messages=[UserMessage("One famous dog in history.")], response_format=Dog
)
print(json.dumps(response.model_dump(), indent=2))

View File

@ -1,28 +1,30 @@
from dapr_agents.llm import HFHubChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents.llm import HFHubChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
load_dotenv()
# Basic chat completion
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
response = llm.generate("Name a famous dog!")
response: LLMChatResponse = llm.generate("Name a famous dog!")
if len(response.get_content()) > 0:
print("Response: ", response.get_content())
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = HFHubChatClient.from_prompty("basic.prompty")
response = llm.generate(input_data={"question": "What is your name?"})
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
response = llm.generate(messages=[UserMessage("hello")])
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
print("Response with user input: ", response.get_content())
if (
response.get_message() is not None
and "hello" in response.get_message().content.lower()
):
print("Response with user input: ", response.get_message().content)

View File

@ -0,0 +1,22 @@
from dotenv import load_dotenv
from dapr_agents import HFHubChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
response: Iterator[LLMChatResponseChunk] = llm.generate(
"Name a famous dog!", stream=True
)
for chunk in response:
if chunk.result.content:
print(chunk.result.content, end="", flush=True)

View File

@ -0,0 +1,50 @@
from dotenv import load_dotenv
from dapr_agents import HFHubChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B", hf_provider="auto")
# Define a tool for addition
def add_numbers(a: int, b: int) -> int:
return a + b
# Define the tool function call schema
add_tool = {
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number."},
"b": {"type": "integer", "description": "The second number."},
},
"required": ["a", "b"],
},
},
}
# Define messages for the chat
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add 5 and 7 and 2 and 2."},
]
response: Iterator[LLMChatResponseChunk] = llm.generate(
messages=messages, tools=[add_tool], stream=True
)
for chunk in response:
print(chunk.result)

View File

@ -57,34 +57,34 @@ python text_completion.py
The script demonstrates basic usage of Dapr Agents' NVIDIAChatClient for text generation:
```python
from dapr_agents import NVIDIAChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents import NVIDIAChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = NVIDIAChatClient()
response = llm.generate("Name a famous dog!")
response: LLMChatResponse = llm.generate("Name a famous dog!")
if len(response.get_content()) > 0:
print("Response: ", response.get_content())
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = NVIDIAChatClient.from_prompty('basic.prompty')
response = llm.generate(input_data={"question":"What is your name?"})
llm = NVIDIAChatClient.from_prompty("basic.prompty")
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = NVIDIAChatClient()
response = llm.generate(messages=[UserMessage("hello")])
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
print("Response with user input: ", response.get_content())
if response.get_message() is not None and "hello" in response.get_message().content.lower():
print("Response with user input: ", response.get_message().content)
```
**2. Expected output:** The LLM will respond with the name of a famous dog (e.g., "Lassie", "Hachiko", etc.).
@ -110,32 +110,30 @@ This example shows how to use Pydantic models to get structured data from LLMs:
```python
import json
from dotenv import load_dotenv
from pydantic import BaseModel
from dapr_agents import NVIDIAChatClient
from dapr_agents.types import UserMessage
from pydantic import BaseModel
from dotenv import load_dotenv
# Load environment variables from .env
load_dotenv()
# Define our data model
class Dog(BaseModel):
name: str
breed: str
reason: str
# Initialize the chat client
llm = NVIDIAChatClient(
model="meta/llama-3.1-8b-instruct"
)
llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
# Get structured response
response = llm.generate(
messages=[UserMessage("One famous dog in history.")],
response_format=Dog
response: Dog = llm.generate(
messages=[UserMessage("One famous dog in history.")], response_format=Dog
)
print(json.dumps(response.model_dump(), indent=2))
```
**Expected output:** A structured Dog object with name, breed, and reason fields (e.g., `Dog(name='Hachiko', breed='Akita', reason='Known for his remarkable loyalty...')`)
@ -156,3 +154,101 @@ output_match_mode: substring
python embeddings.py
```
<!-- END_STEP -->
### Streaming
Our NVIDIA chat client also support streaming responses, where you can process partial results as they arrive. Below are two examples:
**1. Basic Streaming Example**
Run the `text_completion_stream.py` script to see tokenbytoken output:
```bash
python text_completion_stream.py
```
The scripts:
```python
from dotenv import load_dotenv
from dapr_agents import NVIDIAChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = NVIDIAChatClient()
response: Iterator[LLMChatResponseChunk] = llm.generate("Name a famous dog!", stream=True)
for chunk in response:
if chunk.result.content:
print(chunk.result.content, end="", flush=True)
```
This will print each partial chunk as it arrives, so you can build up the full answer in real time.
**2. Streaming with Tool Calls:**
Use `text_completion_with_tools.py` to combine streaming with functioncall “tools”:
```bash
python text_completion_with_tools.py
```
```python
from dotenv import load_dotenv
from dapr_agents import NVIDIAChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = NVIDIAChatClient()
# Define a tool for addition
def add_numbers(a: int, b: int) -> int:
return a + b
# Define the tool function call schema
add_tool = {
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number."},
"b": {"type": "integer", "description": "The second number."}
},
"required": ["a", "b"]
}
}
}
# Define messages for the chat
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add 5 and 7 and 2 and 2."}
]
response: Iterator[LLMChatResponseChunk] = llm.generate(messages=messages, tools=[add_tool], stream=True)
for chunk in response:
print(chunk.result)
```
Here, the model can decide to call your `add_numbers` function midstream, and youll see those calls (and their results) as they come in.

View File

@ -1,9 +1,10 @@
import json
from dotenv import load_dotenv
from pydantic import BaseModel
from dapr_agents import NVIDIAChatClient
from dapr_agents.types import UserMessage
from pydantic import BaseModel
from dotenv import load_dotenv
# Load environment variables from .env
load_dotenv()
@ -20,7 +21,7 @@ class Dog(BaseModel):
llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
# Get structured response
response = llm.generate(
response: Dog = llm.generate(
messages=[UserMessage("One famous dog in history.")], response_format=Dog
)

View File

@ -1,28 +1,31 @@
from dapr_agents import NVIDIAChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents import NVIDIAChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = NVIDIAChatClient()
response = llm.generate("Name a famous dog!")
response: LLMChatResponse = llm.generate("Name a famous dog!")
if len(response.get_content()) > 0:
print("Response: ", response.get_content())
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = NVIDIAChatClient.from_prompty("basic.prompty")
response = llm.generate(input_data={"question": "What is your name?"})
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = NVIDIAChatClient()
response = llm.generate(messages=[UserMessage("hello")])
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
print("Response with user input: ", response.get_content())
if (
response.get_message() is not None
and "hello" in response.get_message().content.lower()
):
print("Response with user input: ", response.get_message().content)

View File

@ -0,0 +1,22 @@
from dotenv import load_dotenv
from dapr_agents import NVIDIAChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = NVIDIAChatClient()
response: Iterator[LLMChatResponseChunk] = llm.generate(
"Name a famous dog!", stream=True
)
for chunk in response:
if chunk.result.content:
print(chunk.result.content, end="", flush=True)

View File

@ -0,0 +1,50 @@
from dotenv import load_dotenv
from dapr_agents import NVIDIAChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = NVIDIAChatClient()
# Define a tool for addition
def add_numbers(a: int, b: int) -> int:
return a + b
# Define the tool function call schema
add_tool = {
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number."},
"b": {"type": "integer", "description": "The second number."},
},
"required": ["a", "b"],
},
},
}
# Define messages for the chat
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add 5 and 7 and 2 and 2."},
]
response: Iterator[LLMChatResponseChunk] = llm.generate(
messages=messages, tools=[add_tool], stream=True
)
for chunk in response:
print(chunk.result)

View File

@ -57,34 +57,34 @@ python text_completion.py
The script demonstrates basic usage of Dapr Agents' OpenAIChatClient for text generation:
```python
from dapr_agents import OpenAIChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = OpenAIChatClient()
response = llm.generate("Name a famous dog!")
response: LLMChatResponse = llm.generate("Name a famous dog!")
if len(response.get_content()) > 0:
print("Response: ", response.get_content())
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = OpenAIChatClient.from_prompty('basic.prompty')
response = llm.generate(input_data={"question":"What is your name?"})
llm = OpenAIChatClient.from_prompty("basic.prompty")
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = OpenAIChatClient()
response = llm.generate(messages=[UserMessage("hello")])
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
print("Response with user input: ", response.get_content())
if response.get_message() is not None and "hello" in response.get_message().content.lower():
print("Response with user input: ", response.get_message().content)
```
**2. Expected output:** The LLM will respond with the name of a famous dog (e.g., "Lassie", "Hachiko", etc.).
@ -110,27 +110,29 @@ This example shows how to use Pydantic models to get structured data from LLMs:
```python
import json
from dotenv import load_dotenv
from pydantic import BaseModel
from dapr_agents import OpenAIChatClient
from dapr_agents.types import UserMessage
from pydantic import BaseModel
from dotenv import load_dotenv
# Load environment variables from .env
load_dotenv()
# Define our data model
class Dog(BaseModel):
name: str
breed: str
reason: str
# Initialize the chat client
llm = OpenAIChatClient()
# Get structured response
response = llm.generate(
messages=[UserMessage("One famous dog in history.")],
response_format=Dog
response: Dog = llm.generate(
messages=[UserMessage("One famous dog in history.")], response_format=Dog
)
print(json.dumps(response.model_dump(), indent=2))
@ -138,6 +140,104 @@ print(json.dumps(response.model_dump(), indent=2))
**Expected output:** A structured Dog object with name, breed, and reason fields (e.g., `Dog(name='Hachiko', breed='Akita', reason='Known for his remarkable loyalty...')`)
### Streaming
Our OpenAI chat client also support streaming responses, where you can process partial results as they arrive. Below are two examples:
**1. Basic Streaming Example**
Run the `text_completion_stream.py` script to see tokenbytoken output:
```bash
python text_completion_stream.py
```
The scripts:
```python
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = OpenAIChatClient()
response: Iterator[LLMChatResponseChunk] = llm.generate("Name a famous dog!", stream=True)
for chunk in response:
if chunk.result.content:
print(chunk.result.content, end="", flush=True)
```
This will print each partial chunk as it arrives, so you can build up the full answer in real time.
**2. Streaming with Tool Calls:**
Use `text_completion_with_tools.py` to combine streaming with functioncall “tools”:
```bash
python text_completion_with_tools.py
```
```python
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = OpenAIChatClient()
# Define a tool for addition
def add_numbers(a: int, b: int) -> int:
return a + b
# Define the tool function call schema
add_tool = {
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number."},
"b": {"type": "integer", "description": "The second number."}
},
"required": ["a", "b"]
}
}
}
# Define messages for the chat
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add 5 and 7 and 2 and 2."}
]
response: Iterator[LLMChatResponseChunk] = llm.generate(messages=messages, tools=[add_tool], stream=True)
for chunk in response:
print(chunk.result)
```
Here, the model can decide to call your `add_numbers` function midstream, and youll see those calls (and their results) as they come in.
### Audio
You can use the OpenAIAudioClient in `dapr-agents` for basic tasks with the OpenAI Audio API. We will explore:
@ -151,8 +251,8 @@ You can use the OpenAIAudioClient in `dapr-agents` for basic tasks with the Open
<!-- STEP
name: Run audio generation example
expected_stdout_lines:
- "Audio saved to output_speech.mp3"
- "File output_speech.mp3 has been deleted."
- "Audio saved to speech.mp3"
- "File speech.mp3 has been deleted."
-->
```bash
python text_to_speech.py

View File

@ -1,9 +1,10 @@
import json
from dotenv import load_dotenv
from pydantic import BaseModel
from dapr_agents import OpenAIChatClient
from dapr_agents.types import UserMessage
from pydantic import BaseModel
from dotenv import load_dotenv
# Load environment variables from .env
load_dotenv()
@ -20,7 +21,7 @@ class Dog(BaseModel):
llm = OpenAIChatClient()
# Get structured response
response = llm.generate(
response: Dog = llm.generate(
messages=[UserMessage("One famous dog in history.")], response_format=Dog
)

View File

@ -1,28 +1,31 @@
from dapr_agents import OpenAIChatClient
from dapr_agents.types import UserMessage
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types import LLMChatResponse, UserMessage
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = OpenAIChatClient()
response = llm.generate("Name a famous dog!")
response: LLMChatResponse = llm.generate("Name a famous dog!")
if len(response.get_content()) > 0:
print("Response: ", response.get_content())
if response.get_message() is not None:
print("Response: ", response.get_message().content)
# Chat completion using a prompty file for context
llm = OpenAIChatClient.from_prompty("basic.prompty")
response = llm.generate(input_data={"question": "What is your name?"})
response: LLMChatResponse = llm.generate(input_data={"question": "What is your name?"})
if len(response.get_content()) > 0:
print("Response with prompty: ", response.get_content())
if response.get_message() is not None:
print("Response with prompty: ", response.get_message().content)
# Chat completion with user input
llm = OpenAIChatClient()
response = llm.generate(messages=[UserMessage("hello")])
response: LLMChatResponse = llm.generate(messages=[UserMessage("hello")])
if len(response.get_content()) > 0 and "hello" in response.get_content().lower():
print("Response with user input: ", response.get_content())
if (
response.get_message() is not None
and "hello" in response.get_message().content.lower()
):
print("Response with user input: ", response.get_message().content)

View File

@ -0,0 +1,22 @@
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = OpenAIChatClient()
response: Iterator[LLMChatResponseChunk] = llm.generate(
"Name a famous dog!", stream=True
)
for chunk in response:
if chunk.result.content:
print(chunk.result.content, end="", flush=True)

View File

@ -0,0 +1,50 @@
from dotenv import load_dotenv
from dapr_agents import OpenAIChatClient
from dapr_agents.types.message import LLMChatResponseChunk
from typing import Iterator
import logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env
load_dotenv()
# Basic chat completion
llm = OpenAIChatClient()
# Define a tool for addition
def add_numbers(a: int, b: int) -> int:
return a + b
# Define the tool function call schema
add_tool = {
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number."},
"b": {"type": "integer", "description": "The second number."},
},
"required": ["a", "b"],
},
},
}
# Define messages for the chat
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add 5 and 7 and 2 and 2."},
]
response: Iterator[LLMChatResponseChunk] = llm.generate(
messages=messages, tools=[add_tool], stream=True
)
for chunk in response:
print(chunk.result)

View File

@ -1,5 +1,3 @@
import os
from dapr_agents.types.llm import AudioSpeechRequest
from dapr_agents import OpenAIAudioClient
from dotenv import load_dotenv
@ -24,11 +22,8 @@ audio_bytes = client.create_speech(request=tts_request)
# client.create_speech(request=tts_request, file_name=output_path)
# Save the audio to an MP3 file
output_path = "output_speech.mp3"
output_path = "speech.mp3"
with open(output_path, "wb") as audio_file:
audio_file.write(audio_bytes)
print(f"Audio saved to {output_path}")
os.remove(output_path)
print(f"File {output_path} has been deleted.")

View File

@ -129,6 +129,10 @@ python weather_agent.py
**Expected output:** The agent will identify the locations and use the get_weather tool to fetch weather information for each city.
**Other Examples:** You can also try the following agents with the same tools using `HuggingFace hub` and `NVIDIA` LLM chat clients. Make sure you add the `HUGGINGFACE_API_KEY` and `NVIDIA_API_KEY` to the `.env` file.
- [HuggingFace Agent](./weather_agent_hf.py)
- [NVIDIA Agent](./weather_agent_nv.py)
## Key Concepts
### Tool Definition
@ -244,24 +248,7 @@ Dapr Agents provides two agent implementations, each designed for different use
The default agent type, designed for tool execution and straightforward interactions. It receives your input, determines which tools to use, executes them directly, and provides the final answer. The reasoning process is mostly hidden from you, focusing instead on delivering concise responses.
### 2. DurableAgent
The DurableAgent class is a workflow-based agent that extends the standard Agent with Dapr Workflows for long-running, fault-tolerant, and durable execution. It provides persistent state management, automatic retry mechanisms, and deterministic execution across failures. We will see this agent in the next example.
```python
from dapr_agents import Agent
from dapr_agents.tool.utils.openapi import OpenAPISpecParser
from dapr_agents.storage import VectorStore
# This agent type requires additional components
openapi_agent = Agent(
name="APIExpert",
role="API Expert",
pattern="openapireact", # Specify OpenAPIReAct pattern
spec_parser=OpenAPISpecParser(),
api_vector_store=VectorStore(),
auth_header={"Authorization": "Bearer token"}
)
```
The DurableAgent class is a workflow-based agent that extends the standard Agent with Dapr Workflows for long-running, fault-tolerant, and durable execution. It provides persistent state management, automatic retry mechanisms, and deterministic execution across failures. We will see this agent in the next example: [Durable Agent](../03-durable-agent-tool-call/).
## Troubleshooting

View File

@ -0,0 +1,30 @@
import asyncio
from weather_tools import tools
from dapr_agents import Agent, HFHubChatClient
from dotenv import load_dotenv
load_dotenv()
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
AIAgent = Agent(
name="Stevie",
role="Weather Assistant",
goal="Assist Humans with weather related tasks.",
instructions=[
"Always answer the user's main weather question directly and clearly.",
"If you perform any additional actions (like jumping), summarize those actions and their results.",
"At the end, provide a concise summary that combines the weather information for all requested locations and any other actions you performed.",
],
llm=llm,
tools=tools,
)
# Wrap your async call
async def main():
await AIAgent.run("What is the weather in Virginia, New York and Washington DC?")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,30 @@
import asyncio
from weather_tools import tools
from dapr_agents import Agent, NVIDIAChatClient
from dotenv import load_dotenv
load_dotenv()
llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
AIAgent = Agent(
name="Stevie",
role="Weather Assistant",
goal="Assist Humans with weather related tasks.",
instructions=[
"Always answer the user's main weather question directly and clearly.",
"If you perform any additional actions (like jumping), summarize those actions and their results.",
"At the end, provide a concise summary that combines the weather information for all requested locations and any other actions you performed.",
],
llm=llm,
tools=tools,
)
# Wrap your async call
async def main():
await AIAgent.run("What is the weather in Virginia, New York and Washington DC?")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -87,6 +87,11 @@ start the agent with Dapr:
dapr run --app-id durableweatherapp --resources-path ./components -- python durable_weather_agent.py
```
## Other Durable Agent
You can also try the following Durable agents with the same tools using `HuggingFace hub` and `NVIDIA` LLM chat clients. Make sure you add the `HUGGINGFACE_API_KEY` and `NVIDIA_API_KEY` to the `.env` file.
- [HuggingFace Durable Agent](./durable_weather_agent_hf.py)
- [NVIDIA Durable Agent](./durable_weather_agent_nv.py)
## About Durable Agents
Durable agents maintain state across runs, enabling workflows that require persistence, recovery, and coordination. This is useful for long-running tasks, multi-step workflows, and agent collaboration.

View File

@ -0,0 +1,37 @@
from dapr_agents import DurableAgent, HFHubChatClient
from dotenv import load_dotenv
from weather_tools import tools
import asyncio
import logging
async def main():
load_dotenv()
logging.basicConfig(level=logging.INFO)
# Initialize the HuggingFaceChatClient with the desired model
llm = HFHubChatClient(model="HuggingFaceTB/SmolLM3-3B")
# Instantiate your agent (no .as_service())
weather_agent = DurableAgent(
role="Weather Assistant",
name="Stevie",
goal="Help humans get weather and location info using smart tools.",
instructions=[
"Respond clearly and helpfully to weather-related questions.",
"Use tools when appropriate to fetch weather data.",
],
llm=llm,
message_bus_name="messagepubsub",
state_store_name="workflowstatestore",
state_key="workflow_state",
agents_registry_store_name="agentstatestore",
agents_registry_key="agents_registry",
tools=tools,
)
# Start the agent service
await weather_agent.run("What's the weather in Boston?")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,37 @@
from dapr_agents import DurableAgent, NVIDIAChatClient
from dotenv import load_dotenv
from weather_tools import tools
import asyncio
import logging
async def main():
load_dotenv()
logging.basicConfig(level=logging.INFO)
# Initialize the NVIDIAChatClient with the desired model
llm = NVIDIAChatClient(model="meta/llama-3.1-8b-instruct")
# Instantiate your agent (no .as_service())
weather_agent = DurableAgent(
role="Weather Assistant",
name="Stevie",
goal="Help humans get weather and location info using smart tools.",
instructions=[
"Respond clearly and helpfully to weather-related questions.",
"Use tools when appropriate to fetch weather data.",
],
llm=llm,
message_bus_name="messagepubsub",
state_store_name="workflowstatestore",
state_key="workflow_state",
agents_registry_store_name="agentstatestore",
agents_registry_key="agents_registry",
tools=tools,
)
# Start the agent service
await weather_agent.run("What's the weather in Boston?")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,9 +1,11 @@
import chainlit as cl
from dapr_agents import Agent
from dapr_agents.memory import ConversationDaprStateMemory
from dapr.clients import DaprClient
from dotenv import load_dotenv
from unstructured.partition.pdf import partition_pdf
from dapr.clients import DaprClient
from dapr_agents import Agent
from dapr_agents.memory import ConversationDaprStateMemory
from dapr_agents.types import AssistantMessage
load_dotenv()
@ -55,20 +57,22 @@ async def start():
upload(file_bytes, text_file.name, "upload")
# give the model the document to learn
response = await agent.run("This is a document element to learn: " + document_text)
response: AssistantMessage = await agent.run(
"This is a document element to learn: " + document_text
)
await cl.Message(content=f"`{text_file.name}` uploaded.").send()
await cl.Message(content=response).send()
await cl.Message(content=response.content).send()
@cl.on_message
async def main(message: cl.Message):
# chat to the model about the document
result = await agent.run(message.content)
result: AssistantMessage = await agent.run(message.content)
await cl.Message(
content=result,
content=result.content,
).send()

View File

@ -1,16 +1,16 @@
{
"instances": {
"e419ad7d06f8480498543e6244b8b3a0": {
"047a27d46af2482eac9e9dda8382281b": {
"input": "What is the weather in New York?",
"output": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
"start_time": "2025-07-23T02:35:58.877055",
"end_time": "2025-07-23T06:36:00.786049+00:00",
"output": "The current weather in New York is 66\u00b0F.",
"start_time": "2025-07-27T02:07:07.814621",
"end_time": "2025-07-27T06:07:09.682802+00:00",
"messages": [
{
"content": "What is the weather in New York?",
"role": "user",
"id": "b57d8dcf-a84e-4ed9-b6eb-b27360a6a466",
"timestamp": "2025-07-23T02:35:58.890010"
"id": "0133e81f-a5c6-4d02-ada6-80527292172e",
"timestamp": "2025-07-27T02:07:07.826629"
},
{
"content": null,
@ -18,7 +18,7 @@
"name": "Stevie",
"tool_calls": [
{
"id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
"id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
"type": "function",
"function": {
"name": "LocalGetWeather",
@ -26,42 +26,42 @@
}
}
],
"id": "7b508dae-d86a-4cb5-9b84-4bf13f5ef3a3",
"timestamp": "2025-07-23T02:35:59.660877"
"id": "ed40c479-55e8-4026-8e73-8c7cba47e291",
"timestamp": "2025-07-27T02:07:08.651564"
},
{
"content": "New York: 79F.",
"content": "New York: 66F.",
"role": "tool",
"tool_call_id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
"id": "4691eabc-1310-4ad7-bac6-7ecade2857a4",
"function_name": "LocalGetWeather",
"function_args": "{\"location\":\"New York\"}",
"timestamp": "2025-07-23T02:35:59.718943"
"name": "LocalGetWeather",
"tool_call_id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
"id": "f89659f6-b908-435b-9017-ee6013e2f519",
"timestamp": "2025-07-27T02:07:08.696208"
},
{
"content": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
"content": "The current weather in New York is 66\u00b0F.",
"role": "assistant",
"name": "Stevie",
"id": "54886dd0-108e-4e07-969c-4d04327768ee",
"timestamp": "2025-07-23T02:36:00.776750"
"id": "a8492536-32b1-433c-93bb-f5987091f5a2",
"timestamp": "2025-07-27T02:07:09.668079"
}
],
"last_message": {
"content": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
"content": "The current weather in New York is 66\u00b0F.",
"role": "assistant",
"name": "Stevie",
"id": "54886dd0-108e-4e07-969c-4d04327768ee",
"timestamp": "2025-07-23T02:36:00.776750"
"id": "a8492536-32b1-433c-93bb-f5987091f5a2",
"timestamp": "2025-07-27T02:07:09.668079"
},
"tool_history": [
{
"content": "New York: 79F.",
"role": "tool",
"tool_call_id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
"id": "4691eabc-1310-4ad7-bac6-7ecade2857a4",
"function_name": "LocalGetWeather",
"function_args": "{\"location\":\"New York\"}",
"timestamp": "2025-07-23T02:35:59.718943"
"id": "7f8c100a-1d11-47ee-b517-a6397f62bcba",
"timestamp": "2025-07-27T02:07:08.696235",
"tool_call_id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
"tool_name": "LocalGetWeather",
"tool_args": {
"location": "New York"
},
"execution_result": "New York: 66F."
}
],
"source": null,
@ -72,8 +72,8 @@
{
"content": "What is the weather in New York?",
"role": "user",
"id": "b57d8dcf-a84e-4ed9-b6eb-b27360a6a466",
"timestamp": "2025-07-23T02:35:58.890010"
"id": "0133e81f-a5c6-4d02-ada6-80527292172e",
"timestamp": "2025-07-27T02:07:07.826629"
},
{
"content": null,
@ -81,7 +81,7 @@
"name": "Stevie",
"tool_calls": [
{
"id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
"id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
"type": "function",
"function": {
"name": "LocalGetWeather",
@ -89,24 +89,23 @@
}
}
],
"id": "7b508dae-d86a-4cb5-9b84-4bf13f5ef3a3",
"timestamp": "2025-07-23T02:35:59.660877"
"id": "ed40c479-55e8-4026-8e73-8c7cba47e291",
"timestamp": "2025-07-27T02:07:08.651564"
},
{
"content": "New York: 79F.",
"content": "New York: 66F.",
"role": "tool",
"tool_call_id": "call_QeSq5TArEZJ67tEzhZXmqAPw",
"id": "4691eabc-1310-4ad7-bac6-7ecade2857a4",
"function_name": "LocalGetWeather",
"function_args": "{\"location\":\"New York\"}",
"timestamp": "2025-07-23T02:35:59.718943"
"name": "LocalGetWeather",
"tool_call_id": "call_4cpTh8rI2OzoaGJaJEgmovxG",
"id": "f89659f6-b908-435b-9017-ee6013e2f519",
"timestamp": "2025-07-27T02:07:08.696208"
},
{
"content": "The weather in New York is currently 79\u00b0F. If you need more information or have other questions, feel free to ask!",
"content": "The current weather in New York is 66\u00b0F.",
"role": "assistant",
"name": "Stevie",
"id": "54886dd0-108e-4e07-969c-4d04327768ee",
"timestamp": "2025-07-23T02:36:00.776750"
"id": "a8492536-32b1-433c-93bb-f5987091f5a2",
"timestamp": "2025-07-27T02:07:09.668079"
}
]
}

View File

@ -75,6 +75,10 @@ weather_agent = Agent(
instructions=["Instrictions go here"],
tools=tools,
)
# Run a sample query
result: AssistantMessage = await weather_agent.run("What is the weather in New York?")
print(result.content)
```
### Running the Example

View File

@ -4,6 +4,7 @@ from dotenv import load_dotenv
from dapr_agents import Agent
from dapr_agents.tool.mcp import MCPClient
from dapr_agents.types import AssistantMessage
load_dotenv()
@ -37,8 +38,10 @@ async def main():
)
# Run a sample query
result = await weather_agent.run("What is the weather in New York?")
print(result)
result: AssistantMessage = await weather_agent.run(
"What is the weather in New York?"
)
print(result.content)
finally:
try:

View File

@ -74,7 +74,7 @@ mkdir docker-entrypoint-initdb.d
cp schema.sql users.sql ./docker-entrypoint-initdb.d
```
Run the database container:
Run the database container (Make sure you are in the quickstart directory):
```bash
docker run --rm --name sampledb \
@ -130,7 +130,7 @@ Change the settings below based on your Postgres configuration:
*Note: If you're running Postgres in a Docker container, change `<HOST>` to `localhost`.*
```bash
docker run -p 8000:8000 -e DATABASE_URI=postgresql://<USERNAME>:<PASSWORD>@<HOST>:5432/userdb crystaldba/postgres-mcp --access-mode=unrestricted --transport=sse
docker run --rm -ti -p 8000:8000 -e DATABASE_URI=postgresql://<USERNAME>:<PASSWORD>@<HOST>:5432/userdb crystaldba/postgres-mcp --access-mode=unrestricted --transport=sse
```
## Examples

View File

@ -1,9 +1,11 @@
import chainlit as cl
from dapr_agents import Agent
from dapr_agents.tool.mcp.client import MCPClient
from dotenv import load_dotenv
from get_schema import get_table_schema_as_dict
from dapr_agents import Agent
from dapr_agents.tool.mcp.client import MCPClient
from dapr_agents.types import AssistantMessage
load_dotenv()
instructions = [
@ -52,18 +54,18 @@ async def start():
async def main(message: cl.Message):
# generate the result set and pass back to the user
prompt = create_prompt_for_llm(table_info, message.content)
result = await agent.run(prompt)
result: AssistantMessage = await agent.run(prompt)
await cl.Message(
content=result,
content=result.content,
).send()
result_set = await agent.run(
result_set: AssistantMessage = await agent.run(
"Execute the following sql query and always return a table format unless instructed otherwise. If the user asks a question regarding the data, return the result and formalize an answer based on inspecting the data: "
+ result
+ result.content
)
await cl.Message(
content=result_set,
content=result_set.content,
).send()

View File

@ -6,7 +6,7 @@ from dapr_agents.agents.agent.agent import Agent
from dapr_agents.types import (
AgentError,
AssistantMessage,
ChatCompletion,
LLMChatResponse,
ToolExecutionRecord,
UserMessage,
ToolCall,
@ -135,15 +135,15 @@ class TestAgent:
@pytest.mark.asyncio
async def test_run_agent_basic(self, basic_agent):
"""Test basic agent run functionality."""
mock_response = Mock(spec=ChatCompletion)
mock_response.get_message.return_value = AssistantMessage(content="Hello!")
mock_response.get_reason.return_value = "stop"
mock_response.get_content.return_value = "Hello!"
mock_response = Mock(spec=LLMChatResponse)
assistant_msg = AssistantMessage(content="Hello!")
mock_response.get_message.return_value = assistant_msg
basic_agent.llm.generate.return_value = mock_response
result = await basic_agent._run_agent("Hello")
assert result == "Hello!"
assert isinstance(result, AssistantMessage)
assert result.content == "Hello!"
basic_agent.llm.generate.assert_called_once()
@pytest.mark.asyncio
@ -155,26 +155,24 @@ class TestAgent:
tool_call = Mock(spec=ToolCall)
tool_call.id = "call_123"
tool_call.function = mock_function
mock_response = Mock(spec=ChatCompletion)
mock_response.get_message.return_value = AssistantMessage(content="Using tool")
mock_response.get_reason.return_value = "tool_calls"
mock_response.get_tool_calls.return_value = [tool_call]
agent_with_tools.llm.generate.return_value = mock_response
final_response = Mock(spec=ChatCompletion)
final_response.get_message.return_value = AssistantMessage(
content="Final answer"
)
final_response.get_reason.return_value = "stop"
final_response.get_content.return_value = "Final answer"
agent_with_tools.llm.generate.side_effect = [mock_response, final_response]
# Ensure the tool is present in the agent and executor
first_response = Mock(spec=LLMChatResponse)
first_assistant = AssistantMessage(content="Using tool", tool_calls=[tool_call])
first_response.get_message.return_value = first_assistant
second_response = Mock(spec=LLMChatResponse)
second_assistant = AssistantMessage(content="Final answer")
second_response.get_message.return_value = second_assistant
agent_with_tools.llm.generate.side_effect = [first_response, second_response]
agent_with_tools.tools = [echo_tool]
agent_with_tools._tool_executor = agent_with_tools._tool_executor.__class__(
tools=[echo_tool]
)
result = await agent_with_tools._run_agent("Use the tool")
assert result == "Final answer"
# tool_history is cleared after a successful run, so we do not assert on its length here -> we should probably fix this later on.
assert isinstance(result, AssistantMessage)
assert result.content == "Final answer"
@pytest.mark.asyncio
async def test_process_response_success(self, agent_with_tools):
@ -217,20 +215,18 @@ class TestAgent:
@pytest.mark.asyncio
async def test_process_iterations_max_reached(self, basic_agent):
"""Test that agent stops after max iterations."""
# Mock LLM to always return tool calls,
# and so this means the agent thinks it has a tool to call, but never actually calls a tool,
# so agent keeps looping until max iterations is reached.
mock_response = Mock(spec=ChatCompletion)
mock_response.get_message.return_value = AssistantMessage(content="Using tool")
mock_response.get_reason.return_value = "tool_calls"
mock_response.get_tool_calls.return_value = []
"""Test that agent stops immediately when there are no tool calls."""
mock_response = Mock(spec=LLMChatResponse)
assistant_msg = AssistantMessage(content="Using tool", tool_calls=[])
mock_response.get_message.return_value = assistant_msg
basic_agent.llm.generate.return_value = mock_response
result = await basic_agent.process_iterations([])
assert result is None
assert basic_agent.llm.generate.call_count == basic_agent.max_iterations
# current logic sees no tools ===> returns on first iteration
assert isinstance(result, AssistantMessage)
assert result.content == "Using tool"
assert basic_agent.llm.generate.call_count == 1
@pytest.mark.asyncio
async def test_process_iterations_with_llm_error(self, basic_agent):
@ -274,15 +270,15 @@ class TestAgent:
"""Test agent using memory context when no input is provided."""
basic_agent.memory.add_message(UserMessage(content="Previous message"))
mock_response = Mock(spec=ChatCompletion)
mock_response.get_message.return_value = AssistantMessage(content="Response")
mock_response.get_reason.return_value = "stop"
mock_response.get_content.return_value = "Response"
mock_response = Mock(spec=LLMChatResponse)
assistant_msg = AssistantMessage(content="Response")
mock_response.get_message.return_value = assistant_msg
basic_agent.llm.generate.return_value = mock_response
result = await basic_agent._run_agent(None)
assert result == "Response"
assert isinstance(result, AssistantMessage)
assert result.content == "Response"
basic_agent.llm.generate.assert_called_once()
def test_agent_tool_history_management(self, basic_agent):

View File

@ -2,11 +2,13 @@
# Right now we have to do a bunch of patching at the class-level instead of patching at the instance-level.
# In future, we should do dependency injection instead of patching at the class-level to make it easier to test.
# This applies to all areas in this file where we have with patch.object()...
import pytest
import asyncio
import os
from unittest.mock import Mock, AsyncMock, patch
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents.agents.durableagent.agent import DurableAgent
from dapr_agents.agents.durableagent.schemas import (
@ -14,22 +16,21 @@ from dapr_agents.agents.durableagent.schemas import (
BroadcastMessage,
)
from dapr_agents.agents.durableagent.state import (
DurableAgentWorkflowEntry,
DurableAgentWorkflowState,
)
from dapr_agents.memory import ConversationListMemory
from dapr_agents.llm import OpenAIChatClient
from dapr_agents.memory import ConversationListMemory
from dapr_agents.tool.base import AgentTool
from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents.tool.executor import AgentToolExecutor
# We need this otherwise these tests all fail since they require Dapr to be available.
@pytest.fixture(autouse=True)
def patch_dapr_check(monkeypatch):
from dapr_agents.workflow import agentic
from dapr_agents.workflow import base
from unittest.mock import Mock
from dapr_agents.workflow import agentic, base
# Mock the Dapr availability check to always return True
monkeypatch.setattr(
agentic.AgenticWorkflow, "_is_dapr_available", lambda self: True
@ -272,9 +273,6 @@ class TestDurableAgent:
"stop",
]
# Manually insert a mock instance to ensure the state is populated for the assertion
from dapr_agents.agents.durableagent.state import DurableAgentWorkflowEntry
basic_durable_agent.state["instances"][
"test-instance-123"
] = DurableAgentWorkflowEntry(
@ -299,23 +297,29 @@ class TestDurableAgent:
@pytest.mark.asyncio
async def test_generate_response_activity(self, basic_durable_agent):
"""Test the generate_response activity."""
mock_response = {
"choices": [
{
"message": {
"content": "Test response",
"tool_calls": [],
"finish_reason": "stop",
},
"finish_reason": "stop",
}
]
}
basic_durable_agent.llm.generate = Mock(return_value=mock_response)
"""Test that generate_response unwraps an LLMChatResponse properly."""
from dapr_agents.types import (
AssistantMessage,
LLMChatCandidate,
LLMChatResponse,
)
# create a fake LLMChatResponse with one choice
fake_response = LLMChatResponse(
results=[
LLMChatCandidate(
message=AssistantMessage(content="Test response", tool_calls=[]),
finish_reason="stop",
)
],
metadata={},
)
basic_durable_agent.llm.generate = Mock(return_value=fake_response)
instance_id = "test-instance-123"
workflow_entry = {
# set up a minimal instance record
basic_durable_agent.state["instances"] = {
instance_id: {
"input": "Test task",
"source": "test_source",
"source_workflow_instance_id": None,
@ -323,134 +327,16 @@ class TestDurableAgent:
"tool_history": [],
"output": None,
}
basic_durable_agent.state["instances"] = {instance_id: workflow_entry}
}
result = await basic_durable_agent.generate_response(instance_id, "Test task")
assert result == mock_response
assistant_dict = await basic_durable_agent.generate_response(
instance_id, "Test task"
)
# The dict dumped from AssistantMessage should have our content
assert assistant_dict["content"] == "Test response"
assert assistant_dict["tool_calls"] == []
basic_durable_agent.llm.generate.assert_called_once()
@pytest.mark.asyncio
async def test_get_response_message_activity(self, basic_durable_agent):
"""Test the get_response_message activity."""
response = {
"choices": [
{
"message": {
"content": "Test response",
"tool_calls": [],
"finish_reason": "stop",
}
}
]
}
result = basic_durable_agent.get_response_message(response)
assert result["content"] == "Test response"
assert result["tool_calls"] == []
assert result["finish_reason"] == "stop"
@pytest.mark.asyncio
async def test_get_finish_reason_activity(self, basic_durable_agent):
"""Test the get_finish_reason activity."""
response = {"choices": [{"finish_reason": "stop"}]}
result = basic_durable_agent.get_finish_reason(response)
assert result == "stop"
@pytest.mark.asyncio
async def test_get_tool_calls_activity_with_tools(self, durable_agent_with_tools):
"""Test the get_tool_calls activity when tools are present."""
response = {
"choices": [
{
"message": {
"tool_calls": [
{
"id": "call_123",
"function": {
"name": "test_tool",
"arguments": '{"arg1": "value1"}',
},
}
]
}
}
]
}
result = durable_agent_with_tools.get_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0]["id"] == "call_123"
assert result[0]["function"]["name"] == "test_tool"
@pytest.mark.asyncio
async def test_get_tool_calls_activity_no_tools(self, basic_durable_agent):
"""Test the get_tool_calls activity when no tools are present."""
response = {"tool_calls": []}
result = basic_durable_agent.get_tool_calls(response)
assert result is None
@pytest.mark.asyncio
async def test_execute_tool_activity_success(self, durable_agent_with_tools):
"""Test successful tool execution in activity."""
tool_call = {
"id": "call_123",
"function": {"name": "test_tool", "arguments": '{"arg1": "value1"}'},
}
instance_id = "test-instance-123"
workflow_entry = {
"input": "Test task",
"source": "test_source",
"source_workflow_instance_id": None,
"messages": [],
"tool_history": [],
"output": None,
}
durable_agent_with_tools.state["instances"] = {instance_id: workflow_entry}
with patch.object(
AgentToolExecutor,
"run_tool",
new_callable=AsyncMock,
return_value="test_result",
) as mock_run_tool:
result = await durable_agent_with_tools.run_tool(tool_call)
# Simulate appending to tool_history as the workflow would do
durable_agent_with_tools.state["instances"][instance_id].setdefault(
"tool_history", []
).append(result)
instance_data = durable_agent_with_tools.state["instances"][instance_id]
assert len(instance_data["tool_history"]) == 1
tool_entry = instance_data["tool_history"][0]
assert tool_entry["tool_call_id"] == "call_123"
assert tool_entry["tool_name"] == "test_tool"
assert tool_entry["execution_result"] == "test_result"
mock_run_tool.assert_called_once_with("test_tool", arg1="value1")
@pytest.mark.asyncio
async def test_execute_tool_activity_failure(self, durable_agent_with_tools):
"""Test tool execution failure in activity."""
tool_call = {
"id": "call_123",
"function": {"name": "test_tool", "arguments": '{"arg1": "value1"}'},
}
with patch.object(
type(durable_agent_with_tools.tool_executor),
"run_tool",
side_effect=Exception("Tool failed"),
):
with pytest.raises(Exception, match="Tool failed"):
await durable_agent_with_tools.run_tool(tool_call)
@pytest.mark.asyncio
async def test_broadcast_message_to_agents_activity(self, basic_durable_agent):
"""Test broadcasting message to agents activity."""
@ -486,7 +372,8 @@ class TestDurableAgent:
"""Test finishing workflow activity."""
instance_id = "test-instance-123"
final_output = "Final response"
workflow_entry = {
basic_durable_agent.state["instances"] = {
instance_id: {
"input": "Test task",
"source": "test_source",
"source_workflow_instance_id": None,
@ -494,7 +381,7 @@ class TestDurableAgent:
"tool_history": [],
"output": None,
}
basic_durable_agent.state["instances"] = {instance_id: workflow_entry}
}
basic_durable_agent.finalize_workflow(instance_id, final_output)
instance_data = basic_durable_agent.state["instances"][instance_id]
@ -513,7 +400,8 @@ class TestDurableAgent:
}
final_output = "Final output"
workflow_entry = {
basic_durable_agent.state["instances"] = {
instance_id: {
"input": "Test task",
"source": "test_source",
"source_workflow_instance_id": None,
@ -521,7 +409,7 @@ class TestDurableAgent:
"tool_history": [],
"output": None,
}
basic_durable_agent.state["instances"] = {instance_id: workflow_entry}
}
basic_durable_agent.append_assistant_message(instance_id, message)
basic_durable_agent.append_tool_message(instance_id, tool_execution_record)

View File

@ -1,81 +1,37 @@
from pydantic import BaseModel
from typing import Optional, Dict, Any, Union, Iterator, Type, Iterable
from pathlib import Path
from collections import UserDict
from dapr_agents.llm import OpenAIChatClient
from dapr_agents.types.message import (
LLMChatResponse,
LLMChatCandidate,
AssistantMessage,
)
class MockLLMClient(UserDict):
"""Mock LLM client for testing."""
class MockLLMClient(OpenAIChatClient):
"""Mock LLM client for testing that passes type validation."""
def __init__(self, **kwargs):
# Initialize UserDict properly
super().__init__()
# Set default values
self.data["model"] = kwargs.get("model", "gpt-4o")
self.data["azure_deployment"] = kwargs.get("azure_deployment", None)
self.data["prompt_template"] = kwargs.get("prompt_template", None)
self.data["api_key"] = kwargs.get("api_key", "mock-api-key")
self.data["base_url"] = kwargs.get("base_url", "https://api.openai.com/v1")
self.data["timeout"] = kwargs.get("timeout", 1500)
# Store additional attributes that might be accessed
object.__setattr__(
self, "_prompt_template", kwargs.get("prompt_template", None)
)
object.__setattr__(self, "_model", self.data["model"])
object.__setattr__(self, "_azure_deployment", self.data["azure_deployment"])
def __getattr__(self, name):
if name == "data":
return object.__getattribute__(self, "data")
if name in self.data:
return self.data[name]
return None
def __setattr__(self, name, value):
if name == "data" or name.startswith("_"):
object.__setattr__(self, name, value)
else:
self.data[name] = value
@classmethod
def from_prompty(
cls,
prompty_source: Union[str, Path],
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "MockLLMClient":
"""Mock implementation of from_prompty method."""
return cls(timeout=timeout)
def get_client(self):
"""Mock implementation of get_client."""
return None
def get_config(self) -> Dict[str, Any]:
"""Mock implementation of get_config."""
return {
"model": self.data["model"],
"azure_deployment": self.data["azure_deployment"],
"api_key": self.data["api_key"],
"base_url": self.data["base_url"],
"timeout": self.data["timeout"],
}
super().__init__(model=kwargs.get("model", "gpt-4o"), api_key="mock-api-key")
self.prompt_template = kwargs.get("prompt_template", None)
def generate(
self,
messages: Union[
str,
Dict[str, Any],
Iterable[Union[Dict[str, Any]]],
] = None,
input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None,
tools: Optional[list] = None,
response_format: Optional[Type[BaseModel]] = None,
structured_mode: str = "json",
messages=None,
*,
input_data=None,
model=None,
tools=None,
response_format=None,
structured_mode="json",
stream=False,
**kwargs,
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
"""Mock implementation of generate method."""
return {
"choices": [{"message": {"content": "Mock response", "role": "assistant"}}]
}
):
return LLMChatResponse(
results=[
LLMChatCandidate(
message=AssistantMessage(
content="This is a mock response from the LLM client."
),
finish_reason="stop",
)
]
)

View File

@ -83,6 +83,6 @@ async def test_select_random_speaker(orchestrator_config):
mockclient.return_value = MagicMock()
orchestrator = RandomOrchestrator(**orchestrator_config)
speaker = orchestrator.select_random_speaker(iteration=1)
speaker = orchestrator.select_random_speaker()
assert speaker in ["agent1", "agent2"]
assert orchestrator.current_speaker == speaker