目录
在训练深层神经网络时,由于模型参数较多,在数据量不足时很容易过拟合。而正则化技术主要就是用于防止过拟合,提升模型的泛化能力(对新数据表现良好)和鲁棒性(对异常数据表现良好)。
1 概念认知
1.1 过拟合
过拟合是指模型对训练数据拟合能力很强并表现很好,但在测试数据上表现较差。
过拟合常见原因有:
数据量不足:当训练数据较少时,模型可能会过度学习数据中的噪声和细节。
模型太复杂:如果模型很复杂,也会过度学习训练数据中的细节和噪声。
正则化强度不足:如果正则化强度不足,可能会导致模型过度学习训练数据中的细节和噪声。
举个例子:
1.2 欠拟合
欠拟合是由于模型学习能力不足,无法充分捕捉数据中的复杂关系。
1.3 如何判断
过拟合
训练误差低,但验证时误差高。模型在训练数据上表现很好,但在验证数据上表现不佳,说明模型可能过度拟合了训练数据中的噪声或特定模式。
欠拟合
训练误差和测试误差都高。模型在训练数据和测试数据上的表现都不好,说明模型可能太简单,无法捕捉到数据中的复杂模式。
2 解决欠拟合
欠拟合的解决思路比较直接:
增加模型复杂度:引入更多的参数、增加神经网络的层数或节点数量,使模型能够捕捉到数据中的复杂模式。
增加特征:通过特征工程添加更多有意义的特征,使模型能够更好地理解数据。
减少正则化强度:适当减小 L1、L2 正则化强度,允许模型有更多自由度来拟合数据。
训练更长时间:如果是因为训练不足导致的欠拟合,可以增加训练的轮数或时间.
3 解决过拟合
避免模型参数过大是防止过拟合的关键步骤之一。
模型的复杂度主要由权重决定,而不是偏置
。偏置只是对模型输出的平移,不会导致模型过度拟合数据。
怎么控制权重w,使w在比较小的范围内?
考虑损失函数,损失函数的目的是使预测值与真实值无限接近,如果在原来的损失函数上添加一个非0的变量
其中是关于权重w的函数,
要使L1变小,就要使L变小的同时,也要使变小。从而控制权重
在较小的范围内。
3.1 L2正则化
L2 正则化通过在损失函数中添加权重参数的平方和来实现,目标是惩罚过大的参数值。
3.1.1 数学表示
设损失函数为 L(\theta),其中 \theta 表示权重参数,加入L2正则化后的损失函数表示为:
其中:
是原始损失函数(比如均方误差、交叉熵等)。
是正则化强度,控制正则化的力度。
是模型的第
个权重参数。
是所有权重参数的平方和,称为 L2 正则化项。
L2 正则化会惩罚权重参数过大的情况,通过参数平方值对损失函数进行约束。
为什么是?
假设没有1/2,则对L2 正则化项的梯度为:
,会引入一个额外的系数 2,使梯度计算和更新公式变得复杂。
添加1/2后,对的梯度为:
。
3.1.2 梯度更新
在 L2 正则化下,梯度更新时,不仅要考虑原始损失函数的梯度,还要考虑正则化项的影响。更新规则为:
其中:
是学习率。
是损失函数关于参数
的梯度。
是 L2 正则化项的梯度,对应的是参数值本身的衰减。
很明显,参数越大惩罚力度就越大,从而让参数逐渐趋向于较小值,避免出现过大的参数。
3.1.3 作用
防止过拟合:当模型过于复杂、参数较多时,模型会倾向于记住训练数据中的噪声,导致过拟合。L2 正则化通过抑制参数的过大值,使得模型更加平滑,降低模型对训练数据噪声的敏感性。
限制模型复杂度:L2 正则化项强制权重参数尽量接近 0,避免模型中某些参数过大,从而限制模型的复杂度。通过引入平方和项,L2 正则化鼓励模型的权重均匀分布,避免单个权重的值过大。
提高模型的泛化能力:正则化项的存在使得模型在测试集上的表现更加稳健,避免在训练集上取得极高精度但在测试集上表现不佳。
平滑权重分布:L2 正则化不会将权重直接变为 0,而是将权重值缩小。这样模型就更加平滑的拟合数据,同时保留足够的表达能力。
3.1.4 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# 设置随机种子以保证可重复性
torch.manual_seed(42)
# 生成随机数据
n_samples = 100
n_features = 20
X = torch.randn(n_samples, n_features) # 输入数据
y = torch.randn(n_samples, 1) # 目标值
# 定义一个简单的全连接神经网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(n_features, 50)
self.fc2 = nn.Linear(50, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
# 训练函数
def train_model(use_l2=False, weight_decay=0.01, n_epochs=100):
# 初始化模型
model = SimpleNet()
criterion = nn.MSELoss() # 损失函数(均方误差)
# 选择优化器
if use_l2:
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=weight_decay) # 使用 L2 正则化
else:
optimizer = optim.SGD(model.parameters(), lr=0.01) # 不使用 L2 正则化
# 记录训练损失
train_losses = []
# 训练过程
for epoch in range(n_epochs):
optimizer.zero_grad() # 清空梯度
outputs = model(X) # 前向传播
loss = criterion(outputs, y) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
train_losses.append(loss.item()) # 记录损失
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}')
return train_losses
# 训练并比较两种模型
train_losses_no_l2 = train_model(use_l2=False) # 不使用 L2 正则化
train_losses_with_l2 = train_model(use_l2=True, weight_decay=0.01) # 使用 L2 正则化
# 绘制训练损失曲线
plt.plot(train_losses_no_l2, label='Without L2 Regularization')
plt.plot(train_losses_with_l2, label='With L2 Regularization')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss: L2 Regularization vs No Regularization')
plt.legend()
plt.show()
3.2 L1正则化
L1 正则化通过在损失函数中添加权重参数的绝对值之和来约束模型的复杂度。
3.2.1 数学表示
设模型的原始损失函数为,其中
表示模型权重参数,则加入 L1 正则化后的损失函数表示为:
其中:
是原始损失函数。
是正则化强度,控制正则化的力度。
是模型第i 个参数的绝对值。
是所有权重参数的绝对值之和,这个项即为 L1 正则化项。
3.2.2 梯度更新
在 L1 正则化下,梯度更新时的公式是:
其中:
是学习率。
是损失函数关于参数
的梯度。
是参数
的符号函数,表示当
为正时取值为 1,为负时取值为 -1,等于 0 时为 0。
因为 L1 正则化依赖于参数的绝对值,其梯度更新时不是简单的线性缩小,而是通过符号函数来直接调整参数的方向。这就是为什么 L1 正则化能促使某些参数完全变为 0。
3.2.3 作用
稀疏性:L1 正则化的一个显著特性是它会促使许多权重参数变为 零。这是因为 L1 正则化倾向于将权重绝对值缩小到零,使得模型只保留对结果最重要的特征,而将其他不相关的特征权重设为零,从而实现 特征选择 的功能。
防止过拟合:通过限制权重的绝对值,L1 正则化减少了模型的复杂度,使其不容易过拟合训练数据。相比于 L2 正则化,L1 正则化更倾向于将某些权重完全移除,而不是减小它们的值。
简化模型:由于 L1 正则化会将一些权重变为零,因此模型最终会变得更加简单,仅依赖于少数重要特征。这对于高维度数据特别有用,尤其是在特征数量远多于样本数量的情况下。
特征选择:因为 L1 正则化会将部分权重置零,因此它天然具有特征选择的能力,有助于自动筛选出对模型预测最重要的特征。
3.2.4 与L2对比
L1 正则化 更适合用于产生稀疏模型,会让部分权重完全为零,适合做特征选择。
L2 正则化 更适合平滑模型的参数,避免过大参数,但不会使权重变为零,适合处理高维特征较为密集的场景。
3.2.5 代码实现
l1_lambda = 0.001
# 计算 L1 正则化项并将其加入到总损失中
l1_norm = sum(p.abs().sum() for p in model.parameters())
loss = loss + l1_lambda * l1_norm
3.3 Dropout
Dropout 的工作流程如下:
在每次训练迭代中,随机选择一部分神经元(通常以概率 p丢弃,比如 p=0.5)。
被选中的神经元在当前迭代中不参与前向传播和反向传播。
在测试阶段,所有神经元都参与计算,但需要对权重进行缩放(通常乘以 1−p),以保持输出的期望值一致。
Dropout 是一种在训练过程中随机丢弃部分神经元的技术。它通过减少神经元之间的依赖来防止模型过于复杂,从而避免过拟合。
3.3.1 基本实现
import torch
import torch.nn as nn
def dropout():
dropout = nn.Dropout(p=0.5)
x = torch.randint(0, 10, (5, 6), dtype=torch.float)
print(x)
# 开始dropout
print(dropout(x))
if __name__ == "__main__":
dropout()
Dropout过程:
按照指定的概率把部分神经元的值设置为0;
为了规避该操作带来的影响,需对非 0 的元素使用缩放因子1/(1-p)进行强化。
假设某个神经元的输出为 x,Dropout 的操作可以表示为:
在训练阶段:
在测试阶段:
为什么要使用缩放因子1/(1-p)?
在训练阶段,Dropout 会以概率 p随机将某些神经元的输出设置为 0,而以概率 1−p 保留这些神经元。
假设某个神经元的原始输出是 x,那么在训练阶段,它的期望输出值为:
通过这种缩放,训练阶段的期望输出值仍然是 x,与没有 Dropout 时一致。
3.3.2 权重影响
示例:对图片进行随机丢弃
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import os
from matplotlib import pyplot as plt
torch.manual_seed(42)
def load_img(path, resize=(224, 224)):
pil_img = Image.open(path).convert('RGB')
print("Original image size:", pil_img.size) # 打印原始尺寸
transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor() # 转换为Tensor并自动归一化到[0,1]
])
return transform(pil_img) # 返回[C,H,W]格式的tensor
if __name__ == '__main__':
dirpath = os.path.dirname(__file__)
path = os.path.join(dirpath, 'img', '100.jpg') # 使用os.path.join更安全
# 加载图像 (已经是[0,1]范围的Tensor)
trans_img = load_img(path)
# 添加batch维度 [1, C, H, W],因为Dropout默认需要4D输入
trans_img = trans_img.unsqueeze(0)
# 创建Dropout层
dropout = nn.Dropout2d(p=0.2)
drop_img = dropout(trans_img)
# 移除batch维度并转换为[H,W,C]格式供matplotlib显示
trans_img = trans_img.squeeze(0).permute(1, 2, 0).numpy()
drop_img = drop_img.squeeze(0).permute(1, 2, 0).numpy()
# 确保数据在[0,1]范围内
drop_img = drop_img.clip(0, 1)
# 显示图像
fig = plt.figure(figsize=(10, 5))
ax1 = fig.add_subplot(1, 2, 1)
ax1.imshow(trans_img)
ax2 = fig.add_subplot(1, 2, 2)
ax2.imshow(drop_img)
plt.show()
效果:
说明:
nn.Dropout2d(p):Dropout2d
是针对二维数据设计的 Dropout 层,它在训练过程中随机将输入张量的某些通道(二维平面)置为零。
参数 | 要求格式 | 示例形状 | 说明 |
---|---|---|---|
输入 | (N, C, H, W) |
(16, 64, 32, 32) |
批大小×通道×高×宽 |
输出 | (N, C, H, W) |
(16, 64, 32, 32) |
与输入同形,部分通道归零 |
3.4 数据增强
样本数量不足(即训练数据过少)是导致过拟合(Overfitting)的常见原因之一,可以从以下角度理解:
当训练数据过少时,模型容易“记住”有限的样本(包括噪声和无关细节),而非学习通用的规律。
简单模型更可能捕捉真实规律,但数据不足时,复杂模型会倾向于拟合训练集中的偶然性模式(噪声)。
样本不足时,训练集的分布可能与真实分布偏差较大,导致模型学到错误的规律。
小数据集中,个别样本的噪声(如标注错误、异常值)会被放大,模型可能将噪声误认为规律。
数据增强(Data Augmentation)是一种通过人工生成或修改训练数据来增加数据集多样性的技术,常用于解决过拟合问题。数据增强通过“模拟”更多训练数据,迫使模型学习泛化性更强的规律,而非训练集中的偶然性模式。其本质是一种低成本的正则化手段,尤其在数据稀缺时效果显著。
在了解计算机如何处理图像之前,需要先了解图像的构成元素。
图像是由像素点组成的,每个像素点的值范围为: [0, 255], 像素值越大意味着较亮。比如一张 200x200 的图像, 则是由 40000 个像素点组成, 如果每个像素点都是 0 的话, 意味着这是一张全黑的图像。
我们看到的彩色图一般都是多通道的图像, 所谓多通道可以理解为图像由多个不同的图像层叠加而成, 例如我们看到的彩色图像一般都是由 RGB 三个通道组成的,还有一些图像具有 RGBA 四个通道,最后一个通道为透明通道,该值越小,则图像越透明。
数据增强是提高模型泛化能力(鲁棒性)的一种有效方法,尤其在图像分类、目标检测等任务中。数据增强可以模拟更多的训练样本,从而减少过拟合风险。数据增强通过torchvision.transforms模块来实现。
数据增强的好处
大幅度降低数据采集和标注成本;
模型过拟合风险降低,提高模型泛化能力;
官方地址:
transforms:Transforming and augmenting images — Torchvision 0.22 documentation
transforms:
常用变换类
transforms.Compose:将多个变换操作组合成一个流水线。
transforms.ToTensor:将 PIL 图像或 NumPy 数组转换为 PyTorch 张量,将图像数据从 uint8 类型 (0-255) 转换为 float32 类型 (0.0-1.0)。
transforms.Normalize:对张量进行标准化。
transforms.Resize:调整图像大小。
transforms.CenterCrop:从图像中心裁剪指定大小的区域。
transforms.RandomCrop:随机裁剪图像。
transforms.RandomHorizontalFlip:随机水平翻转图像。
transforms.RandomVerticalFlip:随机垂直翻转图像。
transforms.RandomRotation:随机旋转图像。
transforms.ColorJitter:随机调整图像的亮度、对比度、饱和度和色调。
transforms.RandomGrayscale:随机将图像转换为灰度图像。
transforms.RandomResizedCrop:随机裁剪图像并调整大小。
3.4.1 图片缩放
具体参考官方文档:Illustration of transforms — Torchvision 0.22 documentation
参考代码:
from PIL import Image
def test03():
img1 = plt.imread('./img/100.jpg')
plt.imshow(img1)
plt.show()
img = Image.open('./img/100.jpg')
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
r_img = transform(img)
print(r_img.shape)
r_img = r_img.permute(1, 2, 0)
plt.imshow(r_img)
plt.show()
3.4.2 随机裁剪
img = Image.open('./img/100.jpg')
transform = transforms.Compose([transforms.RandomCrop(size=(224, 224)), transforms.ToTensor()])
r_img = transform(img)
print(r_img.shape)
r_img = r_img.permute(1, 2, 0)
plt.imshow(r_img)
plt.show()
3.4.3 随机水平翻转
RandomHorizontalFlip(p):随机水平翻转图像,参数p表示翻转概率(0 ≤ p
≤ 1),p=1
表示必定翻转,p=0
表示不翻转
img = Image.open('./img/100.jpg')
transform = transforms.Compose([transforms.RandomHorizontalFlip(p=1), transforms.ToTensor()])
r_img = transform(img)
print(r_img.shape)
r_img = r_img.permute(1, 2, 0)
plt.imshow(r_img)
plt.show()
3.4.4 调整图片颜色
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
brightness:
亮度调整的范围。
可以
float
或(min, max)
元组:如果是
float
(如brightness=0.2
),则亮度在[max(0, 1 - 0.2), 1 + 0.2] = [0.8, 1.2]
范围内随机缩放。如果是
(min, max)
(如brightness=(0.5, 1.5)
),则亮度在[0.5, 1.5]
范围内随机缩放。
contrast:
对比度调整的范围。
格式与 brightness 相同。
saturation:
饱和度调整的范围。
格式与 brightness 相同。
hue:
色调调整的范围。
可以是一个浮点数(表示相对范围)或一个元组 (min, max)。
取值范围必须为
[-0.5, 0.5]
(因为色相在 HSV 色彩空间中是循环的,超出范围会导致颜色异常)。例如,hue=0.1 表示色调在 [-0.1, 0.1] 之间随机调整。
img = Image.open('./img/100.jpg')
transform = transforms.Compose([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor()])
r_img = transform(img)
print(r_img.shape)
r_img = r_img.permute(1, 2, 0)
plt.imshow(r_img)
plt.show()
3.4.5 随机旋转
RandomRotation用于对图像进行随机旋转。
transforms.RandomRotation(
degrees,
interpolation=InterpolationMode.NEAREST,
expand=False,
center=None,
fill=0
)
degrees:
旋转角度的范围,可以是一个浮点数或元组 (min_degree, max_degree)。
例如,degrees=30 表示旋转角度在 [-30, 30] 之间随机选择。
例如,degrees=(30, 60) 表示旋转角度在 [30, 60] 之间随机选择。
interpolation:
插值方法,用于旋转图像。
默认是 InterpolationMode.NEAREST(最近邻插值)。
其他选项包括 InterpolationMode.BILINEAR(双线性插值)、InterpolationMode.BICUBIC(双三次插值)等。
expand:
是否扩展图像大小以适应旋转后的图像。如:当需要保留完整旋转后的图像时(如医学影像、文档扫描)
如果为 True,旋转后的图像可能会比原始图像大。
如果为 False,旋转后的图像大小与原始图像相同。
center:
旋转中心点的坐标,默认为图像中心。
可以是一个元组 (x, y),表示旋转中心的坐标。
fill:
旋转后图像边缘的填充值。
可以是一个浮点数(用于灰度图像)或一个元组(用于 RGB 图像)。默认填充0(黑色)
# 加载图像
image = Image.open('./img/100.jpg')
# 定义 RandomRotation 变换
transform = transforms.RandomRotation(degrees=30) # 旋转角度在 [-30, 30] 之间随机选择
# 应用变换
rotated_image = transform(image)
# 显示图像
plt.imshow(rotated_image)
plt.axis('off')
plt.show()
3.4.6 图片转Tensor
import torch
from PIL import Image
from torchvision import transforms
import os
def test001():
dir_path = os.path.dirname(__file__)
file_path = os.path.join(dir_path,'img', '1.jpg')
file_path = os.path.relpath(file_path)
print(file_path)
# 1. 读取图片
img = Image.open(file_path)
# transforms.ToTensor()用于将 PIL 图像或 NumPy 数组转换为 PyTorch 张量,并自动进行数值归一化和维度调整
# 将像素值从 [0, 255] 缩放到 [0.0, 1.0](浮点数)
# 自动将图像格式从 (H, W, C)(高度、宽度、通道)转换为 PyTorch 标准的 (C, H, W)
transform = transforms.ToTensor()
img_tensor = transform(img)
print(img_tensor)
if __name__ == "__main__":
test001()
3.4.7 Tensor转图片
import torch
from PIL import Image
from torchvision import transforms
def test002():
# 1. 随机一个数据表示图片
img_tensor = torch.randn(3, 224, 224)
# 2. 创建一个transforms
transform = transforms.ToPILImage()
# 3. 转换为图片
img = transform(img_tensor)
img.show()
# 4. 保存图片
img.save("./test.jpg")
if __name__ == "__main__":
test002()
练习:通过一个Demo加深对Torch的API理解和使用
import torch
from PIL import Image
from torchvision import transforms
import os
def test003():
# 获取文件的相对路径
dir_path = os.path.dirname(__file__)
file_path = os.path.relpath(os.path.join(dir_path, 'dog.jpg'))
# 加载图片
img = Image.open(file_path)
# 转换图片为tensor
transform = transforms.ToTensor()
t_img = transform(img)
print(t_img.shape)
# 获取GPU资源,将图片处理移交给CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
t_img = t_img.to(device)
t_img += 0.3
# 将图片移交给CPU进行图片保存处理,一般IO操作是基于CPU的
t_img = t_img.cpu()
transform = transforms.ToPILImage()
img = transform(t_img)
img.show()
if __name__ == "__main__":
test003()
3.4.8 归一化
标准化:将图像的像素值从原始范围(如 [0, 255] 或 [0, 1])转换为均值为 0、标准差为 1 的分布。
加速训练:标准化后的数据分布更均匀,有助于加速模型训练。
提高模型性能:标准化可以使模型更容易学习到数据的特征,提高模型的收敛性和稳定性。
img = Image.open('./img/100.jpg')
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
r_img = transform(img)
print(r_img.shape)
r_img = r_img.permute(1, 2, 0)
plt.imshow(r_img)
plt.show()
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
均值(Mean):数据集中所有图像在每个通道上的像素值的平均值。
标准差(Std):数据集中所有图像在每个通道上的像素值的标准差。
RGB 三个通道的均值和标准差 不是随便定义的,而是需要根据具体的数据集进行统计计算。这些值是 ImageNet 数据集的统计结果,已成为计算机视觉任务的默认标准。
数据集计算均值和标准差
以CIFAR10数据集为例:
# 获取数据集
train_data = datasets.CIFAR10(
root='./cifar10',
train=True,
download=True,
transform=transforms.ToTensor() # 自动将PIL图像转为[0,1]范围的张量
)
def compute_mean_std(dataset):
# 初始化累加器
mean = torch.zeros(3)
std = torch.zeros(3)
num_samples = len(dataset)
# 遍历数据集计算均值
for img, _ in dataset:
# dim=(1, 2) 表示对图像的高度(H)和宽度(W)维度求均值,保留通道维度(C)。
mean += img.mean(dim=(1, 2))
# 全局的通道均值
mean /= num_samples
print(mean)
# 遍历数据集计算标准差
for img, _ in dataset:
# 原始mean 是一个形状为 [3] 的张量,表示每个通道的均值。
# 使用 view(3, 1, 1) 将 mean 的形状从 [3] 改变为 [3, 1, 1]。
# 这样,mean 的形状变为 [3, 1, 1],其中 3 表示通道数,1 和 1 分别表示高度和宽度的维度。
# 当执行 img - mean.view(3, 1, 1) 时,PyTorch 会利用广播机制将 mean 自动扩展到与 img 相同的形状 [3, H, W]。
# 然后利用方差公式计算:var=E(x-E(x))^2
std += (img - mean.view(3, 1, 1)).pow(2).mean(dim=(1, 2))
# 计算出所有图片的方差后,计算平均方差,然后求标准差
std = torch.sqrt(std / num_samples)
return mean, std
mean, std = compute_mean_std(train_data)
print(f"Mean: {mean}") # 输出类似 [0.4914, 0.4822, 0.4465]
print(f"Std: {std}") # 输出类似 [0.2470, 0.2435, 0.2616]
3.4.9 数据增强整合
使用transforms.Compose()把要增强的操作整合到一起:
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms, datasets, utils
def test01():
# 定义数据增强和归一化
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(10), # 随机旋转 ±10 度
transforms.RandomResizedCrop(
32, scale=(0.8, 1.0)
), # 随机裁剪到 32x32,缩放比例在0.8到1.0之间
transforms.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
), # 随机调整亮度、对比度、饱和度、色调
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化,这是一种常见的经验设置,适用于数据范围 [0, 1],使其映射到 [-1, 1]
]
)
# 加载 CIFAR-10 数据集,并应用数据增强
trainset = datasets.CIFAR10(root="./cifar10_data", train=True, download=True, transform=transform)
dataloader = DataLoader(trainset, batch_size=4, shuffle=False)
# 获取一个批次的数据
images, labels = next(iter(dataloader))
# 还原图片并显示
plt.figure(figsize=(10, 5))
for i in range(4):
# 反归一化:将像素值从 [-1, 1] 还原到 [0, 1]
img = images[i] / 2 + 0.5
# 转换为 PIL 图像
img_pil = transforms.ToPILImage()(img)
# 显示图片
plt.subplot(1, 4, i + 1)
plt.imshow(img_pil)
plt.axis('off')
plt.title(f'Label: {labels[i]}')
plt.show()
if __name__ == "__main__":
test01()
代码解释:
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
若数据分布与ImageNet差异较大(如医学影像、卫星图、MNIST等),或均值和标准差未知时,可用此简化设置。
将图片进行归一化,使数据更符合正态分布,归一化公式:
img = img / 2 + 0.5
表示反归一化,是归一化的逆运算:
img=\frac{normalized\_img+1}{2} =\frac{normalized\_img}{2}+0.5