While longformer extends maxlen, there are many structural difficulties, which are gradually analyzed here.
_ sliding_ chunks_ query_ key_ Structural transformation in matmul function
The hardest thing to understand here is these sentences
query_size = list(query.size()) query_size[1] = query_size[1]*2-1 query_stride = list(query.stride()) query_stride[1] = query_stride[1]//2 query = query.as_strided(size=query_size,stride=query_stride)
First, query here_ Size and key_size is an integer multiple of 512, because the longest length of longformer is 4096. For values that are not integer multiples, longformer will be automatically filled into integer multiples.
The query entered here_ Size = (24,1512,64), (24,2512,64),... (24,n,512,64) and other situations, among which batch_size = 2, first 24=batch_size*num_heads: the second 1 or 2 has several 512 bits, and then the last two bits are generally fixed as 512 and 64512. 512 is the fixed length of a cycle of longformer, and 64 is size_per_head is the size of an attention head.
So this is essentially query_size[1] multiply by 2 and then subtract 1, query_ New tensor content obtained after stripe [1] / / 2
Here, because the values of query and key are too large, we can't directly see the transformed tensor. We can use the method of reducing complexity to simplicity, first look at the transformation of small tensor, and then find the law.
import torch import numpy as np import random def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True # Set random number seed setup_seed(20) data = torch.rand(5,2,4,3) print('data = ') print(data) data_size = list(data.size()) data_size[1] = data_size[1]*2-1 data_stride = list(data.stride()) data_stride[1] = data_stride[1]//2 data = data.as_strided(size=data_size,stride=data_stride) print('data = ') print(data)
1. When size[1] = 1, the data remains unchanged
2. When size[1] = 2, the middle of the data changes to multiply 2 plus one, and the corresponding data changes as follows:
Original data:
data = tensor([[[[0.5615, 0.1774, 0.8147], [0.3295, 0.2319, 0.7832], [0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156]], [[0.9423, 0.2536, 0.7388], [0.5404, 0.4356, 0.4430], [0.6257, 0.0379, 0.7130], [0.3229, 0.9631, 0.2284]]]])
Current data
print('data = ') print(data) data = tensor([[[[0.5615, 0.1774, 0.8147], [0.3295, 0.2319, 0.7832], [0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156]], [[0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156], [0.9423, 0.2536, 0.7388], [0.5404, 0.4356, 0.4430]], [[0.9423, 0.2536, 0.7388], [0.5404, 0.4356, 0.4430], [0.6257, 0.0379, 0.7130], [0.3229, 0.9631, 0.2284]]]])
It can be seen that since 4 in the second dimension is an even number, the data in the middle is spliced from the upper and lower waves of data
[[0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156], [0.9423, 0.2536, 0.7388], [0.5404, 0.4356, 0.4430]]
If the intermediate data is odd, test one wave, and the original data is
data = tensor([[[[0.5615, 0.1774, 0.8147], [0.3295, 0.2319, 0.7832], [0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156], [0.9423, 0.2536, 0.7388]], [[0.5404, 0.4356, 0.4430], [0.6257, 0.0379, 0.7130], [0.3229, 0.9631, 0.2284], [0.4489, 0.2113, 0.6839], [0.7478, 0.4627, 0.7742]]]])
Experienced as_ After the stripe function, the new data is
data = tensor([[[[0.5615, 0.1774, 0.8147], [0.3295, 0.2319, 0.7832], [0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156], [0.9423, 0.2536, 0.7388]], [[0.1012, 0.1877, 0.9310], [0.0899, 0.3156, 0.9423], [0.2536, 0.7388, 0.5404], [0.4356, 0.4430, 0.6257], [0.0379, 0.7130, 0.3229]], [[0.7388, 0.5404, 0.4356], [0.4430, 0.6257, 0.0379], [0.7130, 0.3229, 0.9631], [0.2284, 0.4489, 0.2113], [0.6839, 0.7478, 0.4627]]]])
It can be seen that the data is extracted from the middle wave
0.1012, 0.1877], [0.9310, 0.0899, 0.3156], [0.9423, 0.2536, 0.7388]], [[0.5404, 0.4356, 0.4430], [0.6257, 0.0379, 0.7130], [0.3229,
And the last wave extracted the data
0.7388]], [[0.5404, 0.4356, 0.4430], [0.6257, 0.0379, 0.7130], [0.3229, 0.9631, 0.2284], [0.4489, 0.2113, 0.6839], [0.7478, 0.4627,
Constitute a new tensor content
3. When size[1]=3, the corresponding data changes as follows:
data = tensor([[[[0.5615, 0.1774, 0.8147], [0.3295, 0.2319, 0.7832], [0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156]], [[0.9423, 0.2536, 0.7388], [0.5404, 0.4356, 0.4430], [0.6257, 0.0379, 0.7130], [0.3229, 0.9631, 0.2284]], [[0.4489, 0.2113, 0.6839], [0.7478, 0.4627, 0.7742], [0.3861, 0.0727, 0.8736], [0.3510, 0.3279, 0.3254]]]])
The content after the change is as follows (similar to the change of size[1]=2)
data = tensor([[[[0.5615, 0.1774, 0.8147], [0.3295, 0.2319, 0.7832], [0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156]], [[0.8544, 0.1012, 0.1877], [0.9310, 0.0899, 0.3156], [0.9423, 0.2536, 0.7388], [0.5404, 0.4356, 0.4430]], [[0.9423, 0.2536, 0.7388], [0.5404, 0.4356, 0.4430], [0.6257, 0.0379, 0.7130], [0.3229, 0.9631, 0.2284]], [[0.6257, 0.0379, 0.7130], [0.3229, 0.9631, 0.2284], [0.4489, 0.2113, 0.6839], [0.7478, 0.4627, 0.7742]], [[0.4489, 0.2113, 0.6839], [0.7478, 0.4627, 0.7742], [0.3861, 0.0727, 0.8736], [0.3510, 0.3279, 0.3254]]]])
In attention_ Shape change of scores
here
attentions = torch.matmul(query,key.transpose(-1,-2))
Get
attentions = (24,5,512,512) (Incoming key = (24,5,512,512),value=(24,5,512,512))
Next, padding is paved with the value 0 on the last layer
attention_scores = nn.functional.pad( attention_scores,(0,0,0,1) )
Get attention_scores = (24,5,513,512)
Then call the view function.
attention_scores = attention_scores.view(*attention_sccores.size()[:-2],attention_scores.size(-1),attention_scores.size(-2))
Here, each layer was originally paved with (24,5513512) - > (24,5512513). The last layer of each layer was originally paved with 512 zeros (a total of 24 * 5, which has nothing to do with 513. 513 is an additional column). Now after changing the shape, 512 zeros are still paved, so there is an additional non-zero number in each last column, and finally 512 zeros