top of page
Writer's pictureHackers Realm

Image Segmentation with UNET: A Step-by-Step Tutorial Using Oxford Pet Data and Keras Tensorflow

Image segmentation plays a crucial role in computer vision, enabling machines to understand and interpret the contents of an image at a pixel level. Whether it’s for medical imaging, autonomous driving, or satellite image analysis, segmentation allows us to classify each pixel into different categories, providing valuable insights and precise data. Among various image segmentation architectures, the U-Net model stands out for its simplicity, efficiency, and remarkable performance, especially in biomedical applications. We will implement the project using keras and tensorflow.

Image Segmentation tutorial
Image Segmentation tutorial

In this tutorial, we’ll explore how to implement image segmentation using the U-Net architecture with TensorFlow. U-Net’s distinctive design, characterized by its “U”-shaped structure with symmetric contraction and expansion paths, enables it to capture both local and global context, making it highly effective for accurate segmentation. By the end of this guide, you will have a solid understanding of how to preprocess image data, build a U-Net model from scratch, and train it to perform pixel-wise segmentation on your own datasets. Let's dive in!


You can watch the video-based tutorial with step by step explanation down below.


Import Modules


import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
  • import tensorflow as tf : This imports the TensorFlow library and assigns it the alias tf. TensorFlow is a deep learning framework used to build, train, and deploy machine learning models. In your code, it is used to construct neural networks and handle tasks related to deep learning.

  • from tensorflow.keras import layers, models : This imports two important submodules from Keras (which is included in TensorFlow)

    • layers: This contains pre-built neural network layers, such as Conv2D, MaxPooling2D, Dense, and Dropout. These layers are the building blocks used to construct a neural network.

    • models: This provides tools to create and manage neural networks. Specifically, it includes two key components:

      • Model: Used to instantiate and compile your neural network.

      • Sequential: Allows you to stack layers in a simple, linear order

  • import tensorflow_datasets as tfds : TensorFlow Datasets (tfds) is a library of ready-to-use datasets for machine learning. These datasets are often used for tasks like image classification, object detection, or image segmentation. This import allows you to load datasets, such as oxford_iiit_pet, directly from the TensorFlow library.

    • For example, the command tfds.load() is used to download and load a dataset along with its associated metadata (such as class labels).

  • import numpy as np : NumPy is a powerful library for numerical computing in Python. It provides support for arrays, mathematical functions, and random number generation, making it essential for handling and manipulating data before feeding it into a neural network. Here, it might be used to preprocess or augment data (e.g., converting arrays into specific shapes).

  • import matplotlib.pyplot as plt : This imports Matplotlib, a plotting library, and assigns its pyplot module the alias plt. It is primarily used to create visualizations like charts, plots, or images. Here, it's useful for visualizing input images and their corresponding segmentation masks.

    • For instance, plt.imshow() allows you to display images or data on a grid, while plt.show() renders the plotted figures.

  • %matplotlib inline : This is a magic command specific to Jupyter Notebooks that allows the output of Matplotlib plots to be displayed directly below the code cells. It’s not a standard Python command, but is commonly used in notebook environments to ensure that visualizations are rendered inline.


Load the Dataset


Next we will load the Oxford-IIIT Pet dataset using the tensorflow_datasets (tfds) library.

dataset, info = tfds.load('oxford_iiit_pet', with_info=True)
  • tfds.load(): This function loads datasets from TensorFlow Datasets (tfds), which provides a wide variety of pre-processed datasets that can be easily integrated into machine learning models. The function can return the dataset in various formats, such as a tf.data. Dataset object, and can also provide additional information about the dataset.

  • 'oxford_iiit_pet': This is the name of the dataset you are loading. The Oxford-IIIT Pet dataset is a popular dataset containing images of 37 pet breeds, along with corresponding segmentation masks. Each image has pixel-level labels, making it ideal for segmentation tasks.

  • with_info=True: This flag indicates that you also want to retrieve additional metadata about the dataset. It returns information such as:

    • Number of samples (training, testing).

    • Image dimensions.

    • Number of classes or labels.

    • Details about the dataset's structure.

  • dataset: This variable contains the loaded dataset, typically returned as a dictionary with keys like 'train' and 'test', each containing the training and testing sets, respectively. Each element of the dataset is typically a dictionary itself with fields like 'image' and 'segmentation_mask'.

  • info: This contains metadata about the dataset, such as:

    • The input/output shapes.

    • The number of classes.

    • The dataset splits (train/test).

    • Description and citation information.

  • Oxford-IIIT Pet Dataset:

    • Images: Contains images of 37 different pet breeds (cats and dogs).

    • Segmentation Masks: Pixel-wise labels indicating which pixels belong to the pet and which to the background.

    • Purpose: Used for image classification and segmentation tasks.


Next we will see the metadata information about the dataset.

info
Dataset information
Dataset information
  • After loading the dataset with with_info=True, the info object contains valuable metadata about the dataset.

  • info.splits:

    • This provides information about the dataset splits, such as train, test, or validation. For each split, it shows the number of examples.

  • info.features:

    • This provides details about the features of the dataset, including the structure of the input data (image, label, segmentation_mask), the type of data (int, float, string), and the shape of each feature.

  • info.description:

    • A textual description of the dataset, explaining what it contains and how it can be used. For the Oxford-IIIT Pet dataset, this typically includes information about the pet breeds, image resolution, and segmentation mask usage.

  • info.citation:

    • A citation reference for the dataset, so you can give proper credit if you use it in a research project or paper. This is often the original paper that introduced the dataset.

  • info.supervised_keys:

    • This shows the input-output pair used for training. For example, it tells you that image is used as input and segmentation_mask or label as output.

  • info.version:

    • The version of the dataset. TensorFlow Datasets occasionally updates the dataset to fix bugs or introduce improvements. This ensures you know which version you are using.

  • info.homepage:

    • The official webpage for the dataset, if available. This can be helpful to check for any extra information or related resources.

  • info.builder:

    • Provides the name and configuration details for the dataset builder, which created the dataset.


Next we will display the information about the dataset.

dataset

{'train': <_PrefetchDataset element_spec={'file_name': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(None, None, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'segmentation_mask': TensorSpec(shape=(None, None, 1), dtype=tf.uint8, name=None), 'species': TensorSpec(shape=(), dtype=tf.int64, name=None)}>, 'test': <_PrefetchDataset element_spec={'file_name': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(None, None, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'segmentation_mask': TensorSpec(shape=(None, None, 1), dtype=tf.uint8, name=None), 'species': TensorSpec(shape=(), dtype=tf.int64, name=None)}>}

  • The dataset object returned by tfds.load() contains the actual data, typically organized as a dictionary of TensorFlow Datasets (tf.data.Dataset objects), with each key representing a specific data split such as train, test, or validation. In the case of the Oxford-IIIT Pet dataset, the dataset is split into two parts: train and test.

  • image: A tensor representing the input image (typically in shape (height, width, channels)).

  • segmentation_mask: A tensor representing the pixel-wise segmentation labels (for pets and background).

  • label: A numerical label (integer) representing the breed of the pet in the image.

  • dataset['train']: The training dataset.

  • dataset['test']: The testing dataset.


Preprocessing Steps


First we will define the functions for preprocessing the dataset.

def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask = input_mask - 1 # convert to zero based indexing
    return input_image, input_mask

def load_train_images(sample):
    # resize the image
    input_image = tf.image.resize(sample['image'], (128, 128))
    input_mask = tf.image.resize(sample['segmentation_mask'], (128, 128))
    # data augmentation
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    # normalize the images
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

def load_test_images(sample):
    # resize the image
    input_image = tf.image.resize(sample['image'], (128, 128))
    input_mask = tf.image.resize(sample['segmentation_mask'], (128, 128))
    # normalize the images
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

Normalization Function:

  • tf.cast(input_image, tf.float32): Converts the image to a float32 tensor. This is important because images are typically represented as integers (0-255 for pixel values), but neural networks often perform better when working with floating-point numbers.

  • / 255.0: Normalizes the pixel values of the image from a range of [0, 255] to [0, 1]. Neural networks typically converge faster when the input values are small and normalized.

  • input_mask - 1: The segmentation masks in this dataset are 1-based indexed (the mask values are 1 for the object and 0 for the background). This converts the masks to 0-based indexing (i.e., 0 for background, and 1 for the pet), which is typically required by TensorFlow's loss functions for segmentation tasks.

Training Image Preprocessing: This function is used to preprocess the training images, including data augmentation. Here's what it does:

  • Resizing:

    • tf.image.resize(): Resizes both the input image and the segmentation mask to the shape (128, 128). This ensures that all the images have the same dimensions before feeding them into the neural network.

  • Data Augmentation:

    • tf.random.uniform(()) > 0.5: Generates a random number between 0 and 1. If this number is greater than 0.5, the image and mask are flipped horizontally.

    • tf.image.flip_left_right(): Flips the image and mask along the vertical axis. This introduces variability in the training data to help prevent overfitting and improve the model’s generalization.

  • Normalization:

    • After resizing and augmentation, the normalize function is called to scale the pixel values of the image and adjust the mask indexing, as explained earlier.

Test Image Preprocessing: This function is used to preprocess the test images without any data augmentation. It only resizes the images and masks and normalizes them, ensuring consistency during evaluation.

  • No Augmentation: Since test data should not be modified, there is no horizontal flipping or other augmentations applied here.

  • Resizing and Normalization: The resizing and normalization steps are the same as in the training function to ensure that the model receives test images in the same format as the training data.


Next we will refine our data pipeline by incorporating parallel processing to improve performance during data loading and preprocessing.

train_dataset = dataset['train'].map(load_train_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = dataset['test'].map(load_test_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
Data pipeline optimization
Data pipeline optimization
  • map() with num_parallel_calls:

    The map() function in TensorFlow’s tf.data API applies a function to each element of the dataset. In your case, you're applying the load_train_images and load_test_images functions to preprocess the images and masks for training and testing, respectively.

    • map(): Applies a function to transform elements of the dataset.

      • In this case, load_train_images is applied to each training sample, and load_test_images to each test sample.

  • num_parallel_calls=tf.data.experimental.AUTOTUNE:

    This argument allows for parallel processing of the data using multiple CPU threads to speed up the data loading and transformation process. Here’s what it does:

    • num_parallel_calls=tf.data.experimental.AUTOTUNE: This automatically determines the optimal number of parallel threads to use for the map() function, based on the system's CPU capacity and the complexity of the transformations. By using AUTOTUNE, TensorFlow dynamically adjusts the number of threads to balance the data preprocessing load and ensure that your GPU/TPU is not waiting for data, thereby improving the overall training speed.

  • train_dataset.map(load_train_images): This applies the load_train_images function to each sample in the training dataset, resizing, augmenting, and normalizing the data. By using num_parallel_calls, multiple images can be preprocessed at once, speeding up the data pipeline.

  • test_dataset.map(load_test_images): This applies the load_test_images function to each sample in the test dataset, resizing and normalizing them. AUTOTUNE ensures this process is done efficiently in parallel.


Exploratory Data Analysis


First we will  display a set of images side by side for comparison. It typically shows the input image, the true segmentation mask, and the predicted mask in a segmentation task.

def display_sample(image_list):
    plt.figure(figsize=(10,10))
    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(image_list)):
        plt.subplot(1, len(image_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(image_list[i]))
        plt.axis('off')

    plt.show()
  • plt.figure(figsize=(10,10)):

    • This creates a new figure for plotting with a specified size of 10x10 inches. The larger figure size ensures the images are displayed clearly and aren’t too small.

  • title = ['Input Image', 'True Mask', 'Predicted Mask']:

    • This list contains the titles for the subplots. Each image in image_list is given a corresponding title:

      • "Input Image": The raw input image from the dataset.

      • "True Mask": The actual segmentation mask (ground truth).

      • "Predicted Mask": The mask predicted by the model.

  • for i in range(len(image_list)):

    • The loop iterates over the images in image_list. The number of iterations is determined by the number of items in the list (typically 3: the input image, the true mask, and the predicted mask).

  • plt.subplot(1, len(image_list), i+1):

    • This function defines a grid of subplots. In this case, 1, len(image_list), i+1 means that the function will create a single row of subplots with as many columns as there are images in image_list.

      • i+1 specifies the position of the current subplot (starting from 1).

  • plt.title(title[i]):

    • Sets the title for the current subplot. It retrieves the title from the title list according to the index i.

  • plt.imshow(tf.keras.utils.array_to_img(image_list[i])):

    • This function displays the image in the i-th position of image_list. The tf.keras.utils.array_to_img() function converts the image (which is typically a TensorFlow tensor or a NumPy array) to a PIL image format, which is necessary for plt.imshow() to display it.

      • image_list[i] is expected to be an image in tensor form (e.g., (128, 128, 3) for an RGB image, or (128, 128, 1) for a grayscale mask).

  • plt.axis('off'):

    • This turns off the axis, so no ticks or labels are shown around the image, making the output look cleaner.

  • plt.show():

    • Finally, this renders and displays the plot with all the subplots in the figure.


Next we will use TensorFlow's data pipeline to take a few batches of images and masks from the train_dataset, then displaying a sample (one image and its corresponding mask) from each batch using the display_sample() function.

for images, masks in train_dataset.take(3):
    sample_image, sample_mask = images[0], masks[0]
    display_sample([sample_image, sample_mask])

Images from train dataset
Images from train dataset
  • train_dataset.take(3):

    • train_dataset is the dataset you created earlier, which has been preprocessed (including resizing, augmentation, and normalization) and potentially batched.

    • .take(3): This function limits the dataset to only the first 3 batches. It’s useful if you want to display only a few samples from the dataset for inspection or debugging.

    • This means that the loop will iterate 3 times, once for each of the first 3 batches of the train_dataset.

  • for images, masks in train_dataset.take(3): In each iteration of the loop, images and masks are the image and mask batches returned from the dataset.

    • images is a batch of images. Its shape would typically be something like (batch_size, height, width, channels). In your case, since you batch the data with .batch(32), the shape could be (32, 128, 128, 3) for a batch of 32 RGB images of size 128x128.

    • masks is a batch of segmentation masks. The shape is similar to the image batch, but typically with 1 channel for the mask (e.g., (32, 128, 128, 1)).

  • sample_image, sample_mask = images[0], masks[0]:

    • This selects the first image and mask from the current batch. Since each batch contains multiple images and masks, you are accessing the first one with images[0] and masks[0].

      • images[0]: This extracts the first image from the batch.

      • masks[0]: This extracts the first mask from the batch.

  • display_sample([sample_image, sample_mask]):

    • This function displays the selected image and its corresponding mask.

    • display_sample([sample_image, sample_mask]): Passes the sample_image and sample_mask as a list to the display_sample() function. Since the function expects an image list, this allows you to display the input image and its true mask side by side for visual inspection.


Define U-Net Model


Here we will define three key building blocks used to create a U-Net architecture for image segmentation: a double convolution block, a downsampling block, and an upsampling block. U-Net is commonly used for tasks like biomedical image segmentation, and these blocks are the core components of the model.

def double_conv_block(x, n_filters):
    x = layers.Conv2D(n_filters, 3, padding='same', activation='relu', kernel_initializer='he_normal')(x)
    x = layers.Conv2D(n_filters, 3, padding='same', activation='relu', kernel_initializer='he_normal')(x)
    return x

def downsample_block(x, n_filters):
    f = double_conv_block(x, n_filters)
    p = layers.MaxPool2D(2)(f)
    p = layers.Dropout(0.3)(p)
    return f, p

def upsample_block(x, conv_features, n_filters):
    x = layers.Conv2DTranspose(n_filters, 3, 2, padding='same')(x)
    x = layers.concatenate([x, conv_features])
    x = layers.Dropout(0.3)(x)
    x = double_conv_block(x, n_filters)
    return x
  • double_conv_block(x, n_filters):

    This block applies two consecutive convolutional layers to the input tensor x. It's a typical pattern in U-Net, where each convolution is followed by a ReLU activation function to introduce non-linearity.

    • First Convolution:

      • Applies a 2D convolution with n_filters filters, a kernel size of 3x3, and padding set to 'same' to keep the spatial dimensions unchanged.

      • The activation function used is ReLU, which adds non-linearity.

      • kernel_initializer='he_normal': Initializes the weights of the convolution using the He Normal initializer, which is well-suited for layers with ReLU activation.

    • Second Convolution:

      • Another 3x3 convolution is applied with the same settings as the first one (number of filters, activation, padding, and initialization)

    • The output is the result of two convolution operations, which helps the network learn more complex features.

  • downsample_block(x, n_filters):

    This block is part of the encoder (contracting path) of the U-Net. It applies the double convolution block, then downsamples the feature maps using max pooling. This process reduces the spatial dimensions, capturing coarse features while preserving the learned representations.

    • Feature extraction:

      • Calls the double_conv_block to apply two consecutive convolutions, producing a feature map f.

    • Downsampling:

      • layers.MaxPool2D(2): Applies max pooling with a pool size of 2x2 to reduce the spatial dimensions by half, allowing the network to learn hierarchical features at different scales.

    • Dropout:

      • layers.Dropout(0.3): Applies a dropout of 30% to prevent overfitting. It randomly drops some of the neurons in this layer during training.

  • upsample_block(x, conv_features, n_filters):

    This block is part of the decoder (expansive path) of the U-Net. It upsamples the input to restore the original spatial dimensions and concatenates it with the feature map from the corresponding downsampling block (using skip connections). This combination of upsampling and skip connections allows the network to recover spatial information lost during downsampling.

    • Upsampling:

      • layers.Conv2DTranspose(n_filters, 3, 2, padding='same'): A transposed convolution (or "deconvolution") upsamples the input by a factor of 2. This increases the spatial resolution of the feature maps.

    • Skip Connection:

      • layers.concatenate([x, conv_features]): The upsampled feature map x is concatenated with the corresponding feature map from the downsampling path (conv_features). This skip connection helps retain spatial information and fine details that were lost during downsampling.

    • Dropout:

      • layers.Dropout(0.3): A 30% dropout is applied to prevent overfitting.

    • Double Convolution:

      • The concatenated feature map is passed through the double_conv_block to further process the combined features.

    • The output is the result of the upsampling, concatenation, and convolution operations, which helps restore the original spatial resolution while preserving important features from the encoder.


Next we will construct a U-Net model for image segmentation tasks. U-Net is a fully convolutional network designed to predict a pixel-wise classification (segmentation) map for images. In this case, the function uses the encoder-decoder structure with skip connections, making it highly effective in tasks that require preserving both local (high-level) and global (low-level) features.

def build_unet_model(output_channels):
    # input layer
    inputs = layers.Input(shape=(128, 128, 3))

    # encoder - downsample
    f1, p1 = downsample_block(inputs, 64)
    f2, p2 = downsample_block(p1, 128)
    f3, p3 = downsample_block(p2, 256)
    f4, p4 = downsample_block(p3, 512)

    # intermediate block
    intermediate_block = double_conv_block(p4, 1024)

    # decoder - upsample
    u6 = upsample_block(intermediate_block, f4, 512)
    u7 = upsample_block(u6, f3, 256)
    u8 = upsample_block(u7, f2, 128)
    u9 = upsample_block(u8, f1, 64)

    # output layer
    outputs = layers.Conv2D(output_channels, 1, padding='same', activation='softmax')(u9)

    # unet model
    unet_model = tf.keras.Model(inputs, outputs, name='U-Net')

    return unet_model
  • Input Layer:

    • inputs: This defines the input layer of the model. It expects images of size 128x128 with 3 channels (RGB). The shape (128, 128, 3) indicates the height, width, and number of channels, respectively.

    • This is where the input images are fed into the network.

  • Encoder (Downsampling Path): The encoder uses the downsample_block to reduce the spatial dimensions and extract hierarchical features at different scales. Each block includes two convolutional layers, followed by max pooling and dropout.

    • Downsampling blocks progressively increase the number of filters (n_filters), which allows the model to capture more complex features.

      • f1, f2, f3, f4: Feature maps that retain important spatial information and are passed to the decoder for skip connections.

      • p1, p2, p3, p4: Pooled feature maps passed to the next layer for downsampling.

    • The number of filters in the convolution layers increases as the spatial dimensions are reduced. This helps the model capture high-level semantic information while reducing spatial resolution.

  • Bottleneck (Intermediate Block): The bottleneck is the part of the U-Net where the model reaches the smallest spatial dimensions but captures the most abstracted features.

    • Intermediate block: Two convolutional layers with 1024 filters are applied to the downsampled feature map p4.

    • This is the deepest point in the U-Net, where the network learns highly abstracted features.

  • Decoder (Upsampling Path): The decoder uses the upsample_block to increase the spatial dimensions and combine coarse, high-level features with fine details from the corresponding encoder blocks (skip connections).

    • Upsampling blocks progressively reduce the number of filters while increasing the spatial resolution.

    • Skip connections: The feature maps (f1, f2, f3, f4) from the corresponding encoder layers are concatenated with the upsampled feature maps, allowing the model to recover spatial details lost during downsampling.

      • For example, in u6 = upsample_block(intermediate_block, f4, 512), the upsampled feature map from the bottleneck is concatenated with the corresponding encoder feature map f4 (from the deepest layer of the encoder).

      • This helps the network to refine its predictions by combining high-level information from the decoder with spatial details from the encoder.

  •  Output Layer:

    • Conv2D layer: This is the final convolutional layer with a kernel size of 1x1, which reduces the number of channels in the output to output_channels.

      • output_channels: Specifies the number of classes for segmentation. For example, if you are performing binary segmentation, this will be 2 (background vs. object). For multi-class segmentation, output_channels would be the number of classes.

      • padding='same' ensures that the spatial dimensions remain the same after the convolution.

      • activation='softmax': The softmax activation function is used to produce a probability distribution for each pixel, assigning each pixel to a class.

  • Model Construction:

    • tf.keras.Model(inputs, outputs): Combines the input and output layers to define the U-Net model.

    • name='U-Net': Gives the model a name ("U-Net"), which helps in identifying the model when saving or loading it.


Next we will compile the U-Net model that you built earlier for a multi-class image segmentation task.

output_channels = 3
model = build_unet_model(output_channels)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  • output_channels = 3: This indicates that the model will output a segmentation mask with 3 classes. Each pixel in the input image will be classified into one of three categories (e.g., background, class 1, class 2).

  • You are calling the build_unet_model function to create the U-Net architecture, passing in output_channels = 3.

  • The model will use 1x1 convolution in the final layer with 3 output channels and the softmax activation function to assign each pixel a class.

  • 'adam': Adam (Adaptive Moment Estimation) is a popular optimization algorithm that adjusts the learning rate throughout training. It combines the advantages of two other optimizers: AdaGrad and RMSProp.

  • Adam is well-suited for tasks like image segmentation because it typically converges faster and requires less hyperparameter tuning.

  • sparse_categorical_crossentropy: This loss function is used for multi-class classification problems where the target class labels are integers (e.g., 0, 1, 2) rather than one-hot encoded vectors.

    • It’s called "sparse" because you pass the integer class labels directly instead of one-hot encoding them. Each pixel in the ground truth segmentation mask will have a class label (0, 1, or 2 in this case).

    • This is appropriate for multi-class segmentation where the goal is to predict one class for each pixel.

  • accuracy: This metric calculates the pixel-wise accuracy during training and evaluation. It measures how many pixels are correctly classified out of the total number of pixels.


Next we will generate a plot of the U-Net model architecture, showing the layers and their connections, as well as the shapes of the tensors that flow between the layers.

# plot the model
tf.keras.utils.plot_model(model, show_shapes=True, expand_nested=False, dpi=64)
U-Net Model
U-Net Model
  • model: This is the U-Net model you built earlier. It's being passed into plot_model to visualize its structure.

  • show_shapes=True:

    • This ensures that the shapes of the inputs and outputs of each layer will be displayed in the plot. It's useful for understanding how the dimensions change across layers, especially in U-Net where you have downsampling and upsampling.

  • expand_nested=False:

    • This argument controls whether nested models or layers (like Sequential or functional API layers inside other layers) should be expanded in the plot. Setting this to False keeps the model plot simpler.

  • dpi=64:

    • This sets the resolution of the plot to 64 dots per inch (DPI). A lower DPI value means a lower-resolution image, which is suitable for quick overviews. You can increase this value for higher resolution.


Train the Model


Next we will initiate the training process for your U-Net model.

EPOCHS = 20
steps_per_epoch = info.splits['train'].num_examples // BATCH_SIZE
validation_steps = info.splits['test'].num_examples // BATCH_SIZE

history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_data=test_dataset)
Logs generated during model training
Logs generated during model training
  • EPOCHS: This specifies the number of times the model will iterate through the entire training dataset. In this case, the model will train for 20 epochs, which means the entire dataset will be passed through the model 20 times.

  • steps_per_epoch: This calculates how many batches will be processed in one epoch during training. It is determined by dividing the total number of training examples (info.splits['train'].num_examples) by the batch size (BATCH_SIZE). This ensures that all training samples are processed once in each epoch.

    • For example, if the training dataset contains 1000 images and the batch size is 32, steps_per_epoch would be 1000 // 32 = 31 steps.

  • validation_steps: Similar to steps_per_epoch, this calculates how many batches will be processed for validation in each epoch.

  • train_dataset: This is your training dataset, which contains the preprocessed images and corresponding masks. It has been mapped using the load_train_images() function.

  • epochs=EPOCHS: Specifies that the model will train for 20 epochs.

  • steps_per_epoch=steps_per_epoch: This controls how many batches will be processed in each epoch, ensuring all training samples are covered.

  • validation_data=test_dataset: This is your test dataset, containing images and masks that the model will use to evaluate its performance after each epoch.

  • validation_steps=validation_steps: This controls how many batches will be processed for validation after each epoch.


Visualize the Results


Next we will plot the training and validation accuracy as well as the training and validation loss over the course of training.

# plot train & val accuracy
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Val'], loc='upper left')

# plot train & val loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['Train', 'Val'], loc='upper left')

plt.tight_layout()
plt.show()
Training Vs Validation accuracy plot
Training Vs Validation accuracy plot
  • plt.figure(figsize=(12, 4)): This creates a figure with a width of 12 inches and a height of 4 inches, giving enough space to display two subplots side by side (one for accuracy and one for loss).

  •  Accuracy Plot:

    • plt.subplot(1, 2, 1): This specifies that the next plot will be the first of two subplots (1 row, 2 columns, and this is the first plot).

    • history.history['accuracy']: This contains the training accuracy for each epoch.

    • history.history['val_accuracy']: This contains the validation accuracy for each epoch.

  • Loss Plot:

    • plt.subplot(1, 2, 2): This sets up the second subplot for the loss plot (the second of two plots in the same row).

    • history.history['loss']: This contains the training loss for each epoch.

    • history.history['val_loss']: This contains the validation loss for each epoch.

  • plt.tight_layout(): This adjusts the spacing between subplots to prevent overlap, ensuring the plot looks clean and well-organized.

  • plt.show(): This displays the plot.


Test Predictions


Next we will generate and display predictions from our trained U-Net model, comparing them to the ground truth masks.

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display_sample([image[0], mask[0], create_mask(pred_mask)])
  • create_mask(pred_mask): This function converts the predicted mask (which is a multi-class probability distribution) into a segmentation mask with class labels.

    • tf.argmax(pred_mask, axis=-1):

      • The model's output is a probability distribution across all classes for each pixel. tf.argmax finds the class with the highest probability for each pixel.

      • axis=-1: Refers to the last dimension (the class probabilities), where the model has output multiple channels (e.g., 3 channels for 3 classes). This function returns the index of the class with the highest probability, effectively turning the probabilities into a single predicted class label for each pixel.

    • pred_mask[..., tf.newaxis]:

      • Adds a new axis to the prediction mask to ensure it has the same dimensionality as the input image (height, width, 1).

    • return pred_mask[0]:

      • Returns the mask for the first image in the batch (as model predictions are typically batched). This allows you to handle a single image at a time.

  • show_predictions(dataset=None, num=1): This function shows the model's predictions on a given dataset by displaying the input image, the ground truth mask, and the predicted mask.


Next we will call the show_predictions() function.

show_predictions(test_dataset, 10)
  • Inside show_predictions():

    • For each image-mask pair in the test dataset:

      • The model generates a predicted mask using model.predict(image).

      • The predicted mask is processed by create_mask() to convert the probabilities into class labels.

      • The input image, true mask, and predicted mask are displayed side by side using display_sample().


Final Thoughts

  • We explored the powerful U-Net architecture for image segmentation using TensorFlow. We began by understanding the fundamentals of U-Net, which is specifically designed to handle pixel-level segmentation tasks, making it highly effective for applications like medical imaging, satellite imagery analysis, and more.

  • We walked through the key steps of building a U-Net model from scratch, including data preprocessing, creating a multi-step encoder-decoder architecture, and applying augmentation to enhance the model's robustness.

  • We also demonstrated how to train the model on the Oxford-IIIT Pet Dataset and visualize the performance by comparing predicted segmentation masks with ground truth masks.


By leveraging TensorFlow’s deep learning capabilities and U-Net's specialized structure, this approach provides an efficient and scalable solution to a variety of image segmentation tasks. As you continue to refine the model, consider tuning hyperparameters, adding more advanced augmentations, or experimenting with more complex datasets for even better results.


Get the project notebook from here


Thanks for reading the article!!!


Check out more project videos from the YouTube channel Hackers Realm

Comments


bottom of page