enhance vector search output & robust URL builder for OpenAPI executor

This commit is contained in:
Roberto Rodriguez 2025-04-21 02:30:25 -04:00
parent f5dc9372e7
commit be3ce1df06
1 changed files with 107 additions and 80 deletions

View File

@ -1,107 +1,134 @@
from dapr_agents.tool.utils.openapi import OpenAPISpecParser
from dapr_agents.tool.storage import VectorToolStore
from dapr_agents.tool.base import tool
from pydantic import BaseModel ,Field, ConfigDict
from typing import Optional, Any, Dict
import json, logging, requests
from urllib.parse import urlparse
import json
import requests
from typing import Any, Dict, Optional, List
def extract_version(path: str) -> str:
from pydantic import BaseModel, Field, ConfigDict
from dapr_agents.tool.base import tool
from dapr_agents.tool.storage import VectorToolStore
from dapr_agents.tool.utils.openapi import OpenAPISpecParser
logger = logging.getLogger(__name__)
def _extract_version(path: str) -> str:
"""Extracts the version prefix from a path if it exists, assuming it starts with 'v' followed by digits."""
parts = path.strip('/').split('/')
if parts and parts[0].startswith('v') and parts[0][1:].isdigit():
return parts[0]
return ''
seg = path.lstrip("/").split("/", 1)[0]
return seg if seg.startswith("v") and seg[1:].isdigit() else ""
def generate_get_openapi_definition(tool_vector_store: VectorToolStore):
@tool
def get_openapi_definition(user_input: str):
def _join_url(base: str, path: str) -> str:
"""
Join *base* and *path* while avoiding duplicated version segments
and double slashes. Assumes base already ends at the **/servers[0].url**.
"""
parsed = urlparse(base)
origin = f"{parsed.scheme}://{parsed.netloc}"
base_path = parsed.path.strip("/")
b_ver, p_ver = _extract_version(base_path), _extract_version(path)
if b_ver and b_ver == p_ver:
path = path[len(f"/{p_ver}") :]
pieces = [p for p in (base_path, path.lstrip("/")) if p]
return f"{origin}/" + "/".join(pieces).replace("//", "/")
def _fmt_candidate(doc: str, meta: Dict[str, Any]) -> str:
"""Return a single nice, log-friendly candidate string."""
meta_line = f"url={meta.get('url')} | method={meta.get('method', '').upper()} | name={meta.get('name')}"
return f"{doc.strip()}\n{meta_line}"
class GetDefinitionInput(BaseModel):
"""Free-form query describing *one* desired operation (e.g. "multiply two numbers")."""
user_input: str = Field(..., description="Natural-language description of ONE desired API operation.")
def generate_get_openapi_definition(store: VectorToolStore):
@tool(args_model=GetDefinitionInput)
def get_openapi_definition(user_input: str) -> List[str]:
"""
Get potential APIs for the user to use to accompish task.
You have to choose the right one after getting a response.
This tool MUST be used before calling any APIs.
Search the vector store for OpenAPI *operation IDs / paths* most relevant
to **one** user task.
Always call this **once per new task** *before* attempting an
`open_api_call_executor`. Returns up to 5 candidate operations.
"""
similatiry_result = tool_vector_store.get_similar_tools(query_texts=[user_input], k=5)
documents = similatiry_result['documents'][0]
return documents
result = store.get_similar_tools(query_texts=[user_input], k=5)
docs: List[str] = result["documents"][0]
metas: List[Dict[str, Any]] = result["metadatas"][0]
return [_fmt_candidate(d, m) for d, m in zip(docs, metas)]
return get_openapi_definition
def generate_api_call_executor(spec_parser: OpenAPISpecParser, auth_header: Dict = None):
base_url = spec_parser.spec.servers[0].url
class OpenAPIExecutorInput(BaseModel):
path_template: str = Field(description="Template of the API path that may include placeholders.")
method: str = Field(description="The HTTP method to be used for the API call (e.g., 'GET', 'POST').")
path_params: Dict[str, Any] = Field(default={}, description="Path parameters to be replaced in the path template.")
data: Dict[str, Any] = Field(default={}, description="Data to be sent in the body of the request, applicable for POST, PUT methods.")
headers: Optional[Dict[str, Any]] = Field(default=None, description="HTTP headers to send with the request.")
params: Optional[Dict[str, Any]] = Field(default=None, description="Query parameters to append to the URL.")
model_config = ConfigDict(extra="allow")
class OpenAPIExecutorInput(BaseModel):
path_template: str = Field(..., description="Path template, may contain `{placeholder}` segments.")
method: str = Field(..., description="HTTP verb, uppercase.")
path_params: Dict[str, Any] = Field(default_factory=dict, description="Replacements for path placeholders.")
data: Dict[str, Any] = Field(default_factory=dict, description="JSON body for POST/PUT/PATCH.")
headers: Optional[Dict[str, Any]] = Field(default=None, description="Extra request headers.")
params: Optional[Dict[str, Any]] = Field(default=None, description="Query params (?key=value).")
model_config = ConfigDict(extra="allow")
def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Dict[str, str]] = None):
base_url = spec.spec.servers[0].url # assumes at least one server entry
@tool(args_model=OpenAPIExecutorInput)
def open_api_call_executor(
*,
path_template: str,
method: str,
path_params: Dict[str, Any],
data: Dict[str, Any],
headers: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs: Any
**req_kwargs: Any,
) -> Any:
"""
Execute an API call based on provided parameters and configuration.
It MUST be used after the get_openapi_definition to call APIs.
Make sure to include the right header values to authenticate to the API if needed.
"""
# Format the path with path_params
formatted_path = path_template.format(**path_params)
# Parse the base_url and extract the version
parsed_url = urlparse(base_url)
origin = f"{parsed_url.scheme}://{parsed_url.netloc}"
base_path = parsed_url.path.strip('/')
base_version = extract_version(base_path)
path_version = extract_version(formatted_path)
Execute **one** REST call described by an OpenAPI operation.
# Avoid duplication of the version in the final URL
if base_version and path_version == base_version:
formatted_path = formatted_path[len(f"/{path_version}"):]
# Ensure there is a single slash between origin, base_path, and formatted_path
final_url = f"{origin}/{base_path}/{formatted_path}".replace('//', '/')
# Fix the issue by ensuring the correct scheme with double slashes
if not final_url.startswith('https://') and parsed_url.scheme == 'https':
final_url = final_url.replace('https:/', 'https://')
# Initialize the headers with auth_header if provided
final_headers = auth_header if auth_header else {}
# Update the final_headers with additional headers passed to the function
Use this only *after* `get_openapi_definition` has returned a matching
`path_template`/`method`.
Authentication: merge `auth_header` given at agent-init time with
any per-call `headers` argument (per-call overrides duplicates).
"""
url = _join_url(base_url, path_template.format(**path_params))
final_headers = (auth_header or {}).copy()
if headers:
final_headers.update(headers)
if data:
data = json.dumps(data) # Convert data to JSON string if not empty
request_kwargs = {
"headers": final_headers,
"params": params,
"data": data,
**kwargs
}
# redact auth key in debug logs
safe_hdrs = {k: ("***" if "auth" in k.lower() or "key" in k.lower() else v)
for k, v in final_headers.items()}
# Only convert data to JSON if we're doing a request that requires a body
# and there's actually data to send
body = None
if method.upper() in ["POST", "PUT", "PATCH"] and data:
body = json.dumps(data)
# Add more detailed logging similar to old implementation
logger.debug("%s %s | headers=%s params=%s data=%s",
method, url, safe_hdrs, params,
"***" if body else None)
# For debugging purposes, similar to the old implementation
print(f"Base Url: {base_url}")
print(f"Requested Url: {final_url}")
print(f"Requested Url: {url}")
print(f"Requested Method: {method}")
print(f"Requested Parameters: {params}")
# Filter out None values to avoid sending them to requests
request_kwargs = {k: v for k, v in request_kwargs.items() if v is not None}
response = requests.request(method, final_url, **request_kwargs)
return response.json()
resp = requests.request(method, url, headers=final_headers,
params=params, data=body, **req_kwargs)
resp.raise_for_status()
return resp.json()
return open_api_call_executor