"Online game" document shadow elimination based on user-defined training template

「 🍔 Customize training template 🍔」 Example of eliminating online game based on document shadow

I believe you have the same confusion as me when writing projects with paste:

The encapsulation of advanced APIs is too good, and it is difficult to add some functions when they need to be customized; The low-level API function is very single. To realize many functions, you need to write a lot of extra code.

If you can have a template, it has relatively perfect functions. You only need to modify part of the code to quickly realize your ideas.

Based on this idea, I have completed a set of training model templates, which have the functions of customizing output, continuing training after model interruption, saving the optimal model and so on.

You can quickly modify and implement your ideas based on this template.

📖 0 project background

In life, when using mobile phones to scan documents, due to angle reasons, it is inevitable to leave annoying shadows left in the photos when raising hands to shoot. In order to help people solve such problems and reduce the interference caused by shadows, players need to train the model through deep learning technology, process the pictures with shadows collected in a given real scene, restore the original appearance of the pictures, and finally output the processed result pictures.
Shaded pictureUnshaded picture

🥦 1 data introduction

  • The newly released data set of this competition consists of training set, test set of list A and test set of List B, of which 1410 samples are in training set (picture number is discontinuous), 300 samples are in test set of list A (picture number is continuous) and 397 samples are in test set of list B;

  • images is the source image data with shadow, and gts is the truth data without shadow; (GT of test data set is not open)

  • Images and images in gts correspond one by one according to the image name.

🌰 2. The function of the template has been realized

💡 1: output customized information for each epoch, and use tqdm progress bar to check the running time of each epoch in real time (correct the confusion caused by windows terminal or abnormal interruption of tqdm)

💡 2: you can directly define and modify the loss function and model, and give examples

💡 3: you can choose the incoming model to continue training or train from scratch

💡 4: after the training, save the log file in the form of csv

💡 5: customize the evaluation index and automatically save the optimal model according to the evaluation index results

💡 6: complete the whole process code of data enhancement, model training, testing and reasoning

💡 7: realize the segmentation of input image, which can effectively improve the accuracy of image tasks

⚙️ 3 code implementation and explanation

# Decompress data set
!unzip  -o data/data125945/shadowRemovalOfDoc.zip

!unzip delight_testA_dataset.zip -d data/ >>/dev/null
!unzip delight_train_dataset.zip -d data/ >>/dev/null
# Install required library functions
!pip install scikit_image==0.15.0

⚔️ 3.1 image blocking

  • This part of the code is used to segment the image into small pieces and save them, including seven different combinations of data enhancement methods such as flipping and rotation.

prepare_data( patch_size=256, stride=200, aug_times=1)

  • patch_size: the size of the segmented square block

  • Stripe: the segmentation step size, and blocks are segmented every other segmentation step size

  • aug_times: enhancement times. Each block generates several enhanced images

CODE 4, line 32, scales = [1] # to randomly enlarge and shrink the data

  • This parameter is a List used to scale the image.

[1, 1.2] represents the original image and performs segmentation after expanding the image by 1.2 times, which increases the amount of training data by twice that of [1]

# Define offline data enhancement methods
def data_augmentation(image,label, mode):
    out = np.transpose(image, (1,2,0))
    out_label = np.transpose(label, (1,2,0))
    if mode == 0:
        # original
        out = out
        out_label = out_label
    elif mode == 1:
        # flip up and down
        out = np.flipud(out)
        out_label = np.flipud(out_label)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(out)
        out_label = np.rot90(out_label)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(out)
        out = np.flipud(out)

        out_label = np.rot90(out_label)
        out_label = np.flipud(out_label)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(out, k=2)
        out_label = np.rot90(out_label, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(out, k=2)
        out = np.flipud(out)

        out_label = np.rot90(out_label, k=2)
        out_label = np.flipud(out_label)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(out, k=3)

        out_label = np.rot90(out_label, k=3)

    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(out, k=3)
        out = np.flipud(out)

        out_label = np.rot90(out_label, k=3)
        out_label = np.flipud(out_label)

    return  out,out_label
## Making block data sets
import cv2
import numpy as np
import math
import glob 
import os

def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])

def prepare_data(patch_size, stride, aug_times=1):
    '''
    This function is used to cut the image into squares and enhance the data
    patch_size:  Size of image block, this project 200*200
        stride:  Step size, interval of each image block
     aug_times:  Data enhancement times: select one of the eight enhancement methods by default
    '''
    # train
    print('process training data')
    scales = [1] # Random scaling of data
    files = glob.glob(os.path.join('data/delight_train_dataset/images', '*.jpg'))
    files.sort()

    img_folder = 'work/img_patch'
    if  not os.path.exists(img_folder):
        os.mkdir(img_folder)

    label_folder = 'work/label_patch'
    if  not os.path.exists(label_folder):
        os.mkdir(label_folder)

    train_num = 0
    for i in range(len(files)):
        img = cv2.imread(files[i])
        label = cv2.imread(files[i].replace('images','gts'))
        h, w, c = img.shape
        for k in range(len(scales)):
            Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
            Label = cv2.resize(label, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)

            Img = np.transpose(Img, (2,0,1))
            Label = np.transpose(Label, (2,0,1))

            Img = np.float32(np.clip(Img,0,255))
            Label = np.float32(np.clip(Label,0,255))

            patches = Im2Patch(Img, win=patch_size, stride=stride)
            label_patches = Im2Patch(Label, win=patch_size, stride=stride)
            print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))

            for n in range(patches.shape[3]):

                data = patches[:,:,:,n].copy()
                label_data = label_patches[:,:,:,n].copy()
            
                for m in range(aug_times):
                    data_aug,label_aug = data_augmentation(data,label_data, np.random.randint(1,8))
                    label_name = os.path.join(label_folder,str(train_num)+"_aug_%d" % (m+1)+'.jpg')
                    image_name = os.path.join(img_folder,str(train_num)+"_aug_%d" % (m+1)+'.jpg')
                    
                    cv2.imwrite(image_name, data_aug,[int( cv2.IMWRITE_JPEG_QUALITY), 100])
                    cv2.imwrite(label_name, label_aug,[int( cv2.IMWRITE_JPEG_QUALITY), 100])

                    train_num += 1

    print('training set, # samples %d\n' % train_num)
   
## Generate data

prepare_data( patch_size=256, stride=200, aug_times=1) 

⚔️ 3.2 rewrite data reading class

  • This part of the code uses padding vision. Transformation built-in image enhancement method

# Rewrite data reading class
import paddle
import paddle.vision.transforms as T
import numpy as np
import glob
import cv2

# Rewrite data reading class
class DEshadowDataset(paddle.io.Dataset):

    def __init__(self,mode = 'train',is_transforms = False):
       
        label_path_ ='work/label_patch/*.jpg'
        jpg_path_ ='work/img_patch/*.jpg'
        self.label_list_ = glob.glob(label_path_)        
        self.jpg_list_ = glob.glob(jpg_path_)
        
        self.is_transforms = is_transforms
        self.mode = mode

        scale_point = 0.95
        
        self.transforms =T.Compose([
            T.Normalize(data_format='HWC',),
            T.HueTransform(0.4),
            T.SaturationTransform(0.4),
            T.HueTransform(0.4),
            T.ToTensor(),
            ])


        # Select the first 95% training and the last 5% verification
        if self.mode == 'train':
            self.jpg_list_ = self.jpg_list_[:int(scale_point*len(self.jpg_list_))]
            self.label_list_ = self.label_list_[:int(scale_point*len(self.label_list_))]

        else:
            self.jpg_list_ = self.jpg_list_[int(scale_point*len(self.jpg_list_)):]
            self.label_list_ = self.label_list_[int(scale_point*len(self.label_list_)):]

    def __getitem__(self, index):
        jpg_ = self.jpg_list_[index]
        label_ =  self.label_list_[index]

        data = cv2.imread(jpg_) # Read and code in the same directory lena.png # To 0-1
        mask = cv2.imread(label_)

        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) # BGR 2 RGB
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) # BGR 2 RGB
        
        data = np.uint8(data)
        mask = np.uint8(mask)

        if self.is_transforms:
            data = self.transforms(data)
            data = data/255
            mask = T.functional.to_tensor(mask) 

        return  data,mask  

    def __len__(self):
        return len(self.jpg_list_)  
  
# Data reading and enhanced visualization
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image

dataset = DEshadowDataset(mode='train',is_transforms = False )
print('=============train dataset=============')
img_,mask_ = dataset[3] # mask is always greater than img

img = Image.fromarray(img_)
mask = Image.fromarray(mask_)

#When the image to be saved is a gray image, the numpy scale of the gray image is [1, h, w]. Need to change [1, h, w] to [h, w]

plt.figure(figsize=(12, 6))
plt.subplot(1,2,1),plt.xticks([]),plt.yticks([]),plt.imshow(img)
plt.subplot(1,2,2),plt.xticks([]),plt.yticks([]),plt.imshow(mask)
plt.show()

⚔️ 3.3 rewrite model

Please click the link to view the details of network introduction.

# Define KIUnet 
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import initializer

def init_weights(init_type='kaiming'):
    if init_type == 'normal':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())
    elif init_type == 'xavier':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())
    elif init_type == 'kaiming':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)


class kiunet(nn.Layer):
    def __init__(self ,in_channels = 3, n_classes =3):
        super(kiunet,self).__init__()

        self.in_channels = in_channels
        self.n_class = n_classes

        self.encoder1 = nn.Conv2D(self.in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB 
        self.en1_bn = nn.BatchNorm(16)
        self.encoder2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)  
        self.en2_bn = nn.BatchNorm(32)
        self.encoder3=   nn.Conv2D(32, 64, 3, stride=1, padding=1)
        self.en3_bn = nn.BatchNorm(64)

        self.decoder1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)   
        self.de1_bn = nn.BatchNorm(32)
        self.decoder2 =   nn.Conv2D(32,16, 3, stride=1, padding=1)
        self.de2_bn = nn.BatchNorm(16)
        self.decoder3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)
        self.de3_bn = nn.BatchNorm(8)

        self.decoderf1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)
        self.def1_bn = nn.BatchNorm(32)
        self.decoderf2=   nn.Conv2D(32, 16, 3, stride=1, padding=1)
        self.def2_bn = nn.BatchNorm(16)
        self.decoderf3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)
        self.def3_bn = nn.BatchNorm(8)

        self.encoderf1 =   nn.Conv2D(in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB 
        self.enf1_bn = nn.BatchNorm(16)
        self.encoderf2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)
        self.enf2_bn = nn.BatchNorm(32)
        self.encoderf3 =   nn.Conv2D(32, 64, 3, stride=1, padding=1)
        self.enf3_bn = nn.BatchNorm(64)

        self.intere1_1 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.inte1_1bn = nn.BatchNorm(16)
        self.intere2_1 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.inte2_1bn = nn.BatchNorm(32)
        self.intere3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.inte3_1bn = nn.BatchNorm(64)

        self.intere1_2 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.inte1_2bn = nn.BatchNorm(16)
        self.intere2_2 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.inte2_2bn = nn.BatchNorm(32)
        self.intere3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.inte3_2bn = nn.BatchNorm(64)

        self.interd1_1 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.intd1_1bn = nn.BatchNorm(32)
        self.interd2_1 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.intd2_1bn = nn.BatchNorm(16)
        self.interd3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.intd3_1bn = nn.BatchNorm(64)

        self.interd1_2 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.intd1_2bn = nn.BatchNorm(32)
        self.interd2_2 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.intd2_2bn = nn.BatchNorm(16)
        self.interd3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.intd3_2bn = nn.BatchNorm(64)

        self.final = nn.Sequential(
            nn.Conv2D(8,self.n_class,1,stride=1,padding=0),
            nn.AdaptiveAvgPool2D(output_size=1))

        # initialise weights
        for m in self.sublayers ():
            if isinstance(m, nn.Conv2D):
                m.weight_attr = init_weights(init_type='kaiming')
                m.bias_attr = init_weights(init_type='kaiming')
            elif isinstance(m, nn.BatchNorm):
                m.param_attr =init_weights(init_type='kaiming')
                m.bias_attr = init_weights(init_type='kaiming') 

    def forward(self, x):
        # input: c * h * w -> 16 * h/2 * w/2
        out = F.relu(self.en1_bn(F.max_pool2d(self.encoder1(x),2,2)))  #U-Net branch
        # c * h * w -> 16 * 2h * 2w
        out1 = F.relu(self.enf1_bn(F.interpolate(self.encoderf1(x),scale_factor=(2,2),mode ='bicubic'))) #Ki-Net branch
        # 16 * h/2 * w/2
        tmp = out
        # 16 * 2h * 2w -> 16 * h/2 * w/2
        out = paddle.add(out,F.interpolate(F.relu(self.inte1_1bn(self.intere1_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic')) #CRFB
        # 16 * h/2 * w/2 -> 16 * 2h * 2w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte1_2bn(self.intere1_2(tmp))),scale_factor=(4,4),mode ='bicubic')) #CRFB
        
        # 16 * h/2 * w/2
        u1 = out  #skip conn
        # 16 * 2h * 2w
        o1 = out1  #skip conn

        # 16 * h/2 * w/2 -> 32 * h/4 * w/4
        out = F.relu(self.en2_bn(F.max_pool2d(self.encoder2(out),2,2)))
        # 16 * 2h * 2w -> 32 * 4h * 4w
        out1 = F.relu(self.enf2_bn(F.interpolate(self.encoderf2(out1),scale_factor=(2,2),mode ='bicubic')))
        #  32 * h/4 * w/4
        tmp = out
        # 32 * 4h * 4w -> 32 * h/4 *w/4
        out = paddle.add(out,F.interpolate(F.relu(self.inte2_1bn(self.intere2_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))
        # 32 * h/4 * w/4 -> 32 *4h *4w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte2_2bn(self.intere2_2(tmp))),scale_factor=(16,16),mode ='bicubic'))
        
        #  32 * h/4 *w/4
        u2 = out
        #  32 *4h *4w
        o2 = out1
        
        # 32 * h/4 *w/4 -> 64 * h/8 *w/8
        out = F.relu(self.en3_bn(F.max_pool2d(self.encoder3(out),2,2)))
        # 32 *4h *4w -> 64 * 8h *8w
        out1 = F.relu(self.enf3_bn(F.interpolate(self.encoderf3(out1),scale_factor=(2,2),mode ='bicubic')))
        #  64 * h/8 *w/8 
        tmp = out
        #  64 * 8h *8w -> 64 * h/8 * w/8
        out = paddle.add(out,F.interpolate(F.relu(self.inte3_1bn(self.intere3_1(out1))),scale_factor=(0.015625,0.015625),mode ='bicubic'))
        #  64 * h/8 *w/8 -> 64 * 8h * 8w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte3_2bn(self.intere3_2(tmp))),scale_factor=(64,64),mode ='bicubic'))
        
        ### End of encoder block

        ### Start Decoder
        
        # 64 * h/8 * w/8 -> 32 * h/4 * w/4 
        out = F.relu(self.de1_bn(F.interpolate(self.decoder1(out),scale_factor=(2,2),mode ='bicubic')))  #U-NET
        # 64 * 8h * 8w -> 32 * 4h * 4w 
        out1 = F.relu(self.def1_bn(F.max_pool2d(self.decoderf1(out1),2,2))) #Ki-NET
        # 32 * h/4 * w/4 
        tmp = out
        # 32 * 4h * 4w  -> 32 * h/4 * w/4 
        out = paddle.add(out,F.interpolate(F.relu(self.intd1_1bn(self.interd1_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))
        # 32 * h/4 * w/4  -> 32 * 4h * 4w 
        out1 = paddle.add(out1,F.interpolate(F.relu(self.intd1_2bn(self.interd1_2(tmp))),scale_factor=(16,16),mode ='bicubic'))
        
        # 32 * h/4 * w/4 
        out = paddle.add(out,u2)  #skip conn
        # 32 * 4h * 4w 
        out1 = paddle.add(out1,o2)  #skip conn

        # 32 * h/4 * w/4 -> 16 * h/2 * w/2 
        out = F.relu(self.de2_bn(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bicubic')))
        # 32 * 4h * 4w  -> 16 * 2h * 2w
        out1 = F.relu(self.def2_bn(F.max_pool2d(self.decoderf2(out1),2,2)))
        # 16 * h/2 * w/2 
        tmp = out
        # 16 * 2h * 2w -> 16 * h/2 * w/2
        out = paddle.add(out,F.interpolate(F.relu(self.intd2_1bn(self.interd2_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic'))
        # 16 * h/2 * w/2 -> 16 * 2h * 2w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.intd2_2bn(self.interd2_2(tmp))),scale_factor=(4,4),mode ='bicubic'))
        
        # 16 * h/2 * w/2
        out = paddle.add(out,u1)
        # 16 * 2h * 2w
        out1 = paddle.add(out1,o1)

        # 16 * h/2 * w/2 -> 8 * h * w
        out = F.relu(self.de3_bn(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bicubic')))
        # 16 * 2h * 2w -> 8 * h * w
        out1 = F.relu(self.def3_bn(F.max_pool2d(self.decoderf3(out1),2,2)))

        # 8 * h * w
        out = paddle.add(out,out1) # fusion of both branches

        # The last layer activates the function with sigmoid
        out = F.sigmoid(self.final(out))  #1*1 conv
        
        return out
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import initializer

def init_weights(init_type='kaiming'):
    if init_type == 'normal':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())
    elif init_type == 'xavier':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())
    elif init_type == 'kaiming':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)

class unetConv2(nn.Layer):
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
                                     nn.BatchNorm(out_size),
                                     nn.ReLU(), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size
        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
                                     nn.ReLU(), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size
        # initialise the blocks
        for m in self.children():
            m.weight_attr=init_weights(init_type='kaiming')
            m.bias_attr=init_weights(init_type='kaiming')
    def forward(self, inputs):
        x = inputs
        for i in range(1, self.n + 1):
            conv = getattr(self, 'conv%d' % i)
            x = conv(x)
        return x

'''
    UNet 3+
'''
class UNet_3Plus(nn.Layer):
    def __init__(self, in_channels=3, n_classes=1, is_deconv=True, is_batchnorm=True, end_sigmoid=True):
        super(UNet_3Plus, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.end_sigmoid = end_sigmoid
        filters = [16, 32, 64, 128, 256]
        ## -------------Encoder--------------
        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2D(kernel_size=2)
        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2D(kernel_size=2)
        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2D(kernel_size=2)
        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2D(kernel_size=2)
        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
        ## -------------Decoder--------------
        self.CatChannels = filters[0]
        self.CatBlocks = 5
        self.UpChannels = self.CatChannels * self.CatBlocks
        '''stage 4d'''
        # h1->320*320, hd4->40*40, Pooling 8 times
        self.h1_PT_hd4 = nn.MaxPool2D(8, 8, ceil_mode=True)
        self.h1_PT_hd4_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd4_bn = nn.BatchNorm(self.CatChannels)
        self.h1_PT_hd4_relu = nn.ReLU()
        # h2->160*160, hd4->40*40, Pooling 4 times
        self.h2_PT_hd4 = nn.MaxPool2D(4, 4, ceil_mode=True)
        self.h2_PT_hd4_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)
        self.h2_PT_hd4_bn = nn.BatchNorm(self.CatChannels)
        self.h2_PT_hd4_relu = nn.ReLU()
        # h3->80*80, hd4->40*40, Pooling 2 times
        self.h3_PT_hd4 = nn.MaxPool2D(2, 2, ceil_mode=True)
        self.h3_PT_hd4_conv = nn.Conv2D(filters[2], self.CatChannels, 3, padding=1)
        self.h3_PT_hd4_bn = nn.BatchNorm(self.CatChannels)
        self.h3_PT_hd4_relu = nn.ReLU()
        # h4->40*40, hd4->40*40, Concatenation
        self.h4_Cat_hd4_conv = nn.Conv2D(filters[3], self.CatChannels, 3, padding=1)
        self.h4_Cat_hd4_bn = nn.BatchNorm(self.CatChannels)
        self.h4_Cat_hd4_relu = nn.ReLU()
        # hd5->20*20, hd4->40*40, Upsample 2 times
        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd5_UT_hd4_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd4_bn = nn.BatchNorm(self.CatChannels)
        self.hd5_UT_hd4_relu = nn.ReLU()
        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
        self.conv4d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn4d_1 = nn.BatchNorm(self.UpChannels)
        self.relu4d_1 = nn.ReLU()
        '''stage 3d'''
        # h1->320*320, hd3->80*80, Pooling 4 times
        self.h1_PT_hd3 = nn.MaxPool2D(4, 4, ceil_mode=True)
        self.h1_PT_hd3_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd3_bn = nn.BatchNorm(self.CatChannels)
        self.h1_PT_hd3_relu = nn.ReLU()
        # h2->160*160, hd3->80*80, Pooling 2 times
        self.h2_PT_hd3 = nn.MaxPool2D(2, 2, ceil_mode=True)
        self.h2_PT_hd3_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)
        self.h2_PT_hd3_bn = nn.BatchNorm(self.CatChannels)
        self.h2_PT_hd3_relu = nn.ReLU()
        # h3->80*80, hd3->80*80, Concatenation
        self.h3_Cat_hd3_conv = nn.Conv2D(filters[2], self.CatChannels, 3, padding=1)
        self.h3_Cat_hd3_bn = nn.BatchNorm(self.CatChannels)
        self.h3_Cat_hd3_relu = nn.ReLU()
        # hd4->40*40, hd4->80*80, Upsample 2 times
        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd4_UT_hd3_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd3_bn = nn.BatchNorm(self.CatChannels)
        self.hd4_UT_hd3_relu = nn.ReLU()
        # hd5->20*20, hd4->80*80, Upsample 4 times
        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd5_UT_hd3_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd3_bn = nn.BatchNorm(self.CatChannels)
        self.hd5_UT_hd3_relu = nn.ReLU()
        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
        self.conv3d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn3d_1 = nn.BatchNorm(self.UpChannels)
        self.relu3d_1 = nn.ReLU()
        '''stage 2d '''
        # h1->320*320, hd2->160*160, Pooling 2 times
        self.h1_PT_hd2 = nn.MaxPool2D(2, 2, ceil_mode=True)
        self.h1_PT_hd2_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd2_bn = nn.BatchNorm(self.CatChannels)
        self.h1_PT_hd2_relu = nn.ReLU()
        # h2->160*160, hd2->160*160, Concatenation
        self.h2_Cat_hd2_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)
        self.h2_Cat_hd2_bn = nn.BatchNorm(self.CatChannels)
        self.h2_Cat_hd2_relu = nn.ReLU()
        # hd3->80*80, hd2->160*160, Upsample 2 times
        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd3_UT_hd2_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd3_UT_hd2_bn = nn.BatchNorm(self.CatChannels)
        self.hd3_UT_hd2_relu = nn.ReLU()
        # hd4->40*40, hd2->160*160, Upsample 4 times
        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd4_UT_hd2_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd2_bn = nn.BatchNorm(self.CatChannels)
        self.hd4_UT_hd2_relu = nn.ReLU()
        # hd5->20*20, hd2->160*160, Upsample 8 times
        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd5_UT_hd2_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd2_bn = nn.BatchNorm(self.CatChannels)
        self.hd5_UT_hd2_relu = nn.ReLU()
        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
        self.Conv2D_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn2d_1 = nn.BatchNorm(self.UpChannels)
        self.relu2d_1 = nn.ReLU()
        '''stage 1d'''
        # h1->320*320, hd1->320*320, Concatenation
        self.h1_Cat_hd1_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)
        self.h1_Cat_hd1_bn = nn.BatchNorm(self.CatChannels)
        self.h1_Cat_hd1_relu = nn.ReLU()
        # hd2->160*160, hd1->320*320, Upsample 2 times
        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd2_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd2_UT_hd1_bn = nn.BatchNorm(self.CatChannels)
        self.hd2_UT_hd1_relu = nn.ReLU()
        # hd3->80*80, hd1->320*320, Upsample 4 times
        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd3_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd3_UT_hd1_bn = nn.BatchNorm(self.CatChannels)
        self.hd3_UT_hd1_relu = nn.ReLU()
        # hd4->40*40, hd1->320*320, Upsample 8 times
        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd4_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd1_bn = nn.BatchNorm(self.CatChannels)
        self.hd4_UT_hd1_relu = nn.ReLU()
        # hd5->20*20, hd1->320*320, Upsample 16 times
        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')  # 14*14
        self.hd5_UT_hd1_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd1_bn = nn.BatchNorm(self.CatChannels)
        self.hd5_UT_hd1_relu = nn.ReLU()
        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
        self.conv1d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn1d_1 = nn.BatchNorm(self.UpChannels)
        self.relu1d_1 = nn.ReLU()
        # output
        self.outconv1 = nn.Conv2D(self.UpChannels, n_classes, 3, padding=1)
        # initialise weights
        for m in self.sublayers ():
            if isinstance(m, nn.Conv2D):
                m.weight_attr = init_weights(init_type='kaiming')
                m.bias_attr = init_weights(init_type='kaiming')
            elif isinstance(m, nn.BatchNorm):
                m.param_attr =init_weights(init_type='kaiming')
                m.bias_attr = init_weights(init_type='kaiming')
    def forward(self, inputs):
        ## -------------Encoder-------------
        h1 = self.conv1(inputs)  # h1->320*320*64
        h2 = self.maxpool1(h1)
        h2 = self.conv2(h2)  # h2->160*160*128
        h3 = self.maxpool2(h2)
        h3 = self.conv3(h3)  # h3->80*80*256
        h4 = self.maxpool3(h3)
        h4 = self.conv4(h4)  # h4->40*40*512
        h5 = self.maxpool4(h4)
        hd5 = self.conv5(h5)  # h5->20*20*1024
        ## -------------Decoder-------------
        h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
            paddle.concat([h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4], 1)))) # hd4->40*40*UpChannels
        h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
            paddle.concat([h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3], 1)))) # hd3->80*80*UpChannels
        h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
        hd2 = self.relu2d_1(self.bn2d_1(self.Conv2D_1(
            paddle.concat([h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2], 1)))) # hd2->160*160*UpChannels
        h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
            paddle.concat([h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1], 1)))) # hd1->320*320*UpChannels
        d1 = self.outconv1(hd1)  # d1->320*320*n_classes
        if self.end_sigmoid:
            out = F.sigmoid(d1)
        else:
            out = d1
        return out
# View network structure
KIunet = kiunet(in_channels = 3, n_classes =3)
Unet3p = UNet_3Plus( in_channels=3, n_classes=3)


model = paddle.Model(Unet3p)
model.summary((2,3, 256, 256))

⚔️ 3.4 introduction to training function

This section describes the specific use of the template:

  • Model operation
parameterself.is_Trainself.PATH
FalseReasoningAb initio training
TrueConduct trainingRead the model and continue training

Note: self True of path refers to the input model path, for example, when self is_ When train is False, self Path is a model path. At this time, there is no need for training, and reasoning is directly carried out to generate results.

  • loss function

In the work directory, lose Py file implements:

---- TVLoss

---- SSIMLoss

----Loss of perception (vgg19)

...

You can modify or add according to your own needs

  • Optimal model

The score variable is defined in Eval method, which records the evaluation index on the verification set.

Each EPOCH counts the score on the validation set once.

Only when a higher score is obtained will the original model be deleted and the new optimal model be recorded.

Note: when you need to customize the evaluation index, you only need to modify the calculation of score

  • Save results

The function zi'dong generates a series of directories under the work directory:

  • log

    Used to save the results of the training process

  • outputs

    Used to save the results of reasoning with the model name as a folder

  • saveModel

    It is used to record the optimal model. The name is model name + verification set score

Note: you need to customize and modify self Modelname, which is used to distinguish the generation results of different models

## Main function definition Baseline 
# Basic assumption: clean image + noise = contaminated image
# The value of some pixels polluted by the shadow needs to be enhanced
# Contaminated image + shadow = clean image
"""
@author: xupeng
"""

from work.loss import *
from work.util import batch_psnr_ssim
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import glob
from PIL import Image
from skimage.measure.simple_metrics import compare_psnr
import skimage
import paddle.vision.transforms as T
import cv2

class Solver(object):
    def __init__(self):
        self.model = None
        self.lr = 1e-4 # Learning rate
        self.epochs = 10  # Algebra of training
        self.batch_size = 4 # Number of training batches
        self.optimizer = None
        self.scheduler = None
        self.saveSetDir = r'data/delight_testA_dataset/images/' # Test set address
        
        self.train_set = None
        self.eval_set = None
        self.train_loader = None
        self.eval_loader = None
        
        self.Modelname = 'Unet3p' # Network declaration used (the generated file will be distinguished by this declaration)
        
        self.evaTOP = 0
        self.is_Train = True # #Whether to train the model. If False, give self Path, direct reasoning
        self.PATH = False #False# It is used to record the name of the optimal model. When pre training is required, this item cannot be empty


    def InitModeAndData(self):
        print("---------------trainInit:---------------")
        # API document search: vision Transformation view data enhancement method

        is_transforms = True

        self.train_set =  DEshadowDataset(mode='train',is_transforms = is_transforms)
        self.eval_set  =  DEshadowDataset(mode='eval',is_transforms = is_transforms)

        # Use the pad io. DataLoader defines the DataLoader object, which is used to load the data generated by Python generator

        self.train_loader = paddle.io.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
        self.eval_loader = paddle.io.DataLoader(self.eval_set, batch_size=self.batch_size, shuffle=False)

        self.model = UNet_3Plus( in_channels=3, n_classes=3)

        if self.is_Train and self.PATH: # When there are already models to be trained, conduct secondary training
            params = paddle.load(self.PATH)
            self.model.set_state_dict(params)
            self.evaTOP = float(self.PATH.split(self.Modelname)[-1].replace('.pdparams',''))

        self.scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=self.lr, T_max=30, verbose=False)
        self.optimizer = paddle.optimizer.Adam(parameters=self.model.parameters(), learning_rate=self.scheduler,beta1=0.5, beta2=0.999)

        # create folder 
        for item in ['log','outputs','saveModel']:
            make_folder = os.path.join('work',item)
            if  not os.path.exists(make_folder):
                os.mkdir(make_folder)

    def train(self,epoch):

        print("=======Epoch:{}/{}=======".format(epoch,self.epochs))
        self.model.train()
        TV_criterion = nTVLoss()    # Used to standardize image noise
        ssim_criterion = SSIMLoss() # Used to modify the original drawing
        lossnet = LossVGG19() # Perceived loss
        l1_loss = paddle.nn.L1Loss()
        Mse_loss = nn.MSELoss()

        try: # Using this writing method (try except and ascii=True) can avoid the confusion of windows terminal or abnormal interrupt tqdm
            with tqdm(enumerate(self.train_loader),total=len(self.train_loader), ascii=True) as tqdmData:
                mean_loss = []
                
                for idx, (img_train,mask_train) in tqdmData:
                    tqdmData.set_description('train')

                    img_train = paddle.to_tensor(img_train,dtype="float32")
                    mask_train =  paddle.to_tensor(mask_train,dtype="float32")


                    # MODEL
                    outputs_noise = self.model(img_train) # Mask > img, so you need to increase the pixel value
                    mask_noise = mask_train - img_train # Difference between real image and shadow


                    # Recovered image
                    restore_trian =img_train + outputs_noise
                    '''
                    # De averaging
                    tensor_c = paddle.to_tensor(np.array([123.6800, 116.7790, 103.9390]).astype(np.float32).reshape((1, 3, 1, 1)))

                    # Perceived loss
                    # preceptual loss
                    loss_fake_B = lossnet(restore_trian * 255 - tensor_c)
                    loss_real_B = lossnet(mask_train * 255 - tensor_c)
                    p0 = l1_loss(restore_trian * 255 - tensor_c, mask_train * 255 - tensor_c) * 2
                    p1 = l1_loss(loss_fake_B['relu1'], loss_real_B['relu1']) / 2.6
                    p2 = l1_loss(loss_fake_B['relu2'], loss_real_B['relu2']) / 4.8

                    loss_p = p0 + p1 + p2

                    loss = loss_p  + ssim_criterion(restore_trian,mask_train) + 10*l1_loss(outputs_noise,mask_noise)
                    '''

                    loss = l1_loss(restore_trian,mask_train)

                    self.optimizer.clear_grad()
                    loss.backward()
                    self.optimizer.step()
                    
                    self.scheduler.step() ### Remember to change this when you change the optimizer

                    mean_loss.append(loss.item())

        except KeyboardInterrupt:
            tqdmData.close()
            os._exit(0)
        tqdmData.close()
        # Clear intermediate variables to free memory
        del loss,img_train,mask_train,outputs_noise,mask_noise,restore_trian
        paddle.device.cuda.empty_cache()
        return {'Mean_trainLoss':np.mean(mean_loss)}
    
    def Eval(self,modelname):
        
        self.model.eval()
        temp_eval_psnr ,temp_eval_ssim= [],[]
        with paddle.no_grad():
            try:
                with tqdm(enumerate(self.eval_loader),total=len(self.eval_loader), ascii=True) as tqdmData:
                    for idx, (img_eval,mask_eval) in tqdmData:
                        tqdmData.set_description(' eval')

                        img_eval=  paddle.to_tensor(img_eval,dtype="float32")
                        mask_eval = paddle.to_tensor(mask_eval,dtype="float32")

                        outputs_denoise = self.model(img_eval) # Model output
                        outputs_denoise = img_eval + outputs_denoise # Restored image

                        psnr_test,ssim_test = batch_psnr_ssim(outputs_denoise, mask_eval, 1.)
                        temp_eval_psnr.append(psnr_test)
                        temp_eval_ssim.append(ssim_test)
                    
            except KeyboardInterrupt:
                tqdmData.close()
                os._exit(0)
            tqdmData.close()
            paddle.device.cuda.empty_cache()

            # Print test PSNR & SSIM
        # print('eval_psnr:',np.mean(temp_eval_psnr),'eval_ssim:',np.mean(temp_eval_ssim))
        # Realize the evaluation index
        score = 0.05*np.mean(temp_eval_psnr)+0.5*np.mean(temp_eval_ssim)

        return {'eval_psnr':np.mean(temp_eval_psnr),'eval_ssim':np.mean(temp_eval_ssim),'SCORE':score}
    
    def saveModel(self,trainloss,modelname):
        
        trainLoss = trainloss['SCORE']
        if trainLoss < self.evaTOP and self.evaTOP!=0: 

            return 0
        else:
            folder = 'work/saveModel/'
            self.PATH = folder+modelname+str(trainLoss)+'.pdparams'
            removePATH = folder+modelname+str(self.evaTOP)+'.pdparams'
            paddle.save(self.model.state_dict(), self.PATH)

            if self.evaTOP!=0:
                os.remove(removePATH)
            
            self.evaTOP = trainLoss
            return 1
        
    def saveResult(self):
        print("---------------saveResult:---------------")

        self.model.set_state_dict(paddle.load(self.PATH))
        self.model.eval()

        paddle.set_grad_enabled(False)
        paddle.device.cuda.empty_cache()

        data_dir = glob.glob(self.saveSetDir+'*.jpg')

        # Create save folder
        make_save_result = os.path.join('work/outputs',self.Modelname)
        if  not os.path.exists(make_save_result):
            os.mkdir(make_save_result)

        saveSet = pd.DataFrame()
        tpsnr,tssim = [],[]
        
        for idx,ori_path in enumerate(data_dir):
            print(len(data_dir),'|',idx+1,end = '\r',flush = True)
            
            ori = cv2.imread(ori_path) # W,H,C
            ori = cv2.cvtColor(ori, cv2.COLOR_BGR2RGB) # BGR 2 RGB
            ori_w,ori_h,ori_c = ori.shape

            # normalize_test = T.Normalize( 
            # [0.610621, 0.5989216, 0.5876396], 
            # [0.1835931, 0.18701428, 0.19362564],
            # data_format='HWC')

            #ori = normalize_test(ori)

            ori = T.functional.resize(ori,(1024,1024),interpolation = 'bicubic') ###Do not cut into pieces and send them to 10241024 for sampling

            # from HWC to CHW ,[0,255] to [0,1]
            ori = np.transpose(ori,(2,0,1))/255


            ori_img = paddle.to_tensor(ori,dtype="float32")
            ori_img = paddle.unsqueeze(ori_img,0) # N C W H

            out_noise = self.model(ori_img) #Don't cut it into pieces and send it in
            out_noise = ori_img + out_noise #Add recovered pixels

            img_cpu = out_noise.cpu().numpy()
            #Save results
            img_cpu = np.squeeze(img_cpu)
            img_cpu = np.transpose(img_cpu,(1,2,0)) # C,W,H to W,H,C
            
            img_cpu = np.clip(img_cpu, 0., 1.)
            savepic = np.uint8(img_cpu*255)

            savepic = T.functional.resize(savepic,(ori_w,ori_h),interpolation = 'bicubic') ## Sample back to original size

            # Save path
            savedir = os.path.join(make_save_result,ori_path.split('/')[-1])

            savepic = cv2.cvtColor(savepic, cv2.COLOR_RGB2BGR) # BGR
            cv2.imwrite(savedir, savepic, [int( cv2.IMWRITE_JPEG_QUALITY), 100])

    def run(self):
        
        self.InitModeAndData()
            
        if self.is_Train:
            modelname = self.Modelname #  Network name used

            result = pd.DataFrame()
            

            for epoch in range(1, self.epochs + 1):
                
                trainloss = self.train(epoch)
                evalloss =  self.Eval(modelname)#

                Type = self.saveModel(evalloss,modelname)
                
                type_ = {'Type':Type}
                trainloss.update(evalloss)#
                trainloss.update(type_)

                result = result.append(trainloss,ignore_index=True)
                print('Epoch:',epoch,trainloss)
                
                #self.scheduler.step()

            evalloss =  self.Eval(modelname)#

            result.to_csv('work/log/' +modelname+str(evalloss['SCORE'])+'.csv')#

            self.saveResult()

        else:

            self.saveResult()

def main():
    solver = Solver()
    solver.run()
    
if __name__ == '__main__':
    main()

Enter the output path to create readme Txt file, enter the required content:

Training framework: PaddlePaddle

Code running environment: V100

Whether to use GPU: Yes

Single picture time / s: 1

Model size: 45

Other notes: refer to UNet for algorithm+++

# %cd /home/aistudio/work/outputs/Unet3p
# !zip result.zip *.jpg *.txt

🐅 4 project summary

  • The project implements a set of simple training templates.

  • The project is created based on the document shadow elimination task, and the submission result is 0.59951

Welcome to modify your algorithm on this basis!

Note: the model does not adjust parameters, but only ensures the normal use of the template. If higher accuracy is required, you can leave a message under the project. If there are many people paying attention, it will be updated again.

If you have any questions, please leave a message in the comment area.

Keywords: Machine Learning Computer Vision Deep Learning

Added by jernhenrik on Sun, 13 Feb 2022 08:36:36 +0200