mirror of https://github.com/dapr/dapr-agents.git
Fix structured response validation and guard issubclass() for Python 3.10 compatibility (#67)
This commit is contained in:
parent
cf400f189c
commit
cb75e76ba1
|
@ -1,10 +1,11 @@
|
||||||
from typing import Union, Dict, Any, Type, Optional, Iterator, Literal, get_args
|
|
||||||
from dapr_agents.llm.utils import StreamHandler, StructureHandler
|
|
||||||
from dataclasses import is_dataclass, asdict
|
|
||||||
from dapr_agents.types import ChatCompletion
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import asdict, is_dataclass
|
||||||
|
from typing import Any, Dict, Iterator, Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from dapr_agents.llm.utils import StreamHandler, StructureHandler
|
||||||
|
from dapr_agents.types import ChatCompletion
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -43,30 +44,26 @@ class ResponseHandler:
|
||||||
else:
|
else:
|
||||||
if response_format:
|
if response_format:
|
||||||
structured_response_json = StructureHandler.extract_structured_response(
|
structured_response_json = StructureHandler.extract_structured_response(
|
||||||
response=response, llm_provider=llm_provider, structured_mode=structured_mode
|
response=response,
|
||||||
|
llm_provider=llm_provider,
|
||||||
|
structured_mode=structured_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Processing Iterable Models
|
# Normalize format and resolve actual model class
|
||||||
is_iterable = isinstance(response_format, Iterable)
|
normalized_format = StructureHandler.normalize_iterable_format(response_format)
|
||||||
if is_iterable:
|
model_cls = StructureHandler.resolve_response_model(normalized_format)
|
||||||
item_model = get_args(response_format)[0]
|
|
||||||
response_format = StructureHandler.create_iterable_model(item_model)
|
if not model_cls:
|
||||||
|
raise TypeError(f"Could not resolve a valid Pydantic model from response_format: {response_format}")
|
||||||
# Validating Response
|
|
||||||
structured_response_instance = StructureHandler.validate_response(
|
structured_response_instance = StructureHandler.validate_response(
|
||||||
structured_response_json, response_format
|
structured_response_json, normalized_format
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(structured_response_instance, response_format):
|
logger.info("Structured output was successfully validated.")
|
||||||
logger.info("Structured output was successfully validated.")
|
if hasattr(structured_response_instance, "objects"):
|
||||||
if is_iterable:
|
return structured_response_instance.objects
|
||||||
logger.debug(f"Returning objects from an instance of {type(structured_response_instance)}.")
|
return structured_response_instance
|
||||||
return structured_response_instance.objects
|
|
||||||
else:
|
|
||||||
logger.debug(f"Returning an instance of {type(structured_response_instance)}.")
|
|
||||||
return structured_response_instance
|
|
||||||
else:
|
|
||||||
logger.error("Validation failed for structured response.")
|
|
||||||
|
|
||||||
# Convert response to dictionary
|
# Convert response to dictionary
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
|
|
|
@ -431,36 +431,38 @@ class StructureHandler:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_all_pydantic_models(tp: Any) -> List[Type[BaseModel]]:
|
def resolve_all_pydantic_models(tp: Any) -> List[Type[BaseModel]]:
|
||||||
"""
|
|
||||||
Extracts all Pydantic models from a typing annotation.
|
|
||||||
|
|
||||||
Handles:
|
|
||||||
- Single BaseModel
|
|
||||||
- List[BaseModel], Iterable[BaseModel]
|
|
||||||
- Union[...] with optional or multiple model types
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Type[BaseModel]]
|
|
||||||
"""
|
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
tp = StructureHandler.unwrap_annotated_type(tp)
|
tp = StructureHandler.unwrap_annotated_type(tp)
|
||||||
|
|
||||||
origin = get_origin(tp)
|
origin = get_origin(tp)
|
||||||
args = get_args(tp)
|
args = get_args(tp)
|
||||||
|
|
||||||
if isinstance(tp, type) and issubclass(tp, BaseModel):
|
if isinstance(tp, type):
|
||||||
return [tp]
|
try:
|
||||||
|
if issubclass(tp, BaseModel):
|
||||||
|
return [tp]
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
if origin in (list, List, tuple, Iterable) and args:
|
if origin in (list, List, tuple, Iterable) and args:
|
||||||
inner = args[0]
|
inner = args[0]
|
||||||
if isinstance(inner, type) and issubclass(inner, BaseModel):
|
if isinstance(inner, type):
|
||||||
return [inner]
|
try:
|
||||||
|
if issubclass(inner, BaseModel):
|
||||||
|
return [inner]
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.debug(f"[resolve] Skipping non-class inner: {inner} ({type(inner)})")
|
||||||
|
|
||||||
if origin is Union:
|
if origin is Union:
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
if isinstance(arg, type):
|
||||||
models.append(arg)
|
try:
|
||||||
|
if issubclass(arg, BaseModel):
|
||||||
|
models.append(arg)
|
||||||
|
except TypeError:
|
||||||
|
continue
|
||||||
|
|
||||||
return list(dict.fromkeys(models))
|
return list(dict.fromkeys(models))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue