# Pytorch Sampler

When training neural network, if the amount of data is too large to put the data into the network for training at one time, it is necessary to read the data in batches. This problem involves how to read data from the data set. PyTorch framework provides Sampler base class and multiple subclasses to realize data sampling in different ways.

## Base class Sampler

class Sampler(object): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators. .. note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """ def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError

## Sequential Sampler

class SequentialSampler(Sampler): r"""Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """ def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(range(len(self.data_source))) def __len__(self): return len(self.data_source)

The sequential sampling class does not define too many methods, and the initialization method only needs a Dataset class as a parameter.

For len() is only responsible for returning the number of data contained in the data source, the iter() method returns an iteratable object, which is a sequential numerical sequence generated by the range method, that is, the iteration is carried out in order.

Each Epoch contains many iterations. Each Epoch executes the iter() function once, and each iteration executes the next() function of the iteratable object once.

//test # Define data and corresponding sampler data = list([1, 2, 3, 4, 5]) seq_sampler = sampler.SequentialSampler(data_source=data) # Iteratively obtain the index generated by the sampler for index in seq_sampler: print("index: {}, data: {}".format(str(index), str(data[index]))) //output index: 0, data: 1 index: 1, data: 2 index: 2, data: 3 index: 3, data: 4 index: 4, data: 5

## Random Sampler

class RandomSampler(Sampler): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Arguments: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is ``True``. generator (Generator): Generator used in sampling. """ def __init__(self, data_source, replacement=False, num_samples=None, generator=None): self.data_source = data_source # This parameter controls whether to repeat sampling self.replacement = replacement self._num_samples = num_samples self.generator = generator # Type check if not isinstance(self.replacement, bool): raise TypeError("replacement should be a boolean value, but got " "replacement={}".format(self.replacement)) if self._num_samples is not None and not replacement: raise ValueError("With replacement=False, num_samples should not be specified, " "since a random permute will be performed.") if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples)) @property def num_samples(self): # dataset size might change at runtime # No num is passed in during initialization_ The length of the data source is used for samples if self._num_samples is None: return len(self.data_source) return self._num_samples def __iter__(self): n = len(self.data_source) if self.replacement: rand_tensor = torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=self.generator) return iter(rand_tensor.tolist()) return iter(torch.randperm(n, generator=self.generator).tolist()) # Returns the length of the dataset def __len__(self): return self.num_samples

The most important is the iter() method, which defines the core index generation behavior. Two random values are returned at the if judgment, and whether to repeat sampling is determined according to whether replacement is given in the initialization parameter. The core difference is that the random number sequence generated by randint() function contains repeated values, while the random number sequence generated by randperm() function does not contain repeated values.

The following two examples are tested when the replacement is False and True respectively:

ran_sampler = sampler.RandomSampler(data_source=data) for index in ran_sampler: print("index: {}, data: {}".format(str(index), str(data[index]))) index: 3, data: 4 index: 4, data: 5 index: 2, data: 3 index: 1, data: 2 index: 0, data: 1 ran_sampler = sampler.RandomSampler(data_source=data, replacement=True) for index in ran_sampler: print("index: {}, data: {}".format(str(index), str(data[index]))) index: 1, data: 2 index: 2, data: 3 index: 4, data: 5 index: 3, data: 4 index: 1, data: 2

## Subset Random Sampler

class SubsetRandomSampler(Sampler): r"""Samples elements randomly from a given list of indices, without replacement. Arguments: indices (sequence): a sequence of indices generator (Generator): Generator used in sampling. """ def __init__(self, indices, generator=None): # Slicing of data sets, such as training sets and test sets self.indices = indices self.generator = generator def __iter__(self): # Return the "data" without repeated disturbance in the form of tuple return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator)) def __len__(self): return len(self.indices)

The function of len() in the above code is to return the random number sequence as the index of indice. It should be noted that the sampling is still not repeated and is also realized through the randperm function. The following example is used for the division of training set, verification set and test set:

sub_sampler_train = sampler.SubsetRandomSampler(indices=data[0:2]) for index in sub_sampler_train: print("index: {}".format(str(index))) print('------------') sub_sampler_val = sampler.SubsetRandomSampler(indices=data[2:]) for index in sub_sampler_val: print("index: {}".format(str(index))) # train: index: 2 index: 1 # val: index: 3 index: 4 index: 5

## Weighted random sampling

class WeightedRandomSampler(Sampler): r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). Args: weights (sequence) : a sequence of weights, not necessary summing up to one num_samples (int): number of samples to draw replacement (bool): if ``True``, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row. generator (Generator): Generator used in sampling. Example: >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2] """ def __init__(self, weights, num_samples, replacement=True, generator=None): # Type check if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(num_samples)) if not isinstance(replacement, bool): raise ValueError("replacement should be a boolean value, but got " "replacement={}".format(replacement)) # weights is used to determine the weight of the generated index self.weights = torch.as_tensor(weights, dtype=torch.double) self.num_samples = num_samples # Used to control whether the data is put back and sampled self.replacement = replacement self.generator = generator def __iter__(self): # Returns a random index value by weight rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) return iter(rand_tensor.tolist()) def __len__(self): return self.num_samples

The replacement parameter still controls whether the sample is put back. num_samples is used to control the number of samples generated. The weights parameter corresponds to the weight of samples rather than the weight of categories. The most important is the iter() method, which returns the random number sequence, but the random number sequence is determined according to the weight specified by weights.

# Weighted random sampling data=[1,2,5,78,6,56] # The position is [0], the circle is 0.1, the position is [1], and the weight is 0.2 weights=[0.1,0.2,0.3,0.4,0.8,0.3,5] rsampler=sampler.WeightedRandomSampler(weights=weights,num_samples=10,replacement=True) for index in rsampler: print("index: {}".format(str(index))) index: 5 index: 4 index: 6 index: 6 index: 6

## Batch samplerbatch sampler

class BatchSampler(Sampler): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object with ``__len__`` implemented. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler, batch_size, drop_last): # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. # Type check if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) # Define which sampler to use self.sampler = sampler self.batch_size = batch_size # Is the number of samples less than batch_ This sampling will be rejected when size is selected self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) # If the number of samples and batch_ If the sizes are equal, the sampling is completed if len(batch) == self.batch_size: yield batch batch = [] # After the for is completed, it is not necessary to eliminate the insufficient batch_size returns the current batch if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): # When not culled, the length of the data is the length of the sampler index if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size

After defining various samplers, batch sampling is required. When drop_ When last is True, if the sampled data is smaller than the batch size, the batch data will be discarded. In the following example, the sampler used by BatchSampler is sequential sampler.

seq_sampler = sampler.SequentialSampler(data_source=data) batch_sampler = sampler.BatchSampler(seq_sampler, 4, False) print(list(batch_sampler)) [[0, 1, 2, 3], [4, 5]]