mirror of https://github.com/dapr/dapr-agents.git
Fix structured response validation and guard issubclass() for Python 3.10 compatibility (#67)
This commit is contained in:
parent
cf400f189c
commit
cb75e76ba1
|
@ -1,10 +1,11 @@
|
|||
from typing import Union, Dict, Any, Type, Optional, Iterator, Literal, get_args
|
||||
from dapr_agents.llm.utils import StreamHandler, StructureHandler
|
||||
from dataclasses import is_dataclass, asdict
|
||||
from dapr_agents.types import ChatCompletion
|
||||
from collections.abc import Iterable
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, Dict, Iterator, Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dapr_agents.llm.utils import StreamHandler, StructureHandler
|
||||
from dapr_agents.types import ChatCompletion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -43,30 +44,26 @@ class ResponseHandler:
|
|||
else:
|
||||
if response_format:
|
||||
structured_response_json = StructureHandler.extract_structured_response(
|
||||
response=response, llm_provider=llm_provider, structured_mode=structured_mode
|
||||
response=response,
|
||||
llm_provider=llm_provider,
|
||||
structured_mode=structured_mode,
|
||||
)
|
||||
|
||||
# Processing Iterable Models
|
||||
is_iterable = isinstance(response_format, Iterable)
|
||||
if is_iterable:
|
||||
item_model = get_args(response_format)[0]
|
||||
response_format = StructureHandler.create_iterable_model(item_model)
|
||||
|
||||
# Validating Response
|
||||
# Normalize format and resolve actual model class
|
||||
normalized_format = StructureHandler.normalize_iterable_format(response_format)
|
||||
model_cls = StructureHandler.resolve_response_model(normalized_format)
|
||||
|
||||
if not model_cls:
|
||||
raise TypeError(f"Could not resolve a valid Pydantic model from response_format: {response_format}")
|
||||
|
||||
structured_response_instance = StructureHandler.validate_response(
|
||||
structured_response_json, response_format
|
||||
structured_response_json, normalized_format
|
||||
)
|
||||
|
||||
if isinstance(structured_response_instance, response_format):
|
||||
logger.info("Structured output was successfully validated.")
|
||||
if is_iterable:
|
||||
logger.debug(f"Returning objects from an instance of {type(structured_response_instance)}.")
|
||||
return structured_response_instance.objects
|
||||
else:
|
||||
logger.debug(f"Returning an instance of {type(structured_response_instance)}.")
|
||||
return structured_response_instance
|
||||
else:
|
||||
logger.error("Validation failed for structured response.")
|
||||
logger.info("Structured output was successfully validated.")
|
||||
if hasattr(structured_response_instance, "objects"):
|
||||
return structured_response_instance.objects
|
||||
return structured_response_instance
|
||||
|
||||
# Convert response to dictionary
|
||||
if isinstance(response, dict):
|
||||
|
|
|
@ -431,36 +431,38 @@ class StructureHandler:
|
|||
|
||||
@staticmethod
|
||||
def resolve_all_pydantic_models(tp: Any) -> List[Type[BaseModel]]:
|
||||
"""
|
||||
Extracts all Pydantic models from a typing annotation.
|
||||
|
||||
Handles:
|
||||
- Single BaseModel
|
||||
- List[BaseModel], Iterable[BaseModel]
|
||||
- Union[...] with optional or multiple model types
|
||||
|
||||
Returns:
|
||||
List[Type[BaseModel]]
|
||||
"""
|
||||
models = []
|
||||
|
||||
tp = StructureHandler.unwrap_annotated_type(tp)
|
||||
|
||||
origin = get_origin(tp)
|
||||
args = get_args(tp)
|
||||
|
||||
if isinstance(tp, type) and issubclass(tp, BaseModel):
|
||||
return [tp]
|
||||
if isinstance(tp, type):
|
||||
try:
|
||||
if issubclass(tp, BaseModel):
|
||||
return [tp]
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
if origin in (list, List, tuple, Iterable) and args:
|
||||
inner = args[0]
|
||||
if isinstance(inner, type) and issubclass(inner, BaseModel):
|
||||
return [inner]
|
||||
if isinstance(inner, type):
|
||||
try:
|
||||
if issubclass(inner, BaseModel):
|
||||
return [inner]
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"[resolve] Skipping non-class inner: {inner} ({type(inner)})")
|
||||
|
||||
if origin is Union:
|
||||
for arg in args:
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
models.append(arg)
|
||||
if isinstance(arg, type):
|
||||
try:
|
||||
if issubclass(arg, BaseModel):
|
||||
models.append(arg)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
return list(dict.fromkeys(models))
|
||||
|
||||
|
|
Loading…
Reference in New Issue