新的网络结构超越残差网络:记忆网络与残差网络的对比实践

发布于:2024-07-29 ⋅ 阅读:(104) ⋅ 点赞:(0)

摘要

本文将带领大家了解一种新的网络结构——记忆网络,并通过与残差网络的对比实验,探讨其在MNIST数据集上的性能表现。让我们一起来看看记忆网络是否能超越残差网络。

正文:

一、引言

近年来,深度学习领域取得了飞速发展,各种新型网络结构层出不穷。残差网络(ResNet)因其卓越的性能和易于训练的特点,在众多任务中取得了优异的成绩。然而,研究者们从未停止对新网络结构的探索。本文将介绍一种新的网络结构——记忆网络(MemoryBlock),并将其与残差网络进行对比实践。

二、记忆网络简介

记忆网络是一种具有记忆功能的新型网络结构,它通过矩阵乘法实现输入与记忆的交互。记忆网络的核心思想是利用记忆矩阵存储历史信息,从而提高模型对数据的处理能力。在本实验中,我们采用了简化的记忆网络结构,如下所示:

class MemoryBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(MemoryBlock, self).__init__()

        # 使用Xavier初始化权重
        self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.fc)

        self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.mem)

        self.out = torch.nn.Linear(hidden_dim, 10, bias=True)
        self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = x @ (self.fc + self.mem)
        x = self.out(self.sig(x))
        return x + x

三、残差网络简介

残差网络(ResNet)是一种经典的深度学习网络结构,通过引入残差单元,有效地解决了深度神经网络训练过程中的梯度消失和梯度爆炸问题。残差网络的核心思想是让网络学习残差映射,而非直接学习原始映射。在本实验中,我们采用了如下残差网络结构:

class ResidualBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(ResidualBlock, self).__init__()
        self.fc = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.out = torch.nn.Linear(hidden_dim, 10, bias=True)
        self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        x = self.fc1(x) + x
        x = self.out(self.sig(x))
        return x

四、实验过程

  1. 数据准备:我们使用MNIST数据集进行实验,该数据集包含了手写数字0-9的灰度图像。
  2. 模型训练:我们分别对记忆网络和残差网络进行训练,使用交叉熵损失函数和Adam优化器。
  3. 性能对比:在训练过程中,我们记录了两种网络结构的损失值和准确率,并在测试集上进行性能对比。

五、实验结果

  1. 记忆网络与残差网络的损失值对比:
    通过实验,我们发现记忆网络在训练过程中的损失值逐渐下降,最终趋于稳定。同样,残差网络的损失值也呈现出相似的趋势。以下是两种网络结构的损失值对比图:
    在这里插入图片描述
  2. 记忆网络与残差网络的准确率对比:
    在测试集上,记忆网络和残差网络都取得了较高的准确率。以下是两种网络结构的准确率对比图:
    在这里插入图片描述

六、总结

通过本次实验,我们发现记忆网络在MNIST数据集上的表现与残差网络相当。虽然记忆网络在某些方面具有一定的优势,但要想完全超越残差网络,还需进一步优化和改进。未来,我们将继续探索新的网络结构,为深度学习领域的发展贡献力量。
感谢您的阅读,希望本文对您有所帮助!如有疑问,请随时留言交流。

import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets, transforms


# pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
# 超越残差的网络结构
class ResidualBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(ResidualBlock, self).__init__()
        self.fc = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.out = torch.nn.Linear(hidden_dim, 10, bias=True)
        self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        x = self.fc1(x) + x
        x = self.out(self.sig(x))
        return x


# 记忆网络结构
class MemoryBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(MemoryBlock, self).__init__()

        # 使用Xavier初始化权重
        self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.fc)

        self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.mem)

        self.out = torch.nn.Linear(hidden_dim, 10, bias=True)
        self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = x @ (self.fc + self.mem)
        x = self.out(self.sig(x))
        return x + x


# 定义数据转换操作,将图像转换为张量并归一化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载训练数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 加载测试数据集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# 训练模型
model = MemoryBlock(784)
loss_f = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.0003)
mem_loss = []
mem_acc = []
bar = tqdm(range(10))
for epoch in bar:
    for data in train_loader:
        inputs, labels = data
        opt.zero_grad()
        outputs = model(inputs.reshape([inputs.shape[0], 1, -1]))
        loss = loss_f(outputs.reshape([inputs.shape[0], -1]), labels)
        mem_loss.append(loss.item())
        mem_acc.append(np.mean((torch.argmax(outputs, -1).reshape(-1) == labels).numpy()))
        bar.set_description("epoch___{}____loss____{}____acc___{}".format(epoch, np.mean(mem_loss), np.mean(mem_acc)))
        # mem_acc.append(accuracy(outputs, labels))
        loss.backward()
        opt.step()
bar = tqdm(range(1))
mem_loss = []
mem_acc = []
for epoch in bar:
    for data in test_loader:
        inputs, labels = data
        opt.zero_grad()
        outputs = model(inputs.reshape([inputs.shape[0], 1, -1]))
        loss = loss_f(outputs.reshape([inputs.shape[0], -1]), labels)
        mem_loss.append(loss.item())
        mem_acc.append(np.mean((torch.argmax(outputs, -1).reshape(-1) == labels).numpy()))
        bar.set_description("epoch___{}____loss____{}____acc___{}".format(epoch, np.mean(mem_loss), np.mean(mem_acc)))

# 训练模型
model = ResidualBlock(784)
loss_f = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.0003)
res_loss = []
res_acc = []
bar = tqdm(range(10))
for epoch in bar:
    for data in train_loader:
        inputs, labels = data
        opt.zero_grad()
        outputs = model(inputs.reshape([inputs.shape[0], 1, -1]))
        loss = loss_f(outputs.reshape([inputs.shape[0], -1]), labels)
        res_loss.append(loss.item())
        res_acc.append(np.mean((torch.argmax(outputs, -1).reshape(-1) == labels).numpy()))
        bar.set_description("epoch___{}____loss____{}____acc___{}".format(epoch, np.mean(res_loss), np.mean(res_acc)))
        # mem_acc.append(accuracy(outputs, labels))
        loss.backward()
        opt.step()
bar = tqdm(range(1))
res_loss = []
res_acc = []
for epoch in bar:
    for data in test_loader:
        inputs, labels = data
        opt.zero_grad()
        outputs = model(inputs.reshape([inputs.shape[0], 1, -1]))
        loss = loss_f(outputs.reshape([inputs.shape[0], -1]), labels)
        res_loss.append(loss.item())
        res_acc.append(np.mean((torch.argmax(outputs, -1).reshape(-1) == labels).numpy()))
        bar.set_description("epoch___{}____loss____{}____acc___{}".format(epoch, np.mean(res_loss), np.mean(res_acc)))

plt.plot(mem_loss)
plt.plot(mem_acc)

plt.plot(res_loss)
plt.plot(res_acc)
plt.legend(["ml", "ma", "rl", "ra"])
plt.show()


网站公告

今日签到

点亮在社区的每一天
去签到