Update initialization of LLM client for agent base

Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>
This commit is contained in:
Roberto Rodriguez 2025-08-01 23:47:58 -04:00
parent 2465fb8207
commit 3a997f7461
1 changed files with 31 additions and 11 deletions

View File

@ -27,7 +27,6 @@ from typing import (
) )
from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict
from dapr_agents.llm.chat import ChatClientBase from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.llm.openai import OpenAIChatClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,8 +65,8 @@ class AgentBase(BaseModel, ABC):
default=None, default=None,
description="A custom system prompt, overriding name, role, goal, and instructions.", description="A custom system prompt, overriding name, role, goal, and instructions.",
) )
llm: ChatClientBase = Field( llm: Optional[ChatClientBase] = Field(
default_factory=OpenAIChatClient, default=None,
description="Language model client for generating responses.", description="Language model client for generating responses.",
) )
prompt_template: Optional[PromptTemplateBase] = Field( prompt_template: Optional[PromptTemplateBase] = Field(
@ -136,12 +135,16 @@ Your role is {role}.
@model_validator(mode="after") @model_validator(mode="after")
def validate_llm(cls, values): def validate_llm(cls, values):
"""Validate that LLM is properly configured.""" """Validate that LLM is properly configured."""
if hasattr(values, "llm") and values.llm: if hasattr(values, "llm"):
try: if values.llm is None:
# Validate LLM is properly configured by accessing it as this is required to be set. logger.warning("LLM client is None, some functionality may be limited.")
_ = values.llm else:
except Exception as e: try:
raise ValueError(f"Failed to initialize LLM: {e}") from e # Validate LLM is properly configured by accessing it as this is required to be set.
_ = values.llm
except Exception as e:
logger.error(f"Failed to initialize LLM: {e}")
values.llm = None
return values return values
@ -160,10 +163,15 @@ Your role is {role}.
if self.tool_choice is None: if self.tool_choice is None:
self.tool_choice = "auto" if self.tools else None self.tool_choice = "auto" if self.tools else None
# Initialize LLM if not provided
if self.llm is None:
self.llm = self._create_default_llm()
# Centralize prompt template selection logic # Centralize prompt template selection logic
self.prompt_template = self._initialize_prompt_template() self.prompt_template = self._initialize_prompt_template()
# Ensure LLM client and agent both reference the same template # Ensure LLM client and agent both reference the same template
self.llm.prompt_template = self.prompt_template if self.llm is not None:
self.llm.prompt_template = self.prompt_template
self._validate_prompt_template() self._validate_prompt_template()
self.prefill_agent_attributes() self.prefill_agent_attributes()
@ -174,6 +182,18 @@ Your role is {role}.
super().model_post_init(__context) super().model_post_init(__context)
def _create_default_llm(self) -> Optional[ChatClientBase]:
"""
Creates a default LLM client when none is provided.
Returns None if the default LLM cannot be created due to missing configuration.
"""
try:
from dapr_agents.llm.openai import OpenAIChatClient
return OpenAIChatClient()
except Exception as e:
logger.warning(f"Failed to create default OpenAI client: {e}. LLM will be None.")
return None
def _initialize_prompt_template(self) -> PromptTemplateBase: def _initialize_prompt_template(self) -> PromptTemplateBase:
""" """
Determines which prompt template to use for the agent: Determines which prompt template to use for the agent:
@ -190,7 +210,7 @@ Your role is {role}.
return self.prompt_template return self.prompt_template
# 2) LLM client has one? # 2) LLM client has one?
if self.llm.prompt_template: if self.llm and hasattr(self.llm, 'prompt_template') and self.llm.prompt_template:
logger.debug("🔄 Syncing from llm.prompt_template") logger.debug("🔄 Syncing from llm.prompt_template")
return self.llm.prompt_template return self.llm.prompt_template