# pytorch nn.Transformer 的 mask 理解

>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
>>> src = torch.rand((10, 32, 512))
>>> tgt = torch.rand((20, 32, 512))
>>> out = transformer_model(src, tgt) # 没有实现position embedding ，也需要自己实现mask机制。否则不是你想象的transformer

• src – the sequence to the encoder (required).
• tgt – the sequence to the decoder (required).

def forward(self, query, key, value, key_padding_mask=None,
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:(L, N, E) where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:(S, N, E), where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:(S, N, E) where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:(N, S) where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of True will be ignored while the position with the value of False will be unchanged.
- attn_mask: 2D mask :math:(L, S) where L is the target sequence length, S is the source sequence length.
3D mask :math:(N*num_heads, L, S) where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with True
is not allowed to attend while False values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.

- Outputs:
- attn_output: :math:(L, N, E) where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:(N, L, S) where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""

[
[‘a’,'b','c','d'],
]

[
[False, False, False, True],
[False, False, False, False],
[False, False, True, True]
]

    def generate_square_subsequent_mask(self, sz: int) -> Tensor:
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
return mask

• 'a'可以看到'a'
• 'b'可以看到'a','b'
• 'c'可以看到'a','b','c'
• '\'理论上不应该看到什么，但是只要它头顶的监督信号是ignore_index，那就没有关系，所以让他看到'a','b','c','\'