mirror of https://github.com/dapr/dapr-agents.git
				
				
				
			enhance vector search output & robust URL builder for OpenAPI executor
This commit is contained in:
		
							parent
							
								
									f5dc9372e7
								
							
						
					
					
						commit
						be3ce1df06
					
				| 
						 | 
				
			
			@ -1,107 +1,134 @@
 | 
			
		|||
 | 
			
		||||
from dapr_agents.tool.utils.openapi import OpenAPISpecParser
 | 
			
		||||
from dapr_agents.tool.storage import VectorToolStore
 | 
			
		||||
from dapr_agents.tool.base import tool
 | 
			
		||||
from pydantic import BaseModel ,Field, ConfigDict
 | 
			
		||||
from typing import Optional, Any, Dict
 | 
			
		||||
import json, logging, requests
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
import json
 | 
			
		||||
import requests
 | 
			
		||||
from typing import Any, Dict, Optional, List
 | 
			
		||||
 | 
			
		||||
def extract_version(path: str) -> str:
 | 
			
		||||
from pydantic import BaseModel, Field, ConfigDict
 | 
			
		||||
 | 
			
		||||
from dapr_agents.tool.base import tool
 | 
			
		||||
from dapr_agents.tool.storage import VectorToolStore
 | 
			
		||||
from dapr_agents.tool.utils.openapi import OpenAPISpecParser
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _extract_version(path: str) -> str:
 | 
			
		||||
    """Extracts the version prefix from a path if it exists, assuming it starts with 'v' followed by digits."""
 | 
			
		||||
    parts = path.strip('/').split('/')
 | 
			
		||||
    if parts and parts[0].startswith('v') and parts[0][1:].isdigit():
 | 
			
		||||
        return parts[0]
 | 
			
		||||
    return ''
 | 
			
		||||
    seg = path.lstrip("/").split("/", 1)[0]
 | 
			
		||||
    return seg if seg.startswith("v") and seg[1:].isdigit() else ""
 | 
			
		||||
 | 
			
		||||
def generate_get_openapi_definition(tool_vector_store: VectorToolStore):
 | 
			
		||||
    @tool
 | 
			
		||||
    def get_openapi_definition(user_input: str):
 | 
			
		||||
 | 
			
		||||
def _join_url(base: str, path: str) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    Join *base* and *path* while avoiding duplicated version segments
 | 
			
		||||
    and double slashes. Assumes base already ends at the **/servers[0].url**.
 | 
			
		||||
    """
 | 
			
		||||
    parsed = urlparse(base)
 | 
			
		||||
    origin = f"{parsed.scheme}://{parsed.netloc}"
 | 
			
		||||
    base_path = parsed.path.strip("/")
 | 
			
		||||
 | 
			
		||||
    b_ver, p_ver = _extract_version(base_path), _extract_version(path)
 | 
			
		||||
    if b_ver and b_ver == p_ver:
 | 
			
		||||
        path = path[len(f"/{p_ver}") :]
 | 
			
		||||
 | 
			
		||||
    pieces = [p for p in (base_path, path.lstrip("/")) if p]
 | 
			
		||||
    return f"{origin}/" + "/".join(pieces).replace("//", "/")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _fmt_candidate(doc: str, meta: Dict[str, Any]) -> str:
 | 
			
		||||
    """Return a single nice, log-friendly candidate string."""
 | 
			
		||||
    meta_line = f"url={meta.get('url')} | method={meta.get('method', '').upper()} | name={meta.get('name')}"
 | 
			
		||||
    return f"{doc.strip()}\n{meta_line}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GetDefinitionInput(BaseModel):
 | 
			
		||||
    """Free-form query describing *one* desired operation (e.g. "multiply two numbers")."""
 | 
			
		||||
    user_input: str = Field(..., description="Natural-language description of ONE desired API operation.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_get_openapi_definition(store: VectorToolStore):
 | 
			
		||||
    @tool(args_model=GetDefinitionInput)
 | 
			
		||||
    def get_openapi_definition(user_input: str) -> List[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Get potential APIs for the user to use to accompish task.
 | 
			
		||||
        You have to choose the right one after getting a response.
 | 
			
		||||
        This tool MUST be used before calling any APIs.
 | 
			
		||||
        Search the vector store for OpenAPI *operation IDs / paths* most relevant
 | 
			
		||||
        to **one** user task.
 | 
			
		||||
 | 
			
		||||
        Always call this **once per new task** *before* attempting an
 | 
			
		||||
        `open_api_call_executor`. Returns up to 5 candidate operations.
 | 
			
		||||
        """
 | 
			
		||||
        similatiry_result =  tool_vector_store.get_similar_tools(query_texts=[user_input], k=5)
 | 
			
		||||
        documents = similatiry_result['documents'][0]
 | 
			
		||||
        return documents
 | 
			
		||||
        result = store.get_similar_tools(query_texts=[user_input], k=5)
 | 
			
		||||
        docs: List[str] = result["documents"][0]
 | 
			
		||||
        metas: List[Dict[str, Any]] = result["metadatas"][0]
 | 
			
		||||
 | 
			
		||||
        return [_fmt_candidate(d, m) for d, m in zip(docs, metas)]
 | 
			
		||||
 | 
			
		||||
    return get_openapi_definition
 | 
			
		||||
 | 
			
		||||
def generate_api_call_executor(spec_parser: OpenAPISpecParser, auth_header: Dict = None):
 | 
			
		||||
    base_url = spec_parser.spec.servers[0].url
 | 
			
		||||
    
 | 
			
		||||
    class OpenAPIExecutorInput(BaseModel):
 | 
			
		||||
        path_template: str = Field(description="Template of the API path that may include placeholders.")
 | 
			
		||||
        method: str = Field(description="The HTTP method to be used for the API call (e.g., 'GET', 'POST').")
 | 
			
		||||
        path_params: Dict[str, Any] = Field(default={}, description="Path parameters to be replaced in the path template.")
 | 
			
		||||
        data: Dict[str, Any] = Field(default={}, description="Data to be sent in the body of the request, applicable for POST, PUT methods.")
 | 
			
		||||
        headers: Optional[Dict[str, Any]] = Field(default=None, description="HTTP headers to send with the request.")
 | 
			
		||||
        params: Optional[Dict[str, Any]] = Field(default=None, description="Query parameters to append to the URL.")
 | 
			
		||||
        
 | 
			
		||||
        model_config = ConfigDict(extra="allow")
 | 
			
		||||
 | 
			
		||||
class OpenAPIExecutorInput(BaseModel):
 | 
			
		||||
    path_template: str = Field(..., description="Path template, may contain `{placeholder}` segments.")
 | 
			
		||||
    method: str = Field(..., description="HTTP verb, upper‑case.")
 | 
			
		||||
    path_params: Dict[str, Any] = Field(default_factory=dict, description="Replacements for path placeholders.")
 | 
			
		||||
    data: Dict[str, Any] = Field(default_factory=dict, description="JSON body for POST/PUT/PATCH.")
 | 
			
		||||
    headers: Optional[Dict[str, Any]] = Field(default=None, description="Extra request headers.")
 | 
			
		||||
    params: Optional[Dict[str, Any]] = Field(default=None, description="Query params (?key=value).")
 | 
			
		||||
 | 
			
		||||
    model_config = ConfigDict(extra="allow")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Dict[str, str]] = None):
 | 
			
		||||
    base_url = spec.spec.servers[0].url  # assumes at least one server entry
 | 
			
		||||
 | 
			
		||||
    @tool(args_model=OpenAPIExecutorInput)
 | 
			
		||||
    def open_api_call_executor(
 | 
			
		||||
        *,
 | 
			
		||||
        path_template: str,
 | 
			
		||||
        method: str,
 | 
			
		||||
        path_params: Dict[str, Any],
 | 
			
		||||
        data: Dict[str, Any],
 | 
			
		||||
        headers: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        params: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        **kwargs: Any
 | 
			
		||||
        **req_kwargs: Any,
 | 
			
		||||
    ) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        Execute an API call based on provided parameters and configuration.
 | 
			
		||||
        It MUST be used after the get_openapi_definition to call APIs.
 | 
			
		||||
        Make sure to include the right header values to authenticate to the API if needed.
 | 
			
		||||
        """
 | 
			
		||||
        
 | 
			
		||||
        # Format the path with path_params
 | 
			
		||||
        formatted_path = path_template.format(**path_params)
 | 
			
		||||
        
 | 
			
		||||
        # Parse the base_url and extract the version
 | 
			
		||||
        parsed_url = urlparse(base_url)
 | 
			
		||||
        origin = f"{parsed_url.scheme}://{parsed_url.netloc}"
 | 
			
		||||
        base_path = parsed_url.path.strip('/')
 | 
			
		||||
        
 | 
			
		||||
        base_version = extract_version(base_path)
 | 
			
		||||
        path_version = extract_version(formatted_path)
 | 
			
		||||
        Execute **one** REST call described by an OpenAPI operation.
 | 
			
		||||
 | 
			
		||||
        # Avoid duplication of the version in the final URL
 | 
			
		||||
        if base_version and path_version == base_version:
 | 
			
		||||
            formatted_path = formatted_path[len(f"/{path_version}"):]
 | 
			
		||||
        
 | 
			
		||||
        # Ensure there is a single slash between origin, base_path, and formatted_path
 | 
			
		||||
        final_url = f"{origin}/{base_path}/{formatted_path}".replace('//', '/')
 | 
			
		||||
        # Fix the issue by ensuring the correct scheme with double slashes
 | 
			
		||||
        if not final_url.startswith('https://') and parsed_url.scheme == 'https':
 | 
			
		||||
            final_url = final_url.replace('https:/', 'https://')
 | 
			
		||||
        
 | 
			
		||||
        # Initialize the headers with auth_header if provided
 | 
			
		||||
        final_headers = auth_header if auth_header else {}
 | 
			
		||||
        # Update the final_headers with additional headers passed to the function
 | 
			
		||||
        Use this only *after* `get_openapi_definition` has returned a matching
 | 
			
		||||
        `path_template`/`method`.
 | 
			
		||||
 | 
			
		||||
        Authentication: merge `auth_header` given at agent-init time with
 | 
			
		||||
        any per-call `headers` argument (per-call overrides duplicates).
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        url = _join_url(base_url, path_template.format(**path_params))
 | 
			
		||||
 | 
			
		||||
        final_headers = (auth_header or {}).copy()
 | 
			
		||||
        if headers:
 | 
			
		||||
            final_headers.update(headers)
 | 
			
		||||
 | 
			
		||||
        if data:
 | 
			
		||||
            data = json.dumps(data)  # Convert data to JSON string if not empty
 | 
			
		||||
 | 
			
		||||
        request_kwargs = {
 | 
			
		||||
            "headers": final_headers,
 | 
			
		||||
            "params": params,
 | 
			
		||||
            "data": data,
 | 
			
		||||
            **kwargs
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        # redact auth key in debug logs
 | 
			
		||||
        safe_hdrs = {k: ("***" if "auth" in k.lower() or "key" in k.lower() else v)
 | 
			
		||||
                     for k, v in final_headers.items()}
 | 
			
		||||
        
 | 
			
		||||
        # Only convert data to JSON if we're doing a request that requires a body
 | 
			
		||||
        # and there's actually data to send
 | 
			
		||||
        body = None
 | 
			
		||||
        if method.upper() in ["POST", "PUT", "PATCH"] and data:
 | 
			
		||||
            body = json.dumps(data)
 | 
			
		||||
        
 | 
			
		||||
        # Add more detailed logging similar to old implementation
 | 
			
		||||
        logger.debug("→ %s %s | headers=%s params=%s data=%s", 
 | 
			
		||||
                    method, url, safe_hdrs, params, 
 | 
			
		||||
                    "***" if body else None)
 | 
			
		||||
        
 | 
			
		||||
        # For debugging purposes, similar to the old implementation
 | 
			
		||||
        print(f"Base Url: {base_url}")
 | 
			
		||||
        print(f"Requested Url: {final_url}")
 | 
			
		||||
        print(f"Requested Url: {url}")
 | 
			
		||||
        print(f"Requested Method: {method}")
 | 
			
		||||
        print(f"Requested Parameters: {params}")
 | 
			
		||||
 | 
			
		||||
        # Filter out None values to avoid sending them to requests
 | 
			
		||||
        request_kwargs = {k: v for k, v in request_kwargs.items() if v is not None}
 | 
			
		||||
 | 
			
		||||
        response = requests.request(method, final_url, **request_kwargs)
 | 
			
		||||
        return response.json()
 | 
			
		||||
        
 | 
			
		||||
        resp = requests.request(method, url, headers=final_headers,
 | 
			
		||||
                                params=params, data=body, **req_kwargs)
 | 
			
		||||
        resp.raise_for_status()
 | 
			
		||||
        return resp.json()
 | 
			
		||||
 | 
			
		||||
    return open_api_call_executor
 | 
			
		||||
		Loading…
	
		Reference in New Issue