One hot coding of category tensors by PyTorch
This article has authorized the platform of the polar city and is the official account of the polar platform. No second reprint is allowed without permission
- Original document: https://www.yuque.com/lart/ugkv9f/src5w8
- Code warehouse: https://github.com/lartpang/CodeForArticle/tree/main/OneHotEncoding.PyTorch
preface
One hot coding is very common in deep learning tasks, but it is not a very natural way of data storage. So in most cases, we need to convert manually. Although the idea is very direct, that is, to split categories into one-to-one corresponding 0-1 vectors, it really needs to think about the specific implementation. In fact, pytorch itself is in NN One is already provided in functional_ Hot method to quickly apply. But this does not affect our thinking and practice: >! Therefore, this paper sorts out the ways to realize one hot coding based on the common methods in pytorch as far as possible, hoping to be useful.
There are several main ways:
- for loop
- scatter
- index_select
for loop
This method is very intuitive. To put it bluntly, it is just to assign (assign 1) to the specified position in a blank (all zero) tensor.
The key is how to set the index.
Two schemes with the same essence but slightly different due to different specified dimensions are designed below.
def bhw_to_onehot_by_for(bhw_tensor: torch.Tensor, num_classes: int): """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """ assert bhw_tensor.ndim == 3, bhw_tensor.shape assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor) one_hot = bhw_tensor.new_zeros(size=(num_classes, *bhw_tensor.shape)) for i in range(num_classes): one_hot[i, bhw_tensor == i] = 1 one_hot = one_hot.permute(1, 2, 3, 0) return one_hot def bhw_to_onehot_by_for_V1(bhw_tensor: torch.Tensor, num_classes: int): """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """ assert bhw_tensor.ndim == 3, bhw_tensor.shape assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor) one_hot = bhw_tensor.new_zeros(size=(*bhw_tensor.shape, num_classes)) for i in range(num_classes): one_hot[..., i][bhw_tensor == i] = 1 return one_hot
scatter
This method should be the most concise one on the Internet_ Hot is a common form of writing. In fact, its main function is to assign values to the positions specified in the tensor.
It is more flexible because it can use a specially constructed index matrix as an index. Of course, flexibility brings difficulties in understanding. The explanation provided in the official documents is very intuitive:
''' https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html * (int dim, Tensor index, Tensor src) * (int dim, Tensor index, Tensor src, *, str reduce) * (int dim, Tensor index, Number value) * (int dim, Tensor index, Number value, *, str reduce) ''' self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
The in place version is used in the document and is explained based on the replacement value of src, i.e. tensor. In fact, our application is mainly based on the in-situ replacement version and the replacement value is in the form of scalar floating-point number value.
In the above form, we can see that by specifying the parameter tensor index, we can place the value of (i,j,k) in src at the specified position of the method caller (here, self). The specified position is composed of the value at (i,j,k) of the index replacing the value of the dim position in the coordinate (i,j,k) (this also reflects a requirement of the index tensor, that is, the number of dimensions should be consistent with self and src (if src is tensor, the specific scalar value 1 is used later, that is, src is replaced with value). This is very consistent with the concept of one hot. Because the formal meaning of one hot itself is that for class i data, the ith position is 1 and the other positions are 0. So use scatter for all zero tensors_ It is very easy to construct a one hotsensor, that is, place 1 at the position corresponding to the category number.
For our problem, index is very suitable to be represented by the entered tensor (shape B,H,W) containing category number. Based on this thinking, two different strategies can be conceived:
def bhw_to_onehot_by_scatter(bhw_tensor: torch.Tensor, num_classes: int): """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """ assert bhw_tensor.ndim == 3, bhw_tensor.shape assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor) one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes)) one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1) one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes) return one_hot def bhw_to_onehot_by_scatter_V1(bhw_tensor: torch.Tensor, num_classes: int): """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """ assert bhw_tensor.ndim == 3, bhw_tensor.shape assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor) one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes)) one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1) return one_hot
The root of the difference between the two forms lies in the treatment of shape. This brings different application forms of scatter.
For the first form, the three dimensions of B, h and W are combined. This has the advantage that the understanding of the index of the channel (category) becomes intuitive.
one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes)) one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1)
Here, the category dimension is directly separated from other dimensions and moved to the last position. This dimension is specified through dim, so there is such a corresponding relationship:
zero_tensor[abc, index[abc][d]] = value # d=0
In the second case, the first three dimensions are still retained, and the category dimension is still moved to the last.
one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes)) one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)
The corresponding relationship is as follows:
zero_tensor[a,b,c, index[a][b][c][d]] = value # d=0
In addition, a similar method is used in the pytorch classification model library timm:
# https://github.com/rwightman/pytorch-image-models/blob/2c33ca6d8ce5d9257edf8cab5ab7ece81780aaf7/timm/data/mixup.py#L17-L19 def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): x = x.long().view(-1, 1) return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
index_select
torch.index_select(input, dim, index, *, out=None) → Tensor - input (Tensor) – the input tensor. - dim (int) – the dimension in which we index - index (IntTensor or LongTensor) – the 1-D tensor containing the indices to index
This function, as its name implies, uses the index to select the child tensors of the specified dimension of the tensor.
To understand the motivation of this approach, you actually need to look at one hot coding from the perspective of category tags.
For the original categories arranged from small to large, the matrix encoded by one hot corresponding to the sequence number is an identity matrix. Therefore, each category corresponds to a specific column (or row) of the identity matrix. This requirement is just in line with the index_select function. So we can use its implementation_ Hot coding, you only need to use the category serial number to index a specific column or row. Here is an example:
def bhw_to_onehot_by_index_select(bhw_tensor: torch.Tensor, num_classes: int): """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """ assert bhw_tensor.ndim == 3, bhw_tensor.shape assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor) one_hot = torch.eye(num_classes).index_select(dim=0, index=bhw_tensor.reshape(-1)) one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes) return one_hot
Performance comparison
The overall code is visible GitHub.
The following shows the general relative performance of different methods (because the background is running the program, it may not be very accurate, so we recommend you to test it yourself). It can be seen that the functions provided by pytorch are not very efficient on the CPU, but they perform well on the GPU. Interestingly, based on index_ The form of select is very eye-catching.
1.10.0 GeForce RTX 2080 Ti cpu ('bhw_to_onehot_by_for', 0.5411529541015625) ('bhw_to_onehot_by_for_V1', 0.4515676498413086) ('bhw_to_onehot_by_scatter', 0.0686192512512207) ('bhw_to_onehot_by_scatter_V1', 0.08529376983642578) ('bhw_to_onehot_by_index_select', 0.05156970024108887) ('F.one_hot', 0.07366824150085449) gpu ('bhw_to_onehot_by_for', 0.005235433578491211) ('bhw_to_onehot_by_for_V1', 0.045584678649902344) ('bhw_to_onehot_by_scatter', 0.0025513172149658203) ('bhw_to_onehot_by_scatter_V1', 0.0024869441986083984) ('bhw_to_onehot_by_index_select', 0.002012014389038086) ('F.one_hot', 0.0024051666259765625)