Quantitative basis of neural network -- Construction of model and Realization of basic quantitative function

Quantitative basis of neural network

Quantitative basis of neural network (1) -- Construction of model and Realization of basic quantitative function

preface

This article is to sort out the code when reading the blog, in order to have a clearer understanding of the basic process of quantification.

This paper uses pytorch to build a basic network model manually, and the data set is mnist data set. I won't repeat it here, just look at the training code

Network model construction

1. Network structure

First, the structure of the defined network is described as follows:

class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)  # The error caused by quantization can be increased by using packet network
        self.fc = nn.Linear(5*5*40, 10)

    def forward(self, x):  #
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*40)
        x = self.fc(x)
        return x

The implementation style is shown below and consists of two convolution layers and a linear layer.

Net(
  (conv1): Conv2d(1, 40, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), groups=20)
  (fc): Linear(in_features=1000, out_features=10, bias=True)
)

2. Network training

python
from model import *

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    lossLayer = torch.nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = lossLayer(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += lossLayer(output, target).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
        test_loss, 100. * correct / len(test_loader.dataset)
    ))


if __name__ == "__main__":
    batch_size = 64
    test_batch_size = 64
    seed = 1
    epochs = 15
    lr = 0.01
    momentum = 0.5
    save_model = True
    using_bn = False

    torch.manual_seed(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
    )

    if using_bn:
        model = NetBN().to(device)  #Home BN network model
    else:
        model = Net().to(device)     #The network model without BN is adopted

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
    
    if save_model:
        if not osp.exists('ckpt'):
            os.makedirs('ckpt')
        if using_bn:
            torch.save(model.state_dict(), 'ckpt/mnist_cnn.pt')
        else:
            torch.save(model.state_dict(), 'ckpt/mnist_cnn.pt')

After training the network model, you will get a weight file, namely ckpt/mnist_cnnbn.pt.

3. Test the training model

Firstly, we load the weight, and then calculate the accuracy of the model in the case of full accuracy.

def full_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))

   batch_size = 64
    using_bn = False

if __name__ == "__main__":
    batch_size = 64
    using_bn = False

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True
    )

    if using_bn:
        model = NetBN()
        model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt', map_location='cpu'))
    else:
        model = Net()
        model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
    model.eval()
    full_inference(model, test_loader)

The next step is our quantification of the model. This paper mainly adopts the post quantization method.

4. Basic quantization function

4.1 calculation of scale factor and zero point

First, we need to implement the basic formula of quantization. They calculate the scale and zero point respectively.


Its code is as follows:

"""
Calculate quantization scale and zero

:min_val: minimum value
:scale: Maximum
:num_bits: Quantized bit position
:returns: Quantitative scale; zero point
"""
def calcScaleZeroPoint(min_val, max_val, num_bits=8):
    qmin = 0.
    qmax = 2. ** num_bits - 1.
    scale = float((max_val - min_val) / (qmax - qmin))

    zero_point = qmax - max_val / scale

    if zero_point < qmin:
        zero_point = qmin
    elif zero_point > qmax:
        zero_point = qmax
    
    zero_point = int(zero_point)

    return scale, zero_point
    
"""
This is right tensor Quantify

:x: Input to be quantified tensor
:scale: Quantitative scale
:zero_point: zero point
:num_bits: Quantized bit position
:signed: Is quantization signed or not
:returns: Tensor after quantization
"""

def quantize_tensor(x, scale, zero_point, num_bits=8, signed=False):
    if signed:
        qmin = - 2. ** (num_bits - 1)
        qmax = 2. ** (num_bits - 1) - 1
    else:
        qmin = 0.
        qmax = 2.**num_bits - 1.
 
    q_x = zero_point + x / scale      #Calculate the quantized value
    q_x.clamp_(qmin, qmax).round_()   #round stands for rounding
    
    return q_x                        #Quantized value

"""
This is right tensor Inverse quantization

:x: Quantized value
:scale: Quantitative scale
:zero_point: zero point
:returns: Output tensor after inverse quantization
"""
def dequantize_tensor(q_x, scale, zero_point):
    return scale * (q_x - zero_point)

4.2 solving the maximum and minimum values

It can be seen from the above formula that if we want to calculate the scale factor and zero to realize the quantization of the value, we need to know the maximum and minimum values of the value and the quantized bits. Therefore, in the post training quantization process, it is necessary to count the samples and the min and max of the middle layer, and some quantization and inverse quantization operations are frequently involved. Therefore, we can package these functions into a QParam class. The code is as follows:

"""
QParam  In the process of post training quantification,It is necessary to make statistics on the samples and the of the middle layer first min, max,It is also frequently involved
        To some quantification,Inverse quantization operation,So we can package all these functions into one QParam class

:num_bits: Bit
:scale: Quantitative scale
:min: minimum value
:max: Maximum
"""
class QParam:

    def __init__(self, num_bits=8):
        self.num_bits = num_bits
        self.scale = None
        self.zero_point = None
        self.min = None    #Manually defined maximum
        self.max = None    #Manually defined minimum

    """
    update Function is used for statistics min,max
    """
    def update(self, tensor):

        if self.max is None or self.max < tensor.max():
            self.max = tensor.max()
        self.max = 0 if self.max < 0 else self.max

        if self.min is None or self.min > tensor.min():
            self.min = tensor.min()
        self.min = 0 if self.min > 0 else self.min

        #Calculate the quantization scale and zero of the parameters
        self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.num_bits)

    def quantize_tensor(self, tensor):
        return quantize_tensor(tensor, self.scale, self.zero_point, num_bits=self.num_bits)

    def dequantize_tensor(self, q_x):
        return dequantize_tensor(q_x, self.scale, self.zero_point)

The defined class can realize the quantization operation of data. Its main operation steps are: first, use the update() function to calculate the maximum and minimum values of the input tensor, and call the function calcScaleZeroPoint() to calculate the scale factor and zero point (this is the specific implementation process of Formula 1). Then, the numerical value is quantized, and the function used is quantize_tensor(), which is the specific implementation process of formula 2. QParam also defines a method for inverse quantization, quantify_ tensor().

It should be noted that except for the min and max of the first conv, all other layers only need to count the min and max of the intermediate output feature. In addition, for activation functions such as relu and maxpooling, they will follow the min and max output from the previous layer without additional statistics, that is, the same min and Max will be shared in the above figure.

The following is the framework of the quantization module.

5. Network quantization module

We need to define some basic quantization modules to call convolution, full connection, Relu and pooling layers. First, we need to define a quantization base class, which can reduce some duplicate codes and make the structure of the code clearer.

class QModule(nn.Module):

    def __init__(self, qi=True, qo=True, num_bits=8):
        super(QModule, self).__init__()
        if qi: 
            self.qi = QParam(num_bits=num_bits)  #
        if qo: 
            self.qo = QParam(num_bits=num_bits)

    """
    freeze  The second is freeze function,This function is mainly the calculation formula (4) in M,q_w,q_b
    """
    def freeze(self):
        pass

    def quantize_inference(self, x):
        raise NotImplementedError('quantize_inference should be implemented.')

This base class defines the methods to be provided by each quantization module.

First__ init__ Function, which specifies not only the number of bits to be quantized, but also whether to provide quantized input (qi) and output (qo). At the same time, as mentioned earlier, not every network needs to count the input min and max. most middle layers use the qo of the upper layer as their own qi.

The second is the free function, which will play a role after the statistics of min and max. As mentioned above, many items in formula (4) can be calculated in advance. freeze fixes these items in advance and converts the weight of the network from floating-point real numbers to fixed-point integers.

Added by tdeez173 on Tue, 18 Jan 2022 18:51:43 +0200