装饰器在Python中的作用
1. 装饰器是什么?为什么它很重要?
装饰器(Decorator)是Python中的一种高级语法,用于在不修改原函数代码的情况下,动态增强函数的功能。它的核心思想是**"装饰"现有函数**,类似于给手机套壳——手机本身功能不变,但多了保护或附加功能。
1.1 装饰器的核心作用
- 代码复用:避免重复写相同的逻辑(如日志、计时、权限检查)
- 非侵入式扩展:不改动原函数代码就能添加功能
- 提高可读性:通过
@decorator
语法,明确功能增强意图
2. 装饰器在PyTorch中的实战案例
2.1 案例1:函数执行计时器
在模型训练中,经常需要统计某个函数的运行时间:
import time
import torch
from functools import wraps
def timer(func):
@wraps(func) # 保留原函数的元信息
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f"{func.__name__} executed in {end - start:.4f}s")
return result
return wrapper
# 使用装饰器统计训练耗时
@timer
def train_one_epoch(model, dataloader, optimizer):
model.train()
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
# 调用时会自动打印执行时间
train_one_epoch(model, train_loader, optim.Adam(model.parameters()))
输出示例:
train_one_epoch executed in 12.3456s
2.2 案例2:自动切换模型状态
在PyTorch中,训练和评估模式需要手动切换,用装饰器可以自动化:
def set_mode(mode='train'):
def decorator(func):
@wraps(func)
def wrapper(model, *args, **kwargs):
if mode == 'train':
model.train()
else:
model.eval()
return func(model, *args, **kwargs)
return wrapper
return decorator
# 训练时自动切换为train模式
@set_mode('train')
def custom_train_step(model, data):
# ...训练逻辑
pass
# 评估时自动切换为eval模式
@set_mode('eval')
def custom_eval_step(model, data):
# ...评估逻辑
pass
3. 装饰器在MMDetection中的高级应用
MMDetection作为目标检测框架,大量使用装饰器实现模块化设计。
3.1 案例1:注册自定义模块
MMDetection通过@MODELS.register_module()
装饰器实现插件化架构:
from mmdet.models import MODELS
@MODELS.register_module() # 注册自定义Backbone
class MyBackbone(nn.Module):
def __init__(self, depth=50):
super().__init__()
# ...自定义实现
# 配置文件中可直接使用
cfg = dict(
backbone=dict(type='MyBackbone', depth=101) # 直接调用注册的类
)
3.2 案例2:Hook机制增强训练流程
MMDetection用装饰器实现训练Hook(如学习率调整):
from mmcv.runner import HOOKS, Hook
@HOOKS.register_module() # 注册自定义Hook
class MyCustomHook(Hook):
def before_train_epoch(self, runner):
print(f"Before epoch {runner.epoch}!")
# 配置中添加Hook
custom_hooks = [
dict(type='MyCustomHook', priority='NORMAL')
]
4. 装饰器的底层原理
理解装饰器需要掌握三个关键概念:
- 函数是一等公民:可以像变量一样传递
- 闭包(Closure):内层函数记住外层作用域
- 语法糖
@
:@decorator
等价于func = decorator(func)
执行流程:
@timer
def foo(): pass
# 等价于
foo = timer(foo)