Official explanation: Dataloader combines dataset & sampler to provide iterable data
Main parameters:
1. Dataset: this dataset must be torch.utils.data.Dataset itself or a class inherited from it
The main method is __ getitem__(self, index) used to retrieve data according to the index index
2,batch_size: how many pieces of data should be returned for each batch
3. shuffle: whether to disturb the data. The default is False
4. Sampler: sample strategy, a data selection strategy, does not need to shuffle because sample itself is a kind of disorder. It seems that this sampler must be torch.utils.data.sampler.Sampler itself or a class inherited from it.
The main method is__ iter__(self) method. Each time ITER is called, only batchsize data can be obtained, that is, the data of a batch.
def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch
5. ... not later
Here is my code:
trainloader = DataLoader( ImageDataset(self.dataset.train, transform=self.transform_train), # Select config.k samples for each id in the incoming data sampler=ClassUniformlySampler(self.dataset.train, class_position=1, k=config.k), # The second dimension of the incoming data is a category, so class_position=1 batch_size=config.p * config.k, num_workers=config.workers, # shuffle=True, # With ClassUniformlySampler, you don't have to shuffle pin_memory=pin_memory, drop_last=False ) for batch_idx, (imgs, pids, _) in enumerate(trainloader): print("batch_idx: ", batch_idx) for i in range(len(pids)): print(pids[i], imgs[i].shape)
At first, I didn't really understand its internal principle. Why can I continuously return the required data by executing the enumerate code? Later, I followed the whole code to understand (if you are confused in this place, you can continue to look down, if not, you can leave)
In execution When trainloader = DataLoader() statement, DataLoder, ImageDataset and ClassUniformlySampler do not have any special operations. They are just initialized by init.
Used here ClassUniformlySampler is a kind of Sampler class. Its function is to reserve only k pieces of data for all IDS in the data. Therefore, during initialization, it generates a dictionary. key is the category and value is the index of all data belonging to the category. (it is only for explanation here and does not need further study.)
The code is taken from other places, which is difficult to find the root. Please mark it here**
class ClassUniformlySampler(Sampler): ''' random sample according to class label Arguments: data_source (Dataset): data_loader to sample from class_position (int): which one is used as class k (int): sample k images of each class ''' def __init__(self, data_source, class_position, k): self.class_position = class_position self.k = k self.samples = data_source self.class_dict = self._tuple2dict(self.samples) # Returns a dictionary. key is the category and value is the index of all data belonging to the category def __iter__(self): self.sample_list = self._generate_list(self.class_dict) return iter(self.sample_list) def __len__(self): return len(self.sample_list) def _tuple2dict(self, inputs): ''' :param inputs: list with tuple elemnts, [(image_path1, class_index_1), (imagespath_2, class_index_2), ...] :return: dict, {class_index_i: [samples_index1, samples_index2, ...]} ''' dict = {} for index, each_input in enumerate(inputs): class_index = each_input[self.class_position] if class_index not in list(dict.keys()): dict[class_index] = [index] else: dict[class_index].append(index) return dict def _generate_list(self, dict): ''' :param dict: dict, whose values are list :return: ''' sample_list = [] dict_copy = dict.copy() keys = list(dict_copy.keys()) random.shuffle(keys) for key in keys: value = dict_copy[key] if len(value) >= self.k: random.shuffle(value) sample_list.extend(value[0: self.k]) else: value = value * self.k random.shuffle(value) sample_list.extend(value[0: self.k]) return sample_list
In the first execution for batch_idx, (imgs, pids, _) in enumerate(trainloader), the first call is the sampler. _iter_() method. After sampling all data, it returns an index list storing the sampled data, and iter(sampler_list) is used as the return. The ITER method mentioned at the beginning, and each call can only return batchsize pieces of data.
Then, the Dataset comes on the stage. It just needs to get the data one by one according to the index in the sampler_list. When it gets the data of batch size, iter won't let it get it again.
After that, every execution for batch_idx, (imgs, pids, _) in enumerate(trainloader), the Dataset will continue to fetch batchsize data from the data index interrupted by the last iter until all data are retrieved.
Note: because the original data order has been disrupted during sampling, even if the sample_list returned after sampling is taken in order, it is not really orderly. In addition, this can prevent repeated extraction of the same data, and an epoch can be ended after the data is taken