PointNet + + up sampling (Feature Propagation)

PointNet + + needs to restore the down sampled points to the same number of points as the input when processing the segmentation task, so as to facilitate the prediction of each point. But in the paper, I only give a simple description and formula, which is not very easy to understand, so I record my understanding process here.

1. Purpose of FP module

PointNet + + will reduce the number of sampling points layer by layer with the network, so as to ensure that the network obtains sufficient global information, but this will lead to the inability to complete the segmentation task, because the segmentation task is end-to-end, and the output and input points must be the same.
One way to complete the segmentation task is to always put all the points into the network for calculation without dropping the sampling points. But this requires a lot of computing costs. Another common method is interpolation, which uses known points to supplement the required points.
The purpose of FP module is to interpolate the known feature points to make the network output the same feature as the input points. See the next step for specific practices.

2. How to interpolate

For how to interpolate, the pixel value of the interpolation point is determined according to the points adjacent to the pixel point in the 2d image. This interpolation method is equivalent to finding the nearest pixel to determine the pixel value of the inserted point. Corresponding to the point cloud, the feature can also be interpolated according to the distance.


Specifically, see the author's description. For the deep features of NL layer, it is assumed that it has 128 points, and each point has d+C features; Its previous layer, NL-1, assumes 512 points. The number of points of NL-1 is greater than that of NL, because we will perform down sampling before, and the points will be compressed.

Now, we know the characteristics and coordinates of NL-1 layer and NL layer. The purpose is to up sample 128 points to 512 points. How to achieve it? The answer is to use the of distance to interpolate (pay special attention to the significance of distinguishing features and coordinates). The specific steps are as follows:

1) We use the coordinates of NL-1 layer and NL layer to calculate the distance between any two points, then we will get a distance matrix,
Its size is 512x128. It means the distance between each point in nl-1 (low dimension) and each point in NL (high dimension).

2) Sort the distance matrix, find the three closest points in NL layer and NL-1 layer, record their values and indexes, and mark them as dist and idx. Note that the size of the index and distance matrix at this time is 512x3, that is, the index of the three closest points in NL-1.

3) Take idx to the features of NL layer for query, and get the feature matrix of 512x3x(c+d). It's hard to understand why 128 can be upsampled to 512 in this step. In fact, the sampling will be repeated here, because the idx matrix obtained earlier is 512x3, and the three elements of each row are indexed in 128 points. For example, the value of the first row may be 2 3 4, and the value of the second row may be 3 4 5... Repeat sampling in this way.

4) Previously, we have upsampled the features, but they are still the features corresponding to the original 128 points. If we do not transform, this upsamping will be meaningless. Therefore, we need to interpolate according to the previous distance matrix to change the value of the feature. Previously, what we keep in dist matrix is our distance value. Therefore, we only need to expand the dimension and multiply it with the feature (because the feature dimension is greater than the distance dimension), which is equivalent to weighting.

The above is what the whole FP module is doing. If you don't understand it, then analyze the whole process again through the debug code:

class PointNetFeaturePropagation(nn.Module):
    #                 in_channel=1280, mlp=[256, 256]
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        #            The centroid of the first two layers and the output of the first two layers
        """
        Input:
            Interpolate the following points by using the points of the previous layer
            xyz1: input points position data, [B, C, N]  l2 Layer output xyz
            xyz2: sampled input points position data, [B, C, S]  l3 Layer output  xyz
            points1: input points data, [B, D, N]  l2 Layer output  points
            points2: input points data, [B, D, S]  l3 Layer output  points

        Return:
            new_points: upsampled points data, [B, D', N]
        """
        "  take B C N Convert to B N C Then the number of high-dimensional point clouds is interpolated S Number of interpolated to low dimensional point clouds N (N greater than S)"
        "  xyz1 The number of low dimensional point clouds is N   xyz2 The number of high-dimensional point clouds is S"
        xyz1 = xyz1.permute(0, 2, 1) # For the first interpolation, 2,3128 -- > 2128,3 | for the second interpolation, 2,3512 -- > 2512,3
        xyz2 = xyz2.permute(0, 2, 1) # For the first interpolation, 2,3,1 -- > 2,1,3 | for the second interpolation, 2,3128 -- > 2128,3

        points2 = points2.permute(0, 2, 1)#  During the first interpolation, 21021,1 -- > 2,11024 finally, the low dimensional information is compressed into a point, which has 1024 features
                                          # Second interpolation 2256128 -- > 2128256
        B, N, C = xyz1.shape # N = 128 number of point clouds of low dimensional features (the number is greater than that of high-dimensional features)
        _, S, _ = xyz2.shape  # s = 1 number of point clouds of high-dimensional features

        if S == 1:
            "If there is only one point at the end, it will S Direct replication N After the copy, it is spliced with the low dimensional information"
            interpolated_points = points2.repeat(1, N, 1) # 21281024 directly uses splicing instead of interpolation for the first time
        else:
            "If it is not a point, the interpolation enlarges 128 points---->512 Points"
            "At this time, the calculated distance is a matrix 512 x128 That is, the distance between 512 low-dimensional points and 128 high-dimensional points"
            dists = square_distance(xyz1, xyz2)  # The second interpolation first calculates the distance between high and low dimensions 2512128
            dists, idx = dists.sort(dim=-1) # 2512128 sorts in the last dimension. Ascending sorting is used by default, that is, the higher the position, the closer xyz1 is to xyz2
            "Find the three nearest neighbors, here idx: 2,512,3 The meaning of is the index of 512 points and 128 first three points closest to each other," \
            "For example, the first line is: which three of the 128 points are closest to the first point in 512"
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [b, N, 3] 2512,3 at this time, the distance between xyz1 and xyz2 is stored in dist

            dist_recip = 1.0 / (dists + 1e-8)  # Find the reciprocal of distance 2512,3 corresponding to Wi(x) in the paper
            "yes dist_recip Reciprocal summation of torch.sum   keepdim=True Retain the summed dimension 2,512,1"
            norm = torch.sum(dist_recip, dim=2, keepdim=True) # That is, the sum of the nearest three neighbors corresponds to the denominator of the formula in the paper
            weight = dist_recip / norm # 2,512,3
            """
            there weight Is to calculate the weight  dist_recip The distance between three neighbors is stored in  norm Storage is the sum of distance  
            The division of the two is the proportion of each distance to the total, that is weight
            """
            t = index_points(points2, idx) # 2,512,3,256
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
            """
            points2: 2,128,256 (128 Points 256 features)   idx 2,512,3 (512 Index of the three points closest to 128 points out of 128 points)
            index_points(points2, idx) From the high-dimensional features (128 points), find the feature 2 corresponding to the three points with the smallest distance corresponding to the low-dimensional features (512 points),512,3,256
            The meaning of this index is more important. You can take another look idx Parameter interpretation, actually 2,512,3,256 All 512 points in the are composed of 128 points of high-dimensional features.
            For example, the first point in 512 may be composed of the first, second and third points in 128; The second point may be composed of 2,3,4 three points
            -------------------------------------------
            weight: 2,512,3    weight.view(B, N, 3, 1) ---> 2,512,3,1
            a And b do*Multiplication, the principle is if a And b of size Different, then in some way a or b Copy so that the copied a and b of size Same, and then a and b do element-wise Multiplication.
            This multiplication is equivalent to 512,3,256  The 256 dimensional vectors of the three points in the will be multiplied by their distance weight, that is, a number will be multiplied by the 256 dimensional vector
            torch.sum dim=2 Finally, in the second dimension, the sum of the features of the three points multiplied by the weight completes the up sampling of the feature points
            """

        if points1 is not None:
            points1 = points1.permute(0, 2, 1) # 2,256,128 -->2,128,256
            new_points = torch.cat([points1, interpolated_points], dim=-1) # 2,128,1280
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points

We only need to focus on its upper sampling part, and its convolution part will not be analyzed in detail here.
The general process is basically the same as the formula described by the author. Here we focus on the weight variable in some codes. Before, I thought it was always out of line with the formula of the paper. Later, I found that the formula of the paper can be aligned by taking it apart.

  there weight Is to calculate the weight  dist_recip The distance between three neighbors is stored in  norm Storage is the sum of distance  
   The division of the two is the proportion of each distance to the total, that is weight

Keywords: AI neural networks Pytorch Computer Vision Deep Learning

Added by Donny Bahama on Mon, 08 Nov 2021 15:14:55 +0200