文章目录
FastAPI 中间件的使用
一、中间件核心概念
1. 什么是中间件
在 FastAPI 中,中间件是处理 HTTP 请求和响应的拦截器,位于客户端和路由处理函数之间:
客户端 → 中间件 → 路由处理 → 中间件 → 客户端
2. 中间件核心作用
- 请求预处理(认证、日志、限流)
- 响应后处理(添加头信息、修改响应)
- 全局错误处理
- 性能监控
3. 中间件类型
类型 | 执行时机 | 典型应用 |
---|---|---|
HTTP 中间件 | 每次 HTTP 请求 | 日志、认证 |
ASGI 中间件 | 更底层处理 | 协议转换 |
数据库中间件 | 连接池管理 | 连接复用 |
错误处理中间件 | 异常发生时 | 统一错误格式 |
二、创建自定义中间件
1. 基础中间件模板
from fastapi import FastAPI, Request
import time
app = FastAPI()
@app.middleware("http")
async def custom_middleware(request: Request, call_next):
# 请求处理前逻辑
start_time = time.time()
print(f"Request started for {request.url}")
# 调用路由处理函数
response = await call_next(request)
# 响应处理后逻辑
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
print(f"Request completed in {process_time:.2f}s")
return response
2. 完整中间件组件示例
from fastapi import FastAPI, Request, HTTPException
from starlette.types import ASGIApp, Scope, Receive, Send
class RateLimitMiddleware:
def __init__(self, app: ASGIApp, limit: int = 10):
self.app = app
self.limit = limit
self.request_counts = {}
async def __call__(self, scope: Scope, receive: Receive, send: Send):
# 仅处理 HTTP 请求
if scope["type"] != "http":
await self.app(scope, receive, send)
return
client_ip = scope["client"][0] # 获取客户端IP
# 限流逻辑
if client_ip in self.request_counts:
if self.request_counts[client_ip] >= self.limit:
raise HTTPException(status_code=429, detail="Rate limit exceeded")
self.request_counts[client_ip] += 1
else:
self.request_counts[client_ip] = 1
# 调用下一个中间件或路由
await self.app(scope, receive, send)
# 注册中间件
app = FastAPI()
app.add_middleware(RateLimitMiddleware, limit=5) # IP每秒5次请求
三、内置中间件使用指南
1. CORS 跨域中间件
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com", "http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2. HTTPS 重定向中间件
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
app.add_middleware(HTTPSRedirectMiddleware) # 强制所有HTTP请求重定向到HTTPS
3. GZip 压缩中间件
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=500) # 压缩大于500字节的响应
4. 信任主机中间件
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["example.com", "*.example.com"])
四、高级中间件模式
1. 中间件执行顺序
中间件按添加顺序的反向执行请求处理,按添加顺序执行响应处理:
2. 数据库连接池中间件
from contextlib import asynccontextmanager
from fastapi import FastAPI
from databases import Database
database = Database("sqlite:///test.db")
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
# 请求开始:获取数据库连接
request.state.db = database
await database.connect()
try:
response = await call_next(request)
finally:
# 请求结束:释放连接
await database.disconnect()
return response
# 在路由中使用,在进入路由处理之前,已经连接了数据库
@app.get("/items")
async def read_items(request: Request):
db = request.state.db
query = "SELECT * FROM items"
return await db.fetch_all(query)
# 处理完成之后,返回到中间件,释放连接
3. 认证中间件
from fastapi import Request, HTTPException
API_KEYS = {"valid_key_123", "another_key_456"}
@app.middleware("http")
async def api_key_auth(request: Request, call_next):
# 从Header获取API Key
api_key = request.headers.get("X-API-KEY")
if not api_key or api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API Key")
return await call_next(request)
4. 性能监控中间件
import time
from fastapi import Request
from prometheus_client import Counter, Histogram
REQUEST_COUNT = Counter(
"http_requests_total",
"Total HTTP Requests",
["method", "path", "status"]
)
REQUEST_LATENCY = Histogram(
"http_request_duration_seconds",
"HTTP request latency",
["method", "path"]
)
@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
start_time = time.time()
method = request.method
path = request.url.path
try:
response = await call_next(request)
except Exception as e:
status_code = 500
raise e
else:
status_code = response.status_code
finally:
process_time = time.time() - start_time
# 记录指标
REQUEST_COUNT.labels(method, path, status_code).inc()
REQUEST_LATENCY.labels(method, path).observe(process_time)
return response
五、中间件最佳实践
1. 使用原则
- 单一职责:一个中间件只做一件事
- 性能优化:避免阻塞操作,优先使用异步
- 错误处理:使用 try/finally 确保资源释放
- 配置化:通过参数使中间件可配置
2. 调试技巧
# 在中间件中添加调试信息
@app.middleware("http")
async def debug_middleware(request: Request, call_next):
# 打印请求信息
print(f"Request: {request.method} {request.url}")
print(f"Headers: {dict(request.headers)}")
# 调用路由
response = await call_next(request)
# 打印响应信息
print(f"Response Status: {response.status_code}")
print(f"Response Headers: {dict(response.headers)}")
return response
3. 中间件生命周期管理
from contextlib import asynccontextmanager
from fastapi import FastAPI
# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动逻辑
print("Starting application...")
yield
# 关闭逻辑
print("Shutting down application...")
app = FastAPI(lifespan=lifespan)
# 中间件中访问应用状态
@app.middleware("http")
async def state_middleware(request: Request, call_next):
# 访问应用状态
if hasattr(request.app.state, "cache"):
request.state.cache = request.app.state.cache
response = await call_next(request)
# 更新应用状态
request.app.state.last_request = time.time()
return response
六、中间件应用场景总结
场景 | 推荐中间件 | 实现要点 |
---|---|---|
安全防护 | 认证中间件、CORS | 请求头验证,跨域配置 |
性能优化 | GZip压缩、缓存 | 响应压缩,缓存控制 |
可观测性 | 日志、监控 | 请求跟踪,指标收集 |
流量控制 | 限流中间件 | 请求计数,速率限制 |
数据管理 | 数据库连接池 | 连接复用,自动释放 |
协议处理 | HTTPS重定向 | 安全协议强制升级 |