• 问答
  • 技术
  • 实践
  • 资源
Vision MLP 网络架构超详细解读(三)
技术讨论

作者丨科技猛兽
编辑丨极市平台

本文目录

6 GFNet:将FFT思想用于空间信息交互
(来自清华大学)
6.1 GFNet原理分析
6.2 GFNet代码解读

6 GFNet:将FFT思想用于空间信息交互

论文名称:Global Filter Networks for Image Classification

论文地址:

https://arxiv.org/pdf/2107.00645.pdfarxiv.org/pdf/2107.00645.pdf

  • 6.1 GFNet原理分析

本文提出了一种 GFNet 模型,是清华大学周杰团队在 MLP 方面的探索。

Vision Transformer 模型使用自注意力机制 (Self-attention) 来捕获长距离的依赖关系,也就是图片块之间的相互关系。

Vision MLP 模型通过用跨空间位置的 MLPs 替换 Self-attention 层,在空域进行 tokens 间信息交换,进一步简化了空域信息的融合操作。

由于不像 CNN 那样引入了归纳偏置 (inductive bias),这两种模型有可能从原始数据中学到空间位置之间更加通用和灵活的关系。

但是,不论是 Vision Transformer 模型还是 Vision MLP 模型,都有个共同的缺点,就是模型在训练和推理时的计算复杂度随着 token (或者 patch) 的数量的增加而二次增长 $O(N^2)$ 。那么为了解决这个问题,Vision Transformer 模型和 Vision MLP 模型会首先对图片分 patch,比如一张 $224×224×3$ 的图片,分为一堆 $16×16×3$ 的 patch,那么一共就有14×14=196 个 patch,即 $N=196$ 。通过这样的手段来控制计算复杂度不至于过大。

但是,这种设计可能会限制下游密集预测任务的应用,比如检测和分割任务。

为了再解决下游任务的问题,Swin Transformer 给出了一种很有效的解决方案,就是通过人为划分 window 的方式用 Local Self-attention 来取代 Global Self-attention,详细得方案可以参考下面的解读:

尽管在实践中有效,Local Self-attention 带来了相当多的手工选择 ,例如,窗口大小 (window size),填充策略 (padding strategy) 等。并限制了每一层的感受野大小。

在本文中,作者提出了全局滤波器网络 (Global Filter Network, GFNet),一种概念简单且计算高效的架构,它在频域以log-linear复杂度学习长距离空间依赖,提出了一种全局滤波器方案在频域进行token间信息交换。GFNet 模型背后的基本思想是学习空间位置之间在频率域的相互作用。不同于 Vision Transformer 模型中的自注意机制和 Vision MLP 模型中的空域上对不同 token 作用的 MLP 层,在 GFNet 里面,token 之间的相互作用被建模为一组应用的可学习的全局滤波器。

因为全局过滤器可以覆盖所有的频率,我们的模型可以捕捉到长期 (long-term) 和短期 (short-term) 的相互作用。所提架构很大程度上是基于 Vision Transformer 模型,只有一些小的修改,具体是通过如下三个关键操作替代 ViT 中的自注意力层:

  • 2D 离散傅里叶变换 (2D discrete Fourier transform):将输入的特征从空间域转化到频域。
  • 频域特征与全局可学习滤波器的点乘操作。
  • 2D逆傅里叶变换 (2D inverse Fourier transform): 将特征从频域再转化到空间域。

由于是采用了傅里叶变换用于混合不同 token 的信息,所以全局滤波器相对于 Self-attention 和 MLP 要有效得多,这要归功于快速傅里叶变换算法 (FFT) 的 $O(N\log N)$ 理论复杂度。 受益于这一点,全局滤波器对于 patch 的数量不过与敏感了。下图1是几种不同模型基本操作的计算复杂度和参数量:

图1:几种不同模型基本操作的计算复杂度和参数量

图中, $H,W,D$ 分别是特征图的长度,宽度和通道数, $k$ 是卷积核大小。本文所提出的全局滤波器比Self-attention 和 MLP 要高效得多。

所提方案在 ImageNet 以及下游任务上表现出了非常有力的精度-复杂度均衡。相比Transformer 与 CNN 模型,所提方案在高效性、泛化性以及鲁棒性方面极具竞争力。

离散傅里叶变换 (discrete Fourier transform)

离散傅里叶变换 (DFT) 在数字信号处理领域起着非常重要的作用。首先介绍1维离散傅里叶变换:

给定一个长度为 $N$ 的序列 $x[n], 0\le n\le N-1$ ,1D DFT通过如下公式将其转换到频域:

$$
\begin{equation} X[k]=\sum{n=0}^{N-1} x[n] e^{-j(2\pi / N) k n} := \sum{n=0}^{N-1} x[n] W_N^{kn} \label{equ:dft} \end{equation} \tag{1}
$$

式中, $j$ 是虚部的符号, $W_N=e^{-j(2\pi/N)}$ 。1D DFT 的表达式可以从连续信号的傅里叶变换中通过时域和频域的采样得到。

从傅里叶变换到离散傅里叶变换

离散傅里叶变换可以通过多种方式得到。这里我们将介绍最初为连续信号设计的标准傅里叶变换 FT 到 DFT 的公式。傅里叶变换将连续信号从时域转换到频域,可以看作是傅里叶级数的延伸。

信号 $x(t)$ 的傅里叶变换是:

$$
\begin{equation} X(j\omega) = \int_{-\infty}^\infty x(t)e^{-j\omega t} dt := \mathcal{F}[x(t)]. \end{equation} \tag{2}
$$

信号 $X(j\omega) $ 的傅里叶反变换是:

$$
\begin{equation} x(t) = \frac{1}{2\pi}\int_{-\infty}^\infty X(j\omega) e^{j\omega t}d\omega. \end{equation} \tag{3}
$$

从傅立叶变换和傅立叶反变换的公式中,我们可以看到傅立叶变换在时域和频域之间的对偶性质。二重性表明时域的性质在频域中总是有对应的。傅里叶变换有各种各样的特性。举几个基本的例子,单位脉冲函数 ( $\delta$ 函数) 的傅里叶变换是:

$\begin{equation} \mathcal{F}[\delta(t)]=\int{-\infty}^\infty \delta(t)e^{-j\omega t} dt=\int{0-}^{0+} \delta(t) dt = 1, \label{equ:delta_FT} \end{equation} \tag{4}$

时移特性:

$$
\begin{equation} \mathcal{F}[\delta(t-t0)]=\int{-\infty}^\infty x(t-t_0)e^{-j\omega t} dt=e^{-j\omega t0}\int{-\infty}^\infty x(t)e^{-j\omega t} dt = e^{-j\omega t_0}X(j\omega). \label{equ:time_shift_FT} \end{equation} \tag{5}
$$

然而,在实际应用中,我们很少处理连续信号。一般的做法是对连续信号进行处理,得到一个离散信号序列。采样可以用一系列单位脉冲函数来实现:

$\begin{equation} xs(t) = x(t)\sum{n=-\infty}^\infty \delta(t - nTs) = \sum{n=-\infty}^\infty x(nT_s)\delta(t - nT_s), \end{equation} \tag{6}$

式中, $T_s$ 是采样间隔。那么现在对采样信号 $x_s(t)$ 应用傅里叶变换,有:

$$
\begin{equation} Xs(j\omega) = \sum{n=-\infty}^\infty x(nT_s) e^{-j\omega nT_s}. \end{equation} \tag{7}
$$

上式说明了 $X_s(j\omega)$ 是个周期函数,并且 $\omega=\frac{2\pi}{T_s}$ 。 通常,我们更喜欢一个标准化频率 $\omega=\omega T_s$ ,这样, $X_s(j\omega) $ 的周期正好是 $2\pi$ 。

进一步,用 $x[n]$ 来表示 $x(nT_s)$ ,代表离散时间信号,则离散时间傅里叶变换 (discrete-time Fourier transform (DTFT)) 就是:

$$
\begin{equation} X(e^{j\omega})=\sum_{n=-\infty}^\infty x[n]e^{-j\omega n}. \label{equ:DTFT} \end{equation} \tag{8}
$$

如果离散信号 $x[n]$ 具有有限长度 $N$ ,则离散时间傅里叶变换就是:

$$
\begin{equation} X(e^{j\omega})=\sum_{n=0}^{N-1} x[n]e^{-j\omega n}, \end{equation} \tag{9}
$$

注意,DTFT 是 $\omega$ 的连续函数,我们可以通过抽样得到频率 $\omega_k=2\pi k/N$ 的 $X[k]$ 序列:

$$
\begin{equation} X[k] = X(e^{j\omega})|{\omega = 2\pi k/N} = \sum{n=0}^{N-1}x[n]e^{-j(2\pi/N)kn}, \end{equation} \tag{10}
$$

这就是离散傅里叶变换 (DFT) 的公式。从一维 DFT 到二维 DFT 的扩展很简单。事实上,2D DFT 可以看作是在两个维度上交替地执行1D DFT,也就是说, $x[m,n]$ 的 2D DFT 是由下面给出的:

$\begin{equation} X[u, v] = \sum{m=0}^{M-1}\sum{n=0}^{N-1}x[m, n]e^{-j2\pi\left(\frac{um}{M}+\frac{vn}{N}\right)}. \end{equation} \tag{11}$

离散傅里叶变换的性质

1 实数信号的 DFT

给定一个实数信号 $x[n]$ ,它的 DFT 是共轭对称的,可以证明如下:

$\begin{equation} X[N - k] = \sum{n=0}^{N-1}x[n]e^{-j(2\pi / N)(N-k)n}=\sum{n=0}^{N-1}x[n]e^{j(2\pi / N)kn}=X^*[k]. \end{equation} \tag{12}$

这个性质意味着 DFT 信号的一半 ${X[k]: 0\le k\le \lceil N / 2\rceil}$ 就包含了关于 $x[n]$ 频率特性的全部信息。

对于 2D 信号,我们得到了类似的结果:

$\begin{equation} \begin{split} X[M-u, N-v]&=\sum{m=0}^{M-1}\sum{n=0}^{N-1}x[m, n]e^{-j2\pi\left(\frac{(M-u)m}{M}+\frac{(N-v)n}{N}\right)}\ &=\sum{m=0}^{M-1}\sum{n=0}^{N-1}x[m, n]e^{j2\pi\left(\frac{um}{M}+\frac{vn}{N}\right)}=X^*[u, v]. \end{split} \end{equation} \tag{13}$

2 卷积定理

傅里叶变换最重要的性质之一就是卷积定理。具体地说,对 DFT 来说,卷积定理表明频域中的多项式乘积等价于时域中的卷积。具体而言:

信号 $x[n]$ 和滤波器 $h[n]$ 的卷积是:

$$
\begin{equation} y[n] = \sum_{m=0}^{N-1}h[m]x[((n-m))_N], \end{equation} \tag{14}
$$

考虑输出信号 $y[n]$ 的离散傅里叶变换,我们有:

$$
\begin{equation} \begin{split} Y[k] &= \sum{n=0}^{N-1} \sum{m=0}^{N-1}h[m]x[((n-m))N]e^{-j(2\pi/N)kn}\ &=\sum{m=0}^{N-1}h[m]e^{-j(2\pi/N)km}\sum_{n=0}^{N-1}x[((n-m))N]e^{-j(2\pi/N)k(n-m)}\ &=H[k]\left(\sum{n=m}^{N-1}x[n-m]e^{-j(2\pi/N)k(n-m)} + \sum{n=0}^{m-1}x[n-m + N]e^{-j(2\pi/N)k(n-m)}\right)\ &=H[k]\left(\sum{n=0}^{N-m-1}x[n]e^{-j(2\pi/N)kn} + \sum{n=N-m}^{N-1}x[n]e^{-j(2\pi/N)kn}\right)\ &=H[k]\sum{n=0}^{N-1}x[n]e^{-j(2\pi/N)kn}=H[k]X[k], \end{split} \end{equation} \tag{15}
$$

其中,等式右侧正是信号和滤波器在频域中的相乘。2D 场景中的卷积定理可以用类似的方法来推导。

具体来说, $X[k]$ 表示频率 $\omega_k=2\pi k/N$ 时的序列 $x[n]$ 的谱。

同样值得注意的是 DFT 是一对一的转换。给定 DFT $X[k]$ ,我们可以通过 inverse DFT (IDFT) 恢复原始信号:

$$
\begin{equation} x[n] = \frac{1}{N} \sum_{k=0}^{N - 1}X[k]e^{j(2\pi/N)kn}.\label{equ:idft} \end{equation} \tag{16}
$$

DFT 在现代信号处理算法中得到广泛应用,主要有两个原因:

  1. DFT 的输入和输出都是离散的,因此易于计算机处理。
  2. 存在有效的 dft 计算算法:快速傅里叶变换 (Fast Fourier transform,FFT) 算法利用了信号的对称性和周期性,降低了计算 DFT 的复杂度,从 $\mathcal{O}(N^2)$ 降低到 $\mathcal{O}(N\log N)$ 。DFT 逆变换的形式与 DFT 相似,也可以通过逆快速傅里叶变换 (IFFT) 有效地计算出来。

GFNet

GFNet 的具体架构如下图1所示。

图1:GFNet 的具体架构

整体架构与 ViT 和 MLP-Mixer 非常相似。即:首先对输入图片分块,分成 $p\times p\times 3$ 大小的块。一共有 $L=HW$ 个块。再通过 PatchEmbedding 操作得到 $L\times D$ 的 Feature map,这个操作和 ViT,MLP-Mixer 是完全一致的。

然后这个 $L\times D$ 的 Feature map 将通过相同的 $N$ 个 GFNet block,每一个 GFNet block 都包括一个 Global Filter Layer 高效地融合空间信息 (复杂度是 $\mathcal{O}(L\log L)$ )。还包括一个 FFN。最后一个 Block 的输出经过全局平均池化层和线性层完成分类任务。

下面就来着重讨论一下 Global filter layer 的计算机制。

Global filter layer 作为 Self-attention layer 的替代方案,Self-attention layer 可以混合不同空间位置 token 的信息。给定 tokens ${x}\in \mathbb{R}^{H\times W\times D}$ ,我们首先在空间维度进行二维 FFT,将其转换到频域:

$\begin{equation} {X}=\mathcal{F}[{x}]\in \mathbb{C}^{H\times W\times D}, \end{equation} \tag{17}$

式中, $\mathcal{F}[\cdot]$ 代表 2D FFT。注意到 ${X}$ 是一个复数的张量,代表着 $x$ 的频谱。

然后我们可以通过将一个可学习的滤波器 ${K}\in \mathbb{C}^{H\times W\times D}$ 乘以 ${X}$ 来调整 $x$ 的频谱:

$\begin{equation} \tilde{{X}} = {K}\odot {X}, \end{equation} \tag{18}$

式中, $\odot$ 操作是 Hadamard product。 ${K}$ 是 global filter 因为它和 ${X}$ 有相同的维度。

最后,我们采用 IFFT 将调制频谱 $\tilde{{X}}$ 转换回空间域,并更新 tokens:

$\begin{equation} {x}\leftarrow \mathcal{F}^{-1}[\tilde{{X}}]. \end{equation} \tag{19}$

上述核心部分的实现伪代码如下:

X = rfft2(x, dim=(1, 2))
X_tilde = X * K
x = irfft2(X_tilde, dim=(1, 2))

Global filter layer 的产生来自数字图像处理中的 Frequency filter。 全局滤波器可以看作是一组可学习的频率滤波器。

根据上文的14式和15式,我们发现:

Global filter layer 其实等价于 Depthwise global circular convolution,卷积核大小为 $H\times W$ 。

为什么有这个结论?

答: 根据上文的14式和15式,对于单一的 channel 而言, ${K}\left( H\times W \right)\odot {X} \left( H\times W \right)$ 就相当于是 2D 频域的相乘,也就是 2D 空域的卷积。

所以,对于多个 channel 的情况,2D 频域的相乘,也就是 Depthwise 的 2D 空域的卷积。

所以 $\tilde{{X}} = {K}\odot {X}$ 这一步就相当于是空域进行了一次卷积运算,那么在频域中就是这样的点乘运算。但是这个运算又和卷积运算不完全一样,因为这种频域上的点乘运算也没有归纳偏置,相当于卷积核大小为 $H\times W$ 的卷积运算。 全局滤波器也可以解释为一个空间域操作, 注意,在频域实现的全局滤波器也比空间域高效得多,计算复杂度是 $\mathcal{O}(DL\log L)$ 。 普通 Depthwise Convolution 的计算复杂度是: $\mathcal{O}(DL^2)$ 。

值得注意的是,在实现过程中,作者利用 DFT 的性质来减少冗余的计算。因为 $x$ 是实数张量,所以其 DFT ${X}$ 是共轭对称的,即有: ${X}[H - u, W - v, :]={X}^*[H, W, :]$ 。因此,我们可以只取 ${X}$ 的一半的值,而获得全部的信息。

$$
\begin{equation} {X}_r = {X}[:, 0:\widehat{W}]:=\mathcal{F}_r[{x}], \quad, \widehat{W}=\lceil W / 2\rceil, \end{equation} \tag{20}
$$

这样一来,全局滤波器的实现就变为了 ${K}_r\in \mathbb{C}^{H\times \widehat{W}\times D}$ ,可以节约一半的参数,同时也可以保证 $\mathcal{F}^{-1}_r[{K}_r\odot {X}_r]$ 是实数,能够直接与输入张量 $x$ 相加。

GFNet 模型的优势:

1 计算高效: 复杂度是 FFT $\mathcal{O}(L\log L)$ + 点乘 $\mathcal{O}(L)$ + IFFT $\mathcal{O}(L)$ = $\mathcal{O}(L\log L)$

2 能够处理不同分辨率的输入

Architecture variants

考虑到自注意力与MLP的高计算复杂度问题,现有 ViT、MLP 采用快速降低分辨率的方式,即初始的 PatchEmbedding 尺寸非常大,比如$14\times 14$。然而,GFNet的的计算复杂度为log-linear,可以避免上述问题。因此,我们可以以更高分辨率(比如$56\times 56$)的特征作为起点,然后逐渐下采样。在这篇文章中,我们主要探索了两种形式的 GFNet,即 Transformer 风格与 CNN 风格。

Transformer 风格的 GFNet 每层的 token 数是固定的,而 CNN 风格的 GFNet 采用金字塔结构。

  • 对于 Transformer 风格,类似 DeiT 与 ResMLP-12,我们同样采用了12层模型并得到了三个尺寸的模型 GFNet-Ti,GFNet-S 以及 GFNet-B (通过调整维度、深度等信息即可得到)。
  • 对于 CNN 风格,我们同样设计了三种复杂度的模 型GFNet-H-Ti、GFNet-H-S、GFNet-H-B。相关信息见下图2。

图2:CNN 风格 GFNet

Experiments

为验证所提方案的有效性,我们在ImageNet分类以及下游任务 (语义分割) 上进行了对比分析。

ImageNet 实验

与 DeiT 的训练不同的是,我们没有使用 EMA,RandomEarse 和 Repeated augmentation,这些模型对于训练效果的好坏有轻微影响。

下图3给出了与 Transformer 风格架构的性能对比,从中可以看到:

  • 所提方法明显优于近期的 MLP 类方案与 DeiT 等方案。
  • GFNet-XS比ResMLP高2.0\%且具有稍少的计算量。
  • GFNet-S同样具有比 gMLP-S、DeiT-S 更高的精度。
  • GFNet-Ti显著优于 DeiT-Ti (+2.4\%) 与 gMLP-Ti (+2.6\%),且具有相似复杂度。

图3:与 Transformer 风格架构的性能对比

下图4给出了与 CNN 风格架构的性能对比,从中可以看到: 受益于对数线性复杂度,GFNet-H 模型显示了比 ResNet,RegNet 和 PVT 更好的性能,并实现了与 Swin 相似的性能,同时具有更简单和更通用的设计。

图4:与 CNN 风格架构的性能对比

Downstream tasks

下图5对比了 GFNet 在不同数据集上的迁移能力,从中可以看到:GFNet 具有更佳的迁移能力。比如,GFNet 显著优于 ResMLP,同时具有比 EfficientNet-B7 相当的性能。

图5:GFNet 在不同下游任务上的迁移能力

下图6对比了 GFNet 在 ADE20K 语义分割数据上的性能对比,从中可以看到:GFNet 在该任务上表现非常好,在不同复杂度方面取得了与其他模型 (比如ResNet、PVT、Swin) 相当甚至更好的性能。

图6:GFNet 在 ADE20K 语义分割数据上的性能对比

  • 6.2 GFNet代码解读

代码来自:

1 Global Filter Layer 代码实现

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x, spatial_size=None):
        B, N, C = x.shape
        if spatial_size is None:
            a = b = int(math.sqrt(N))
        else:
            a, b = spatial_size

        x = x.view(B, a, b, C)

        x = x.to(torch.float32)

        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')

        x = x.reshape(B, N, C)

        return x

torch.view_as_complex(input)→ Tensor

GlobalFilter 这个类用到了 torch.view_as_complex 这个函数。
这个函数仅仅支持 torch.float64torch.float32 这2种数据类型,且数据的最后一维度必须是2。
举例:

>>> x=torch.randn(4, 2)
>>> x
tensor([[ 1.6116, -0.5772],
        [-1.4606, -0.9120],
        [ 0.0786, -1.7497],
        [-0.6561, -1.6623]])
>>> torch.view_as_complex(x)
tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)])

torch.fft.rfft(input,n=None,dim=-1,norm=None,*,out=None)→ Tensor

GlobalFilter 这个类用到了 torch.fft.rfft 这个函数。
这个函数计算实值输入的一维傅里叶变换。如上文所述,实数信号的 FFT 是共轭对称 (conjugate symmetric) 的,即:X[i]= conj(X[-i]),所以输出只包含奈奎斯特频率以下的正频率。要计算完整的输出,请使用 FFT()。
参数:

  • input (Tensor) – 实数输入张量

  • n (int, optional) – 信号长度

  • dim (int, optional) – 做 FFT 变换的维度

  • norm (str, optional) – 正则化模式

    • "forward" - normalize by 1/n
    • "backward" - no normalization
    • "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)

torch.fft.irfft(input,n=None,dim=-1,norm=None,*,out=None)→ Tensor

计算 rfft() 的逆变换。
输入被解释为傅里叶域中的共轭对称的信号,由 rfft() 产生。根据赫米特属性,输出将是实值的。
参数和torch.fft.rfft 函数一直。

torch.fft.rfft2(input,s=None,dim=(-2,-1),norm=None,*,out=None)→ Tensor

torch.fft.irfft2(input,s=None,dim=(-2,-1),norm=None,*,out=None)→ Tensor

这两个函数是 2D FFT 和 2D IFFT,参数多了一个:

  • s(Tuple[int],optional) – 变换后的信号大小。如果给定,在计算实数 FFT 之前,每个维度dim[i] 将被填充 (padding) 为零或修剪为长度 s[i] 。如果指定长度为 -1,则在该维度上不做填充。

另外需要注意的是这里 weight 的参数量:

self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
标准设置是 w=h//2+1,比如:h=14, w=8
这是因为 rfft 的共轭对称性,x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 得到的 x 的维度中,宽度是长度的一半再加一,比如:

>>> t = torch.rand(10, 10)
>>> rfft2 = torch.fft.rfft2(t)
>>> rfft2.size()
torch.Size([10, 6])

2 GFNet Block

class Block(nn.Module):

    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.filter = GlobalFilter(dim, h=h, w=w)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
        return x

注意与图1的对应关系。

3 PatchEmbedding

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

图片分成 patch 的操作,方法就是通过一个卷积 nn.Conv2d。卷积核大小是 $p$ ,stride 也等于 $p$ 。

4 下采样 Downsample

class DownLayer(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=56, dim_in=64, dim_out=128):
        super().__init__()
        self.img_size = img_size
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.proj = nn.Conv2d(dim_in, dim_out, kernel_size=2, stride=2)
        self.num_patches = img_size * img_size // 4

    def forward(self, x):
        B, N, C = x.size()
        x = x.view(B, self.img_size, self.img_size, C).permute(0, 3, 1, 2)
        x = self.proj(x).permute(0, 2, 3, 1)
        x = x.reshape(B, -1, self.dim_out)
        return x

用于金字塔操作的下采样。

5 Transformer 风格的 GFNet ,每层的 token 数是固定的

class GFNet(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 mlp_ratio=4., representation_size=None, uniform_drop=False,
                 drop_rate=0., drop_path_rate=0., norm_layer=None, 
                 dropcls=0):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
            norm_layer: (nn.Module): normalization layer
        """
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = PatchEmbed(
                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        h = img_size // patch_size
        w = h // 2 + 1

        if uniform_drop:
            print('using uniform droppath with expect rate', drop_path_rate)
            dpr = [drop_path_rate for _ in range(depth)]  # stochastic depth decay rule
        else:
            print('using linear droppath with expect rate', drop_path_rate * 0.5)
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # dpr = [drop_path_rate for _ in range(depth)]  # stochastic depth decay rule

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, mlp_ratio=mlp_ratio,
                drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w)
            for i in range(depth)])

        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        if dropcls > 0:
            print('dropout %.2f before classifier' % dropcls)
            self.final_dropout = nn.Dropout(p=dropcls)
        else:
            self.final_dropout = nn.Identity()

        trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x).mean(1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.final_dropout(x)
        x = self.head(x)
        return x

实现方式和 ViT 一致,每层的 Block 使用 GFNet Block 实现。

6 CNN 风格的 GFNet ,每层的 token 数是不固定的,采用金字塔结构

class GFNetPyramid(nn.Module):

    def __init__(self, img_size=224, patch_size=4, num_classes=1000, embed_dim=[64, 128, 256, 512], depth=[2,2,10,4],
                 mlp_ratio=[4, 4, 4, 4],
                 drop_rate=0., drop_path_rate=0., norm_layer=None, init_values=0.001, no_layerscale=False, dropcls=0):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
            norm_layer: (nn.Module): normalization layer
        """
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim[-1]  # num_features for consistency with other models
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = nn.ModuleList()

        patch_embed = PatchEmbed(
                img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim[0])
        num_patches = patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))

        self.patch_embed.append(patch_embed)

        sizes = [56, 28, 14, 7]
        for i in range(4):
            sizes[i] = sizes[i] * img_size // 224

        for i in range(3):
            patch_embed = DownLayer(sizes[i], embed_dim[i], embed_dim[i+1])
            num_patches = patch_embed.num_patches
            self.patch_embed.append(patch_embed)

        self.pos_drop = nn.Dropout(p=drop_rate)
        self.blocks = nn.ModuleList()

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]  # stochastic depth decay rule
        cur = 0
        for i in range(4):
            h = sizes[i]
            w = h // 2 + 1

            if no_layerscale:
                print('using standard block')
                blk = nn.Sequential(*[
                    Block(
                    dim=embed_dim[i], mlp_ratio=mlp_ratio[i],
                    drop=drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, h=h, w=w)
                for j in range(depth[i])
                ])
            else:
                print('using layerscale block')
                blk = nn.Sequential(*[
                    BlockLayerScale(
                    dim=embed_dim[i], mlp_ratio=mlp_ratio[i],
                    drop=drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, h=h, w=w, init_values=init_values)
                for j in range(depth[i])
                ])
            self.blocks.append(blk)
            cur += depth[i]

        # Classifier head
        self.norm = norm_layer(embed_dim[-1])

        self.head = nn.Linear(self.num_features, num_classes)

        if dropcls > 0:
            print('dropout %.2f before classifier' % dropcls)
            self.final_dropout = nn.Dropout(p=dropcls)
        else:
            self.final_dropout = nn.Identity()

        trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        for i in range(4):
            x = self.patch_embed[i](x)
            if i == 0:
                x = x + self.pos_embed
            x = self.blocks[i](x)

        x = self.norm(x).mean(1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.final_dropout(x)
        x = self.head(x)
        return x

GFNetPyramid 一开始使用大小为4×4的 patch,所以得到56×56个 patch。
每个 stage 结束通过 DownLayer(sizes[i], embed_dim[i], embed_dim[i+1]) 实现下采样的操作。

  • 0
  • 0
  • 87
收藏
暂无评论