DAY 42 Grad-CAM与Hook函数
1.回调函数
def handle_result ( result) :
print ( f'计算结果是: { result} ' )
def calculate ( a, b, callback) :
result = a + b
callback( result)
calculate( 3 , 5 , handle_result)
计算结果是: 8
def handle_result ( result) :
print ( f'计算结果是: { result} ' )
def with_callback ( callback) :
def decorator ( func) :
def wrapper ( a, b) :
result = func( a, b)
callback( result)
return result
return wrapper
return decorator
@with_callback ( handle_result)
def calculate ( a, b) :
return a + b
calculate( 3 , 5 )
计算结果是: 8
8
2.lambda函数
square = lambda x: x ** 2
print ( square( 5 ) )
25
3.hook函数的模块钩子和张量钩子
import torch
import torch. nn as nn
import numpy as np
import matplotlib. pyplot as plt
torch. manual_seed( 42 )
np. random. seed( 42 )
import torch
import torch. nn as nn
class SimpleModel ( nn. Module) :
def __init__ ( self) :
super ( SimpleModel, self) . __init__( )
self. conv = nn. Conv2d( 1 , 2 , kernel_size= 3 , padding= 1 )
self. relu = nn. ReLU( )
self. fc = nn. Linear( 2 * 4 * 4 , 10 )
def forward ( self, x) :
x = self. conv( x)
x = self. relu( x)
x = x. view( - 1 , 2 * 4 * 4 )
x = self. fc( x)
return x
model = SimpleModel( )
conv_outputs = [ ]
def forward_hook ( module, input , output) :
print ( f'钩子被调用!模块类型: { type ( module) } ' )
print ( f'输入形状: { input [ 0 ] . shape} ' )
print ( f'输出形状: { output. shape} ' )
conv_outputs. append( output. detach( ) )
hook_handle = model. conv. register_forward_hook( forward_hook)
x = torch. randn( 1 , 1 , 4 , 4 )
output = model( x)
hook_handle. remove( )
钩子被调用!模块类型: <class 'torch.nn.modules.conv.Conv2d'>
输入形状: torch.Size([1, 1, 4, 4])
输出形状: torch.Size([1, 2, 4, 4])
import warnings
import matplotlib. pyplot as plt
warnings. filterwarnings( 'ignore' )
plt. rcParams[ 'font.family' ] = [ 'SimHei' ]
plt. rcParams[ 'axes.unicode_minus' ] = False
if conv_outputs:
plt. figure( figsize= ( 10 , 5 ) )
plt. subplot( 1 , 3 , 1 )
plt. title( '输入图像' )
plt. imshow( x[ 0 , 0 ] . detach( ) . numpy( ) , cmap= 'gray' )
plt. subplot( 1 , 3 , 2 )
plt. title( '卷积核1输出' )
plt. imshow( conv_outputs[ 0 ] [ 0 , 0 ] . detach( ) . numpy( ) , cmap= 'gray' )
plt. subplot( 1 , 3 , 3 )
plt. title( '卷积核2输出' )
plt. imshow( conv_outputs[ 0 ] [ 0 , 1 ] . detach( ) . numpy( ) , cmap= 'gray' )
plt. tight_layout( )
plt. show( )
conv_gradients = [ ]
def backward_hook ( module, grad_input, grad_output) :
print ( f'反向钩子被调用!模块类型: { type ( module) } ' )
print ( f'输入梯度数量: { len ( grad_input) } ' )
print ( f'输出梯度数量: { len ( grad_output) } ' )
conv_gradients. append( ( grad_input, grad_output) )
hook_handle = model. conv. register_backward_hook( backward_hook)
x = torch. randn( 1 , 1 , 4 , 4 , requires_grad= True )
output = model( x)
loss = output. sum ( )
loss. backward( )
hook_handle. remove( )
反向钩子被调用!模块类型: <class 'torch.nn.modules.conv.Conv2d'>
输入梯度数量: 3
输出梯度数量: 1
x = torch. tensor( [ 2.0 ] , requires_grad= True )
y = x ** 2
z = y ** 3
def tensor_hook ( grad) :
print ( f'原始梯度: { grad} ' )
return grad / 2
hook_handle = y. register_hook( tensor_hook)
z. backward( )
print ( f'x的梯度: { x. grad} ' )
hook_handle. remove( )
原始梯度: tensor([48.])
x的梯度: tensor([96.])
4.Grad-CAM的示例
import torch
import torch. nn as nn
import torch. nn. functional as F
import torchvision
import torchvision. transforms as transforms
import numpy as np
import matplotlib. pyplot as plt
from PIL import Image
torch. manual_seed( 42 )
np. random. seed( 42 )
transform = transforms. Compose( [
transforms. ToTensor( ) ,
transforms. Normalize( ( 0.5 , 0.5 , 0.5 ) , ( 0.5 , 0.5 , 0.5 ) )
] )
testset = torchvision. datasets. CIFAR10(
root= './data' ,
train= False ,
download= True ,
transform= transform
)
classes = ( '飞机' , '汽车' , '鸟' , '猫' , '鹿' , '狗' , '青蛙' , '马' , '船' , '卡车' )
class SimpleCNN ( nn. Module) :
def __init__ ( self) :
super ( SimpleCNN, self) . __init__( )
self. conv1 = nn. Conv2d( 3 , 32 , kernel_size= 3 , padding= 1 )
self. conv2 = nn. Conv2d( 32 , 64 , kernel_size= 3 , padding= 1 )
self. conv3 = nn. Conv2d( 64 , 128 , kernel_size= 3 , padding= 1 )
self. pool = nn. MaxPool2d( 2 , 2 )
self. fc1 = nn. Linear( 128 * 4 * 4 , 512 )
self. fc2 = nn. Linear( 512 , 10 )
def forward ( self, x) :
x = self. pool( F. relu( self. conv1( x) ) )
x = self. pool( F. relu( self. conv2( x) ) )
x = self. pool( F. relu( self. conv3( x) ) )
x = x. view( - 1 , 128 * 4 * 4 )
x = F. relu( self. fc1( x) )
x = self. fc2( x)
return x
model = SimpleCNN( )
print ( '模型已创建' )
device = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
model = model. to( device)
def train_model ( model, epochs= 1 ) :
trainset = torchvision. datasets. CIFAR10(
root= './data' ,
train= True ,
download= True ,
transform= transform
)
trainloader = torch. utils. data. DataLoader(
trainset,
batch_size= 64 ,
shuffle= True ,
num_workers= 2
)
criterion = nn. CrossEntropyLoss( )
optimizer = torch. optim. Adam( model. parameters( ) , lr= 0.001 )
for epoch in range ( epochs) :
running_loss = 0.0
for i, data in enumerate ( trainloader, 0 ) :
inputs, labels = data
inputs, labels = inputs. to( device) , labels. to( device)
optimizer. zero_grad( )
outputs = model( inputs)
loss = criterion( outputs, labels)
loss. backward( )
optimizer. step( )
running_loss += loss. item( )
if i % 100 == 99 :
print ( f'[ { epoch + 1 } , { i + 1 } ] 损失: { running_loss / 100 : .3f } ' )
running_loss = 0.0
print ( '训练完成' )
try :
model. load_state_dict( torch. load( 'cifar10_cnn.pth' ) )
print ( '已加载预训练模型' )
except :
print ( '无法加载预训练模型, 使用未训练模型或训练新模型' )
train_model( model, epochs= 1 )
torch. save( model. state_dict( ) , 'cifar10_cnn.pth' )
model. eval ( )
class GradCAM :
def __init__ ( self, model, target_layer) :
self. model = model
self. target_layer = target_layer
self. gradients = None
self. activations = None
self. register_hooks( )
def register_hooks ( self) :
def forward_hook ( module, input , output) :
self. activations = output. detach( )
def backward_hook ( module, grad_input, grad_output) :
self. gradients = grad_output[ 0 ] . detach( )
self. target_layer. register_forward_hook( forward_hook)
self. target_layer. register_backward_hook( backward_hook)
def generate_cam ( self, input_image, target_class= None ) :
model_output = self. model( input_image)
if target_class is None :
target_class = torch. argmax( model_output, dim= 1 ) . item( )
self. model. zero_grad( )
one_hot = torch. zeros_like( model_output)
one_hot[ 0 , target_class] = 1
model_output. backward( gradient= one_hot)
gradients = self. gradients
activations = self. activations
weights = torch. mean( gradients, dim= ( 2 , 3 ) , keepdim= True )
cam = torch. sum ( weights * activations, dim= 1 , keepdim= True )
cam = F. relu( cam)
cam = F. interpolate( cam, size= ( 32 , 32 ) , mode= 'bilinear' , align_corners= False )
cam = cam - cam. min ( )
cam = cam / cam. max ( ) if cam. max ( ) > 0 else cam
return cam. cpu( ) . squeeze( ) . numpy( ) , target_class
Files already downloaded and verified
模型已创建
已加载预训练模型
import warnings
import matplotlib. pyplot as plt
warnings. filterwarnings( 'ignore' )
plt. rcParams[ 'font.family' ] = [ 'SimHei' ]
plt. rcParams[ 'axes.unicode_minus' ] = False
idx = 102
image, label = testset[ idx]
print ( f'选择的图像类别: { classes[ label] } ' )
def tensor_to_np ( tensor) :
img = tensor. cpu( ) . numpy( ) . transpose( 1 , 2 , 0 )
mean = np. array( [ 0.5 , 0.5 , 0.5 ] )
std = np. array( [ 0.5 , 0.5 , 0.5 ] )
img = std * img + mean
img = np. clip( img, 0 , 1 )
return img
input_tensor = image. unsqueeze( 0 ) . to( device)
grad_cam = GradCAM( model, model. conv3)
heatmap, pred_class = grad_cam. generate_cam( input_tensor)
plt. figure( figsize= ( 12 , 4 ) )
plt. subplot( 1 , 3 , 1 )
plt. imshow( tensor_to_np( image) )
plt. title( f'原始图像: { classes[ label] } ' )
plt. axis( 'off' )
plt. subplot( 1 , 3 , 2 )
plt. imshow( heatmap, cmap= 'jet' )
plt. title( f'Grad-CAM热力图: { classes[ pred_class] } ' )
plt. axis( 'off' )
plt. subplot( 1 , 3 , 3 )
img = tensor_to_np( image)
heatmap_resized = np. uint8( 255 * heatmap)
heatmap_colored = plt. cm. jet( heatmap_resized) [ : , : , : 3 ]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt. imshow( superimposed_img)
plt. title( '叠加热力图' )
plt. axis( 'off' )
plt. tight_layout( )
plt. savefig( 'grad_cam_result.png' )
plt. show( )
选择的图像类别: 青蛙
作业:理解下今天的代码即可
@浙大疏锦行