FastAPI 中间件的使用

发布于:2025-07-21 ⋅ 阅读:(14) ⋅ 点赞:(0)

FastAPI 中间件的使用

一、中间件核心概念

1. 什么是中间件

在 FastAPI 中,中间件是处理 HTTP 请求和响应的拦截器,位于客户端和路由处理函数之间:

客户端 → 中间件 → 路由处理 → 中间件 → 客户端

2. 中间件核心作用

  • 请求预处理(认证、日志、限流)
  • 响应后处理(添加头信息、修改响应)
  • 全局错误处理
  • 性能监控

3. 中间件类型

类型 执行时机 典型应用
HTTP 中间件 每次 HTTP 请求 日志、认证
ASGI 中间件 更底层处理 协议转换
数据库中间件 连接池管理 连接复用
错误处理中间件 异常发生时 统一错误格式

二、创建自定义中间件

1. 基础中间件模板

from fastapi import FastAPI, Request
import time

app = FastAPI()

@app.middleware("http")
async def custom_middleware(request: Request, call_next):
    # 请求处理前逻辑
    start_time = time.time()
    print(f"Request started for {request.url}")
    
    # 调用路由处理函数
    response = await call_next(request)
    
    # 响应处理后逻辑
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    print(f"Request completed in {process_time:.2f}s")
    
    return response

2. 完整中间件组件示例

from fastapi import FastAPI, Request, HTTPException
from starlette.types import ASGIApp, Scope, Receive, Send

class RateLimitMiddleware:
    def __init__(self, app: ASGIApp, limit: int = 10):
        self.app = app
        self.limit = limit
        self.request_counts = {}

    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        # 仅处理 HTTP 请求
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return
        
        client_ip = scope["client"][0]  # 获取客户端IP
        
        # 限流逻辑
        if client_ip in self.request_counts:
            if self.request_counts[client_ip] >= self.limit:
                raise HTTPException(status_code=429, detail="Rate limit exceeded")
            self.request_counts[client_ip] += 1
        else:
            self.request_counts[client_ip] = 1
        
        # 调用下一个中间件或路由
        await self.app(scope, receive, send)

# 注册中间件
app = FastAPI()
app.add_middleware(RateLimitMiddleware, limit=5)  # IP每秒5次请求

三、内置中间件使用指南

1. CORS 跨域中间件

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://example.com", "http://localhost:3000"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

2. HTTPS 重定向中间件

from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

app.add_middleware(HTTPSRedirectMiddleware)  # 强制所有HTTP请求重定向到HTTPS

3. GZip 压缩中间件

from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(GZipMiddleware, minimum_size=500)  # 压缩大于500字节的响应

4. 信任主机中间件

from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(TrustedHostMiddleware, allowed_hosts=["example.com", "*.example.com"])

四、高级中间件模式

1. 中间件执行顺序

中间件按添加顺序的反向执行请求处理,按添加顺序执行响应处理:

客户端请求
中间件1
中间件2
路由处理
中间件2
中间件1
客户端响应

2. 数据库连接池中间件

from contextlib import asynccontextmanager
from fastapi import FastAPI
from databases import Database

database = Database("sqlite:///test.db")

@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
    # 请求开始:获取数据库连接
    request.state.db = database
    await database.connect()
    
    try:
        response = await call_next(request)
    finally:
        # 请求结束:释放连接
        await database.disconnect()
    
    return response

# 在路由中使用,在进入路由处理之前,已经连接了数据库
@app.get("/items")
async def read_items(request: Request):
    db = request.state.db
    query = "SELECT * FROM items"
    return await db.fetch_all(query)
# 处理完成之后,返回到中间件,释放连接

3. 认证中间件

from fastapi import Request, HTTPException

API_KEYS = {"valid_key_123", "another_key_456"}

@app.middleware("http")
async def api_key_auth(request: Request, call_next):
    # 从Header获取API Key
    api_key = request.headers.get("X-API-KEY")
    
    if not api_key or api_key not in API_KEYS:
        raise HTTPException(status_code=401, detail="Invalid API Key")
    
    return await call_next(request)

4. 性能监控中间件

import time
from fastapi import Request
from prometheus_client import Counter, Histogram

REQUEST_COUNT = Counter(
    "http_requests_total",
    "Total HTTP Requests",
    ["method", "path", "status"]
)

REQUEST_LATENCY = Histogram(
    "http_request_duration_seconds",
    "HTTP request latency",
    ["method", "path"]
)

@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
    start_time = time.time()
    method = request.method
    path = request.url.path
    
    try:
        response = await call_next(request)
    except Exception as e:
        status_code = 500
        raise e
    else:
        status_code = response.status_code
    finally:
        process_time = time.time() - start_time
        
        # 记录指标
        REQUEST_COUNT.labels(method, path, status_code).inc()
        REQUEST_LATENCY.labels(method, path).observe(process_time)
    
    return response

五、中间件最佳实践

1. 使用原则

  • 单一职责:一个中间件只做一件事
  • 性能优化:避免阻塞操作,优先使用异步
  • 错误处理:使用 try/finally 确保资源释放
  • 配置化:通过参数使中间件可配置

2. 调试技巧

# 在中间件中添加调试信息
@app.middleware("http")
async def debug_middleware(request: Request, call_next):
    # 打印请求信息
    print(f"Request: {request.method} {request.url}")
    print(f"Headers: {dict(request.headers)}")
    
    # 调用路由
    response = await call_next(request)
    
    # 打印响应信息
    print(f"Response Status: {response.status_code}")
    print(f"Response Headers: {dict(response.headers)}")
    
    return response

3. 中间件生命周期管理

from contextlib import asynccontextmanager
from fastapi import FastAPI

# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动逻辑
    print("Starting application...")
    yield
    # 关闭逻辑
    print("Shutting down application...")

app = FastAPI(lifespan=lifespan)

# 中间件中访问应用状态
@app.middleware("http")
async def state_middleware(request: Request, call_next):
    # 访问应用状态
    if hasattr(request.app.state, "cache"):
        request.state.cache = request.app.state.cache
    
    response = await call_next(request)
    
    # 更新应用状态
    request.app.state.last_request = time.time()
    
    return response

六、中间件应用场景总结

场景 推荐中间件 实现要点
安全防护 认证中间件、CORS 请求头验证,跨域配置
性能优化 GZip压缩、缓存 响应压缩,缓存控制
可观测性 日志、监控 请求跟踪,指标收集
流量控制 限流中间件 请求计数,速率限制
数据管理 数据库连接池 连接复用,自动释放
协议处理 HTTPS重定向 安全协议强制升级

网站公告

今日签到

点亮在社区的每一天
去签到