# Pytoch: fast migration of image style

### Pytoch: fast migration of image style - residual network, fixed style and arbitrary content

Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, School of Artificial and Intelligence, Huazhong University of Science and Technology

This tutorial is not for commercial use. It is only for learning and reference. If you need to reprint it, please contact me.

#### Reference

Perceptual Losses for Real-Time Style Transfer and Super-Resolution

ResNet

Unlike ordinary style migration, the input image of ordinary image style migration is random noise, while the input of fast style migration is an image conversion network f w fw fw output.

Fast style migration is through the input image x x x through image conversion network f w fw fw, get the output of the network y ^ \hat{y} y^​ . Therefore, it can realize fast image migration of arbitrary content.

Refer to the article perceptual losses for real time style transfer and super resolution to adjust the up sampling operation of the image conversion network accordingly. In the established network, the transpose convolution operation will be used for up sampling of feature mapping.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import time

import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision import models

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
print(device)
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

cuda
1
GeForce MX250


#### Fast style migration network preparation

adopt 3 3 Three convolution layers reduce the dimension of the image feature mapping, and then 5 5 Five residual connection layers, learn the image style and add it to the content image. Finally, through 3 3 Three transpose convolution operations are used to upgrade the dimension of the feature map (analogous to the semantic segmentation network) to reconstruct the image after style migration.

In the dimension upgrading operation of the conversion network, transpose convolution is used to replace the combination of up sampling and convolution layer in the article, because the input is the standardized image, and the pixel value range is − 2.1 − 2.7 -2.1-2.7 − 2.1 − 2.7, so in the last output layer of the network, the activation function is not used, and most of the output values of the network will be in − 2.1 − 2.7 -2.1-2.7 Between − 2.1 − 2.7, only a small part is not in this interval, so when actually training the network, the output will be cut to − 2.1 − 2.7 -2.1-2.7 Between − 2.1 − 2.7, that is, the last layer does not need to use the activation function, and other layers use the ReLU function. In the network, the number of feature maps gradually changes from 3 3 3 added to 128 128 128, and each residual connection layer has 128 128 128 feature maps, and the number of feature maps in the transposed convolution layer will increase from 128 128 128 down to 3 3 3. Three channels corresponding to the image.

##### Structural block definition residuals

Focus on the local of neural network. Set the input to X. Suppose that the ideal mapping we want to learn is f(x), which is used as the input of the activation function. The part needs to fit the residual mapping f(x) − X of identity mapping. Residual mapping is often easier to optimize in practice. Take identity mapping as the ideal mapping f(x) we want to learn. We only need to learn the weight and deviation parameters of weighting operations (such as affine) 0 0 0, then f(x) is an identity map. In practice, when the ideal mapping f(x) is very close to the identity mapping, the residual mapping is also easy to capture the subtle fluctuations of the identity mapping. In the residual block, the input can propagate forward faster through cross layer data lines.

Define residual connection network, 128 128 128 feature maps, active size 128 × 64 × 64 128\times64\times64 128×64×64

# ResidualBlock residual block
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size = 3, stride = 1, padding = 1),
nn.ReLU(),
nn.Conv2d(channels, channels, kernel_size = 3, stride = 1, padding = 1)
)
def forward(self, x):
return F.relu(self.conv(x) + x)

##### Define image conversion network

The lower sampling module, 5 5 Five residual connection modules and up sampling module

# Define image conversion network
class ImfwNet(nn.Module):
def __init__(self):
super(ImfwNet, self).__init__()
# Down sampling
self.downsample = nn.Sequential(
nn.Conv2d(3, 32, kernel_size = 9, stride = 1),
nn.InstanceNorm2d(32, affine = True), # Normalize pixel values
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size = 3, stride = 2),
nn.InstanceNorm2d(64, affine = True),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size = 3, stride = 2),
nn.InstanceNorm2d(128, affine = True),
nn.ReLU()
)
# 5 residual connections
self.res_blocks = nn.Sequential(
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
)
# Up sampling
self.unsample = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
nn.InstanceNorm2d(64, affine = True),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
nn.InstanceNorm2d(32, affine = True),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size = 9, stride = 1, padding = 4)
)
def forward(self, x):
x = self.downsample(x) # The input pixel value is between - 2.1-2.7
x = self.res_blocks(x)
x = self.unsample(x) # The output pixel value is between - 2.1-2.7
return x

myfwnet = ImfwNet().to(device)

from torchsummary import summary
summary(myfwnet, input_size=(3, 256, 256))

----------------------------------------------------------------
Layer (type)               Output Shape         Param #
================================================================
ReflectionPad2d-1          [-1, 3, 264, 264]               0
Conv2d-2         [-1, 32, 256, 256]           7,808
InstanceNorm2d-3         [-1, 32, 256, 256]              64
ReLU-4         [-1, 32, 256, 256]               0
Conv2d-5         [-1, 64, 127, 127]          18,496
InstanceNorm2d-6         [-1, 64, 127, 127]             128
ReLU-7         [-1, 64, 127, 127]               0
ReflectionPad2d-8         [-1, 64, 129, 129]               0
Conv2d-9          [-1, 128, 64, 64]          73,856
InstanceNorm2d-10          [-1, 128, 64, 64]             256
ReLU-11          [-1, 128, 64, 64]               0
Conv2d-12          [-1, 128, 64, 64]         147,584
ReLU-13          [-1, 128, 64, 64]               0
Conv2d-14          [-1, 128, 64, 64]         147,584
ResidualBlock-15          [-1, 128, 64, 64]               0
Conv2d-16          [-1, 128, 64, 64]         147,584
ReLU-17          [-1, 128, 64, 64]               0
Conv2d-18          [-1, 128, 64, 64]         147,584
ResidualBlock-19          [-1, 128, 64, 64]               0
Conv2d-20          [-1, 128, 64, 64]         147,584
ReLU-21          [-1, 128, 64, 64]               0
Conv2d-22          [-1, 128, 64, 64]         147,584
ResidualBlock-23          [-1, 128, 64, 64]               0
Conv2d-24          [-1, 128, 64, 64]         147,584
ReLU-25          [-1, 128, 64, 64]               0
Conv2d-26          [-1, 128, 64, 64]         147,584
ResidualBlock-27          [-1, 128, 64, 64]               0
Conv2d-28          [-1, 128, 64, 64]         147,584
ReLU-29          [-1, 128, 64, 64]               0
Conv2d-30          [-1, 128, 64, 64]         147,584
ResidualBlock-31          [-1, 128, 64, 64]               0
ConvTranspose2d-32         [-1, 64, 128, 128]          73,792
InstanceNorm2d-33         [-1, 64, 128, 128]             128
ReLU-34         [-1, 64, 128, 128]               0
ConvTranspose2d-35         [-1, 32, 256, 256]          18,464
InstanceNorm2d-36         [-1, 32, 256, 256]              64
ReLU-37         [-1, 32, 256, 256]               0
ConvTranspose2d-38          [-1, 3, 256, 256]           7,779
================================================================
Total params: 1,676,675
Trainable params: 1,676,675
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 246.85
Params size (MB): 6.40
Estimated Total Size (MB): 253.99
----------------------------------------------------------------

# Output network structure
from torchviz import make_dot

x = torch.randn(1, 3, 256, 256).requires_grad_(True)
y = myfwnet(x.to(device))
myResNet_vis = make_dot(y, params=dict(list(myfwnet.named_parameters()) + [('x', x)]))
myResNet_vis


#### Fast style migration data preparation

The validation set of COCO2014 is used as model input.

# Define image preprocessing
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256), # The image size is 256 * 256
transforms.ToTensor(), # Tensor converted to 0-1
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
# The pixel value is changed to -2.1-2.7
])

# Read data from folder
dataset = ImageFolder('./data/COCO', transform = data_transform)
# Each batch uses 4 images
num_workers = 8, pin_memory = True)
dataset

Dataset ImageFolder
Number of datapoints: 40504
Root location: ./data/COCO
StandardTransform
Transform: Compose(
Resize(size=256, interpolation=bilinear)
CenterCrop(size=(256, 256))
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)


Description: parameter pin_memory means that when creating a DataLoader, the Tensor data generated first belongs to the lock page memory in memory (all the video memory in the graphics card is lock page memory), so it will be faster to transfer the Tensor of memory to the video memory of GPU, and the operation speed for high-performance GPU will be faster.

Next, read the pre trained VGG16 network, only need the layer contained in the features, and set it to the GPU device. In the calculation, we only need to use VGG network to extract the feature map of a specific layer, without training the parameters, and set it to eval

# Read pre trained VGG16 network
vgg16 = models.vgg16(pretrained = True)
# No classifier is needed, only convolution layer and pooling layer are needed
vgg = vgg16.features.to(device).eval()


Define a method that can read style images and convert them to a four-dimensional tensor format that can be used by VGG networks.

# Define a reading style image function and convert the image as necessary
image = Image.open(img_path)
size = image.size
if shape is not None:
size = shape # If the image size is specified, it is converted to the specified size
# Use transforms to convert the image into tensor and standardize it
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(), # Tensor converted to 0-1
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
])
# Use the RGB channel of the image and add the batch dimension
image = in_transform(image)[:3, :, :].unsqueeze(dim = 0)
return image

# Define a function to convert the standardized image into a visual function using matplotlib
def im_convert(tensor):
'''
take[1, c, h, w]The tensor of dimension is transformed into[h, w, c]Array of
Because the tensor is transformed into a table, it is necessary to carry out standardized inverse transformation
'''
tensor = tensor.cpu()
image = tensor.data.numpy().squeeze() # Remove data from batch dimension
image = image.transpose(1, 2, 0) # Permutation dimension [C, W, H -] [C, H]
# Conduct standardized reverse operation
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1) # Cut the value of the image to 0-1
return image


# Read style image
style = load_image('./data/COCO/COCO/COCO_val2014_000000000139.jpg', shape = (256, 256)).to(device)
# Visual image
plt.figure()
plt.imshow(im_convert(style))
plt.axis('off')
plt.show()


#### Fast style migration network training and data visualization display

As with ordinary style migration, first calculate the Gram matrix of the input tensor:

# Define and calculate Gram matrix
def gram_matrix(tensor):
'''
Calculate the image style features Gram Matrix, which can ultimately ensure the content,
Carry out style transmission. tensor: It is a layer of feature mapping after forward calculation of an image
'''
# Obtain the batch of tensor_ size, channel, height, width
b, c, h, w = tensor.size()
# Change the dimension of the matrix to (depth, height * width)
tensor = tensor.view(b, c, h * w)
tensor_t = tensor.transpose(1, 2)
# Calculate gram matrix for multiple images
gram = tensor.bmm(tensor_t) / (c * h * w)
return gram


Note that since the input data uses a batch feature mapping, when the tensor is multiplied by its transpose, it is necessary to calculate the Gram matrix of each image, so tensor is used BMM method completes the related matrix multiplication calculation

Define get features to obtain the feature mapping of image data on the specified layer of the specified network:

# Defines a method for obtaining the output of an image at a specified layer on the network
def get_features(image, model, layers = None):
'''
Put an image image In a network model Forward propagation calculation in,
And get the specified layer layers Feature output in
'''
# Match the name of the mapping layer with the name in the paper
if layers is None:
layers = {'3': 'relu1_2',
'8': 'relu2_2',
'15': 'relu3_3', # Content layer representation
'22': 'relu4_3'} # Output after ReLU activation
features = {} # The obtained features of each layer are saved in the dictionary
x = image # The image of the feature needs to be acquired
# model._modules is a dictionary that holds the information of each layer of the network model
for name, layer in model._modules.items():
# The features of the image are obtained from the first layer
x = layer(x)
# If it is the feature specified by the layers parameter, it is saved to features
if name in layers:
features[layers[name]] = x
return features


Thereinto, relu3_ The feature mapping output from layer 3 is used to measure the similarity of image content.

Next, calculate the style of the image 4 4 4 Gram matrices on the specified multiple layers and save them in a dictionary

# Style representation of computational style image
style_layer = {'3': 'relu1_2',
'8': 'relu2_2',
'15': 'relu3_3',
'22': 'relu4_3'}
content_layer = {'15': 'relu3_3'}
# The layers represented by content use the output after relu activation
style_features = get_features(style, vgg, layers = style_layer)
# Calculate the Gram matrix of each layer for our style representation and save it in a dictionary
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}


Next, start training the network. In the training process, three kinds of losses are defined: style loss, content loss and total variation loss. Their weight is 1 0 5 , 1 , 1 0 − 5 10^5,1,10^{-5} 105,1,10 − 5, the optimizer is Adam, and the learning rate is 0.0003 0.0003 0.0003 . in the light of 4 4 More than 40000 image data, each 4 4 Four images are a batch for training 4 4 4 epoch s, i.e. about 40000 40000 40000 iterations.

# Network training, define the weight of three kinds of losses
style_weight = 1e5
content_weight = 1
tv_weight = 1e-5
# Define optimizer
optimizer = optim.Adam(myfwnet.parameters(), lr = 1e-3)

myfwnet.train()
since = time.time()
for epoch in range(4):
print('Epoch: {}'.format(epoch + 1))
content_loss_all = []
style_loss_all = []
tv_loss_all = []
all_loss = []

# Calculate the output of the content image after using the image conversion network
content_images = batch[0].to(device)
transformed_images = myfwnet(content_images)
transformed_images = transformed_images.clamp(-2.1, 2.7)

# Use VGG16 to calculate the content corresponding to the original image_ Layer features
content_features = get_features(content_images, vgg, layers = content_layer)

# Use VGG16 to calculate all features corresponding to the \ hat{y} image
transformed_features = get_features(transformed_images, vgg)

# Content loss
# Use f.mse_ The loss function calculates the loss between transformed_images and content_images
content_loss = F.mse_loss(transformed_features['relu3_3'], content_features['relu3_3'])
content_loss = content_weight * content_loss

# Total variation loss
# The image is horizontally and vertically shifted by one pixel and subtracted from the original image
# Then calculate the sum of the absolute values, which is tv_loss
y = transformed_images # \hat{y}
tv_loss = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))
tv_loss = tv_weight * tv_loss

# Style loss
style_loss = 0
transformed_grams = {layer: gram_matrix(transformed_features[layer]) for layer in transformed_features}
for layer in style_grams:
transformed_gram = transformed_grams[layer]
# Is a Gram for a batch image
style_gram = style_grams[layer]
# It is for an image, so we need to expand the style_gram
# And calculate the loss between transformed_gram and style_gram
style_loss += F.mse_loss(transformed_gram,
style_gram.expand_as(transformed_gram))
style_loss = style_weight * style_loss

loss = style_loss + content_loss + tv_loss
loss.backward(retain_graph = True)
optimizer.step()

# Count the changes of each loss
content_loss_all.append(content_loss.item())
style_loss_all.append(style_loss.item())
tv_loss_all.append(tv_loss.item())
all_loss.append(loss.item())
if step % 5000 == 0:
print('step: {}; content loss: {:.3f}; style loss: {:.3f}; tv loss: {:.3f}, loss: {:.3f}'.format(step, content_loss.item(), style_loss.item(), tv_loss.item(), loss.item()))
time_use = time.time() - since
print('Train complete in {:.0f}m {:.0f}s'.format(time_use // 60, time_use % 60))
# Visualize an image
plt.figure()
im = transformed_images[1, ...] # The ellipsis indicates that the following contents are not written
plt.axis('off')
plt.imshow(im_convert(im))
plt.show()

Epoch: 1
step: 0; content loss: 21.736; style loss: 679.825; tv loss: 17.357, loss: 718.918
Train complete in 0m 10s


step: 5000; content loss: 11.223; style loss: 4.921; tv loss: 1.068, loss: 17.212
Train complete in 32m 21s


step: 10000; content loss: 10.715; style loss: 3.768; tv loss: 1.101, loss: 15.584
Train complete in 64m 34s


Epoch: 2
step: 0; content loss: 12.664; style loss: 3.324; tv loss: 1.182, loss: 17.170
Train complete in 65m 40s


step: 5000; content loss: 5.582; style loss: 3.621; tv loss: 1.234, loss: 10.438
Train complete in 97m 55s


step: 10000; content loss: 5.797; style loss: 3.302; tv loss: 1.209, loss: 10.308
Train complete in 130m 11s


Epoch: 3
step: 0; content loss: 4.639; style loss: 3.312; tv loss: 1.250, loss: 9.201
Train complete in 131m 16s


step: 5000; content loss: 4.507; style loss: 3.565; tv loss: 1.291, loss: 9.364
Train complete in 163m 32s


step: 10000; content loss: 4.570; style loss: 3.609; tv loss: 1.098, loss: 9.276
Train complete in 195m 48s


Epoch: 4
step: 0; content loss: 4.425; style loss: 2.844; tv loss: 1.239, loss: 8.509
Train complete in 196m 46s


step: 5000; content loss: 6.227; style loss: 4.176; tv loss: 1.231, loss: 11.633
Train complete in 229m 2s


step: 10000; content loss: 4.537; style loss: 3.191; tv loss: 1.178, loss: 8.906
Train complete in 261m 19s


# Save the trained network myfwnet
torch.save(myfwnet.state_dict(), './model/imfwnet_dict.pkl')


In order to test the style migration network fwnet obtained from training, the following randomly obtains the image of a batch in the data set for image style migration:

myfwnet.eval()
content_images = batch[0].to(device)
if step > 0:
break
plt.figure(figsize = (16, 4))
for ii in range(4):
im = content_images[ii, ...]
plt.subplot(1, 4, ii + 1)
plt.axis('off')
plt.imshow(im_convert(im))
plt.show()
transformed_images = myfwnet(content_images)
transformed_images = transformed_images.clamp(-2.1, 2.7)
plt.figure(figsize = (16, 4))
for ii in range(4):
im = im_convert(transformed_images[ii, ...])
plt.subplot(1, 4, ii + 1)
plt.axis('off')
plt.imshow(im)
plt.show()


#### Use the GPU pre training model

# Read content image
content = load_image('./data/COCO/COCO/COCO_val2014_000000000192.jpg', shape = (256, 256))
# Import the trained GPU network
device = torch.device('cpu')
newfwnet = ImfwNet()
newfwnet.load_state_dict(torch.load('./model/imfwnet_dict.pkl', map_location = device)) # The GPU model is mapped to a network based on CPU computing
transform_content = newfwnet(content)
# Visual image
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(im_convert(content))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(im_convert(transform_content))
plt.axis('off')
plt.show()


Generally speaking, common style migration takes a long time (it will take several hours), but the effect of style migration is good.

Fast style migration is very fast (the network has been trained, which is an offline process), but the effect is not so ideal.

Added by idire on Wed, 16 Feb 2022 20:03:05 +0200