Pytorch | 从零构建ParNet/Non-Deep Networks对CIFAR10进行分类

发布于:2025-02-11 ⋅ 阅读:(81) ⋅ 点赞:(0)

前面文章我们构建了AlexNet、Vgg、GoogleNet、ResNet、MobileNet、EfficientNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
这篇文章我们来构建ParNet(Non-Deep Networks).

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

ParNet

ParNet是一种高效的深度学习网络架构由谷歌研究人员于2021年提出,以下从其架构特点、优势及应用等方面进行详细介绍:

架构特点

  • 并行子结构:ParNet的核心在于其并行的子结构设计。它由多个并行的分支组成,每个分支都包含一系列的卷积层和池化层等操作。这些分支在网络中同时进行计算,就像多条并行的道路同时运输信息一样,大大提高了信息处理的效率。
  • 多尺度特征融合:不同分支在不同的尺度上对输入图像进行处理,然后将这些多尺度的特征进行融合。例如,一个分支可能专注于提取图像中的局部细节特征,而另一个分支则更擅长捕捉图像的全局上下文信息。通过融合这些不同尺度的特征,ParNet能够更全面、更准确地理解图像内容。
  • 深度可分离卷积:在网络的卷积操作中,大量使用了深度可分离卷积。这种卷积方式将传统的卷积操作分解为深度卷积和逐点卷积两个步骤,大大减少了计算量,同时提高了模型的运行速度,使其更适合在移动设备等资源受限的环境中应用。
    在这里插入图片描述

优势

  • 高效性:由于其并行结构和深度可分离卷积的使用,ParNet在计算效率上具有很大的优势。它可以在保证模型性能的前提下,大大减少模型的参数量和计算量,从而实现快速的推理和训练。
  • 灵活性:ParNet的并行子结构和多尺度特征融合方式使其具有很强的灵活性。它可以根据不同的任务和数据集进行调整和优化,轻松适应各种图像识别和处理任务。
  • 可扩展性:该网络架构具有良好的可扩展性,可以方便地增加或减少分支的数量和深度,以满足不同的性能需求。

应用

  • 图像分类:在图像分类任务中,ParNet能够快速准确地对图像中的物体进行分类。例如,在CIFAR-10和ImageNet等标准图像分类数据集上,ParNet取得了与现有先进模型相当的准确率,同时具有更快的推理速度。
  • 目标检测:在目标检测任务中,ParNet可以有效地检测出图像中的目标物体,并确定其位置和类别。通过对多尺度特征的融合和利用,ParNet能够更好地处理不同大小和形状的目标物体,提高检测的准确率和召回率。
  • 语义分割:在语义分割任务中,ParNet能够对图像中的每个像素进行分类,将图像分割成不同的语义区域。其多尺度特征融合的特点使得它在处理复杂的场景和物体边界时具有更好的效果,能够生成更准确的分割结果。

ParNet结构代码详解

结构代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class SSE(nn.Module):
    def __init__(self, in_channels):
        super(SSE, self).__init__()
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, in_channels)

    def forward(self, x):
        out = self.global_avgpool(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = torch.sigmoid(out)
        out = out.view(out.size(0), out.size(1), 1, 1)
        
        return x * out
    

class ParNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ParNetBlock, self).__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.sse = SSE(out_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        out = branch1x1 + branch3x3
        out = self.sse(out)
        out = F.silu(out)

        return out
    

class DownsamplingBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownsamplingBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.se = SSE(out_channels)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.se(out)

        return out
    

class FusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FusionBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.se = SSE(out_channels)
        self.concat = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=False)
    
    def forward(self, x1, x2):
        x1, x2 = self.conv1x1(x1), self.conv1x1(x2)
        x1, x2 = self.bn(x1), self.bn(x2)
        x1, x2 = self.relu(x1), self.relu(x2)
        x1, x2 = self.se(x1), self.se(x2)
        out = torch.cat([x1, x2], dim=1)
        out = self.concat(out)

        return out
    
class ParNet(nn.Module):
    def __init__(self, num_classes):
        super(ParNet, self).__init__()
        self.downsampling_blocks = nn.ModuleList([
            DownsamplingBlock(3, 64),
            DownsamplingBlock(64, 128),
            DownsamplingBlock(128, 256),
        ])

        self.streams = nn.ModuleList([
            nn.Sequential(
                ParNetBlock(64, 64),
                ParNetBlock(64, 64),
                ParNetBlock(64, 64),
                DownsamplingBlock(64, 128)
            ),
            nn.Sequential(
                ParNetBlock(128, 128),
                ParNetBlock(128, 128),
                ParNetBlock(128, 128),
                ParNetBlock(128, 128)
            ),
            nn.Sequential(
                ParNetBlock(256, 256),
                ParNetBlock(256, 256),
                ParNetBlock(256, 256),
                ParNetBlock(256, 256)
            )
        ])

        self.fusion_blocks = nn.ModuleList([
            FusionBlock(128, 256),
            FusionBlock(256, 256)
        ])

        self.final_downsampling = DownsamplingBlock(256, 1024)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        downsampled_features = []
        for i, downsampling_block in enumerate(self.downsampling_blocks):
            x = downsampling_block(x)
            downsampled_features.append(x)

        stream_features = []
        for i, stream in enumerate(self.streams):
            stream_feature = stream(downsampled_features[i])
            stream_features.append(stream_feature)

        fused_features = stream_features[0]
        for i in range(1, len(stream_features)):
            fused_features = self.fusion_blocks[i - 1](fused_features, stream_features[i])

        x = self.final_downsampling(fused_features)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

代码详解

以下是对上述提供的ParNet代码的详细解释,这段代码使用PyTorch框架构建了一个名为ParNet的神经网络模型,整体结构符合ParNet网络架构的特点,下面从不同模块依次进行分析:

SSE

class SSE(nn.Module):
    def __init__(self, in_channels):
        super(SSE, self).__init__()
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, in_channels)

    def forward(self, x):
        out = self.global_avgpool(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = torch.sigmoid(out)
        out = out.view(out.size(0), out.size(1), 1, 1)
        
        return x * out
  • 功能概述
    这个类实现了类似Squeeze-and-Excitation(SE)模块的功能,旨在对输入特征进行通道维度的重加权,突出重要的通道特征,抑制相对不重要的通道特征。

  • __init__方法

    • 首先通过nn.AdaptiveAvgPool2d(1)创建了一个自适应平均池化层,它可以将输入特征图在空间维度上压缩为大小为(1, 1)的特征图,也就是将每个通道的特征进行全局平均池化,得到通道维度上的统计信息,无论输入特征图的尺寸是多少都可以自适应处理。
    • 接着创建了一个全连接层nn.Linear(in_channels, in_channels),其输入和输出维度都是in_channels,目的是学习通道维度上的变换权重。
  • forward方法

    • 先将输入x经过全局平均池化层得到压缩后的特征表示out,然后通过view操作将其维度调整为二维形式(批次大小,通道数),方便后续全连接层处理。
    • 接着将这个特征送入全连接层进行线性变换,再经过sigmoid激活函数,将输出值映射到(0, 1)区间,得到每个通道对应的权重。
    • 最后将权重的维度调整回四维(批次大小,通道数,1,1),并与原始输入x进行逐元素相乘,实现对不同通道特征的重加权。

ParNetBlock 类

class ParNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ParNetBlock, self).__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.sse = SSE(out_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        out = branch1x1 + branch3x3
        out = self.sse(out)
        out = F.silu(out)

        return out
  • 功能概述
    该类定义了ParNet中的一个基础并行块结构,包含两个并行分支(1x1卷积分支和3x3卷积分支)以及一个SSE模块,用于提取和融合特征,并进行通道重加权和非线性激活。

  • __init__方法

    • 构建了两个并行分支,branch1x1是一个由1x1卷积层、批归一化层和ReLU激活函数组成的序列,1x1卷积主要用于调整通道维度,同时可以融合不同通道间的信息,且计算量相对较小。
    • branch3x3同样是由3x3卷积层(带有合适的填充保证特征图尺寸不变)、批归一化层和ReLU激活函数组成,3x3卷积能够捕捉局部空间特征信息。
    • 最后实例化了一个SSE模块,用于后续对融合后的特征进行通道维度的重加权。
  • forward方法

    • 首先将输入x分别送入两个并行分支进行处理,得到两个分支的输出branch1x1branch3x3,然后将它们对应元素相加进行特征融合。
    • 接着把融合后的特征送入SSE模块进行通道重加权,最后使用F.silu(也就是swish函数)激活函数对结果进行非线性激活,并返回处理后的特征。

DownsamplingBlock 类

class DownsamplingBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownsamplingBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.se = SSE(out_channels)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.se(out)

        return out
  • 功能概述
    用于对输入特征图进行下采样操作,同时融合了批归一化、非线性激活以及类似SE的通道重加权功能,以减少特征图的空间尺寸并提取更抽象的特征。

  • __init__方法
    创建了一个3x3卷积层,其步长设置为2,配合合适的填充,在进行卷积操作时可以实现特征图在空间维度上长宽各减半的下采样效果,同时调整通道维度到out_channels。还定义了批归一化层、ReLU激活函数以及一个SSE模块。

  • forward方法
    按照顺序依次将输入x经过卷积层、批归一化层、ReLU激活函数进行处理,然后再通过SSE模块进行通道重加权,最终返回下采样并处理后的特征图。

FusionBlock 类

class FusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FusionBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.se = SSE(out_channels)
        self.concat = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=False)
    
    def forward(self, x1, x2):
        x1, x2 = self.conv1x1(x1), self.conv1x1(x2)
        x1, x2 = self.bn(x1), self.bn(x2)
        x1, x2 = self.relu(x1), self.relu(x2)
        x1, x2 = self.se(x1), self.se(x2)
        out = torch.cat([x1, x2], dim=1)
        out = self.concat(out)

        return out
  • 功能概述
    该类用于融合不同分支或不同阶段的特征,通过一系列操作包括调整通道维度、批归一化、激活以及通道重加权,然后将两个特征在通道维度上进行拼接并进一步融合。

  • __init__方法

    • 首先创建了1x1卷积层,步长设置为2,用于对输入的两个特征分别进行通道维度调整以及下采样操作(特征图空间尺寸减半)。
    • 接着定义了批归一化层、ReLU激活函数以及SSE模块,用于对下采样后的特征进行处理。还创建了一个1x1卷积层concat,用于将拼接后的特征进一步融合为指定的通道维度。
  • forward方法
    分别对输入的两个特征x1x2依次进行1x1卷积、批归一化、ReLU激活以及SSE模块的处理,然后将它们在通道维度上进行拼接(torch.cat操作,维度dim=1表示按通道维度拼接),最后通过concat卷积层将拼接后的特征融合为指定的通道维度,并返回融合后的特征。

ParNet 类

class ParNet(nn.Module):
    def __init__(self, num_classes):
        super(ParNet, self).__init__()
        self.downsampling_blocks = nn.ModuleList([
            DownsamplingBlock(3, 64),
            DownsamplingBlock(64, 128),
            DownsamplingBlock(128, 256),
        ])

        self.streams = nn.ModuleList([
            nn.Sequential(
                ParNetBlock(64, 64),
                ParNetBlock(64, 64),
                ParNetBlock(64, 64),
                DownsamplingBlock(64, 128)
            ),
            nn.Sequential(
                ParNetBlock(128, 128),
                ParNetBlock(128, 128),
                ParNetBlock(128, 128),
                ParNetBlock(128, 128)
            ),
            nn.Sequential(
                ParNetBlock(256, 256),
                ParNetBlock(256, 256),
                ParNetBlock(256, 256),
                ParNetBlock(256, 256)
            )
        ])

        self.fusion_blocks = nn.ModuleList([
            FusionBlock(128, 256),
            FusionBlock(256, 256)
        ])

        self.final_downsampling = DownsamplingBlock(256, 1024)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        downsampled_features = []
        for i, downsampling_block in enumerate(self.downsampling_blocks):
            x = downsampling_block(x)
            downsampled_features.append(x)

        stream_features = []
        for i, stream in enumerate(self.streams):
            stream_feature = stream(downsampled_features[i])
            stream_features.append(stream_feature)

        fused_features = stream_features[0]
        for i in range(1, len(stream_features)):
            fused_features = self.fusion_blocks[i - 1](fused_features, stream_features[i])

        x = self.final_downsampling(fused_features)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
  • 功能概述
    这是整个ParNet网络的定义类,整合了前面定义的各个模块,构建出完整的网络结构,包括下采样、并行分支处理、特征融合以及最后的分类全连接层等部分,能够接收输入图像数据并输出对应的分类预测结果。

  • __init__方法

    • downsampling_blocks:通过nn.ModuleList创建了一个包含三个下采样块的列表,用于对输入图像依次进行下采样,将图像的空间尺寸逐步缩小,同时增加通道数,从最初的3通道(对应RGB图像)逐步变为64128256通道。
    • streams:同样是nn.ModuleList,定义了三个并行的流(stream),每个流由多个ParNetBlock和一个DownsamplingBlock组成,不同流在不同的特征图尺度和通道维度上进行特征提取和处理,每个流内部的ParNetBlock用于提取和融合局部特征,最后的DownsamplingBlock用于进一步下采样。
    • fusion_blocks:也是nn.ModuleList,包含两个特征融合块,用于融合不同流的特征,将各个流提取到的不同层次的特征进行融合,以综合利用多尺度信息。
    • final_downsampling:定义了一个下采样块,用于对融合后的特征再进行一次下采样,将通道数提升到1024,进一步提取更抽象的全局特征。
    • fc:创建了一个全连接层,用于将最终提取到的特征映射到指定的类别数量num_classes,实现图像分类任务的输出。
  • forward方法

    • 首先,通过循环将输入x依次经过各个下采样块进行下采样,并将每次下采样后的特征保存到downsampled_features列表中,得到不同阶段下采样后的特征图。
    • 接着,针对每个流,将对应的下采样后的特征图送入流中进行处理,每个流内部的模块会进一步提取和融合特征,得到每个流输出的特征,并保存在stream_features列表中。
    • 然后,先取第一个流的特征作为初始的融合特征,再通过循环依次使用特征融合块将其他流的特征与已有的融合特征进行融合,不断更新融合特征。
    • 之后,将融合后的特征送入最后的下采样块进行进一步下采样处理。
    • 再通过自适应平均池化F.adaptive_avg_pool2d将特征图在空间维度上压缩为(1, 1)大小,然后使用view操作将其展平为二维向量。
    • 最后将展平后的特征送入全连接层进行分类预测,返回最终的分类结果。

总体而言,这段代码构建了一个符合ParNet架构特点的神经网络模型,通过多个模块的组合实现了高效的特征提取、融合以及分类功能,可应用于图像分类等相关任务。

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述

训练过程准确率变化曲线:
在这里插入图片描述

测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models
	|--__init__.py
	|-parnet.py
	|--...
|--results
|--weights
|--train.py
|--test.py

parnet.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class SSE(nn.Module):
    def __init__(self, in_channels):
        super(SSE, self).__init__()
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, in_channels)

    def forward(self, x):
        out = self.global_avgpool(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = torch.sigmoid(out)
        out = out.view(out.size(0), out.size(1), 1, 1)
        
        return x * out
    

class ParNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ParNetBlock, self).__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.sse = SSE(out_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        out = branch1x1 + branch3x3
        out = self.sse(out)
        out = F.silu(out)

        return out
    

class DownsamplingBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownsamplingBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.se = SSE(out_channels)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.se(out)

        return out
    

class FusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FusionBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.se = SSE(out_channels)
        self.concat = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=False)
    
    def forward(self, x1, x2):
        x1, x2 = self.conv1x1(x1), self.conv1x1(x2)
        x1, x2 = self.bn(x1), self.bn(x2)
        x1, x2 = self.relu(x1), self.relu(x2)
        x1, x2 = self.se(x1), self.se(x2)
        out = torch.cat([x1, x2], dim=1)
        out = self.concat(out)

        return out
    
class ParNet(nn.Module):
    def __init__(self, num_classes):
        super(ParNet, self).__init__()
        self.downsampling_blocks = nn.ModuleList([
            DownsamplingBlock(3, 64),
            DownsamplingBlock(64, 128),
            DownsamplingBlock(128, 256),
        ])

        self.streams = nn.ModuleList([
            nn.Sequential(
                ParNetBlock(64, 64),
                ParNetBlock(64, 64),
                ParNetBlock(64, 64),
                DownsamplingBlock(64, 128)
            ),
            nn.Sequential(
                ParNetBlock(128, 128),
                ParNetBlock(128, 128),
                ParNetBlock(128, 128),
                ParNetBlock(128, 128)
            ),
            nn.Sequential(
                ParNetBlock(256, 256),
                ParNetBlock(256, 256),
                ParNetBlock(256, 256),
                ParNetBlock(256, 256)
            )
        ])

        self.fusion_blocks = nn.ModuleList([
            FusionBlock(128, 256),
            FusionBlock(256, 256)
        ])

        self.final_downsampling = DownsamplingBlock(256, 1024)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        downsampled_features = []
        for i, downsampling_block in enumerate(self.downsampling_blocks):
            x = downsampling_block(x)
            downsampled_features.append(x)

        stream_features = []
        for i, stream in enumerate(self.streams):
            stream_feature = stream(downsampled_features[i])
            stream_features.append(stream_feature)

        fused_features = stream_features[0]
        for i in range(1, len(stream_features)):
            fused_features = self.fusion_blocks[i - 1](fused_features, stream_features[i])

        x = self.final_downsampling(fused_features)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'ParNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':
    model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':
    model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':
    model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':
    model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':
    model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':
    model = MobileNet(num_classes=10).to(device)
elif model_name == 'EfficientNet':
    model = EfficientNet(num_classes=10).to(device)
elif model_name == 'ParNet':
    model = ParNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练轮次
epochs = 15

def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    loss_history, acc_history = [], []
    for epoch in range(epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        loss_history.append(train_loss)
        acc_history.append(train_acc)
        # 保存模型权重,每5轮次保存到weights文件夹下
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')
    
    # 绘制损失曲线
    plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_loss_curve.png')
    plt.close()

    # 绘制准确率曲线
    plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_acc_curve.png')
    plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'ParNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':
    model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':
    model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':
    model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':
    model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':
    model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':
    model = MobileNet(num_classes=10).to(device)
elif model_name == 'EfficientNet':
    model = EfficientNet(num_classes=10).to(device)
elif model_name == 'ParNet':
    model = ParNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))

def test(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(testloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    test_loss, test_acc = test(model, testloader, criterion, device)
    print(f"================{model_name} Test================")
    print(f"Load Model Weights From: {weights_path}")
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

网站公告

今日签到

点亮在社区的每一天
去签到