pytorch 今日小知识3——nn.MaxPool3d 、nn.AdaptiveAvgPool3d、nn.ModuleList

发布于:2024-04-17 ⋅ 阅读:(34) ⋅ 点赞:(0)

MaxPool3d — PyTorch 2.2 documentation

假设输入维度(1,2,3,4,4)

maxpool = torch.nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(1, 0, 0))

 F 维的 kernel_size 为 2,说明在 F 维的覆盖的 frame 数为 2,也就是每次有 2 个 frame 加入运算。

F 维的 stride 为 2,说明每次的跨度为 2。

  • padding

F 维的 padding 为 1,说明需要在 F 维做 padding,也就是会在输入的输入之前和之后各做 1 个 frame 的padding。

 

2。nn.AdaptiveAvgPool3d()是一个自适应的三维平均池化层,它对输入的立体数据进行降采样,并允许动态地指定输出的目标大小。
不同于nn.AvgPool3d()需要手动指定池化窗口大小,nn.AdaptiveAvgPool3d()直接指定输出的目标大小。它会根据目标输出大小自适应地调整池化窗口的大小,以保证输出的大小和目标大小一致。

import torch
import torch.nn as nn

# 创建一个 AdaptiveAvgPool3d 层,指定目标输出大小为 (2, 2, 2)
pool = nn.AdaptiveAvgPool3d((2, 2, 2))

# 输入数据,假设输入大小为 [batch_size, channels, depth, height, width]
input_data = torch.randn(1, 1, 4, 4, 4)  # 示例输入数据

# 对输入数据进行自适应池化操作
output = pool(input_data)

print(output.shape)

3.在复现代码过程中遇到了 

self.hpp = nn.ModuleList([
    GeMHPP(bin_num=[1]) for i in range(self.m)])

就学习了nn.moduleList这个函数

nn.ModuleList,它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器。你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中

class net4(nn.Module):
    def __init__(self):
        super(net4, self).__init__()
        layers = [nn.Linear(10, 10) for i in range(5)]
        self.linears = nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x

net = net4()
print(net)

 

 

nn.Sequential内部实现了forward函数,因此可以不用写forward函数。而nn.ModuleList则没有实现内部forward函数

seq = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
print(seq)
# Sequential(
#   (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (1): ReLU()
#   (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (3): ReLU()
# )

#对上述seq进行输入
input = torch.randn(16, 1, 20, 20)
print(seq(input))
#torch.Size([16, 64, 12, 12])

参考

详解PyTorch中的ModuleList和Sequential - 知乎 (zhihu.com)

 3D 池化(MaxPool3D) 和 3D(Conv3d) 卷积详解-CSDN博客


网站公告

今日签到

点亮在社区的每一天
去签到