top of page
  • Writer's pictureHackers Realm

Data Augmentation for Image Data | Keras Tensorflow | Python

Data augmentation is a fundamental technique in the field of computer vision and machine learning, particularly when working with image data. It involves applying various transformations and modifications to the original dataset to create new, altered versions of the images. The goal of data augmentation is to enhance the diversity and quality of the training data using python, leading to improved model generalization, robustness, and performance.

Data Augmentation for Image Data
Data Augmentation for Image Data

In the context of Keras and Tensorflow, data augmentation involves applying a series of transformations to your image dataset before feeding it into your neural network for training. These transformations can include rotations, flips, shifts, zooms, and more.


You can watch the video-based tutorial with step by step explanation down below.


Import Modules

from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import warnings
warnings.filterwarnings('ignore')
  • tensorflow.keras.datasets - provides access to several commonly used datasets for machine learning and deep learning tasks.

  • matplotlib.pyplot - provides a high-level interface for creating a wide variety of plots and visualizations in Python.

  • tensorflow.keras.preprocessing.image - provides a set of utilities for working with image data, including data preprocessing, augmentation, and loading images into a format suitable for training deep learning models.

  • numpy - provides support for working with large, multi-dimensional arrays and matrices, as well as a wide range of mathematical functions to operate on these arrays efficiently.

  • warnings - provides a way to handle warnings that are generated during program execution.


Load the dataset


Next load the MNIST dataset.

(X_train, y_train), (X_test, y_test) = mnist.load_data()
  • Load the MNIST dataset using Keras 'mnist.load_data()' function.

  • The MNIST dataset is a commonly used dataset in the field of machine learning and computer vision, containing a large collection of handwritten digits.

  • X_train: This variable will hold the training images. It's a 3D NumPy array with the shape (num_samples, height, width) where num_samples is the number of training examples, and height and width are the dimensions of each image (28x28 pixels in the case of MNIST).

  • y_train: This variable will hold the corresponding labels for the training images. It's a 1D NumPy array with the shape (num_samples,) containing integers representing the digit labels (0 to 9).

  • X_test: Similar to X_train, this variable holds the testing images.

  • y_test: Similar to y_train, this variable holds the corresponding labels for the testing images.

  • These arrays can be used to train and test machine learning models for tasks like digit recognition and image classification.


Next check the shapes of these arrays.

X_train.shape, X_test.shape

((60000, 28, 28), (10000, 28, 28))

  • The shape attribute of a NumPy array returns a tuple representing the dimensions of the array. In this case, the shapes of X_train and X_test will indicate the number of samples, height, and width of each image.

  • For the MNIST dataset, each image is a 28x28 grayscale image.

  • This indicates that there are 60,000 training images, each of size 28x28 pixels, and 10,000 testing images of the same size.


Preprocess the Data


Next reshape and convert the data types of the X_train and X_test arrays, which are the training and testing images from the MNIST dataset.

# reshape the data
X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))
# change the type to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
  • The MNIST images are originally in a 3D shape (num_samples, height, width), where num_samples is the number of images, and height and width are the dimensions of each image. However, many machine learning models, especially convolutional neural networks (CNNs), expect input data in the form (num_samples, height, width, channels), where channels represents the color channels of the image (1 for grayscale, 3 for RGB).

  • The code reshapes the data to have an additional dimension for the channels.

  • It's common to normalize the pixel values of images before feeding them into a model. This involves converting the pixel values from their original range (0 to 255 for grayscale images) to a normalized range (usually 0 to 1). To do this, it's necessary to change the data type of the arrays from integers to floating-point numbers.

  • After these transformations, X_train and X_test will have shapes (num_samples, height, width, 1) and their pixel values will be in the range of 0 to 1.

  • This pre-processing is commonly done to ensure that the data is ready for training a machine learning model, especially a CNN designed for image data.


Next check the shapes of these arrays after preprocessing.

X_train.shape, X_test.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

  • After applying the reshaping and data type conversion operations to the X_train and X_test arrays, the shapes of these arrays will change.

  • In this case, num_samples is 60,000 for training and 10,000 for testing, height and width are both 28 (the size of the images), and channels is set to 1, indicating that the images are grayscale.

  • This reshaped and converted data is now suitable for training neural networks, particularly convolutional neural networks (CNNs), which commonly expect images in this format.


Generate Grid of Images without Augmentation


Next generate a grid of images using the ImageDataGenerator class from Keras, which applies no augmentation transformations.

data_generator = ImageDataGenerator()
# configure batch for the images
for X_batch, y_batch in data_generator.flow(X_train, y_train, batch_size=16, shuffle=False):
    # create grid of 4x4 images
    fig, ax = plt.subplots(4, 4, figsize=(8, 8))
    for i in range(4):
        for j in range(4):
            ax[i][j].axis('off')
            ax[i][j].imshow(X_batch[i*4 + j].reshape(28, 28), cmap=plt.get_cmap('gray'))
    plt.show()
    break
Grid of images without Augmentation
Grid of images without Augmentation
  • You initialize the ImageDataGenerator without specifying any augmentation parameters, which means that the images will not be modified.

  • The code uses a loop to generate batches of data using the flow() function. It takes X_train and y_train as input and generates batches of size 16.

  • Within the loop, the code creates a 4x4 grid of subplots using Matplotlib. It then iterates through the grid and displays each image from the batch using imshow(). The images are reshaped to 28x28 pixels, and the grayscale colormap 'gray' is used.

  • After constructing the image grid, the plt.show() function is used to display the grid of images.

  • The loop contains a break statement, causing it to execute only once. This means that only one batch of images will be displayed.

  • The above code snippet is useful for visualizing a subset of the training images as a 4x4 grid. It displays only the first batch of 16 images due to the break statement.


Apply Data Augmentation


Next apply data augmentation using the ImageDataGenerator with various augmentation parameters.

# data augmentation
data_generator = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    zoom_range=0.2,
    shear_range=0.2,
    height_shift_range=0.2
)

# configure batch for the images
for X_batch, y_batch in data_generator.flow(X_train, y_train, batch_size=16, shuffle=False):
    # create grid of 4x4 images
    fig, ax = plt.subplots(4, 4, figsize=(8, 8))
    for i in range(4):
        for j in range(4):
            ax[i][j].axis('off')
            ax[i][j].imshow(X_batch[i*4 + j].reshape(28, 28), cmap=plt.get_cmap('gray'))
    plt.show()
    break
Grid of images with Data Augmentation
Grid of images with Data Augmentation
  • You initialize the ImageDataGenerator with augmentation parameters like rescale, rotation_range, zoom_range, shear_range, and height_shift_range. These parameters control the types of transformations that will be applied to the images during the data generation process. Let us look at this parameters one by one:

    1. rescale=1./255: This parameter scales the pixel values of the images by dividing them by 255. This normalization step ensures that the pixel values are within the range of 0 to 1, which is typically preferred for training neural networks.

    2. rotation_range=30: The rotation_range parameter specifies the range in degrees for random rotations of the images. In this case, images will be rotated by up to 30 degrees in a random direction.

    3. zoom_range=0.2: The zoom_range parameter controls the range for random zooming of the images. A value of 0.2 means that images can be zoomed in or out by up to 20%.

    4. shear_range=0.2: The shear_range parameter controls the range for applying shear transformations to the images. Shear transformations slant or tilt the images by a certain degree. A value of 0.2 means that shear transformations of up to 20% can be applied.

    5. height_shift_range=0.2: The height_shift_range parameter determines the range for vertically shifting the images. This can simulate variations in object position within the images. A value of 0.2 indicates that images can be shifted up or down by up to 20% of their height.

  • Similar to the previous code snippet, you use a loop to generate batches of data using the flow() function. This time, the data generator applies the specified augmentation transformations to the images.

  • Inside the loop, you create and display a grid of 4x4 images, just like before.

  • These augmented images can help improve the robustness and generalization capabilities of your deep learning model by exposing it to variations commonly encountered in real-world data.

  • When training your model, the flow() function of the ImageDataGenerator will generate batches of augmented images that have undergone these transformations, providing a more diverse and representative training dataset.


Final Thoughts

  • Augmented data introduces variability into the training process, making your model more resistant to overfitting and better able to generalize to unseen data.

  • By exposing your model to diverse variations of the same images, it learns to recognize patterns that are more consistent across different scenarios.

  • Data augmentation effectively multiplies the size of your training dataset, which can be particularly valuable when you have limited labeled data.

  • Augmentation helps reduce bias in your training dataset by presenting the model with a broader representation of possible input variations.

  • Not all augmentations make sense for all types of images. Tailor your augmentations to the domain and context of your data.

  • Augmentations should preserve the essence of the image. Overly aggressive augmentations can lead to unrealistic or distorted data that confuses the model.

  • Apply augmentation only to the training set, leaving validation and testing data unchanged. This ensures that the model's performance is evaluated on unaltered samples.

  • Experiment with different augmentation settings and parameters to find the right balance between diversity and relevance.

  • Augmentation doesn't replace the need for a well-architected model. A good model architecture is still essential for achieving optimal results.

  • Augmentation increases the time required for each training epoch, as each batch of images is generated on-the-fly with augmentations.

In summary, data augmentation is an indispensable technique for enhancing the performance and generalization capabilities of your deep learning models, especially when working with limited training data. By thoughtfully selecting and applying augmentation techniques that align with your problem domain, you can significantly boost your model's ability to handle real-world scenarios and improve its overall accuracy and robustness.


Get the project notebook from here


Thanks for reading the article!!!


Check out more project videos from the YouTube channel Hackers Realm

bottom of page