YOLOv8改进 | 图像去噪篇 | 一种基于注意力机制的图像去噪网络ADNet融合YOLOv8(全网独家首发)

发布于:2024-05-17 ⋅ 阅读:(164) ⋅ 点赞:(0)

一、本文介绍

本文给大家带来的改进机制是Attention-guided Denoising Convolutional Neural Network (ADNet) 是一种专为图像去噪设计的深度学习模型,旨在解决合成噪声图像、真实噪声图像和盲去噪的挑战。它通过注意力机制提升性能,聚焦于相关特征,抑制无关噪声。其主要由四个模块组稀疏块(Sparse Block, SB)、特征增强块(Feature Enhancement Block, FEB)注意力块(Attention Block, AB)重建块(Reconstruction Block, RB)。本文内容为包含代码加解释加添加教程以及运行记录!

欢迎大家订阅我的专栏一起学习YOLO!  


目录

一、本文介绍

二、ADNet网络介绍

三、ADNet核心代码

四、添加方式

4.1 修改一

4.2 修改二 

4.3 修改三 

4.4 修改四 

五、ADNet的yaml文件和运行记录

5.1 ADNet的yaml文件

5.2 训练代码 

5.3 ADNet的训练过程截图 

五、本文总结


二、ADNet网络介绍

官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转


### ADNet 介绍

Attention-guided Denoising Convolutional Neural Network (ADNet) 是一种专为图像去噪设计的深度学习模型,旨在解决合成噪声图像、真实噪声图像和盲去噪的挑战。它通过注意力机制提升性能,聚焦于相关特征,抑制无关噪声。

ADNet 的架构

ADNet 的架构由四个主要模块组成:

1. 稀疏块(Sparse Block, SB)
2. 特征增强块(Feature Enhancement Block, FEB)
3. 注意力块(Attention Block, AB)
4. 重建块(Reconstruction Block, RB)

1. 稀疏块(SB)

稀疏块旨在平衡性能和效率。它结合了膨胀卷积和普通卷积,有效去除噪声的同时保持计算效率。膨胀卷积有助于扩展感受野,不显著增加计算负载,这对于捕捉更多上下文信息至关重要。

2. 特征增强块(FEB)

特征增强块通过长路径整合全局和局部特征,增强模型的表达能力。通过捕捉局部细节和全局上下文,FEB 提高了模型从噪声输入中重建干净图像的能力。这种整合使得模型能够更好地理解和处理图像结构,从而提高去噪性能。

3. 注意力块(AB)

注意力块是 ADNet 处理复杂噪声图像的核心。它细致地提取复杂背景中隐藏的噪声信息,使其特别适用于真实噪声图像和盲去噪任务。注意力机制使得模型能够聚焦于图像的最相关部分,减少噪声的影响,提高去噪图像的整体质量。

4. 重建块(RB)

重建块旨在通过结合获取的噪声映射和给定的噪声图像来构建干净图像。它通过利用前几个模块增强的特征,生成去噪输出。此模块确保最终输出在去除噪声的同时保持原始图像的结构。

性能与评估

ADNet 在三项主要任务中进行了评估:合成噪声图像、真实噪声图像和盲去噪。综合实验结果表明,ADNet 在定量和定性评估中表现出色。

1. 合成噪声图像:ADNet 在去除人工添加噪声方面表现出色,通常优于传统方法和其他基于 CNN 的模型。
2. 真实噪声图像:该模型在处理真实场景中的复杂噪声模式方面特别有效。
3. 盲去噪:ADNet 处理盲去噪的能力使其在各种噪声特性未知的应用中表现出色。

ADNet 的主要优势

增强的表达能力:通过 FEB 整合全局和局部特征,提升模型捕捉重要细节和上下文的能力。
注意力机制:通过使用注意力块,模型能够聚焦于相关特征,显著提高去噪性能,特别是在复杂场景中。
效率与性能平衡:稀疏块确保模型在保持去噪效果的同时,保持计算效率。
稳健的重建能力:重建块确保最终输出为高质量的干净图像,噪声伪影最小。


三、ADNet核心代码

ADNet的使用方式看章节四!

import torch
import torch.nn as nn


class ADNet(nn.Module):
    def __init__(self, channels, num_of_layers=16):
        super(ADNet, self).__init__()
        kernel_size = 3
        padding = 1
        features = num_of_layers
        groups =1
        layers = []
        kernel_size1 = 1
        self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_4 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_5 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_6 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_7 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_8 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_9 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_10 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_11 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_12 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_13 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_14 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_15 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_16 = nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=1,groups=groups,bias=False)
        self.conv3 = nn.Conv2d(in_channels=6,out_channels=3,kernel_size=1,stride=1,padding=0,groups=1,bias=True)
        self.ReLU = nn.ReLU(inplace=True)
        self.Tanh= nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)
                clip_b = 0.025
                w = m.weight.data.shape[0]
                for j in range(w):
                    if m.weight.data[j] >= 0 and m.weight.data[j] < clip_b:
                        m.weight.data[j] = clip_b
                    elif m.weight.data[j] > -clip_b and m.weight.data[j] < 0:
                        m.weight.data[j] = -clip_b
                m.running_var.fill_(0.01)

    def _make_layers(self, block,features, kernel_size, num_of_layers, padding=1, groups=1, bias=False):
        layers = []
        for _ in range(num_of_layers):
            layers.append(block(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, groups=groups, bias=bias))
        return nn.Sequential(*layers)

    def forward(self, x):
        input = x
        x1 = self.conv1_1(x)
        x1 = self.conv1_2(x1)
        x1 = self.conv1_3(x1)
        x1 = self.conv1_4(x1)
        x1 = self.conv1_5(x1)
        x1 = self.conv1_6(x1)
        x1 = self.conv1_7(x1)
        x1t = self.conv1_8(x1)
        x1 = self.conv1_9(x1t)
        x1 = self.conv1_10(x1)
        x1 = self.conv1_11(x1)
        x1 = self.conv1_12(x1)
        x1 = self.conv1_13(x1)
        x1 = self.conv1_14(x1)
        x1 = self.conv1_15(x1)
        x1 = self.conv1_16(x1)
        out = torch.cat([x,x1],1)
        out= self.Tanh(out)
        out = self.conv3(out)
        out = out*x1
        out2 = x - out
        return out2


if __name__ == "__main__":
    # Generating Sample image
    image_size = (1, 3, 640, 640)
    image = torch.rand(*image_size)

    # Model
    model = ADNet(3)

    out = model(image)
    print(out.size())


四、添加方式

4.1 修改一

第一还是建立文件,我们找到如下ultralytics/nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二 

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。


4.3 修改三 

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四 

按照我的添加在parse_model里添加即可。

到此就修改完成了,大家可以复制下面的yaml文件运行。


五、ADNet的yaml文件和运行记录

5.1 ADNet的yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, ADNet, []]  # 0-P1/2
  - [-1, 1, Conv, [64, 3, 2]]  # 1-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 2-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 4-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 6-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 8-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 7], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)


5.2 训练代码 

大家可以创建一个py文件将我给的代码复制粘贴进去,配置好自己的文件路径即可运行。

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
    # model.load('yolov8n.pt') # loading pretrain weights
    model.train(data=r'替换数据集yaml文件地址',
                # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
                cache=False,
                imgsz=640,
                epochs=150,
                single_cls=False,  # 是否是单类别检测
                batch=4,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD', # using SGD
                # resume='', # 如过想续训就设置last.pt的地址
                amp=False,  # 如果出现训练损失为Nan可以关闭amp
                project='runs/train',
                name='exp',
                )


5.3 ADNet的训练过程截图 

 


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏目录:


网站公告

今日签到

点亮在社区的每一天
去签到