mirror of https://github.com/dapr/dapr-agents.git
170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
import os
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from distutils.util import strtobool
|
|
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI
|
|
from pydantic import Field, ConfigDict, PrivateAttr
|
|
from typing import List, Optional, Any
|
|
from dapr_agents.service import APIServerBase
|
|
import uvicorn
|
|
import asyncio
|
|
import signal
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FastAPIServerBase(APIServerBase):
|
|
"""
|
|
Abstract base class for FastAPI-based API server services.
|
|
Provides core FastAPI functionality, with support for CORS, lifecycle management, and graceful shutdown.
|
|
"""
|
|
|
|
description: Optional[str] = Field(
|
|
None, description="Description of the API service."
|
|
)
|
|
cors_origins: Optional[List[str]] = Field(
|
|
default_factory=lambda: ["*"], description="Allowed CORS origins."
|
|
)
|
|
cors_credentials: bool = Field(
|
|
True, description="Whether to allow credentials in CORS requests."
|
|
)
|
|
cors_methods: Optional[List[str]] = Field(
|
|
default_factory=lambda: ["*"], description="Allowed HTTP methods for CORS."
|
|
)
|
|
cors_headers: Optional[List[str]] = Field(
|
|
default_factory=lambda: ["*"], description="Allowed HTTP headers for CORS."
|
|
)
|
|
|
|
# Fields initialized in model_post_init
|
|
app: Optional[FastAPI] = Field(
|
|
default=None, init=False, description="The FastAPI application instance."
|
|
)
|
|
server: Optional[Any] = Field(
|
|
default=None,
|
|
init=False,
|
|
description="Server handle for running the FastAPI app.",
|
|
)
|
|
|
|
_otel_enabled: Optional[bool] = PrivateAttr(default=True)
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Post-initialization to configure core FastAPI app and CORS settings.
|
|
"""
|
|
|
|
try:
|
|
self._otel_enabled: bool = bool(
|
|
strtobool(os.getenv("DAPR_AGENTS_OTEL_ENABLED", "True"))
|
|
)
|
|
except ValueError:
|
|
self._otel_enabled = False
|
|
|
|
if self._otel_enabled:
|
|
from dapr_agents.agent import DaprAgentsOTel
|
|
from opentelemetry import trace
|
|
from opentelemetry._logs import set_logger_provider
|
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|
|
|
otel_client = DaprAgentsOTel(
|
|
service_name="FastAPI Server",
|
|
otlp_endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
|
|
)
|
|
tracer = otel_client.create_and_instrument_tracer_provider()
|
|
trace.set_tracer_provider(tracer)
|
|
|
|
otel_logger = otel_client.create_and_instrument_logging_provider(
|
|
logger=logger,
|
|
)
|
|
set_logger_provider(otel_logger)
|
|
|
|
# We can instrument FastAPI automatically
|
|
FastAPIInstrumentor.instrument_app(self.app)
|
|
|
|
# Initialize FastAPI app with title and description
|
|
self.app = FastAPI(
|
|
title=f"{self.service_name} API Server",
|
|
description=self.description or self.service_name,
|
|
lifespan=self.lifespan,
|
|
)
|
|
|
|
# Configure CORS settings
|
|
self.app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=self.cors_origins,
|
|
allow_credentials=self.cors_credentials,
|
|
allow_methods=self.cors_methods,
|
|
allow_headers=self.cors_headers,
|
|
)
|
|
|
|
logger.info(
|
|
f"{self.service_name} FastAPI server initialized on port {self.service_port} with CORS settings."
|
|
)
|
|
|
|
# Call the base post-initialization
|
|
super().model_post_init(__context)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(self, app: FastAPI):
|
|
"""
|
|
Default lifespan function to manage startup and shutdown processes.
|
|
Can be overridden by subclasses to add setup and teardown tasks such as handling agent metadata.
|
|
"""
|
|
try:
|
|
yield
|
|
finally:
|
|
await self.stop()
|
|
|
|
async def start(self, log_level=None):
|
|
"""
|
|
Start the FastAPI app server using the existing event loop with a specified logging level,
|
|
and ensure that shutdown is handled gracefully with SIGINT and SIGTERM signals.
|
|
"""
|
|
if log_level is None:
|
|
log_level = logging.getLevelName(logger.getEffectiveLevel()).lower()
|
|
|
|
# Set port to 0 if we want a random port
|
|
requested_port = self.service_port or 0
|
|
|
|
config = uvicorn.Config(
|
|
self.app,
|
|
host=self.service_host,
|
|
port=requested_port,
|
|
log_level=log_level,
|
|
)
|
|
self.server: uvicorn.Server = uvicorn.Server(config)
|
|
|
|
# Add signal handlers
|
|
loop = asyncio.get_event_loop()
|
|
for s in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(s, lambda: asyncio.create_task(self.stop()))
|
|
|
|
# Start in background so we can inspect the actual port
|
|
server_task = asyncio.create_task(self.server.serve())
|
|
|
|
# Wait for startup to complete
|
|
while not self.server.started:
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Extract the real port from the bound socket
|
|
if self.server.servers:
|
|
sock = list(self.server.servers)[0].sockets[0]
|
|
actual_port = sock.getsockname()[1]
|
|
self.service_port = actual_port
|
|
else:
|
|
logger.warning(f"{self.service_name} could not determine bound port")
|
|
|
|
await server_task
|
|
|
|
async def stop(self):
|
|
"""
|
|
Stop the FastAPI server gracefully.
|
|
"""
|
|
if self.server:
|
|
logger.info(
|
|
f"Stopping {self.service_name} server on port {self.service_port}."
|
|
)
|
|
self.server.should_exit = True
|