2D卷积核处理3D(时序)数据

发布于:2024-06-27 ⋅ 阅读:(156) ⋅ 点赞:(0)

Conv2D一般用于处理image,dim一般是4,(batch,channel,high,width)。对于多帧问题,例如时间序列,会多一个frame,也就是dim=5,(batch,frame,channel,high,width)。
此时需要做一些处理来适应Conv2D,同时提取时序特征。

一、Make A Video的处理方法(PseudoConv3d)

PseudoConv3d 选择通过分离空间卷积和时间卷积的方式来处理视频数据。先进行二维空间卷积,再通过一维时间卷积处理时间维度。

class PseudoConv3d(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        kernel_size = 3,
        *,
        temporal_kernel_size = None,
        **kwargs
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        temporal_kernel_size = default(temporal_kernel_size, kernel_size)
 
 		# 在2d卷积后,再加1d卷积来处理时间序列。
        self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
        self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2) if kernel_size > 1 else None
 
        if exists(self.temporal_conv):
            nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
            nn.init.zeros_(self.temporal_conv.bias.data)
 
    def forward(
        self,
        x,
        enable_time = True
    ):
        b, c, *_, h, w = x.shape
 
        is_video = x.ndim == 5
        enable_time &= is_video
 
        if is_video:
            x = rearrange(x, 'b c f h w -> (b f) c h w')
 
        x = self.spatial_conv(x)
 
        if is_video:
            x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
 
        if not enable_time or not exists(self.temporal_conv):
            return x
 
        x = rearrange(x, 'b c f h w -> (b h w) c f')
 
        x = self.temporal_conv(x)
 
        x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
 
        return x

二、Tune A Video的处理方法(InflatedConv3d)

InflatedConv3d 选择展平法,使二维卷积适用于视频数据。具体步骤是先将视频数据展平成二维数据,进行二维卷积操作,然后再恢复回三维数据。

class InflatedConv3d(nn.Conv2d):
    def forward(self, x):
        video_length = x.shape[2]

        # 将输入的三维张量重排为二维形式以适配 Conv2d
        x = rearrange(x, "b c f h w -> (b f) c h w")
        
        # 调用父类的 forward 方法进行二维卷积操作
        x = super().forward(x)
        
        # 将卷积后的张量重新排列回三维形式
        x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)

        return x

InflatedConv3d源码链接

比较与分析

相似点

  1. 处理方式:
    两者都使用二维卷积处理空间维度上的特征。
    两者都通过数据形状重排来适应二维卷积的输入要求。

不同点

  1. 时间维度处理:
  • InflatedConv3d:不单独处理时间维度,而是通过二维卷积直接在展平后的数据上进行卷积,这种方式没有专门的时间卷积操作。
  • PseudoConv3d:先进行二维空间卷积,然后通过一维时间卷积处理时间维度。这样做能够明确分离空间和时间特征的提取。
  1. 卷积核初始化:
  • InflatedConv3d:直接使用父类 nn.Conv2d 的权重和偏置,没有特别的初始化。
  • PseudoConv3d:时间卷积核初始化为 Dirac 分布,使其初始状态下相当于恒等映射。
  1. 等价性
    严格来说,InflatedConv3d 和 PseudoConv3d 并不完全等价,因为它们处理时间维度的方式不同:
  • InflatedConv3d 是通过展平数据和二维卷积来“间接”处理时间维度。
  • PseudoConv3d 明确地在空间卷积后,使用一维卷积处理时间维度。

由于 PseudoConv3d 进行了显式的一维时间卷积,它在时间特征的提取上可能更灵活和强大,而 InflatedConv3d 则更简洁,但可能在处理复杂时间依赖时不如 PseudoConv3d。

结论

InflatedConv3d 和 PseudoConv3d 在一些情况下可以产生类似的效果,尤其是当时间维度的变化相对简单时。然而,PseudoConv3d 由于其明确的时间卷积操作,在处理复杂时间动态时可能更有效。因此,它们并不是严格等价的,但都在一定程度上解决了将二维卷积扩展到三维视频数据的问题。


网站公告

今日签到

点亮在社区的每一天
去签到