• 问答
  • 技术
  • 实践
  • 资源
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(十六)
技术讨论

​作者丨科技猛兽
编辑丨极市平台

本文目录

35 Swin Transformer: 屠榜各大CV任务的视觉Transformer模型
(来自 微软亚研院,中科大)
35.1 Swin Transformer原理分析
35.2 Swin Transformer代码解读

36 SwinIR: 用于图像复原的 Swin Transformer
(来自 ETH Zurich)
36.1 SwinIR原理分析
36.2 SwinIR代码解读

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

本文介绍2篇文章是 Swin Transformer 系列及其变体。Swin Transformer 是屠榜各大CV任务的通用视觉Transformer模型,它在图像分类、目标检测、分割上全面超越 SOTA,在语义分割任务中在 ADE20K 上刷到 53.5 mIoU,超过之前 SOTA 大概 4.5 mIoU!可能是CNN的完美替代方案。

35 Swin Transformer: 屠榜各大CV任务的视觉Transformer模型

论文名称:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

论文地址:

https://arxiv.org/abs/2103.14030

35.1 Swin Transformer原理分析:

Swin Transformer 提出了一种针对视觉任务的通用的 Transformer 架构,Transformer 架构在 NLP 任务中已经算得上一种通用的架构,但是如果想迁移到视觉任务中有一个比较大的困难就是处理数据的尺寸不一样。作者分析表明,Transformer 从 NLP 迁移到 CV 上没有大放异彩主要有两点原因:

1. 最主要的原因是两个领域涉及的scale不同,NLP 任务以 token 为单位,scale 是标准固定的,而 CV 中基本元素的 scale 变化范围非常大。

2. CV 比起 NLP 需要更大的分辨率,而且 CV 中使用 Transformer 的计算复杂度是图像尺度的平方,这会导致计算量过于庞大, 例如语义分割,需要像素级的密集预测,这对于高分辨率图像上的Transformer来说是难以处理的。

Swin Transformer 就是为了解决这两个问题所提出的一种通用的视觉架构。Swin Transformer 引入 CNN 中常用的层次化构建方式。

接下来我们直观地梳理下 Swin Transformer 的前向传播过程,从一张输入图片开始经过以下6步:

1 图片预处理:分块和降维 (Patch Partition)

Swin Transformer 首先把 $\text{x}\in H\times W\times 3$ 的图片,变成一个 $\text{x}_p\in N\times (P^2\cdot C)$ 的2维的image patches。它可以看做是一系列的展平的2D块的序列,这个序列中一共有 $N=HW/P^2$ 个展平的2D块,每个块的维度是 $(P^2\cdot 3)$ 。其中 $P$ 是块大小。

在 Swin Transformer 中,块的大小 $P=4$ ,所以得到的 $\text{x}_p\in N\times 48$ ,这里的 $N=HW/16=\frac{H}{4}\times \frac{W}{4}$ 。

所以经过了这一步的分块操作,一张 $\text{x}\in H\times W\times 3$ 的图片就变成了 $\frac{H}{4}\times \frac{W}{4}\times 48$ 的张量,可以理解成是 $\frac{H}{4}\times \frac{W}{4}$ 个图片块,每个块是一个 $48$ 维的 token,如下图1所示。

2 Stage 1:线性变换 (Linear Embedding)

现在得到的向量维度是: $\frac{H}{4}\times \frac{W}{4}\times 48$ ,还需要做一步叫做Linear Embedding的步骤,对每个向量都做一个线性变换(即全连接层),变换后的维度为 $C$ ,这里我们称其为 Linear Embedding。这一步之后得到的张量维度是: $\frac{H}{4}\times \frac{W}{4}\times C$ 。

图1:Swin Transformer 结构

3 Stage 1:Swin Transformer Block

接下来 $\frac{H}{4}\times \frac{W}{4}\times C$ 这个张量进入2个连续的 Swin Transformer Block 中,这被称作 Stage 1,在整个的 Stage 1 里面 token 的数量一直维持 $\frac{H}{4}\times \frac{W}{4}$ 不变。

Swin Transformer Block 具体是如何操作的呢?

图2:2个连续的 Swin Transformer Block

Swin Transformer Block 的结构如上图2所示。上图是2个连续的 Swin Transformer Block。其中一个 Swin Transformer Block 由一个带两层 MLP 的 Shifted Window-based MSA 组成,另一个 Swin Transformer Block 由一个带两层 MLP 的 Window-based MSA 组成。在每个 MSA 模块和每个 MLP 之前使用 LayerNorm(LN) 层,并在每个 MSA 和 MLP之后使用残差连接。

可以看到 Swin Transformer Block 和 ViT Block 的区别就在于将 ViT 的多头注意力机制 MSA 替换为了 Shifted Window-based MSA 和 Window-based MSA。

4 Stage 1:Swin Transformer Block:Window-based MSA

标准 ViT 的多头注意力机制 MSA 采用的是全局自注意力机制,即:计算每个 token 和所有其他 token 的 attention map。全局自注意力机制的计算复杂度是 $O(N^2d)$ ,其中, $N$ 是 token的数量, $d$ 是 Embedding dimension。全局自注意力机制的计算复杂度与序列长度 $N$ 成平方关系。当图片分辨率较高或是密集预测任务中计算量会过大。

Window-based MSA 不同于普通的 MSA,它在一个个 window 里面去计算 self-attention。假设每个 window 里面包括 $M\times M$ 个 image patches,则 Window-based MSA 和普通的 MSA 的计算量分别为:

由于 Window 的 patch 数量 $M$ 远小于图片patch数量 $hw$ ,Window-based MSA 的计算量与序列长度 $N=hw$ 成线性关系。

5 Stage 1:Swin Transformer Block:Shifted Window-based MSA

Window-based MSA 虽然大幅节约了计算量,但是牺牲了 windows 之间关系的建模,不重合的 Window 之间缺乏信息交流影响了模型的表征能力。Shifted Window-based MSA 就是为了解决这个问题,如下图3所示。在两个连续的Swin Transformer Block中交替使用W-MSA 和 SW-MSA。以上图为例,将前一层 Swin Transformer Block 的 8x8 尺寸feature map划分成 2x2 个patch,每个 patch 尺寸为 4x4,然后将下一层 Swin Transformer Block的 Window 位置进行移动,得到 3x3 个不重合的 patch。移动 window 的划分方式使上一层相邻的不重合 window 之间引入连接,大大的增加了感受野。

图3中表示连续的2个 Blocks,其中第1个 Block 有4个windows,每个 window 中是 $M\times M=4×4$ 的patch。第2个 Block 也有4个windows,每个 window 中也是 $M\times M=4×4$ 的patch,但是window的位置发生了偏移,偏移的距离是 $\frac{M}{2}=2$ 。

这样一来,在新的 window 里面做 self-attention 操作,就可以包括原有的 windows 的边界,实现 windows 之间关系的建模。

图3:Shifted Window-based MSA

所以 2个连续的 Swin Transformer Block 的表达式为:

$$
\begin{align} &{{\hat{\bf{z}}}^{l}} = \text{W-MSA}\left( {\text{LN}\left( {{{\bf{z}}^{l - 1}}} \right)} \right) + {\bf{z}}^{l - 1},\nonumber\ &{{\bf{z}}^l} = \text{MLP}\left( {\text{LN}\left( {{{\hat{\bf{z}}}^{l}}} \right)} \right) + {{\hat{\bf{z}}}^{l}},\nonumber\ &{{\hat{\bf{z}}}^{l+1}} = \text{SW-MSA}\left( {\text{LN}\left( {{{\bf{z}}^{l}}} \right)} \right) + {\bf{z}}^{l}, \nonumber\ &{{\bf{z}}^{l+1}} = \text{MLP}\left( {\text{LN}\left( {{{\hat{\bf{z}}}^{l+1}}} \right)} \right) + {{\hat{\bf{z}}}^{l+1}}, \label{eq.swin} \end{align} \tag{2}
$$

但是引入 Shifted Window 会带来另一个问题就是会造成 window 数发生改变,而且有的 window 大,有的 window 小,比如下面图4。

图4:Cycle Shift 操作

一种简单的解决办法是把所有 window 都做 padding 操作,使之达到相同的大小。但是这会因为 window 数量的增加 (从 $\lceil \frac{h}{M}\rceil \times \lceil\frac{w}{M}\rceil$ 增加到 $(\lceil \frac{h}{M}\rceil +1) \times (\lceil\frac{w}{M}\rceil + 1)$ ) 而增加计算量。所以作者在这里提出了一种更加高效的 batch computation 计算方法,通过 cycle shift 的方法,合并小的 windows,仔细看图4,将 A,B,C 这3个小的 windows 进行循环移位,使之合并小的 windows。

经过了 cycle shift 的方法,一个 window 可能会包括来自不同 window 的内容。比如图4右下角的 window,来自4个不同的 sub-window。因此,要采用 masked MSA 机制将 self-attention 的计算限制在每个子窗口内。最后通过 reverse cycle shift 的方法将每个 window 的 self-attention 结果返回。

这里进行下简单的图解,下图5代表 cycle shift 的过程,这9个 window 通过移位从左边移动到右侧的位置。

图5: cycle shift 的过程

这样再按照之前的 window 划分,就能够得到 window 5 的attention 的结果了。但是这样操作会使得 window 6 和 4 的 attention 混在一起,window 1,3,7 和 9 的 attention 混在一起。所以需要采用 masked MSA 机制将 self-attention 的计算限制在每个子窗口内。具体怎么做呢?

按照 Swin Transformer 的代码实现 (下面会有讲解),还是做正常的 self-attention (在 window_size 上做),之后要进行一次 mask 操作,把不需要的 attention 值给它置为0。

例1: 比如右上角这个 window,如下图6所示。它由4个 patch 组成,所以应该计算出的 attention map是4×4的。但是6和4是2个不同的 sub-window,我们又不想让它们的 attention 发生交叠。所以我们希望的 attention map 应该是图7这个样子。

图6:右上角这个 window

图7:右上角这个 window 希望的 attention map 的样子

因此我们就需要如下图8所示的 mask。

图8:右上角这个 window 希望的 mask 的样子

例2: 比如右下角这个 window,如下图9所示。它由4个 patch 组成,所以应该计算出的 attention map是4×4的。但是1,3,7和9是4个不同的 sub-window,我们又不想让它们的 attention 发生交叠。所以我们希望的 mask 应该是图10这个样子。

图9:右下角这个 window

图10:右下角这个 window 希望的 mask 的样子

6 Stage 2/3/4

Stage 2 的输入是维度是 $\frac{H}{4}\times \frac{W}{4}\times C$ 的张量。从 Stage 2 到 Stage 4 的每个 stage 的初始阶段都会先做一步 Patch Merging 操作,Patch Merging 操作的目的是为了减少 tokens 的数量,它会把相邻的 2×2 个 tokens 给合并到一起,得到的 token 的维度是 $4C$ 。Patch Merging 操作再通过一次线性变换把维度降为 $2C$ 。至此,维度是 $\frac{H}{4}\times \frac{W}{4}\times C$ 的张量经过Patch Merging 操作变成了维度是 $\frac{H}{8}\times \frac{W}{8}\times 2C$ 的张量。

同理,Stage 3 的Patch Merging 操作会把维度是 $\frac{H}{8}\times \frac{W}{8}\times 2C$ 的张量变成维度是 $\frac{H}{16}\times \frac{W}{16}\times 4C$ 的张量。Stage 4 的Patch Merging 操作会把维度是 $\frac{H}{16}\times \frac{W}{16}\times 4C$ 的张量变成维度是 $\frac{H}{32}\times \frac{W}{32}\times 8C$ 的张量。

每个 Stage 都会改变张量的维度,形成一种层次化的表征。 因此,这种层次化的表征可以方便地替换为各种视觉任务的骨干网络。

Swin Transformer 的结构

Swin Transformer 分为 Swin-T,Swin-S,Swin-B,Swin-L 这四种结构。使用的 window 的大小统一为 $M=7$ ,每个 head 的embedding dimension 都是 32,每个 stage 的层数如下:

  • Swin-T: $C=96$ ,layer number: $\left{ {2,2,6,2} \right}$ 。

  • Swin-S: $C=96$ ,layer number: $\left{ {2,2,18,2} \right}$ 。

  • Swin-B: $C=128$ ,layer number: $\left{ {2,2,18,2} \right}$ 。

  • Swin-L: $C=192$ ,layer number: $\left{ {2,2,18,2} \right}$ 。

Experiments:

1 图像分类:

数据集:ImageNet

(a)表是直接在 ImageNet-1k 上训练,(b)表是先在 ImageNet-22k 上预训练,再在 ImageNet-1k 上微调。

对标 88M 参数的 DeiT-B 模型,它在 ImageNet-1k 上训练的结果是83.1\% Top1 Accuracy,Swin-B 模型的参数是80M,它在 ImageNet-1k 上训练的结果是83.5\% Top1 Accuracy,优于DeiT-B 模型。

图11:图像分类实验结果

图像分类上比 ViT、DeiT等 Transformer 类型的网络效果更好,但是比不过 CNN 类型的EfficientNet,猜测 Swin Transformer 还是更加适用于更加复杂、尺度变化更多的任务。

2 目标检测:

数据集:COCO 2017 (118k Training, 5k validation, 20k test)

(a) 表是在 Cascade Mask R-CNN, ATSS, RepPoints v2, 和 Sparse RCNN 上对比 Swin-T 和 ResNet-50 作为 Backbone 的性能。

(b) 表是使用 Cascade Mask R-CNN 模型的不同 Backbone 的性能对比。

(c) 表是整体的目标检测系统的对比,在 COCO test-dev 上达到了 58.7 box AP 和 51.1 mask AP。

图12:目标检测实验结果

3 语义分割:

数据集:ADE20K (20k Training, 2k validation, 3k test)

下图13列出了不同方法/Backbone的mIoU、模型大小(#param)、FLOPs和FPS。从这些结果可以看出,Swin-S 比具有相似计算成本的 DeiT-S 高出+5.3 mIoU (49.3 vs . 44.0)。也比ResNet-101 高+4.4 mIoU,比 ResNeSt-101 高 +2.4 mIoU。

图13:语义分割实验结果

35.2 Swin Transformer代码解读:

代码来自:

https://github.com/microsoft/Swin-Transformer

把张量 (B, H, W, C) 分成 window (B×H/M×W/M, M, M, C),其中M是 window_size。这一步相当于得到 B×H/M×W/M 个大小为 (M, M, C) 的 window。

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

把window (B×H/M×W/M, M, M, C) 变回张量 (B, H, W, C)。

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

W-MSA 模块

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

这里我们着重分析下 WindowAttention 和 Attention 在代码实现上面的不同之处。

attention的实现过程是一致的,只是这里的B_代表 B×H/M×W/M,这里的N代表 window size M。

定义一个相对位置编码表,维度是[(2M-1)*(2M-1), num_heads]。
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
attn = attn + relative_position_bias.unsqueeze(0) 代表给 attention map 添加相对位置编码。

一个 Swin Transformer Block

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

代码实现对标一个 ViT 的 Transformer Block,由一次 Window Attention 和一个 MLP 组成。最关键的是 attn_mask 的计算。代码如下:

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

slice() 函数返回 slice 对象 (切片)。slice 对象用于指定如何对序列进行裁切。应该 mask 住的地方的 attn_mask 值不为0,所以填入-100;不应该 mask 住的地方的 attn_mask 值为0,所以填入0,之后的softmax操作会把-100值置为极小值。

Patch Merging 操作

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

Patch Merging 操作把相邻的 2×2 个 tokens 给合并到一起,得到的 token 的维度是 $4C$ 。Patch Merging 操作再通过一次线性变换把维度降为 $2C$ 。

一个基本的 Stage

class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

由 depth 个 SwinTransformerBlock 组成,相邻的2个 SwinTransformerBlock 要进行一次 Shift window 操作。

整体的 Swin Transformer

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

由4个 Stage 组成,每个 Stage 由 BasicLayer 实现。
传入的 depths 代表每个 Stage 的层数,比如 Swin-T 就是:[2, 2, 6, 2]。

36 SwinIR: 用于图像复原的 Swin Transformer

论文名称:SwinIR: Image Restoration Using Swin Transformer

论文地址:

https://arxiv.org/abs/2108.10257

36.1 SwinIR原理分析:

Swin Transformer 提出了一种针对视觉任务的通用的 Transformer 架构,在分类,检测,分割任务上均达到了最优。SwinIR 继承了 Swin Transformer 的结构,是一个用于图像恢复的强基线模型,在图像超分辨率、去噪等任务上表现SOTA。

图像复原 (Image restoration) 是一个长期存在的 low-level 视觉问题,旨在从低质量退化的图像 (例如,缩小、嘈杂和压缩图像) 中恢复高质量干净的图像。虽然最先进的图像恢复方法基于卷积神经网络,但很少有人尝试使用 Transformer,它们在high-level视觉任务中表现出令人印象深刻的性能。

基于卷积网络的图像复原模型有两个缺点:

  1. 图像和卷积核之间的交互是与内容无关的。使用相同的卷积核来恢复不同的图像区域可能不是最佳选择。
  2. 收到卷积核感受野大小的限制,卷积无法建模长距离的相关性。

针对第二个问题,目前已有一些基于 Transformer 网络的图像复原模型,比较有效的比如:

1) IPT:

https://zhuanlan.zhihu.com/p/342261872

2) U-Transformer:

https://zhuanlan.zhihu.com/p/380391088

3) Video super-resolution transformer:

https://link.zhihu.com/?target=https%3A//arxiv.org/abs/2106.06847

这些模型的做法都是把图片分成了固定大小的 Patch (比如48×48)。但是,基于 Transformer 网络的图像复原模型会产生2个缺点:

  1. 边界像素不能利用 Patch 之外的相邻像素进行图像恢复。
  2. 恢复的图像可能会在每个 Patch 周围引入边界伪影。虽然这个问题可以通过 Patch 的重叠来缓解,但它会带来额外的计算负担。

上一节介绍的 Swin Transformer 在多种视觉任务上展示了巨大的前景,因为它集成了 CNN和 Transformer 的优势。一方面,由于局部注意机制 (local attention mechanism),它具有CNN 处理大尺寸图像的优势。另一方面,它具有Transformer的优势,可以用 shift window方案来建模远程依赖关系。再一方面,这种基于注意力机制的模型图像和卷积核之间的交互是与内容有关的,可以理解成一种 spatially varying 的卷积操作。所以在一定程度上解决了 CNN 和 Transformer 模型的缺点。

如下图14所示,作者提出了一个基于 Swin Transformer 的用于图像恢复的强基线模型 SwinIR。 SwinIR 由三部分组成:浅层特征提取 (shallow feature extraction)、深层特征提取 (deep feature extraction) 和高质量图像重建 (high-quality image reconstruction)。无论是经典图像超分(即退化方式为bicubic),还是真实场景图像超分,亦或图像降噪与JPEG压缩伪影移除,所提SwinIR均取得了显著优于已有方案的性能

图14:SwinIR 模型结构

接下来我们和上一节一样直观地梳理下 SwinIR 的前向传播过程,从一张输入图片开始经过以下3步:

1 浅层特征提取模块 (shallow feature extraction)

给定输入的低质量图片 $I{\textit{LQ}}\in\mathbb{R}^{H\times W\times C{in}}$ ,使用一个3×3卷积 $H_{\textit{SF}}(\cdot)$ 来提取它的浅层特征 $F_0\in\mathbb{R}^{H\times W\times C}$ :

2 深层特征提取模块 (deep feature extraction)

给定上一阶段的输出特征 $F0$,使用深层特征提取模块 $H{\textit{DF}}$ 进行深层特征提取:

深层特征提取模块由 $K$ 个 Residual Swin Transformer Blocks (RSTB) 和一个卷积操作组成。每个 RSTB 模块的输出 $F_{1},F_2, \ldots, F_K$ 和最终的输出特征是:

式中, $H{\textit{RSTB}{i}}(\cdot)$ 表示第 $i$ 个 RSTB 模块, $H_{\textit{CONV}}$ 表示最终的卷积层,使用它的目的是 将卷积网络的归纳偏差 (inductive bias) 融入基于 Transformer 的网络,并为后来的浅层和深度特征奠定了更好的基础。

每个 Residual Swin Transformer Block 的内部结构如上图14所示。它由一堆 STL (Swin Transformer Layer) 和一个卷积操作,外加残差链接构成。写成表达式就是:

式中, $H{\textit{STL}{i,j}}(\cdot)$ 代表第 $i$ 个 RSTB 的第 $j$ 个 STL (Swin Transformer Layer), $H_{\textit{CONV}i}(\cdot)$ 代表第 $i$ 个 RSTB 的卷积操作, $F{i,0}$ 代表残差连接。

每个 RSTB 的残差链接使得模型便于融合不同级别的特征,卷积操作有利于增强平移不变性。

Swin Transformer Layer 已经在上一节介绍过,2个连续的 Swin Transformer Layer 就如上面2式所示:

代码实现就是上一节的 class SwinTransformerBlock(nn.Module),这里不再赘述。

3 高质量图像重建模块 (high-quality image reconstruction)

对于图像超分任务,通过浅层特征 $F0$ 和深层特征 $F{\textit{DF}}$ 重建高质量图像 $I_{\textit{RHQ}}$ :

式中, $H_{\textit{REC}}(\cdot)$ 代表高质量图像重建模块。

浅层特征 $F0$ 主要含有低频信息,而深层特征 $F{\textit{DF}}$ 专注于恢复丢失的高频信息。通过长距离的跳变连接,SwinIR 可以直接将低频道信息直接传输到重建模块,这可以帮助深层特征 $F_{\textit{DF}}$ 专注于提取高频信息并稳定训练。作者使用 sub-pixel 的卷积层实现高质量图像重建模块。

对于图像去噪任务和压缩图像任务,仅仅使用一个带有残差的卷积操作作为高质量图像重建模块:

损失函数

对于超分任务,直接去优化生成的高质量图片和GT的 $L_1$ 距离:

对于图像去噪任务和压缩图像任务,使用 Charbonnier Loss:

0

式中, $\epsilon$ 通常取 $10^{-3}$ 。

Experiments:

模型参数:

  • RSTB 模块数: $K=6$ (轻量级超分模型 $K=4$ )。
  • 每个 RSTB 模块的 STL 层数: $L=6$ 。
  • Window size: $M=8$ (在图片压缩任务中 $M=7$ )。
  • Attention 的 head 数:6。
  • channel 数:180 (轻量级超分模型为60)。

1 经典图像超分

图15:SwinIR 经典图像超分结果

图16:SwinIR 经典图像超分可视化结果

仅仅使用 DIV2K 数据集训练时,仅具有 11.8M 参数的 SwinIR 超越了诸多 CNN 模型,当使用的数据集再加上 Flickr2K 时,性能得到再次提升,并超越了使用 ImageNet 训练的,具有 115.5M 参数的 IPT 模型。

2 轻量级图像超分

图17:SwinIR 轻量级图像超分结果

作者还对比了几个轻量级的超分网络模型,如图17所示。在相似的计算量和参数量的前提下,SwinIR 超越了诸多轻量级的 CNN 模型。

3 真实世界图像超分

图18:SwinIR 真实世界图像超分可视化结果

真实世界图像超分任务没有 GT 图像,作者对比了几种真实世界图像超分模型和 SwinIR 的可视化结果。SwinIR 的可视化结果令人满意,并且能够产生锐度高的清晰图像。

4 图像去噪

图19:SwinIR 图片去噪结果

图20:SwinIR 图片去噪可视化结果

噪声等级包括15,25和50,作者对比了多个CNN去噪模型,SwinIR 都超过了它们的性能。值得注意的是,SwinIR 的参数量只有12M,但是 DRUNet 的参数量达到了32.7M。 这表明SwinIR 架构在学习特征表示中具有高效的恢复能力。

5 图片压缩

图21:SwinIR 图片压缩结果

JPEG quality factors 取10,20,30,40,从结果发现,SwinIR 至少有0.11dB 和0.11dB的提升。SwinIR 的参数量只有12M,但是 DRUNet 的参数量达到了32.7M。

36.2 SwinIR代码解读:

代码来自:

https://github.com/JingyunLiang/SwinIR

代码中的 window_partition,window_reverse,WindowAttention,SwinTransformerBlock,PatchMerging,BasicLayer 都继承自 Swin Transformer 的源代码,不再赘述。

BasicLayer

继承自 Swin Transformer 的源代码, Swin Transformer 中代表一个 Stage,SwinIR里面指的是一个 RSTB 模块的内容。

class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x, x_size):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x, x_size)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

Residual Swin Transformer Blocks (RSTB模块)

每个 RSTB 模块有 $L=6$ 个 STL 层。

RSTB 模块包含1个 BasicLayer 层,一个 conv 层和残差链接。

class RSTB(nn.Module):
    """Residual Swin Transformer Block (RSTB).

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        img_size: Input image size.
        patch_size: Patch size.
        resi_connection: The convolutional block before residual connection.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
                 img_size=224, patch_size=4, resi_connection='1conv'):
        super(RSTB, self).__init__()

        self.dim = dim
        self.input_resolution = input_resolution

        self.residual_group = BasicLayer(dim=dim,
                                         input_resolution=input_resolution,
                                         depth=depth,
                                         num_heads=num_heads,
                                         window_size=window_size,
                                         mlp_ratio=mlp_ratio,
                                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                                         drop=drop, attn_drop=attn_drop,
                                         drop_path=drop_path,
                                         norm_layer=norm_layer,
                                         downsample=downsample,
                                         use_checkpoint=use_checkpoint)

        if resi_connection == '1conv':
            self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
        elif resi_connection == '3conv':
            # to save parameters and memory
            self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                      nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
                                      nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                      nn.Conv2d(dim // 4, dim, 3, 1, 1))

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
            norm_layer=None)

        self.patch_unembed = PatchUnEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
            norm_layer=None)

    def forward(self, x, x_size):
        return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x

PatchEmbed 操作就是把维度为 (B, C, H, W)三维图片格式图像转化成二维张量格式 (B, H×W, C)
PatchUnEmbed 操作就是把二维张量格式 (B, H×W, C) 转化成维度为 (B, C, H, W)三维图片格式

SwinIR 整体模型

调用:

if __name__ == '__main__':
    upscale = 4
    window_size = 8
    height = (1024 // upscale // window_size + 1) * window_size
    width = (720 // upscale // window_size + 1) * window_size
    model = SwinIR(upscale=2, img_size=(height, width),
                   window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
                   embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
    print(model)
    print(height, width, model.flops() / 1e9)

    x = torch.randn((1, 3, height, width))
    x = model(x)
    print(x.shape)

第1阶段:浅层特征提取模块 (shallow feature extraction)

        #####################################################################################################
        ################################### 1, shallow feature extraction ###################################
        self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

第2阶段:深层特征提取模块 (deep feature extraction)
包括4个 RSTB,每个 RSTB 中包括6个 Swin Transformer Layer。

        #####################################################################################################
        ################################### 2, deep feature extraction ######################################
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = embed_dim
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # merge non-overlapping patches into image
        self.patch_unembed = PatchUnEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build Residual Swin Transformer blocks (RSTB)
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = RSTB(dim=embed_dim,
                         input_resolution=(patches_resolution[0],
                                           patches_resolution[1]),
                         depth=depths[i_layer],
                         num_heads=num_heads[i_layer],
                         window_size=window_size,
                         mlp_ratio=self.mlp_ratio,
                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                         drop=drop_rate, attn_drop=attn_drop_rate,
                         drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
                         norm_layer=norm_layer,
                         downsample=None,
                         use_checkpoint=use_checkpoint,
                         img_size=img_size,
                         patch_size=patch_size,
                         resi_connection=resi_connection

                         )
            self.layers.append(layer)
        self.norm = norm_layer(self.num_features)

        # build the last conv layer in deep feature extraction
        if resi_connection == '1conv':
            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        elif resi_connection == '3conv':
            # to save parameters and memory
            self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                                 nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                                 nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))

第3阶段:高质量图像重建模块 (high-quality image reconstruction)

        #####################################################################################################        ################################ 3, high quality image reconstruction ################################        if self.upsampler == 'pixelshuffle':            # for classical SR            self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),                                                      nn.LeakyReLU(inplace=True))            self.upsample = Upsample(upscale, num_feat)            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)        elif self.upsampler == 'pixelshuffledirect':            # for lightweight SR (to save parameters)            self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,                                            (patches_resolution[0], patches_resolution[1]))        elif self.upsampler == 'nearest+conv':            # for real-world SR (less artifacts)            assert self.upscale == 4, 'only support x4 now.'            self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),                                                      nn.LeakyReLU(inplace=True))            self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)            self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)            self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)            self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)        else:            # for image denoising and JPEG compression artifact reduction            self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)

把它们连接在一起的前向传播过程

    def forward_features(self, x):
        x_size = (x.shape[2], x.shape[3])
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x, x_size)

        x = self.norm(x)  # B L C
        x = self.patch_unembed(x, x_size)

        return x

    def forward(self, x):
        self.mean = self.mean.type_as(x)
        x = (x - self.mean) * self.img_range

        if self.upsampler == 'pixelshuffle':
            # for classical SR
            x = self.conv_first(x)
            x = self.conv_after_body(self.forward_features(x)) + x
            x = self.conv_before_upsample(x)
            x = self.conv_last(self.upsample(x))
        elif self.upsampler == 'pixelshuffledirect':
            # for lightweight SR
            x = self.conv_first(x)
            x = self.conv_after_body(self.forward_features(x)) + x
            x = self.upsample(x)
        elif self.upsampler == 'nearest+conv':
            # for real-world SR
            x = self.conv_first(x)
            x = self.conv_after_body(self.forward_features(x)) + x
            x = self.conv_before_upsample(x)
            x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
            x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
            x = self.conv_last(self.lrelu(self.conv_hr(x)))
        else:
            # for image denoising and JPEG compression artifact reduction
            x_first = self.conv_first(x)
            res = self.conv_after_body(self.forward_features(x_first)) + x_first
            x = x + self.conv_last(res)

        x = x / self.img_range + self.mean

        return x

forward_features 函数代表第2阶段深层特征提取模块:先通过 x = self.patch_embed(x) 把图片变成二维张量格式(B, H×W, C),再通过一系列 RSTB 模块。
forward 函数先通过第1阶段浅层特征提取模块 conv_first,再通过 x = self.conv_after_body(self.forward_features(x)) + x 完成深层特征提取模块和结尾卷积+残差链接。最后通过第3阶段高质量图像重建模块。

总结

本文介绍2篇文章是 Swin Transformer 系列及其变体。Swin Transformer 是屠榜各大CV任务的通用视觉Transformer模型,它在图像分类、目标检测、分割上全面超越 SOTA,在语义分割任务中在 ADE20K 上刷到 53.5 mIoU,超过之前 SOTA 大概 4.5 mIoU!可能是CNN的完美替代方案。SwinIR 是 Swin Transformer 在底层视觉任务的尝试,是一个基于Swin Transformer的用于图像恢复的强基线模型,在图像超分辨率、去噪等任务上表现SOTA!性能优于IPT、DRUNet等网络,无论是经典图像超分(即退化方式为bicubic),还是真实场景图像超分,亦或图像降噪与JPEG压缩伪影移除,所提SwinIR均取得了显著优于已有方案的性能

参考:

https://zhuanlan.zhihu.com/p/360513527

  • 0
  • 0
  • 120
收藏
暂无评论