Pytoch: fast migration of image style

Pytoch: fast migration of image style - residual network, fixed style and arbitrary content

Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, School of Artificial and Intelligence, Huazhong University of Science and Technology

Pytoch tutorial column link

This tutorial is not for commercial use. It is only for learning and reference. If you need to reprint it, please contact me.

Reference

Perceptual Losses for Real-Time Style Transfer and Super-Resolution

ResNet

Unlike ordinary style migration, the input image of ordinary image style migration is random noise, while the input of fast style migration is an image conversion network f w fw fw output.

Fast style migration is through the input image x x x through image conversion network f w fw fw, get the output of the network y ^ \hat{y} y^​ . Therefore, it can realize fast image migration of arbitrary content.

Refer to the article perceptual losses for real time style transfer and super resolution to adjust the up sampling operation of the image conversion network accordingly. In the established network, the transpose convolution operation will be used for up sampling of feature mapping.

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from PIL import Image
import time

import torch
import torch.nn as nn 
import torch.utils.data as Data 
import torch.nn.functional as F 
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision import models
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# Select GPU for model loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
print(device)
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
cuda
1
GeForce MX250

Fast style migration network preparation

adopt 3 3 Three convolution layers reduce the dimension of the image feature mapping, and then 5 5 Five residual connection layers, learn the image style and add it to the content image. Finally, through 3 3 Three transpose convolution operations are used to upgrade the dimension of the feature map (analogous to the semantic segmentation network) to reconstruct the image after style migration.

In the dimension upgrading operation of the conversion network, transpose convolution is used to replace the combination of up sampling and convolution layer in the article, because the input is the standardized image, and the pixel value range is − 2.1 − 2.7 -2.1-2.7 − 2.1 − 2.7, so in the last output layer of the network, the activation function is not used, and most of the output values of the network will be in − 2.1 − 2.7 -2.1-2.7 Between − 2.1 − 2.7, only a small part is not in this interval, so when actually training the network, the output will be cut to − 2.1 − 2.7 -2.1-2.7 Between − 2.1 − 2.7, that is, the last layer does not need to use the activation function, and other layers use the ReLU function. In the network, the number of feature maps gradually changes from 3 3 3 added to 128 128 128, and each residual connection layer has 128 128 128 feature maps, and the number of feature maps in the transposed convolution layer will increase from 128 128 128 down to 3 3 3. Three channels corresponding to the image.

Structural block definition residuals

Focus on the local of neural network. Set the input to X. Suppose that the ideal mapping we want to learn is f(x), which is used as the input of the activation function. The part needs to fit the residual mapping f(x) − X of identity mapping. Residual mapping is often easier to optimize in practice. Take identity mapping as the ideal mapping f(x) we want to learn. We only need to learn the weight and deviation parameters of weighting operations (such as affine) 0 0 0, then f(x) is an identity map. In practice, when the ideal mapping f(x) is very close to the identity mapping, the residual mapping is also easy to capture the subtle fluctuations of the identity mapping. In the residual block, the input can propagate forward faster through cross layer data lines.

Define residual connection network, 128 128 128 feature maps, active size 128 × 64 × 64 128\times64\times64 128×64×64

# ResidualBlock residual block
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size = 3, stride = 1, padding = 1)
        )
    def forward(self, x):
        return F.relu(self.conv(x) + x)
Define image conversion network

The lower sampling module, 5 5 Five residual connection modules and up sampling module

# Define image conversion network
class ImfwNet(nn.Module):
    def __init__(self):
        super(ImfwNet, self).__init__()
        # Down sampling
        self.downsample = nn.Sequential(
            nn.ReflectionPad2d(padding = 4), # Use boundary reflection fill
            nn.Conv2d(3, 32, kernel_size = 9, stride = 1),
            nn.InstanceNorm2d(32, affine = True), # Normalize pixel values
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size = 3, stride = 2),
            nn.InstanceNorm2d(64, affine = True),
            nn.ReLU(),
            nn.ReflectionPad2d(padding = 1),
            nn.Conv2d(64, 128, kernel_size = 3, stride = 2),
            nn.InstanceNorm2d(128, affine = True),
            nn.ReLU()
        )
        # 5 residual connections
        self.res_blocks = nn.Sequential(
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
        )
        # Up sampling
        self.unsample = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
            nn.InstanceNorm2d(64, affine = True),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
            nn.InstanceNorm2d(32, affine = True),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size = 9, stride = 1, padding = 4)
        )
    def forward(self, x):
        x = self.downsample(x) # The input pixel value is between - 2.1-2.7
        x = self.res_blocks(x)
        x = self.unsample(x) # The output pixel value is between - 2.1-2.7
        return x
myfwnet = ImfwNet().to(device)
from torchsummary import summary
summary(myfwnet, input_size=(3, 256, 256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ReflectionPad2d-1          [-1, 3, 264, 264]               0
            Conv2d-2         [-1, 32, 256, 256]           7,808
    InstanceNorm2d-3         [-1, 32, 256, 256]              64
              ReLU-4         [-1, 32, 256, 256]               0
            Conv2d-5         [-1, 64, 127, 127]          18,496
    InstanceNorm2d-6         [-1, 64, 127, 127]             128
              ReLU-7         [-1, 64, 127, 127]               0
   ReflectionPad2d-8         [-1, 64, 129, 129]               0
            Conv2d-9          [-1, 128, 64, 64]          73,856
   InstanceNorm2d-10          [-1, 128, 64, 64]             256
             ReLU-11          [-1, 128, 64, 64]               0
           Conv2d-12          [-1, 128, 64, 64]         147,584
             ReLU-13          [-1, 128, 64, 64]               0
           Conv2d-14          [-1, 128, 64, 64]         147,584
    ResidualBlock-15          [-1, 128, 64, 64]               0
           Conv2d-16          [-1, 128, 64, 64]         147,584
             ReLU-17          [-1, 128, 64, 64]               0
           Conv2d-18          [-1, 128, 64, 64]         147,584
    ResidualBlock-19          [-1, 128, 64, 64]               0
           Conv2d-20          [-1, 128, 64, 64]         147,584
             ReLU-21          [-1, 128, 64, 64]               0
           Conv2d-22          [-1, 128, 64, 64]         147,584
    ResidualBlock-23          [-1, 128, 64, 64]               0
           Conv2d-24          [-1, 128, 64, 64]         147,584
             ReLU-25          [-1, 128, 64, 64]               0
           Conv2d-26          [-1, 128, 64, 64]         147,584
    ResidualBlock-27          [-1, 128, 64, 64]               0
           Conv2d-28          [-1, 128, 64, 64]         147,584
             ReLU-29          [-1, 128, 64, 64]               0
           Conv2d-30          [-1, 128, 64, 64]         147,584
    ResidualBlock-31          [-1, 128, 64, 64]               0
  ConvTranspose2d-32         [-1, 64, 128, 128]          73,792
   InstanceNorm2d-33         [-1, 64, 128, 128]             128
             ReLU-34         [-1, 64, 128, 128]               0
  ConvTranspose2d-35         [-1, 32, 256, 256]          18,464
   InstanceNorm2d-36         [-1, 32, 256, 256]              64
             ReLU-37         [-1, 32, 256, 256]               0
  ConvTranspose2d-38          [-1, 3, 256, 256]           7,779
================================================================
Total params: 1,676,675
Trainable params: 1,676,675
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 246.85
Params size (MB): 6.40
Estimated Total Size (MB): 253.99
----------------------------------------------------------------
# Output network structure
from torchviz import make_dot

x = torch.randn(1, 3, 256, 256).requires_grad_(True)
y = myfwnet(x.to(device))
myResNet_vis = make_dot(y, params=dict(list(myfwnet.named_parameters()) + [('x', x)]))
myResNet_vis


Fast style migration data preparation

Download address: https://cocodataset.org/#home

The validation set of COCO2014 is used as model input.

# Define image preprocessing
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256), # The image size is 256 * 256
    transforms.ToTensor(), # Tensor converted to 0-1
    transforms.Normalize(mean = [0.485, 0.456, 0.406],
                         std = [0.229, 0.224, 0.225]) 
                         # The pixel value is changed to -2.1-2.7
])
# Read data from folder
dataset = ImageFolder('./data/COCO', transform = data_transform)
# Each batch uses 4 images
data_loader = Data.DataLoader(dataset, batch_size = 4, shuffle = True,
                              num_workers = 8, pin_memory = True)
dataset
Dataset ImageFolder
    Number of datapoints: 40504
    Root location: ./data/COCO
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear)
               CenterCrop(size=(256, 256))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

Description: parameter pin_memory means that when creating a DataLoader, the Tensor data generated first belongs to the lock page memory in memory (all the video memory in the graphics card is lock page memory), so it will be faster to transfer the Tensor of memory to the video memory of GPU, and the operation speed for high-performance GPU will be faster.

Next, read the pre trained VGG16 network, only need the layer contained in the features, and set it to the GPU device. In the calculation, we only need to use VGG network to extract the feature map of a specific layer, without training the parameters, and set it to eval

# Read pre trained VGG16 network
vgg16 = models.vgg16(pretrained = True)
# No classifier is needed, only convolution layer and pooling layer are needed
vgg = vgg16.features.to(device).eval()

Define a method that can read style images and convert them to a four-dimensional tensor format that can be used by VGG networks.

# Define a reading style image function and convert the image as necessary
def load_image(img_path, shape = None):
    image = Image.open(img_path)
    size = image.size
    if shape is not None:
        size = shape # If the image size is specified, it is converted to the specified size
    # Use transforms to convert the image into tensor and standardize it
    in_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(), # Tensor converted to 0-1
        transforms.Normalize(mean = [0.485, 0.456, 0.406],
                            std = [0.229, 0.224, 0.225])
    ])
    # Use the RGB channel of the image and add the batch dimension
    image = in_transform(image)[:3, :, :].unsqueeze(dim = 0)
    return image
# Define a function to convert the standardized image into a visual function using matplotlib
def im_convert(tensor):
    '''
    take[1, c, h, w]The tensor of dimension is transformed into[h, w, c]Array of
    Because the tensor is transformed into a table, it is necessary to carry out standardized inverse transformation
    '''
    tensor = tensor.cpu()
    image = tensor.data.numpy().squeeze() # Remove data from batch dimension
    image = image.transpose(1, 2, 0) # Permutation dimension [C, W, H -] [C, H]
    # Conduct standardized reverse operation
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1) # Cut the value of the image to 0-1
    return image

Read style image and visualize

# Read style image
style = load_image('./data/COCO/COCO/COCO_val2014_000000000139.jpg', shape = (256, 256)).to(device)
# Visual image
plt.figure()
plt.imshow(im_convert(style))
plt.axis('off')
plt.show()


Fast style migration network training and data visualization display

As with ordinary style migration, first calculate the Gram matrix of the input tensor:

# Define and calculate Gram matrix
def gram_matrix(tensor):
    '''
    Calculate the image style features Gram Matrix, which can ultimately ensure the content,
    Carry out style transmission. tensor: It is a layer of feature mapping after forward calculation of an image
    '''
    # Obtain the batch of tensor_ size, channel, height, width
    b, c, h, w = tensor.size()
    # Change the dimension of the matrix to (depth, height * width)
    tensor = tensor.view(b, c, h * w)
    tensor_t = tensor.transpose(1, 2)
    # Calculate gram matrix for multiple images
    gram = tensor.bmm(tensor_t) / (c * h * w)
    return gram

Note that since the input data uses a batch feature mapping, when the tensor is multiplied by its transpose, it is necessary to calculate the Gram matrix of each image, so tensor is used BMM method completes the related matrix multiplication calculation

Define get features to obtain the feature mapping of image data on the specified layer of the specified network:

# Defines a method for obtaining the output of an image at a specified layer on the network
def get_features(image, model, layers = None):
    '''
    Put an image image In a network model Forward propagation calculation in,
    And get the specified layer layers Feature output in
    '''
    # Match the name of the mapping layer with the name in the paper
    if layers is None:
        layers = {'3': 'relu1_2',
                  '8': 'relu2_2',
                  '15': 'relu3_3', # Content layer representation
                  '22': 'relu4_3'} # Output after ReLU activation
    features = {} # The obtained features of each layer are saved in the dictionary
    x = image # The image of the feature needs to be acquired
    # model._modules is a dictionary that holds the information of each layer of the network model
    for name, layer in model._modules.items():
        # The features of the image are obtained from the first layer
        x = layer(x)
        # If it is the feature specified by the layers parameter, it is saved to features
        if name in layers:
            features[layers[name]] = x
    return features

Thereinto, relu3_ The feature mapping output from layer 3 is used to measure the similarity of image content.

Next, calculate the style of the image 4 4 4 Gram matrices on the specified multiple layers and save them in a dictionary

# Style representation of computational style image
style_layer = {'3': 'relu1_2',
               '8': 'relu2_2',
               '15': 'relu3_3',
               '22': 'relu4_3'}
content_layer = {'15': 'relu3_3'}
# The layers represented by content use the output after relu activation
style_features = get_features(style, vgg, layers = style_layer)
# Calculate the Gram matrix of each layer for our style representation and save it in a dictionary
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

Next, start training the network. In the training process, three kinds of losses are defined: style loss, content loss and total variation loss. Their weight is 1 0 5 , 1 , 1 0 − 5 10^5,1,10^{-5} 105,1,10 − 5, the optimizer is Adam, and the learning rate is 0.0003 0.0003 0.0003 . in the light of 4 4 More than 40000 image data, each 4 4 Four images are a batch for training 4 4 4 epoch s, i.e. about 40000 40000 40000 iterations.

# Network training, define the weight of three kinds of losses
style_weight = 1e5
content_weight = 1
tv_weight = 1e-5
# Define optimizer
optimizer = optim.Adam(myfwnet.parameters(), lr = 1e-3)
myfwnet.train()
since = time.time()
for epoch in range(4):
    print('Epoch: {}'.format(epoch + 1))
    content_loss_all = []
    style_loss_all = []
    tv_loss_all = []
    all_loss = []
    for step, batch in enumerate(data_loader):
        optimizer.zero_grad()

        # Calculate the output of the content image after using the image conversion network
        content_images = batch[0].to(device)
        transformed_images = myfwnet(content_images)
        transformed_images = transformed_images.clamp(-2.1, 2.7)

        # Use VGG16 to calculate the content corresponding to the original image_ Layer features
        content_features = get_features(content_images, vgg, layers = content_layer)

        # Use VGG16 to calculate all features corresponding to the \ hat{y} image
        transformed_features = get_features(transformed_images, vgg)

        # Content loss
        # Use f.mse_ The loss function calculates the loss between transformed_images and content_images
        content_loss = F.mse_loss(transformed_features['relu3_3'], content_features['relu3_3'])
        content_loss = content_weight * content_loss

        # Total variation loss
        # The image is horizontally and vertically shifted by one pixel and subtracted from the original image
        # Then calculate the sum of the absolute values, which is tv_loss
        y = transformed_images # \hat{y}
        tv_loss = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))
        tv_loss = tv_weight * tv_loss

        # Style loss
        style_loss = 0
        transformed_grams = {layer: gram_matrix(transformed_features[layer]) for layer in transformed_features}
        for layer in style_grams:
            transformed_gram = transformed_grams[layer]
            # Is a Gram for a batch image
            style_gram = style_grams[layer]
            # It is for an image, so we need to expand the style_gram
            # And calculate the loss between transformed_gram and style_gram
            style_loss += F.mse_loss(transformed_gram,
                                style_gram.expand_as(transformed_gram))
        style_loss = style_weight * style_loss

        # The three losses add up and the gradient decreases
        loss = style_loss + content_loss + tv_loss
        loss.backward(retain_graph = True)
        optimizer.step()

        # Count the changes of each loss
        content_loss_all.append(content_loss.item())
        style_loss_all.append(style_loss.item())
        tv_loss_all.append(tv_loss.item())
        all_loss.append(loss.item())
        if step % 5000 == 0:
            print('step: {}; content loss: {:.3f}; style loss: {:.3f}; tv loss: {:.3f}, loss: {:.3f}'.format(step, content_loss.item(), style_loss.item(), tv_loss.item(), loss.item()))
            time_use = time.time() - since
            print('Train complete in {:.0f}m {:.0f}s'.format(time_use // 60, time_use % 60))
            # Visualize an image
            plt.figure()
            im = transformed_images[1, ...] # The ellipsis indicates that the following contents are not written
            plt.axis('off')
            plt.imshow(im_convert(im))
            plt.show()
Epoch: 1
step: 0; content loss: 21.736; style loss: 679.825; tv loss: 17.357, loss: 718.918
Train complete in 0m 10s

step: 5000; content loss: 11.223; style loss: 4.921; tv loss: 1.068, loss: 17.212
Train complete in 32m 21s

step: 10000; content loss: 10.715; style loss: 3.768; tv loss: 1.101, loss: 15.584
Train complete in 64m 34s

Epoch: 2
step: 0; content loss: 12.664; style loss: 3.324; tv loss: 1.182, loss: 17.170
Train complete in 65m 40s

step: 5000; content loss: 5.582; style loss: 3.621; tv loss: 1.234, loss: 10.438
Train complete in 97m 55s

step: 10000; content loss: 5.797; style loss: 3.302; tv loss: 1.209, loss: 10.308
Train complete in 130m 11s

Epoch: 3
step: 0; content loss: 4.639; style loss: 3.312; tv loss: 1.250, loss: 9.201
Train complete in 131m 16s

step: 5000; content loss: 4.507; style loss: 3.565; tv loss: 1.291, loss: 9.364
Train complete in 163m 32s

step: 10000; content loss: 4.570; style loss: 3.609; tv loss: 1.098, loss: 9.276
Train complete in 195m 48s

Epoch: 4
step: 0; content loss: 4.425; style loss: 2.844; tv loss: 1.239, loss: 8.509
Train complete in 196m 46s

step: 5000; content loss: 6.227; style loss: 4.176; tv loss: 1.231, loss: 11.633
Train complete in 229m 2s

step: 10000; content loss: 4.537; style loss: 3.191; tv loss: 1.178, loss: 8.906
Train complete in 261m 19s

# Save the trained network myfwnet
torch.save(myfwnet.state_dict(), './model/imfwnet_dict.pkl')

In order to test the style migration network fwnet obtained from training, the following randomly obtains the image of a batch in the data set for image style migration:

myfwnet.eval()
for step, batch in enumerate(data_loader):
    content_images = batch[0].to(device)
    if step > 0:
        break
plt.figure(figsize = (16, 4))
for ii in range(4):
    im = content_images[ii, ...]
    plt.subplot(1, 4, ii + 1)
    plt.axis('off')
    plt.imshow(im_convert(im))
plt.show()
transformed_images = myfwnet(content_images)
transformed_images = transformed_images.clamp(-2.1, 2.7)
plt.figure(figsize = (16, 4))
for ii in range(4):
    im = im_convert(transformed_images[ii, ...])
    plt.subplot(1, 4, ii + 1)
    plt.axis('off')
    plt.imshow(im)
plt.show()



Use the GPU pre training model

# Read content image
content = load_image('./data/COCO/COCO/COCO_val2014_000000000192.jpg', shape = (256, 256))
# Import the trained GPU network
device = torch.device('cpu')
newfwnet = ImfwNet()
newfwnet.load_state_dict(torch.load('./model/imfwnet_dict.pkl', map_location = device)) # The GPU model is mapped to a network based on CPU computing
transform_content = newfwnet(content)
# Visual image
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(im_convert(content))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(im_convert(transform_content))
plt.axis('off')
plt.show()


Generally speaking, common style migration takes a long time (it will take several hours), but the effect of style migration is good.

Fast style migration is very fast (the network has been trained, which is an offline process), but the effect is not so ideal.

Keywords: AI neural networks Pytorch Computer Vision Deep Learning

Added by idire on Wed, 16 Feb 2022 20:03:05 +0200