终于该学习神经网络的搭建了,开心,嘻嘻
学习神经网络离不开torch.nn,先把他印在脑子里,什么是torch.nn?他是Pytorch的一个模块,包含了大量构建神经网络需要的类和方法,就像前面学习的torch.utils,什么?忘了torch.utils是啥了?
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
想起来了吗?
1.nn.Module
先来认识一下神经网络最基本的基类:nn.Module
nn.Module是 PyTorch 中所有神经网络模块的基类。它提供了许多重要的功能和方法,简化了神经网络的构建、训练和管理。
想要搭建自己的神经网络,首先要继承他,并实现他的两个方法:__init__和forward
__init__:主要用于定义神经网络中的各种层
forward:调用上面的层,数据经过哪些层,怎么处理都是在这里定义
什么是层?
神经网络中包含了各种层,他们都有各自处理数据的方式,每一层就像一把菜刀一样,__init__就是定义这些菜刀的规格,他们有多大,多长等等。forward呢,就是定义送进来的食材需要用哪些菜刀?需不需要加佐料,撒点孜然、葱花啥的
看一下基本的代码结构:
from torch import nn class Model(nn.Module): def __init__(self): # 空的 一把菜刀也没有 super(Model, self).__init__() def forward(self, x): # 食材怎么进来就怎么出去,没有经过任何处理 return x
2.Conv2d
接下来要开始学习第一把刀的使用:Conv2d,PyTorch 中用于处理二维数据(如图像)的卷积层
背景知识:
1.名称解析
Conv2d:
- Conv:Convolution的简写,表示卷积操作
- 2d:表示卷积操作是对二维数据进行的,如:图像
2.维度(Pytorch)
在Pytorch中,图像通常是3D张量,维度是3,形状:(channel,height,width)
- channel:代表通道数,彩色图像是3,灰色图像是1
- height:图像的高度,每一行的像素数量
- width:图像的宽度,每一列的像素数量
彩色图像:
- 彩色图像:形状(3,68,68),代表这是一张68*68的彩色图像
- 灰色图像:形状(1,68,68),代表这是一张68*68的灰色图像
在Pytorch的学习中,通常将多个图像放入一个batch(批次)中,这样可以进行批量处理,因此图像张量会有一个额外的维度,就是batch_size,代表批次的大小,此时图像就是一个4D张量,形状是(batch_size,channel,height,width)
3.空间维度
此时,我们再从空间维度上理解一下图像,因为他的基本结构只有(height,width),所以我们称为2D图像,视频是由一系列连续的图像帧组成的,每个时间点(帧)是一个图像,所以他是3D视频,形状(time,height,width)
Conv2d就是专门处理2D图像的,当然还有Conv3d,这个我们以后再说,有兴趣可以去官网了解一下
4.卷积:
卷积简单来说就是将卷积核(一个矩阵)在输入数据上滑动,并与输入内容的局部区域进行点乘操作,最后输出一个新的矩阵
蓝色矩阵就叫做卷积核,绿色矩阵就是我们的输入数据(按照目前学习阶段,那应该是一个图像数据),黄色部分是卷积核移动到输入内容上的第一步,用卷积核上的每一个区域的数字和输入数据上的对应区域的数字进行相乘,再相加,比如:1*1+2*2+3*3......最后9个区域的和加起来,作为输出内容的第一行第一列的数据,(1+4+9)*3=42
然后卷积核向右平移一步,继续计算,最后得到一个3*3的矩阵,也就是棕色的那个
动图:
有条件的可以去官网看一下:https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
5.常用参数:
- in_channels :输入数据的通道数
- out_channels :输出数据的通道数,也是卷积核的个数
- kernel_size :卷积核大小
- stride:卷积核滑动的步数
- padding:填充,对输入数据进行填充
- padding_mode :对填充部分的数据进行更改,默认0
填充(绿色区域):
填充之后进行滑动,步数为2
实操部分
1.准备数据集
import torchvision.datasets
from torch.utils.data import DataLoader
from torchvision import transforms
# 定义内置数据集
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, download=True, transform=transforms.ToTensor())
# 定义数据加载器
dataloader = DataLoader(dataset, batch_size=64)
2.准备模型
from torch import nn
# 准备模型
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# 定义一个网络层conv1 处理2D图像
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
def forward(self, x):
# 调用第一个网络层conv1,用来处理输入数据x
x = self.conv1(x)
return x
model = Model()
3.将数据循环写入tensorboard
torch.reshape:接受一个张量,不改变原始张量,返回一个新的张量
参数:
- input:要改变的张量
- shape:你希望的新张量的形状(batch_size,channel,height,width)
当某一维度是-1,表示由PyTorch 自动推算,但最多只能有一个维度为-1
writer = SummaryWriter("logs")
step = 0
# 遍历数据加载器
for data in dataloader:
imgs, labels = data
# 调用模型对图片进行处理
output=model(imgs)
# 把输入内容的图片写入tensorboard
writer.add_images("input", imgs, step)
#改变输出图的形状
output = torch.reshape(output,(-1,3,30,30))
# 把输出内容的图片写入tensorboard
writer.add_images("output", output, step)
step += 1
writer.close()
4.完整代码:
import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
# 定义内置数据集
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, download=True, transform=transforms.ToTensor())
# 定义数据加载器
dataloader = DataLoader(dataset, batch_size=64)
# 准备模型
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# 定义一个网络层conv1 处理2D图像
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
def forward(self, x):
# 调用第一个网络层conv1,用来处理输入数据x
x = self.conv1(x)
return x
model = Model()
writer = SummaryWriter("logs")
step = 0
# 遍历数据加载器
for data in dataloader:
imgs, labels = data
# 调用模型对图片进行处理
output=model(imgs)
# 把输入内容的图片写入tensorboard
writer.add_images("input", imgs, step)
#改变输出图的形状
output = torch.reshape(output,(-1,3,30,30))
# 把输出内容的图片写入tensorboard
writer.add_images("output", output, step)
step += 1
writer.close()
运行tensorboard,网页展示出来是这样的,因为模型处理后的图片通道是6,我们后面又手动改变了图片的通道,所以多余的数据就到了批次那里,input是64,output是128
好了,拜拜