torch.cat()
torch.cat(tensors, dim=0, *, out=None) → Tensor
Official explanation: connect a given tensor sequence with a given dimension (cat stands for concatenate). All tensors must have the same shape (except the connection dimension) or be empty.
It is equivalent to splicing the tensor sequence according to the specified dimension
Parameter interpretation:
- tensors: tensor sequence to be connected (tuple or list)
- dim: Dimension of tensor connection
- out: output tensor (generally not used. If there is output, it can be directly assigned)
be careful:
① tensors input must be tensor sequence, not a single tensor;
② The input tensor sequence must have the same shape except dim dimension.
give an example:
import torch a=torch.arange(6).reshape(2,3) b=torch.arange(12) c=torch.cat((a,b.reshape(4,3)),dim=0) # Splicing along dimension 0, that is, splicing by row (vertical) d=torch.cat((a,b.reshape(2,6)),dim=1) # Splicing along the first dimension, that is, splicing by column (horizontally) print(c) print(c.shape) print(d) print(d.shape)
Output:
tensor([[ 0, 1, 2], [ 3, 4, 5], [ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) torch.Size([6, 3]) tensor([[ 0, 1, 2, 0, 1, 2, 3, 4, 5], [ 3, 4, 5, 6, 7, 8, 9, 10, 11]]) torch.Size([2, 9])
Using torch.cat() to splice along the dim is equivalent to adding the dim in shape, and the size of other dimensions remains unchanged. Using this idea, it is easy to understand the splicing of high-dimensional arrays
High dimensional example:
import torch a=torch.ones(4*256*56*56).reshape(4,256,56,56) b=torch.arange(4*128*56*56).reshape(4,128,56,56) c=torch.zeros(4*64*56*56).reshape(4,64,56,56) d=torch.cat((a,b,c),dim=1) print(d.shape)
Output:
torch.Size([4, 448, 56, 56])
The above examples are often used for stacking feature graphs in convolutional neural networks.
torch.stack()
torch.stack(tensors, dim=0, *, out=None) → Tensor
Official explanation: connect a series of tensors along the new dimension, and all tensors need to have the same size.
It is equivalent to expanding multiple n-dimensional tensors, and then splicing them into an n+1-dimensional tensor
Parameter interpretation:
- tensors: tensor sequence to be connected (tuple or list)
- Dim: the dimension to be inserted. The size must be between 0 and the dimension of the tensor to be spliced (the maximum dim does not exceed the dimension of the tensor)
- out: output tensor (similar to cat(), generally not used)
be careful:
① Similar to cat(), tensor sequence must be input and cannot be a single tensor;
② The shape (size) of all input tensor sequences must be consistent (different from cat here).
give an example:
import torch a=torch.arange(12).reshape(3,4) b=torch.ones(12).reshape(3,4) c=torch.stack((a,b),dim=0) d=torch.stack((a,b),dim=1) e=torch.stack((a,b),dim=2) # dim can the dimension of input tensor, that is, the dimension of a and b print(c) print(c.shape) print(d) print(d.shape) print(e) print(e.shape)
Output:
tensor([[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]]) torch.Size([2, 3, 4]) tensor([[[ 0., 1., 2., 3.], [ 1., 1., 1., 1.]], [[ 4., 5., 6., 7.], [ 1., 1., 1., 1.]], [[ 8., 9., 10., 11.], [ 1., 1., 1., 1.]]]) torch.Size([3, 2, 4]) tensor([[[ 0., 1.], [ 1., 1.], [ 2., 1.], [ 3., 1.]], [[ 4., 1.], [ 5., 1.], [ 6., 1.], [ 7., 1.]], [[ 8., 1.], [ 9., 1.], [10., 1.], [11., 1.]]]) torch.Size([3, 4, 2])
Carefully observe the change of the dimension of the last case. It can be found that when the input is two sets of tensors, what is the dim, and which dimension after splicing is 2 (with two input tensors), which is equivalent to an expanded dimension splicing operation. First add a dimension according to dim, and then splice from this dimension.
The difference between cat and stack
torch.cat() directly splices the original tensor data without changing the dimension; torch.stack first expands the dimension, and then splices, which will increase the dimension by one unit.
Official documents
torch.cat():https://pytorch.org/docs/stable/generated/torch.cat.html
torch.stack():https://pytorch.org/docs/stable/generated/torch.stack.html