MNIST Digit Recognization with PyTorch
Introduction
In this article we will do the practical on MNIST Handwriting Digit Recognization using PyTorch with Artificial Neural Network (Feed Forward Neural Network). about the MNIST Database (Modified National Institute of Standards and Technology database) contains about 60k for Training Images and 10k for Testing Images has total 10 classes i.e [0,1,2,3,4,5,6,7,8,9] and mainly used this dataset for image processing and machine learning classification problem. in the below figure some random handwriting digits are displayed.
Table of Content
- Import Libraries
- Download the MNIST dataset
- Load the Data Loader
- Visualize the Dataset
- Build ANN model for mnist
- Load Loss and Optimizer
- Train & Validate the Model
- Plot the Accuracy & Loss graph of mnist model
Import Libraries
We are going to use some basic libraries of PyTorch from Torch Libraries https://github.com/pytorch/pytorch/tree/master/torch let me explain about the torchvision which is widely used in computer vision which goes hand in hand with PyTorch mainly used for multiple images/videos transformation also contains pre-trained models ( ResNet , InceptionNet , VGG-16/19 etc) and contains some datasets ( MNIST , FashionMnist , CIFAR 10/100 , Image Net etc) . torchvision does not come with PyTorch you have to download it separately from https://pypi.org/project/torchvision/
In the above we are importing from torch libraries nn which stands for Neural Network contains Linear Layer /Dense Layer , Activation function etc. , Next DataLoader i will explain it more about DataLoader when we will load training and testing dataset , torchvision which already i have explained it in detail above and last matplotlib library which we will used it for visualization of MNIST dataset
Download the MNIST DataSet
We will download the MNIST dataset from torchvision libraries already explained about MNIST dataset above contains 60k training images and 10k testing images we will use datasets libraries from torchvision to download MNIST its mainly takes the 4 parameters.
- root = basically the directory to store MNIST dataset
- train = True if it is training data and False if it is testing data
- transform = for multi transformation of dataset ( image resize , image augmentation , convert to tensor etc.)
- download = True
DataLoader
Now we will use DataLoader to prepare the dataset for training and testing iteration also divide the dataset into batches and random shuffle the dataset to overcome from overfitting. before loading the DataLoader we will check the shape of train_set and test_set
now load the DataLoader
now we will check the shape of images in each batch as we have taken batch_size = 64 so our image dimensions will be
( Batch_size , no_channels , height , width )
Visualize the MNIST dataset
we will visualize the MNIST dataset before going through we will first use iter which return iterator objects and by using next we can iterator through all data Value and after we will just check the shape what we have done above using for loop.
Build ANN Model for mnist
Now we will build Fully Connected Layer which consist of Input layer , hidden layers & output layer. first we will create a class ANN which inherent from parent class nn.Module the first method will be __init__ method in which we will initialize the all FC layers i.e Linear layer and also activation function to introduce non-linearity in network basically we will use ReLu activation function in hidden layer and Softmax activation function in output layer but we will not initialize Softmax as it comes hand in hand with Categorical Cross Entropy Loss in PyTorch.
Next method will be Forward method which take image as a parameter then the first things we should do it resize the image to (Batch_Size , No.channels *Height*Width) after pass thorough initialize FC layer with ReLu as Activation function and return vector out after passing through output layer the coding structure is shown below.
Initialize Loss & Optimizer
For Loss function we are going to use Categorical Cross Entropy Loss which mainly used for classification problem it gives probability the image number is may be of 8 or 5 or any other digit and optimizer we are going to use Adam optimizer to optimize/decrease the loss during back-propagation of model and update the weights.
Train & Validate the mnist Model
We reach to main step i.e training the model for that we are going to use 10 epochs the loss categorical cross entropy as we discuss above , the optimizer we are using Adam optimizer to update the weights during back propagation first we will loop thorough all the train_loader which has 64 batch_size so during first iteration we will take 64 images as batch_size is 64 then pass through network after output layer we will calculate the loss between the model predicted label and actual label of image after calculating loss we will do back propagate to update the weights by calculating gradient. the old weight are updated with new weight by using formula (Wnew = Wold – lr * gradient) its update weights for first batch after iterating through all batches the one epochs is completed that will be overall cost value .
Its will work for validation part too the things that during validation we don’t update the weights of our model we only check how our model is performing during testing time. in below code i have written step by step the implementation of Model , calculating Loss , Calculating Accuracy both Validation Accuracy and Training Accuracy and store in List so that we will see the graph how loss and accuracy is varing after each epochs.
Plot the Accuracy & Loss graph of mnist model
Now we will see how our training & validation loss decrease after every epochs and training & validation accuracy increase after every epochs. as we already calculated the accuracy and loss in above code which is store in list array we need to just plot the graph between accuracy , loss vs epochs
Training/Validation Accuracy Graph
Training/Validation Loss Graph
End Notes
I hope that this article clear about implementation of MNIST dataset using PyTorch from basic to advanced the next article will use Convolutional Neural Network for implementation of MNIST if you have any doubts in this article you can comment out we will give you response as soon as possible.
Tag:deep learning