mirror of https://github.com/dapr/dapr-agents.git
612 lines
24 KiB
Python
612 lines
24 KiB
Python
import json
|
|
import logging
|
|
from collections.abc import Iterable
|
|
from typing import (
|
|
Annotated,
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
)
|
|
|
|
from pydantic import BaseModel, Field, TypeAdapter, ValidationError, create_model
|
|
|
|
from dapr_agents.tool.utils.function_calling import to_function_call_definition
|
|
from dapr_agents.types import (
|
|
AssistantMessage,
|
|
OAIJSONSchema,
|
|
OAIResponseFormatSchema,
|
|
StructureError,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T", bound=BaseModel)
|
|
|
|
|
|
class StructureHandler:
|
|
@staticmethod
|
|
def is_json_string(input_string: str) -> bool:
|
|
"""
|
|
Check if the given input is a valid JSON string.
|
|
|
|
Args:
|
|
input_string (str): The string to check.
|
|
|
|
Returns:
|
|
bool: True if the input is a valid JSON string, False otherwise.
|
|
"""
|
|
try:
|
|
json.loads(input_string)
|
|
return True
|
|
except json.JSONDecodeError:
|
|
return False
|
|
|
|
@staticmethod
|
|
def normalize_iterable_format(tp: Any) -> Any:
|
|
origin = get_origin(tp)
|
|
args = get_args(tp)
|
|
|
|
if origin in (list, List, tuple, Iterable) and args:
|
|
item_type = args[0]
|
|
if isinstance(item_type, type) and issubclass(item_type, BaseModel):
|
|
logger.debug(
|
|
"Detected iterable of BaseModel. Wrapping in generated Pydantic model."
|
|
)
|
|
return StructureHandler.create_iterable_model(item_type)
|
|
|
|
return tp
|
|
|
|
@staticmethod
|
|
def generate_request(
|
|
response_format: Union[Type[T], Dict[str, Any], Iterable[Type[T]]],
|
|
llm_provider: str,
|
|
structured_mode: Literal["json", "function_call"] = "json",
|
|
**params,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generates a structured request that conforms to a specified API format using the given Pydantic model.
|
|
This function prepares a request configuration that includes the model as a tool specification and
|
|
sets the necessary parameters for the API call.
|
|
|
|
Args:
|
|
response_format (Union[Type[T], Dict[str, Any], Iterable[Type[T]]]): Defines the response structure.
|
|
- If `structured_mode="json"`: Can be a Pydantic model (converted to JSON schema) or a JSON schema dictionary.
|
|
- If `structured_mode="function_call"`: Must be a Pydantic model or an iterable of models.
|
|
- If an iterable of models is provided in either mode, it is treated as a list schema.
|
|
llm_provider (str): The LLM provider (e.g., "openai", "claude").
|
|
structured_mode (Literal["json", "function_call"]): Determines the response structure.
|
|
- "json": Generates and enforces a strict JSON schema.
|
|
- "function_call": Converts a Pydantic model to a function-call definition.
|
|
**params: Additional request parameters.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Updated request parameters, including tools or response format.
|
|
|
|
Raises:
|
|
ValueError: If an unsupported `structured_mode` is provided.
|
|
TypeError: If `response_format` is invalid for the selected mode.
|
|
"""
|
|
logger.debug(f"Structured response mode: {structured_mode}")
|
|
|
|
response_format = StructureHandler.normalize_iterable_format(response_format)
|
|
|
|
# Handle iterable models in both modes
|
|
if structured_mode == "function_call":
|
|
model_cls = StructureHandler.resolve_response_model(response_format)
|
|
if not model_cls:
|
|
raise TypeError(
|
|
"function_call mode requires a single, unambiguous Pydantic model."
|
|
)
|
|
|
|
name = model_cls.__name__
|
|
description = model_cls.__doc__ or ""
|
|
model_tool_format = to_function_call_definition(
|
|
name, description, model_cls, llm_provider
|
|
)
|
|
|
|
params["tools"] = [model_tool_format]
|
|
params["tool_choice"] = {
|
|
"type": "function",
|
|
"function": {"name": model_tool_format["function"]["name"]},
|
|
}
|
|
return params
|
|
|
|
elif structured_mode == "json":
|
|
try:
|
|
logger.debug(
|
|
f"generate_request called with type={type(response_format)}, mode={structured_mode}, provider={llm_provider}"
|
|
)
|
|
# If it's a dict, assume it's already a JSON schema; otherwise, try to create from model
|
|
if isinstance(response_format, dict):
|
|
raw_schema = response_format
|
|
name = response_format.get("name", "custom_schema")
|
|
description = response_format.get("description")
|
|
elif isinstance(response_format, type) and issubclass(
|
|
response_format, BaseModel
|
|
):
|
|
raw_schema = response_format.model_json_schema()
|
|
name = response_format.__name__
|
|
description = response_format.__doc__
|
|
else:
|
|
raise TypeError("json mode requires a dict or a Pydantic model.")
|
|
|
|
# Enforce strict JSON schema (process $refs, $defs, etc.)
|
|
logger.debug(f"Raw Schema: {raw_schema}")
|
|
strict_schema = StructureHandler.enforce_strict_json_schema(raw_schema)
|
|
|
|
# Construct the JSON schema object using OAIJSONSchema
|
|
json_schema_obj = OAIJSONSchema(
|
|
name=name,
|
|
description=description,
|
|
schema_=strict_schema,
|
|
strict=True,
|
|
)
|
|
|
|
# Wrap it in the top-level response format object
|
|
response_format_obj = OAIResponseFormatSchema(
|
|
json_schema=json_schema_obj
|
|
)
|
|
|
|
logger.debug(
|
|
f"Generated JSON schema: {response_format_obj.model_dump()}"
|
|
)
|
|
|
|
# Use model_dump() to serialize the response format into a dictionary
|
|
params["response_format"] = response_format_obj.model_dump(
|
|
by_alias=True
|
|
)
|
|
|
|
except ValidationError as e:
|
|
logger.error(f"Validation error in JSON schema: {e}")
|
|
raise ValueError(f"Invalid response_format provided: {e}")
|
|
|
|
return params
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported structured_mode: {structured_mode}. Must be 'json' or 'function_call'."
|
|
)
|
|
|
|
@staticmethod
|
|
def create_iterable_model(
|
|
model: Type[BaseModel],
|
|
model_name: Optional[str] = None,
|
|
model_description: Optional[str] = None,
|
|
) -> Type[BaseModel]:
|
|
"""
|
|
Constructs an iterable Pydantic model for a given Pydantic model.
|
|
|
|
Args:
|
|
model (Type[BaseModel]): The original Pydantic model to capture a list of objects of the original model type.
|
|
model_name (Optional[str]): The name of the new iterable model. Defaults to None.
|
|
model_description (Optional[str]): The description of the new iterable model. Defaults to None.
|
|
|
|
Returns:
|
|
Type[BaseModel]: A new Pydantic model class representing a list of the original Pydantic model.
|
|
"""
|
|
model_name = model.__name__ if model_name is None else model_name
|
|
iterable_model_name = f"Iterable{model_name}"
|
|
|
|
objects_field = (
|
|
List[model],
|
|
Field(..., description=f"A list of `{model_name}` objects"),
|
|
)
|
|
|
|
iterable_model = create_model(
|
|
iterable_model_name, objects=objects_field, __base__=(BaseModel,)
|
|
)
|
|
|
|
iterable_model.__doc__ = (
|
|
f"A Pydantic model to capture `{iterable_model_name}` objects"
|
|
if model_description is None
|
|
else model_description
|
|
)
|
|
|
|
return iterable_model
|
|
|
|
@staticmethod
|
|
def extract_structured_response(
|
|
message: AssistantMessage,
|
|
llm_provider: str,
|
|
structured_mode: Literal["json", "function_call"] = "json",
|
|
) -> Union[str, Dict[str, Any]]:
|
|
"""
|
|
Extracts the structured JSON string or content from the response.
|
|
|
|
Args:
|
|
message (AssistantMessage): The API response data to extract.
|
|
llm_provider (str): The LLM provider (e.g., 'openai').
|
|
structured_mode (Literal["json", "function_call"]): The structured response mode.
|
|
|
|
Returns:
|
|
Union[str, Dict[str, Any]]: The extracted structured response.
|
|
|
|
Raises:
|
|
StructureError: If the structured response is not found or extraction fails.
|
|
"""
|
|
try:
|
|
logger.debug(f"Processing structured response for mode: {structured_mode}")
|
|
if llm_provider in ("openai", "nvidia", "huggingface"):
|
|
if structured_mode == "function_call":
|
|
tool_calls = getattr(message, "tool_calls", None)
|
|
if tool_calls:
|
|
function = getattr(tool_calls[0], "function", None)
|
|
if function and hasattr(function, "arguments"):
|
|
extracted_response = function.arguments
|
|
logger.debug(
|
|
f"Extracted function-call response: {extracted_response}"
|
|
)
|
|
return extracted_response
|
|
raise StructureError("No tool_calls found for function_call mode.")
|
|
|
|
elif structured_mode == "json":
|
|
content = getattr(message, "content", None)
|
|
refusal = getattr(message, "refusal", None)
|
|
|
|
if refusal:
|
|
logger.warning(
|
|
f"Model refused to fulfill the request: {refusal}"
|
|
)
|
|
raise StructureError(f"Request refused by the model: {refusal}")
|
|
|
|
if not content:
|
|
raise StructureError("No content found for JSON mode.")
|
|
|
|
logger.debug(f"Extracted JSON content: {content}")
|
|
return content
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported structured_mode: {structured_mode}. Must be 'json' or 'function_call'."
|
|
)
|
|
else:
|
|
raise StructureError(f"Unsupported LLM provider: {llm_provider}")
|
|
except Exception as e:
|
|
logger.error(f"Error while extracting structured response: {e}")
|
|
raise StructureError(f"Extraction failed: {e}")
|
|
|
|
@staticmethod
|
|
def validate_response(response: Union[str, dict], model: Type[T]) -> T:
|
|
"""
|
|
Validates a JSON string or a dictionary using a specified Pydantic model.
|
|
|
|
This method checks whether the response is a JSON string or a dictionary.
|
|
If the response is a JSON string, it validates it using the `model_validate_json` method.
|
|
If the response is a dictionary, it validates it using the `model_validate` method.
|
|
|
|
Args:
|
|
response (Union[str, dict]): The JSON string or dictionary to validate.
|
|
model (Type[T]): The Pydantic model that defines the expected structure of the response.
|
|
|
|
Returns:
|
|
T: An instance of the Pydantic model populated with the validated data.
|
|
|
|
Raises:
|
|
StructureError: If the validation fails.
|
|
"""
|
|
try:
|
|
if isinstance(response, str) and StructureHandler.is_json_string(response):
|
|
return model.model_validate_json(response)
|
|
elif isinstance(response, dict):
|
|
# If it's a dictionary, use model_validate
|
|
return model.model_validate(response)
|
|
else:
|
|
raise ValueError("Response must be a JSON string or a dictionary.")
|
|
except ValidationError as e:
|
|
logger.error(f"Validation error while parsing structured response: {e}")
|
|
raise StructureError(f"Validation failed for structured response: {e}")
|
|
|
|
@staticmethod
|
|
def expand_local_refs(part: Dict[str, Any], root: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Recursively expand all local $refs in the schema, including nested references.
|
|
|
|
Args:
|
|
part (Dict[str, Any]): The schema part to process.
|
|
root (Dict[str, Any]): The root schema for resolving $refs.
|
|
|
|
Returns:
|
|
Dict[str, Any]: The schema part with all $refs expanded.
|
|
"""
|
|
ref = part.pop("$ref", None)
|
|
if ref:
|
|
logger.debug(f"Found $ref: {ref}")
|
|
if not ref.startswith("#/$defs/"):
|
|
raise ValueError(f"Unexpected $ref format: {ref}")
|
|
|
|
ref_name = ref.split("/")[-1]
|
|
defs_section = root.get("$defs", {})
|
|
if ref_name not in defs_section:
|
|
raise ValueError(f"Reference '{ref_name}' not found in $defs.")
|
|
|
|
# Merge the referenced schema with the current part, resolving nested $refs
|
|
merged = {
|
|
**defs_section[ref_name],
|
|
**{k: v for k, v in part.items() if k != "$ref"},
|
|
}
|
|
return StructureHandler.expand_local_refs(merged, root)
|
|
|
|
# Process objects and their properties
|
|
if part.get("type") == "object" and "properties" in part:
|
|
for key, value in part["properties"].items():
|
|
part["properties"][key] = StructureHandler.expand_local_refs(
|
|
value, root
|
|
)
|
|
|
|
# Process arrays and their items
|
|
if part.get("type") == "array" and "items" in part:
|
|
part["items"] = StructureHandler.expand_local_refs(part["items"], root)
|
|
|
|
# Process anyOf and allOf schemas
|
|
for key in ("anyOf", "allOf"):
|
|
if key in part and isinstance(part[key], list):
|
|
part[key] = [
|
|
StructureHandler.expand_local_refs(subschema, root)
|
|
for subschema in part[key]
|
|
]
|
|
|
|
return part
|
|
|
|
@staticmethod
|
|
def enforce_strict_json_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Enforces strict JSON schema constraints while making it OpenAI-compatible.
|
|
|
|
- Expands all local $refs by resolving them from the $defs section.
|
|
- Ensures "additionalProperties": false for object schemas.
|
|
- Removes default values and replaces optional fields with `anyOf: [{"type": T}, {"type": "null"}]`.
|
|
- Converts optional arrays to `anyOf: [{"type": "array", "items": T}, {"type": "null"}]`.
|
|
- Ensures all `array` schemas define `items`.
|
|
- Converts optional integers and floats to `anyOf: [{"type": T}, {"type": "null"}]`.
|
|
- Prevents the use of `anyOf` at the root level of an array.
|
|
|
|
Args:
|
|
schema (Dict[str, Any]): The JSON schema dictionary to process.
|
|
|
|
Returns:
|
|
Dict[str, Any]: The updated schema with strict constraints applied.
|
|
"""
|
|
# Expand all $refs (resolves and removes them)
|
|
schema = StructureHandler.expand_local_refs(schema, schema)
|
|
|
|
# Ensure "additionalProperties": false for all objects
|
|
if schema.get("type") == "object":
|
|
schema.setdefault("additionalProperties", False)
|
|
|
|
required_fields = set(schema.get("required", []))
|
|
|
|
for key, value in schema.get("properties", {}).items():
|
|
schema["properties"][key] = StructureHandler.enforce_strict_json_schema(
|
|
value
|
|
)
|
|
|
|
# Remove default values (not allowed by OpenAI)
|
|
schema["properties"][key].pop("default", None)
|
|
|
|
# Convert optional fields (string, number, integer) to `anyOf`
|
|
if key not in required_fields:
|
|
field_type = schema["properties"][key].get("type")
|
|
|
|
if field_type and not isinstance(
|
|
field_type, list
|
|
): # Ensure it's not already `anyOf`
|
|
if field_type in ["string", "integer", "number"]:
|
|
schema["properties"][key]["anyOf"] = [
|
|
{"type": field_type},
|
|
{"type": "null"},
|
|
]
|
|
schema["properties"][key].pop(
|
|
"type", None
|
|
) # Remove direct "type" field
|
|
|
|
# Ensure field is included in "required" (even if it allows null)
|
|
required_fields.add(key)
|
|
|
|
# Handle optional arrays inside object properties
|
|
if schema["properties"][key].get("anyOf") and isinstance(
|
|
schema["properties"][key]["anyOf"], list
|
|
):
|
|
for subschema in schema["properties"][key]["anyOf"]:
|
|
if subschema.get("type") == "array":
|
|
schema["properties"][key] = {
|
|
"anyOf": [
|
|
{
|
|
"type": "array",
|
|
"items": subschema.get("items", {}),
|
|
},
|
|
{"type": "null"},
|
|
]
|
|
}
|
|
|
|
# Ensure all required fields are explicitly listed
|
|
schema["required"] = list(required_fields)
|
|
|
|
# Process arrays and enforce strictness
|
|
if schema.get("type") == "array":
|
|
# Ensure `items` is always present in arrays
|
|
if "items" not in schema:
|
|
raise ValueError(f"Array schema missing 'items': {schema}")
|
|
|
|
schema["items"] = StructureHandler.enforce_strict_json_schema(
|
|
schema["items"]
|
|
)
|
|
|
|
# Convert optional arrays from `anyOf` to `anyOf: [{"type": "array", "items": T}, {"type": "null"}]`
|
|
if "anyOf" in schema and isinstance(schema["anyOf"], list):
|
|
if any(
|
|
subschema.get("type") == "array" for subschema in schema["anyOf"]
|
|
):
|
|
schema["anyOf"] = [
|
|
{"type": "array", "items": schema["items"]},
|
|
{"type": "null"},
|
|
]
|
|
schema.pop("type", None) # Remove direct "type" field
|
|
schema.pop(
|
|
"minItems", None
|
|
) # Remove `minItems`, not needed with null
|
|
|
|
# Process $defs and remove after expansion
|
|
if "$defs" in schema:
|
|
for def_name, def_schema in schema["$defs"].items():
|
|
schema["$defs"][def_name] = StructureHandler.enforce_strict_json_schema(
|
|
def_schema
|
|
)
|
|
schema.pop("$defs", None)
|
|
|
|
# Process anyOf and allOf schemas recursively
|
|
for key in ("anyOf", "allOf"):
|
|
if key in schema and isinstance(schema[key], list):
|
|
schema[key] = [
|
|
StructureHandler.enforce_strict_json_schema(subschema)
|
|
for subschema in schema[key]
|
|
]
|
|
|
|
return schema
|
|
|
|
@staticmethod
|
|
def unwrap_annotated_type(tp: Any) -> Any:
|
|
origin = get_origin(tp)
|
|
if origin is Annotated:
|
|
args = get_args(tp)
|
|
return StructureHandler.unwrap_annotated_type(args[0])
|
|
if hasattr(tp, "__supertype__"): # for NewType
|
|
return StructureHandler.unwrap_annotated_type(tp.__supertype__)
|
|
return tp
|
|
|
|
@staticmethod
|
|
def resolve_all_pydantic_models(tp: Any) -> List[Type[BaseModel]]:
|
|
models = []
|
|
|
|
tp = StructureHandler.unwrap_annotated_type(tp)
|
|
origin = get_origin(tp)
|
|
args = get_args(tp)
|
|
|
|
if isinstance(tp, type):
|
|
try:
|
|
if issubclass(tp, BaseModel):
|
|
return [tp]
|
|
except TypeError:
|
|
pass
|
|
|
|
if origin in (list, List, tuple, Iterable) and args:
|
|
inner = args[0]
|
|
if isinstance(inner, type):
|
|
try:
|
|
if issubclass(inner, BaseModel):
|
|
return [inner]
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
logger.debug(
|
|
f"[resolve] Skipping non-class inner: {inner} ({type(inner)})"
|
|
)
|
|
|
|
if origin is Union:
|
|
for arg in args:
|
|
if isinstance(arg, type):
|
|
try:
|
|
if issubclass(arg, BaseModel):
|
|
models.append(arg)
|
|
except TypeError:
|
|
continue
|
|
|
|
return list(dict.fromkeys(models))
|
|
|
|
@staticmethod
|
|
def resolve_response_model(tp: Any) -> Optional[Type[BaseModel]]:
|
|
"""
|
|
Resolves a single Pydantic model from a type annotation if available.
|
|
|
|
This method attempts to extract exactly one BaseModel from the given type. It is used
|
|
to determine if structured output formatting (e.g., JSON schema or function call) should be applied.
|
|
|
|
- If the annotation is a BaseModel or a container of one (e.g., List[BaseModel]), it returns the model.
|
|
- If no Pydantic models are found (e.g., str, int), it returns None without logging.
|
|
- If multiple Pydantic models are found (e.g., Union[ModelA, ModelB]), it logs a warning and returns None.
|
|
|
|
Args:
|
|
tp (Any): The return type annotation to analyze.
|
|
|
|
Returns:
|
|
Optional[Type[BaseModel]]: The resolved model class, or None if not applicable or ambiguous.
|
|
"""
|
|
tp = StructureHandler.unwrap_annotated_type(tp)
|
|
models = StructureHandler.resolve_all_pydantic_models(tp)
|
|
|
|
if len(models) == 1:
|
|
return models[0]
|
|
elif len(models) == 0:
|
|
return None # No model = primitive or unsupported type → silently skip
|
|
else:
|
|
logger.warning(
|
|
f"Ambiguous model resolution: found multiple models in {tp}. Returning None."
|
|
)
|
|
return None
|
|
|
|
@staticmethod
|
|
def validate_against_signature(result: Any, expected_type: Any) -> Any:
|
|
"""
|
|
Validates a result against an expected return annotation type.
|
|
|
|
Supports:
|
|
- Single BaseModel
|
|
- List[BaseModel]
|
|
- Union[BaseModel, ...]
|
|
- Primitives (int, str, bool, etc.)
|
|
- Dict[str, Any], List[Dict[str, Any]]
|
|
|
|
Returns:
|
|
Any: The validated and possibly transformed result (e.g., model_dump()).
|
|
|
|
Raises:
|
|
TypeError: If validation fails or types mismatch.
|
|
"""
|
|
expected_type = StructureHandler.unwrap_annotated_type(expected_type)
|
|
|
|
origin = get_origin(expected_type)
|
|
args = get_args(expected_type)
|
|
|
|
# Handle one or more BaseModels
|
|
models = StructureHandler.resolve_all_pydantic_models(expected_type)
|
|
for model_cls in models:
|
|
try:
|
|
if isinstance(result, list):
|
|
return [
|
|
StructureHandler.validate_response(item, model_cls).model_dump()
|
|
for item in result
|
|
]
|
|
else:
|
|
validated = StructureHandler.validate_response(result, model_cls)
|
|
return validated.model_dump()
|
|
except ValidationError:
|
|
continue
|
|
|
|
# Handle Union[str, dict, etc.]
|
|
if origin is Union:
|
|
for variant in args:
|
|
if isinstance(variant, type) and isinstance(result, variant):
|
|
return result
|
|
|
|
# Handle Dict[str, Any]
|
|
if origin is dict and isinstance(result, dict):
|
|
return result
|
|
|
|
# Handle List[Dict[str, Any]]
|
|
if origin is list and args and args[0] is dict and isinstance(result, list):
|
|
return result
|
|
|
|
# Fallback for primitives via TypeAdapter
|
|
try:
|
|
logger.debug(f"Falling back to TypeAdapter for type: {expected_type}")
|
|
adapter = TypeAdapter(expected_type)
|
|
return adapter.validate_python(result)
|
|
except ValidationError as e:
|
|
raise TypeError(f"Validation failed for type {expected_type}: {e}")
|