Introduction
DataSets in PyTorch serve as a crucial bridge between raw data and the machine learning model during training. They encapsulate the logic needed to access, transform, and even augment the data, making the model’s training loop cleaner and more manageable. In this tutorial, we’ll shed light on what PyTorch DataSets are, how you can create one from the ground up, and how it can be seamlessly incorporated into your model’s training process.
To make this journey engaging and practical, we’ll embark on a unique project: constructing a DataSet that renders Chinese characters. The choice of Chinese characters is strategic: these characters encompass both simple and intricate features, mirroring the complexity often found in real-world data. Furthermore, the vast number of unique Chinese characters, compared to Latin letters, presents a greater challenge for classification models, making this an excellent toy problem. This simple, yet powerful example will illustrate the versatility of PyTorch DataSets and showcase their potential to handle even complex data types.
What is a DataSet in PyTorch?
A DataSet
in PyTorch is an abstraction that represents a collection of data. It provides a standardized way to load, preprocess, and access the data in a unified manner, regardless of the type or structure of the original data source. This can include anything from images and text files to CSVs and databases. PyTorch Datasets are designed around two main methods: __getitem__
and __len__
. The __len__
method returns the number of items in the dataset, while the __getitem__
method allows indexed access to individual data items, returning both input data and the corresponding target. This design allows the PyTorch DataLoader to efficiently and conveniently fetch batches of data, facilitating the implementation of large-scale machine learning models that learn from examples in an iterative manner.
Building the Chinese Character DataSet Class
In this tutorial, we aim to construct a DataSet that generates random Chinese characters on-demand. To bring our Chinese Character DataSet to life, we’ll break down the process into several manageable steps. First, we’ll identify the Unicode range corresponding to Chinese characters. Next, we’ll calculate the appropriate font size to ensure our character fits within a specified image size. We’ll then render each character onto an image canvas using Python’s Imaging Library (PIL). Following this, we’ll convert our image into a PyTorch tensor, a format suitable for machine learning models. Finally, to add a layer of realism and improve the robustness of our training, we’ll post-process our tensor by introducing random distortions, noise, and normalization.
Import statements
To bring our Chinese Character DataSet to life, we’ll be harnessing the power of several Python libraries. Let’s start by importing the necessary modules:
import torch
import unicodedata
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, RandomAffine, GaussianBlur
- torch: This is the PyTorch library, which we’ll use to convert our images into tensors – the standard data format for deep learning models.
- unicodedata: This will help us identify Chinese characters within the Unicode range.
- matplotlib: We’ll use this popular plotting library to visualize our images.
- PIL: Python’s Imaging Library (PIL) will be instrumental in rendering the characters onto an image canvas.
- torchvision: Finally, torchvision’s transformation functions will enable us to introduce random distortions and noise to our images, helping simulate real-world imperfections.
Class Definition and Initializer
With our libraries imported, we can now define our Chinese Character DataSet class. Named ChineseCharacters
, this class inherits from PyTorch’s Dataset class. It is initiated with an optional image_size parameter, dictating the size of the image canvas for rendering our characters. By default, we set the image size to 64×64 pixels.
class ChineseCharacters(Dataset):
def __init__(self, image_size=(64, 64)):
super().__init__()
# Generate a list of all the Chinese characters
start = int("4E00", 16)
end = int("9FFF", 16)
characters = [chr(i) for i in range(start, end+1) if unicodedata.category(chr(i)).startswith('Lo')]
# Set the parameters
self.image_size = image_size
self.characters = characters
self.num_characters = len(characters)
self.font_path = "fonts/NotoSansSC-Regular.otf"
The initializer begins by generating a list of all Chinese characters. We use the unicodedata library to identify these characters within the Unicode range 4E00-9FFF. Following this, we establish several parameters for the class: the image size, the list and count of characters, and the font file’s path. Given that many fonts cannot render Chinese characters, we opt for the versatile Noto Sans font, which supports a broad spectrum of languages, including Chinese.
The __len__
Method
The len method is a crucial part of any DataSet in PyTorch. It’s responsible for reporting the number of items in the dataset. For our Chinese Character DataSet, the len method returns the total count of unique Chinese characters we’ve identified:
def __len__(self):
return self.num_characters
By implementing this method, we enable PyTorch’s data loader to correctly iterate over our dataset during the training process, handling data fetching and batching.
The __getitem__
Method
The heart of the DataSet class, the __getitem__
method, is responsible for loading or creating data. It should return both the input data (in this case, an image of a Chinese character) and the corresponding target (the character itself).
To accomplish this, we need to calculate the appropriate font size to render the character onto an image canvas properly. We define an internal helper function, _get_font_size
, to handle this task. In Python, it’s common to prefix an underscore to methods that are intended for internal use within a class.
def _get_font_size(self, text, font_ratio):
# Calculate the maximum font size based on the image size
max_font_size = int(max(self.image_size) * font_ratio)
# Load the font file
font = ImageFont.truetype(self.font_path, size=max_font_size)
# Calculate the size of the text with the maximum font size
text_bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox((0, 0), text, font)
text_size = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
# Calculate the font size that fits within the image
font_size = int(font_ratio * max(self.image_size) / max(text_size) * max_font_size)
return font_size
The _get_font_size
function calculates the maximum possible font size given the image size, then uses the PIL
library to measure the rendered text’s size. Finally, it finds an optimal font size that ensures the text fits within the image dimensions.
With the font size available, we can then render the character onto an image. We’ll define another internal helper function, print_character
, for this purpose:
def print_character(self, text, font_ratio=0.8):
# Create an image
image = Image.new('L', self.image_size, color=255)
# Get the font
font_size = self._get_font_size(text, font_ratio)
font = ImageFont.truetype(self.font_path, size=font_size)
# Get the text size and location
text_bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox((0, 0), text, font)
text_size = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
ascent = int(font.getmetrics()[1] * font_size / text_size[1])
x = (self.image_size[0] - text_size[0]) // 2
y = (self.image_size[1] - text_size[1]) // 2 - ascent
# Draw the text
draw = ImageDraw.Draw(image)
draw.text((x, y), text, font=font, fill=0)
# Convert the image to a tensor
image = ToTensor()(image)
# Return the image
return image
The print_character
function creates a new image canvas, calculates the appropriate font size and the text’s position, and then renders the text onto the canvas. The image is then converted to a PyTorch tensor.
Now we are equipped to define the main __getitem__
method. This function will not only return our created image and the target character but also incorporate an optional transformation parameter. This can be used for applying data augmentation to the image, enhancing the robustness of our model.
def __getitem__(self, idx, transform=True):
# Get the character
character = self.characters[idx]
# Get the image
image = self.print_character(character)
# Distort the image
if transform:
image = RandomAffine(
degrees=10,
translate=(0.1, 0.1),
scale=(0.8, 1),
shear=(0.1, 0.1),
fill=1
)(image)
image = GaussianBlur(kernel_size=9, sigma=.02*max(self.image_size))(image)
# Finalize the image
image += torch.randn(*image.shape) * .01
image -= image.min()
if image.max() > 0:
image /= image.max()
return image
Here, __getitem__
retrieves the character to be rendered and creates an image of it. If the transform parameter is set to True, the method will apply data augmentation techniques to the image. Finally, it adds random noise and normalizes the image before returning it.
Testing the DataSet
Now that we have our dataset class defined, it’s essential to show how it can be utilized. In this section, we demonstrate how to create a dataset instance and display a series of images from it.
Note: The code block below is wrapped in an if __name__ == "__main__":
construct. This standard Python practice ensures the code is only executed when the script is run directly, not when it is imported as a module by another script.
if __name__ == "__main__":
# Create a dataset
dataset = ChineseCharacters()
# Get a random character
fig = plt.gcf()
fig.clf()
plt.ion()
plt.show()
for i in range(100):
image = dataset[i]
plt.clf()
plt.imshow(image[0, :, :].detach().numpy(), cmap='gray')
plt.pause(.1)
print("Done.")
This code creates an instance of the ChineseCharacters dataset. It then uses matplotlib to display a series of 100 images drawn from this dataset in real-time. The loop accesses the images, clears the current plot (if any), displays the new image, and then briefly pauses before moving on to the next one.
Below is an example of the output:
This demonstration shows that our dataset correctly generates and visualizes Chinese characters, confirming that our code functions as expected.
Using the DataSet with a DataLoader
Often times, DataSets are paired with DataLoaders. The DataLoader class is responsible for managing various aspects of the data loading process, including batching, shuffling, and multiprocessing. In this section, we’ll briefly demonstrate how to use our ChineseCharacters
dataset with a DataLoader.
Here is an example script that creates a DataLoader instance and uses it to iterate through the dataset:
# Import the DataLoader and DataSet
from torch.utils.data import DataLoader
from character_printer import ChineseCharacters
# Create an instance of the DataSet
dataset = ChineseCharacters()
# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Loop through the DataLoader
for batch in dataloader:
images = batch
# Here you could perform operations such as a forward pass, backward pass,
# updating weights, etc.
In this loop, each batch will contain a batch of 32 images from the ChineseCharacters DataSet. This allows you to load and process data in manageable chunks, making the DataLoader an invaluable tool in deep learning workflows.
Conclusion
Throughout this tutorial, we have delved into the concept of a DataSet, elucidating its importance and functionality in the realm of deep learning. We didn’t stop at theoretical understanding; instead, we rolled up our sleeves and developed a hands-on ‘toy’ dataset from scratch. This DataSet, with its emphasis on generating Chinese characters, can serve as an insightful starting point for harnessing these tools in your deep learning endeavors.
But, the knowledge gained from this exercise transcends this specific application. The principles and techniques acquired here equip you to build tailored DataSets for a wide array of applications, enabling you to pave your unique path in the diverse world of deep learning.
We genuinely hope this tutorial was helpful, and we invite you to share your experiences, thoughts, and questions in the comments below. Your feedback not only helps us improve, but also contributes to a vibrant learning community. Happy experimenting!