Fix structured response validation and guard issubclass() for Python 3.10 compatibility (#67)

This commit is contained in:
Roberto Rodriguez 2025-03-28 19:17:54 -04:00 committed by GitHub
parent cf400f189c
commit cb75e76ba1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 43 deletions

View File

@ -1,10 +1,11 @@
from typing import Union, Dict, Any, Type, Optional, Iterator, Literal, get_args
from dapr_agents.llm.utils import StreamHandler, StructureHandler
from dataclasses import is_dataclass, asdict
from dapr_agents.types import ChatCompletion
from collections.abc import Iterable
from pydantic import BaseModel
import logging import logging
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Iterator, Literal, Optional, Type, Union
from pydantic import BaseModel
from dapr_agents.llm.utils import StreamHandler, StructureHandler
from dapr_agents.types import ChatCompletion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,30 +44,26 @@ class ResponseHandler:
else: else:
if response_format: if response_format:
structured_response_json = StructureHandler.extract_structured_response( structured_response_json = StructureHandler.extract_structured_response(
response=response, llm_provider=llm_provider, structured_mode=structured_mode response=response,
llm_provider=llm_provider,
structured_mode=structured_mode,
) )
# Processing Iterable Models # Normalize format and resolve actual model class
is_iterable = isinstance(response_format, Iterable) normalized_format = StructureHandler.normalize_iterable_format(response_format)
if is_iterable: model_cls = StructureHandler.resolve_response_model(normalized_format)
item_model = get_args(response_format)[0]
response_format = StructureHandler.create_iterable_model(item_model) if not model_cls:
raise TypeError(f"Could not resolve a valid Pydantic model from response_format: {response_format}")
# Validating Response
structured_response_instance = StructureHandler.validate_response( structured_response_instance = StructureHandler.validate_response(
structured_response_json, response_format structured_response_json, normalized_format
) )
if isinstance(structured_response_instance, response_format): logger.info("Structured output was successfully validated.")
logger.info("Structured output was successfully validated.") if hasattr(structured_response_instance, "objects"):
if is_iterable: return structured_response_instance.objects
logger.debug(f"Returning objects from an instance of {type(structured_response_instance)}.") return structured_response_instance
return structured_response_instance.objects
else:
logger.debug(f"Returning an instance of {type(structured_response_instance)}.")
return structured_response_instance
else:
logger.error("Validation failed for structured response.")
# Convert response to dictionary # Convert response to dictionary
if isinstance(response, dict): if isinstance(response, dict):

View File

@ -431,36 +431,38 @@ class StructureHandler:
@staticmethod @staticmethod
def resolve_all_pydantic_models(tp: Any) -> List[Type[BaseModel]]: def resolve_all_pydantic_models(tp: Any) -> List[Type[BaseModel]]:
"""
Extracts all Pydantic models from a typing annotation.
Handles:
- Single BaseModel
- List[BaseModel], Iterable[BaseModel]
- Union[...] with optional or multiple model types
Returns:
List[Type[BaseModel]]
"""
models = [] models = []
tp = StructureHandler.unwrap_annotated_type(tp) tp = StructureHandler.unwrap_annotated_type(tp)
origin = get_origin(tp) origin = get_origin(tp)
args = get_args(tp) args = get_args(tp)
if isinstance(tp, type) and issubclass(tp, BaseModel): if isinstance(tp, type):
return [tp] try:
if issubclass(tp, BaseModel):
return [tp]
except TypeError:
pass
if origin in (list, List, tuple, Iterable) and args: if origin in (list, List, tuple, Iterable) and args:
inner = args[0] inner = args[0]
if isinstance(inner, type) and issubclass(inner, BaseModel): if isinstance(inner, type):
return [inner] try:
if issubclass(inner, BaseModel):
return [inner]
except TypeError:
pass
else:
logger.debug(f"[resolve] Skipping non-class inner: {inner} ({type(inner)})")
if origin is Union: if origin is Union:
for arg in args: for arg in args:
if isinstance(arg, type) and issubclass(arg, BaseModel): if isinstance(arg, type):
models.append(arg) try:
if issubclass(arg, BaseModel):
models.append(arg)
except TypeError:
continue
return list(dict.fromkeys(models)) return list(dict.fromkeys(models))