mirror of https://github.com/dapr/dapr-agents.git
Update initialization of LLM client for agent base
Signed-off-by: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com>
This commit is contained in:
parent
2465fb8207
commit
3a997f7461
|
|
@ -27,7 +27,6 @@ from typing import (
|
|||
)
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict
|
||||
from dapr_agents.llm.chat import ChatClientBase
|
||||
from dapr_agents.llm.openai import OpenAIChatClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -66,8 +65,8 @@ class AgentBase(BaseModel, ABC):
|
|||
default=None,
|
||||
description="A custom system prompt, overriding name, role, goal, and instructions.",
|
||||
)
|
||||
llm: ChatClientBase = Field(
|
||||
default_factory=OpenAIChatClient,
|
||||
llm: Optional[ChatClientBase] = Field(
|
||||
default=None,
|
||||
description="Language model client for generating responses.",
|
||||
)
|
||||
prompt_template: Optional[PromptTemplateBase] = Field(
|
||||
|
|
@ -136,12 +135,16 @@ Your role is {role}.
|
|||
@model_validator(mode="after")
|
||||
def validate_llm(cls, values):
|
||||
"""Validate that LLM is properly configured."""
|
||||
if hasattr(values, "llm") and values.llm:
|
||||
try:
|
||||
# Validate LLM is properly configured by accessing it as this is required to be set.
|
||||
_ = values.llm
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to initialize LLM: {e}") from e
|
||||
if hasattr(values, "llm"):
|
||||
if values.llm is None:
|
||||
logger.warning("LLM client is None, some functionality may be limited.")
|
||||
else:
|
||||
try:
|
||||
# Validate LLM is properly configured by accessing it as this is required to be set.
|
||||
_ = values.llm
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM: {e}")
|
||||
values.llm = None
|
||||
|
||||
return values
|
||||
|
||||
|
|
@ -160,10 +163,15 @@ Your role is {role}.
|
|||
if self.tool_choice is None:
|
||||
self.tool_choice = "auto" if self.tools else None
|
||||
|
||||
# Initialize LLM if not provided
|
||||
if self.llm is None:
|
||||
self.llm = self._create_default_llm()
|
||||
|
||||
# Centralize prompt template selection logic
|
||||
self.prompt_template = self._initialize_prompt_template()
|
||||
# Ensure LLM client and agent both reference the same template
|
||||
self.llm.prompt_template = self.prompt_template
|
||||
if self.llm is not None:
|
||||
self.llm.prompt_template = self.prompt_template
|
||||
|
||||
self._validate_prompt_template()
|
||||
self.prefill_agent_attributes()
|
||||
|
|
@ -174,6 +182,18 @@ Your role is {role}.
|
|||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def _create_default_llm(self) -> Optional[ChatClientBase]:
|
||||
"""
|
||||
Creates a default LLM client when none is provided.
|
||||
Returns None if the default LLM cannot be created due to missing configuration.
|
||||
"""
|
||||
try:
|
||||
from dapr_agents.llm.openai import OpenAIChatClient
|
||||
return OpenAIChatClient()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create default OpenAI client: {e}. LLM will be None.")
|
||||
return None
|
||||
|
||||
def _initialize_prompt_template(self) -> PromptTemplateBase:
|
||||
"""
|
||||
Determines which prompt template to use for the agent:
|
||||
|
|
@ -190,7 +210,7 @@ Your role is {role}.
|
|||
return self.prompt_template
|
||||
|
||||
# 2) LLM client has one?
|
||||
if self.llm.prompt_template:
|
||||
if self.llm and hasattr(self.llm, 'prompt_template') and self.llm.prompt_template:
|
||||
logger.debug("🔄 Syncing from llm.prompt_template")
|
||||
return self.llm.prompt_template
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue