Article catalog
preface
When reproducing the previous papers, the author wrote the data set processing in the source code, but now most of the code written by pytorch uses Dataset class combined with DataLoader to read the data set. Therefore, try to rewrite the source code into a structure that meets the requirements of Dataset class. There are rich tutorials on the Internet, Here I mainly record my learning process.
1, What is the Dataset class?
Dataset class is the dataset interface officially defined by pytoch. We can create our own data interface to meet any requirements according to its requirements. We first start with the official code, in which__ getitem__ And__ len__ Subclasses must be inherited. In fact, with these two parts, the dataset can be used directly.
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
Among them__ len__ Indicates the length of your data set, that is, the number of samples of the data set used in your final training__ getitem__ The sample pair (x,y) is obtained according to the index. The index here is actually (0,len(dataset)-1), that is, the index sequence starting from 0 according to the length of your dataset.
In this way, you only need to set the length and corresponding index of the data set, and the rest is the preprocessing operation of your own data set. Next, I will rewrite the specific data set encountered this time.
2, Rewrite steps
1. Import and storage
The code is as follows:
from torch.utils.data.dataset import Dataset #Introduce Dataset class from torch.utils.data import DataLoader #Used to read a defined data set import numpy as np from PIL import Image import torch
2. Data set introduction
Human3.6M : this data set is a 3D human posture estimation data set, but the corresponding experiment is the video prediction direction, so only the image information is used, and the partners who need it can click the hyperlink to download it by themselves. S1, S5, S6, S7 and S8 were used as data sets and S9 and S11 were used as test sets.
It should be noted that due to the particularity of video prediction task, the returned form of our dataset is not (image, label) but (input,output). Both input and output are tensors composed of multiple video frames, and output is the picture to be predicted in video prediction. (for example, in the following code, both our input and output are (4,3128128), representing four RGB images with a size of 128x128)
3. Data set rewriting
Firstly, we preprocess the data set according to its own characteristics and process the data set into the format we need.
class Human(Dataset): def __init__(self, data_dir, nt_cond, nt_pred, train): super(Human,self).__init__() self.data_dir = data_dir #Raw data directory self.pred_h = nt_pred #Number of predicted video frames self.lb = nt_cond #Number of conditional frames (what the model needs to do is to predict the subsequent pred_h frames given lb frames) self.train = train #Labels that distinguish training sets from data sets self.image_width = 128 self.seq_len = nt_cond + nt_pred self.data, self.indeces = self.load_data(data_dir,train) #Function to read the dataset #↑ this section initializes the data class Human defined by us def load_data(self, paths, train=True): data_dir = paths intervel = 2 frames_np = [] scenarios = ['Walking'] if train == True: subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] elif train == False: subjects = ['S9', 'S11'] else: print ("MODE ERROR") _path = data_dir print ('load data...', _path) filenames = os.listdir(_path) filenames.sort() print ('data size ', len(filenames)) frames_file_name = [] for filename in filenames: fix = filename.split('.') fix = fix[0] subject = fix.split('_') scenario = subject[1] subject = subject[0] if subject not in subjects or scenario not in scenarios: continue file_path = os.path.join(_path, filename) image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB) #[1000,1000,3] image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :] if self.image_width != image.shape[0]: image = cv2.resize(image, (self.image_width, self.image_width)) #image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width), # interpolation=cv2.INTER_LINEAR) #[128,128,3] frames_np.append(np.array(image, dtype=np.float32) / 255.0) frames_file_name.append(filename) # if len(frames_np) % 100 == 0: print len(frames_np) #if len(frames_np) % 1000 == 0: break # is it a begin index of sequence indices = [] index = 0 print ('gen index') while index + intervel * self.seq_len - 1 < len(frames_file_name): # 'S11_Discussion_1.54138969_000471.jpg' # ['S11_Discussion_1', '54138969_000471', 'jpg'] start_infos = frames_file_name[index].split('.') end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.') if start_infos[0] != end_infos[0]: index += 1 continue start_video_id, start_frame_id = start_infos[1].split('_') end_video_id, end_frame_id = end_infos[1].split('_') if start_video_id != end_video_id: index += 1 continue if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel: indices.append(index) if train == True: index += 10 elif train == False: index += 5 print("there are " + str(len(indices)) + " sequences") # data = np.asarray(frames_np) self.data = frames_np print("there are " + str(len(self.data)) + " pictures") return self.data, indices
load_ The data () method is to process the source video frame image into the form of 128x128x3 size we need. In addition, considering that our data set sample is composed of continuous frames, we need to build our own indexes. Select a video every 10 frames on the training set and a video every 5 frames on the test set.
After writing the initialization function, compensate the corresponding__ len__ And__ getitem__ Method subclass, our data set is used.
def __getitem__(self, index): #print(index) idx_id = self.indeces[index] self.data = np.array(self.data) inputs = self.data[idx_id:idx_id+self.lb].reshape(self.lb,3,self.image_width,self.image_width) targets = self.data[idx_id+self.lb+1:idx_id+self.lb+1+self.pred_h].reshape(self.pred_h,3,self.image_width,self.image_width) return torch.tensor(inputs,dtype=torch.float), torch.tensor(targets,dtype=torch.float) def __len__(self): return len(self.indeces) #Since a sample of the dataset is (input,output), the corresponding dataset length is self The length of indeces
4. Dataset call
After constructing the dataset class Human, we use DataLoader to read it.
parser.add_argument('--root', type=str, default='/data/datasets/Human3.6M/Human3.6M/') parser.add_argument('--batch_size', type=int, default=16, help='batch_size') human_train = Human(args.root, nt_cond=4, nt_pred=4, train=True) train_loader = DataLoader(dataset=human_train, batch_size=args.batch_size, pin_memory=True, shuffle=True, num_workers=0) #Batch here_ Size indicates the number of training samples sent at a time, shuffle indicates whether to disrupt the data set, that is, whether the index is in order human_test = Human(args.root, nt_cond=4, nt_pred=4, train=False) test_loader = DataLoader(dataset=human_test, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=0)
summary
Finally, for Dataset class rewriting, just write it correctly__ len__ And__ item__ Subclasses can meet the requirements, but the key is the preprocessing of data sets in the initialization part, which still needs to be studied.