transforms的使用 小土堆pytorch记录

发布于:2025-08-15 ⋅ 阅读:(19) ⋅ 点赞:(0)

关键知识点总结

1. TensorBoard可视化
  • 作用:实时可视化训练过程、模型结构和数据

  • 核心组件

    • SummaryWriter:创建日志写入器

    • add_image():记录图像数据

    • add_scalar():记录标量数据(如损失函数)

  • 使用流程

writer = SummaryWriter("logs")  # 创建写入器
writer.add_image("tag", tensor, step)  # 记录数据
writer.close()  # 关闭资源
  • 查看日志:终端执行 tensorboard --logdir=logs

2. 图像预处理(transforms)
变换类型 作用 重要参数 输入/输出类型
ToTensor PIL图像 → Tensor格式 PIL → Tensor[C,H,W]
Normalize 标准化(均值&标准差) (mean), (std) Tensor → Tensor
Resize 调整图像尺寸 (h,w) 或 int(短边尺寸) PIL → PIL
RandomCrop 随机裁剪(数据增强) (size) PIL → PIL
Compose 组合多个变换 [transform1, ...] 链式执行
3. 标准化(Normalize)原理
# 计算公式
normalized = (input - mean) / std

# 示例计算(单通道):
原始值 = 0.8
mean = 0.5, std = 0.5
结果 = (0.8 - 0.5)/0.5 = 0.6
4. 维度顺序变化
  • PIL图像:(Width, Height, Channels)

  • Tensor图像:(Channels, Height, Width)

  • 转换时机ToTensor自动完成维度转换

视频代码

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

# 创建TensorBoard写入器,日志保存在logs目录
writer = SummaryWriter("logs")

# 打开图像并转换为RGB格式(确保三通道)
img = Image.open("img/redraw.png").convert("RGB")
print(img)  # 打印原始图像信息(尺寸/格式)

# 1. ToTensor转换(核心知识点)
trans_totensor = transforms.ToTensor()  # 创建转换器
img_tensor = trans_totensor(img)  # PIL图像 -> Tensor[0,1]
writer.add_image("ToTensor", img_tensor)  # 写入TensorBoard

# 2. Normalize标准化(核心知识点)
print("原始值:", img_tensor[0][0][0])  # 查看原始像素值
trans_norm = transforms.Normalize(
    [0.5, 0.5, 0.5],  # RGB三通道均值
    [0.5, 0.5, 0.5]   # RGB三通道标准差
)
img_norm = trans_norm(img_tensor)  # 应用标准化:(input-mean)/std
print("归一化后:", img_norm[0][0][0])  # 查看标准化后的值
writer.add_image("Normalize", img_norm, 2)  # 写入TensorBoard(step=2)

# 3. Resize调整尺寸
print("原始尺寸:", img.size)  # PIL格式尺寸 (width, height)
trans_resize = transforms.Resize((512, 512))  # 目标尺寸 (height, width)
img_resize = trans_resize(img)  # 调整尺寸(返回PIL图像)
img_resize_tensor = trans_totensor(img_resize)  # 再次转换为Tensor
writer.add_image("Resize", img_resize_tensor, 0)  # 写入TensorBoard

# 4. Compose组合变换(核心知识点)
trans_resize_2 = transforms.Resize(512)  # 单参数:调整短边为512
trans_compose = transforms.Compose([
    trans_resize_2,  # 第一步:调整尺寸
    trans_totensor   # 第二步:转Tensor
])
img_compose = trans_compose(img)  # 自动顺序执行
writer.add_image("Resize_Compose", img_compose, 1)  # 写入TensorBoard

# 5. RandomCrop随机裁剪(数据增强)
trans_random = transforms.RandomCrop(512)  # 随机裁剪为512x512
trans_compose_2 = transforms.Compose([
    trans_random,    # 随机裁剪
    trans_totensor   # 转Tensor
])
for i in range(10):  # 生成10个随机裁剪样本
    img_crop = trans_compose_2(img)
    writer.add_image("RandomCrop", img_crop, i)  # 每个样本单独记录

writer.close()  # 关闭写入器

知识拓展

1. 常用transforms扩展
# 颜色空间变换
transforms.Grayscale()      # 转灰度图
transforms.ColorJitter()    # 随机调整亮度/对比度

# 几何变换
transforms.RandomRotation(30)   # 随机旋转(-30°~30°)
transforms.RandomHorizontalFlip(p=0.5)  # 50%概率水平翻转

# 组合示例(典型数据增强)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

调试技巧

# 检查中间结果
print(type(img))          # 查看当前数据类型
print(img.shape)          # 查看Tensor维度
img.show()                # 显示PIL图像

transforms是构建数据预处理流水线的核心工具,掌握它们能大幅提升数据准备效率。TensorBoard则是模型训练过程的"望远镜",帮助我们直观理解训练动态。


网站公告

今日签到

点亮在社区的每一天
去签到