diff --git a/app/config.py b/app/config.py index c370680..402fee1 100644 --- a/app/config.py +++ b/app/config.py @@ -25,6 +25,11 @@ class Settings(BaseSettings): dd_env: str = Field(default="development", alias="DD_ENV") dd_version: str = Field(default="1.0.0", alias="DD_VERSION") + # Rate limiting settings + rate_limit_enabled: bool = Field(default=True, alias="RATE_LIMIT_ENABLED") + rate_limit_requests: int = Field(default=100, alias="RATE_LIMIT_REQUESTS") + rate_limit_period: str = Field(default="1/minute", alias="RATE_LIMIT_PERIOD") + class Config: env_file = ".env" case_sensitive = False diff --git a/app/middleware.py b/app/middleware.py index 0ab17e2..ec12d00 100644 --- a/app/middleware.py +++ b/app/middleware.py @@ -12,7 +12,10 @@ from app.logging import get_logger logger = get_logger(__name__) # Rate limiter -limiter = Limiter(key_func=get_remote_address) +limiter = Limiter( + key_func=get_remote_address, + default_limits=[f"{settings.rate_limit_requests}/{settings.rate_limit_period}"] if settings.rate_limit_enabled else [] +) async def logging_middleware(request: Request, call_next: Callable) -> Response: diff --git a/main.py b/main.py index cba66d7..e730670 100644 --- a/main.py +++ b/main.py @@ -5,10 +5,13 @@ import sentry_sdk from ddtrace import patch_all from fastapi import APIRouter, FastAPI from sentry_sdk.integrations.fastapi import FastApiIntegration +from slowapi import _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.middleware import SlowAPIMiddleware from app.config import settings from app.logging import configure_logging, get_logger -from app.middleware import logging_middleware +from app.middleware import logging_middleware, limiter from app.resources import health @@ -41,6 +44,15 @@ def create_app() -> FastAPI: redoc_url="/redoc", ) + # Add rate limiting middleware if enabled + if settings.rate_limit_enabled: + app.state.limiter = limiter + app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.add_middleware(SlowAPIMiddleware) + + # Add logging middleware + app.middleware("http")(logging_middleware) + # Include all endpoint routers app.include_router(health.router, tags=["health"])