Unpaired degradation
1. Basic structure
-
generator
1 conv + 8 resblocks + 1 conv
Generator( (block_input): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): PReLU(num_parameters=1) ) (res_blocks): ModuleList( (0): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (1): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (2): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (3): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (4): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (5): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (6): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (7): ResidualBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (prelu): PReLU(num_parameters=1) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (block_output): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) )
-
discriminator
DiscriminatorBasic( (net): Sequential( (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (1): LeakyReLU(negative_slope=0.2) (2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): LeakyReLU(negative_slope=0.2) (5): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): LeakyReLU(negative_slope=0.2) ) (gan_net): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) )
2. Loss function
The schematic diagram of the loss function is as follows:
2.1 generator
(1) Perceived_loss
- Input: Xb (preliminary LR obtained by HR down sampling) and Xd (fake LR img generated by GAN)
- Objective: to keep the style consistency between the degraded image Xd and the original Xb.
- Implementation: VGG is used to calculate
(2) color loss
- Input: Xb (preliminary LR obtained by HR down sampling) and Xd (fake LR img generated by GAN)
- Objective: to maintain a certain similarity in color between the degraded image Xd and the initial Xb, so that the color cannot be changed by degradation.
- Realization: first use low-pass filtering, and then calculate L1 loss. The formula is as follows. The author believes that the low-pass filter preserves the color information of the image, and adopts average pooling, k=5, stirde = 1 in the implementation of the code. The author also said that low-pass can be realized in many ways, not limited to average pooling
(3) GAN loss
- Input: real LR (z) and Xd (fake LR img generated by GAN)
- Objective: in this paper, Xb is first generated through down sampling, which eliminates the high image frequency and keeps the low-frequency information within the reduced number of pixels. This leads to the loss of high-frequency features, while low-frequency information such as color and background still exist. Therefore, high pass filtering is used to obtain the high-frequency information of Xd and z. Distinguish the image after high pass filtering.
- Implementation: the high-frequency image is the original image minus the low-frequency image. The general formula is as follows:
Total generator losses:
2.2 discriminator
- Input: real LR (z) and Xd (fake LR img generated by GAN)
- Implementation: the high-frequency image is the original image minus the low-frequency image. Standard cross entropy loss.
3. Training data and parameter setting
3.1 the certificate is over divided into tasks
(1) Data
- LR: low definition ID photo, align to 128 × one hundred and twenty-eight
- HR: HD ID photo, align to 256 × two hundred and fifty-six
The whole image is used as input to train dsgan
(2) Parameter setting
- Learning rate: 0.0002
- Total iter: 8w
4. Evaluation indicators
See the article for details: FID evaluation index
Frechet Inception Distance score (FID) is a measure of the distance between the feature vectors of the real image and the generated image.
At present, the fid of human face is about 9.8