[several data set sampling methods]

Pytorch Sampler

When training neural network, if the amount of data is too large to put the data into the network for training at one time, it is necessary to read the data in batches. This problem involves how to read data from the data set. PyTorch framework provides Sampler base class and multiple subclasses to realize data sampling in different ways.

Base class Sampler

class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

Sequential Sampler

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

The sequential sampling class does not define too many methods, and the initialization method only needs a Dataset class as a parameter.
For len() is only responsible for returning the number of data contained in the data source, the iter() method returns an iteratable object, which is a sequential numerical sequence generated by the range method, that is, the iteration is carried out in order.
Each Epoch contains many iterations. Each Epoch executes the iter() function once, and each iteration executes the next() function of the iteratable object once.

//test
# Define data and corresponding sampler
data = list([1, 2, 3, 4, 5])
seq_sampler = sampler.SequentialSampler(data_source=data)
# Iteratively obtain the index generated by the sampler
for index in seq_sampler:
    print("index: {}, data: {}".format(str(index), str(data[index])))
//output    
index: 0, data: 1
index: 1, data: 2
index: 2, data: 3
index: 3, data: 4
index: 4, data: 5

Random Sampler

class RandomSampler(Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Arguments:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
        generator (Generator): Generator used in sampling.
    """

    def __init__(self, data_source, replacement=False, num_samples=None, generator=None):
        self.data_source = data_source
        # This parameter controls whether to repeat sampling
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

		# Type check
        if not isinstance(self.replacement, bool):
            raise TypeError("replacement should be a boolean value, but got "
                            "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self):
        # dataset size might change at runtime
        # No num is passed in during initialization_ The length of the data source is used for samples
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            rand_tensor = torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=self.generator)
            return iter(rand_tensor.tolist())
        return iter(torch.randperm(n, generator=self.generator).tolist())
    # Returns the length of the dataset
    def __len__(self):
        return self.num_samples

The most important is the iter() method, which defines the core index generation behavior. Two random values are returned at the if judgment, and whether to repeat sampling is determined according to whether replacement is given in the initialization parameter. The core difference is that the random number sequence generated by randint() function contains repeated values, while the random number sequence generated by randperm() function does not contain repeated values.
The following two examples are tested when the replacement is False and True respectively:

ran_sampler = sampler.RandomSampler(data_source=data)
for index in ran_sampler:
    print("index: {}, data: {}".format(str(index), str(data[index])))

index: 3, data: 4
index: 4, data: 5
index: 2, data: 3
index: 1, data: 2
index: 0, data: 1

ran_sampler = sampler.RandomSampler(data_source=data, replacement=True)
for index in ran_sampler:
    print("index: {}, data: {}".format(str(index), str(data[index])))

index: 1, data: 2
index: 2, data: 3
index: 4, data: 5
index: 3, data: 4
index: 1, data: 2

Subset Random Sampler

class SubsetRandomSampler(Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Arguments:
        indices (sequence): a sequence of indices
        generator (Generator): Generator used in sampling.
    """

    def __init__(self, indices, generator=None):
    	# Slicing of data sets, such as training sets and test sets
        self.indices = indices
        self.generator = generator

    def __iter__(self):
     	# Return the "data" without repeated disturbance in the form of tuple
        return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))

    def __len__(self):
        return len(self.indices)

The function of len() in the above code is to return the random number sequence as the index of indice. It should be noted that the sampling is still not repeated and is also realized through the randperm function. The following example is used for the division of training set, verification set and test set:

sub_sampler_train = sampler.SubsetRandomSampler(indices=data[0:2])
for index in sub_sampler_train:
    print("index: {}".format(str(index)))
print('------------')
sub_sampler_val = sampler.SubsetRandomSampler(indices=data[2:])
for index in sub_sampler_val:
    print("index: {}".format(str(index)))
    
# train: 
index: 2
index: 1
# val: 
index: 3
index: 4
index: 5

Weighted random sampling

class WeightedRandomSampler(Sampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).

    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
        generator (Generator): Generator used in sampling.

    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [4, 4, 1, 4, 5]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    def __init__(self, weights, num_samples, replacement=True, generator=None):
    	#  Type check
        if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
         # weights is used to determine the weight of the generated index
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples 
         # Used to control whether the data is put back and sampled
        self.replacement = replacement
        self.generator = generator

    def __iter__(self):
    	 # Returns a random index value by weight
        rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
        return iter(rand_tensor.tolist())

    def __len__(self):
        return self.num_samples

The replacement parameter still controls whether the sample is put back. num_samples is used to control the number of samples generated. The weights parameter corresponds to the weight of samples rather than the weight of categories. The most important is the iter() method, which returns the random number sequence, but the random number sequence is determined according to the weight specified by weights.

# Weighted random sampling
data=[1,2,5,78,6,56]
# The position is [0], the circle is 0.1, the position is [1], and the weight is 0.2
weights=[0.1,0.2,0.3,0.4,0.8,0.3,5]
rsampler=sampler.WeightedRandomSampler(weights=weights,num_samples=10,replacement=True)

for index in rsampler:
    print("index: {}".format(str(index)))

index: 5
index: 4
index: 6
index: 6
index: 6

Batch samplerbatch sampler

class BatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
            with ``__len__`` implemented.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler, batch_size, drop_last):
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        # Type check
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        # Define which sampler to use
        self.sampler = sampler
        self.batch_size = batch_size
         # Is the number of samples less than batch_ This sampling will be rejected when size is selected
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            # If the number of samples and batch_ If the sizes are equal, the sampling is completed
            if len(batch) == self.batch_size:
                yield batch
                batch = []
         # After the for is completed, it is not necessary to eliminate the insufficient batch_size returns the current batch      
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
     # When not culled, the length of the data is the length of the sampler index
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

After defining various samplers, batch sampling is required. When drop_ When last is True, if the sampled data is smaller than the batch size, the batch data will be discarded. In the following example, the sampler used by BatchSampler is sequential sampler.

seq_sampler = sampler.SequentialSampler(data_source=data)
batch_sampler = sampler.BatchSampler(seq_sampler, 4, False)
print(list(batch_sampler))

[[0, 1, 2, 3], [4, 5]]

Keywords: AI Deep Learning NLP

Added by lth2h on Wed, 02 Mar 2022 15:14:19 +0200