Attention is All You Need paper notes and pytorch code Notes

Self reference Li Mu read the paper and pytorch code

I don't understand

  • residual network
  • Position-wise
  • Layer norm
  • Encoder attention

Parameter setting

## dimension
d_model = 512 # Dimensions of sub layers, embedding layers and outputs (an addition operation to make use of residual connection)
d_inner_hid = 2048 # Dimension of Feed Forward(MLP) [d_ff]
d_k = 64 # Dimension of key
d_v = 64 # Dimension of value
## other
n_head = 8 # Number of long attention mechanisms [h]
n_layers = 6 # Number of layers of encoder/decoder [N]

[] is the mark in the paper

Model body

1. Encoder

1.1. Model structure diagram

It is composed of h=6 identical structures and is divided into two sub layers:

1.2.Multi-Head Attention

1.2.1.Attention
  • The Attention function involves query and key value pairs
  • The output of attention is the weighted sum of value s. The weight comes from the similarity between query and key and is calculated by the compatibility function.
  • There are many kinds of compatibility function s for calculating attention, such as additive attention (using different lengths of query and key) or multiplicative dot product attention. This paper selects a relatively simple scaled dot product attention. Query and key are inner products. The larger the value, the more similar the table is. Because the two vectors have the same length, the larger the inner product, the greater the cosine, the smaller the included angle, and the more similar. If the inner product is 0, the two vectors are orthogonal.
1.2.2.Scaled Dot-Product Attention
  • A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk​ ​QKT​)V

  • Matrix multiplication can be parallelized, which is different from CNN and RNN.

  • 1 query and d k d_k dk # a key to do dot multiplication and calculate it d k d_k dk , values are obtained through a softmax d k d_k dk is a weight whose sum is 1.

  • Scaled is reflected in 1 d k \frac{1}{\sqrt{d_k}} dk​ 1. Because d k d_k When dk # is large, do softmax, and the weight is not 0 or 1, resulting in the disappearance of the gradient. Kind of like distillation logits.

    use 1 d k \frac{1}{\sqrt{d_k}} dk​ The reason for 1 ¢ is: suppose q and k are iid variables with mean value of 0 and variance of 1, and their dot product q ⋅ k = ∑ i = 1 d k q i k i q·k=\sum_{i=1}^{d_k}q_ik_i The mean value of q ⋅ k = ∑ i=1dk ⋅ qi ki ⋅ is 0 and the variance is 0 d k d_k dk​. So the scaled operation changes the variance to 1.

1.2.3.Multi-Head Attention
  • M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O MultiHead(Q,K,V)=Concat(head1​,...,headh​)WO
  • h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) headi​=Attention(QWiQ​,KWiK​,VWiV​)
  • It is equivalent to projecting Q, K and V into different low dimensional spaces h times, and it is equivalent to calculating attention with pos representing different subspaces, which is a bit like CNN channels.
  • W i Q , W i K ∈ R d m o d e l × d k W_i^Q,W_i^K\in R^{d_{model}×d_k} WiQ​,WiK​∈Rdmodel​×dk​, W i V ∈ R d m o d e l × d v W_i^V\in R^{d_{model}×d_v} WiV​∈Rdmodel​×dv​, W O ∈ R h d v × d m o d e l W^O\in R^{hd_v×d_{model}} WO∈Rhdv​ × dmodel, projection matrix, parameter learnable
  • h = 8 , d k = d v = d m o d e l / h = 64 h=8,d_k=d_v=d_{model}/h=64 h=8,dk​=dv​=dmodel​/h=64
  • When inputting, Q, K and V are actually the same. The first layer is input_embedding+positional encoding

1.3.Feed Forward

  • Simple full connection layer, expressed as F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x)=max(0,xW_1+b_1)W_2+b_2 FFN(x)=max(0,xW1​+b1​)W2​+b2​

    • max is the activation function of Relu
    • W 1 W_1 What is the dimension of W1 ( d m o d e l , d f f ) (d_{model},d_{ff}) (dmodel​,dff​) , W 2 W_2 The dimension of W2 # is ( d f f , d m o d e l ) (d_{ff},d_{model}) (dff​,dmodel​), d f f d_{ff} dff # happens to be d m o d e l d_{model} dmodel, which is equivalent to quadrupling the dimension before restoring it
    • It is also equivalent to two convolution layers with convolution kernel size of 1
  • The FF parameters of the six layers are different, but the FF of the same layer is the same, which must be applied once for each word

1.4. Other parts

  • Input Embedding: the input is a word sequence, which is converted into a word embedded tensor

  • Positional Encoding: because the calculation of attention is for the whole, the same is true for disruption, so relative or absolute location information should be added. It is directly added with input embedding, so the dimension is also d m o d e l d_{model} dmodel, there are many ways to calculate the location information. The sin and cos functions used in this paper.

    • P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i)=sin(pos/10000^{2i/d_{model}}) PE(pos,2i)=sin(pos/100002i/dmodel​)

    • P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i+1)=cos(pos/10000^{2i/d_{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel​)

      • pos stands for the position of the word; i stands for the dimension, and the value [0,d_model) is to scale the trigonometric function into d_model functions
    • The reason for choosing trigonometric function is that it can well represent the relative position relationship, because P E p o s + k PE_{pos+k} PEpos+k , can be used P E p o s PE_{pos} Expressed by the linear equation of PEpos

    • Sine two angle sum formula s i n ( p o s + k ) = s i n ( p o s ) c o s ( k ) + c o s ( p o s ) s i n ( k ) sin(pos+k)=sin(pos)cos(k)+cos(pos)sin(k) sin(pos+k)=sin(pos)cos(k)+cos(pos)sin(k),

      Namely P E ( p o s + k , 2 i ) = P E ( p o s , 2 i ) c o s ( k ) + P E ( p o s , 2 i + 1 ) s i n ( k ) PE(pos+k,2i)=PE(pos,2i)cos(k)+PE(pos,2i+1)sin(k) PE(pos+k,2i)=PE(pos,2i)cos(k)+PE(pos,2i+1)sin(k), omitting dimension

      When k is a constant, s i n ( k ) sin(k) sin(k) and c o s ( k ) cos(k) cos(k) is also a constant

  • residual connection and layer normalization: after each sub layer, it is expressed as L a y e r N o r m ( x + S u b l a y e r ( x ) ) LayerNorm(x+Sublayer(x)) LayerNorm(x+Sublayer(x))

    • residual connection
    • layer normalization

2.Decoder

2.1. Model structure diagram

2.2.Masked Multi-Head Attention(self-attention)

For the added part of the decoder, only the previous position can be seen during prediction, and the subsequent position cannot be seen. When calculating the attention, use a large negative number to replace the attention at the mask position before softmax, so that the result of passing through softmax is very close to 0.

2.3.Multi-Head Attention(encoder-attention)

  • Q: decoder_output
  • K: encoder_output
  • V: encoder_output

3. Code

3.1.Encoder architecture

class Encoder(nn.Module):
    ''' A encoder model with self attention mechanism. '''
    def __init__(
            self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False):

        super().__init__()

        self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) # [Input Embedding layer]
        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) # [Positional Encoding]
        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)]) # [Encoder Layer]*n_layers
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) # [layer norm]
        self.scale_emb = scale_emb # boolean, scaled embed ded
        self.d_model = d_model

    def forward(self, src_seq, src_mask, return_attns=False):

        enc_slf_attn_list = []

        # -- Forward
        enc_output = self.src_word_emb(src_seq)
        if self.scale_emb:
            enc_output *= self.d_model ** 0.5 # I don't know why
        enc_output = self.dropout(self.position_enc(enc_output))
        enc_output = self.layer_norm(enc_output)

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
            enc_slf_attn_list += [enc_slf_attn] if return_attns else []

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output,

3.2.Positional Encoding

class PositionalEncoding(nn.Module):

    def __init__(self, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        # Not a parameter
        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

    def _get_sinusoid_encoding_table(self, n_position, d_hid):
        ''' Sinusoid position encoding table '''
        # TODO: make it with torch instead of numpy

        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] # 2*(hid_j//2) make 2i and 2i+1 correspond to 2i
		
        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) # pos(0,200)
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i (0::2 means starting from 0, intercepting 1 for every 2 hops)
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach() # x + pos_ Intercepting dimension x_ encoding

3.3.EncoderLayer

Contains two subs_ layers

class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) # Including residual network & layer norm
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask) # All inputs are enc_input
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn

MultiHeadAttention

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
				
        # d_k=d_v=d_model/n_head = 512/8 = 64
        self.n_head = n_head # 8
        self.d_k = d_k # 64
        self.d_v = d_v # 64

        # The projection matrix W and linear are represented together, and then the shape is reconstructed through view
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False) # concat is followed by a linear

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(self, q, k, v, mask=None):
				# q. K and V are actually the same
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q # residual is one of the inputs

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn

ScaledDotProductAttention

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):

        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9) # mask is 0 and is replaced by a large negative number

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn

PositionWiseFeedForward

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x

Keywords: AI Pytorch Deep Learning NLP

Added by lyasian on Wed, 09 Feb 2022 17:29:29 +0200