##################################### # Packages & Dependencies ##################################### import torch from torchvision import datasets from torchvision.transforms import v2 from torch.utils.data import DataLoader import utils from typing import Tuple import io import base64 from PIL import Image import numpy as np # Transformations applied to each image BASE_TRANSFORMS = v2.Compose([ v2.ToImage(), # Convert to tensor v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1] v2.Normalize(mean = [0.1307], std = [0.3081]) # Normalize with MNIST stats ]) TRAIN_TRANSFORMS = v2.Compose([ v2.RandomAffine(degrees = 15, # Rotate up to -/+ 15 degrees scale = (0.8, 1.2), # Scale between 80 and 120 percent translate = (0.08, 0.08), # Translate up to -/+ 8 percent in both x and y shear = 10), # Shear up to -/+ 10 degrees v2.ToImage(), # Convert to tensor v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1] v2.Normalize(mean = [0.1307], std = [0.3081]), # Normalize with MNIST stats ]) ##################################### # Functions ##################################### def get_dataloaders(root: str, batch_size: int, num_workers: int = 0) -> Tuple[DataLoader, DataLoader]: ''' Creates training and testing dataloaders for the MNIST dataset Args: root (str): Path to download MNIST data. batch_size (int): Size used to split training and testing datasets into batches. num_workers (int): Number of workers to use for multiprocessing. Default is 0. ''' # Get training and testing MNIST data mnist_train = datasets.MNIST(root, download = True, train = True, transform = TRAIN_TRANSFORMS) mnist_test = datasets.MNIST(root, download = True, train = False, transform = BASE_TRANSFORMS) # Create dataloaders if num_workers > 0: mp_context = utils.MP_CONTEXT persistent_workers = True else: mp_context = None persistent_workers = False train_dl = DataLoader( dataset = mnist_train, batch_size = batch_size, shuffle = True, num_workers = num_workers, multiprocessing_context = mp_context, pin_memory = utils.PIN_MEM, persistent_workers = persistent_workers ) test_dl = DataLoader( dataset = mnist_test, batch_size = batch_size, shuffle = False, num_workers = num_workers, multiprocessing_context = mp_context, pin_memory = utils.PIN_MEM, persistent_workers = persistent_workers ) return train_dl, test_dl def mnist_preprocess(uri: str): ''' Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset. The pipeline includes: 1. Converting the image to grayscale. 2. Resizing the image to 20x20, preserving the aspect ratio, and using anti-aliasing. 3. Centering the resized image in a 28x28 image based on the center of mass (COM). 4. Converting the image to a tensor (pixel values between 0 and 1) and normalizing it using MNIST statistics. Reference: https://paperswithcode.com/dataset/mnist Args: uri (str): A string representing the full data URI. Returns: Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics. ''' encoded_img = uri.split(',', 1)[1] image_bytes = io.BytesIO(base64.b64decode(encoded_img)) pil_img = Image.open(image_bytes).convert('L') # Gray scale # Resize to 20x20, preserving aspect ratio, and using anti-aliasing pil_img.thumbnail((20, 20), Image.Resampling.LANCZOS) # Convert to numpy and invert image img = 255 - np.array(pil_img) # Get image indices for y-axis (rows) and x-axis (columns) img_idxs = np.indices(img.shape) tot_mass = img.sum() # This represents the indices of the center of masses (COMs) com_x = np.round((img_idxs[1] * img).sum() / tot_mass).astype(int) com_y = np.round((img_idxs[0] * img).sum() / tot_mass).astype(int) dist_com_end_x = img.shape[1] - com_x # number of column indices from com_x to last index dist_com_end_y = img.shape[0] - com_y # number of row indices from com_y to last index new_img = np.zeros((28, 28), dtype = np.uint8) new_com_x, new_com_y = 14, 14 # Indices of the COMs for the new 28x28 image valid_start_x = min(new_com_x, com_x) valid_end_x = min(14, dist_com_end_x) # 14 is index distance from new COM to 28-th index valid_start_y = min(new_com_y, com_y) valid_end_y = min(14, dist_com_end_y) # 14 is index distance from new COM to 28-th index old_slice_x = slice(com_x - valid_start_x, com_x + valid_end_x) old_slice_y = slice(com_y - valid_start_y, com_y + valid_end_y) new_slice_x = slice(new_com_x - valid_start_x, new_com_x + valid_end_x) new_slice_y = slice(new_com_y - valid_start_y, new_com_y + valid_end_y) # Paste cropped image into 28x28 field such that the old COM (com_y, com_x), is at the center (14, 14) new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x] # Return transformed tensor of new image. This includes normalizing to MNIST stats return BASE_TRANSFORMS(new_img)