How to define your own Dataset using the Dataset class in pytoch

Article catalog

preface

When reproducing the previous papers, the author wrote the data set processing in the source code, but now most of the code written by pytorch uses Dataset class combined with DataLoader to read the data set. Therefore, try to rewrite the source code into a structure that meets the requirements of Dataset class. There are rich tutorials on the Internet, Here I mainly record my learning process.

1, What is the Dataset class?

Dataset class is the dataset interface officially defined by pytoch. We can create our own data interface to meet any requirements according to its requirements. We first start with the official code, in which__ getitem__ And__ len__ Subclasses must be inherited. In fact, with these two parts, the dataset can be used directly.

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
 
    def __getitem__(self, index):
        raise NotImplementedError
 
    def __len__(self):
        raise NotImplementedError
 
    def __add__(self, other):
        return ConcatDataset([self, other])

Among them__ len__ Indicates the length of your data set, that is, the number of samples of the data set used in your final training__ getitem__ The sample pair (x,y) is obtained according to the index. The index here is actually (0,len(dataset)-1), that is, the index sequence starting from 0 according to the length of your dataset.

In this way, you only need to set the length and corresponding index of the data set, and the rest is the preprocessing operation of your own data set. Next, I will rewrite the specific data set encountered this time.

2, Rewrite steps

1. Import and storage

The code is as follows:

from torch.utils.data.dataset import Dataset  #Introduce Dataset class
from torch.utils.data import DataLoader       #Used to read a defined data set
import numpy as np
from PIL import Image
import torch

2. Data set introduction

Human3.6M : this data set is a 3D human posture estimation data set, but the corresponding experiment is the video prediction direction, so only the image information is used, and the partners who need it can click the hyperlink to download it by themselves. S1, S5, S6, S7 and S8 were used as data sets and S9 and S11 were used as test sets.

It should be noted that due to the particularity of video prediction task, the returned form of our dataset is not (image, label) but (input,output). Both input and output are tensors composed of multiple video frames, and output is the picture to be predicted in video prediction. (for example, in the following code, both our input and output are (4,3128128), representing four RGB images with a size of 128x128)

3. Data set rewriting

Firstly, we preprocess the data set according to its own characteristics and process the data set into the format we need.

class Human(Dataset):
    def __init__(self, data_dir, nt_cond, nt_pred, train):
        super(Human,self).__init__()
        self.data_dir = data_dir #Raw data directory
        self.pred_h = nt_pred    #Number of predicted video frames
        self.lb = nt_cond        #Number of conditional frames (what the model needs to do is to predict the subsequent pred_h frames given lb frames)
        self.train = train       #Labels that distinguish training sets from data sets
        self.image_width = 128
        self.seq_len = nt_cond + nt_pred
        self.data, self.indeces = self.load_data(data_dir,train) #Function to read the dataset
        #↑ this section initializes the data class Human defined by us
    def load_data(self, paths, train=True):
        data_dir = paths
        intervel = 2

        frames_np = []
        scenarios = ['Walking']
        if train == True:
            subjects = ['S1', 'S5', 'S6', 'S7', 'S8']
        elif train == False:
            subjects = ['S9', 'S11']
        else:
            print ("MODE ERROR")
        _path = data_dir
        print ('load data...', _path)
        filenames = os.listdir(_path)
        filenames.sort()
        print ('data size ', len(filenames))
        frames_file_name = []
        for filename in filenames:
            fix = filename.split('.')
            fix = fix[0]
            subject = fix.split('_')
            scenario = subject[1]
            subject = subject[0]
            if subject not in subjects or scenario not in scenarios:
                continue
            file_path = os.path.join(_path, filename)
            image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB)
            #[1000,1000,3]
            image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :]
            if self.image_width != image.shape[0]:
                image = cv2.resize(image, (self.image_width, self.image_width))
            #image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width),
            #                   interpolation=cv2.INTER_LINEAR)
            #[128,128,3]
            frames_np.append(np.array(image, dtype=np.float32) / 255.0)
            frames_file_name.append(filename)
#             if len(frames_np) % 100 == 0: print len(frames_np)
            #if len(frames_np) % 1000 == 0: break
        # is it a begin index of sequence
        indices = []
        index = 0
        print ('gen index')
        while index + intervel * self.seq_len - 1 < len(frames_file_name):
            # 'S11_Discussion_1.54138969_000471.jpg'
            # ['S11_Discussion_1', '54138969_000471', 'jpg']
            start_infos = frames_file_name[index].split('.')
            end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.')
            if start_infos[0] != end_infos[0]:
                index += 1
                continue
            start_video_id, start_frame_id = start_infos[1].split('_')
            end_video_id, end_frame_id = end_infos[1].split('_')
            if start_video_id != end_video_id:
                index += 1
                continue
            if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel:
                indices.append(index)
            if train == True:
                index += 10
            elif train == False:
                index += 5
        print("there are " + str(len(indices)) + " sequences")
        # data = np.asarray(frames_np)
        self.data = frames_np
        print("there are " + str(len(self.data)) + " pictures")
        return self.data, indices

 load_ The data () method is to process the source video frame image into the form of 128x128x3 size we need. In addition, considering that our data set sample is composed of continuous frames, we need to build our own indexes. Select a video every 10 frames on the training set and a video every 5 frames on the test set.

After writing the initialization function, compensate the corresponding__ len__ And__ getitem__ Method subclass, our data set is used.

    def __getitem__(self, index):
        #print(index)
        idx_id = self.indeces[index]
        self.data = np.array(self.data)
        inputs = self.data[idx_id:idx_id+self.lb].reshape(self.lb,3,self.image_width,self.image_width)
        targets = self.data[idx_id+self.lb+1:idx_id+self.lb+1+self.pred_h].reshape(self.pred_h,3,self.image_width,self.image_width)
        return torch.tensor(inputs,dtype=torch.float), torch.tensor(targets,dtype=torch.float)
        

    def __len__(self):
        return len(self.indeces)
        #Since a sample of the dataset is (input,output), the corresponding dataset length is self The length of indeces

4. Dataset call

After constructing the dataset class Human, we use DataLoader to read it.

parser.add_argument('--root', type=str, default='/data/datasets/Human3.6M/Human3.6M/')
parser.add_argument('--batch_size', type=int, default=16, help='batch_size')


human_train = Human(args.root, nt_cond=4, nt_pred=4, train=True)
train_loader = DataLoader(dataset=human_train, batch_size=args.batch_size, pin_memory=True, shuffle=True, num_workers=0)
#Batch here_ Size indicates the number of training samples sent at a time, shuffle indicates whether to disrupt the data set, that is, whether the index is in order
human_test = Human(args.root, nt_cond=4, nt_pred=4, train=False)
test_loader = DataLoader(dataset=human_test, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=0)

summary

Finally, for Dataset class rewriting, just write it correctly__ len__ And__ item__ Subclasses can meet the requirements, but the key is the preprocessing of data sets in the initialization part, which still needs to be studied.

Keywords: Machine Learning Pytorch

Added by Cory94bailly on Mon, 17 Jan 2022 04:48:08 +0200