CGAN implementation process

In this paper, MNIST data set is used for training, and the graphical method is used to show the difference between the input in CGAN and GAN, so as to help understand the operation process of CGAN

1, Principle

As shown in the figure below, when we input noise z, we add an additional restriction condition, z and c to obtain the generated picture through generator G

2, Parameter initialization

With the above principle explanation, we can initialize our parameters. We can roughly see that we have the following parameters: noise z, condition c, real picture x, initialization parameters of generator and discriminator

  • Input of G: z_ And y_vec_
  • Input of D: x and y_fill_
  • Initialization of model parameters
  • Noise sample for test_ z_ And the corresponding label sample_y_

The single noise dimension entered here is z_dim=62. Of course, there are many other initializations, such as optimizer. Because this article mainly introduces the specific implementation process of the model, it only introduces the initialization of variables

1. Input of G

  • Input noise z: uz#: (64, 62)
  • Input condition c: y_vec_:(64, 10)

Input of final G: transverse splicing z+c (64, 72)

G:
torch.Size([64, 72])
tensor([[0.8920, 0.9742, 0.6876,  ..., 0.0000, 0.0000, 0.0000],
        [0.5271, 0.6423, 0.7480,  ..., 0.0000, 1.0000, 0.0000],
        [0.9545, 0.6324, 0.9603,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.1931, 0.7773, 0.8154,  ..., 0.0000, 0.0000, 0.0000],
        [0.0049, 0.7129, 0.3272,  ..., 0.0000, 0.0000, 0.0000],
        [0.2902, 0.1194, 0.0020,  ..., 0.0000, 1.0000, 0.0000]])

2. Input of D

  • Input real data: x: (64, 1, 28, 28)
  • Input generation data: G(z): (64, 1, 28, 28)
  • Input condition: c: y_fill_: (64, 10, 28, 28)

Input of final D: horizontal splicing x+c (64, 11, 28, 28), that is to say, take a value in the batch, the dimension is (1, 28, 28), and take it as the first dimension of (11, 28, 28). If the label of the remaining ten dimensions is 0, the second dimension is all 1, and the rest is all 0. If the label is 1, the third dimension is all 1, and the rest is all 0, and so on

D:
torch.Size([64, 11, 28, 28])
tensor([[[[ 0.1099, -0.5590,  0.9668,  ...,  3.0843,  0.6788, -0.4171],
          [ 0.8949, -0.3523, -0.4086,  ..., -0.8257, -2.1445,  1.0512],
          [ 1.5333, -0.0918, -1.1146,  ..., -1.1746, -0.4689,  0.3702],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

3. Initialization of model parameters

def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

4. Test noise

During the test, we only need to set the input of G, that is to say, we need:

  • Input noise z: uz#: (100, 62)
  • Input condition c: y_vec_:(100, 10)

Input of final G: transverse splicing z+c (100, 72)

The code and output are given below

# fixed noise
sample_z_ = torch.randn((100, 62))
for i in range(10):
    sample_z_[i*10] = torch.rand(1, 62)
    for j in range(1, 10):
        sample_z_[i*10 + j] = sample_z_[i*10]
print(sample_z_)
"""
sample_z_:(100, 62)
          0-9:    same value
          10-19:  same value
          ...
          90-99:  same value
"""
temp = torch.zeros((10, 1))     # (10,1)---> 0,0,0,0,0,0,0,0,0,0
for i in range(10):
    temp[i, 0] = i                     # (10, 1) ---> 0,1,2,3,4,5,6,7,8,9
# print("temp:      ", temp)

temp_y = torch.zeros((100, 1))  #(100,1)---> 0,0,0,0,...,0,0,0,0
for i in range(10):             #(100,1)---> 0,1,2,3,...,6,7,8,9
    temp_y[i*10: (i+1)*10] = temp
# print("temp_y:    ", temp_y)           
sample_y_ = torch.zeros((100, 10)).scatter_(1, temp_y.type(torch.LongTensor), 1)
print(sample_y_)                       #(100,10)
'''
tensor([[0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],
        [0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],
        [0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],
        ...,
        [0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332],
        [0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332],
        [0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332]])
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
'''

A detailed explanation is given below. We know that the input of G has noise and conditions. Here we have 100 groups of noise. The values in each group of 10 noise are exactly the same, but the conditions of each noise in the group are different, representing the numbers 0-9 respectively

In other words, we want to generate ten numbers from 0 to 9 with the same noise and generate ten groups

3, Execution process

The red line in the figure represents an execution process, the green line represents an execution process, and the red box is the network of back propagation in this step. Because the discriminator and the generator are trained separately, they are represented by two diagrams. The first step training discriminator is on the left and the second step training generator is on the right

  • step1: input the sample first and use BCE_loss to evaluate D_real_loss, then input the data generated by G, and similarly evaluate D_fake_loss, add the two for back propagation optimization D. Be careful not to optimize g at this step
  • Step 2: directly input the data generated by G and get g through evaluation_ Loss, back propagation optimization G. Note that although this step is the data generated by G, it needs to calculate the loss with real after passing D

4, Testing

After the training, the test can be carried out directly. The pictures generated by the final test are as follows:

Keywords: Python Pytorch Deep Learning

Added by AdamSnow on Sat, 23 Oct 2021 05:42:09 +0300