Introduction to PyTorch | DATSETS & DATALOADERS

Python wechat ordering applet course video

https://edu.csdn.net/course/detail/36074

Python actual combat quantitative transaction financial management system

https://edu.csdn.net/course/detail/35475
The code used to process data samples may become messy and difficult to maintain; Ideally, we want to decouple the dataset code from the model training code to achieve better readability and modularity. PyTorch provides two data primitives: torch utils. data. DataLoader and torch utils. data. Dataset, which allows you to use preloaded datasets and your own data. The dataset stores samples and their corresponding tags. The DataLoader wraps an iterator for the dataset to access samples.

The PyTorch library provides some preloaded data sets (such as FashionMNIST), which are torch utils. data. Subclass of dataset. Specific data corresponds to specific implementation functions. They can be used to prototype and benchmark your model. You can view them here: Image Datasets, Text Datasets, and Audio Datasets.

Load dataset

This is how to load from TorchVision Fashion-MNIST Examples of data sets. Fashion MNIST comes from Zalando's article and consists of 60000 training samples and 10000 test samples. Each sample contains a 28x28
The grayscale picture and the label of one of the corresponding 10 classes.

We load with the following parameters FashionMNIST Dataset

  • root is the storage path of training / test data
  • train specifies whether it is a training set or a test set
  • download=True if not in root, download from the Internet
  • transform and target_transform specifies the transformation of the sample
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training\_data = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

test\_data = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True
    transform=ToTensor()
)

Output:

Click to view the code

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Iteration and dataset visualization

We can index Datasets: training like a list_ data[index]. Use matplotlib to visualize samples of some training sets.

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training\_data[sample\_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    # torch.squeeze(): delete dimension with dimension 1
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

Create a custom dataset

A custom dataset class must implement three functions: init, len and getitem. Check the implementation process below. The FashionMNIST image is saved in img_dir, and their labels are saved in a CSV file (comma separated value file) annotations_file.

In the next section, we will decompose what each function does.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def \_\_init\_\_(self, annotations\_file, img\_dir, transform=None, target\_transform=None):
        # Use pandas to read csv and convert it into DataFrame
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    
    def \_\_len\_\_(self):
        return len(self.img_labels)

    def \_\_getitem\_\_(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target\_transform:
            label = self.target_transform(label)
        return image, label

init

Once the data object is instantiated, the function__ init__ It will run immediately: initialize the directory containing images, label files, and two transformations (described in more detail in the next section)

labels.csv is similar to this:

tshirt1.jpg, 0
tshirt2.jpg, 0
...
anleboot999.jpg, 9

def \_\_init\_\_(self, annotations\_file, img\_dir, transform=None, target\_transform=None):
    # The column name is specified here
    self.img_labels = pd.read_csv(annotations_file, names=['file\_name', 'labels'])
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

len

__ len__ Function returns the number of samples in the dataset

For example:

def \_\_len\_\_(self):
    return len(self.img_labels)

getitem

__ getitem__ Function loads and returns a sample of the given index idx in the dataset. According to the index, it obtains the position of the picture on the hard disk and uses read_image is converted to tensor in self img_ Labels, retrieve the corresponding label from the csv and call the conversion function (if available) to return a tuple containing the picture and the corresponding label tensor.

def \_\_getitem\_\_(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target\_transform:
        label = self.target_transform(label)
    return image, label

Use DataLoader to prepare your data for your workout

Dataset can only retrieve the data characteristics and labels of one sample at the same time. When training the model, it is usually necessary to pass "minipatches" samples. Each epoch repeatedly disrupts the data, reduces over fitting, and uses Python's multiprocessing to accelerate data retrieval.

DataLoader is an iterator.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch\_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch\_size=64, shuffle=True)

Iterate through DataLoader

We have loaded the dataset into the DataLoader and can iterate the dataset as needed. Each iteration returns a train_features and train_ Batch of labels (including the features and labels of batch_size=64 respectively). Because we specify shuffle=True, the data will be disrupted after we iterate all batches (for more detailed control of the data loading order, see Samplers)

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train\_features.size()}")
print(f"Labels batch shape: {train\_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

Output:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 7

Extended reading

Keywords: AI Pytorch Deep Learning computer

Added by bhavin12300 on Sat, 29 Jan 2022 05:27:48 +0200