【自记】Python 的 SQLAlchemy 完整实践教程

发布于:2025-09-11 ⋅ 阅读:(17) ⋅ 点赞:(0)

目录

  1. SQLAlchemy 介绍
  2. 环境准备与安装
  3. 数据库连接
  4. 数据模型定义
  5. 基本数据操作
  6. 复杂查询操作
  7. 高级特性
  8. 实战项目示例
  9. 性能优化与最佳实践
  10. 常见问题与解决方案

1. SQLAlchemy 介绍

1.1 什么是SQLAlchemy

SQLAlchemy 是一个用于 Python 的 SQL 工具和对象关系映射(ORM)库。它允许开发者通过 Python 代码来与关系型数据库交互,而不必直接编写SQL语句。

对象关系映射(ORM) 是一种程序设计技术,用于实现面向对象编程语言里不同类型系统的数据之间的转换。简单来说,就是将数据库表映射为Python类,将表中的记录映射为类的实例。

1.2 SQLAlchemy 架构

SQLAlchemy 采用分层架构设计:

┌─────────────────────────────────────┐
│            ORM 层                   │
│  (对象关系映射 - 高级抽象)            │
├─────────────────────────────────────┤
│           Core 层                   │
│  (SQL表达式语言 - 底层抽象)           │
├─────────────────────────────────────┤
│          Engine 层                  │
│  (数据库引擎 - 连接管理)              │
├─────────────────────────────────────┤
│         DBAPI 层                    │
│  (数据库驱动 - 具体实现)              │
└─────────────────────────────────────┘

1.3 主要应用场景

SQLAlchemy 主要应用于以下场景:

  • 数据库访问和操作:提供高层抽象来操作数据库,避免编写原生SQL语句
  • ORM映射:建立Python类与数据库表的映射关系,简化数据模型操作
  • 复杂查询:提供丰富的查询方式,如过滤、分组、联结等
  • 异步查询:基于Greenlet等实现异步查询,提高查询效率
  • 事务控制:通过Session管理数据库会话和事务
  • 多数据库支持:支持PostgreSQL、MySQL、Oracle、SQLite等主流数据库
  • Web框架集成:与Flask、FastAPI等框架无缝集成

2. 环境准备与安装

2.1 安装SQLAlchemy

# 安装SQLAlchemy核心库
pip install sqlalchemy

# 安装数据库驱动(根据需要选择)
pip install pymysql          # MySQL驱动
pip install psycopg2-binary  # PostgreSQL驱动
pip install cx_Oracle        # Oracle驱动
# SQLite驱动已内置在Python标准库中

2.2 数据库依赖对照表

数据库类型 依赖库 连接字符串示例
关系型数据库
MySQL pymysql mysql+pymysql://username:password@localhost:3306/database_name
PostgreSQL psycopg2 postgresql://username:password@localhost:5432/database_name
SQLite 内置 sqlite:///example.db
Oracle cx_Oracle oracle://username:password@localhost:1521/orcl
NoSQL数据库
MongoDB pymongo mongodb://username:password@localhost:27017/database_name
Redis redis redis://localhost:6379/0

2.3 验证安装

# 验证SQLAlchemy安装
import sqlalchemy
print(f"SQLAlchemy版本: {sqlalchemy.__version__}")

# 测试数据库连接
from sqlalchemy import create_engine

# 创建内存SQLite数据库进行测试
engine = create_engine('sqlite:///:memory:', echo=True)
print("数据库连接测试成功!")

3. 数据库连接

3.1 创建数据库引擎

数据库引擎是SQLAlchemy的核心组件,负责管理数据库连接。

from sqlalchemy import create_engine

# SQLite连接(文件数据库)
engine = create_engine('sqlite:///example.db', echo=True)

# MySQL连接
dbHost = 'mysql+pymysql://root:password@127.0.0.1:3306/test'
engine = create_engine(
    dbHost,
    echo=True,              # 是否打印SQL语句
    pool_size=10,           # 连接池大小
    max_overflow=20,        # 超出连接池大小的连接数
    pool_pre_ping=True,     # 连接前检查连接有效性
    pool_recycle=3600       # 连接回收时间(秒)
)

# PostgreSQL连接
engine = create_engine(
    'postgresql://username:password@localhost:5432/database',
    echo=False,
    pool_size=5,
    max_overflow=10
)

3.2 引擎参数详解

参数 说明 默认值
echo 是否打印执行的SQL语句 False
pool_size 连接池保持的连接数 5
max_overflow 允许超过pool_size的最大连接数 10
pool_timeout 获取连接的超时时间(秒) 30
pool_recycle 连接在连接池中保持的最长时间(秒) -1
pool_pre_ping 连接前检查连接是否有效 False

3.3 连接池管理

from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool

# 自定义连接池配置
engine = create_engine(
    'mysql+pymysql://user:password@localhost/dbname',
    poolclass=QueuePool,
    pool_size=20,           # 连接池大小
    max_overflow=30,        # 最大溢出连接数
    pool_timeout=60,        # 获取连接超时时间
    pool_recycle=7200,      # 连接回收时间(2小时)
    echo=True
)

# 获取连接池状态信息
print(f"连接池大小: {engine.pool.size()}")
print(f"已检出连接数: {engine.pool.checkedout()}")
print(f"溢出连接数: {engine.pool.overflow()}")

4. 数据模型定义

4.1 声明式基类

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text
from datetime import datetime

# 创建声明式基类
Base = declarative_base()

4.2 基础模型定义

class User(Base):
    """用户模型类"""
    __tablename__ = 'users'  # 指定表名
    __table_args__ = {'comment': '用户信息表'}  # 表注释
    
    # 定义字段
    id = Column(Integer, primary_key=True, autoincrement=True, comment='用户ID')
    username = Column(String(50), nullable=False, unique=True, comment='用户名')
    email = Column(String(100), nullable=False, index=True, comment='邮箱')
    password_hash = Column(String(128), nullable=False, comment='密码哈希')
    full_name = Column(String(100), comment='全名')
    is_active = Column(Boolean, default=True, comment='是否激活')
    created_at = Column(DateTime, default=datetime.now, comment='创建时间')
    updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='更新时间')
    
    def __repr__(self):
        return f"<User(id={self.id}, username='{self.username}')>"
    
    def __str__(self):
        return f"用户: {self.username} ({self.email})"

4.3 Column常用参数说明

参数 说明 示例
primary_key 是否为主键 primary_key=True
nullable 是否允许为空 nullable=False
unique 是否唯一 unique=True
index 是否创建索引 index=True
default 默认值 default=0
onupdate 更新时的默认值 onupdate=datetime.now
autoincrement 是否自增 autoincrement=True
comment 字段注释 comment='用户ID'

4.4 复杂数据类型示例

from sqlalchemy import JSON, DECIMAL, Enum
from sqlalchemy.dialects.mysql import LONGTEXT
import enum

class UserStatus(enum.Enum):
    """用户状态枚举"""
    ACTIVE = "active"
    INACTIVE = "inactive"
    SUSPENDED = "suspended"

class Product(Base):
    """商品模型类"""
    __tablename__ = 'products'
    
    id = Column(Integer, primary_key=True)
    name = Column(String(200), nullable=False, comment='商品名称')
    description = Column(Text, comment='商品描述')
    price = Column(DECIMAL(10, 2), nullable=False, comment='价格')
    stock = Column(Integer, default=0, comment='库存数量')
    
    # JSON字段存储额外属性
    attributes = Column(JSON, comment='商品属性')
    
    # 枚举字段
    status = Column(Enum(UserStatus), default=UserStatus.ACTIVE, comment='状态')
    
    # 长文本字段
    content = Column(LONGTEXT, comment='详细内容')
    
    created_at = Column(DateTime, default=datetime.now)
    updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)

4.5 表关系定义

from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship

class Category(Base):
    """商品分类模型"""
    __tablename__ = 'categories'
    
    id = Column(Integer, primary_key=True)
    name = Column(String(100), nullable=False, comment='分类名称')
    description = Column(Text, comment='分类描述')
    
    # 一对多关系:一个分类有多个商品
    products = relationship("Product", back_populates="category")

class Product(Base):
    """商品模型(带关系)"""
    __tablename__ = 'products'
    
    id = Column(Integer, primary_key=True)
    name = Column(String(200), nullable=False)
    category_id = Column(Integer, ForeignKey('categories.id'), comment='分类ID')
    
    # 多对一关系:多个商品属于一个分类
    category = relationship("Category", back_populates="products")

# 多对多关系示例
from sqlalchemy import Table

# 中间表定义
user_role_association = Table(
    'user_roles',
    Base.metadata,
    Column('user_id', Integer, ForeignKey('users.id')),
    Column('role_id', Integer, ForeignKey('roles.id'))
)

class Role(Base):
    """角色模型"""
    __tablename__ = 'roles'
    
    id = Column(Integer, primary_key=True)
    name = Column(String(50), nullable=False, unique=True)
    description = Column(String(200))
    
    # 多对多关系:角色和用户
    users = relationship("User", secondary=user_role_association, back_populates="roles")

# 更新User模型,添加角色关系
User.roles = relationship("Role", secondary=user_role_association, back_populates="users")

5. 基本数据操作

5.1 创建表结构

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base

# 创建引擎
engine = create_engine('sqlite:///example.db', echo=True)

# 创建所有表
Base.metadata.create_all(engine)

# 删除所有表(谨慎使用)
# Base.metadata.drop_all(engine)

# 创建特定表
# User.__table__.create(engine, checkfirst=True)

# 删除特定表
# User.__table__.drop(engine, checkfirst=True)

5.2 会话管理

from sqlalchemy.orm import sessionmaker, scoped_session

# 创建会话类
Session = sessionmaker(bind=engine)

# 方式1:手动管理会话
session = Session()
try:
    # 数据库操作
    user = User(username='john', email='john@example.com')
    session.add(user)
    session.commit()
except Exception as e:
    session.rollback()
    print(f"操作失败: {e}")
finally:
    session.close()

# 方式2:使用上下文管理器(推荐)
with Session() as session:
    user = User(username='jane', email='jane@example.com')
    session.add(user)
    session.commit()
    # 自动关闭会话

# 方式3:使用scoped_session(线程安全)
SessionLocal = scoped_session(sessionmaker(bind=engine))

def get_session():
    """获取数据库会话"""
    return SessionLocal()

def close_session():
    """关闭会话"""
    SessionLocal.remove()

5.3 插入数据

# 单条数据插入
with Session() as session:
    # 创建用户对象
    new_user = User(
        username='alice',
        email='alice@example.com',
        full_name='Alice Smith',
        password_hash='hashed_password_here'
    )
    
    # 添加到会话
    session.add(new_user)
    
    # 提交事务
    session.commit()
    
    # 获取插入后的ID
    print(f"新用户ID: {new_user.id}")

# 批量插入
with Session() as session:
    users = [
        User(username='bob', email='bob@example.com', full_name='Bob Johnson'),
        User(username='charlie', email='charlie@example.com', full_name='Charlie Brown'),
        User(username='diana', email='diana@example.com', full_name='Diana Prince')
    ]
    
    # 批量添加
    session.add_all(users)
    session.commit()
    
    print(f"批量插入了 {len(users)} 个用户")

# 使用bulk_insert_mappings(高性能批量插入)
with Session() as session:
    user_data = [
        {'username': 'user1', 'email': 'user1@example.com', 'full_name': 'User One'},
        {'username': 'user2', 'email': 'user2@example.com', 'full_name': 'User Two'},
        {'username': 'user3', 'email': 'user3@example.com', 'full_name': 'User Three'}
    ]
    
    session.bulk_insert_mappings(User, user_data)
    session.commit()

5.4 查询数据

with Session() as session:
    # 查询所有用户
    all_users = session.query(User).all()
    print(f"总用户数: {len(all_users)}")
    
    # 查询第一个用户
    first_user = session.query(User).first()
    print(f"第一个用户: {first_user}")
    
    # 根据ID查询
    user_by_id = session.query(User).get(1)  # 主键查询
    if user_by_id:
        print(f"ID为1的用户: {user_by_id.username}")
    
    # 条件查询
    active_users = session.query(User).filter(User.is_active == True).all()
    print(f"活跃用户数: {len(active_users)}")
    
    # 单条记录查询(确保唯一)
    try:
        unique_user = session.query(User).filter(User.username == 'alice').one()
        print(f"找到用户: {unique_user.email}")
    except Exception as e:
        print(f"查询失败: {e}")
    
    # 可能为空的单条查询
    maybe_user = session.query(User).filter(User.username == 'nonexistent').one_or_none()
    if maybe_user:
        print(f"找到用户: {maybe_user.username}")
    else:
        print("用户不存在")

5.5 更新数据

with Session() as session:
    # 方式1:查询后更新
    user = session.query(User).filter(User.username == 'alice').first()
    if user:
        user.full_name = 'Alice Johnson'  # 修改属性
        user.updated_at = datetime.now()  # 更新时间
        session.commit()
        print(f"用户 {user.username} 信息已更新")
    
    # 方式2:批量更新
    updated_count = session.query(User).filter(
        User.is_active == True
    ).update({
        User.updated_at: datetime.now()
    })
    session.commit()
    print(f"批量更新了 {updated_count} 个用户")
    
    # 方式3:条件更新
    session.query(User).filter(
        User.created_at < datetime(2023, 1, 1)
    ).update({
        User.is_active: False,
        User.updated_at: datetime.now()
    })
    session.commit()

5.6 删除数据

with Session() as session:
    # 方式1:查询后删除
    user_to_delete = session.query(User).filter(User.username == 'bob').first()
    if user_to_delete:
        session.delete(user_to_delete)
        session.commit()
        print(f"用户 {user_to_delete.username} 已删除")
    
    # 方式2:批量删除
    deleted_count = session.query(User).filter(
        User.is_active == False
    ).delete()
    session.commit()
    print(f"批量删除了 {deleted_count} 个用户")
    
    # 方式3:条件删除
    session.query(User).filter(
        User.created_at < datetime(2022, 1, 1)
    ).delete()
    session.commit()

6. 复杂查询操作

6.1 条件查询

from sqlalchemy import and_, or_, not_, func
from sqlalchemy.sql import text

with Session() as session:
    # 基本条件查询
    users = session.query(User).filter(User.is_active == True).all()
    
    # 多条件查询(AND)
    users = session.query(User).filter(
        and_(
            User.is_active == True,
            User.created_at > datetime(2023, 1, 1)
        )
    ).all()
    
    # 或条件查询(OR)
    users = session.query(User).filter(
        or_(
            User.username.like('%admin%'),
            User.email.like('%@admin.com')
        )
    ).all()
    
    # 非条件查询(NOT)
    users = session.query(User).filter(
        not_(User.is_active == False)
    ).all()
    
    # IN查询
    user_ids = [1, 2, 3, 4, 5]
    users = session.query(User).filter(User.id.in_(user_ids)).all()
    
    # NOT IN查询
    users = session.query(User).filter(~User.id.in_(user_ids)).all()
    
    # LIKE模糊查询
    users = session.query(User).filter(User.username.like('%john%')).all()
    
    # ILIKE不区分大小写查询(PostgreSQL)
    users = session.query(User).filter(User.username.ilike('%JOHN%')).all()
    
    # BETWEEN范围查询
    users = session.query(User).filter(
        User.created_at.between(datetime(2023, 1, 1), datetime(2023, 12, 31))
    ).all()
    
    # IS NULL查询
    users = session.query(User).filter(User.full_name.is_(None)).all()
    
    # IS NOT NULL查询
    users = session.query(User).filter(User.full_name.isnot(None)).all()
    
    # 原生SQL条件
    users = session.query(User).filter(
        text("username LIKE :pattern")
    ).params(pattern='%admin%').all()

6.2 排序和分页

with Session() as session:
    # 升序排序
    users = session.query(User).order_by(User.created_at).all()
    
    # 降序排序
    users = session.query(User).order_by(User.created_at.desc()).all()
    
    # 多字段排序
    users = session.query(User).order_by(
        User.is_active.desc(),  # 先按活跃状态降序
        User.created_at.asc()   # 再按创建时间升序
    ).all()
    
    # 限制结果数量
    recent_users = session.query(User).order_by(
        User.created_at.desc()
    ).limit(10).all()
    
    # 跳过指定数量的记录
    users = session.query(User).offset(20).limit(10).all()
    
    # 分页查询
    page = 1
    per_page = 10
    users = session.query(User).offset(
        (page - 1) * per_page
    ).limit(per_page).all()
    
    # 使用slice进行分页(更Pythonic)
    users = session.query(User)[20:30]  # 获取第21-30条记录

6.3 聚合查询

from sqlalchemy import func, distinct

with Session() as session:
    # 计数查询
    total_users = session.query(func.count(User.id)).scalar()
    print(f"总用户数: {total_users}")
    
    # 去重计数
    unique_emails = session.query(func.count(distinct(User.email))).scalar()
    print(f"唯一邮箱数: {unique_emails}")
    
    # 最大值、最小值
    latest_user = session.query(func.max(User.created_at)).scalar()
    earliest_user = session.query(func.min(User.created_at)).scalar()
    
    # 平均值(假设User有age字段)
    # avg_age = session.query(func.avg(User.age)).scalar()
    
    # 求和
    # total_age = session.query(func.sum(User.age)).scalar()
    
    # 分组查询
    user_counts_by_status = session.query(
        User.is_active,
        func.count(User.id).label('count')
    ).group_by(User.is_active).all()
    
    for is_active, count in user_counts_by_status:
        status = "活跃" if is_active else "非活跃"
        print(f"{status}用户数: {count}")
    
    # HAVING子句(分组后筛选)
    active_email_domains = session.query(
        func.substr(User.email, func.instr(User.email, '@') + 1).label('domain'),
        func.count(User.id).label('count')
    ).filter(
        User.is_active == True
    ).group_by(
        func.substr(User.email, func.instr(User.email, '@') + 1)
    ).having(
        func.count(User.id) > 1  # 只显示用户数大于1的域名
    ).all()

6.4 连接查询(JOIN)

with Session() as session:
    # 内连接(INNER JOIN)
    results = session.query(User, Product).join(
        Product, User.id == Product.user_id
    ).all()
    
    # 左外连接(LEFT OUTER JOIN)
    results = session.query(User, Product).outerjoin(
        Product, User.id == Product.user_id
    ).all()
    
    # 使用relationship进行连接
    results = session.query(User).join(User.products).all()
    
    # 连接查询with条件
    results = session.query(User, Product).join(
        Product
    ).filter(
        Product.price > 100
    ).all()
    
    # 多表连接
    results = session.query(User, Product, Category).join(
        Product
    ).join(
        Category
    ).filter(
        Category.name == 'Electronics'
    ).all()
    
    # 自连接(查找同一表中的关联数据)
    # 假设User表有manager_id字段
    manager_alias = aliased(User)
    results = session.query(User, manager_alias).join(
        manager_alias, User.manager_id == manager_alias.id
    ).all()

6.5 子查询

from sqlalchemy.orm import aliased

with Session() as session:
    # 标量子查询
    avg_price_subquery = session.query(
        func.avg(Product.price)
    ).scalar_subquery()
    
    expensive_products = session.query(Product).filter(
        Product.price > avg_price_subquery
    ).all()
    
    # 表子查询
    user_product_count = session.query(
        Product.user_id,
        func.count(Product.id).label('product_count')
    ).group_by(Product.user_id).subquery()
    
    # 使用子查询结果
    productive_users = session.query(User).join(
        user_product_count, User.id == user_product_count.c.user_id
    ).filter(
        user_product_count.c.product_count > 5
    ).all()
    
    # EXISTS子查询
    from sqlalchemy.sql import exists
    
    users_with_products = session.query(User).filter(
        exists().where(Product.user_id == User.id)
    ).all()
    
    # NOT EXISTS子查询
    users_without_products = session.query(User).filter(
        ~exists().where(Product.user_id == User.id)
    ).all()
    
    # IN子查询
    active_user_ids = session.query(User.id).filter(
        User.is_active == True
    ).subquery()
    
    products_by_active_users = session.query(Product).filter(
        Product.user_id.in_(active_user_ids)
    ).all()

6.6 窗口函数(高级查询)

from sqlalchemy import func

with Session() as session:
    # ROW_NUMBER() 窗口函数
    results = session.query(
        User.id,
        User.username,
        func.row_number().over(
            order_by=User.created_at.desc()
        ).label('row_num')
    ).all()
    
    # RANK() 窗口函数
    results = session.query(
        Product.id,
        Product.name,
        Product.price,
        func.rank().over(
            order_by=Product.price.desc()
        ).label('price_rank')
    ).all()
    
    # 分区窗口函数
    results = session.query(
        Product.id,
        Product.name,
        Product.category_id,
        Product.price,
        func.row_number().over(
            partition_by=Product.category_id,
            order_by=Product.price.desc()
        ).label('category_rank')
    ).all()
    
    # LAG/LEAD 窗口函数(获取前一行/后一行数据)
    results = session.query(
        User.id,
        User.username,
        User.created_at,
        func.lag(User.created_at, 1).over(
            order_by=User.created_at
        ).label('prev_created_at')
    ).all()

7. 高级特性

7.1 事务管理

from sqlalchemy.exc import SQLAlchemyError

# 基本事务管理
with Session() as session:
    try:
        # 开始事务(自动开始)
        user1 = User(username='user1', email='user1@example.com')
        user2 = User(username='user2', email='user2@example.com')
        
        session.add(user1)
        session.add(user2)
        
        # 提交事务
        session.commit()
        print("事务提交成功")
        
    except SQLAlchemyError as e:
        # 回滚事务
        session.rollback()
        print(f"事务回滚: {e}")
        raise

# 嵌套事务(保存点)
with Session() as session:
    try:
        user = User(username='main_user', email='main@example.com')
        session.add(user)
        
        # 创建保存点
        savepoint = session.begin_nested()
        try:
            # 可能失败的操作
            risky_user = User(username='risky', email='invalid_email')
            session.add(risky_user)
            session.flush()  # 强制执行SQL但不提交
            
            # 如果成功,提交保存点
            savepoint.commit()
        except Exception as e:
            # 回滚到保存点
            savepoint.rollback()
            print(f"保存点回滚: {e}")
        
        # 主事务提交
        session.commit()
        
    except Exception as e:
        session.rollback()
        print(f"主事务回滚: {e}")

7.2 懒加载与预加载

from sqlalchemy.orm import joinedload, selectinload, subqueryload

# 懒加载(默认行为)
with Session() as session:
    user = session.query(User).first()
    # 访问关联对象时才执行查询
    products = user.products  # 这里会触发额外的SQL查询

# 预加载 - joinedload(使用JOIN)
with Session() as session:
    users = session.query(User).options(
        joinedload(User.products)
    ).all()
    # 一次查询获取用户和产品数据
    for user in users:
        print(f"用户 {user.username}{len(user.products)} 个产品")

# 预加载 - selectinload(使用IN查询)
with Session() as session:
    users = session.query(User).options(
        selectinload(User.products)
    ).all()
    # 两次查询:先查用户,再用IN查询产品

# 预加载 - subqueryload(使用子查询)
with Session() as session:
    users = session.query(User).options(
        subqueryload(User.products)
    ).all()
    # 使用子查询预加载关联数据

# 多层预加载
with Session() as session:
    users = session.query(User).options(
        joinedload(User.products).joinedload(Product.category)
    ).all()
    # 一次查询获取用户、产品和分类数据

# 选择性预加载
with Session() as session:
    users = session.query(User).options(
        selectinload(User.products).selectinload(Product.reviews)
    ).filter(User.is_active == True).all()

7.3 缓存机制

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session

# 一级缓存(Session级别)
with Session() as session:
    # 第一次查询,从数据库获取
    user1 = session.query(User).get(1)
    
    # 第二次查询,从Session缓存获取
    user2 = session.query(User).get(1)
    
    # 两个对象是同一个实例
    assert user1 is user2
    print("Session缓存生效")

# 二级缓存(需要额外配置)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session
from sqlalchemy import event

# 简单的内存缓存实现
class SimpleCache:
    def __init__(self):
        self._cache = {}
    
    def get(self, key):
        return self._cache.get(key)
    
    def set(self, key, value):
        self._cache[key] = value
    
    def delete(self, key):
        self._cache.pop(key, None)

cache = SimpleCache()

# 查询缓存装饰器
def cached_query(cache_key):
    def decorator(func):
        def wrapper(*args, **kwargs):
            # 尝试从缓存获取
            result = cache.get(cache_key)
            if result is not None:
                print(f"从缓存获取: {cache_key}")
                return result
            
            # 执行查询
            result = func(*args, **kwargs)
            
            # 存入缓存
            cache.set(cache_key, result)
            print(f"存入缓存: {cache_key}")
            return result
        return wrapper
    return decorator

@cached_query('all_active_users')
def get_active_users(session):
    return session.query(User).filter(User.is_active == True).all()

7.4 数据库迁移

# 使用Alembic进行数据库迁移
# 首先安装: pip install alembic

# 初始化迁移环境
# alembic init migrations

# 创建迁移脚本
# alembic revision --autogenerate -m "创建用户表"

# 执行迁移
# alembic upgrade head

# 回滚迁移
# alembic downgrade -1

# 编程方式创建迁移
from alembic import command
from alembic.config import Config

def run_migrations():
    """运行数据库迁移"""
    alembic_cfg = Config("alembic.ini")
    command.upgrade(alembic_cfg, "head")

def create_migration(message):
    """创建新的迁移文件"""
    alembic_cfg = Config("alembic.ini")
    command.revision(alembic_cfg, autogenerate=True, message=message)

7.5 异步支持

# SQLAlchemy 1.4+ 异步支持
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
import asyncio

# 创建异步引擎
async_engine = create_async_engine(
    "postgresql+asyncpg://user:password@localhost/dbname",
    echo=True
)

# 创建异步会话
AsyncSessionLocal = async_sessionmaker(
    async_engine,
    class_=AsyncSession,
    expire_on_commit=False
)

# 异步数据操作
async def create_user_async(username: str, email: str):
    """异步创建用户"""
    async with AsyncSessionLocal() as session:
        user = User(username=username, email=email)
        session.add(user)
        await session.commit()
        return user

async def get_users_async():
    """异步获取用户列表"""
    async with AsyncSessionLocal() as session:
        result = await session.execute(
            select(User).filter(User.is_active == True)
        )
        return result.scalars().all()

# 运行异步函数
async def main():
    # 创建表
    async with async_engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    
    # 创建用户
    user = await create_user_async("async_user", "async@example.com")
    print(f"创建用户: {user.username}")
    
    # 查询用户
    users = await get_users_async()
    print(f"查询到 {len(users)} 个用户")

# 运行异步代码
# asyncio.run(main())

8. 实战项目示例

8.1 项目结构

blog_project/
├── main.py              # 应用入口
├── database.py          # 数据库配置
├── models.py           # 数据模型
├── schemas.py          # 数据验证模式
├── crud.py             # 数据库操作
├── api.py              # API接口
└── requirements.txt    # 依赖包

8.2 数据库配置 (database.py)

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import os

# 数据库URL配置
DATABASE_URL = os.getenv(
    "DATABASE_URL", 
    "sqlite:///./blog.db"
)

# 创建数据库引擎
engine = create_engine(
    DATABASE_URL,
    echo=True,  # 开发环境显示SQL
    pool_pre_ping=True,
    pool_recycle=3600
)

# 创建会话工厂
SessionLocal = sessionmaker(
    autocommit=False,
    autoflush=False,
    bind=engine
)

# 创建基类
Base = declarative_base()

# 依赖注入:获取数据库会话
def get_db():
    """获取数据库会话"""
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 创建所有表
def create_tables():
    """创建数据库表"""
    Base.metadata.create_all(bind=engine)

# 删除所有表
def drop_tables():
    """删除数据库表"""
    Base.metadata.drop_all(bind=engine)

8.3 数据模型 (models.py)

from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, Table
from sqlalchemy.orm import relationship
from datetime import datetime
from database import Base

# 多对多关系中间表:文章标签
article_tags = Table(
    'article_tags',
    Base.metadata,
    Column('article_id', Integer, ForeignKey('articles.id')),
    Column('tag_id', Integer, ForeignKey('tags.id'))
)

class User(Base):
    """用户模型"""
    __tablename__ = 'users'
    
    id = Column(Integer, primary_key=True, index=True)
    username = Column(String(50), unique=True, index=True, nullable=False)
    email = Column(String(100), unique=True, index=True, nullable=False)
    hashed_password = Column(String(100), nullable=False)
    full_name = Column(String(100))
    is_active = Column(Boolean, default=True)
    is_superuser = Column(Boolean, default=False)
    created_at = Column(DateTime, default=datetime.now)
    updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
    
    # 关系:一个用户可以有多篇文章
    articles = relationship("Article", back_populates="author", cascade="all, delete-orphan")
    comments = relationship("Comment", back_populates="author", cascade="all, delete-orphan")
    
    def __repr__(self):
        return f"<User(id={self.id}, username='{self.username}')>"

class Category(Base):
    """分类模型"""
    __tablename__ = 'categories'
    
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String(50), unique=True, nullable=False)
    description = Column(Text)
    created_at = Column(DateTime, default=datetime.now)
    
    # 关系:一个分类可以有多篇文章
    articles = relationship("Article", back_populates="category")
    
    def __repr__(self):
        return f"<Category(id={self.id}, name='{self.name}')>"

class Tag(Base):
    """标签模型"""
    __tablename__ = 'tags'
    
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String(30), unique=True, nullable=False)
    created_at = Column(DateTime, default=datetime.now)
    
    # 多对多关系:标签和文章
    articles = relationship("Article", secondary=article_tags, back_populates="tags")
    
    def __repr__(self):
        return f"<Tag(id={self.id}, name='{self.name}')>"

class Article(Base):
    """文章模型"""
    __tablename__ = 'articles'
    
    id = Column(Integer, primary_key=True, index=True)
    title = Column(String(200), nullable=False, index=True)
    content = Column(Text, nullable=False)
    summary = Column(String(500))
    is_published = Column(Boolean, default=False)
    view_count = Column(Integer, default=0)
    created_at = Column(DateTime, default=datetime.now, index=True)
    updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
    published_at = Column(DateTime)
    
    # 外键
    author_id = Column(Integer, ForeignKey('users.id'), nullable=False)
    category_id = Column(Integer, ForeignKey('categories.id'))
    
    # 关系
    author = relationship("User", back_populates="articles")
    category = relationship("Category", back_populates="articles")
    tags = relationship("Tag", secondary=article_tags, back_populates="articles")
    comments = relationship("Comment", back_populates="article", cascade="all, delete-orphan")
    
    def __repr__(self):
        return f"<Article(id={self.id}, title='{self.title[:30]}...')>"

class Comment(Base):
    """评论模型"""
    __tablename__ = 'comments'
    
    id = Column(Integer, primary_key=True, index=True)
    content = Column(Text, nullable=False)
    is_approved = Column(Boolean, default=False)
    created_at = Column(DateTime, default=datetime.now)
    
    # 外键
    article_id = Column(Integer, ForeignKey('articles.id'), nullable=False)
    author_id = Column(Integer, ForeignKey('users.id'), nullable=False)
    parent_id = Column(Integer, ForeignKey('comments.id'))  # 回复评论
    
    # 关系
    article = relationship("Article", back_populates="comments")
    author = relationship("User", back_populates="comments")
    parent = relationship("Comment", remote_side=[id])  # 自引用关系
    
    def __repr__(self):
        return f"<Comment(id={self.id}, content='{self.content[:30]}...')>"

8.4 数据操作层 (crud.py)

from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func, desc
from models import User, Article, Category, Tag, Comment
from typing import List, Optional
from datetime import datetime

class UserCRUD:
    """用户数据操作"""
    
    @staticmethod
    def create_user(db: Session, username: str, email: str, password: str, full_name: str = None) -> User:
        """创建用户"""
        user = User(
            username=username,
            email=email,
            hashed_password=password,  # 实际应用中需要加密
            full_name=full_name
        )
        db.add(user)
        db.commit()
        db.refresh(user)
        return user
    
    @staticmethod
    def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
        """根据ID获取用户"""
        return db.query(User).filter(User.id == user_id).first()
    
    @staticmethod
    def get_user_by_username(db: Session, username: str) -> Optional[User]:
        """根据用户名获取用户"""
        return db.query(User).filter(User.username == username).first()
    
    @staticmethod
    def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
        """获取用户列表"""
        return db.query(User).offset(skip).limit(limit).all()
    
    @staticmethod
    def update_user(db: Session, user_id: int, **kwargs) -> Optional[User]:
        """更新用户信息"""
        user = db.query(User).filter(User.id == user_id).first()
        if user:
            for key, value in kwargs.items():
                if hasattr(user, key):
                    setattr(user, key, value)
            user.updated_at = datetime.now()
            db.commit()
            db.refresh(user)
        return user
    
    @staticmethod
    def delete_user(db: Session, user_id: int) -> bool:
        """删除用户"""
        user = db.query(User).filter(User.id == user_id).first()
        if user:
            db.delete(user)
            db.commit()
            return True
        return False

class ArticleCRUD:
    """文章数据操作"""
    
    @staticmethod
    def create_article(db: Session, title: str, content: str, author_id: int, 
                      category_id: int = None, tag_names: List[str] = None) -> Article:
        """创建文章"""
        article = Article(
            title=title,
            content=content,
            author_id=author_id,
            category_id=category_id,
            summary=content[:200] + "..." if len(content) > 200 else content
        )
        
        # 处理标签
        if tag_names:
            for tag_name in tag_names:
                tag = db.query(Tag).filter(Tag.name == tag_name).first()
                if not tag:
                    tag = Tag(name=tag_name)
                    db.add(tag)
                article.tags.append(tag)
        
        db.add(article)
        db.commit()
        db.refresh(article)
        return article
    
    @staticmethod
    def get_article_by_id(db: Session, article_id: int) -> Optional[Article]:
        """根据ID获取文章"""
        return db.query(Article).filter(Article.id == article_id).first()
    
    @staticmethod
    def get_articles(db: Session, skip: int = 0, limit: int = 20, 
                    published_only: bool = True) -> List[Article]:
        """获取文章列表"""
        query = db.query(Article)
        if published_only:
            query = query.filter(Article.is_published == True)
        return query.order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
    
    @staticmethod
    def get_articles_by_category(db: Session, category_id: int, 
                               skip: int = 0, limit: int = 20) -> List[Article]:
        """根据分类获取文章"""
        return db.query(Article).filter(
            and_(Article.category_id == category_id, Article.is_published == True)
        ).order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
    
    @staticmethod
    def get_articles_by_tag(db: Session, tag_name: str, 
                          skip: int = 0, limit: int = 20) -> List[Article]:
        """根据标签获取文章"""
        return db.query(Article).join(Article.tags).filter(
            and_(Tag.name == tag_name, Article.is_published == True)
        ).order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
    
    @staticmethod
    def search_articles(db: Session, keyword: str, skip: int = 0, limit: int = 20) -> List[Article]:
        """搜索文章"""
        return db.query(Article).filter(
            and_(
                or_(
                    Article.title.contains(keyword),
                    Article.content.contains(keyword)
                ),
                Article.is_published == True
            )
        ).order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
    
    @staticmethod
    def publish_article(db: Session, article_id: int) -> Optional[Article]:
        """发布文章"""
        article = db.query(Article).filter(Article.id == article_id).first()
        if article:
            article.is_published = True
            article.published_at = datetime.now()
            db.commit()
            db.refresh(article)
        return article
    
    @staticmethod
    def increment_view_count(db: Session, article_id: int) -> Optional[Article]:
        """增加文章浏览量"""
        article = db.query(Article).filter(Article.id == article_id).first()
        if article:
            article.view_count += 1
            db.commit()
            db.refresh(article)
        return article

class StatisticsCRUD:
    """统计数据操作"""
    
    @staticmethod
    def get_article_stats(db: Session) -> dict:
        """获取文章统计信息"""
        total_articles = db.query(func.count(Article.id)).scalar()
        published_articles = db.query(func.count(Article.id)).filter(
            Article.is_published == True
        ).scalar()
        total_views = db.query(func.sum(Article.view_count)).scalar() or 0
        
        return {
            "total_articles": total_articles,
            "published_articles": published_articles,
            "draft_articles": total_articles - published_articles,
            "total_views": total_views
        }
    
    @staticmethod
    def get_popular_articles(db: Session, limit: int = 10) -> List[Article]:
        """获取热门文章"""
        return db.query(Article).filter(
            Article.is_published == True
        ).order_by(desc(Article.view_count)).limit(limit).all()
    
    @staticmethod
    def get_category_stats(db: Session) -> List[dict]:
        """获取分类统计"""
        results = db.query(
            Category.name,
            func.count(Article.id).label('article_count')
        ).outerjoin(Article).group_by(Category.id, Category.name).all()
        
        return [
            {"category": name, "count": count}
            for name, count in results
        ]
    
    @staticmethod
    def get_monthly_article_stats(db: Session, year: int) -> List[dict]:
        """获取月度文章统计"""
        results = db.query(
            func.extract('month', Article.created_at).label('month'),
            func.count(Article.id).label('count')
        ).filter(
            and_(
                func.extract('year', Article.created_at) == year,
                Article.is_published == True
            )
        ).group_by(func.extract('month', Article.created_at)).all()
        
        return [
            {"month": int(month), "count": count}
            for month, count in results
        ]

8.5 API接口层 (api.py)

from fastapi import FastAPI, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from database import get_db, create_tables
from crud import UserCRUD, ArticleCRUD, StatisticsCRUD
from typing import List, Optional
import uvicorn

# 创建FastAPI应用
app = FastAPI(
    title="博客系统API",
    description="基于SQLAlchemy的博客系统",
    version="1.0.0"
)

# 启动时创建表
@app.on_event("startup")
def startup_event():
    create_tables()

# 用户相关接口
@app.post("/users/", summary="创建用户")
def create_user(
    username: str,
    email: str,
    password: str,
    full_name: Optional[str] = None,
    db: Session = Depends(get_db)
):
    """创建新用户"""
    # 检查用户名是否已存在
    existing_user = UserCRUD.get_user_by_username(db, username)
    if existing_user:
        raise HTTPException(status_code=400, detail="用户名已存在")
    
    user = UserCRUD.create_user(db, username, email, password, full_name)
    return {"message": "用户创建成功", "user_id": user.id}

@app.get("/users/{user_id}", summary="获取用户信息")
def get_user(user_id: int, db: Session = Depends(get_db)):
    """根据ID获取用户信息"""
    user = UserCRUD.get_user_by_id(db, user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    
    return {
        "id": user.id,
        "username": user.username,
        "email": user.email,
        "full_name": user.full_name,
        "is_active": user.is_active,
        "created_at": user.created_at
    }

@app.get("/users/", summary="获取用户列表")
def get_users(
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    db: Session = Depends(get_db)
):
    """获取用户列表"""
    users = UserCRUD.get_users(db, skip, limit)
    return {
        "users": [
            {
                "id": user.id,
                "username": user.username,
                "email": user.email,
                "full_name": user.full_name,
                "is_active": user.is_active
            }
            for user in users
        ],
        "total": len(users)
    }

# 文章相关接口
@app.post("/articles/", summary="创建文章")
def create_article(
    title: str,
    content: str,
    author_id: int,
    category_id: Optional[int] = None,
    tag_names: Optional[List[str]] = None,
    db: Session = Depends(get_db)
):
    """创建新文章"""
    # 验证作者是否存在
    author = UserCRUD.get_user_by_id(db, author_id)
    if not author:
        raise HTTPException(status_code=404, detail="作者不存在")
    
    article = ArticleCRUD.create_article(
        db, title, content, author_id, category_id, tag_names or []
    )
    return {"message": "文章创建成功", "article_id": article.id}

@app.get("/articles/{article_id}", summary="获取文章详情")
def get_article(article_id: int, db: Session = Depends(get_db)):
    """获取文章详情"""
    article = ArticleCRUD.get_article_by_id(db, article_id)
    if not article:
        raise HTTPException(status_code=404, detail="文章不存在")
    
    # 增加浏览量
    ArticleCRUD.increment_view_count(db, article_id)
    
    return {
        "id": article.id,
        "title": article.title,
        "content": article.content,
        "summary": article.summary,
        "is_published": article.is_published,
        "view_count": article.view_count,
        "created_at": article.created_at,
        "author": {
            "id": article.author.id,
            "username": article.author.username
        },
        "category": {
            "id": article.category.id,
            "name": article.category.name
        } if article.category else None,
        "tags": [
            {"id": tag.id, "name": tag.name}
            for tag in article.tags
        ]
    }

@app.get("/articles/", summary="获取文章列表")
def get_articles(
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    published_only: bool = Query(True),
    db: Session = Depends(get_db)
):
    """获取文章列表"""
    articles = ArticleCRUD.get_articles(db, skip, limit, published_only)
    return {
        "articles": [
            {
                "id": article.id,
                "title": article.title,
                "summary": article.summary,
                "view_count": article.view_count,
                "created_at": article.created_at,
                "author": article.author.username
            }
            for article in articles
        ],
        "total": len(articles)
    }

@app.get("/search/articles/", summary="搜索文章")
def search_articles(
    keyword: str = Query(..., min_length=1),
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    db: Session = Depends(get_db)
):
    """搜索文章"""
    articles = ArticleCRUD.search_articles(db, keyword, skip, limit)
    return {
        "keyword": keyword,
        "articles": [
            {
                "id": article.id,
                "title": article.title,
                "summary": article.summary,
                "view_count": article.view_count,
                "created_at": article.created_at
            }
            for article in articles
        ],
        "total": len(articles)
    }

# 统计相关接口
@app.get("/statistics/overview", summary="获取统计概览")
def get_statistics_overview(db: Session = Depends(get_db)):
    """获取统计概览"""
    stats = StatisticsCRUD.get_article_stats(db)
    category_stats = StatisticsCRUD.get_category_stats(db)
    popular_articles = StatisticsCRUD.get_popular_articles(db, 5)
    
    return {
        "article_stats": stats,
        "category_stats": category_stats,
        "popular_articles": [
            {
                "id": article.id,
                "title": article.title,
                "view_count": article.view_count
            }
            for article in popular_articles
        ]
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

8.6 应用入口 (main.py)

from database import create_tables, get_db
from crud import UserCRUD, ArticleCRUD
from models import User, Article, Category, Tag
from datetime import datetime

def init_sample_data():
    """初始化示例数据"""
    # 获取数据库会话
    db = next(get_db())
    
    try:
        # 创建示例用户
        admin_user = UserCRUD.create_user(
            db, "admin", "admin@blog.com", "hashed_password", "管理员"
        )
        
        author_user = UserCRUD.create_user(
            db, "author", "author@blog.com", "hashed_password", "作者"
        )
        
        # 创建示例分类
        tech_category = Category(name="技术", description="技术相关文章")
        life_category = Category(name="生活", description="生活随笔")
        
        db.add_all([tech_category, life_category])
        db.commit()
        
        # 创建示例文章
        article1 = ArticleCRUD.create_article(
            db,
            title="SQLAlchemy入门教程",
            content="这是一篇关于SQLAlchemy的详细教程...",
            author_id=author_user.id,
            category_id=tech_category.id,
            tag_names=["Python", "数据库", "ORM"]
        )
        
        article2 = ArticleCRUD.create_article(
            db,
            title="我的编程之路",
            content="分享我学习编程的心得体会...",
            author_id=author_user.id,
            category_id=life_category.id,
            tag_names=["编程", "心得"]
        )
        
        # 发布文章
        ArticleCRUD.publish_article(db, article1.id)
        ArticleCRUD.publish_article(db, article2.id)
        
        print("示例数据初始化完成!")
        
    except Exception as e:
        print(f"初始化数据失败: {e}")
        db.rollback()
    finally:
        db.close()

def main():
    """主函数"""
    print("=== SQLAlchemy博客系统 ===")
    
    # 创建数据库表
    print("创建数据库表...")
    create_tables()
    
    # 初始化示例数据
    print("初始化示例数据...")
    init_sample_data()
    
    print("\n系统启动完成!")
    print("API文档地址: http://localhost:8000/docs")
    print("启动API服务器: python api.py")

if __name__ == "__main__":
    main()

9. 性能优化与最佳实践

9.1 查询优化

# 1. 使用索引
class User(Base):
    __tablename__ = 'users'
    
    id = Column(Integer, primary_key=True)
    username = Column(String(50), index=True)  # 单列索引
    email = Column(String(100), index=True)
    created_at = Column(DateTime, index=True)
    
    # 复合索引
    __table_args__ = (
        Index('idx_username_email', 'username', 'email'),
        Index('idx_active_created', 'is_active', 'created_at'),
    )

# 2. 预加载关联数据
from sqlalchemy.orm import joinedload, selectinload

# 避免N+1查询问题
with Session() as session:
    # 错误方式:会产生N+1查询
    users = session.query(User).all()
    for user in users:
        print(f"{user.username}: {len(user.articles)} 篇文章")  # 每次都查询
    
    # 正确方式:使用预加载
    users = session.query(User).options(
        selectinload(User.articles)
    ).all()
    for user in users:
        print(f"{user.username}: {len(user.articles)} 篇文章")  # 不会额外查询

# 3. 使用批量操作
with Session() as session:
    # 批量插入
    users_data = [
        {'username': f'user{i}', 'email': f'user{i}@example.com'}
        for i in range(1000)
    ]
    session.bulk_insert_mappings(User, users_data)
    
    # 批量更新
    session.query(User).filter(
        User.created_at < datetime(2023, 1, 1)
    ).update({'is_active': False})
    
    session.commit()

# 4. 使用原生SQL进行复杂查询
from sqlalchemy import text

with Session() as session:
    # 复杂统计查询使用原生SQL
    result = session.execute(text("""
        SELECT 
            c.name as category_name,
            COUNT(a.id) as article_count,
            AVG(a.view_count) as avg_views
        FROM categories c
        LEFT JOIN articles a ON c.id = a.category_id
        WHERE a.is_published = true
        GROUP BY c.id, c.name
        ORDER BY article_count DESC
    """))
    
    for row in result:
        print(f"分类: {row.category_name}, 文章数: {row.article_count}, 平均浏览: {row.avg_views}")

9.2 连接池优化

from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool

# 生产环境连接池配置
engine = create_engine(
    DATABASE_URL,
    # 连接池配置
    poolclass=QueuePool,
    pool_size=20,           # 连接池大小
    max_overflow=30,        # 最大溢出连接
    pool_timeout=60,        # 获取连接超时
    pool_recycle=3600,      # 连接回收时间
    pool_pre_ping=True,     # 连接前检查
    
    # 性能优化
    echo=False,             # 生产环境关闭SQL日志
    future=True,            # 使用2.0风格API
    
    # 连接参数
    connect_args={
        "charset": "utf8mb4",
        "autocommit": False,
        "check_same_thread": False  # SQLite专用
    }
)

# 监控连接池状态
def monitor_pool_status(engine):
    """监控连接池状态"""
    pool = engine.pool
    print(f"连接池大小: {pool.size()}")
    print(f"已检出连接: {pool.checkedout()}")
    print(f"溢出连接: {pool.overflow()}")
    print(f"无效连接: {pool.invalidated()}")

9.3 缓存策略

import redis
import json
from functools import wraps
from typing import Any, Optional

# Redis缓存配置
redis_client = redis.Redis(
    host='localhost',
    port=6379,
    db=0,
    decode_responses=True
)

def cache_result(key_prefix: str, expire_time: int = 3600):
    """结果缓存装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 生成缓存键
            cache_key = f"{key_prefix}:{hash(str(args) + str(kwargs))}"
            
            # 尝试从缓存获取
            cached_result = redis_client.get(cache_key)
            if cached_result:
                return json.loads(cached_result)
            
            # 执行函数
            result = func(*args, **kwargs)
            
            # 存入缓存
            redis_client.setex(
                cache_key,
                expire_time,
                json.dumps(result, default=str)
            )
            
            return result
        return wrapper
    return decorator

# 使用缓存
class CachedArticleCRUD(ArticleCRUD):
    """带缓存的文章操作"""
    
    @staticmethod
    @cache_result("hot_articles", 1800)  # 缓存30分钟
    def get_hot_articles(db: Session, limit: int = 10) -> List[dict]:
        """获取热门文章(带缓存)"""
        articles = db.query(Article).filter(
            Article.is_published == True
        ).order_by(Article.view_count.desc()).limit(limit).all()
        
        return [
            {
                "id": article.id,
                "title": article.title,
                "view_count": article.view_count
            }
            for article in articles
        ]
    
    @staticmethod
    def invalidate_article_cache(article_id: int):
        """清除文章相关缓存"""
        patterns = [
            f"article:{article_id}:*",
            "hot_articles:*",
            "recent_articles:*"
        ]
        
        for pattern in patterns:
            keys = redis_client.keys(pattern)
            if keys:
                redis_client.delete(*keys)

9.4 数据库分页优化

from sqlalchemy import func
from typing import Tuple, List

class PaginationHelper:
    """分页助手类"""
    
    @staticmethod
    def paginate_query(query, page: int, per_page: int) -> Tuple[List, dict]:
        """查询分页"""
        # 计算总数
        total = query.count()
        
        # 计算分页信息
        total_pages = (total + per_page - 1) // per_page
        has_prev = page > 1
        has_next = page < total_pages
        
        # 获取当前页数据
        items = query.offset((page - 1) * per_page).limit(per_page).all()
        
        pagination_info = {
            "page": page,
            "per_page": per_page,
            "total": total,
            "total_pages": total_pages,
            "has_prev": has_prev,
            "has_next": has_next,
            "prev_page": page - 1 if has_prev else None,
            "next_page": page + 1 if has_next else None
        }
        
        return items, pagination_info
    
    @staticmethod
    def cursor_paginate(query, cursor_field, cursor_value=None, limit: int = 20, desc: bool = True):
        """游标分页(适合大数据量)"""
        if cursor_value is not None:
            if desc:
                query = query.filter(cursor_field < cursor_value)
            else:
                query = query.filter(cursor_field > cursor_value)
        
        if desc:
            query = query.order_by(cursor_field.desc())
        else:
            query = query.order_by(cursor_field.asc())
        
        items = query.limit(limit + 1).all()
        
        has_more = len(items) > limit
        if has_more:
            items = items[:-1]
        
        next_cursor = None
        if has_more and items:
            next_cursor = getattr(items[-1], cursor_field.name)
        
        return items, {
            "has_more": has_more,
            "next_cursor": next_cursor,
            "limit": limit
        }

# 使用示例
with Session() as session:
    # 传统分页
    query = session.query(Article).filter(Article.is_published == True)
    articles, pagination = PaginationHelper.paginate_query(query, page=1, per_page=20)
    
    # 游标分页(适合实时数据)
    articles, cursor_info = PaginationHelper.cursor_paginate(
        query, Article.created_at, limit=20
    )

10. 常见问题与解决方案

10.1 常见错误及解决方案

# 1. 解决"DetachedInstanceError"错误
from sqlalchemy.orm import make_transient

with Session() as session:
    user = session.query(User).first()
    # 会话关闭后访问关联对象会报错
    
# 解决方案1:在会话内访问所有需要的数据
with Session() as session:
    user = session.query(User).options(
        joinedload(User.articles)
    ).first()
    # 在会话内访问关联数据
    articles = user.articles

# 解决方案2:使用expunge_all()或merge()
with Session() as session:
    user = session.query(User).first()
    session.expunge(user)  # 从会话中移除对象
    
# 在新会话中重新附加
with Session() as session:
    user = session.merge(user)  # 重新附加到会话
    articles = user.articles

# 2. 解决"IntegrityError"约束违反错误
try:
    with Session() as session:
        user = User(username="existing_user", email="test@example.com")
        session.add(user)
        session.commit()
except IntegrityError as e:
    print(f"约束违反: {e}")
    # 处理重复数据或约束违反

# 3. 解决"PendingRollbackError"错误
try:
    with Session() as session:
        # 可能出错的操作
        user = User(username=None)  # 违反非空约束
        session.add(user)
        session.commit()
except Exception as e:
    # 必须回滚事务
    session.rollback()
    print(f"操作失败,已回滚: {e}")

# 4. 解决"StatementError"参数错误
# 错误方式
# session.query(User).filter(User.id == None)  # 应该使用is_(None)

# 正确方式
session.query(User).filter(User.id.is_(None))
session.query(User).filter(User.id.isnot(None))

10.2 性能问题诊断

import time
from sqlalchemy import event
from sqlalchemy.engine import Engine

# SQL执行时间监控
@event.listens_for(Engine, "before_cursor_execute")
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
    context._query_start_time = time.time()

@event.listens_for(Engine, "after_cursor_execute")
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
    total = time.time() - context._query_start_time
    if total > 0.1:  # 记录超过100ms的查询
        print(f"慢查询 ({total:.3f}s): {statement[:100]}...")

# 查询分析器
class QueryAnalyzer:
    """查询分析器"""
    
    def __init__(self, session):
        self.session = session
        self.queries = []
    
    def __enter__(self):
        # 开始监控
        event.listen(self.session.bind, "before_cursor_execute", self._before_execute)
        event.listen(self.session.bind, "after_cursor_execute", self._after_execute)
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        # 停止监控
        event.remove(self.session.bind, "before_cursor_execute", self._before_execute)
        event.remove(self.session.bind, "after_cursor_execute", self._after_execute)
        
        # 输出分析结果
        self.print_analysis()
    
    def _before_execute(self, conn, cursor, statement, parameters, context, executemany):
        context._start_time = time.time()
    
    def _after_execute(self, conn, cursor, statement, parameters, context, executemany):
        duration = time.time() - context._start_time
        self.queries.append({
            'statement': statement,
            'duration': duration,
            'parameters': parameters
        })
    
    def print_analysis(self):
        """打印分析结果"""
        total_queries = len(self.queries)
        total_time = sum(q['duration'] for q in self.queries)
        avg_time = total_time / total_queries if total_queries > 0 else 0
        
        print(f"\n=== 查询分析结果 ===")
        print(f"总查询数: {total_queries}")
        print(f"总耗时: {total_time:.3f}s")
        print(f"平均耗时: {avg_time:.3f}s")
        
        # 显示最慢的查询
        slow_queries = sorted(self.queries, key=lambda x: x['duration'], reverse=True)[:5]
        print(f"\n最慢的5个查询:")
        for i, query in enumerate(slow_queries, 1):
            print(f"{i}. {query['duration']:.3f}s - {query['statement'][:100]}...")

# 使用查询分析器
with Session() as session:
    with QueryAnalyzer(session) as analyzer:
        # 执行需要分析的查询
        users = session.query(User).all()
        for user in users:
            articles = user.articles  # 可能产生N+1查询

10.3 数据一致性保证

from sqlalchemy import event
from sqlalchemy.orm import Session
from datetime import datetime

# 自动更新时间戳
@event.listens_for(User, 'before_update')
def update_timestamp(mapper, connection, target):
    """更新前自动设置更新时间"""
    target.updated_at = datetime.now()

# 软删除实现
class SoftDeleteMixin:
    """软删除混入类"""
    deleted_at = Column(DateTime, nullable=True)
    is_deleted = Column(Boolean, default=False)
    
    def soft_delete(self):
        """软删除"""
        self.is_deleted = True
        self.deleted_at = datetime.now()
    
    def restore(self):
        """恢复删除"""
        self.is_deleted = False
        self.deleted_at = None

class User(Base, SoftDeleteMixin):
    __tablename__ = 'users'
    # ... 其他字段

# 查询时自动过滤已删除记录
@event.listens_for(Session, 'after_attach')
def auto_filter_deleted(session, instance):
    """自动过滤已删除记录"""
    if hasattr(instance.__class__, 'is_deleted'):
        session.query(instance.__class__).filter(
            instance.__class__.is_deleted == False
        )

# 数据验证
from sqlalchemy.orm import validates

class User(Base):
    __tablename__ = 'users'
    
    email = Column(String(100), nullable=False)
    
    @validates('email')
    def validate_email(self, key, address):
        """验证邮箱格式"""
        import re
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if not re.match(pattern, address):
            raise ValueError("邮箱格式不正确")
        return address

10.4 调试技巧

# 1. 启用SQL日志
engine = create_engine('sqlite:///example.db', echo=True)

# 2. 自定义日志格式
import logging
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)

# 3. 查看生成的SQL
from sqlalchemy.dialects import mysql, postgresql, sqlite

query = session.query(User).filter(User.is_active == True)

# 查看不同数据库的SQL
print("MySQL SQL:", query.statement.compile(dialect=mysql.dialect()))
print("PostgreSQL SQL:", query.statement.compile(dialect=postgresql.dialect()))
print("SQLite SQL:", query.statement.compile(dialect=sqlite.dialect()))

# 4. 使用explain分析查询计划
def explain_query(session, query):
    """分析查询执行计划"""
    sql = str(query.statement.compile(compile_kwargs={"literal_binds": True}))
    explain_sql = f"EXPLAIN QUERY PLAN {sql}"
    result = session.execute(explain_sql)
    
    print("查询执行计划:")
    for row in result:
        print(row)

# 使用示例
with Session() as session:
    query = session.query(User).filter(User.is_active == True)
    explain_query(session, query)

总结

本教程全面介绍了SQLAlchemy的核心概念、基本用法和高级特性。通过学习本教程,你应该能够:

  1. 理解SQLAlchemy架构:掌握ORM层、Core层、Engine层的作用和关系
  2. 熟练使用基本功能:数据模型定义、CRUD操作、查询语法
  3. 掌握高级特性:复杂查询、关系映射、事务管理、性能优化
  4. 应用最佳实践:连接池配置、缓存策略、错误处理、调试技巧
  5. 构建实际项目:通过博客系统示例了解完整的开发流程

学习建议

  1. 循序渐进:从基础概念开始,逐步深入高级特性
  2. 动手实践:运行示例代码,修改参数观察结果
  3. 阅读文档:结合官方文档深入理解细节
  4. 项目实战:在实际项目中应用所学知识
  5. 持续学习:关注SQLAlchemy更新,学习新特性

进阶方向

  • 异步编程:学习SQLAlchemy的异步支持
  • 微服务架构:在分布式系统中使用SQLAlchemy
  • 数据库优化:深入学习数据库性能调优
  • 框架集成:与Flask、FastAPI、Django等框架集成
  • 数据迁移:使用Alembic进行数据库版本管理

希望这份教程能够帮助你掌握SQLAlchemy,在Python数据库开发中游刃有余!


网站公告

今日签到

点亮在社区的每一天
去签到