CIFAR-10 Image Classifier
- Ajay Reddy
- Nov 3, 2021
- 3 min read
Updated: Nov 3, 2021
Classification of CIFAR-10 data set by using Convolutional Neural Networks
In this project, we will try to build a Neural network model using Pytorch and test it on the CIFAR-10 dataset to check what accuracy of prediction can be obtained.
CIFAR-10 is a dataset that has a collection of 60,000 images of 10 different classes and the images in CIFAR-10 are of size 3x32x32(3-channel color, 32x32 pixels in size). This dataset is widely used for research purposes to test different machine learning models and especially for computer vision problems. The dataset has the classes : airplane, automobile, bird, dog, horse, cat, deer, truck, frog and ship.

Training an image classifier
To train the model that classify the images we need to perform the following steps:
Load and normalize the CIFAR10 training and test datasets using torchvision
Define a Convolutional Neural Network
Define a loss function
Train the network on the training data
Test the network on the test data
1. Importing the libraries, loading and normalizing the data


2. Define a Convolutional Neural Network
Neural networks can be constructed using the torch.nn package.
As per the tutorial we have defined two convolutional layers with a kernel of size 5 and three fully connected layers.

Convolutional Layer: A convolution is a linear operation that involves the multiplication of a set of weights with the input, much like a traditional neural network. Given that the technique was designed for two-dimensional input, the multiplication is performed between an array of input data and a two-dimensional array of weights, called a filter or a kernel.
Pooling Layer: These are also known as downsampling, conducts dimensionality reduction, reducing the number of parameters in the input.
Fully connected layer: This layer performs the task of classification based on the features extracted through the previous layers and their different filters. While convolutional and pooling layers tend to use ReLu functions, FC layers usually leverage a softmax activation function to classify inputs appropriately, producing a probability from 0 to 1.
3. Define a loss function
In this classification model we used cross entropy as the loss function and Stochastic Gradient Descent as the optimizer.

4. Train the network
To train the network we need to iterate the model over the training data, calculate the loss and optimize

Here the output is divided into small batches of 2000 images and the loss for each batch is given
5. Test the network
Now we evaluate the efficiency of the model based on the accuracy of the test data.

As we can see the accuracy of the model is 53% . We can improve the accuracy by changing the hyperparameters of the model. This is explained more clearly in 'My Contribution' section.
Below, we calculated the accuracy of each class to find out which classes performed well.

Challenge:
The most difficult challenge I faced while working on this project is to increase the accuracy score of the model. On keen research for hours on the neural network, I have learned that by changing the hyperparameters and keep on experimenting on those values we can achieve increased accuracy. Accuracy can also be increased by changing the optimizer and keep experimenting.
My Contribution:
As seen above the accuracy of the model defined as per the tutorial is 54%. Now to improve the accuracy I have experimented by changing the values of the following hyperparameters of the model.
Convolutional Layers: I have increased the number of convolutional layers to 3 and have decreased the kernel size to 3.
Fully connected Layers:
Epoch: In the tutorial the epoch is 2 but I have increased it to 6 so that more number of iterations results in a better trained model.

This model resulted in increasing the accuracy to a whopping 70%.
Plotting the bar Graph:
As we found out the accuracy of this model below is the accuracy of the individual class.

Here is the code of my CNN model
References:
Bar graph plot: https://www.geeksforgeeks.org/bar-plot-in-matplotlib/
CNN model: https://www.youtube.com/watch?v=pDdP0TFzsoQ



Comments