【论文笔记】Semantic-Aware Domain Generalized Segmentation

发布于:2022-11-09 ⋅ 阅读:(19) ⋅ 点赞:(0) ⋅ 评论:(0)

论文

论文标题:Semantic-Aware Domain Generalized Segmentation

发表于:CVPR 2022 (oral)

论文地址:https://arxiv.org/abs/2204.00822

代码:https://github.com/leolyj/SAN-SAW/

参考博客:CVPR2022 Oral-即插即用!感知语义的域泛化语义分割模型 (SAN & SAW)

CVPR2022 (ORAL) Semantic-Aware Domain Generalized Segmentation - 知乎 (zhihu.com)

本文分别基于Instance Normalization (IN)与Instance Whitening (IW) 提出了两个用于编码器与解码器之间的即插即用模块:Semantic-Aware Normalization (SAN)与Semantic-Aware Whitening (SAW),能够极大的提示模型的泛化能力。在面临各种与训练数据的分布不一致的测试数据时,SAN与SAW仍能帮助模型尽可能的维持模型的性能。

一、Motivation

语义分割中的无监督域适应(UDA)基于目标域影像数据可知但对应标签数据不可得这一前提。但是一个更为现实的前提是我们无法得知包括图像数据在内的目标域任何信息,换句话说,就是模型在测试时可能会面临各种各样数据分布的图片,如果模型的泛化能力不够,那么其性能肯定会出现很大的波动。

增强模型的泛化能力一个最常用的方式就是数据增强,即把训练数据转换为各种各样的形式使得模型在训练阶段就见过各种各样的数据分布,从而提高模型的泛化能力。但是寄希望于数据增强能够使得转换后的训练数据覆盖所有测试数据的分布是不现实的。因此,数据增强的方式来增强模型的泛化能力具有其固有的缺陷。

图 1 应用不同的方法后,同一模型的编码器所提取的来自不同数据分布(不同域)的测试图片的特征分布。

另一个方向来增强泛化能力的方式是使用Normalization 和 Whitening,该方向的方法利用实例归一化(Instance Normalization, IN)或实例白化(Instance Whitening,IW)对不同样本的特征分布进行标准化。IN分别对单个图像的每个通道的特征进行标准化归一化,以减轻由于样式变化引起的特征不匹配。IN的具体过程可以由以下公式来表示:

 

 其中,F_n,k,h,w 表示一个 mini-batch 对应的特征图 F 中第 n 个 sample 的第 k 个通道特征图上空间位置为(h,w)上的特征值。但是如图1 (a)所示,使用IN只实现了特征分布的中心对齐,但是无法对齐特征的联合分布。

而如图1 (b)所示, 由于IW 可以消除各通道特征间的线性相关性,所以使用 IW 后可以形成均匀分布的良好聚类特征。IW的具体工程可以由以下公式来表示:

 

 

 其中\Psi(F_{n})本质上表示的是一个 mini-batch 中第 n 个 sample 对应的特征图上各个通道间的相关性。更加形象化的表达如图2所示。

图2 IW的示意图

 从如图1 (b)也可观察得到,特征虽然均匀分布了,但是却也没有对齐特征的联合分布。最近有论文研究表明,如图1 (c)所示,联合IN与IW后,能够对齐来自不同域(即数据分布不同)的联合特征边缘分布。然而,图1 (c)中也可以观察到虽然特征边缘分布得到了对齐,但是条件分布却依然处于没有对齐的状态,每个类别的分布仍然混合在一起以至于难以区分。

那么由此引出了本文的出发点:既要对齐特征的全局边缘分布,也要对齐条件分布,从而使得每个类别的分布在特征空间能够被很好的区分开。具体做法便在于在IN与IW的基础上引入了类别信息。下面来看一下具体方法设计。

二、方法

整体的框架如图3所示。

图3 模型总体框架图。

 也就是说,在原有分割模型编码器-解码器的结构框架上插入了两个即插即用的模块:Semantic-Aware Normalization (SAN)与Semantic-Aware Whitening (SAW)。它们分别基于IN与IW而设计。

Semantic-Aware Normalization (SAN)

SAN的总体框架如图4所示。

图4 SAN框架图

 该模块看似很复杂,但是其本质内容却很简单。

SAN模块设置了一个归一化后的特征图真值F_obj, 它是通过对编码器提取的特征图 F 进行如下类别级别的实例归一化后得到的:

 其中,F^c_{n,k}表示编码器提取的特征图 F_{n,k} 上对应真值标签 Y 中第 c 个类别所属的空间位置的特征值集合。γ 和 β 分别表示实例归一化重点缩放和平移变量,它们是可学习的,每一个不同的类别都有不同的 γ 和 β。

但是,之所以 F_obj 能够得到类别级别的归一化是因为有真值标签 Y。然而测试阶段真值标签是不可得的。所以,SAN中添加了一个预分割的分支,受到下采样的真值标签的监督,以引入相对正确的语义类别信息。编码器提取的特征图 F 会分别与预分割分支的对应类别的通道相乘以强调特征图中对应语义类别区域的特征。

但是由于预分割分支的分割可能不那么准确,这种直接相乘后的特征图可能会错误的强调一些不属于当前类别的的区域的特征,因此,作者提出了一个类别级别的特征优化模块(CRF)来改善特征。CRF的具体操作图示很清楚,可能需要注意的是这里的Maxpool和Avgpool是在通道维度上展开的。这样能够一定程度上的平滑预分割引入的错误信息。

经过CRF优化后的特征会送入区域归一化(Regional Normalization,RN)中进行最终的归一化操作。RF与IN的区别在于不是整个原始特征图上进行归一化,而是针对每一个类对应的区域进行归一化。这里其实CRF输出的特征图上如同图示中的heatmap一般会高亮对应语义类别的区域。为了进一步细化以及最终分割出对应语义类别的区域,会先对heatmap进行k-means聚类(k是超参数,论文里设置为5),然后选择其中的第一个类别作为对应语义类别的区域。选中的区域记为\Phi ^c_{high},它实际上是SAN模块中得到的作为对Y(c)的预测。因此,我们可以得到以下的归一化结果:

 

 其中,RN对应区域归一化。原文中没有写清楚RN的具体操作,代码实际上还处于没开源的状态。我猜测RN就是对应\Phi ^c_{high}区域的特征进行实例归一化,然后剩余的空间区域特征不变(有待代码开源考证)。综合上述过程,SAN中的总体目标函数为:

 其中CE应该就是图4中的L_seg,而第二项的作用在于使得SAN的归一化结果在没有真值标签的情况下逼近F_obj.

Semantic-Aware Whitening (SAW)

SAW的通体框架如图5所示。

图5 SAW框架图示意

 SAN中归一化后的特征会送入SAW中进行进一步的处理以对齐全局边缘分布与条件分布。从图上来看,又是一个看似很复杂的模块,实际上也不复杂。

SAW基于IW的改进版本GIW(分组实例白化)而来。GIW认为直接采用IW这种严格去除所有通道之间相关性的强白化方式可能会损害语义内容,导致关键的领域不变信息的丢失。因此,如图6所示,GIW将特征图分为了几个组,只去除组内的特征通道之间的相关性。

图6 GIW的示意图,其中L_GIW形式上与图2中的L_IW类似,只是重复了M次

 然而,GIW只对相邻通道特征图进行去相关性操作,却没有考虑寻找更合适的通道组合。我们知道,每个通道的特征实际上提取的是对应某一个类别的关键语义信息。因此,我们可以以语义类别的信息来进行分组。这个语义信息就来自SAN中的Classier中的权重。它代表了特征图中每个通道的特征图对不同类别的重要程度。对这个进行排序,然后依次取出对应的特征图便可以得到分组后的特征。一共选取了K/C组,每组C个通道。分组之后计算与图6中的L_GIW形式一致的损失函数L_SAW即可。

三、解读

 

四、实验

 

关键代码

SAN

from math import ceil
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import argparse
import torch.utils.model_zoo as model_zoo
import kmeans1d
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.vq import whiten
affine_par = True

# https://github.com/leolyj/SAN-SAW/blob/main/graphs/models/SAN.py

class SAN(nn.Module):

    def __init__(self, inplanes, selected_classes=None):
        super(SAN, self).__init__()
        self.margin = 0
        self.IN = nn.InstanceNorm2d(inplanes, affine=affine_par)
        self.selected_classes = selected_classes
        self.CFR_branches = nn.ModuleList()
        for i in selected_classes:
            self.CFR_branches.append(
                nn.Conv2d(3, 1, kernel_size=7, stride=1, padding=3, bias=False))

        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.mask_matrix = None

    def cosine_distance(self, obs, centers):
        obs_norm = obs / obs.norm(dim=1, keepdim=True)
        centers_norm = centers / centers.norm(dim=1, keepdim=True)
        cos = torch.matmul(obs_norm, centers_norm.transpose(1, 0))
        return 1 - cos

    def l2_distance(self, obs, centers):
        dis = ((obs.unsqueeze(dim=1) - centers.unsqueeze(dim=0)) ** 2.0).sum(dim=-1).squeeze()
        return dis

    def _kmeans_batch(self, obs: torch.Tensor, k: int, distance_function,batch_size=0, thresh=1e-5, norm_center=False):

        # k x D
        centers = obs[torch.randperm(obs.size(0))[:k]].clone()
        history_distances = [float('inf')]
        if batch_size == 0:
            batch_size = obs.shape[0]
        while True:
            # (N x D, k x D) -> N x k
            segs = torch.split(obs, batch_size)
            seg_center_dis = []
            seg_center_ids = []
            for seg in segs:
                distances = distance_function(seg, centers)
                center_dis, center_ids = distances.min(dim=1)
                seg_center_ids.append(center_ids)
                seg_center_dis.append(center_dis)

            obs_center_dis_mean = torch.cat(seg_center_dis).mean()
            obs_center_ids = torch.cat(seg_center_ids)
            history_distances.append(obs_center_dis_mean.item())
            diff = history_distances[-2] - history_distances[-1]
            if diff < thresh:
                if diff < 0:
                    warnings.warn("Distance diff < 0, distances: " + ", ".join(map(str, history_distances)))
                break
            for i in range(k):
                obs_id_in_cluster_i = obs_center_ids == i
                if obs_id_in_cluster_i.sum() == 0:
                    continue
                obs_in_cluster = obs.index_select(0, obs_id_in_cluster_i.nonzero().squeeze())
                c = obs_in_cluster.mean(dim=0)
                if norm_center:
                    c /= c.norm()
                centers[i] = c
        return centers, history_distances[-1]

    def kmeans(self, obs: torch.Tensor, k: int, distance_function=l2_distance, iter=20, batch_size=0, thresh=1e-5, norm_center=False):

        best_distance = float("inf")
        best_centers = None
        for i in range(iter):
            if batch_size == 0:
                batch_size == obs.shape[0]
            centers, distance = self._kmeans_batch(obs, k,
                                              norm_center=norm_center,
                                              distance_function=distance_function,
                                              batch_size=batch_size,
                                              thresh=thresh)
            if distance < best_distance:
                best_centers = centers
                best_distance = distance
        return best_centers, best_distance

    def product_quantization(self, data, sub_vector_size, k, **kwargs):
        centers = []
        for i in range(0, data.shape[1], sub_vector_size):
            sub_data = data[:, i:i + sub_vector_size]
            sub_centers, _ = self.kmeans(sub_data, k=k, **kwargs)
            centers.append(sub_centers)
        return centers

    def data_to_pq(self, data, centers):
        assert (len(centers) > 0)
        assert (data.shape[1] == sum([cb.shape[1] for cb in centers]))

        m = len(centers)
        sub_size = centers[0].shape[1]
        ret = torch.zeros(data.shape[0], m,
                          dtype=torch.uint8,
                          device=data.device)
        for idx, sub_vec in enumerate(torch.split(data, sub_size, dim=1)):
            dis = self.l2_distance(sub_vec, centers[idx])
            ret[:, idx] = dis.argmin(dim=1).to(dtype=torch.uint8)
        return ret

    def train_product_quantization(self, data, sub_vector_size, k, **kwargs):
        center_list = self.product_quantization(data, sub_vector_size, k, **kwargs)
        pq_data = self.data_to_pq(data, center_list)
        return pq_data, center_list

    def _gram(self, x):
        (bs, ch, h, w) = x.size()
        f = x.view(bs, ch, w * h)
        f_T = f.transpose(1, 2)
        G = f.bmm(f_T) / (ch * h * w)
        return G

    def pq_distance_book(self, pq_centers):
        assert (len(pq_centers) > 0)

        pq = torch.zeros(len(pq_centers),
                         len(pq_centers[0]),
                         len(pq_centers[0]),
                         device=pq_centers[0].device)
        for ci, center in enumerate(pq_centers):
            for i in range(len(center)):
                dis = self.l2_distance(center[i:i + 1, :], center)
                pq[ci, i] = dis
        return pq

    def Regional_Normalization(self, region_mask, x):
        masked = x*region_mask
        RN_feature_map = self.IN(masked)
        return RN_feature_map

    def asymmetric_table(self, query, centers):
        m = len(centers)
        sub_size = centers[0].shape[1]
        ret = torch.zeros(
            query.shape[0], m, centers[0].shape[0],
            device=query.device)
        assert (query.shape[1] == sum([cb.shape[1] for cb in centers]))
        for i, offset in enumerate(range(0, query.shape[1], sub_size)):
            sub_query = query[:, offset: offset + sub_size]
            ret[:, i, :] = self.l2_distance(sub_query, centers[i])
        return ret

    def asymmetric_distance_slow(self, asymmetric_tab, pq_data):
        ret = torch.zeros(asymmetric_tab.shape[0], pq_data.shape[0])
        for i in range(asymmetric_tab.shape[0]):
            for j in range(pq_data.shape[0]):
                dis = 0
                for k in range(pq_data.shape[1]):
                    sub_dis = asymmetric_tab[i, k, pq_data[j, k].item()]
                    dis += sub_dis
                ret[i, j] = dis
        return ret

    def asymmetric_distance(self, asymmetric_tab, pq_data):
        pq_db = pq_data.long()
        dd = [torch.index_select(asymmetric_tab[:, i, :], 1, pq_db[:, i]) for i in range(pq_data.shape[1])]
        return sum(dd)

    def pq_distance(self, obj, centers, pq_disbook):
        ret = torch.zeros(obj.shape[0], centers.shape[0])
        for obj_idx, o in enumerate(obj):
            for ct_idx, c in enumerate(centers):
                for i, (oi, ci) in enumerate(zip(o, c)):
                    ret[obj_idx, ct_idx] += pq_disbook[i, oi.item(), ci.item()]
        return ret

    def set_class_mask_matrix(self, normalized_map):

        b,c,h,w = normalized_map.size()
        var_flatten = torch.flatten(normalized_map)


        try:  # kmeans1d clustering setting for RN block
            clusters, centroids = kmeans1d.cluster(var_flatten,5, 3)
            num_category = var_flatten.size()[0] - clusters.count(0)  # 1: class-region, 2~5: background
            _, indices = torch.topk(var_flatten, k=int(num_category))
            mask_matrix = torch.flatten(torch.zeros(b, c, h, w).cuda())
            mask_matrix[indices] = 1
        except:
            mask_matrix = torch.ones(var_flatten.size()[0]).cuda()

        mask_matrix = mask_matrix.view(b, c, h, w)

        return mask_matrix

    def forward(self, x, masks):
        outs=[]
        idx = 0
        masks = F.softmax(masks,dim=1)
        for i in self.selected_classes:
            mask = torch.unsqueeze(masks[:,i,:,:],1)
            mid = x * mask
            avg_out = torch.mean(mid, dim=1, keepdim=True)
            max_out,_ = torch.max(mid,dim=1, keepdim=True)
            atten = torch.cat([avg_out,max_out,mask],dim=1)
            atten = self.sigmoid(self.CFR_branches[idx](atten))
            out = mid*atten
            heatmap = torch.mean(out, dim=1, keepdim=True)

            class_region = self.set_class_mask_matrix(heatmap)
            out = self.Regional_Normalization(class_region,out)
            outs.append(out)
        out_ = sum(outs)
        out_ = self.relu(out_)

        return out_

SAM

# https://github.com/leolyj/SAN-SAW/blob/main/graphs/models/SAW.py

import torch
import torch.nn as nn
import math
import torch.nn.functional as F


class SAW(nn.Module):
    def __init__(self, args, dim, relax_denom=0, classifier=None, work=False):
        super(SAW, self).__init__()
        self.work = work
        self.selected_classes = args.selected_classes
        self.C = len(args.selected_classes)
        self.dim = dim
        self.i = torch.eye(self.C, self.C).cuda()
        self.reversal_i = torch.ones(self.C, self.C).triu(diagonal=1).cuda()
        self.classify = classifier
        self.num_off_diagonal = torch.sum(self.reversal_i)
        if relax_denom == 0:
            print("Note relax_denom == 0!")
            self.margin = 0
        else:
            self.margin = self.num_off_diagonal // relax_denom


    def get_mask_matrix(self):
        return self.i, self.reversal_i, self.margin, self.num_off_diagonal

    def get_covariance_matrix(self, x, eye=None):
        eps = 1e-5
        B, C, H, W = x.shape  # i-th feature size (B X C X H X W)
        HW = H * W
        if eye is None:
            eye = torch.eye(C).cuda()
        x = x.contiguous().view(B, C, -1)  # B X C X H X W > B X C X (H X W)
        x_cor = torch.bmm(x, x.transpose(1, 2)).div(HW - 1) + (eps * eye)  # C X C / HW

        return x_cor, B

    def instance_whitening_loss(self, x, eye, mask_matrix, margin, num_remove_cov):
        x_cor, B = self.get_covariance_matrix(x, eye=eye)
        x_cor_masked = x_cor * mask_matrix

        off_diag_sum = torch.sum(torch.abs(x_cor_masked), dim=(1, 2), keepdim=True) - margin  # B X 1 X 1
        loss = torch.clamp(torch.div(off_diag_sum, num_remove_cov), min=0)  # B X 1 X 1
        loss = torch.sum(loss) / B

        return loss
    def sort_with_idx(self, x, idx,weights):
        b,c,_,_ = x.size()
        after_sort = torch.zeros_like(x)
        weights = F.sigmoid(weights)
        for i in range(b):

            for k in range(int(c / self.C)):
                for j in range(self.C):
                    channel_id = idx[self.selected_classes[j]][k]
                    wgh = weights[self.selected_classes[j]][channel_id]
                    after_sort[i][self.C*k+j][:][:] = wgh * x[i][channel_id][:][:]

        return after_sort

    def forward(self, x):
        if self.work:
            weights_keys = self.classify.state_dict().keys()

            selected_keys_classify = []

            for key in weights_keys:
                if "weight" in key:
                    selected_keys_classify.append(key)

            for key in selected_keys_classify:
                weights_t = self.classify.state_dict()[key]

            classsifier_weights = abs(weights_t.squeeze())
            _,index = torch.sort(classsifier_weights, descending=True,dim=1)
            f_map_lst = []
            B, channel_num, H, W = x.shape
            x = self.sort_with_idx(x,index,classsifier_weights)

            for i in range(int(channel_num/self.C)):
                group = x[:,self.C*i:self.C*(i+1),:,:]
                f_map_lst.append(group)

            eye, mask_matrix, margin, num_remove_cov = self.get_mask_matrix()
            SAW_loss = torch.FloatTensor([0]).cuda()

            for i in range(int(channel_num / self.C)):
                loss = self.instance_whitening_loss(f_map_lst[i], eye, mask_matrix, margin, num_remove_cov)
                SAW_loss = SAW_loss+loss
        else:
            SAW_loss = torch.FloatTensor([0]).cuda()



        return