语义分割——DeeplabV3plus

发布于:2024-12-18 ⋅ 阅读:(142) ⋅ 点赞:(0)

DeeplabV3plus 是一种先进的用于语义分割任务的深度学习模型。DeepLabV3plus模型采用了编码器-解码器(Encoder-Decoder)结构,通过编码器提取图像特征,再通过解码器将这些特征映射回原始图像尺寸,实现像素级的分类。具体来说,模型的主干网络(论文中对ResNet101或Xception做了实验)负责特征提取,特征提取分为高层语义提取和底层的语义提取两个部分。然后,模型会利用空洞卷积(Dilated Convolution)技术,构建了ASPP(Atrous Spatial Pyramid Pooling)模块,提高模型在不同尺度特征提取上的能力。最后,通过解码器恢复图像的细节信息,得到最终的分割结果。总体流程如下:

这里面,核心部分是ASPP模块,也就是空洞金字塔池化模块,该模型最大的特点就是利用空洞卷积来提取出不同尺度的信息。并把不同尺度的特征信息进行拼接,再结合浅层特征后进行上采样,得到影像的预测结果。具体流程如下:
  1. 原始图像经过骨干特征提取特征,采用ResNet或Xception等卷积神经网络进行特征提取;
  2. 这里分成两部分,一部分是较为浅层的特征x1,一部分是较为深层的特征x2;
  3. 将较为深层的特征x2,输入ASPP模块,在ASPP中,分为五个分支:
    1. 第一个分支经过1x1卷积,不改变特征大小,得到特征图;
    2. 第二个分支经过3x3卷积,设置空洞系数为6,填充和空洞系数一致,不改变特征大小,得到特征图;
    3. 第三个分支经过3x3卷积,设置空洞系数为12,填充和空洞系数一致,不改变特征大小,得到特征图;
    4. 第四个分支经过3x3卷积,设置空洞系数为18,填充和空洞系数一致,不改变特征大小,得到特征图;
    5. 第五个分支经过平均池化操作,再经过一个1x1卷积改变通道数,得到特征图;
    6. 按通道维度合并五个分支的特征;
    7. 合并后的特征经过1x1卷积,得到深层特征的最终特征图x3;
  4. 将较为浅层的特征x1进行1x1卷积,得到特征图x4;
  5. 将深层特征的最终特征图x3进行上采样,恢复到和浅层特征x1一样的大小,假设称为x5;
  6. 按通道维度合并浅层特征x4和深层特征x5;
  7. 再进行一个3x3卷积,得到分类结果;
  8. 上采样,恢复成原始输入图像的大小,得到图像分割结果。
空洞卷积的内容,网上有很多介绍。大家可以自己去查阅相关资料,简单来说,空洞卷积或者叫膨胀卷积,就是为了增加感受野的一种卷积方式。

扩张率为1的时候,就是普通卷积,可以看到感受野就是3x3,当扩张率为2的时候,卷积核元素之间就会间隔1个像素点,实际参与运算的感受野范围就会扩大,等效卷积核变成了5x5,感受野变成了7x7,当扩张率为4的时候,卷积核元素之间就会间隔3个像素点,等效卷积核变成了9x9,感受野扩张到15x15。可以看到,空洞卷积的目的就是在不增加卷积核元素的前提下,增加感受野。DeeplabV3plus模型就是利用这种卷积方式,获取到不同尺度下的特征值。

DeeplabV3plus的代码实现如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet101  # 可以选择其他主干网络

class ASPPModule(nn.Module):
    def __init__(self, in_channels, out_channels, dilations):
        super(ASPPModule, self).__init__()
        self.branches = nn.ModuleList()
        self.branches.append(
            # image pooling 分支
            nn.Sequential(nn.AvgPool2d(3,1,1),
                          nn.Conv2d(in_channels, out_channels, 1, 1),
                          nn.BatchNorm2d(out_channels),
                          nn.ReLU(inplace=True))
        )
        # 四个空洞卷积分支
        for d in dilations:
            self.branches.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 3, 1, dilation=d, padding=d),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )
        # 1x1卷积
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d((len(dilations)+1) * out_channels, out_channels, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        size = x.size()[2:]
        print("size: ",size)
        features = []
        # 获取各个分支的特征,并把大小调整到一致
        for i in range(len(self.branches)):
            out = self.branches[i](x)
            print("out.shape: ",out.shape)
            out = F.interpolate(out, size=size, mode='bilinear', align_corners=True)
            print("upsample out.shape: ",out.shape)
            features.append(out)
        # 按通道维度合并五个特征分支
        features = torch.cat(features, dim=1)
        return self.conv_bn_relu(features)

# 凯明初始化
def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

class DeepLabV3Plus(nn.Module):
    def __init__(self, n_classes=21, backbone='resnet101', output_stride=16):
        super(DeepLabV3Plus, self).__init__()
        if backbone == 'resnet101':
            # 这里要用新的写法,否则会显示警告信息,提示过期
            #self.backbone = resnet101(pretrained=False)
            self.backbone = resnet101(weights="IMAGENET1K_V1")
            # 修改ResNet的最后几个层以适应DeepLabV3+
            # 移除最后的平均池化层和分类层
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            self.first = self.backbone[0:3]
            self.layer1 = self.backbone[4]
            self.layer2 = self.backbone[5]
            self.layer3 = self.backbone[6]
            self.layer4 = self.backbone[7]
        else:
            raise ValueError('Unsupported backbone - `{}`, Use resnet101'.format(backbone))

        self.aspp = ASPPModule(2048, 256, [1, 6, 12, 18])
        self.conv1x1 = nn.Conv2d(256, 48, 1, 1)
        self.upsample4 = nn.ConvTranspose2d(48, 48, 4, stride=2, padding=1)
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(256, 48, 1, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        self.final_conv = nn.Conv2d(96, n_classes, 3, 1, 1)

        initialize_weights(self.backbone, self.aspp, self.conv1x1, self.low_level_conv, self.final_conv)

    def forward(self, x):
        # 获取主干网络的特征图
        c2, c3, c4, c5 = self._forward_backbone(x)
        size0 = x.size()[2:]
        print("size:",size0)
        # ASPP模块
        features = self.aspp(c5)
        print("features.shape: ", features.shape)
        features = self.conv1x1(features)
        print("features.shape: ", features.shape)
        features = self.upsample4(features)
        print("features.shape: ", features.shape)

        # 低级特征融合
        low_level_features = self.low_level_conv(c3)
        size = low_level_features.size()[2:]
        features = F.interpolate(features, size=size, mode='bilinear', align_corners=True)
        features = torch.cat([features, low_level_features], dim=1)

        # 最终分类层
        output = self.final_conv(features)
        # 最终上采样
        output = F.interpolate(output, size=size0, mode='bilinear', align_corners=True) 
        return output

    def _forward_backbone(self, x):
        c2 = self.first(x)
        c3 = self.layer1(c2)
        c4 = self.layer2(c3)
        c5 = self.layer3(c4)
        c5 = self.layer4(c5)
        print("c2.shape: {}".format(c2.shape))
        print("c3.shape: {}".format(c3.shape))
        print("c4.shape: {}".format(c4.shape))
        print("c5.shape: {}".format(c5.shape))
        return c2, c3, c4, c5

# 示例用法
model = DeepLabV3Plus(n_classes=21)  # Pascal VOC数据集的类别数
input_tensor = torch.randn(1, 3, 513, 513)  # 示例输入,批量大小为1,3个通道,高度和宽度为513
output = model(input_tensor)
print(output.shape)  # 输出形状应该是[1, 21, 513, 513],表示每个像素的类别预测
# 输出:
c2.shape: torch.Size([1, 64, 257, 257])
c3.shape: torch.Size([1, 256, 257, 257])
c4.shape: torch.Size([1, 512, 129, 129])
c5.shape: torch.Size([1, 2048, 33, 33])
size: torch.Size([513, 513])
size:  torch.Size([33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
features.shape:  torch.Size([1, 256, 33, 33])
features.shape:  torch.Size([1, 48, 33, 33])
features.shape:  torch.Size([1, 48, 66, 66])

torch.Size([1, 21, 513, 513])

对遥感影像解译数据集GID进行训练,学习率0.01,batch_size设置为8,训练100个epoch,总体精度达到0.847,各类别精度如下:

2024-12-12 13:22:26,051 - __main__ - DEBUG - --------------------------------------
2024-12-12 13:22:26,051 - __main__ - DEBUG - |0|background|0.84|
2024-12-12 13:22:26,051 - __main__ - DEBUG - |1|building|0.78|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |2|farmland|0.83|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |3|tree|0.21|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |4|grass|0.37|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |5|water|0.75|
2024-12-12 13:22:26,052 - __main__ - DEBUG - --------------------------------------

分割后结果如下:


网站公告

今日签到

点亮在社区的每一天
去签到