Color planet image generation 3: fine tuning of code details (pytorch version)

Previous episode: Color planet image generation 2: using both traditional Gan discriminator and Markov discriminator (pytorch version)

Based on the previous set of code, some detailed modifications are made to improve the generation effect.

1. Modification

1.1 preprocessing scaling

The code for preprocessing training set pictures is modified to:

import cv2
import os
from PIL import Image

# Data set source
img_path = "train_images/"

for path, dirs, files in os.walk(img_path, topdown=False):
    file_list = list(files)
for file in file_list:
    image_path = img_path + file
    img = cv2.imread(image_path, 1)
    # Cut to square
    bias = (img.shape[1] - img.shape[0]) // 2
    img = img[:, bias:bias+img.shape[0], :]
    (B, G, R) = cv2.split(img)
    # Color channel merge
    img = cv2.merge([R, G, B])
    # ANTIALIAS scaling algorithm using Image
    img = Image.fromarray(img)
    img = img.resize((264, 264), Image.ANTIALIAS)

Improvement points: use the Image of the Image library The antialias parameter is used for Image scaling. Pre scaling reduces the time consumed by Image scaling in the training process. At the same time, the high-quality scaling algorithm of Image library can preserve the detail texture when reducing the large Image to low resolution, reduce the occurrence of sawtooth phenomenon, and is conducive to the training model to learn the detail texture features.

1.2 random turnover

The code of dataset construction in the training code is modified as follows:

elif config.read_from == "Memory":
    class image_dataset(Dataset):
        def __init__(self, file_list, img_path, transform):
            self.imgs = []
            for file in file_list:
                image_path = img_path + file
                img = cv2.imread(image_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                self.transform = transform
                # Save all original pictures in memory

        def __getitem__(self, index):
            # Modify the transform operation to be performed every time a specific picture is extracted
            img = self.imgs[index]
            img = self.transform(image=img)['image']
            return img

        def __len__(self):
            return len(self.imgs)
def get_transforms(img_size):
    # Scale the resolution and convert to 0-1
    return Compose(
        # The Resize part is cancelled, and 0.5 probability random vertical flip and horizontal flip are added at the same time
        # Obviously, the planet pictures should be able to be flipped at will, which effectively expands the amount of information in the training set
         [ HorizontalFlip(p=0.5),
         Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, p=1.0),

Modification point: save the original picture in memory, flip it horizontally and vertically at random in each epoch, and then convert it to Tensor data type to add operation, which is equivalent to increasing the number of training sets.

1.3 modify global discriminator

Modify the code of the global discriminator in the model part as follows:

# Global discriminator, traditional gan
class D_net_global(nn.Module):
    def __init__(self):
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 16, kernel_size=4, stride=1, padding=0, bias=False),
            nn.LeakyReLU(0.2, True),
        self.classifier = nn.Sequential(
            # Eliminate the two lines of comments below
            # nn.Linear(1024, 1024),
            # nn.ReLU(True),
            nn.Linear(1024, 1),

    def forward(self, img):
        features = self.features(img)
        features = features.view(features.shape[0], -1)
        output = self.classifier(features)
        return output

Modification point: the full connection layer at the end of the discriminator is cleared, and the matching relu layer is deleted to retain the information of picture details to the greatest extent.

1.4 modification progress printing

The progress printing code during training is modified to:

# Print program work progress
print("\rEpoch: %2d, Batch: %4d / %4d" % (epoch + 1, index + 1, batch_num), end="")

Modification point: the line feed at the end is cancelled. Each time you print, move to the beginning of the line to overwrite the contents of the last print, so as to refresh the batch work progress in real time in the same line, which is more intuitive and concise.

2. Effect

After the improvement, more epochs are trained, and the training set is expanded to 128. After training about 10000 epochs, the generated picture effect is as follows:

It is obvious that the picture has been greatly improved, the surface texture of the planet is clear, the grid phenomenon is greatly reduced, and the background noise is reduced.
Finally, a picture of the whole family during the training [4x8 per picture, with an interval of 100epoch]:

Keywords: Python Pytorch Computer Vision Deep Learning

Added by gx30uk on Wed, 22 Dec 2021 22:02:50 +0200