pytorch-11 神经网络的学习

发布于:2024-05-16 ⋅ 阅读:(68) ⋅ 点赞:(0)

一、梯度下降中的两个关键问题

1 找出梯度向量的方向和大小

2 让坐标点移动起来(进行一次迭代)

二、找出距离和方向:反向传播

1 反向传播的定义与价值

我们是从左向右,从输出向输入,逐渐往前求解导数的表达式,并且我们所使用的节点上的张量,也是从后向前逐渐用到,这和我们正向传播的过程完全相反。这种从左到右,不断使用正向传播中的元素对梯度向量进行计算的方式,就是反向传播。

2 PyTorch实现反向传播

# 3分类,500个样本,20个特征,共3层,第一层13个神经元,第二层8个神经元
# 第一层的激活函数时relu,第二层的激活函数是sigmoid

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F

#确定数据
torch.manual_seed(420)
X = torch.rand((500,20),dtype=torch.float32) * 100
y = torch.randint(low=0,high=3,size=(500,),dtype=torch.float32)
input_ = X.shape[1]         # 特征的数目    
output_ = len(y.unique())   # 分类的数目

# 定义神经网络的架构
# logsoftmax + NLLloss / CrossEntropyLoss
# BCE,BCEWithLogitsLoss
class Model(nn.Module):
    def __init__(self,in_features=40,out_features=2):
        super().__init__()
        self.linear1 = nn.Linear(in_features,13,bias=False)
        self.linear2 = nn.Linear(13,8,bias=False)
        self.output = nn.Linear(8,out_features,bias=True)
    
    def forward(self, x):
        sigma1 = torch.relu(self.linear1(x))
        sigma2 = torch.sigmoid(self.linear2(sigma1))
        zhat = self.output(sigma2)
        return zhat
    
torch.manual_seed(420)
net = Model(in_features=input_, out_features=output_)				# 实例化神经网络
zhat = net.forward(X)				# 向前传播

criterion = CrossEntropyLoss()      # 定义损失函数
loss = criterion(zhat, y.long())    # 计算损失函数

net.linear1.weight.grad #还没有梯度

loss.backward(retain_graph=True)    # 反向传播

net.linear1.weight.grad
net.linear1.weight.shape

三、移动坐标点

1 走出第一步

# 权重更新: w(t+1) = w(t) - 步长 * grad
lr = 0.1                         # learning_rate, 0.001,0.01,0.05
w = net.linear1.weight.data     #现有的权重,w(t)
dw = net.linear1.weight.grad    # 本轮梯度, grad

w = w - lr * dw     # 更新权重w

2 动量法Momentum:从第一步到第二步

提升梯度下降的速度。

#momentum
# v(t) = gamma * v(t-1)  - lr * dw
# w(t+1) = w(t) + v(t)
lr = 0.1
gamma = 0.9

w = net.linear1.weight.data     #现有的权重,w(t)
dw = net.linear1.weight.grad    # 本轮梯度, grad

#t = 1,走第一步,进行首次迭代的时候,需要一个v0
dw.shape        #500,20
v = torch.zeros(dw.shape