• 问答
  • 技术
  • 实践
  • 资源
Vision MLP 超详细解读
技术讨论

作者丨科技猛兽
​编辑丨极市平台

本文目录

4 谷歌大脑提出gMLP:请多多关注MLP
(来自谷歌大脑,Quoc V .Le 团队)
4.1 gMLP原理分析

5 港大提出CycleMLP:用于密集预测的类似 MLP 的架构
(来自港大,罗平教授团队)
5.1 CycleMLP原理分析
5.2 CycleMLP代码解读

4 谷歌大脑提出gMLP:请多多关注MLP

论文名称:Pay Attention to MLPs

论文地址:

https://arxiv.org/abs/2105.08050

4.1 gMLP原理分析

本文提出了一种 gMLP 模型,g 代表 "gating"。作者觉得这是一个仅包含 gating 的 MLP 模型,所以取名为 gMLP。

介绍 gMLP 模型之前先大致了解下它的性能如何:

在 ImageNet 标准分类实验上,gMLP 与基于 Transformer 的 DeiT 模型性能相当。因此作者认为:self-attention 结构可能不是视觉模型所必需的。

在 masked language modeling (MLM) 问题上,即作者将 gMLP 模型应用于这个最常见的 NLP 任务中时,发现它与 Transformer 模型性能也相当。 作者的实验表明,MLM 问题模型的性能只与模型大小相关,对注意力机制的存在不敏感。比如,在256-batch size 和 1M-step training 的实验设置下,gMLP 可以在 MNLI 上达到 86.4\% 的精度,在 SQuAD v1.1 上达到 89.5\% 的 F1,与 Transformer 模型的结果相当。

除此之外,gMLP 模型的诸多性质和 Transformer 模型是类似的,比如:增大模型的尺寸可以提升模型的性能,增加数据量可以提升模型的性能,这与是否使用 self-attention 机制无关。 作者的研究结果表明,self-attention 机制并不是扩大机器学习模型的必要因素。随着数据和计算量的增加,使用简单空间相互作用机制的模型,如 gMLP,可以像 Transformer 一样强大。

图1:gMLP 模型的结构和伪代码

gMLP 模型的结构和伪代码如上图1所示。

gMLP 模型包括 $L$ 个大小和结构相同的 block,假设输入是 $X \in \mathbb{R}^{n \times d}$ ,其中 $n$ 是 $n$ 个tokens, $d$ 是每个 token 的 Embedding dimension。每个 block 的定义是:

$$
\begin{align} Z = \sigma(XU), \qquad \color{red}{\tilde{Z} = s(Z), \qquad} Y = \tilde{Z}V \end{align}
$$

其中, $\sigma$ 是激活函数,比如 GeLU。

$U$ 和 $V$ 是类似于 FFN 的全连接操作,作者称为 Channel Projection。为了简写起见,省略了 Shortcuts, normalizations 和 biases。

所以最关键的就是红色部分的 $\color{red}{\tilde{Z} = s(Z)}$ ,也就是 Spatial Gating Unit 的操作。

在 MLP-Mixer 中我们提到,MLP-Mixer的输入是维度是 $X \in \mathbb{R}^{n \times d}$ 的张量,那么这个Mixer不仅混合各个channels之间的信息,也混合不同空间位置 (tokens 或 patches) 之间的信息。 所以主要分成了2种层:channel-mixing MLPs 层和 token-mixing MLPs 层。

channel-mixing MLPs层结合不同channels的信息,是一种按位置per-location的操作。

token-mixing MLPs层结合不同tokens的信息,是一种跨位置cross-location的操作。

那么对于一个维度是 $(N,D)=\text{patches×channels}$ 的输入矩阵来讲,channel-mixing MLPs层应该是作用于它的每一行,而token-mixing MLPs层作用于它的每一列。

MLP-Mixer即可以靠channel-mixing MLPs层结合不同channels的信息,也可以靠token-mixing MLPs层结合不同空间位置的信息。

gMLP 的 Channel Projection 层就是 MLP-Mixer 的 channel-mixing MLPs 层,而这个红色的 $\color{red}{\tilde{Z} = s(Z)}$ 作用就相当于 MLP-Mixer 的 token-mixing MLPs 层

如果 $\color{red}{s(\cdot)}$ 是 identity mapping,则整个结构就是个 FFN,没有任何的 cross-token 的交流。因此,重点是设计一个捕捉复杂空间相互作用的 $\color{red}{s(\cdot)}$ 。

设计 $\color{red}{s(\cdot)}$ 的一种最简单的想法是:

$\begin{equation} f_{W, b}(Z) = WZ + b \label{eq:spatial-proj} \end{equation} \tag{2}$

式中, $W \in \mathbb{R}^{n \times n}$ 矩阵的作用是融合 spatial 的信息。比如输入序列的长度是 $n=128$ ,则 $W \in \mathbb{R}^{128 \times 128}$ 。在这项工作中,作者将空间相互作用单元定义为二者的乘积:

$$
\begin{align} \color{red}{s(Z)} = Z \odot f_{W, b}(Z) \label{eq:spatial-gating} \end{align}
$$

式中, $\odot$ 代表 element-wise 的乘法。 $W$ 以接近0作初始化, $b$ 以接近1作初始化。这样初始化的目的是让 $\color{red}{s(\cdot)}$ 在一开始和 Identity mapping 接近。

这个初始化确保每个 gMLP block 在训练的初期阶段像一个常规的 FFN 模型一样运行,其中每个 token 都是独立处理的,并且在 token 之间随训练的进行逐步注入空间信息。

multiplicative gating 靠的是 $f_{W, b}(Z) $ 来调制输入信号 $Z$ ,作者进一步发现把输入信号 $Z$沿着 channel 维度分开成 $(Z_1, Z_2)$ ,并先对 $Z_2$ 进行归一化运算 norm,再对 $Z_1$ 和归一化后的 $Z_2$ 进行下列运算:

$$
\begin{align} s(Z) = Z1 \odot f{W, b}(Z_2) \label{eq:spatial-gating-independent} \end{align}
$$

这一过程的伪代码如下:

 def spatial_gating_unit(x):
   u, v = split(x, axis="channel")
   v = norm(v, axis="channel")
   n = get_dim(v, axis="spatial")
   v = proj(v, n, axis="spatial", init_bias=1)
   return u * v

norm(v, axis="channel"):
对 $Z_2$ 进行归一化运算 norm。

proj(v, n, axis="spatial", init_bias=1):
使用 $f_{W, b}(Z) $ 来调制输入信号。

return u * v:
进行 element-wise 的乘法计算输出。

$\color{red}{s(\cdot)}$ Spatial Gating Unit 与 Gated Linear Units (GLUs) 的区别是:

Spatial Gating Unit 进行空间维度的 cross-token 的计算,而 Gated Linear Units (GLUs) 进行 channel 维度的 per-token 的计算。Spatial Gating Unit 也有点像 SE 模块的元素级操作。

图像分类实验

数据集:ImageNet

模型配置如下:

图2:gMLP 模型配置

gMLP-Tiny,Small,Base 模型的参数量和 DeiT-Tiny,Small,Base 模型相当。所有的图片都会被分成大小为 16×16 的 patch。下图3是 ImageNet 分类任务的实验结果。

图3:ImageNet 分类任务的实验结果

可以看到尺寸相当的 gMLP 的 ImageNet 分类精度超过了 ResMLP,并且和 DeiT 模型的精度相当。 换句话讲,就 accuracy-parameter/FLOPs tradeoff 而言,gMLP 超过了其余两种 MLP 模型。事实上,模型的准确性似乎与参数量大小更好地相关,而不是是否存在注意力机制。attention-free 的模型也可以是 Data-Efficient 的。

下图4可视化了每一层的空间投影矩阵 (spatial projection 矩阵),即伪代码中的:

v = proj(v, n, axis="spatial", init_bias=1)

图4:spatial projection 矩阵可视化

观察可以发现, spatial weights 展示出了局部和空间的不同。换句话说,每个空间投影矩阵有效地学习到了不规则的核形状。

Masked Language Modeling 实验

模型: BERT

作者进一步在 Masked Language Modeling 任务上进行了实验,pretraining 和 finetuning 的输入和输出和 BERT 保持一致。

如下图5所示,作者对比了几种位置编码以及 gMLP 模块的组成方式,发现$s(Z)=Z{1} \odot f{W, b}\left(Z_{2}\right)$的做法性能最佳。

图5:几种位置编码以及 gMLP 模块的组成方式

下图6展示了 gMLP 随着模型增大逐渐能有与 Transformer 相当的效果,可见Transformer的效果应该主要是依赖于模型尺寸而非self-attention。

图6:随着模型尺寸增大的性能

从上面的 Case Study 可以发现 gMLP 对于需要跨句子连接的 Finetuing 任务可能不及Transformer,所以作者提出了 gMLP 的增强版 aMLP。aMLP 相较于 gMLP 仅增加了一个单头 Embedding dimension=64 (远小于 DeiT-Base 的Embedding dimension=768,12 heads) 的 self-attention 如下图7所示:

图7:aMLP 模型的结构和伪代码

从下图8结果可以发现 aMLP (64-d single head-attention) 相较于 gMLP 极大提升了效果并在所有 task 都超过了Transformer。 通过改变模型深度 (从 {0.5,1,2} 中选择) 或数据量 (从 {1,2,4,8} 中选择) 来收集数据点。可以看出,不管注意力的存在与否,gMLP 对 SST-2 的迁移效果都比 Transformer 好,而 gMLP 对 MNLI-m 的迁移效果较差,只要稍微加上 Tiny-attention 就足以缩小性能上的差距。

图8:aMLP 相较于 gMLP 极大提升了效果

下图9展示了BERT、gMLP、aMLP在不同模型参数情况下的效果比较,可以看到模型相当的情况下, gMLP 能取得接近 BERT 的效果,aMLP 更是能超过 BERT。

图9:模型参数量相当的情况下,gMLP 能取得接近 BERT 的效果,aMLP 更是能超过 BERT

小结

gMLP 这种模型架构通过一种不包含卷积的 Spatial Gating Unit 操作来实现模型对于输入特征的空间位置信息的融合,在图像分类任务和 Masked Language Modeling 实验上都能取得不错的性能。而且,只要稍微融合如一些 Attention 机制就能够在 Masked Language Modeling 任务上进一步地再提升模型性能。

5 港大提出CycleMLP:用于密集预测的类似 MLP 的架构

论文名称:CycleMLP: A MLP-like Architecture for Dense Prediction

论文地址: https://arxiv.org/abs/2107.10224

5.1 CycleMLP原理分析

本文来自港大罗平老师团队,是 MLP 模型家族的又一个新作品。针对 MLP-Mixer 等已有方案存在的分辨率相关、不便于向下游任务迁移的问题,提出了一种新颖的 CycleFC 操作,并由此构建了 CycleMLP 架构。所提 CycleMLP 在 ImageNet 分类、COCO 检测以及 ADE20K语义分割等任务上均取得了优于其他 MLP 架构的性能,同时具有与 Swin 相当甚至更佳的性能。

已有的一些经典的 MLP 模型,比如:MLP-Mixer, ResMLP, 和 gMLP 等,它们存在以下3个问题:

  1. 所有的 Block 具有相同的结构,即每个 Block 输出的特征的维度也是一致的,不能产生金字塔结构的特征。
  2. 只能适用于 input shape 不变的情况,但是不能适用于 input shape 改变的情况,如目标检测和语义分割任务。比如某一层的输入是 $x\in \mathbb{R}^{C_i\times H_i \times W_i}$ ,输出是 $y\in \mathbb{R}^{C_o\times H_o \times W_o}$ ,spatial FC 的权重是 $Ws \in \mathbb{R}^{{H{i}W{i}}\times {H{o}W_{o}}}$ 。因此,Spatial FC 的结构是通过 $H\times W$ 来配置的。因此,这些模型在训练和验证阶段都需要一个固定的输入。但是,比如在 ADE20K (分辨率为 512×512) 上做语义分割,或者一些需要多尺度训练的任务中,在训练和验证阶段具有不同的分辨率大小。因此,之前的这些 MLP 模型在此类应用场景中就会受到限制。
  3. Spatial FC 的计算复杂度与图像尺寸呈平方关系,使得现有的高分辨率图像难以处理。

图10:几种不同 FC 结构的比较

接下来我们依次来看 Cycle MLP 是如何解决以上3个问题的,如上图10所示。

针对第1个问题,作者构建了一金字塔结构的表征,模型从浅层到深层,特征的空间分辨率在逐渐减小,channel 数在逐渐增大。

针对第2和第3个问题,作者提出了一种 Cycle Fully-Connected Layer。该 Cycle FC 层能够处理不同尺度的图像,并且计算复杂度与图像尺寸成线性关系。

仔细观察图10,发现 Cycle FC 层和 FFN 的 Channel FC 层很像,我们把 $H\times W$ 这个维度视作是 token 维度,并有 $N=HW$ 。Channel FC 层就是对这 $N$ 个 tokens 中的每一个的 $C$ 个值进行映射,而 Cycle FC 层就是对来自不同 tokens 的 $C$ 个值进行映射,且这 $C$ 个值是循环往复选取的。

这样一来,Cycle FC 层就和 Channel FC 层的计算复杂度一致了,都是与序列长度 $N$ 成线性关系。这样,模型中所有的 Channel FC 层都使用 Cycle FC 层进行替换,就得到了 Cycle MLP 模型,它是第一个提供检测分割任务的 MLP 模型基线的模型。

模型的设计思路遵循金字塔结构,即随着层数的加深,特征的分辨率在逐渐减小,也就是 token 的数目在不断地减小。假设输入图片的维度是 $H\times W\times 3$ ,首先通过 patch embedding 操作,即 stride=4 的 patch embedding 操作将 $H\times W\times 3$ 投影成 $\frac{H}{4}\times\frac{W}{4}\times C$ 的特征。

接下来连续地通过几个 Cycle FC block,在每个阶段中 token 的数量保持不变。在阶段与阶段的转换中,token 的数量减少,同时增加 channel 数。 这种策略有效地降低了空间分辨率的复杂性。总的来说,每个模型都有四个阶段,最后阶段的输出特征是 $\frac{H}{32}\times\frac{W}{32}\times 4C$ 。

Cycle FC 块的具体表达式是:

$$\begin{aligned} \hat{\mathbf{z}}^{\ell} &={\text { Cycle } \mathrm{FC}}\left(\mathrm{LN}\left(\mathbf{z}^{\ell-1}\right)\right)+\mathbf{z}^{\ell-1}, & & \ell=1 \ldots L \ \mathbf{z}^{\ell} &={\text { Channel-MLP }}\left(\operatorname{LN}\left(\hat{\mathbf{z}}^{\ell}\right)\right)+\hat{\mathbf{z}}^{\ell}, & & \ell=1 \ldots L \end{aligned}\tag{5}$$

block 的设计和 ViT 相似,只是将 MHSA 换成了 Cycle FC 模块。 因此,Cycle FC 模块可以作为现有的变压器或基于 mlp 架构的现成替代品。

Cycle Fully-Connected Layer

如上图10 (a) 所示,Channel FC 层的结构与图像的尺度无关。因此,这个操作可以处理可变的输入图像尺度。此外,Channel FC 的另一个优点是它的计算复杂度与图像尺度的关系是线性的。然而,它有限的感受野,不能聚合足够的上下文信息。

为了增加感受野的大小,如上图10 (c) 所示,Cycle FC 层的也在 channel 的维度进行计算,但是,不同于 Channel FC 层的采样点位于同一空间位置的所有通道,Cycle FC 层的采样点像个梯子,采用一种循环采样的方式。

得益于这种简单而有效的设计,Cycle FC 层具有与Channel FC 层严格相等的计算复杂度。另外,感受野从一个点扩大到伪核 (Pseudo-kernel),这将在下面介绍。

伪核 (Pseudo-kernel)

图11:伪核 \(Pseudo-kernel\)

这里作者引入了伪核的概念。如上图11所示,在空间表面投影 Cycle FC 层的采样点 (橙色块) ,并定义投影区域为 pseudo-kernel size。令 $X \in \mathbb{R}^{HW \times C_i}$ 表示输入特征映射,其中 $HW$ 表示其展平的高度和宽度, $C_i$ 表示输入通道。对于 Channel FC 层,输出特征是: $Y \in \mathbb{R}^{HW \times C_o}$ 。

$$
Y{i, j}=\sum{c=0}^{C{i}-1} \mathcal{F}{j, c}^{T} \cdot X_{i, c}\tag{6}
$$

Cycle FC 层的计算范式如下式所示:

$$
Y{i, j}=\sum{c=0}^{C{i}-1} \mathcal{F}{j, c}^{T} \cdot X{i+c \% S{\mathcal{P}}, c}\tag{7}
$$

将 $X{i, c} $ 变成了 $X{i+c\%S\mathcal{P}, c}$ ,随着 $c$ 的增加,采样的位置 $i+c\%S\mathcal{P}$ 也在不断地变化。这里 $S_\mathcal{P}=K_h \times K_w$ 是伪核的尺寸,也可以看做 Cycle FC 层的感受野。

Cycle FC 层可以感知相邻的上下文信息,同时保持与 Channel FC 层相同的复杂度。因此,Cycle FC 是一个通用的即插即用操作符,用于促进空间上下文聚合。注意,当将伪内核大小配置为 $S_\mathcal{P}=1 \times 1$ 时,Cycle FC 层退化为 Channel FC。

模型及其变体

图12:Cycle MLP 模型及其变体

如上图12所示是 Cycle MLP 模型及其变体,计算量的范围从2.1GFLOPs 到 12.3GFLOPs 不等。并行的 Cycle FC 分支的数目设置为3个,伪核容量为 $S_\mathcal{P}=1 \times 3,3 \times 1,1 \times 1$ 。 超参数的定义如下:

• $S_i$:第 $i$ 个 stage 过渡层的 stride 值。

• $C_i$:第 $i$ 个 stage 的 Embedding dimension。

• $L_i$:第 $i$ 个 stage 的 block 数。

• $E_i$:第 $i$ 个 stage 的 Expansion ratio。

ImageNet 分类实验

如下图13所示为 ImageNet 分类实验结果。作者首先比较 Cycle MLP 模型和现有的 MLP-like 模型,结果总结在下图。Cycle MLP 模型的 Accuracy/FLOPs 的 trade-off 始终优于现有的 MLP-like 模型,作者把这归功于Cycle FC 模块的有效性。此外,与现有的 SOTA MLP-like 模型,即 ViP 相比,Cycle MLP-B5 只使用了一半的浮点运算 (12.3 GFLOPs) ,同时保持了相同的 Top-1 准确度(83.2\%)。

图13:ImageNet 分类实验结果

下图14进一步比较了 Cycle MLP 模型与先前的最先进的 CNN,Transformer 和混合结构的性能。有趣的是,看到 Cycle MLP 模型实现了与 Swin-transformer 相当的性能,Swin-transformer 是最先进的基于 Transformer 的模型。具体来说,Cycle MLP-B5 的准确率几乎与 Swin-B 相当 (83.2\%) ,但是参数和浮点运算量都较低。

图14:Cycle MLP 模型与先前的最先进的 CNN,Transformer 和混合结构的性能对比

对比实验:三个分支的作用

下图15进一步详细描述了 Cycle MLP 模型的对比实验。结果表明,去掉 $S\mathcal{P}=1 \times 3,3 \times 1,1 \times 1$ 这三个平行支路中的任意一个,特别是去掉 $S\mathcal{P}=1 \times 3,3 \times 1$ 后,Top-1 Accuracy 会明显下降。为了消除更少的参数和浮点运算导致性能下降的可能性,作者进一步使用两个相同的分支和一个 $S_\mathcal{P}=1 \times 1$ 分支来对齐参数和浮点运算。这进一步证明了3个独特分支的必要性。

图15:Cycle MLP 模型的对比实验

目标检测实验

数据集: COCO,118k训练集,5k验证集

Backbone: RetinaNet,Mask R-CNN

初始化权重: ImageNet pre-trained weights

结果如下图16所示:

图16:目标检测实验结果

基于 CycleMLP 模型的 Retina-Net 在相似的参数约束条件下,始终超过了基于 CNN 的 Resnet、 ResneXt 和基于 Transformer 的 PVT,这表明 CycleMLP 模型可以作为一个优秀的通用 Backbone 网络。

语义分割实验

数据集: ADE20K,20k训练集,2k验证集,3k测试集

Backbone: Semantic FPN

初始化权重: ImageNet pre-trained weights

结果如下图17所示:

图17:语义分割实验结果

CycleMLP 在类似参数下的性能优于 Resnet 和 PVT。此外,与最先进的基于 Transformer 的 Backbone Swin transformer 相比,CycleMLP 可以获得相当甚至更好的性能。

5.2 CycleMLP代码解读

代码来自:

https://github.com/ShoufaChen/CycleMLP

1 Cycle FC

我们首先看看最关键的 Cycle FC 是如何实现的。

CycleFC 的实现其实借助了 torchvision.ops.deform_conv 中的 deform_conv2d 这个 API。所以,可以从某种角度上认为:Cycle FC 操作是一种极其特殊的 Deformable Conv。但也不能完全这么说,因为毕竟 Cycle FC 操作和 Deformable Conv 的具体做法存在本质的区别

class CycleFC(nn.Module): 
    """
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size,  # re-defined kernel_size, represent the spatial area of staircase FC
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
        super(CycleFC, self).__init__()

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        if stride != 1:
            raise ValueError('stride must be 1')
        if padding != 0:
            raise ValueError('padding must be 0')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups

        self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, 1, 1))  # kernel size == 1

        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)
        self.register_buffer('offset', self.gen_offset())

        self.reset_parameters()

deform_conv2d 这个操作需要传入几个参数:
input,offset,weight,bias,stride,padding,dilation

weight 参数设置为 kernel size=1×1。

offset 参数不随优化器进行更新,所以使用了:
self.register_buffer('offset', self.gen_offset())
进行定义。

2 gen_offset()

这个函数就是生成 offset 值的函数,定义如下:

    def gen_offset(self):
        """
        offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
            out_height, out_width]): offsets to be applied for each position in the
            convolution kernel.
        """
        offset = torch.empty(1, self.in_channels*2, 1, 1)
        start_idx = (self.kernel_size[0] * self.kernel_size[1]) // 2
        assert self.kernel_size[0] == 1 or self.kernel_size[1] == 1, self.kernel_size
        for i in range(self.in_channels):
            if self.kernel_size[0] == 1:
                offset[0, 2 * i + 0, 0, 0] = 0
                offset[0, 2 * i + 1, 0, 0] = (i + start_idx) % self.kernel_size[1] - (self.kernel_size[1] // 2)
            else:
                offset[0, 2 * i + 0, 0, 0] = (i + start_idx) % self.kernel_size[0] - (self.kernel_size[0] // 2)
                offset[0, 2 * i + 1, 0, 0] = 0
        return offset

每个 channel 都有 2 个偏移值,所以维度设置为:(1, self.in_channels*2, 1, 1)

kernel_size 是 CycleFC 这个 class 传入的参数,如果是 1×3 的话,则:
offset[0, 2 * i + 0, 0, 0] = 0,即每个 channel 的横向偏移为0。
offset[0, 2 * i + 1, 0, 0] = (i + start_idx) \% self.kernel_size[1] - (self.kernel_size[1] // 2),即每个 channel 的纵向偏移为 $c\%S_\mathcal{P}$ 。

接下来有了 offset 和 weight,就能够使用 deform_conv2d 这个 API 构建 CycleFC 的操作了。

    def forward(self, input: Tensor) -> Tensor:
        """
        Args:
            input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
        """
        B, C, H, W = input.size()
        return deform_conv2d_tv(input, self.offset.expand(B, -1, H, W), self.weight, self.bias, stride=self.stride,
                                padding=self.padding, dilation=self.dilation)

3 Cycle MLP Block

借助 CycleFC 操作实现 Cycle MLP Block。

class CycleMLP(nn.Module):
    def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)

        self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0)
        self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0)

        self.reweight = Mlp(dim, dim // 4, dim * 3)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape
        # h: (B, H, W, C)
        h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        # w: (B, H, W, C)
        w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        # c: (B, H, W, C)
        c = self.mlp_c(x)

        # a: (B, C)
        a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
        # a: (3, B, 1, 1, C)
        a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)

        # X: (B, H, W, C)
        x = h * a[0] + w * a[1] + c * a[2]

        x = self.proj(x)
        x = self.proj_drop(x)

        # X: (B, H, W, C)
        return x

每一步的维度变化都已经以注释的形式标注在了代码里面,可以清晰地看到 Cycle MLP Block 的3个分支,并且这3个分支还进行了一次:
a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0). unsqueeze(2).unsqueeze(2)
x = h * a[0] + w * a[1] + c * a[2]
操作,它极其类似于 SE 模块的 Attention,可能是涨点的关键。

4 CycleBlock

借助 Cycle MLP Block 实现 CycleBlock。

class CycleBlock(nn.Module):

    def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=CycleMLP):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = mlp_fn(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop)

        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
        self.skip_lam = skip_lam

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam
        x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam
        return x

5 basic_blocks

借助 CycleBlock 实现 basic_blocks。

def basic_blocks(dim, index, layers, mlp_ratio=3., qkv_bias=False, qk_scale=None, attn_drop=0.,
                 drop_path_rate=0., skip_lam=1.0, mlp_fn=CycleMLP, **kwargs):
    blocks = []

    for block_idx in range(layers[index]):
        block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
        blocks.append(CycleBlock(dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                      attn_drop=attn_drop, drop_path=block_dpr, skip_lam=skip_lam, mlp_fn=mlp_fn))
    blocks = nn.Sequential(*blocks)

    return blocks

6 整个 CycleNet

class CycleNet(nn.Module):
    """ CycleMLP Network """
    def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
        embed_dims=None, transitions=None, segment_dim=None, mlp_ratios=None, skip_lam=1.0,
        qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
        norm_layer=nn.LayerNorm, mlp_fn=CycleMLP, fork_feat=False):

        super().__init__()
        if not fork_feat:
            self.num_classes = num_classes
        self.fork_feat = fork_feat

        self.patch_embed = PatchEmbedOverlapping(patch_size=7, stride=4, padding=2, in_chans=3, embed_dim=embed_dims[0])

        network = []
        for i in range(len(layers)):
            stage = basic_blocks(embed_dims[i], i, layers, mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
                                 qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate,
                                 norm_layer=norm_layer, skip_lam=skip_lam, mlp_fn=mlp_fn)
            network.append(stage)
            if i >= len(layers) - 1:
                break
            if transitions[i] or embed_dims[i] != embed_dims[i+1]:
                patch_size = 2 if transitions[i] else 1
                network.append(Downsample(embed_dims[i], embed_dims[i+1], patch_size))

        self.network = nn.ModuleList(network)

        if self.fork_feat:
            # add a norm layer for each output
            self.out_indices = [0, 2, 4, 6]
            for i_emb, i_layer in enumerate(self.out_indices):
                if i_emb == 0 and os.environ.get('FORK_LAST3', None):
                    # TODO: more elegant way
                    """For RetinaNet, `start_level=1`. The first norm layer will not used.
                    cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
                    """
                    layer = nn.Identity()
                else:
                    layer = norm_layer(embed_dims[i_emb])
                layer_name = f'norm{i_layer}'
                self.add_module(layer_name, layer)
        else:
            # Classifier head
            self.norm = norm_layer(embed_dims[-1])
            self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self.cls_init_weights)

    def cls_init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, CycleFC):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def init_weights(self, pretrained=None):
        """ mmseg or mmdet `init_weight` """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_embeddings(self, x):
        x = self.patch_embed(x)
        # B,C,H,W-> B,H,W,C
        x = x.permute(0, 2, 3, 1)
        return x

    def forward_tokens(self, x):
        outs = []
        for idx, block in enumerate(self.network):
            x = block(x)
            if self.fork_feat and idx in self.out_indices:
                norm_layer = getattr(self, f'norm{idx}')
                x_out = norm_layer(x)
                outs.append(x_out.permute(0, 3, 1, 2).contiguous())
        if self.fork_feat:
            return outs

        B, H, W, C = x.shape
        x = x.reshape(B, -1, C)
        return x

    def forward(self, x):
        x = self.forward_embeddings(x)
        # B, H, W, C -> B, N, C
        x = self.forward_tokens(x)
        if self.fork_feat:
            return x

        x = self.norm(x)
        cls_out = self.head(x.mean(1))
        return cls_out

7 最后注册好全部的模型:

@register_model
def CycleMLP_B1(pretrained=False, kwargs):
transitions = [True, True, True, True]
layers = [2, 2, 4, 2]
mlp_ratios = [4, 4, 4, 4]
embed_dims = [64, 128, 320, 512]
model = CycleNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
mlp_ratios=mlp_ratios, mlp_fn=CycleMLP,
kwargs)
model.default_cfg = default_cfgs['cycle_S']
return model

... ...

小结

CycleMLP 的关键是 CycleFC 操作,Cycle FC 层的也在 channel 的维度进行计算,但是,不同于 Channel FC 层的采样点位于同一空间位置的所有通道,Cycle FC 层的采样点像个梯子,采用一种循环采样的方式。在代码实现上,作者巧妙地借助了 torchvision.ops.deform_conv 中的 deform_conv2d 这个 API,并且 Cycle MLP Block 中有类似于 SE 模块的注意力操作。总而言之,CycleMLP 是一个优秀的 MLP-like 模型,在分类检测分割任务上都取得了优异的性能。

  • 0
  • 0
  • 249
收藏
暂无评论