目录
1. SQLAlchemy 介绍
1.1 什么是SQLAlchemy
SQLAlchemy 是一个用于 Python 的 SQL 工具和对象关系映射(ORM)库。它允许开发者通过 Python 代码来与关系型数据库交互,而不必直接编写SQL语句。
对象关系映射(ORM) 是一种程序设计技术,用于实现面向对象编程语言里不同类型系统的数据之间的转换。简单来说,就是将数据库表映射为Python类,将表中的记录映射为类的实例。
1.2 SQLAlchemy 架构
SQLAlchemy 采用分层架构设计:
┌─────────────────────────────────────┐
│ ORM 层 │
│ (对象关系映射 - 高级抽象) │
├─────────────────────────────────────┤
│ Core 层 │
│ (SQL表达式语言 - 底层抽象) │
├─────────────────────────────────────┤
│ Engine 层 │
│ (数据库引擎 - 连接管理) │
├─────────────────────────────────────┤
│ DBAPI 层 │
│ (数据库驱动 - 具体实现) │
└─────────────────────────────────────┘
1.3 主要应用场景
SQLAlchemy 主要应用于以下场景:
- 数据库访问和操作:提供高层抽象来操作数据库,避免编写原生SQL语句
- ORM映射:建立Python类与数据库表的映射关系,简化数据模型操作
- 复杂查询:提供丰富的查询方式,如过滤、分组、联结等
- 异步查询:基于Greenlet等实现异步查询,提高查询效率
- 事务控制:通过Session管理数据库会话和事务
- 多数据库支持:支持PostgreSQL、MySQL、Oracle、SQLite等主流数据库
- Web框架集成:与Flask、FastAPI等框架无缝集成
2. 环境准备与安装
2.1 安装SQLAlchemy
# 安装SQLAlchemy核心库
pip install sqlalchemy
# 安装数据库驱动(根据需要选择)
pip install pymysql # MySQL驱动
pip install psycopg2-binary # PostgreSQL驱动
pip install cx_Oracle # Oracle驱动
# SQLite驱动已内置在Python标准库中
2.2 数据库依赖对照表
数据库类型 | 依赖库 | 连接字符串示例 |
---|---|---|
关系型数据库 | ||
MySQL | pymysql | mysql+pymysql://username:password@localhost:3306/database_name |
PostgreSQL | psycopg2 | postgresql://username:password@localhost:5432/database_name |
SQLite | 内置 | sqlite:///example.db |
Oracle | cx_Oracle | oracle://username:password@localhost:1521/orcl |
NoSQL数据库 | ||
MongoDB | pymongo | mongodb://username:password@localhost:27017/database_name |
Redis | redis | redis://localhost:6379/0 |
2.3 验证安装
# 验证SQLAlchemy安装
import sqlalchemy
print(f"SQLAlchemy版本: {sqlalchemy.__version__}")
# 测试数据库连接
from sqlalchemy import create_engine
# 创建内存SQLite数据库进行测试
engine = create_engine('sqlite:///:memory:', echo=True)
print("数据库连接测试成功!")
3. 数据库连接
3.1 创建数据库引擎
数据库引擎是SQLAlchemy的核心组件,负责管理数据库连接。
from sqlalchemy import create_engine
# SQLite连接(文件数据库)
engine = create_engine('sqlite:///example.db', echo=True)
# MySQL连接
dbHost = 'mysql+pymysql://root:password@127.0.0.1:3306/test'
engine = create_engine(
dbHost,
echo=True, # 是否打印SQL语句
pool_size=10, # 连接池大小
max_overflow=20, # 超出连接池大小的连接数
pool_pre_ping=True, # 连接前检查连接有效性
pool_recycle=3600 # 连接回收时间(秒)
)
# PostgreSQL连接
engine = create_engine(
'postgresql://username:password@localhost:5432/database',
echo=False,
pool_size=5,
max_overflow=10
)
3.2 引擎参数详解
参数 | 说明 | 默认值 |
---|---|---|
echo |
是否打印执行的SQL语句 | False |
pool_size |
连接池保持的连接数 | 5 |
max_overflow |
允许超过pool_size的最大连接数 | 10 |
pool_timeout |
获取连接的超时时间(秒) | 30 |
pool_recycle |
连接在连接池中保持的最长时间(秒) | -1 |
pool_pre_ping |
连接前检查连接是否有效 | False |
3.3 连接池管理
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool
# 自定义连接池配置
engine = create_engine(
'mysql+pymysql://user:password@localhost/dbname',
poolclass=QueuePool,
pool_size=20, # 连接池大小
max_overflow=30, # 最大溢出连接数
pool_timeout=60, # 获取连接超时时间
pool_recycle=7200, # 连接回收时间(2小时)
echo=True
)
# 获取连接池状态信息
print(f"连接池大小: {engine.pool.size()}")
print(f"已检出连接数: {engine.pool.checkedout()}")
print(f"溢出连接数: {engine.pool.overflow()}")
4. 数据模型定义
4.1 声明式基类
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text
from datetime import datetime
# 创建声明式基类
Base = declarative_base()
4.2 基础模型定义
class User(Base):
"""用户模型类"""
__tablename__ = 'users' # 指定表名
__table_args__ = {'comment': '用户信息表'} # 表注释
# 定义字段
id = Column(Integer, primary_key=True, autoincrement=True, comment='用户ID')
username = Column(String(50), nullable=False, unique=True, comment='用户名')
email = Column(String(100), nullable=False, index=True, comment='邮箱')
password_hash = Column(String(128), nullable=False, comment='密码哈希')
full_name = Column(String(100), comment='全名')
is_active = Column(Boolean, default=True, comment='是否激活')
created_at = Column(DateTime, default=datetime.now, comment='创建时间')
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='更新时间')
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}')>"
def __str__(self):
return f"用户: {self.username} ({self.email})"
4.3 Column常用参数说明
参数 | 说明 | 示例 |
---|---|---|
primary_key |
是否为主键 | primary_key=True |
nullable |
是否允许为空 | nullable=False |
unique |
是否唯一 | unique=True |
index |
是否创建索引 | index=True |
default |
默认值 | default=0 |
onupdate |
更新时的默认值 | onupdate=datetime.now |
autoincrement |
是否自增 | autoincrement=True |
comment |
字段注释 | comment='用户ID' |
4.4 复杂数据类型示例
from sqlalchemy import JSON, DECIMAL, Enum
from sqlalchemy.dialects.mysql import LONGTEXT
import enum
class UserStatus(enum.Enum):
"""用户状态枚举"""
ACTIVE = "active"
INACTIVE = "inactive"
SUSPENDED = "suspended"
class Product(Base):
"""商品模型类"""
__tablename__ = 'products'
id = Column(Integer, primary_key=True)
name = Column(String(200), nullable=False, comment='商品名称')
description = Column(Text, comment='商品描述')
price = Column(DECIMAL(10, 2), nullable=False, comment='价格')
stock = Column(Integer, default=0, comment='库存数量')
# JSON字段存储额外属性
attributes = Column(JSON, comment='商品属性')
# 枚举字段
status = Column(Enum(UserStatus), default=UserStatus.ACTIVE, comment='状态')
# 长文本字段
content = Column(LONGTEXT, comment='详细内容')
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
4.5 表关系定义
from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship
class Category(Base):
"""商品分类模型"""
__tablename__ = 'categories'
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False, comment='分类名称')
description = Column(Text, comment='分类描述')
# 一对多关系:一个分类有多个商品
products = relationship("Product", back_populates="category")
class Product(Base):
"""商品模型(带关系)"""
__tablename__ = 'products'
id = Column(Integer, primary_key=True)
name = Column(String(200), nullable=False)
category_id = Column(Integer, ForeignKey('categories.id'), comment='分类ID')
# 多对一关系:多个商品属于一个分类
category = relationship("Category", back_populates="products")
# 多对多关系示例
from sqlalchemy import Table
# 中间表定义
user_role_association = Table(
'user_roles',
Base.metadata,
Column('user_id', Integer, ForeignKey('users.id')),
Column('role_id', Integer, ForeignKey('roles.id'))
)
class Role(Base):
"""角色模型"""
__tablename__ = 'roles'
id = Column(Integer, primary_key=True)
name = Column(String(50), nullable=False, unique=True)
description = Column(String(200))
# 多对多关系:角色和用户
users = relationship("User", secondary=user_role_association, back_populates="roles")
# 更新User模型,添加角色关系
User.roles = relationship("Role", secondary=user_role_association, back_populates="users")
5. 基本数据操作
5.1 创建表结构
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
# 创建引擎
engine = create_engine('sqlite:///example.db', echo=True)
# 创建所有表
Base.metadata.create_all(engine)
# 删除所有表(谨慎使用)
# Base.metadata.drop_all(engine)
# 创建特定表
# User.__table__.create(engine, checkfirst=True)
# 删除特定表
# User.__table__.drop(engine, checkfirst=True)
5.2 会话管理
from sqlalchemy.orm import sessionmaker, scoped_session
# 创建会话类
Session = sessionmaker(bind=engine)
# 方式1:手动管理会话
session = Session()
try:
# 数据库操作
user = User(username='john', email='john@example.com')
session.add(user)
session.commit()
except Exception as e:
session.rollback()
print(f"操作失败: {e}")
finally:
session.close()
# 方式2:使用上下文管理器(推荐)
with Session() as session:
user = User(username='jane', email='jane@example.com')
session.add(user)
session.commit()
# 自动关闭会话
# 方式3:使用scoped_session(线程安全)
SessionLocal = scoped_session(sessionmaker(bind=engine))
def get_session():
"""获取数据库会话"""
return SessionLocal()
def close_session():
"""关闭会话"""
SessionLocal.remove()
5.3 插入数据
# 单条数据插入
with Session() as session:
# 创建用户对象
new_user = User(
username='alice',
email='alice@example.com',
full_name='Alice Smith',
password_hash='hashed_password_here'
)
# 添加到会话
session.add(new_user)
# 提交事务
session.commit()
# 获取插入后的ID
print(f"新用户ID: {new_user.id}")
# 批量插入
with Session() as session:
users = [
User(username='bob', email='bob@example.com', full_name='Bob Johnson'),
User(username='charlie', email='charlie@example.com', full_name='Charlie Brown'),
User(username='diana', email='diana@example.com', full_name='Diana Prince')
]
# 批量添加
session.add_all(users)
session.commit()
print(f"批量插入了 {len(users)} 个用户")
# 使用bulk_insert_mappings(高性能批量插入)
with Session() as session:
user_data = [
{'username': 'user1', 'email': 'user1@example.com', 'full_name': 'User One'},
{'username': 'user2', 'email': 'user2@example.com', 'full_name': 'User Two'},
{'username': 'user3', 'email': 'user3@example.com', 'full_name': 'User Three'}
]
session.bulk_insert_mappings(User, user_data)
session.commit()
5.4 查询数据
with Session() as session:
# 查询所有用户
all_users = session.query(User).all()
print(f"总用户数: {len(all_users)}")
# 查询第一个用户
first_user = session.query(User).first()
print(f"第一个用户: {first_user}")
# 根据ID查询
user_by_id = session.query(User).get(1) # 主键查询
if user_by_id:
print(f"ID为1的用户: {user_by_id.username}")
# 条件查询
active_users = session.query(User).filter(User.is_active == True).all()
print(f"活跃用户数: {len(active_users)}")
# 单条记录查询(确保唯一)
try:
unique_user = session.query(User).filter(User.username == 'alice').one()
print(f"找到用户: {unique_user.email}")
except Exception as e:
print(f"查询失败: {e}")
# 可能为空的单条查询
maybe_user = session.query(User).filter(User.username == 'nonexistent').one_or_none()
if maybe_user:
print(f"找到用户: {maybe_user.username}")
else:
print("用户不存在")
5.5 更新数据
with Session() as session:
# 方式1:查询后更新
user = session.query(User).filter(User.username == 'alice').first()
if user:
user.full_name = 'Alice Johnson' # 修改属性
user.updated_at = datetime.now() # 更新时间
session.commit()
print(f"用户 {user.username} 信息已更新")
# 方式2:批量更新
updated_count = session.query(User).filter(
User.is_active == True
).update({
User.updated_at: datetime.now()
})
session.commit()
print(f"批量更新了 {updated_count} 个用户")
# 方式3:条件更新
session.query(User).filter(
User.created_at < datetime(2023, 1, 1)
).update({
User.is_active: False,
User.updated_at: datetime.now()
})
session.commit()
5.6 删除数据
with Session() as session:
# 方式1:查询后删除
user_to_delete = session.query(User).filter(User.username == 'bob').first()
if user_to_delete:
session.delete(user_to_delete)
session.commit()
print(f"用户 {user_to_delete.username} 已删除")
# 方式2:批量删除
deleted_count = session.query(User).filter(
User.is_active == False
).delete()
session.commit()
print(f"批量删除了 {deleted_count} 个用户")
# 方式3:条件删除
session.query(User).filter(
User.created_at < datetime(2022, 1, 1)
).delete()
session.commit()
6. 复杂查询操作
6.1 条件查询
from sqlalchemy import and_, or_, not_, func
from sqlalchemy.sql import text
with Session() as session:
# 基本条件查询
users = session.query(User).filter(User.is_active == True).all()
# 多条件查询(AND)
users = session.query(User).filter(
and_(
User.is_active == True,
User.created_at > datetime(2023, 1, 1)
)
).all()
# 或条件查询(OR)
users = session.query(User).filter(
or_(
User.username.like('%admin%'),
User.email.like('%@admin.com')
)
).all()
# 非条件查询(NOT)
users = session.query(User).filter(
not_(User.is_active == False)
).all()
# IN查询
user_ids = [1, 2, 3, 4, 5]
users = session.query(User).filter(User.id.in_(user_ids)).all()
# NOT IN查询
users = session.query(User).filter(~User.id.in_(user_ids)).all()
# LIKE模糊查询
users = session.query(User).filter(User.username.like('%john%')).all()
# ILIKE不区分大小写查询(PostgreSQL)
users = session.query(User).filter(User.username.ilike('%JOHN%')).all()
# BETWEEN范围查询
users = session.query(User).filter(
User.created_at.between(datetime(2023, 1, 1), datetime(2023, 12, 31))
).all()
# IS NULL查询
users = session.query(User).filter(User.full_name.is_(None)).all()
# IS NOT NULL查询
users = session.query(User).filter(User.full_name.isnot(None)).all()
# 原生SQL条件
users = session.query(User).filter(
text("username LIKE :pattern")
).params(pattern='%admin%').all()
6.2 排序和分页
with Session() as session:
# 升序排序
users = session.query(User).order_by(User.created_at).all()
# 降序排序
users = session.query(User).order_by(User.created_at.desc()).all()
# 多字段排序
users = session.query(User).order_by(
User.is_active.desc(), # 先按活跃状态降序
User.created_at.asc() # 再按创建时间升序
).all()
# 限制结果数量
recent_users = session.query(User).order_by(
User.created_at.desc()
).limit(10).all()
# 跳过指定数量的记录
users = session.query(User).offset(20).limit(10).all()
# 分页查询
page = 1
per_page = 10
users = session.query(User).offset(
(page - 1) * per_page
).limit(per_page).all()
# 使用slice进行分页(更Pythonic)
users = session.query(User)[20:30] # 获取第21-30条记录
6.3 聚合查询
from sqlalchemy import func, distinct
with Session() as session:
# 计数查询
total_users = session.query(func.count(User.id)).scalar()
print(f"总用户数: {total_users}")
# 去重计数
unique_emails = session.query(func.count(distinct(User.email))).scalar()
print(f"唯一邮箱数: {unique_emails}")
# 最大值、最小值
latest_user = session.query(func.max(User.created_at)).scalar()
earliest_user = session.query(func.min(User.created_at)).scalar()
# 平均值(假设User有age字段)
# avg_age = session.query(func.avg(User.age)).scalar()
# 求和
# total_age = session.query(func.sum(User.age)).scalar()
# 分组查询
user_counts_by_status = session.query(
User.is_active,
func.count(User.id).label('count')
).group_by(User.is_active).all()
for is_active, count in user_counts_by_status:
status = "活跃" if is_active else "非活跃"
print(f"{status}用户数: {count}")
# HAVING子句(分组后筛选)
active_email_domains = session.query(
func.substr(User.email, func.instr(User.email, '@') + 1).label('domain'),
func.count(User.id).label('count')
).filter(
User.is_active == True
).group_by(
func.substr(User.email, func.instr(User.email, '@') + 1)
).having(
func.count(User.id) > 1 # 只显示用户数大于1的域名
).all()
6.4 连接查询(JOIN)
with Session() as session:
# 内连接(INNER JOIN)
results = session.query(User, Product).join(
Product, User.id == Product.user_id
).all()
# 左外连接(LEFT OUTER JOIN)
results = session.query(User, Product).outerjoin(
Product, User.id == Product.user_id
).all()
# 使用relationship进行连接
results = session.query(User).join(User.products).all()
# 连接查询with条件
results = session.query(User, Product).join(
Product
).filter(
Product.price > 100
).all()
# 多表连接
results = session.query(User, Product, Category).join(
Product
).join(
Category
).filter(
Category.name == 'Electronics'
).all()
# 自连接(查找同一表中的关联数据)
# 假设User表有manager_id字段
manager_alias = aliased(User)
results = session.query(User, manager_alias).join(
manager_alias, User.manager_id == manager_alias.id
).all()
6.5 子查询
from sqlalchemy.orm import aliased
with Session() as session:
# 标量子查询
avg_price_subquery = session.query(
func.avg(Product.price)
).scalar_subquery()
expensive_products = session.query(Product).filter(
Product.price > avg_price_subquery
).all()
# 表子查询
user_product_count = session.query(
Product.user_id,
func.count(Product.id).label('product_count')
).group_by(Product.user_id).subquery()
# 使用子查询结果
productive_users = session.query(User).join(
user_product_count, User.id == user_product_count.c.user_id
).filter(
user_product_count.c.product_count > 5
).all()
# EXISTS子查询
from sqlalchemy.sql import exists
users_with_products = session.query(User).filter(
exists().where(Product.user_id == User.id)
).all()
# NOT EXISTS子查询
users_without_products = session.query(User).filter(
~exists().where(Product.user_id == User.id)
).all()
# IN子查询
active_user_ids = session.query(User.id).filter(
User.is_active == True
).subquery()
products_by_active_users = session.query(Product).filter(
Product.user_id.in_(active_user_ids)
).all()
6.6 窗口函数(高级查询)
from sqlalchemy import func
with Session() as session:
# ROW_NUMBER() 窗口函数
results = session.query(
User.id,
User.username,
func.row_number().over(
order_by=User.created_at.desc()
).label('row_num')
).all()
# RANK() 窗口函数
results = session.query(
Product.id,
Product.name,
Product.price,
func.rank().over(
order_by=Product.price.desc()
).label('price_rank')
).all()
# 分区窗口函数
results = session.query(
Product.id,
Product.name,
Product.category_id,
Product.price,
func.row_number().over(
partition_by=Product.category_id,
order_by=Product.price.desc()
).label('category_rank')
).all()
# LAG/LEAD 窗口函数(获取前一行/后一行数据)
results = session.query(
User.id,
User.username,
User.created_at,
func.lag(User.created_at, 1).over(
order_by=User.created_at
).label('prev_created_at')
).all()
7. 高级特性
7.1 事务管理
from sqlalchemy.exc import SQLAlchemyError
# 基本事务管理
with Session() as session:
try:
# 开始事务(自动开始)
user1 = User(username='user1', email='user1@example.com')
user2 = User(username='user2', email='user2@example.com')
session.add(user1)
session.add(user2)
# 提交事务
session.commit()
print("事务提交成功")
except SQLAlchemyError as e:
# 回滚事务
session.rollback()
print(f"事务回滚: {e}")
raise
# 嵌套事务(保存点)
with Session() as session:
try:
user = User(username='main_user', email='main@example.com')
session.add(user)
# 创建保存点
savepoint = session.begin_nested()
try:
# 可能失败的操作
risky_user = User(username='risky', email='invalid_email')
session.add(risky_user)
session.flush() # 强制执行SQL但不提交
# 如果成功,提交保存点
savepoint.commit()
except Exception as e:
# 回滚到保存点
savepoint.rollback()
print(f"保存点回滚: {e}")
# 主事务提交
session.commit()
except Exception as e:
session.rollback()
print(f"主事务回滚: {e}")
7.2 懒加载与预加载
from sqlalchemy.orm import joinedload, selectinload, subqueryload
# 懒加载(默认行为)
with Session() as session:
user = session.query(User).first()
# 访问关联对象时才执行查询
products = user.products # 这里会触发额外的SQL查询
# 预加载 - joinedload(使用JOIN)
with Session() as session:
users = session.query(User).options(
joinedload(User.products)
).all()
# 一次查询获取用户和产品数据
for user in users:
print(f"用户 {user.username} 有 {len(user.products)} 个产品")
# 预加载 - selectinload(使用IN查询)
with Session() as session:
users = session.query(User).options(
selectinload(User.products)
).all()
# 两次查询:先查用户,再用IN查询产品
# 预加载 - subqueryload(使用子查询)
with Session() as session:
users = session.query(User).options(
subqueryload(User.products)
).all()
# 使用子查询预加载关联数据
# 多层预加载
with Session() as session:
users = session.query(User).options(
joinedload(User.products).joinedload(Product.category)
).all()
# 一次查询获取用户、产品和分类数据
# 选择性预加载
with Session() as session:
users = session.query(User).options(
selectinload(User.products).selectinload(Product.reviews)
).filter(User.is_active == True).all()
7.3 缓存机制
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session
# 一级缓存(Session级别)
with Session() as session:
# 第一次查询,从数据库获取
user1 = session.query(User).get(1)
# 第二次查询,从Session缓存获取
user2 = session.query(User).get(1)
# 两个对象是同一个实例
assert user1 is user2
print("Session缓存生效")
# 二级缓存(需要额外配置)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session
from sqlalchemy import event
# 简单的内存缓存实现
class SimpleCache:
def __init__(self):
self._cache = {}
def get(self, key):
return self._cache.get(key)
def set(self, key, value):
self._cache[key] = value
def delete(self, key):
self._cache.pop(key, None)
cache = SimpleCache()
# 查询缓存装饰器
def cached_query(cache_key):
def decorator(func):
def wrapper(*args, **kwargs):
# 尝试从缓存获取
result = cache.get(cache_key)
if result is not None:
print(f"从缓存获取: {cache_key}")
return result
# 执行查询
result = func(*args, **kwargs)
# 存入缓存
cache.set(cache_key, result)
print(f"存入缓存: {cache_key}")
return result
return wrapper
return decorator
@cached_query('all_active_users')
def get_active_users(session):
return session.query(User).filter(User.is_active == True).all()
7.4 数据库迁移
# 使用Alembic进行数据库迁移
# 首先安装: pip install alembic
# 初始化迁移环境
# alembic init migrations
# 创建迁移脚本
# alembic revision --autogenerate -m "创建用户表"
# 执行迁移
# alembic upgrade head
# 回滚迁移
# alembic downgrade -1
# 编程方式创建迁移
from alembic import command
from alembic.config import Config
def run_migrations():
"""运行数据库迁移"""
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
def create_migration(message):
"""创建新的迁移文件"""
alembic_cfg = Config("alembic.ini")
command.revision(alembic_cfg, autogenerate=True, message=message)
7.5 异步支持
# SQLAlchemy 1.4+ 异步支持
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
import asyncio
# 创建异步引擎
async_engine = create_async_engine(
"postgresql+asyncpg://user:password@localhost/dbname",
echo=True
)
# 创建异步会话
AsyncSessionLocal = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False
)
# 异步数据操作
async def create_user_async(username: str, email: str):
"""异步创建用户"""
async with AsyncSessionLocal() as session:
user = User(username=username, email=email)
session.add(user)
await session.commit()
return user
async def get_users_async():
"""异步获取用户列表"""
async with AsyncSessionLocal() as session:
result = await session.execute(
select(User).filter(User.is_active == True)
)
return result.scalars().all()
# 运行异步函数
async def main():
# 创建表
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# 创建用户
user = await create_user_async("async_user", "async@example.com")
print(f"创建用户: {user.username}")
# 查询用户
users = await get_users_async()
print(f"查询到 {len(users)} 个用户")
# 运行异步代码
# asyncio.run(main())
8. 实战项目示例
8.1 项目结构
blog_project/
├── main.py # 应用入口
├── database.py # 数据库配置
├── models.py # 数据模型
├── schemas.py # 数据验证模式
├── crud.py # 数据库操作
├── api.py # API接口
└── requirements.txt # 依赖包
8.2 数据库配置 (database.py)
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import os
# 数据库URL配置
DATABASE_URL = os.getenv(
"DATABASE_URL",
"sqlite:///./blog.db"
)
# 创建数据库引擎
engine = create_engine(
DATABASE_URL,
echo=True, # 开发环境显示SQL
pool_pre_ping=True,
pool_recycle=3600
)
# 创建会话工厂
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
# 创建基类
Base = declarative_base()
# 依赖注入:获取数据库会话
def get_db():
"""获取数据库会话"""
db = SessionLocal()
try:
yield db
finally:
db.close()
# 创建所有表
def create_tables():
"""创建数据库表"""
Base.metadata.create_all(bind=engine)
# 删除所有表
def drop_tables():
"""删除数据库表"""
Base.metadata.drop_all(bind=engine)
8.3 数据模型 (models.py)
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, Table
from sqlalchemy.orm import relationship
from datetime import datetime
from database import Base
# 多对多关系中间表:文章标签
article_tags = Table(
'article_tags',
Base.metadata,
Column('article_id', Integer, ForeignKey('articles.id')),
Column('tag_id', Integer, ForeignKey('tags.id'))
)
class User(Base):
"""用户模型"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, index=True, nullable=False)
email = Column(String(100), unique=True, index=True, nullable=False)
hashed_password = Column(String(100), nullable=False)
full_name = Column(String(100))
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
# 关系:一个用户可以有多篇文章
articles = relationship("Article", back_populates="author", cascade="all, delete-orphan")
comments = relationship("Comment", back_populates="author", cascade="all, delete-orphan")
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}')>"
class Category(Base):
"""分类模型"""
__tablename__ = 'categories'
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), unique=True, nullable=False)
description = Column(Text)
created_at = Column(DateTime, default=datetime.now)
# 关系:一个分类可以有多篇文章
articles = relationship("Article", back_populates="category")
def __repr__(self):
return f"<Category(id={self.id}, name='{self.name}')>"
class Tag(Base):
"""标签模型"""
__tablename__ = 'tags'
id = Column(Integer, primary_key=True, index=True)
name = Column(String(30), unique=True, nullable=False)
created_at = Column(DateTime, default=datetime.now)
# 多对多关系:标签和文章
articles = relationship("Article", secondary=article_tags, back_populates="tags")
def __repr__(self):
return f"<Tag(id={self.id}, name='{self.name}')>"
class Article(Base):
"""文章模型"""
__tablename__ = 'articles'
id = Column(Integer, primary_key=True, index=True)
title = Column(String(200), nullable=False, index=True)
content = Column(Text, nullable=False)
summary = Column(String(500))
is_published = Column(Boolean, default=False)
view_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.now, index=True)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
published_at = Column(DateTime)
# 外键
author_id = Column(Integer, ForeignKey('users.id'), nullable=False)
category_id = Column(Integer, ForeignKey('categories.id'))
# 关系
author = relationship("User", back_populates="articles")
category = relationship("Category", back_populates="articles")
tags = relationship("Tag", secondary=article_tags, back_populates="articles")
comments = relationship("Comment", back_populates="article", cascade="all, delete-orphan")
def __repr__(self):
return f"<Article(id={self.id}, title='{self.title[:30]}...')>"
class Comment(Base):
"""评论模型"""
__tablename__ = 'comments'
id = Column(Integer, primary_key=True, index=True)
content = Column(Text, nullable=False)
is_approved = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.now)
# 外键
article_id = Column(Integer, ForeignKey('articles.id'), nullable=False)
author_id = Column(Integer, ForeignKey('users.id'), nullable=False)
parent_id = Column(Integer, ForeignKey('comments.id')) # 回复评论
# 关系
article = relationship("Article", back_populates="comments")
author = relationship("User", back_populates="comments")
parent = relationship("Comment", remote_side=[id]) # 自引用关系
def __repr__(self):
return f"<Comment(id={self.id}, content='{self.content[:30]}...')>"
8.4 数据操作层 (crud.py)
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func, desc
from models import User, Article, Category, Tag, Comment
from typing import List, Optional
from datetime import datetime
class UserCRUD:
"""用户数据操作"""
@staticmethod
def create_user(db: Session, username: str, email: str, password: str, full_name: str = None) -> User:
"""创建用户"""
user = User(
username=username,
email=email,
hashed_password=password, # 实际应用中需要加密
full_name=full_name
)
db.add(user)
db.commit()
db.refresh(user)
return user
@staticmethod
def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
"""根据ID获取用户"""
return db.query(User).filter(User.id == user_id).first()
@staticmethod
def get_user_by_username(db: Session, username: str) -> Optional[User]:
"""根据用户名获取用户"""
return db.query(User).filter(User.username == username).first()
@staticmethod
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
"""获取用户列表"""
return db.query(User).offset(skip).limit(limit).all()
@staticmethod
def update_user(db: Session, user_id: int, **kwargs) -> Optional[User]:
"""更新用户信息"""
user = db.query(User).filter(User.id == user_id).first()
if user:
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
user.updated_at = datetime.now()
db.commit()
db.refresh(user)
return user
@staticmethod
def delete_user(db: Session, user_id: int) -> bool:
"""删除用户"""
user = db.query(User).filter(User.id == user_id).first()
if user:
db.delete(user)
db.commit()
return True
return False
class ArticleCRUD:
"""文章数据操作"""
@staticmethod
def create_article(db: Session, title: str, content: str, author_id: int,
category_id: int = None, tag_names: List[str] = None) -> Article:
"""创建文章"""
article = Article(
title=title,
content=content,
author_id=author_id,
category_id=category_id,
summary=content[:200] + "..." if len(content) > 200 else content
)
# 处理标签
if tag_names:
for tag_name in tag_names:
tag = db.query(Tag).filter(Tag.name == tag_name).first()
if not tag:
tag = Tag(name=tag_name)
db.add(tag)
article.tags.append(tag)
db.add(article)
db.commit()
db.refresh(article)
return article
@staticmethod
def get_article_by_id(db: Session, article_id: int) -> Optional[Article]:
"""根据ID获取文章"""
return db.query(Article).filter(Article.id == article_id).first()
@staticmethod
def get_articles(db: Session, skip: int = 0, limit: int = 20,
published_only: bool = True) -> List[Article]:
"""获取文章列表"""
query = db.query(Article)
if published_only:
query = query.filter(Article.is_published == True)
return query.order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
@staticmethod
def get_articles_by_category(db: Session, category_id: int,
skip: int = 0, limit: int = 20) -> List[Article]:
"""根据分类获取文章"""
return db.query(Article).filter(
and_(Article.category_id == category_id, Article.is_published == True)
).order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
@staticmethod
def get_articles_by_tag(db: Session, tag_name: str,
skip: int = 0, limit: int = 20) -> List[Article]:
"""根据标签获取文章"""
return db.query(Article).join(Article.tags).filter(
and_(Tag.name == tag_name, Article.is_published == True)
).order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
@staticmethod
def search_articles(db: Session, keyword: str, skip: int = 0, limit: int = 20) -> List[Article]:
"""搜索文章"""
return db.query(Article).filter(
and_(
or_(
Article.title.contains(keyword),
Article.content.contains(keyword)
),
Article.is_published == True
)
).order_by(desc(Article.created_at)).offset(skip).limit(limit).all()
@staticmethod
def publish_article(db: Session, article_id: int) -> Optional[Article]:
"""发布文章"""
article = db.query(Article).filter(Article.id == article_id).first()
if article:
article.is_published = True
article.published_at = datetime.now()
db.commit()
db.refresh(article)
return article
@staticmethod
def increment_view_count(db: Session, article_id: int) -> Optional[Article]:
"""增加文章浏览量"""
article = db.query(Article).filter(Article.id == article_id).first()
if article:
article.view_count += 1
db.commit()
db.refresh(article)
return article
class StatisticsCRUD:
"""统计数据操作"""
@staticmethod
def get_article_stats(db: Session) -> dict:
"""获取文章统计信息"""
total_articles = db.query(func.count(Article.id)).scalar()
published_articles = db.query(func.count(Article.id)).filter(
Article.is_published == True
).scalar()
total_views = db.query(func.sum(Article.view_count)).scalar() or 0
return {
"total_articles": total_articles,
"published_articles": published_articles,
"draft_articles": total_articles - published_articles,
"total_views": total_views
}
@staticmethod
def get_popular_articles(db: Session, limit: int = 10) -> List[Article]:
"""获取热门文章"""
return db.query(Article).filter(
Article.is_published == True
).order_by(desc(Article.view_count)).limit(limit).all()
@staticmethod
def get_category_stats(db: Session) -> List[dict]:
"""获取分类统计"""
results = db.query(
Category.name,
func.count(Article.id).label('article_count')
).outerjoin(Article).group_by(Category.id, Category.name).all()
return [
{"category": name, "count": count}
for name, count in results
]
@staticmethod
def get_monthly_article_stats(db: Session, year: int) -> List[dict]:
"""获取月度文章统计"""
results = db.query(
func.extract('month', Article.created_at).label('month'),
func.count(Article.id).label('count')
).filter(
and_(
func.extract('year', Article.created_at) == year,
Article.is_published == True
)
).group_by(func.extract('month', Article.created_at)).all()
return [
{"month": int(month), "count": count}
for month, count in results
]
8.5 API接口层 (api.py)
from fastapi import FastAPI, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from database import get_db, create_tables
from crud import UserCRUD, ArticleCRUD, StatisticsCRUD
from typing import List, Optional
import uvicorn
# 创建FastAPI应用
app = FastAPI(
title="博客系统API",
description="基于SQLAlchemy的博客系统",
version="1.0.0"
)
# 启动时创建表
@app.on_event("startup")
def startup_event():
create_tables()
# 用户相关接口
@app.post("/users/", summary="创建用户")
def create_user(
username: str,
email: str,
password: str,
full_name: Optional[str] = None,
db: Session = Depends(get_db)
):
"""创建新用户"""
# 检查用户名是否已存在
existing_user = UserCRUD.get_user_by_username(db, username)
if existing_user:
raise HTTPException(status_code=400, detail="用户名已存在")
user = UserCRUD.create_user(db, username, email, password, full_name)
return {"message": "用户创建成功", "user_id": user.id}
@app.get("/users/{user_id}", summary="获取用户信息")
def get_user(user_id: int, db: Session = Depends(get_db)):
"""根据ID获取用户信息"""
user = UserCRUD.get_user_by_id(db, user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return {
"id": user.id,
"username": user.username,
"email": user.email,
"full_name": user.full_name,
"is_active": user.is_active,
"created_at": user.created_at
}
@app.get("/users/", summary="获取用户列表")
def get_users(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db)
):
"""获取用户列表"""
users = UserCRUD.get_users(db, skip, limit)
return {
"users": [
{
"id": user.id,
"username": user.username,
"email": user.email,
"full_name": user.full_name,
"is_active": user.is_active
}
for user in users
],
"total": len(users)
}
# 文章相关接口
@app.post("/articles/", summary="创建文章")
def create_article(
title: str,
content: str,
author_id: int,
category_id: Optional[int] = None,
tag_names: Optional[List[str]] = None,
db: Session = Depends(get_db)
):
"""创建新文章"""
# 验证作者是否存在
author = UserCRUD.get_user_by_id(db, author_id)
if not author:
raise HTTPException(status_code=404, detail="作者不存在")
article = ArticleCRUD.create_article(
db, title, content, author_id, category_id, tag_names or []
)
return {"message": "文章创建成功", "article_id": article.id}
@app.get("/articles/{article_id}", summary="获取文章详情")
def get_article(article_id: int, db: Session = Depends(get_db)):
"""获取文章详情"""
article = ArticleCRUD.get_article_by_id(db, article_id)
if not article:
raise HTTPException(status_code=404, detail="文章不存在")
# 增加浏览量
ArticleCRUD.increment_view_count(db, article_id)
return {
"id": article.id,
"title": article.title,
"content": article.content,
"summary": article.summary,
"is_published": article.is_published,
"view_count": article.view_count,
"created_at": article.created_at,
"author": {
"id": article.author.id,
"username": article.author.username
},
"category": {
"id": article.category.id,
"name": article.category.name
} if article.category else None,
"tags": [
{"id": tag.id, "name": tag.name}
for tag in article.tags
]
}
@app.get("/articles/", summary="获取文章列表")
def get_articles(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
published_only: bool = Query(True),
db: Session = Depends(get_db)
):
"""获取文章列表"""
articles = ArticleCRUD.get_articles(db, skip, limit, published_only)
return {
"articles": [
{
"id": article.id,
"title": article.title,
"summary": article.summary,
"view_count": article.view_count,
"created_at": article.created_at,
"author": article.author.username
}
for article in articles
],
"total": len(articles)
}
@app.get("/search/articles/", summary="搜索文章")
def search_articles(
keyword: str = Query(..., min_length=1),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db)
):
"""搜索文章"""
articles = ArticleCRUD.search_articles(db, keyword, skip, limit)
return {
"keyword": keyword,
"articles": [
{
"id": article.id,
"title": article.title,
"summary": article.summary,
"view_count": article.view_count,
"created_at": article.created_at
}
for article in articles
],
"total": len(articles)
}
# 统计相关接口
@app.get("/statistics/overview", summary="获取统计概览")
def get_statistics_overview(db: Session = Depends(get_db)):
"""获取统计概览"""
stats = StatisticsCRUD.get_article_stats(db)
category_stats = StatisticsCRUD.get_category_stats(db)
popular_articles = StatisticsCRUD.get_popular_articles(db, 5)
return {
"article_stats": stats,
"category_stats": category_stats,
"popular_articles": [
{
"id": article.id,
"title": article.title,
"view_count": article.view_count
}
for article in popular_articles
]
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
8.6 应用入口 (main.py)
from database import create_tables, get_db
from crud import UserCRUD, ArticleCRUD
from models import User, Article, Category, Tag
from datetime import datetime
def init_sample_data():
"""初始化示例数据"""
# 获取数据库会话
db = next(get_db())
try:
# 创建示例用户
admin_user = UserCRUD.create_user(
db, "admin", "admin@blog.com", "hashed_password", "管理员"
)
author_user = UserCRUD.create_user(
db, "author", "author@blog.com", "hashed_password", "作者"
)
# 创建示例分类
tech_category = Category(name="技术", description="技术相关文章")
life_category = Category(name="生活", description="生活随笔")
db.add_all([tech_category, life_category])
db.commit()
# 创建示例文章
article1 = ArticleCRUD.create_article(
db,
title="SQLAlchemy入门教程",
content="这是一篇关于SQLAlchemy的详细教程...",
author_id=author_user.id,
category_id=tech_category.id,
tag_names=["Python", "数据库", "ORM"]
)
article2 = ArticleCRUD.create_article(
db,
title="我的编程之路",
content="分享我学习编程的心得体会...",
author_id=author_user.id,
category_id=life_category.id,
tag_names=["编程", "心得"]
)
# 发布文章
ArticleCRUD.publish_article(db, article1.id)
ArticleCRUD.publish_article(db, article2.id)
print("示例数据初始化完成!")
except Exception as e:
print(f"初始化数据失败: {e}")
db.rollback()
finally:
db.close()
def main():
"""主函数"""
print("=== SQLAlchemy博客系统 ===")
# 创建数据库表
print("创建数据库表...")
create_tables()
# 初始化示例数据
print("初始化示例数据...")
init_sample_data()
print("\n系统启动完成!")
print("API文档地址: http://localhost:8000/docs")
print("启动API服务器: python api.py")
if __name__ == "__main__":
main()
9. 性能优化与最佳实践
9.1 查询优化
# 1. 使用索引
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
username = Column(String(50), index=True) # 单列索引
email = Column(String(100), index=True)
created_at = Column(DateTime, index=True)
# 复合索引
__table_args__ = (
Index('idx_username_email', 'username', 'email'),
Index('idx_active_created', 'is_active', 'created_at'),
)
# 2. 预加载关联数据
from sqlalchemy.orm import joinedload, selectinload
# 避免N+1查询问题
with Session() as session:
# 错误方式:会产生N+1查询
users = session.query(User).all()
for user in users:
print(f"{user.username}: {len(user.articles)} 篇文章") # 每次都查询
# 正确方式:使用预加载
users = session.query(User).options(
selectinload(User.articles)
).all()
for user in users:
print(f"{user.username}: {len(user.articles)} 篇文章") # 不会额外查询
# 3. 使用批量操作
with Session() as session:
# 批量插入
users_data = [
{'username': f'user{i}', 'email': f'user{i}@example.com'}
for i in range(1000)
]
session.bulk_insert_mappings(User, users_data)
# 批量更新
session.query(User).filter(
User.created_at < datetime(2023, 1, 1)
).update({'is_active': False})
session.commit()
# 4. 使用原生SQL进行复杂查询
from sqlalchemy import text
with Session() as session:
# 复杂统计查询使用原生SQL
result = session.execute(text("""
SELECT
c.name as category_name,
COUNT(a.id) as article_count,
AVG(a.view_count) as avg_views
FROM categories c
LEFT JOIN articles a ON c.id = a.category_id
WHERE a.is_published = true
GROUP BY c.id, c.name
ORDER BY article_count DESC
"""))
for row in result:
print(f"分类: {row.category_name}, 文章数: {row.article_count}, 平均浏览: {row.avg_views}")
9.2 连接池优化
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool
# 生产环境连接池配置
engine = create_engine(
DATABASE_URL,
# 连接池配置
poolclass=QueuePool,
pool_size=20, # 连接池大小
max_overflow=30, # 最大溢出连接
pool_timeout=60, # 获取连接超时
pool_recycle=3600, # 连接回收时间
pool_pre_ping=True, # 连接前检查
# 性能优化
echo=False, # 生产环境关闭SQL日志
future=True, # 使用2.0风格API
# 连接参数
connect_args={
"charset": "utf8mb4",
"autocommit": False,
"check_same_thread": False # SQLite专用
}
)
# 监控连接池状态
def monitor_pool_status(engine):
"""监控连接池状态"""
pool = engine.pool
print(f"连接池大小: {pool.size()}")
print(f"已检出连接: {pool.checkedout()}")
print(f"溢出连接: {pool.overflow()}")
print(f"无效连接: {pool.invalidated()}")
9.3 缓存策略
import redis
import json
from functools import wraps
from typing import Any, Optional
# Redis缓存配置
redis_client = redis.Redis(
host='localhost',
port=6379,
db=0,
decode_responses=True
)
def cache_result(key_prefix: str, expire_time: int = 3600):
"""结果缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"{key_prefix}:{hash(str(args) + str(kwargs))}"
# 尝试从缓存获取
cached_result = redis_client.get(cache_key)
if cached_result:
return json.loads(cached_result)
# 执行函数
result = func(*args, **kwargs)
# 存入缓存
redis_client.setex(
cache_key,
expire_time,
json.dumps(result, default=str)
)
return result
return wrapper
return decorator
# 使用缓存
class CachedArticleCRUD(ArticleCRUD):
"""带缓存的文章操作"""
@staticmethod
@cache_result("hot_articles", 1800) # 缓存30分钟
def get_hot_articles(db: Session, limit: int = 10) -> List[dict]:
"""获取热门文章(带缓存)"""
articles = db.query(Article).filter(
Article.is_published == True
).order_by(Article.view_count.desc()).limit(limit).all()
return [
{
"id": article.id,
"title": article.title,
"view_count": article.view_count
}
for article in articles
]
@staticmethod
def invalidate_article_cache(article_id: int):
"""清除文章相关缓存"""
patterns = [
f"article:{article_id}:*",
"hot_articles:*",
"recent_articles:*"
]
for pattern in patterns:
keys = redis_client.keys(pattern)
if keys:
redis_client.delete(*keys)
9.4 数据库分页优化
from sqlalchemy import func
from typing import Tuple, List
class PaginationHelper:
"""分页助手类"""
@staticmethod
def paginate_query(query, page: int, per_page: int) -> Tuple[List, dict]:
"""查询分页"""
# 计算总数
total = query.count()
# 计算分页信息
total_pages = (total + per_page - 1) // per_page
has_prev = page > 1
has_next = page < total_pages
# 获取当前页数据
items = query.offset((page - 1) * per_page).limit(per_page).all()
pagination_info = {
"page": page,
"per_page": per_page,
"total": total,
"total_pages": total_pages,
"has_prev": has_prev,
"has_next": has_next,
"prev_page": page - 1 if has_prev else None,
"next_page": page + 1 if has_next else None
}
return items, pagination_info
@staticmethod
def cursor_paginate(query, cursor_field, cursor_value=None, limit: int = 20, desc: bool = True):
"""游标分页(适合大数据量)"""
if cursor_value is not None:
if desc:
query = query.filter(cursor_field < cursor_value)
else:
query = query.filter(cursor_field > cursor_value)
if desc:
query = query.order_by(cursor_field.desc())
else:
query = query.order_by(cursor_field.asc())
items = query.limit(limit + 1).all()
has_more = len(items) > limit
if has_more:
items = items[:-1]
next_cursor = None
if has_more and items:
next_cursor = getattr(items[-1], cursor_field.name)
return items, {
"has_more": has_more,
"next_cursor": next_cursor,
"limit": limit
}
# 使用示例
with Session() as session:
# 传统分页
query = session.query(Article).filter(Article.is_published == True)
articles, pagination = PaginationHelper.paginate_query(query, page=1, per_page=20)
# 游标分页(适合实时数据)
articles, cursor_info = PaginationHelper.cursor_paginate(
query, Article.created_at, limit=20
)
10. 常见问题与解决方案
10.1 常见错误及解决方案
# 1. 解决"DetachedInstanceError"错误
from sqlalchemy.orm import make_transient
with Session() as session:
user = session.query(User).first()
# 会话关闭后访问关联对象会报错
# 解决方案1:在会话内访问所有需要的数据
with Session() as session:
user = session.query(User).options(
joinedload(User.articles)
).first()
# 在会话内访问关联数据
articles = user.articles
# 解决方案2:使用expunge_all()或merge()
with Session() as session:
user = session.query(User).first()
session.expunge(user) # 从会话中移除对象
# 在新会话中重新附加
with Session() as session:
user = session.merge(user) # 重新附加到会话
articles = user.articles
# 2. 解决"IntegrityError"约束违反错误
try:
with Session() as session:
user = User(username="existing_user", email="test@example.com")
session.add(user)
session.commit()
except IntegrityError as e:
print(f"约束违反: {e}")
# 处理重复数据或约束违反
# 3. 解决"PendingRollbackError"错误
try:
with Session() as session:
# 可能出错的操作
user = User(username=None) # 违反非空约束
session.add(user)
session.commit()
except Exception as e:
# 必须回滚事务
session.rollback()
print(f"操作失败,已回滚: {e}")
# 4. 解决"StatementError"参数错误
# 错误方式
# session.query(User).filter(User.id == None) # 应该使用is_(None)
# 正确方式
session.query(User).filter(User.id.is_(None))
session.query(User).filter(User.id.isnot(None))
10.2 性能问题诊断
import time
from sqlalchemy import event
from sqlalchemy.engine import Engine
# SQL执行时间监控
@event.listens_for(Engine, "before_cursor_execute")
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
context._query_start_time = time.time()
@event.listens_for(Engine, "after_cursor_execute")
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - context._query_start_time
if total > 0.1: # 记录超过100ms的查询
print(f"慢查询 ({total:.3f}s): {statement[:100]}...")
# 查询分析器
class QueryAnalyzer:
"""查询分析器"""
def __init__(self, session):
self.session = session
self.queries = []
def __enter__(self):
# 开始监控
event.listen(self.session.bind, "before_cursor_execute", self._before_execute)
event.listen(self.session.bind, "after_cursor_execute", self._after_execute)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# 停止监控
event.remove(self.session.bind, "before_cursor_execute", self._before_execute)
event.remove(self.session.bind, "after_cursor_execute", self._after_execute)
# 输出分析结果
self.print_analysis()
def _before_execute(self, conn, cursor, statement, parameters, context, executemany):
context._start_time = time.time()
def _after_execute(self, conn, cursor, statement, parameters, context, executemany):
duration = time.time() - context._start_time
self.queries.append({
'statement': statement,
'duration': duration,
'parameters': parameters
})
def print_analysis(self):
"""打印分析结果"""
total_queries = len(self.queries)
total_time = sum(q['duration'] for q in self.queries)
avg_time = total_time / total_queries if total_queries > 0 else 0
print(f"\n=== 查询分析结果 ===")
print(f"总查询数: {total_queries}")
print(f"总耗时: {total_time:.3f}s")
print(f"平均耗时: {avg_time:.3f}s")
# 显示最慢的查询
slow_queries = sorted(self.queries, key=lambda x: x['duration'], reverse=True)[:5]
print(f"\n最慢的5个查询:")
for i, query in enumerate(slow_queries, 1):
print(f"{i}. {query['duration']:.3f}s - {query['statement'][:100]}...")
# 使用查询分析器
with Session() as session:
with QueryAnalyzer(session) as analyzer:
# 执行需要分析的查询
users = session.query(User).all()
for user in users:
articles = user.articles # 可能产生N+1查询
10.3 数据一致性保证
from sqlalchemy import event
from sqlalchemy.orm import Session
from datetime import datetime
# 自动更新时间戳
@event.listens_for(User, 'before_update')
def update_timestamp(mapper, connection, target):
"""更新前自动设置更新时间"""
target.updated_at = datetime.now()
# 软删除实现
class SoftDeleteMixin:
"""软删除混入类"""
deleted_at = Column(DateTime, nullable=True)
is_deleted = Column(Boolean, default=False)
def soft_delete(self):
"""软删除"""
self.is_deleted = True
self.deleted_at = datetime.now()
def restore(self):
"""恢复删除"""
self.is_deleted = False
self.deleted_at = None
class User(Base, SoftDeleteMixin):
__tablename__ = 'users'
# ... 其他字段
# 查询时自动过滤已删除记录
@event.listens_for(Session, 'after_attach')
def auto_filter_deleted(session, instance):
"""自动过滤已删除记录"""
if hasattr(instance.__class__, 'is_deleted'):
session.query(instance.__class__).filter(
instance.__class__.is_deleted == False
)
# 数据验证
from sqlalchemy.orm import validates
class User(Base):
__tablename__ = 'users'
email = Column(String(100), nullable=False)
@validates('email')
def validate_email(self, key, address):
"""验证邮箱格式"""
import re
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, address):
raise ValueError("邮箱格式不正确")
return address
10.4 调试技巧
# 1. 启用SQL日志
engine = create_engine('sqlite:///example.db', echo=True)
# 2. 自定义日志格式
import logging
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
# 3. 查看生成的SQL
from sqlalchemy.dialects import mysql, postgresql, sqlite
query = session.query(User).filter(User.is_active == True)
# 查看不同数据库的SQL
print("MySQL SQL:", query.statement.compile(dialect=mysql.dialect()))
print("PostgreSQL SQL:", query.statement.compile(dialect=postgresql.dialect()))
print("SQLite SQL:", query.statement.compile(dialect=sqlite.dialect()))
# 4. 使用explain分析查询计划
def explain_query(session, query):
"""分析查询执行计划"""
sql = str(query.statement.compile(compile_kwargs={"literal_binds": True}))
explain_sql = f"EXPLAIN QUERY PLAN {sql}"
result = session.execute(explain_sql)
print("查询执行计划:")
for row in result:
print(row)
# 使用示例
with Session() as session:
query = session.query(User).filter(User.is_active == True)
explain_query(session, query)
总结
本教程全面介绍了SQLAlchemy的核心概念、基本用法和高级特性。通过学习本教程,你应该能够:
- 理解SQLAlchemy架构:掌握ORM层、Core层、Engine层的作用和关系
- 熟练使用基本功能:数据模型定义、CRUD操作、查询语法
- 掌握高级特性:复杂查询、关系映射、事务管理、性能优化
- 应用最佳实践:连接池配置、缓存策略、错误处理、调试技巧
- 构建实际项目:通过博客系统示例了解完整的开发流程
学习建议
- 循序渐进:从基础概念开始,逐步深入高级特性
- 动手实践:运行示例代码,修改参数观察结果
- 阅读文档:结合官方文档深入理解细节
- 项目实战:在实际项目中应用所学知识
- 持续学习:关注SQLAlchemy更新,学习新特性
进阶方向
- 异步编程:学习SQLAlchemy的异步支持
- 微服务架构:在分布式系统中使用SQLAlchemy
- 数据库优化:深入学习数据库性能调优
- 框架集成:与Flask、FastAPI、Django等框架集成
- 数据迁移:使用Alembic进行数据库版本管理
希望这份教程能够帮助你掌握SQLAlchemy,在Python数据库开发中游刃有余!