(pt可视化)利用torch的make_grid进行张量可视化

发布于:2022-12-08 ⋅ 阅读:(356) ⋅ 点赞:(0)

在使用pytorch时,有时候需要对张量进行可视化,比如在经过一堆数据预处理后,我们从dataloader拿到了一个张量:[8,3,224,224],显然这是一个bs=8且为RGB的张量,一般来说经过ToTensor和Normalize后值范围在[-1,1],如果想看看这些张量是什么样子,一堆代码还是挺麻烦的,所以利用torch提供的make_grid和plt就能够轻松可视化张量。
在这里插入图片描述

make_graid()

def make_grid(
    tensor: Union[torch.Tensor, List[torch.Tensor]], 
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
    value_range: Optional[Tuple[int, int]] = None,
    scale_each: bool = False,
    pad_value: int = 0,
    **kwargs
) -> torch.Tensor:
  • tensor:要可视化的张量,比如为[8,3,224,224]
  • nrow:列数,行数=bs/列数
  • padding:不同图像之间的间隙大小
  • normalize:是否归一化,若是则按图像最大最小值归一化到[0,1]
  • value_range:指定normalize使用的最大最小值,默认使用图像本身的最大最小值
  • scale_each:是否单独为图像进行normalize。默认所有的图像都进行normalize
  • pad_value:间隙的填充值。范围在0(间隙为黑色)~1(间隙为白色)之间
  • return:返回(C,H,W)数据(多张图拼凑成了一张图)

plt
上面我们得到了makr_graid生成的图像,我们使用plt来可视化:

npimg = vis.numpy()  # plt输入需要时ndarray
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')  # 需要将通道转到最后一维
plt.show()

关于plt显示图像,详见:7、显示图片

效果如下:
在这里插入图片描述

本文含有隐藏内容,请 开通VIP 后查看

网站公告

今日签到

点亮在社区的每一天
去签到