DSGAN degenerate network

Unpaired degradation

1. Basic structure

  • generator

1 conv + 8 resblocks + 1 conv

Generator(
  (block_input): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=1)
  )
  (res_blocks): ModuleList(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (4): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (5): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (6): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (7): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (block_output): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
  • discriminator 

DiscriminatorBasic(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
  )
  (gan_net): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
)

2. Loss function

The schematic diagram of the loss function is as follows:

2.1 generator
(1) Perceived_loss

  • Input: Xb (preliminary LR obtained by HR down sampling) and Xd (fake LR img generated by GAN)
  • Objective: to keep the style consistency between the degraded image Xd and the original Xb.
  • Implementation: VGG is used to calculate

(2) color loss

  • Input: Xb (preliminary LR obtained by HR down sampling) and Xd (fake LR img generated by GAN)
  • Objective: to maintain a certain similarity in color between the degraded image Xd and the initial Xb, so that the color cannot be changed by degradation.
  • Realization: first use low-pass filtering, and then calculate L1 loss. The formula is as follows. The author believes that the low-pass filter preserves the color information of the image, and adopts average pooling, k=5, stirde = 1 in the implementation of the code. The author also said that low-pass can be realized in many ways, not limited to average pooling

(3) GAN loss

  • Input: real LR (z) and Xd (fake LR img generated by GAN)
  • Objective: in this paper, Xb is first generated through down sampling, which eliminates the high image frequency and keeps the low-frequency information within the reduced number of pixels. This leads to the loss of high-frequency features, while low-frequency information such as color and background still exist. Therefore, high pass filtering is used to obtain the high-frequency information of Xd and z. Distinguish the image after high pass filtering.
  • Implementation: the high-frequency image is the original image minus the low-frequency image. The general formula is as follows:

Total generator losses:

 

2.2 discriminator

  • Input: real LR (z) and Xd (fake LR img generated by GAN)
  • Implementation: the high-frequency image is the original image minus the low-frequency image. Standard cross entropy loss.

3. Training data and parameter setting

3.1 the certificate is over divided into tasks

(1) Data

  • LR: low definition ID photo, align to 128 × one hundred and twenty-eight
  • HR: HD ID photo, align to 256 × two hundred and fifty-six

The whole image is used as input to train dsgan

(2) Parameter setting

  • Learning rate: 0.0002
  • Total iter: 8w

4. Evaluation indicators

See the article for details: FID evaluation index

Frechet Inception Distance score (FID) is a measure of the distance between the feature vectors of the real image and the generated image.

At present, the fid of human face is about 9.8

 

Keywords: GAN

Added by exoduses on Mon, 14 Feb 2022 11:41:58 +0200