DEV Community

Cover image for Train a lines segmentation model using Pytorch
Mostafa Gazar
Mostafa Gazar

Posted on

Train a lines segmentation model using Pytorch

Let us start by identifying the problem we want to solve which is inspired by this project.

Given an image containing lines of text, returns a pixelwise labeling of that image, with each pixel belonging to either background or line of handwriting.


The project structure

It consists of 5 main sections, one for notebooks, one for the shared python code, datasets, Google Cloud scripts and one for saving the model weights.

In a production project, you will probably have more directories like web and api.

I also chose to use pipenv instead of conda and virtualenv to manage my python environment. I only recently switched to pipenv from conda and I found it to consistently work as expected everywhere.

For GPU training, I used a google cloud instance with one T4 Nvidia GPU. Bash scripts manage the instance lifecycle, from creating it initially to starting it, connecting to it and stopping it.

Data

The dataset is described in a toml file inside the raw directory, a toml file basically consists of key, value pairs. The other directories under data are git ignored because they will contain the actual full datasets downloads.

Notebooks

I use notebooks for exploration and as a high-level container for the code required to construct, clean datasets and build a training basic pipeline.

Python files

Under the src directory I keep the code that can be shared and reused between various notebooks. Following good Software Engineering practices is a key to get things done quickly and correctly, finding and identifying bugs in ML code can be extremely hard. That is why you would want to start small and reiterate often.

The python environment

You can install pipenv on Linux or mac using linuxbrew or macbrew with the following command:

brew install pipenv
Enter fullscreen mode Exit fullscreen mode

And then you can download your dependencies using pipenv install SOMETHING from your project directory.


The dataset

I will use this old academic dataset here as a base to build a lines segmentation dataset to train a UNet mini-network to detect lines of handwriting.

The original images in the dataset look like the following, they also come with XML files that define the bounding boxes.

In notebooks/01-explore-iam-dataset.ipynb I downloaded the dataset, unzipped it and then overplayed some random images with the data from the XML file.

Next, I cropped the images and generated masks images to match the new dimensions. The mask images are the ground truth images that we will use for training the final model.

Finally, I split the data into train, valid and test


The Network

Because we do not have a lot of data available for training, I used a mini version of the UNet architecture based on this Keras implementation.

And using this great library I can visualize the network by doing a feedforward with a specific input size.


The Training Pipeline

Now that we have the data ready and the network that we want to train defined, it is time to build a basic training pipeline.

First is defining a torch dataset and iterate through it using a DataLoader

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


class FormsDataset(Dataset):

    def __init__(self, images, masks, num_classes: int, transforms=None):
        self.images = images
        self.masks = masks
        self.num_classes = num_classes
        self.transforms = transforms

    def __getitem__(self, idx):
        image = self.images[idx]
        image = image.astype(np.float32)
        image = np.expand_dims(image, -1)
        image = image / 255
        if self.transforms:
            image = self.transforms(image)

        mask = self.masks[idx]
        mask = mask.astype(np.float32)
        mask = mask / 255
        mask[mask > .7] = 1
        mask[mask <= .7] = 0
        if self.transforms:
            mask = self.transforms(mask)

        return image, mask

    def __len__(self):
        return len(self.images)

train_dataset = FormsDataset(train_images, train_masks, number_of_classes, get_transformations(True))
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(f'Train dataset has {len(train_data_loader)} batches of size {batch_size}')
Enter fullscreen mode Exit fullscreen mode

Next, I define the training loop

# Use gpu for training if available else use cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Here is the loss and optimizer definition
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# The training loop
total_steps = len(train_data_loader)
print(f"{epochs} epochs, {total_steps} total_steps per epoch")

for epoch in range(epochs):
    for i, (images, masks) in enumerate(train_data_loader, 1):
        images = images.to(device)
        masks = masks.type(torch.LongTensor)
        masks = masks.reshape(masks.shape[0], masks.shape[2], masks.shape[3])
        masks = masks.to(device)

        # Forward pass
        outputs = model(images)
        softmax = F.log_softmax(outputs, dim=1)
        loss = criterion(softmax, masks)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i) % 100 == 0:
            print (f"Epoch [{epoch + 1}/{epochs}], Step [{i}/{total_steps}], Loss: {loss.item():4f}")
Enter fullscreen mode Exit fullscreen mode

Here are the final predictions


You can check a Keras backed by TF2 implementation here.


Thanks for making it this far. The last thing I would like to say is that unfortunately, most of the available online materials either offer bad advice or are very basic that they do not actually offer much value and some are plain wrong. There are some great resources though like their 60-minute blitz series and great API docs. There is also this cheat sheet and this great GitHub repo.


If you enjoyed reading this post and found it helpful I would love to hear from you, my Twitter DMs are open.

Top comments (0)