Demystifying PyTorch DataSets: Building a Chinese Character Dataset

Introduction

DataSets in PyTorch serve as a crucial bridge between raw data and the machine learning model during training. They encapsulate the logic needed to access, transform, and even augment the data, making the model’s training loop cleaner and more manageable. In this tutorial, we’ll shed light on what PyTorch DataSets are, how you can create one from the ground up, and how it can be seamlessly incorporated into your model’s training process.

To make this journey engaging and practical, we’ll embark on a unique project: constructing a DataSet that renders Chinese characters. The choice of Chinese characters is strategic: these characters encompass both simple and intricate features, mirroring the complexity often found in real-world data. Furthermore, the vast number of unique Chinese characters, compared to Latin letters, presents a greater challenge for classification models, making this an excellent toy problem. This simple, yet powerful example will illustrate the versatility of PyTorch DataSets and showcase their potential to handle even complex data types.

What is a DataSet in PyTorch?

A DataSet in PyTorch is an abstraction that represents a collection of data. It provides a standardized way to load, preprocess, and access the data in a unified manner, regardless of the type or structure of the original data source. This can include anything from images and text files to CSVs and databases. PyTorch Datasets are designed around two main methods: __getitem__ and __len__. The __len__ method returns the number of items in the dataset, while the __getitem__ method allows indexed access to individual data items, returning both input data and the corresponding target. This design allows the PyTorch DataLoader to efficiently and conveniently fetch batches of data, facilitating the implementation of large-scale machine learning models that learn from examples in an iterative manner.

Building the Chinese Character DataSet Class

In this tutorial, we aim to construct a DataSet that generates random Chinese characters on-demand. To bring our Chinese Character DataSet to life, we’ll break down the process into several manageable steps. First, we’ll identify the Unicode range corresponding to Chinese characters. Next, we’ll calculate the appropriate font size to ensure our character fits within a specified image size. We’ll then render each character onto an image canvas using Python’s Imaging Library (PIL). Following this, we’ll convert our image into a PyTorch tensor, a format suitable for machine learning models. Finally, to add a layer of realism and improve the robustness of our training, we’ll post-process our tensor by introducing random distortions, noise, and normalization.

Import statements

To bring our Chinese Character DataSet to life, we’ll be harnessing the power of several Python libraries. Let’s start by importing the necessary modules:

import torch
import unicodedata
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, RandomAffine, GaussianBlur
  • torch: This is the PyTorch library, which we’ll use to convert our images into tensors – the standard data format for deep learning models.
  • unicodedata: This will help us identify Chinese characters within the Unicode range.
  • matplotlib: We’ll use this popular plotting library to visualize our images.
  • PIL: Python’s Imaging Library (PIL) will be instrumental in rendering the characters onto an image canvas.
  • torchvision: Finally, torchvision’s transformation functions will enable us to introduce random distortions and noise to our images, helping simulate real-world imperfections.

Class Definition and Initializer

With our libraries imported, we can now define our Chinese Character DataSet class. Named ChineseCharacters, this class inherits from PyTorch’s Dataset class. It is initiated with an optional image_size parameter, dictating the size of the image canvas for rendering our characters. By default, we set the image size to 64×64 pixels.

class ChineseCharacters(Dataset):
    def __init__(self, image_size=(64, 64)):
        super().__init__()

        # Generate a list of all the Chinese characters
        start = int("4E00", 16)
        end = int("9FFF", 16)
        characters = [chr(i) for i in range(start, end+1) if unicodedata.category(chr(i)).startswith('Lo')]

        # Set the parameters
        self.image_size = image_size
        self.characters = characters
        self.num_characters = len(characters)
        self.font_path = "fonts/NotoSansSC-Regular.otf"

The initializer begins by generating a list of all Chinese characters. We use the unicodedata library to identify these characters within the Unicode range 4E00-9FFF. Following this, we establish several parameters for the class: the image size, the list and count of characters, and the font file’s path. Given that many fonts cannot render Chinese characters, we opt for the versatile Noto Sans font, which supports a broad spectrum of languages, including Chinese.

The __len__ Method

The len method is a crucial part of any DataSet in PyTorch. It’s responsible for reporting the number of items in the dataset. For our Chinese Character DataSet, the len method returns the total count of unique Chinese characters we’ve identified:

def __len__(self):
    return self.num_characters

By implementing this method, we enable PyTorch’s data loader to correctly iterate over our dataset during the training process, handling data fetching and batching.

The __getitem__ Method

The heart of the DataSet class, the __getitem__ method, is responsible for loading or creating data. It should return both the input data (in this case, an image of a Chinese character) and the corresponding target (the character itself).

To accomplish this, we need to calculate the appropriate font size to render the character onto an image canvas properly. We define an internal helper function, _get_font_size, to handle this task. In Python, it’s common to prefix an underscore to methods that are intended for internal use within a class.

def _get_font_size(self, text, font_ratio):

    # Calculate the maximum font size based on the image size
    max_font_size = int(max(self.image_size) * font_ratio)

    # Load the font file
    font = ImageFont.truetype(self.font_path, size=max_font_size)

    # Calculate the size of the text with the maximum font size
    text_bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox((0, 0), text, font)
    text_size = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]

    # Calculate the font size that fits within the image
    font_size = int(font_ratio * max(self.image_size) / max(text_size) * max_font_size)
    return font_size

The _get_font_size function calculates the maximum possible font size given the image size, then uses the PIL library to measure the rendered text’s size. Finally, it finds an optimal font size that ensures the text fits within the image dimensions.

With the font size available, we can then render the character onto an image. We’ll define another internal helper function, print_character, for this purpose:

def print_character(self, text, font_ratio=0.8):

    # Create an image
    image = Image.new('L', self.image_size, color=255)

    # Get the font
    font_size = self._get_font_size(text, font_ratio)
    font = ImageFont.truetype(self.font_path, size=font_size)

    # Get the text size and location
    text_bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox((0, 0), text, font)
    text_size = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
    ascent = int(font.getmetrics()[1] * font_size / text_size[1])
    x = (self.image_size[0] - text_size[0]) // 2
    y = (self.image_size[1] - text_size[1]) // 2 - ascent

    # Draw the text
    draw = ImageDraw.Draw(image)
    draw.text((x, y), text, font=font, fill=0)

    # Convert the image to a tensor
    image = ToTensor()(image)

    # Return the image
    return image

The print_character function creates a new image canvas, calculates the appropriate font size and the text’s position, and then renders the text onto the canvas. The image is then converted to a PyTorch tensor.

Now we are equipped to define the main __getitem__ method. This function will not only return our created image and the target character but also incorporate an optional transformation parameter. This can be used for applying data augmentation to the image, enhancing the robustness of our model.

def __getitem__(self, idx, transform=True):
    # Get the character
    character = self.characters[idx]

    # Get the image
    image = self.print_character(character)

    # Distort the image
    if transform:
        image = RandomAffine(
            degrees=10,
            translate=(0.1, 0.1),
            scale=(0.8, 1),
            shear=(0.1, 0.1),
            fill=1
        )(image)
        image = GaussianBlur(kernel_size=9, sigma=.02*max(self.image_size))(image)

    # Finalize the image
    image += torch.randn(*image.shape) * .01
    image -= image.min()
    if image.max() > 0:
        image /= image.max()

    return image

Here, __getitem__ retrieves the character to be rendered and creates an image of it. If the transform parameter is set to True, the method will apply data augmentation techniques to the image. Finally, it adds random noise and normalizes the image before returning it.

Testing the DataSet

Now that we have our dataset class defined, it’s essential to show how it can be utilized. In this section, we demonstrate how to create a dataset instance and display a series of images from it.

Note: The code block below is wrapped in an if __name__ == "__main__": construct. This standard Python practice ensures the code is only executed when the script is run directly, not when it is imported as a module by another script.

if __name__ == "__main__":
    # Create a dataset
    dataset = ChineseCharacters()

    # Get a random character
    fig = plt.gcf()
    fig.clf()
    plt.ion()
    plt.show()
    for i in range(100):
        image = dataset[i]
        plt.clf()
        plt.imshow(image[0, :, :].detach().numpy(), cmap='gray')
        plt.pause(.1)

    print("Done.")

This code creates an instance of the ChineseCharacters dataset. It then uses matplotlib to display a series of 100 images drawn from this dataset in real-time. The loop accesses the images, clears the current plot (if any), displays the new image, and then briefly pauses before moving on to the next one.

Below is an example of the output:

This demonstration shows that our dataset correctly generates and visualizes Chinese characters, confirming that our code functions as expected.

Using the DataSet with a DataLoader

Often times, DataSets are paired with DataLoaders. The DataLoader class is responsible for managing various aspects of the data loading process, including batching, shuffling, and multiprocessing. In this section, we’ll briefly demonstrate how to use our ChineseCharacters dataset with a DataLoader.

Here is an example script that creates a DataLoader instance and uses it to iterate through the dataset:

# Import the DataLoader and DataSet
from torch.utils.data import DataLoader
from character_printer import ChineseCharacters

# Create an instance of the DataSet
dataset = ChineseCharacters()

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Loop through the DataLoader
for batch in dataloader:
    images = batch
    # Here you could perform operations such as a forward pass, backward pass, 
    # updating weights, etc.

In this loop, each batch will contain a batch of 32 images from the ChineseCharacters DataSet. This allows you to load and process data in manageable chunks, making the DataLoader an invaluable tool in deep learning workflows.

Conclusion

Throughout this tutorial, we have delved into the concept of a DataSet, elucidating its importance and functionality in the realm of deep learning. We didn’t stop at theoretical understanding; instead, we rolled up our sleeves and developed a hands-on ‘toy’ dataset from scratch. This DataSet, with its emphasis on generating Chinese characters, can serve as an insightful starting point for harnessing these tools in your deep learning endeavors.

But, the knowledge gained from this exercise transcends this specific application. The principles and techniques acquired here equip you to build tailored DataSets for a wide array of applications, enabling you to pave your unique path in the diverse world of deep learning.

We genuinely hope this tutorial was helpful, and we invite you to share your experiences, thoughts, and questions in the comments below. Your feedback not only helps us improve, but also contributes to a vibrant learning community. Happy experimenting!

Author


Posted

in

by

Tags: