I tried to explain Pytorch dataset


There are people who use Pytorch in the laboratory, so for explanation for that person

The main subject- "Why datasets are needed"

Simple code and description


import torch
from sklearn.datasets import load_iris

class Dataset(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        self.iris = load_iris() #Loading iris dataset
        self.data = self.iris['data']
        self.label = self.iris['target']
        self.datanum = len(self.label) #Total number of data
        self.transform = transform #Special treatment for data

    def __len__(self):
        return self.datanum

    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]

        if self.transform:
            data = self.transform(data)

        return data, label

if __name__ == "__main__":
    batch_size = 20
    dataset = Dataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for data in dataloader:
        print("Number of data: " + str(len(data[0])))
        print("data: {}".format(data[0]))
        print("Number of labels: " + str(len(data[1])))
        print("label: {}".format(data[1]) + "\n")

That's the whole code. I've really kept it to a minimum for clarity.

init is the process when the class is defined. This time it's very small data, so I have defined everything in init, If you want to iterate very large data from storage, specify the path etc. here and iterate in order with getitem.

len returns the total number of data.

getitem returns the data specified by index. When iterating with Dataloader, specify and return the data for the batch size described later.

About Dataloader

Pass the dataset, which is an instance of the Dataset class, as the first argument. For the second and third arguments, either batch size or shuffle or pass __ (True / False) __. Set shuffle to True unless you have a specific reason to do so.

There are other ways to iterate over the Dataloader besides for, but here we use for. The above program is executed below.

Number of data: 20
data: tensor([[6.8000, 3.0000, 5.5000, 2.1000],
        [6.7000, 3.1000, 5.6000, 2.4000],
        [5.4000, 3.9000, 1.3000, 0.4000],
        [5.5000, 2.4000, 3.7000, 1.0000],
        [5.1000, 3.7000, 1.5000, 0.4000],
        [4.5000, 2.3000, 1.3000, 0.3000],
        [6.6000, 2.9000, 4.6000, 1.3000],
        [6.5000, 3.0000, 5.8000, 2.2000],
        [7.0000, 3.2000, 4.7000, 1.4000],
        [4.4000, 3.2000, 1.3000, 0.2000],
        [5.0000, 3.4000, 1.5000, 0.2000],
        [5.4000, 3.4000, 1.5000, 0.4000],
        [4.9000, 2.4000, 3.3000, 1.0000],
        [6.3000, 3.4000, 5.6000, 2.4000],
        [7.7000, 2.6000, 6.9000, 2.3000],
        [6.2000, 2.8000, 4.8000, 1.8000],
        [6.2000, 3.4000, 5.4000, 2.3000],
        [5.6000, 2.7000, 4.2000, 1.3000],
        [6.1000, 3.0000, 4.9000, 1.8000],
        [6.7000, 3.0000, 5.0000, 1.7000]], dtype=torch.float64)
Number of labels: 20
label: tensor([2, 2, 0, 1, 0, 0, 1, 2, 1, 0, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1],

~~~~~ Omitted on the way ~~~~~~

Number of data: 10
data: tensor([[4.8000, 3.4000, 1.6000, 0.2000],
        [6.1000, 2.8000, 4.7000, 1.2000],
        [5.1000, 3.8000, 1.9000, 0.4000],
        [6.7000, 3.3000, 5.7000, 2.1000],
        [6.4000, 2.9000, 4.3000, 1.3000],
        [7.4000, 2.8000, 6.1000, 1.9000],
        [6.4000, 3.2000, 5.3000, 2.3000],
        [5.0000, 3.3000, 1.4000, 0.2000],
        [5.0000, 3.2000, 1.2000, 0.2000],
        [5.8000, 2.7000, 4.1000, 1.0000]], dtype=torch.float64)
Number of labels: 10
label: tensor([0, 1, 0, 2, 1, 2, 2, 0, 0, 1], dtype=torch.int32)

Both the data and the label are output correctly with 20 each defined as the batch size. 150/20 is not so much, but Dataloader outputs 10 without any error. It is also convenient that it can be adjusted in this way without any special processing.


This is a simple example, so please comment if you have any other questions. If you make a mistake, please do so.

Recommended Posts

I tried to explain Pytorch dataset
I tried to implement reading Dataset with PyTorch
I tried to implement SSD with PyTorch now (Dataset)
I tried to implement CVAE with PyTorch
I tried to debug.
I tried to paste
I tried to detect Mario with pytorch + yolov3
I tried to learn PredNet
I tried to implement PCANet
I tried to reintroduce Linux
I tried to introduce Pylint
I tried to summarize SparseMatrix
I tried to touch jupyter
I tried to implement StarGAN (1)
I tried to move Faster R-CNN quickly with pytorch
I tried to implement and learn DCGAN with PyTorch
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
[Deep Learning from scratch] I tried to explain Dropout
I tried to implement Deep VQE
I tried to create Quip API
I tried to touch Python (installation)
I tried to implement adversarial validation
I tried Watson Speech to Text
I tried implementing DeepPose with PyTorch
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to touch Tesla's API
I tried to implement hierarchical clustering
I tried to organize about MCMC.
I tried to implement Realness GAN
I tried to move the ball
I tried to estimate the interval.
I tried to implement SSD with PyTorch now (model edition)
I tried to visualize the Beverage Preference Dataset by tensor decomposition.
I tried to implement sentence classification by Self Attention with PyTorch
I tried to create a linebot (implementation)
I tried to implement Autoencoder with TensorFlow
I tried to summarize the umask command
I tried to implement permutation in Python
I tried to create a linebot (preparation)
I tried to visualize AutoEncoder with TensorFlow
I tried to recognize the wake word
I tried to get started with Hy
I tried to implement PLSA in Python 2
Python3 standard input I tried to summarize
I tried to classify text using TensorFlow
I tried to summarize the graphical modeling.
I tried adding post-increment to CPython Implementation
I tried to implement ADALINE in Python
I tried to let optuna solve Sudoku
[Introduction to Pytorch] I played with sinGAN ♬
I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch PartⅡ
I tried to estimate the pi stochastically
I tried to touch the COTOHA API
I tried to implement PPO in Python
I tried to make a Web API
I tried to solve TSP with QAOA
[Python] I tried to calculate TF-IDF steadily
I tried to touch Python (basic syntax)
I tried my best to return to Lasso
I tried to summarize Ansible modules-Linux edition