lstm token classification model code parsing (pass batch data directly, test pad and pack_padded, pad_packed three functions)

lstm token classification model code parsing


Code reference to teacher Che Wanxiang's <plm-nlp-code/chp4/lstm_postag.py >
You need to copy the entire folder to install nltk code to run properly

import nltk
nltk.download()
#Select to install all

1. Loading data

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from collections import defaultdict
from vocab import Vocab
from utils import 

batch_size=5#Originally just for testing, the result was set at 5...
train_data, test_data, vocab, pos_vocab = load_treebank()
train_dataset = LstmDataset(train_data)
test_dataset = LstmDataset(test_data)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)

1.2 Define dataset

class LstmDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data[i]

1.3 Set the collation function to package variable-length sequences

def collate_fn(examples):
    lengths = torch.tensor([len(ex[0]) for ex in examples])
    inputs = [torch.tensor(ex[0]) for ex in examples]
    targets = [torch.tensor(ex[1]) for ex in examples]
    #pad the variable length sequence to the same length
    inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab["<pad>"])
    targets = pad_sequence(targets, batch_first=True, padding_value=vocab["<pad>"])
    return inputs, lengths, targets, inputs != vocab["<pad>"]

1.3.2 pad_sequence function test

for batch in train_dataset[:5]:
    print(batch)

([2, 3, 4, 5, 6, 7, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [1, 1, 2, 3, 4, 5, 2, 6, 7, 8, 9, 10, 8, 5, 9, 1, 3, 11])
([19, 3, 20, 21, 22, 23, 24, 4, 10, 25, 26, 27, 18], [1, 1, 12, 9, 10, 1, 1, 2, 8, 1, 13, 9, 11])
([28, 29, 4, 30, 6, 7, 31, 32, 21, 22, 33, 34, 35, 36, 4, 37, 38, 39, 13, 14, 15, 22, 40, 41, 42, 43, 18], [1, 1, 2, 3, 4, 5, 14, 5, 9, 10, 1, 1, 1, 1, 2, 15, 16, 17, 8, 5, 9, 10, 8, 5, 5, 9, 11])
([44, 45, 22, 46, 47, 48, 49, 49, 50, 51, 52, 53, 54, 55, 56, 13, 57, 58, 22, 59, 60, 61, 13, 27, 22, 62, 63, 49, 50, 64, 65, 66, 67, 6, 68, 4, 69, 70, 71, 72, 18], [8, 9, 10, 9, 18, 16, 17, 17, 19, 7, 1, 9, 4, 12, 16, 8, 5, 9, 10, 9, 4, 10, 8, 9, 10, 4, 16, 17, 19, 20, 21, 10, 3, 4, 10, 2, 4, 15, 17, 17, 11])
([73, 46, 74, 4, 75, 4, 20, 76, 77, 47, 64, 78, 10, 79, 4, 80, 81, 82, 83, 50, 64, 84, 85, 86, 72, 87, 88, 89, 90, 4, 69, 91, 71, 92, 18], [8, 9, 9, 2, 9, 2, 12, 18, 5, 10, 20, 12, 8, 4, 2, 10, 18, 5, 4, 19, 20, 13, 4, 22, 17, 23, 24, 4, 5, 2, 4, 15, 17, 17, 11])
You can see the original dataset The data does vary in length.
for batch in train_data_loader:
    inputs, lengths, targets, mask = [x for x in batch]
    #output=collate_fn(batch)
    #print(output)
    break
test_input=tensor([[1815, 1041, 6262, 6229, 2383,  104, 1424,  177,  501, 1672,  503,  670,
           50,  501,  734,  503,  670,   13, 6224,   18,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1],
        [ 204, 7889,    4, 7890,  420,  159,   13, 7891, 5282,   22, 2943,    4,
         1413,   50, 6373,  420,  152, 7892, 4969,   22, 2943, 3599,  118, 1393,
           18,    1,    1,    1,    1,    1],
        [ 570, 1267, 1472,   99,    4,   22,   96, 6408,    4,   31,  105,  376,
            4, 2664,  666, 7961,    4, 1111,   22,   96, 6408,    4, 7951,   50,
           13,  501, 5074,  503, 1059,   18],
        [1209,    4,   10, 1210,  566, 1211, 1212, 1213,  480,  189,  148,   13,
         1214,  259, 1106, 1215,  589,   22,  105, 1216, 1217,   22, 1110,   22,
           96, 1093,   18,    1,    1,    1],
        [  73,  376,  794,  267, 1012, 4561,   40, 8305,  259,   39, 9309,  619,
         5722,   49,   50, 4873, 1062, 1355,    4,  449,   10, 9312, 2601, 5270,
         9313,  683,   50,   10,  624,   18]])
lengths=tensor([20, 25, 30, 27, 30])

targets= tensor([[ 8,  5,  9,  4, 15, 10,  9, 10, 35,  3, 17, 17, 19, 35,  3, 17, 17,  8,
          9, 11,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [10,  4,  2,  3,  9, 23,  8,  5,  9, 10, 25,  2, 16, 19,  3,  9, 31, 17,
         23, 10, 25,  9, 27,  9, 11,  1,  1,  1,  1,  1],
        [ 1,  1,  1,  1,  2, 10,  1,  1,  2, 14, 25,  9,  2,  1,  1,  1,  2, 18,
         10,  1,  1,  2, 15, 19,  8, 35,  3, 17,  9, 11],
        [18,  2,  8,  1,  1,  1,  1, 15, 24, 10, 18,  8,  9, 10,  1, 13,  9, 10,
         25,  5,  9, 10,  1, 10,  1,  1, 11,  1,  1,  1],
        [ 8,  9,  6, 18, 18,  7,  8,  9, 10, 17, 13,  4,  5, 17, 19,  7, 10,  9,
          2, 10,  8,  1,  1, 12,  5,  9, 19,  8,  9, 11]])
mask=tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])


The four values above are returned. (It's actually a series, no name)

  • inputs after pad
  • Length that records the original length of a sentence (which can be packaged by subsequently passing in the pack_padded_sequence function (compression removes the pad position)
  • Target and target true-false matrices after pad (the target at the pad position is false.)

1.3.3 pack_padded_sequence function test

The pack_padded_sequence source code is:

 def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
    ...
    if enforce_sorted:
        sorted_indices = None
    else:
        lengths, sorted_indices = torch.sort(lengths, descending=True)
        sorted_indices = sorted_indices.to(input.device)
        batch_dim = 0 if batch_first else 1
        input = input.index_select(batch_dim, sorted_indices)
        data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first)
    return _packed_sequence_init(data, batch_sizes, sorted_indices, None)

Look directly at the test results:

x_pack = pack_padded_sequence(input1,lengths, batch_first=True, enforce_sorted=False)
x_pack 
#In the result below, batch_sizes has 20 5, which means that all five sequences have been evaluated in the first 20 times. Therefore, the shortest sequence length is 20.
#There are then five four, representing the second shortest sequence length of 25. Analogue by analogy.
PackedSequence(data=tensor([ 570,   73, 1209,  204, 1815, 1267,  376,    4, 7889, 1041, 1472,  794,
          10,    4, 6262,   99,  267, 1210, 7890, 6229,    4, 1012,  566,  420,
        2383,   22, 4561, 1211,  159,  104,   96,   40, 1212,   13, 1424, 6408,
        8305, 1213, 7891,  177,    4,  259,  480, 5282,  501,   31,   39,  189,
          22, 1672,  105, 9309,  148, 2943,  503,  376,  619,   13,    4,  670,
           4, 5722, 1214, 1413,   50, 2664,   49,  259,   50,  501,  666,   50,
        1106, 6373,  734, 7961, 4873, 1215,  420,  503,    4, 1062,  589,  152,
         670, 1111, 1355,   22, 7892,   13,   22,    4,  105, 4969, 6224,   96,
         449, 1216,   22,   18, 6408,   10, 1217, 2943,    4, 9312,   22, 3599,
        7951, 2601, 1110,  118,   50, 5270,   22, 1393,   13, 9313,   96,   18,
         501,  683, 1093, 5074,   50,   18,  503,   10, 1059,  624,   18,   18]),

 batch_sizes=tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4,
        4, 3, 3, 2, 2, 2]), 
sorted_indices=tensor([2, 4, 3, 1, 0]), 
unsorted_indices=tensor([4, 3, 0, 2, 1]))

You can see that none of the five sequences has a pad value. The returned values are:

  • PackedSequence: Takes tokens of all sequences in sequence by position, skipping the pad position. For example, token at position 1 and token at position 2 in sequence of five sequences until the shortest sequence is completed. Then takes tokens of the remaining four sequences.
  • Bach_size represents the number of values returned for each time step. (For example, five sequences were taken for the first time and only four for the twenty-first time)
  • sorted_indices=tensor([2, 4, 3, 1, 0]) denotes the original position index of each element after the length elements have been descending.

Give an example:

torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)
1.Pairs along a given dimension input The elements of a tensor are sorted in ascending order. If a parameter dim If not given, choose `input` The last dimension.
2.descending=True,Then the elements are sorted in descending order.
3.stable=True,Then the sorting routine becomes stable, preserving the order of the equivalent elements
lengths=tensor([20, 25, 30, 27, 30])
lengths, sorted_indices = torch.sort(lengths, descending=True)
print(lengths,'\n',sorted_indices )

tensor([30, 30, 27, 25, 20]) 
tensor([2, 4, 3, 1, 0])

Output test of 1.3.4 lstm

embedding_dim = 128
hidden_dim = 256

class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class):
        super(LSTM, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, num_class)
        init_weights(self)

    def forward(self, inputs, lengths):
        embeddings = self.embeddings(inputs)
        x_pack = pack_padded_sequence(embeddings, lengths, batch_first=True, enforce_sorted=False)
        hidden, (hn, cn) = self.lstm(x_pack)
        #hidden, len = pad_packed_sequence(hidden, batch_first=True)
        #outputs = self.output(hidden)
        #log_probs = F.log_softmax(outputs, dim=-1)
        return log_probs

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTM(len(vocab), embedding_dim, hidden_dim, num_class)
model.to(device) #Load the model into the GPU (if installed correctly)

Test:
model1 commented out the three lines above and looked directly at the result hidden when x_pack entered lstm.
model2 is the result of commenting two lines plus pad_packed_sequence.

model1=LSTM1(len(vocab), embedding_dim, hidden_dim, num_class)
model2=LSTM2(len(vocab), embedding_dim, hidden_dim, num_class)
for batch in train_data_loader:
    inputs, lengths, targets, mask = [x for x in batch]
    hidden1= model1(inputs, lengths)
    print(inputs,inputs.shape)
    print('    1    ')
    print(hidden1[0],hidden1[1],hidden1[2],hidden1[3])
    print('    2  ++++++++++++++++++++++++++     ')
    print(hidden1[0].shape,hidden1[1].shape,hidden1[2].shape,hidden1[3].shape)
    
    print('***************************************************************')
    print('***************************************************************')
    hidden2= model2(inputs, lengths)
    print(inputs,inputs.shape)
    print('         3   #######################################              ')
    print(hidden2[0],hidden2[1],hidden2[2],hidden2[3])
    print('          4   ======================================          ')
    print(hidden2[0].shape,hidden2[1].shape,hidden2[2].shape,hidden2[3].shape)
    
    break
input=tensor([[1735, 4372, 4402,   39,   50, 4403,  811,  129, 4404,  149,  104,   10,
         2901,  118, 4405,   31, 4406, 4407,    4,   31, 1736, 1111,  523,  370,
          125, 1181, 1488,  104, 2943, 4408,   18,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1],
        [7110,  281, 4158,   22, 8176,    4,   64, 5679, 5260,   39,   50, 1168,
         6567,   22, 9791, 8695,  104,   10, 1424,   22,   13, 2374, 1968,  214,
          104,   40, 1386,    4,   13, 9766,   22, 2373,  214,  104,  193, 2470,
          214,  104,   40, 1386,   10,   96, 1231, 1232, 1233,   31,   10, 3962,
         7159, 2470,   18],
        [ 379, 4967,    4,   39, 4968,   86, 4955,    8, 4969,   13, 1724, 4970,
          294, 4925, 4926, 2434,    4, 4971, 1127,   88, 1225,   22, 4955, 1873,
          157,   10,  987,   18,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1],
        [  19, 7704, 1895,   19, 2113,   31, 1307, 3330, 4116, 1189,    4, 7798,
         7799,    4,  137,   39, 4027,   50, 1712,  157, 7800,   18,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1]]) 
            
torch.Size([4, 51])
hidden1=tensor([[ 0.0029,  0.0193,  0.0212,  ...,  0.0001, -0.0066,  0.0003],
        [-0.0144,  0.0252,  0.0266,  ...,  0.0096,  0.0107, -0.0115],
        [ 0.0010,  0.0102,  0.0305,  ...,  0.0034, -0.0094, -0.0114],
        ...,
        [-0.0067,  0.0104,  0.0418,  ..., -0.0184, -0.0153, -0.0887],
        [-0.0151,  0.0101,  0.0386,  ..., -0.0038, -0.0186, -0.0766],
        [-0.0219,  0.0197,  0.0579,  ...,  0.0024, -0.0250, -0.0852]],
       grad_fn=<CatBackward>)
       
        tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3,
        3, 3, 3, 3, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1])#hidden[1] is length. The original sequence length [22,28,31,51], and 132.
        
        tensor([1, 0, 2, 3])
        tensor([1, 0, 2, 3])
        #shape of hidden1 elements.
        torch.Size([132, 256]) torch.Size([51]) torch.Size([4]) torch.Size([4])

You can see that x_ The direct output of lstm input after pack has four elements:

  • The model output vector, shape torch.Size([132, 256]), is the length of the sequence straightening after the pad is removed.
  • Pack_ Padded_ length list when sequencing, 51 long.
  • sorted_indices and unsorted_indices, the same value.
hidden2=tensor([[-0.0149,  0.0012,  0.0159,  ..., -0.0335, -0.0006, -0.0271],
        [-0.0339,  0.0047,  0.0248,  ..., -0.0422, -0.0209, -0.0569],
        [-0.0379,  0.0104,  0.0459,  ..., -0.0629, -0.0205, -0.0533],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward>) 

tensor([[-0.0121,  0.0010,  0.0193,  ..., -0.0180, -0.0015, -0.0323],
        [-0.0154,  0.0147,  0.0470,  ..., -0.0219, -0.0197, -0.0456],
        [-0.0045,  0.0210,  0.0554,  ..., -0.0435, -0.0271, -0.0440],
        ...,
        [-0.0362, -0.0124,  0.0628,  ..., -0.0643, -0.0345, -0.0596],
        [-0.0319, -0.0056,  0.0538,  ..., -0.0695, -0.0578, -0.0691],
        [-0.0331, -0.0223,  0.0400,  ..., -0.0821, -0.0560, -0.0682]],
       grad_fn=<SelectBackward>) 

tensor([[-0.0119, -0.0006,  0.0208,  ..., -0.0235, -0.0029, -0.0189],
        [-0.0590,  0.0138,  0.0322,  ..., -0.0340, -0.0180, -0.0385],
        [-0.0382,  0.0104,  0.0454,  ..., -0.0520, -0.0409, -0.0286],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward>) 

tensor([[-0.0165,  0.0155,  0.0077,  ..., -0.0130,  0.0058, -0.0343],
        [-0.0369,  0.0260,  0.0282,  ..., -0.0433, -0.0221, -0.0409],
        [-0.0434,  0.0133,  0.0344,  ..., -0.0546, -0.0403, -0.0254],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward>)
          4   ======================================          
torch.Size([51, 256]) torch.Size([51, 256]) torch.Size([51, 256]) torch.Size([51, 256])

1.3.5 pad_packed_sequence function

def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
    max_seq_length = sequence.batch_sizes.size(0)
    if total_length is not None:
        max_seq_length = total_length
    ...

total_length, which is the length of the sequence to be padding, is generally expected to be padding to the same time_as the input sequence Step length, but PackedSequence data does not record it, so it uses sequence.batch_sizes.size(0), batch_ The length of the sizes tensor.

Summary:

  1. pad_sequence packages variable-length sequences and writes them in collation functions. Input after returning pad, original length, and sorted_indices, unsorted_indices.
  2. pack_padded_sequence eliminates the pad portion (which needs to be passed in the length of the previous step) and gets that PackedSequence is a one-dimensional array, picking up the sequence values for each time step, which are no longer in the same position as the original sentence sequence. Moreover, the results are obtained by directly entering the lstm model. Must be handled
  3. pad_packed_sequence restores the results of the previous step to their original order and pads to the same length.
  4. So pack_padded_sequence and pad_packed_sequence must be used together, especially when token is labeled. If sentence classification uses the hidden vector of the last moment, it is not necessary to restore the original order.

Keywords: lstm batch

Added by jwagner on Tue, 09 Nov 2021 19:01:14 +0200