GAN generates MNIST dataset

1. What is GAN

GAN (Generative Adversarial Networks) is a deep learning model, which passes through two modules in the framework In the original GAN theory, as long as G and D can fit the corresponding generation and discrimination functions, they are not required to be neural networks, but in our practical application, we generally use deep neural networks as G and D.

GAN paper: https://arxiv.org/abs/1406.2661

2. Principle of GAN

@Basic principles

GAN is divided into a Discriminator (D for short) and a Generator (g for short). In short, G and D are two multilayer perceptrons or convolutional neural networks. Its basic idea is the generation game process of G and D.

G is a network that generates pictures. It receives a random noise Z and generates pictures through this noise, which is recorded as G(z)

D is a network for discrimination, which can distinguish whether a picture is true or not. That is, if you input a true picture to D, it will assign label to 1, and if you input a false picture, it will assign label to 0

In the training process, the goal of generating network G is to generate real pictures as much as possible to deceive and distinguish network D, so that D thinks that what he generates is a real picture; The goal of D is to distinguish the pictures generated by G from the real pictures as much as possible, so that G and D form a dynamic game process.

So, what is the final result of the game? In an ideal state, G is enough to generate a picture G(z) that is false as true. For D, it is difficult to determine whether the picture generated by G is true or not. Therefore, D(G(z))=0.5

The specific process is shown in the figure.

First, there is a generation of G, which generates some poor pictures, and then there is a generation of D, which can accurately distinguish the fake pictures generated by G from the real pictures and label them 0. In fact, this D is a two classifier, which outputs 0 to the generated image and 1 to the real image.

Then after training, the second generation of G appeared, which can generate slightly better pictures and make the first generation of d think that he generated real pictures. At this time, there is also the second generation D, which can identify which pictures are generated by G and which are real pictures.

And so on, there will be three generations, four generations.... n generations of G (generator) and D (discriminator), and finally D can't distinguish the generated picture from the real picture, so the network is fitted.

What are these two networks?

@Discriminator Network

The first thing to say is the confrontation network, because this network is relatively simple.

The countermeasure network is simply a discriminator to judge whether it is true or false, which solves the problem of two classification. When inputting a real picture, we want its output to be 1. When inputting a fake picture, we want it to output 0. In fact, this has nothing to do with the category of the original picture. No matter what category the original picture is, we collectively call it a true picture, and the label is 1; The generated picture is false and the label is 0

In the process of D training, we hope that this discriminator can accurately distinguish between real pictures and false pictures. There are many solutions to this binary classification problem, such as logistic regression, deep network, convolutional neural network and cyclic neural network.

# Discriminant network
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
                 nn.Linear(784, 256),
                 nn.LeakyReLU(0.2),
                 nn.Linear(256, 256),                                        nn.LeakyReLU(0.2),
                 nn.Linear(256, 1),
                 # sigmoid activation function obtains a probability between 0 and 1 for binary classification
                 nn.Sigmoid())   
 
    def forward(self, x):
        x = self.dis(x)
        return x

@Generative Network

How can I generate a fake picture?

Firstly, a simple high-dimensional normally distributed noise vector is given, such as the D-dimensional noise vector shown in the above figure. By affine transformation, that is, mapping xw+b to a higher dimension, and then rearranging it into a rectangle, it looks more like a picture, and then goes through a series of convolution, pooling and activation operations, As like as two peas, we get a noise matrix that is exactly the same as the size of the image we input. How do we train our generators at this time? In fact, the result is obtained through the discriminator. We continue to increase the probability that the discriminator recognizes the generated picture as true. In this step, we will not update the parameters of the discriminator, but only the parameters of the generator.

# Generate network
class generator(nn.Module):
    def __init__(self, input_size):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            #The Tanh activation function is that it is hoped that the generated false picture data distribution can be between - 1 and 1
            nn.Tanh()
        )
 
    def forward(self, x):
        x = self.gen(x)
        return x

3. Training Train

@Discriminator training

The training of the discriminator consists of two parts: the true picture is judged as true and the false picture is judged as false. In this process, the participation of the generator does not participate in the update

The first is to define the measurement method and optimization function of the loss function. The loss measurement uses the cross entropy of binary classification, and the learning rate of the optimization function should be 0.0003

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

Then enter training

img = img.view(num_img, -1)      # Expand the picture by 28x28=784
real_img = Variable(img).cuda()  # Change the tensor into Variable and put it into the calculation diagram
real_label = Variable(torch.ones(num_img)).cuda()  # Define the true label as 1
fake_label = Variable(torch.zeros(num_img)).cuda() # Define false label as 0

# Calculate the loss of the real picture
real_out = D(real_img)  # Put the real picture into the discriminator
d_loss_real = criterion(real_out, real_label)  # Get the loss of the real picture  
real_scores = real_out  # The closer the real picture is put into the discriminator output, the better

# Calculate the loss of false pictures
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # Generate some noise randomly
fake_img = G(z)          # Put a fake picture into the generation network
fake_out = D(fake_img)   # The discriminator judges the false picture
d_loss_fake = criterion(fake_out, fake_label)  # Get the loss of the fake picture
fake_scores = fake_out   # The closer the false picture is put into the discriminator, the better

# bp and optimize
d_loss = d_loss_real + d_loss_fake  # Add up the loss of true and false pictures
d_optimizer.zero_grad()  # Return to 0 gradient
d_loss.backward()        # Back propagation
d_optimizer.step()       # Update parameters

@Generator training

In the training process of generating the network, we generate a false picture, but we hope that the discriminator can recognize it as a true picture. We fix the discriminator, and the result of passing the false picture into the discriminator corresponds to the real label. The updated parameters of back propagation are the parameters in the generation network, so we can update the parameters in the generation network to make the discriminator judge that the generated false picture is true, so as to achieve the role of generation confrontation.

# Calculate the loss of false pictures
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # Random noise is obtained
fake_img = G(z)       # Generate fake pictures
output = D(fake_img)  # The result is obtained by discriminator
g_loss = criterion(output, real_label)  # Get the loss of false picture and real picture label

# bp and optimize
g_optimizer.zero_grad()  # Return to 0 gradient
g_loss.backward()        # Back propagation
g_optimizer.step()       # Update the parameters of the generated network

4. All code (pytorch Implementation)

Paste the complete code of the program:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os

if not os.path.exists('img'):
    os.mkdir('img')

def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out

# Initialization parameters
batch_size = 128
num_epoch = 50
z_dimension = 100
# Do some pre-processing on the picture
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# img_transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ]
# Download dataset
mnist = datasets.MNIST(
    root='mnist_data', train=True, transform=img_transform, download=False)
# Load dataset
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)

# Discriminant network
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(nn.Linear(784, 256),
                                 nn.LeakyReLU(0.2),
                                 nn.Linear(256, 256),
                                 nn.LeakyReLU(0.2), nn.Linear(256, 1),
                                 nn.Sigmoid())  # sigmoid activation function obtains a probability between 0 and 1 for binary classification

    def forward(self, x):
        x = self.dis(x)
        return x


# generator 
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh())  # The Tanh activation function is that it is hoped that the generated false picture data distribution can be between - 1 and 1.

    def forward(self, x):
        x = self.gen(x)
        return x


D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# The training of the discriminator consists of two parts. The first part is to judge the true image as true and the second part is to judge the false image as false. In these two processes, the parameters of the generator do not participate in the update.
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
# Start training
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # ================================Training discriminator===================================
        img = img.view(num_img, -1)  # # Expand the picture by 28x28=784
        # real_img = Variable(img).cuda()
        # real_label = Variable(torch.ones(num_img)).cuda()
        # fake_label = Variable(torch.zeros(num_img)).cuda()
        real_img = Variable(img)
        real_label = Variable(torch.ones(num_img))  # Define the real label as 1
        fake_label = Variable(torch.zeros(num_img))  # Define the false label as 1

        # Calculate real_ Loss of img
        real_out = D(real_img)  # Put the real picture into the discriminator
        d_loss_real = criterion(real_out, real_label)  # Get the loss of the real picture
        real_scores = real_out  # The closer you get to one, the better

        # Calculate fake_ Loss of img
        # z = Variable(torch.randn(num_img, z_dimension)).cuda()
        z = Variable(torch.randn(num_img, z_dimension))  # Generate some noise randomly
        fake_img = G(z)  # Put it into the network to generate a fake picture
        fake_out = D(fake_img)  ## The discriminator judges the false picture
        d_loss_fake = criterion(fake_out, fake_label)  ## Get the loss of the fake picture
        fake_scores = fake_out  # The closer you get to 0, the better

        # Back propagation and optimization
        d_loss = d_loss_real + d_loss_fake  # Add up the loss of true and false pictures
        d_optimizer.zero_grad()  # Zero each gradient
        d_loss.backward()  # Back propagation
        d_optimizer.step()  # Update parameters

        # =================================Training generator================================

        # Calculate fake_img loss
        # z = Variable(torch.randn(num_img, z_dimension)).cuda()
        z = Variable(torch.randn(num_img, z_dimension))  # Random noise is obtained
        fake_img = G(z)  # Generate fake pictures
        output = D(fake_img)  # The result is obtained by discriminator
        g_loss = criterion(output, real_label)  ##Get the loss of false picture and real picture label

        # Back propagation and optimization
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f},D real: {:.6f}, D fake: {:.6f}'.format(
                epoch, num_epoch, d_loss.item(), g_loss.item(),
                real_scores.data.mean(), fake_scores.data.mean()))

    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, 'real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, 'fake_images-{}.png'.format(epoch + 1))
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')

5. Result

Result display:

With the increase of epoch, it can be found that the generated noise is less, the training is more stable, and the numbers in the picture gradually change from fuzzy to clear. The picture in epoch-49 is just like a real picture.

Keywords: Deep Learning

Added by papa on Sat, 01 Jan 2022 18:45:54 +0200