|
|
|
|
|
|
|
import torch |
|
from torchvision import datasets |
|
from torchvision.transforms import v2 |
|
from torch.utils.data import DataLoader |
|
|
|
import utils |
|
from typing import Tuple |
|
|
|
import io |
|
import base64 |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
BASE_TRANSFORMS = v2.Compose([ |
|
v2.ToImage(), |
|
v2.ToDtype(torch.float32, scale = True), |
|
v2.Normalize(mean = [0.1307], std = [0.3081]) |
|
]) |
|
|
|
TRAIN_TRANSFORMS = v2.Compose([ |
|
v2.RandomAffine(degrees = 15, |
|
scale = (0.8, 1.2), |
|
translate = (0.08, 0.08), |
|
shear = 10), |
|
v2.ToImage(), |
|
v2.ToDtype(torch.float32, scale = True), |
|
v2.Normalize(mean = [0.1307], std = [0.3081]), |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
def get_dataloaders(root: str, |
|
batch_size: int, |
|
num_workers: int = 0) -> Tuple[DataLoader, DataLoader]: |
|
''' |
|
Creates training and testing dataloaders for the MNIST dataset |
|
|
|
Args: |
|
root (str): Path to download MNIST data. |
|
batch_size (int): Size used to split training and testing datasets into batches. |
|
num_workers (int): Number of workers to use for multiprocessing. Default is 0. |
|
''' |
|
|
|
|
|
mnist_train = datasets.MNIST(root, download = True, train = True, |
|
transform = TRAIN_TRANSFORMS) |
|
mnist_test = datasets.MNIST(root, download = True, train = False, |
|
transform = BASE_TRANSFORMS) |
|
|
|
|
|
if num_workers > 0: |
|
mp_context = utils.MP_CONTEXT |
|
persistent_workers = True |
|
else: |
|
mp_context = None |
|
persistent_workers = False |
|
|
|
train_dl = DataLoader( |
|
dataset = mnist_train, |
|
batch_size = batch_size, |
|
shuffle = True, |
|
num_workers = num_workers, |
|
multiprocessing_context = mp_context, |
|
pin_memory = utils.PIN_MEM, |
|
persistent_workers = persistent_workers |
|
) |
|
|
|
test_dl = DataLoader( |
|
dataset = mnist_test, |
|
batch_size = batch_size, |
|
shuffle = False, |
|
num_workers = num_workers, |
|
multiprocessing_context = mp_context, |
|
pin_memory = utils.PIN_MEM, |
|
persistent_workers = persistent_workers |
|
) |
|
|
|
return train_dl, test_dl |
|
|
|
def mnist_preprocess(uri: str): |
|
''' |
|
Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset. |
|
The pipeline includes: |
|
1. Converting the image to grayscale. |
|
2. Resizing the image to 20x20, preserving the aspect ratio, and using anti-aliasing. |
|
3. Centering the resized image in a 28x28 image based on the center of mass (COM). |
|
4. Converting the image to a tensor (pixel values between 0 and 1) and normalizing it using MNIST statistics. |
|
|
|
Reference: https://paperswithcode.com/dataset/mnist |
|
|
|
Args: |
|
uri (str): A string representing the full data URI. |
|
|
|
Returns: |
|
Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics. |
|
''' |
|
encoded_img = uri.split(',', 1)[1] |
|
image_bytes = io.BytesIO(base64.b64decode(encoded_img)) |
|
pil_img = Image.open(image_bytes).convert('L') |
|
|
|
|
|
pil_img.thumbnail((20, 20), Image.Resampling.LANCZOS) |
|
|
|
|
|
img = 255 - np.array(pil_img) |
|
|
|
|
|
img_idxs = np.indices(img.shape) |
|
tot_mass = img.sum() |
|
|
|
|
|
com_x = np.round((img_idxs[1] * img).sum() / tot_mass).astype(int) |
|
com_y = np.round((img_idxs[0] * img).sum() / tot_mass).astype(int) |
|
|
|
dist_com_end_x = img.shape[1] - com_x |
|
dist_com_end_y = img.shape[0] - com_y |
|
|
|
new_img = np.zeros((28, 28), dtype = np.uint8) |
|
new_com_x, new_com_y = 14, 14 |
|
|
|
valid_start_x = min(new_com_x, com_x) |
|
valid_end_x = min(14, dist_com_end_x) |
|
valid_start_y = min(new_com_y, com_y) |
|
valid_end_y = min(14, dist_com_end_y) |
|
|
|
old_slice_x = slice(com_x - valid_start_x, com_x + valid_end_x) |
|
old_slice_y = slice(com_y - valid_start_y, com_y + valid_end_y) |
|
new_slice_x = slice(new_com_x - valid_start_x, new_com_x + valid_end_x) |
|
new_slice_y = slice(new_com_y - valid_start_y, new_com_y + valid_end_y) |
|
|
|
|
|
new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x] |
|
|
|
|
|
return BASE_TRANSFORMS(new_img) |
|
|