feat: trace through tool call

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>
This commit is contained in:
Casper Guldbech Nielsen 2025-05-05 06:00:43 -07:00
parent 7852aee7a3
commit 9d1a9e89ee
No known key found for this signature in database
GPG Key ID: B004583B52B9A446
3 changed files with 60 additions and 1 deletions

View File

@ -8,6 +8,15 @@ from dapr_agents.tool.utils.tool import ToolHelper
from dapr_agents.tool.utils.function_calling import to_function_call_definition from dapr_agents.tool.utils.function_calling import to_function_call_definition
from dapr_agents.types import ToolError from dapr_agents.types import ToolError
from pydantic import PrivateAttr
from dapr_agents.agent.telemetry import (
span_decorator,
async_span_decorator,
)
from opentelemetry import trace
from opentelemetry.trace import Tracer, Status, StatusCode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,6 +46,7 @@ class AgentTool(BaseModel):
) )
_is_async: bool = PrivateAttr(default=False) _is_async: bool = PrivateAttr(default=False)
_tracer: Optional[Tracer] = PrivateAttr(default=None)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -76,6 +86,17 @@ class AgentTool(BaseModel):
self._initialize_from_func(self.func) self._initialize_from_func(self.func)
else: else:
self._initialize_from_run() self._initialize_from_run()
try:
provider = provider = trace.get_tracer_provider()
self._tracer = provider.get_tracer(f"{self.name}_tracer")
except Exception as e:
logger.warning(
f"OpenTelemetry initialization failed: {e}. Continuing without telemetry."
)
self._tracer = None
return super().model_post_init(__context) return super().model_post_init(__context)
def _initialize_from_func(self, func: Callable) -> None: def _initialize_from_func(self, func: Callable) -> None:
@ -88,6 +109,7 @@ class AgentTool(BaseModel):
if self.args_model is None: if self.args_model is None:
self.args_model = ToolHelper.infer_func_schema(self._run) self.args_model = ToolHelper.infer_func_schema(self._run)
@span_decorator("validate_and_prep_args")
def _validate_and_prepare_args( def _validate_and_prepare_args(
self, func: Callable, *args, **kwargs self, func: Callable, *args, **kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -138,15 +160,19 @@ class AgentTool(BaseModel):
except Exception as e: except Exception as e:
self._log_and_raise_error(e) self._log_and_raise_error(e)
@async_span_decorator("arun_tool")
async def arun(self, *args, **kwargs) -> Any: async def arun(self, *args, **kwargs) -> Any:
""" """
Execute the tool asynchronously (whether it's sync or async under the hood). Execute the tool asynchronously (whether it's sync or async under the hood).
""" """
span = trace.get_current_span()
try: try:
func = self.func or self._run func = self.func or self._run
kwargs = self._validate_and_prepare_args(func, *args, **kwargs) kwargs = self._validate_and_prepare_args(func, *args, **kwargs)
return await func(**kwargs) if self._is_async else func(**kwargs) return await func(**kwargs) if self._is_async else func(**kwargs)
except Exception as e: except Exception as e:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
self._log_and_raise_error(e) self._log_and_raise_error(e)
def _run(self, *args, **kwargs) -> Any: def _run(self, *args, **kwargs) -> Any:

View File

@ -7,6 +7,14 @@ from rich.console import Console
from dapr_agents.tool import AgentTool from dapr_agents.tool import AgentTool
from dapr_agents.types import AgentToolExecutorError, ToolError from dapr_agents.types import AgentToolExecutorError, ToolError
from pydantic import PrivateAttr
from dapr_agents.agent.telemetry import (
async_span_decorator,
)
from opentelemetry import trace
from opentelemetry.trace import Tracer, Status, StatusCode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -22,12 +30,24 @@ class AgentToolExecutor(BaseModel):
default_factory=list, description="List of tools to register and manage." default_factory=list, description="List of tools to register and manage."
) )
_tools_map: Dict[str, AgentTool] = PrivateAttr(default_factory=dict) _tools_map: Dict[str, AgentTool] = PrivateAttr(default_factory=dict)
_tracer: Optional[Tracer] = PrivateAttr(default=None)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
"""Initializes the internal tools map after model creation.""" """Initializes the internal tools map after model creation."""
for tool in self.tools: for tool in self.tools:
self.register_tool(tool) self.register_tool(tool)
logger.info(f"Tool Executor initialized with {len(self._tools_map)} tool(s).") logger.info(f"Tool Executor initialized with {len(self._tools_map)} tool(s).")
try:
provider = provider = trace.get_tracer_provider()
self._tracer = provider.get_tracer(f"agent_tool_exec_tracer")
except Exception as e:
logger.warning(
f"OpenTelemetry initialization failed: {e}. Continuing without telemetry."
)
self._tracer = None
super().model_post_init(__context) super().model_post_init(__context)
def register_tool(self, tool: AgentTool) -> None: def register_tool(self, tool: AgentTool) -> None:
@ -88,6 +108,7 @@ class AgentToolExecutor(BaseModel):
for tool in self._tools_map.values() for tool in self._tools_map.values()
) )
@async_span_decorator("run_tool")
async def run_tool(self, tool_name: str, *args, **kwargs) -> Any: async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
""" """
Executes a tool by name, automatically handling both sync and async tools. Executes a tool by name, automatically handling both sync and async tools.
@ -103,6 +124,7 @@ class AgentToolExecutor(BaseModel):
Raises: Raises:
AgentToolExecutorError: If the tool is not found or execution fails. AgentToolExecutorError: If the tool is not found or execution fails.
""" """
span = trace.get_current_span()
tool = self.get_tool(tool_name) tool = self.get_tool(tool_name)
if not tool: if not tool:
logger.error(f"Tool not found: {tool_name}") logger.error(f"Tool not found: {tool_name}")
@ -114,9 +136,13 @@ class AgentToolExecutor(BaseModel):
return tool(*args, **kwargs) return tool(*args, **kwargs)
except ToolError as e: except ToolError as e:
logger.error(f"Tool execution error in '{tool_name}': {e}") logger.error(f"Tool execution error in '{tool_name}': {e}")
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
raise AgentToolExecutorError(str(e)) from e raise AgentToolExecutorError(str(e)) from e
except Exception as e: except Exception as e:
logger.error(f"Unexpected error in '{tool_name}': {e}") logger.error(f"Unexpected error in '{tool_name}': {e}")
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
raise AgentToolExecutorError( raise AgentToolExecutorError(
f"Unexpected error in tool '{tool_name}': {e}" f"Unexpected error in tool '{tool_name}': {e}"
) from e ) from e

View File

@ -377,15 +377,22 @@ class AssistantAgent(AgentWorkflowBase):
Raises: Raises:
AgentError: If the tool call is malformed or execution fails. AgentError: If the tool call is malformed or execution fails.
""" """
span = trace.get_current_span()
span.set_attribute("workflow.id", instance_id)
function_details = tool_call.get("function", {}) function_details = tool_call.get("function", {})
function_name = function_details.get("name") function_name = function_details.get("name")
span.set_attribute("tool.call.name", function_name)
span.set_attribute("tool.call.details", str(function_details))
if not function_name: if not function_name:
span.set_attribute("error.type", type(e).__name__)
raise AgentError("Missing function name in tool execution request.") raise AgentError("Missing function name in tool execution request.")
try: try:
function_args = function_details.get("arguments", "") function_args = function_details.get("arguments", "")
function_args_as_dict = json.loads(function_args) if function_args else {} function_args_as_dict = json.loads(function_args) if function_args else {}
span.set_attributes("tool.call.args", str(function_args_as_dict))
# Execute tool function # Execute tool function
result = await self.tool_executor.run_tool( result = await self.tool_executor.run_tool(
@ -407,7 +414,6 @@ class AssistantAgent(AgentWorkflowBase):
except (ToolError, AgentToolExecutorError) as e: except (ToolError, AgentToolExecutorError) as e:
logger.info(e) logger.info(e)
span = trace.get_current_span()
span.set_status(Status(StatusCode.ERROR)) span.set_status(Status(StatusCode.ERROR))
span.record_exception(e) span.record_exception(e)
@ -431,6 +437,7 @@ class AssistantAgent(AgentWorkflowBase):
except Exception as e: except Exception as e:
logger.error(f"Error executing tool '{function_name}': {e}", exc_info=True) logger.error(f"Error executing tool '{function_name}': {e}", exc_info=True)
span.set_attribute("error.type", type(e).__name__)
raise AgentError(f"Error executing tool '{function_name}': {e}") from e raise AgentError(f"Error executing tool '{function_name}': {e}") from e
@task @task