首先情况是开始训练正常,网络也在更新,更新后网络就输出了NaN。调试过程:
1. 查看模型权重更新前后的值
print("更新前权重信息:")
print(f" 权重均值: {fc2.weight.mean().item() if not torch.isnan(fc2.weight.mean()) else 'NaN'}")
print(f" 最大值: {fc2.weight.max().item() if not torch.isnan(fc2.weight.max()) else 'NaN'}")
print(f" 最小值: {fc2.weight.min().item() if not torch.isnan(fc2.weight.min()) else 'NaN'}\n")
权重更新
print("更新后权重信息:")
print(f" 权重均值: {fc2.weight.mean().item() if not torch.isnan(fc2.weight.mean()) else 'NaN'}")
print(f" 最大值: {fc2.weight.max().item() if not torch.isnan(fc2.weight.max()) else 'NaN'}")
print(f" 最小值: {fc2.weight.min().item() if not torch.isnan(fc2.weight.min()) else 'NaN'}\n")
判断标准:
- 权重 / 偏置的绝对值如果超过
1e4
,可能导致输出过大。 - 若训练中权重突然变得极大,说明可能存在梯度爆炸。
2. 发现权重更新前正常,更新后NaN
权重在参数更新后变成了NaN
,这说明问题出在反向传播和参数更新环节(梯度计算或优化器步骤导致权重被更新为异常值)。
原因分析:
权重从正常数值突然变成NaN
,几乎可以确定是梯度爆炸导致的:
- 反向传播时计算出的梯度为
NaN
或极端大值(如1e20
),优化器用这些异常梯度更新权重,直接导致权重变成NaN
。 - 常见触发点:损失函数计算异常(如
NaN
损失)、输入数据极端值导致中间激活值爆炸、学习率过高放大梯度影响。
第一步
检查损失函数是否为NaN
如果损失本身输出是NaN
,反向传播的梯度必然是NaN
,直接导致权重更新异常。在反向传播前检查损失需要。
损失为NaN
的常见原因:
损失中包含
log(0)
(如nn.Softmax
输出接近 0 时,torch.log(prob)
会趋近于-inf
)
结果:发现正是损失函数torch.log输出了NaN。
3 解决
限制torch.log的值,NaN的问题得到了解决