• 问答
  • 技术
  • 实践
  • 资源
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(十七)
技术讨论

​作者丨科技猛兽
编辑丨极市平台

本文目录

37 只使用纯粹的注意力机制就够了吗 (ICML 2021)
(来自谷歌,EPFL)
37.1 Attention is not all you need 原理分析
37.2 Attention is not all you need 代码解读

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

本文介绍这个工作来自谷歌,Attention is not all you need 这篇文章为网络架构中含有 attention 结构的模型提供了一个新的看法,这篇工作认为:Transformer 这种结构在 CV 和 NLP 任务上表现良好并不代表仅仅由 Self-attention 机制构成的网络 (即去掉 MLP 层,残差连接,Layer Normalziation 等等) 也能够表现良好。本文证明了一件事:

随着输入在网络中向前传播,深度不断加深,仅仅由 Self-attention 机制构成的网络 (即去掉 MLP 层,残差连接,Layer Normalziation 等等) 的表达能力会逐渐降低。最终,输出会退化成一个秩为1的矩阵,每一排的值变得一致。这个问题,本文把它称为 Rank Collapse。而Transformer中的其他构件 (即 MLP 层,残差连接,Layer Normalziation 等等) 可以缓解这个问题。比如,Shortcut 操作和 MLP 在缓解 Rank Collapse 问题上起了关键作用。

1 新的路径分解方法来研究 Self-attention 网络

作者提出了一种新的路径分解方法来研究 Self-attention 网络,如下图1所示。路径分解方法将一个 Self-attention 网络视为不同的通路组成。在每1层,一条路径可以通过1个head或者跳过1个layer (因为有 Shortcut 连接)。在每个 attention layer 之后,追加 MLP 块,就构成了 Transformer的架构。把一个 Self-attention,分解为一个若干 "弱相互依赖" 的路径的线性组合,其中每个路径对应一个单 head 的 Self-attention。

或者说,增加了一个 dummy head,一个路径,会选择已有的 heads + dummy head中的任何一个 head 去通过:如果选择了已有的 heads (例如 $H$ 个 head 中的某一个) ,那就是通过网关,如果选择了 "dummy head",那就是越过了当前 layer。

图1:具有H个heads和L层的Self-attention网络的随机2条路径

注释: 本文使用了一种特殊的范数,叫做 $\ell1, \ell\infty$ 混合范数。比如矩阵 $ {X}$ 满足 $| {X}|_{1,\infty} = \sqrt{| {X} |1 | {X} |\infty}$ ,注意到 $\ell_{1,\infty}$ 其实并不是一种真正的范数,因为范数的三大条件是:

  • 正定性,即 $| {X} |\geq 0$ 且当 $| {X} |=0$ 时必有 $ {X}=0$ 。
  • 齐次性,即 $| c {X} |=|c|| {X} |$ 。
  • 三角不等式,即 $| {X+Y} |\leq | {X} |+| {Y} |$ 。

所以严格来讲 $\ell{1,\infty}$ 其实并不属于一种范数,因为 $\ell{1,\infty}$ 不满足三角不等式。

其他简写符号: $[H] = (1,\cdots,H)$ 。

第1步是证明:一个只由 Self-attention 模块构成的网络的输出 (矩阵) 随深度的增加,其 Rank 呈指数衰减,最终会趋向于一个1秩矩阵,即所有 token 都一样。 当有100层,1000层,10000层。。。SANs 的时候,最终的每个词的向量表示都一样了?

假设 $\textbf{X}\in\mathbb{R}^{n\times d{in}}$ 是输入矩阵,它包含有 $n$ 个tokens,每个 token (单词) 用 $d{in}$ 维度表示。模型包含 $L$ 层 MHSA , 即多头自注意力层。每个 MHSA 包含 $H$ 个 head。则第 $h$ 个 head 的 MHSA 的输出可以写成:

$$
\begin{align} SA_h( {X}) ={P}h {X} {W}{V,h} + {1} {b}_{V,h}^T \end{align} \tag{1}
$$

式中,$\boldsymbol{W}{V, h} \in d{i n} \times d_{v}$是 Value matrix。 ${P}_h$ 是 attention map,按照下式来计算:

$$
\begin{align} {P}h &= \text{softmax}\big(d{qk}^{-\frac{1}{2}}( {X} {W}{Q,h} + {1} {b}{Q,h}^T) ( {X} {W}{K,h} + {1} {b}{K,h}^T )^T\big) \ &= \text{softmax} (d{qk}^{-\frac{1}{2}}( {X} {W}{QK,h} {X}^T + {1} {b}{Q,h}^T {W}{K,h}^T {X}^T)), \end{align} \tag{2}
$$

上式就是标准的 Self-attention 的实现方式。其中,$\boldsymbol{W}{Q K, h}=\boldsymbol{W}{Q, h}^{T} \boldsymbol{W}_{K, h}$。

每一层的输出是将所有注意力头的输出 (沿着最后一个维度) concat 起来,并线性投影到一个适当大小的子空间上:

其中, $\boldsymbol{W}{Q K, h}=\boldsymbol{W}{Q, h}^{T} \boldsymbol{W}{K, h}, \boldsymbol{b}{O}=\sum{h} \boldsymbol{b}{O, h}$ 。每一层都有相同数量的 head 。 $\boldsymbol{X}^{l}$ 是第 $l$ 层 的输出。 $\quad \boldsymbol{W}{K, h} \in \mathbb{R}^{d{i n} \times d{q k}}, \quad \boldsymbol{W}{Q, h} \in \mathbb{R}^{d{i n} \times d{q k}}$ 。

SAN 的输出可以根据3式表示为:

$$
\begin{align} {X}^L &= \sum_{h \in [H]} {P}_h^{L} {X}^{L-1} {Wh}^{L} \nonumber \ &= \sum{h \in [H]} {P}h^{L} \bigg( \sum{h' \in [H]} {P{h'}^{L-1}} {X}^{L-2} {W}{h'}^{L-1} \bigg) \, {W}h^{L} = \hspace{-3mm} \sum{hL,h{L-1} \in [H]^2} \hspace{-3mm} {P}_{hL}^{L} {P}{h{L-1}}^{L-1} {X}^{L-2} {W}{h{L-1}}^{L-1} {W}{h_L}^{L}, \nonumber \end{align} \tag{4}
$$

展开递归,得到:

$$
\begin{align} {X}^L &= \hspace{-3mm} \sum_{h_1, \ldots, hL \in [H]^L} \hspace{-3mm} ( {P}{hL}^{L} \cdots {P}{h1}^{1}) \, {X} \, ( {W}{h1}^{1} \cdots {W}{h_L}^{L}). \end{align} \tag{5}
$$

2 SAN 的路径分解

深度为 $L$ ,且有 $H$ 个 head 的 self-attention 网络每一层的输出为:

$$
\begin{align} SAN( {X}) = \sum{\textit{path} \in [H]^L} {P}\textit{path} \, {X} \, {W}_{\textit{path}} + {1} {b}^T, \end{align} \tag{T1}
$$

式中, $\quad \boldsymbol{P}{\text {path }}=\boldsymbol{P}{h{L}}^{L} \cdots \boldsymbol{P}{h{1}}^{1}$ 是个依赖输入的转移矩阵。而 $\boldsymbol{W}{\text {path }}=\boldsymbol{W}{h{1}}^{1} \cdots \boldsymbol{W}{h{L}}^{L}$ 和 $\boldsymbol{b}$ 是不依赖于输入的变量。

对于任意的转移矩阵 $ {P}$ ,有: $ {P} {1}= {1}$ 。

6式可以看成是 $H^L$ 个长度为 $L$ 的路径:

$$
\textit{path} = (h_1, \ldots, h_L), \ \text{ where } \ h_l\in (0,1,\ldots, H)\
$$

转移矩阵 $ {P}$ 可以完成 tokens 之间的混合。

下面作者证明每一条路径都会逐渐收敛到一个 1秩矩阵 (rank-1 matrix)。有趣的是,这种聚合作用在增加更多的层时是没有帮助的。尽管路径的数量是指数增加的,但每条路径都是双指数衰减的,并且导致了一个 rank-1 的输出。

3 SAN 的收敛性

接下来会介绍几个 Lemma,分别证明在:

  1. single-head, single-layer
  2. multiple-heads, single-layer
  3. single-head, multiple-layers
  4. multiple-heads, multiple-layers

这四种情况下,下面的这个残差是如何变化的。

$$
\text{res}( {X}) = {X} - {1} {x}^T, \ \text{ where } \ {x} = \text{argmin}_{ {x}} | {X} - {1} {x}^T|\
$$

Lemma A.1 (single-head, single-layer):

即,残差遵循:

$$
| \text{res}(\text{SA}( {X})) |{1,\infty} \leq \frac{4 \, | {W}{QK}|1 \, | {W}{V}|{1,\infty}}{\sqrt{d{qk}}} \, |\text{res}( {X})|_{1,\infty}^3\
$$

Lemma A.1 证明:

unscaled attention 的权重值可以这样计算:

$$
\boldsymbol{A}=\left(\boldsymbol{X} \boldsymbol{W}{Q}+\mathbf{1} \boldsymbol{b}{Q}^{T}\right)\left(\boldsymbol{X} \boldsymbol{W}{K}+\mathbf{1} \boldsymbol{b}{K}^{T}\right)^{T} \tag{P.1}
$$

根据 softmax 的平移不变性,有:

$\boldsymbol{A}=\boldsymbol{X} \boldsymbol{W}{Q K} \boldsymbol{X}^{T}+\mathbf{1} \boldsymbol{b}{Q K}^{T} \boldsymbol{X}^{T} \tag{P.2}$

式中,$\boldsymbol{W}{Q K}=\boldsymbol{W}{Q} \boldsymbol{W}{K}^{T} \in \mathbb{R}^{d{i n} \times d{i n}}, \boldsymbol{b}{Q K}=\boldsymbol{W}{K} \boldsymbol{b}{Q}$。

使用速记表示: $ {R} := \text{res}( {X}), {R'} := \text{res}( {X'})$ 。

scaled attention 的权重值就可以写成:

$$
\begin{align} A &= ( {1} {x}^T + {R}) \frac{ {W}{QK}}{\sqrt{d{qk}}} ( {1} {x}^T + {R})^T + {1} \frac{ {b}{QK}^T}{\sqrt{d{qk}}}( {1} {x}^T + {R})^T \ &= \left(\frac{ {x}^T {W}{QK} {x}}{\sqrt{d{qk}}} {1} + {R} \frac{ {W}{QK}}{\sqrt{d{qk}}} {x} + {1} \frac{ {b}{QK}^T}{\sqrt{d{qk}}} {x} \right) {1}^T + {1} {x}^T \frac{ {W}{QK}}{\sqrt{d{qk}}} {R}^T + {R} \frac{ {W}{QK}}{\sqrt{d{qk}}} {R}^T + {1} \frac{ {b}{QK}^T}{\sqrt{d{qk}}} {R}^T \end{align} \tag{P.3}
$$

因为上式的第1项的任意一行里面的每一列的值是相同的,所以满足 softmax 的平移不变性[1]。所以第1项可以安全地去掉。所以我们有:

$$
\begin{align} {P} &= \text{softmax} \left( {R} \frac{ {W}{QK}}{\sqrt{d{qk}}} {R}^T + {1} {r}^T \right), \end{align} \tag{P.4}
$$

式中, $\quad \boldsymbol{r}:=\boldsymbol{R} \frac{\boldsymbol{W}{Q K}^{T}}{\sqrt{d{q k}}} \boldsymbol{x}+\boldsymbol{R} \frac{\boldsymbol{b}{Q K}}{\sqrt{d{q k}}}$ 。

令 $\boldsymbol{E}=\boldsymbol{R} \frac{\boldsymbol{W}{Q K}}{\sqrt{d{q k}}} \boldsymbol{R}^{T}, \tilde{\boldsymbol{A}}=\mathbf{1} \boldsymbol{r}^{T}$,则注意力权重重新加权的输入为:

$$
\begin{aligned}
\boldsymbol{P} \boldsymbol{X} &=\boldsymbol{P}\left(\mathbf{1} \boldsymbol{x}^{T}+\boldsymbol{R}\right) \
&=\mathbf{1} \boldsymbol{x}^{T}+\boldsymbol{P} \boldsymbol{R} \
&=\mathbf{1} \boldsymbol{x}^{T}+\operatorname{softmax}\left(\mathbf{1} \boldsymbol{r}^{T}+\boldsymbol{E}\right) \boldsymbol{R} \
& \leq \mathbf{1} \boldsymbol{x}^{T}+(\boldsymbol{I}+2 \boldsymbol{D}) \mathbf{1} \operatorname{softmax}(\boldsymbol{r})^{T} \boldsymbol{R} \
&=\mathbf{1}\left(\boldsymbol{x}^{T}+\operatorname{softmax}(\boldsymbol{r})^{T} \boldsymbol{R}\right)+2 \boldsymbol{D} \mathbf{1} \operatorname{softmax}(\boldsymbol{r})^{T} \boldsymbol{R}
\end{aligned} \tag{P.5}
$$

其中第4行的不等式来自下面的 Lamma A.2,且根据 Lamma A.2, $ {D}$ 满足条件:

$$
\begin{align} | {D} {1}|\infty = \max{i,j,j'} | {\delta}_i^T {E} ( {\delta}j - {\delta}{j'}) | \leq 2 \max{ij} |E{ij}| &\leq 2 \, | {E}|{1} \ &= 2 \, | {R} \frac{ {W}{QK}}{\sqrt{d_{qk}}} {R}^T|1 \ &\leq \frac{2}{\sqrt{d{qk}}} \, | {R}|1 | {W}{QK}|_1 | {R}^T|1 \ &= \frac{2}{\sqrt{d{qk}}} \, | {R}|1 | {W}{QK}|1 | {R}|\infty, \end{align} \tag{P.6}
$$

同样有: $ {P} {X} \geq {1} ( {x}^T + \text{softmax}( {r})^T {R}) - {D} \, {1} \, \text{softmax}( {r})^T {R}$ 。

因为 self-attention layer 的输出是: $SA( {X}) = {P} {X} {W}_V$ ,所以有:

$$
\begin{align} | [SA( {X}) - {1} ( {r'})^T]_{ij} | &\leq 2 \, |[ {D} \, {1}\, \text{softmax}( {r})^T {R} {W}V]{ij}|, \end{align} \tag{P.7}
$$

式中, $ {r'} = ( {x} + {R}^T \text{softmax}( {r})) {W}_V$ 。

现在我们对不等式的右侧做约束,有:

$$
\begin{align} | {D} \, {1} \, \text{softmax}( {r})^T {R} {W}_V|1 &\leq | {D} {1}|\infty \, | {R}|_1 | {W}_V|_1 , \end{align} \tag{P.7}
$$

这一项的成立来自赫尔德不等式 (Holder's inequality) 以及 $| {A} {B}|_1\le | {A}|_1 | {B}|_1$ 。

把式 P.6 代入式 P.7 得:

$$
\begin{align} |SA( {X}) - {1} ( {r'})^T|1 &\leq \frac{4}{\sqrt{d{qk}}} \, | {R}|1^2 | {R}|\infty \, | {W}_{QK}|1 \, | {W}{V}|_1 . \end{align} \tag{P.8}
$$

根据 P.7 式还可以得到:

$$
\begin{align} |SA( {X}) - {1} ( {r'})^T|\infty &\leq 2 | {D} \, {1}|\infty |\text{softmax}( {r})^T {R} {W}V |\infty \ &\leq 2 | {D} \, {1}|\infty | {R}|\infty | {W}V |\infty \ &\leq \frac{4}{\sqrt{d_{qk}}} \, | {R}|1 | {R}|\infty^2 | {W}_{QK}|_1 \, | {W}V |\infty. \end{align} \tag{P.9}
$$

根据 P.8,P.9 式的这2个范数有,两式相乘后开根号有:

$$
\begin{align} | {R}'|_{1,\infty} &= \sqrt{| {R}'|1 | {R}'|\infty} \leq \frac{4\, | {W}_{QK}|1 | {W}{V}|{1, \infty}}{\sqrt{d{qk}} } \, (\sqrt{| {R}|1 | {R}|\infty})^3 \&= \frac{4\, | {W}_{QK}|1 | {W}{V}|{1, \infty}}{\sqrt{d{qk}}} \, | {R}|_{1,\infty}^3 \end{align} \tag{P.10}
$$

Lemma A.1 得证。

Lamma A.2 (A technical lemma):

假设 $ {P}$ 是与 $ {A}$ 相关联的行随机矩阵, $\tilde{ {P}}$ 是与 $\tilde{ {A}}= {A}- {E}$ 相关联的行随机矩阵,则:

$$
( {I} - {D}) \, \tilde{ {P}} \leq {P} \leq ( {I} + 2 {D}) \, \tilde{ {P}}\
$$

式中, $\quad \boldsymbol{D}$ 是对角矩阵, 且有: $D{i i}=\max {j}\left|\boldsymbol{\delta}{i}^{T} \boldsymbol{E}\left(\boldsymbol{\delta}{j}-\boldsymbol{\delta}_{j^{\prime}}\right)\right|$ 。

Lemma A.2 证明:

首先从行随机矩阵的定义开始:

$$
\begin{align} P{ij} = [\text{softmax}( {A})]{ij} = [\text{softmax}(\tilde{ {A}} + {E})]{ij} = \frac{\exp{(\tilde{A}{ij} + E{ij})}}{ \sum{t = 1}^n \exp{(\tilde{A}{it} + E{it})} } = \frac{\exp{(\tilde{A}{ij})} \, \exp{(E{ij})}}{ \sum{t = 1}^n \exp{(\tilde{A}{it})} \, \exp{(E_{it})} } \end{align} \
$$

由于 $\tilde{P{ij}}=\frac{\exp{(\tilde{A}{ij})} \, }{ \sum{t = 1}^n \exp{(\tilde{A}{it})} \, }$ ,所以 $P{ij}$ 和 $\tilde{P{ij}}$ 之间的关系为:

$$
\min{j'} \exp{(E{ij} - E{ij'})} \, \tilde{P}{ij} \leq P{ij} \leq \tilde{P}{ij} \, \max{j'} \exp{(E{ij} - E_{ij'})}\
$$

通过泰勒展开,可以进一步放缩到:

$$
(1 - \min{j'} (E{ij} - E{ij'})) \, \tilde{P}{ij} \leq P{ij} \leq \tilde{P}{ij} \, ( 1 + 2 \max{j'} (E{ij} - E_{ij'}))\
$$

Lemma A.2 得证。

Lemma A.3 (multiple-heads, single-layer):

对于任何包含 $H$ 个 heads 的 SAN 来讲有:

$$
\begin{align} | \text{res}(SAN( {X})) |{1,\infty} \leq \frac{4 H \beta}{\sqrt{d{qk}}} \, |\text{res}( {X})|_{1,\infty}^3\,, \end{align} \tag{P.11}
$$

式中, $| {W}_{QK}^l|1 | {W}{V}^{l}|_{1,\infty} \leq \beta$ 。

Lemma A.3 证明:

multi-head attention layer 的输出可以表示为:

$$
\begin{align} SAN( {X}) = \sum_{h\in [H]} {P}h {X} {W}{h}\, = \sum_{h\in [H]} SA_h( {X}), \end{align} \tag{P.12}
$$

式中,$\boldsymbol{W}{h}:=\boldsymbol{W}{V, h} \boldsymbol{W}_{O, h}$。

$$
\begin{aligned}
\sum{h \in[H]} \boldsymbol{P}{h} \boldsymbol{X} \boldsymbol{W}{h} &=\sum{h \in[H]} \boldsymbol{P}{h}\left(\mathbf{1} \boldsymbol{x}^{T}+\boldsymbol{R}\right) \boldsymbol{W}{h} \
&=\sum{h \in[H]} \boldsymbol{P}{h}\left[\mathbf{1} \boldsymbol{x}^{T}+\boldsymbol{P} \boldsymbol{R}\right] \boldsymbol{W}{h} \
&=\sum
{h \in[H]} \boldsymbol{P}{h}\left[\mathbf{1} \boldsymbol{x}^{T}+\operatorname{softmax}\left(\mathbf{1} \boldsymbol{r}{h}^{T}+\boldsymbol{E}\right) \boldsymbol{R}\right] \boldsymbol{W}{h} \
& \leq \sum
{h \in[H]} \boldsymbol{P}{h}\left[\mathbf{1} \boldsymbol{x}^{T}+(\boldsymbol{I}+2 \boldsymbol{D}) \mathbf{1} \operatorname{softmax}\left(\boldsymbol{r}{h}\right)^{T} \boldsymbol{R}\right] \boldsymbol{W}{h} \
&=\sum
{h \in[H]} \boldsymbol{P}{h}\left[\mathbf{1}\left(\boldsymbol{x}^{T}+\operatorname{softmax}\left(\boldsymbol{r}{h}\right)^{T} \boldsymbol{R}\right)+2 \boldsymbol{D} \mathbf{1} \operatorname{softmax}\left(\boldsymbol{r}{h}\right)^{T} \boldsymbol{R}\right] \boldsymbol{W}{h}
\end{aligned} \tag{P.13}
$$

因为 multi-head attention layer 的输出是: $SAN( {X}) = \sum_{h\in [H]} {P}h {X} {W}{h}$ ,所以有:

$$
\begin{align} | [SA( {X}) - {1} ( {r''})^T]_{ij} | &\leq 2 \, \left|\left[ \sum_h {D}_h \, {X} \, \text{softmax}( {r}h)^T {R} {W}{h}\right]_{ij} \right|, \end{align} \tag{P.14}
$$

式中,$\boldsymbol{r}^{\prime \prime}=\sum{h}\left(\boldsymbol{x}+\boldsymbol{R}^{T} \operatorname{softmax}\left(\boldsymbol{r}{h}\right)\right) \boldsymbol{W}_{h}$

对式 P.14 应用三角不等式,得到 $\ell1, \ell\infty$ 混合范数:

$$
\begin{align} | SA^H( {X}) - {1} ( {r''})^T |1 & \leq 2 \sum{h\in[H]} | {D}_h \, {1} \, \text{softmax}( {r}h)^T {R} {W}{h} |1 \& \leq 2H \max{h\in[H]} | {D}_h \, {1} \, \text{softmax}( {r}h)^T {R} {W}{h} |_1 \end{align} \tag{P.15}
$$

$$
\begin{align} | SA^H( {X}) - {1} ( {r''})^T |\infty & \leq 2 \sum{h\in[H]} | {D}_h \, {1} \, \text{softmax}( {r}h)^T {R} {W}{h} |\infty \& \leq 2H \max{h\in[H]} | {D}_h \, {1} \, \text{softmax}( {r}h)^T {R} {W}{h} |_\infty \end{align} \tag{P.16}
$$

根据 P.15,P.16 式的这2个范数,两式相乘后开根号,按照 Lemma A.1 的后续步骤即得到 Lemma A.3。

Lemma A.3 得证。

Lemma A.4 (single-head, multiple-layers):

对于任何包含 $L$ 层的 single-head SAN 来讲,有:

$|\operatorname{res}(S A N(\boldsymbol{X}))|{1, \infty} \leq\left(\frac{4 \beta}{\sqrt{d{q k}}}\right)^{\frac{3^{L}-1}{2}}|\operatorname{res}(\boldsymbol{X})|_{1, \infty}^{3^{L}},\tag{P.17}$

Lemma A.4 证明:

将递归从最后一层向后展开到第一层,并应用引理 Lemma A.1,得到:

$$
\begin{aligned}
\left|\operatorname{res}\left(\boldsymbol{X}^{L}\right)\right|{1, \infty} & \leq \frac{4 \beta}{\sqrt{d{q k}}}\left|\operatorname{res}\left(\boldsymbol{X}^{L-1}\right)\right|{1, \infty}^{3} \
& \leq \frac{4 \beta}{\sqrt{d
{q k}}}\left(\frac{4 \beta}{\sqrt{d{q k}}}\left|\operatorname{res}\left(\boldsymbol{X}^{L-2}\right)\right|{1, \infty}^{3}\right)^{3} \
&=\frac{4 \beta}{\sqrt{d{q k}}}\left(\frac{4 \beta}{\sqrt{d{q k}}}\right)^{3}\left|\operatorname{res}\left(\boldsymbol{X}^{L-2}\right)\right|{1, \infty}^{2^{2}} \
& \leq \cdots \
& \leq \prod
{l=1}^{L}\left(\frac{4 \beta}{\sqrt{d{q k}}}\right)^{3^{l-1}}|\operatorname{res}(\boldsymbol{X})|{1, \infty}^{L}=\left(\frac{4 \beta}{\sqrt{d{q k}}}\right)^{\frac{3^{L}-1}{2}}|\operatorname{res}(\boldsymbol{X})|{1, \infty}^{L},
\end{aligned} \tag{P.18}
$$

Lemma A.4 得证。

Lemma A.5 (multiple-heads, multiple-layers):

对于任何包含 $L$ 层和 $H$ 个 heads 的 SAN 来讲,有:

$$
\left|\operatorname{res}\left(\boldsymbol{X}^{L}\right)\right|{1, \infty} \leq\left(\frac{4 H \beta}{\sqrt{d{q k}}}\right)^{\frac{3^{L}-1}{2}}|\operatorname{res}(\boldsymbol{X})|_{1, \infty}^{L} \tag{P.19}
$$

式中, $| {W}_{QK}^l|1 | {W}{V}^{l}|_{1,\infty} \leq \beta$ 。

Lemma A.5 证明:

Lemma A.5 的证明思路与 Lemma A.2 一致,此处省略。

Lemma A.5 得证。

4 $\text{res}( {X})$ 的变化规律

通过以上 Lemma A.1-A.5,我们特别关注 $\text{res}( {X})$ 会如何变化:

对于任何包含 $L$ 层和 $H$ 个 heads 的 SAN 来讲,如果满足 $| {W}_{QK}^l|1 | {W}{V}^{l}|{1,\infty} \leq \beta$ ,则随着网络的加深,相当于 $|\text{res}( {X})|{1,\infty}$ 双指数收敛到一个1秩矩阵

我们解释一下 P.19 式的含义:

首先 $\text{res}( {X})$ 代表的是输入的 $ {X}$ 这个张量和一个1秩矩阵 $ {1} {x}^T$ 的差距有多大。
那么 $\text{res}( {X})$ 越小,则代表 $ {X}$ 这个张量和一个1秩矩阵 $ {1} {x}^T$ 的差距就越小。
当满足 $4\beta < \sqrt{d_{qk}}$ 时,括号里面的值小于1,则 $|\text{res}( {X}) |$ 会逐渐收敛,也就是说:在更深的层,$ {X}$ 这个输入张量和一个1秩矩阵 $ {1} {x}^T$ 的差距会很小,即:特征收敛到了1秩矩阵

这种特征收敛到1秩矩阵的现象,作者称之为:秩崩塌

既然 Transformer 模型会出现秩崩塌的现象,但是为什么实际应用中还可以获得良好的训练呢?作者研究了一下3个角色的作用:Skip-Connections,Multi-Layer Perceptrons (MLP),Layer Normalization。

5 抑制秩崩塌现象的方法:skip-connection 很重要

Corollary A.1 (SAN with skip-connections):

对于任何包含 $L$ 层和 $H$ 个 heads 的,带有 Skip-Connection 的 SAN 来讲,假设有:$| {W}_{QK}^l|1 | {W}{h}^{l}|_{1,\infty} \leq \beta$ ,所有的 heads $h \in [H]$ ,所有的 layers $l \in [L]$ ,输出的界限是:

$$
\begin{align} |\text{res}( {X}^L)|{1,\infty} \leq \max{0 \le l \le L} \left(\frac{8 \, \beta\, H}{\sqrt{d{qk}}}\right)^{\frac{3^l-1}{2}} \, (2H)^{3^l(L-l)}|\text{res}( {X})|{1,\infty}^{3^l}, \end{align} \tag{P.20}
$$

不会导致秩崩塌现象。

Corollary A.1 证明:

对于带有 Skip-Connection 的 SAN 来讲,Lemma A.1 中的 single-head, single-layer 现在变成了:

$\begin{align} |\text{res}(SAN( {X}))|{1,\infty} \le \frac{4\, | {W}{QK,h}|1 | {W}{V}|{1, \infty}}{\sqrt{d{qk}}} \, |\text{res}( {X})|{1,\infty}^3 + |\text{res}( {X})|{1,\infty} \end{align} \tag{P.21}$

为了得到一个 multi-layer bound,我们向后展开递归。

我们首先考虑一个单头模型,并且有:$| {W}_{QK}^l|1 | {W}{V}^{l}|_{1,\infty} \leq \beta$ ,我们得到:

$$
\begin{align} | \text{res}( {X}^L)|{1,\infty} &\leq \frac{4 \, \beta}{\sqrt{d{qk}}} \, |\text{res}( {X}^{L-1})|{1,\infty}^3 + |\text{res}( {X}^{L-1})|{1,\infty} \nonumber \ &\leq 2\max(\frac{4 \, \beta}{\sqrt{d{qk}}} \, |\text{res}( {X}^{L-1})|{1,\infty}^3, \, |\text{res}( {X}^{L-1})|_{1,\infty} ) \end{align} \tag{P.22}
$$

下面我们将递归从最后一层向后展开到第一层,在展开的第 $k$ 个 step 时, P.22 式中两项的最大值要么是 $\frac{4 \, \beta}{\sqrt{d{qk}}} \, |\text{res}( {X}^{L-k})|{1,\infty}^3$ ,要么是 $|\text{res}( {X}^{L-k})|_{1,\infty}$ 。假设网络深度为 $L$ ,有 $l$ 次是前者最大。注意选择的顺序并不重要,重要的是被选择的次数。因此有:

$$
\begin{align} | \text{res}( {X}^L)|{1,\infty} \leq \max{0 \le l \le L} \left(\frac{8 \, \beta}{\sqrt{d{qk}}}\right)^{\frac{3^l-1}{2}} \, 2^{3^l(L-l)}|\text{res}( {X})|{1,\infty}^{3^l}. \end{align} \tag{P.23}
$$

将式 P.23 拓展到有 $H$ 个 heads 的情况:

$$
\begin{align} |\text{res}(SAN( {X}))|{1,\infty} \le \frac{4\, \beta \, H}{\sqrt{d{qk}}} \, |\text{res}( {X})|{1,\infty}^3 + H |\text{res}( {X})|{1,\infty} \end{align} \tag{P.24}
$$

$$
\begin{align} |\text{res}( {X}^L)|{1,\infty} \leq \max{0 \le l \le L} \left(\frac{8 \, \beta\, H}{\sqrt{d{qk}}}\right)^{\frac{3^l-1}{2}} \, (2H)^{3^l(L-l)}|\text{res}( {X})|{1,\infty}^{3^l}, \end{align} \tag{P.25}
$$

Corollary A.1 得证。

Skip-Connection 操作使得通过路径分解,总有一条路径跳过所有层,即总存在长度为0的路径。Skip-Connection 操作使得 $|\text{res}( {X}^L)|_{1,\infty}$ 不会收敛,即不会导致秩崩塌现象。

6 抑制秩崩塌现象的方法:MLP 有一定帮助

这一节研究 MLP 的作用,MLP 可以写成下式:

$$
\boldsymbol{X}^{l+1}=f{l}\left(\sum{h \in[H]} \boldsymbol{P}{h} \boldsymbol{X}^{l} \boldsymbol{W}{h}\right) \tag{P.26}
$$

Corollary A.2 (SAN with MLPs):

对于任何包含 $L$ 层和 $H$ 个 heads 的,带有 Skip-Connection 的 SAN 来讲,假设有:$| {W}_{QK}^l|1 | {W}{h}^{l}|_{1,\infty} \leq \beta$ ,所有的 heads $h \in [H]$ ,所有的 layers $l \in [L]$ , $fl$ 的 Lipschitz constant 是 $\lambda{l, 1,\infty}$ ,则输出的界限是:

$$
\left|\operatorname{res}\left(\boldsymbol{X}^{L}\right)\right|{1, \infty} \leq\left(\frac{4 \beta H \lambda}{\sqrt{d{q k}}}\right)^{\frac{3 L{-1}}{2}}|\operatorname{res}(\boldsymbol{X})|{1, \infty}^{3^{L}}, \tag{P.27}
$$

不会导致秩崩塌现象。

Corollary A.2 证明:

证明思路和 Lemma A.3 的证明很相似,根据 P.14,P.15 式有:

$$
\left|\left[\sum{h \in[H]} \boldsymbol{P}{h} \boldsymbol{X}^{l} \boldsymbol{W}{h}-\mathbf{1}\left(\boldsymbol{r}^{\prime}\right)^{T}\right]{i j}\right| \leq 2\left|\left[\sum{h} \boldsymbol{D}{h} \mathbf{1} \operatorname{softmax}\left(\boldsymbol{r}{h}\right)^{T} \boldsymbol{R} \boldsymbol{W}{h}\right]_{i j}\right| \tag{P.28}
$$

对式 P.28 应用三角不等式,得到 $\ell1, \ell\infty$ 范数:

$$
\begin{align} \left\Vert \sum_{h \in [H]} {P}h {X}^{l} {W}{h} - {1}( {r}')^T \right\Vertp & \leq 2H \max{h\in[H]} | {D}_h \, {1} \, \text{softmax}( {r}h)^T {R} {W}{h} |_p, \end{align} \tag{P.29}
$$

因为有 MLP 层的作用, $\sum_{h \in [H]} {P}h {X}^{l} {W}{h}$ 通过 MLP 层之后的 boundary 会发生变化:

$$
\begin{align} |\text{res}(SAN( {X}) )|p &= \left\Vert f\left(\sum{h \in [H]} {P}h {X}^{l} {W}{h}\right) - {1} r''^T\right\Vertp &\ &= \left\Vert f\left(\sum{h \in [H]} {P}h {X}^{l} {W}{h}\right) - f( {1} (r')^T) \right\Vertp &\rhd\, \text{$f$ preserves constancy-across-rows.} \ &\leq \lambda{l,p} \left\Vert \sum_{h \in [H]} {P}h {X}^{l} {W}{h} - {1} (r')^T\right\Vertp &\rhd\, \text{By definition of Lipschitz constant.} \ &\leq 2\lambda{l,p}\, H \max_{h\in[H]} | {D}_h \, {1} \, \text{softmax}( {r}h)^T {R} {W}{h} |_p &\rhd\, \text{ By Eq.P.29} \end{align} \tag{P.30}
$$

类似式 P.7,对式 P.30 应用赫尔德不等式 (Holder's inequality) 和三角不等式得到:

$$
\begin{aligned}
\left|\boldsymbol{D}{h} \mathbf{1} \operatorname{softmax}\left(\boldsymbol{r}{h}\right)^{T} \boldsymbol{R} \boldsymbol{W}{h}\right|{1} & \leq\left|\boldsymbol{D}{h} \mathbf{1}\right|{\infty}|\boldsymbol{R}|{1}\left|\boldsymbol{W}{h}\right|{1} \
\left|\boldsymbol{D}
{h} \mathbf{1} \operatorname{softmax}\left(\boldsymbol{r}{h}\right)^{T} \boldsymbol{R} \boldsymbol{W}{h}\right|{\infty} & \leq\left|\boldsymbol{D}{h} \mathbf{1}\right|{\infty}|\boldsymbol{R}|{\infty}\left|\boldsymbol{W}{h}\right|{\infty}
\end{aligned}\tag{P.31}
$$

根据式 P.6 有:

$$
\begin{align} | {D} {1}|\infty = \max{i,j,j'} | {\delta}_i^T {E} ( {\delta}j - {\delta}{j'}) | \leq 2 \max{ij} |E{ij}| &\leq 2 \, | {E}|{1} \ &= 2 \, | {R} \frac{ {W}{QK}}{\sqrt{d_{qk}}} {R}^T|1 \ &\leq \frac{2}{\sqrt{d{qk}}} \, | {R}|1 | {W}{QK}|_1 | {R}^T|1 \ &= \frac{2}{\sqrt{d{qk}}} \, | {R}|1 | {W}{QK}|1 | {R}|\infty, \end{align} \tag{P.6}
$$

代入 P.31,P.30 之后,有:

$$
\begin{align} |\text{res}(SAN( {X} )) |{1,\infty} \leq \frac{4\, H\, \lambda{l, 1,\infty} | {W}_{QK,h}|1 | {W}{h}|{1, \infty}}{\sqrt{d{qk}}} \, |\text{res}( {X})|_{1,\infty}^3 \end{align} \tag{P.32}
$$

最后, 递归展开 boundary,有:

$$
\begin{align} | \text{res}( {X}^L)|{1,\infty} &\leq \left( \frac{4 \, \beta\, H\, \lambda{l, 1,\infty}} { \sqrt{d{qk}}}\right)^{\frac{3^L-1}{2}} \, |\text{res}( {X})|{1,\infty}^{3^L}, \end{align} \tag{P.33}
$$

Corollary A.2 得证。

MLP 操作也可以使得 $|\text{res}( {X}^L)|{1,\infty}$ 收敛,即也会导致秩崩塌现象。但是由于 $\lambda{l, 1,\infty}$ 的作用收敛速度会减慢。$\lambda_{l, 1,\infty}$ 越大,收敛速度就越慢。应该强调,使用 MLPs 来抵消秩崩溃并不是没有缺点:增加 Lipschitz 常数会减慢残差收敛速度,同时也会降低模型对输入扰动的敏感性和鲁棒性。更大的 Lipschitz 常数也可能对优化提出更大的挑战,因为会导致更大的梯度方差。

7 抑制秩崩塌现象的方法:Layer Normalization 没有用

Layer Normalization 是对输出特征进行 shifting 和 rescaling 操作,如下式所示:

$$
\begin{align} \text{LN}\hspace{-1pt}\left({SAN( {X})}\right) &= \text{LN}\hspace{-1pt}\left({\sum_{h \in [H]} {P}h {X} {W}{h} + {1} {b}{O}^T}\right) = \bigg(\sum{h \in [H]} {P}h {X} {W}{h} + {1} {b}{O}^T - {1} {b}{LN}^T \bigg) {D}_{LN}^{-1}, \end{align} \tag{P.34}
$$

式中, $\boldsymbol{b}{L N}$ 是 $S A N(\boldsymbol{X})$ 每一列的均值。 $\boldsymbol{D}{L N}$ 是对角矩阵, 每个值代表 $S A N(\boldsymbol{X})$ 每一列 的标准差。

令 $\tilde{ {W}}h = {W}{h} {D}{LN}^{-1},\tilde{ {b}}{O} = {b}{O} - {b}{LN}$ ,上式可以写成:

$$
\begin{align} \text{LN}\hspace{-1pt}\left({SAN( {X})}\right) &= \sum_{h \in [H]} {P}h {X} \tilde{ {W}}{h} + {1} \tilde{ {b}}_{O}^T, \end{align} \tag{P.35}
$$

虽然 $\tilde{ {W}}h$ 和 $\tilde{ {b}}{O}$ 都是依赖与输入的变量,但是式 P.35 依然能够等效成没有 LN 的形式,所以 Layer Normalization 对抑制秩崩塌现象不起作用。

总结

本文分为两部分。

第1部分通过 5个 Lemma 证明了在 single-head, single-layer,multiple-heads, single-layer,single-head, multiple-layers,multiple-heads, multiple-layers 这四种情况下,残差 $\text{res}( {X}) = {X} - {1} {x}^T, \ {x} = \text{argmin}_{ {x}} | {X} - {1} {x}^T|$ 是如何变化的。发现 $\text{res}( {X})$ 都会逐步收敛到0,即在更深的层,$ {X}$ 这个输入张量和一个1秩矩阵 $ {1} {x}^T$ 的差距会很小,即:特征收敛到了1秩矩阵。这种特征收敛到1秩矩阵的现象,作者称之为秩崩塌。

第2部分说明了既然 Transformer 模型会出现秩崩塌的现象,但是为什么实际应用中还可以获得良好的训练呢?作者研究了一下3个角色的作用:Skip-Connections,Multi-Layer Perceptrons (MLP),Layer Normalization。通过证明发现:Skip-Connections,Multi-Layer Perceptrons (MLP) 可以起到抑制秩崩塌的效果,而 Layer Normalization 对抑制秩崩塌不起作用。

参考

[1] Jean-Baptiste Cordonnier, Andreas Loukas, and Martin Jaggi. Multi-head attention: Collaborate instead of concatenate. 2020.

  • 0
  • 0
  • 391
收藏
暂无评论