知识点回顾
- 回调函数
- lambda函数
- hook函数的模块钩子和张量钩子
- Grad-CAM的示例
作业:理解下今天的代码即可
# 定义一个回调函数
def handle_result(result):
"""处理计算结果的回调函数"""
print(f"计算结果是: {result}")
# 定义一个接受回调函数的函数
def calculate(a, b, callback): # callback是一个约定俗成的参数名
"""
这个函数接受两个数值和一个回调函数,用于处理计算结果。
执行计算并调用回调函数
"""
result = a + b
callback(result) # 在计算完成后调用回调函数
# 使用回调函数
calculate(3, 5, handle_result) # 输出: 计算结果是: 8
def handle_result(result):
"""处理计算结果的回调函数"""
print(f"计算结果是: {result}")
def with_callback(callback):
"""装饰器工厂:创建一个将计算结果传递给回调函数的装饰器"""
def decorator(func):
"""实际的装饰器,用于包装目标函数"""
def wrapper(a, b):
"""被装饰后的函数,执行计算并调用回调"""
result = func(a, b) # 执行原始计算
callback(result) # 调用回调函数处理结果
return result # 返回计算结果(可选)
return wrapper
return decorator
# 使用装饰器包装原始计算函数
@with_callback(handle_result)
def calculate(a, b):
"""执行加法计算"""
return a + b
# 直接调用被装饰后的函数
calculate(3, 5) # 输出: 计算结果是: 8
前向钩子
import torch
import torch.nn as nn
# 定义一个简单的卷积神经网络模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
# 定义卷积层:输入通道1,输出通道2,卷积核3x3,填充1保持尺寸不变
self.conv = nn.Conv2d(1, 2, kernel_size=3, padding=1)
# 定义ReLU激活函数
self.relu = nn.ReLU()
# 定义全连接层:输入特征2*4*4,输出10分类
self.fc = nn.Linear(2 * 4 * 4, 10)
def forward(self, x):
# 卷积操作
x = self.conv(x)
# 激活函数
x = self.relu(x)
# 展平为一维向量,准备输入全连接层
x = x.view(-1, 2 * 4 * 4)
# 全连接分类
x = self.fc(x)
return x
# 创建模型实例
model = SimpleModel()
# 创建一个列表用于存储中间层的输出
conv_outputs = []
# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):
"""
前向钩子函数,会在模块每次执行前向传播后被自动调用
参数:
module: 当前应用钩子的模块实例
input: 传递给该模块的输入张量元组
output: 该模块产生的输出张量
"""
print(f"钩子被调用!模块类型: {type(module)}")
print(f"输入形状: {input[0].shape}") # input是一个元组,对应 (image, label)
print(f"输出形状: {output.shape}")
# 保存卷积层的输出用于后续分析
# 使用detach()避免追踪梯度,防止内存泄漏
conv_outputs.append(output.detach())
# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)
# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)
# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)
# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()
# # 打印中间层输出结果
# if conv_outputs:
# print(f"\n卷积层输出形状: {conv_outputs[0].shape}")
# print(f"卷积层输出值示例: {conv_outputs[0][0, 0, :, :]}")
# 让我们可视化卷积层的输出
if conv_outputs:
plt.figure(figsize=(10, 5))
# 原始输入图像
plt.subplot(1, 3, 1)
plt.title('输入图像')
plt.imshow(x[0, 0].detach().numpy(), cmap='gray') # 显示灰度图像
# 第一个卷积核的输出
plt.subplot(1, 3, 2)
plt.title('卷积核1输出')
plt.imshow(conv_outputs[0][0, 0].detach().numpy(), cmap='gray')
# 第二个卷积核的输出
plt.subplot(1, 3, 3)
plt.title('卷积核2输出')
plt.imshow(conv_outputs[0][0, 1].detach().numpy(), cmap='gray')
plt.tight_layout()
plt.show()