PyTorch MNIST: Load MNIST Dataset from PyTorch Torchvision

PyTorch MNIST - Load the MNIST dataset from PyTorch Torchvision and split it into a train data set and a test data set

PyTorch MNIST - Load the MNIST dataset from PyTorch Torchvision and split it into a train data set and a test data set

Video Transcript


This video will show how to import the MNIST dataset from PyTorch torchvision dataset.

The MNIST dataset is comprised of 70,000 handwritten numeric digit images and their respective labels.

There are 60,000 training images and 10,000 test images, all of which are 28 pixels by 28 pixels.


First, we import PyTorch.

import torch


Then we print the PyTorch version we are using.

print(torch.__version__)

We are using PyTorch 0.3.1.post2.


Now that we have PyTorch available, let’s load torchvision.

import torchvision

Torchvision is a package in the PyTorch library containing computer-vision models, datasets, and image transformations.


Since we want to get the MNIST dataset from the torchvision package, let’s next import the torchvision datasets.

import torchvision.datasets as datasets


First, let’s initialize the MNIST training set.

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)

We use the root parameter to define where to save the data.

The train parameter is set to true because we are initializing the MNIST training dataset.

The download parameter is set to true because we want to download it if it’s not already present in our data folder.

The transform parameter is set to none because we don’t want to apply any image manipulation transforms at this time.


If we switch to the folder view, we can see the data folder that was created.

Inside of it is a Raw folder and a Processed folder.


Inside of the Raw folder, we see the four files that were downloaded.


Inside of the Processed folder, we see the two files that were generated after the processing.

Note that because we set the transform parameter to none, that they should be what comes out of the raw data.


Next, let’s initialize the MNIST test set.

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

This time, it’s very quick because the data has already been loaded.

The train parameter is set to false because we want test set, not the train set.

Then like the training set, we set download to true and transform to none.


Let’s do a very brief exploration of what data we’ve loaded.


We can get the length of the MNIST training set using the Python len function to get the number of items to make sure it matches what we expect.

len(mnist_trainset)

We see that it is 60,000 which is what we expect.


Let’s also check the length of the MNIST test set using the Python len function to get the number of items to make sure it matches what we expect.

len(mnist_testset)

We see that it is 10,000 which is what we expect.


Brilliant - We were able to load the MNIST dataset from PyTorch torchvision and split it into a train dataset and a test dataset.

Receive the Data Science Weekly Newsletter every Thursday

Easy to unsubscribe at any time. Your e-mail address is safe.