import numpy as np import torch from PIL import Image from torch.utils.data import Dataset import torchvision.transforms as transforms class ResizeLongestSide: def __init__(self, size): self.size = size def __call__(self, img): # Get the original dimensions width, height = img.size # Determine the scaling factor if width > height: new_width = self.size new_height = int(height * (self.size / float(width))) else: new_height = self.size new_width = int(width * (self.size / float(height))) # Resize the image return img.resize((new_width, new_height), Image.BILINEAR) class ImageDataset(Dataset): def __init__(self, images, image_ids=None, img_size=1280): """ Initialize the ImageDataset class. Args: - images (list): List of image paths or PIL.Image.Image objects. - image_ids (list, optional): List of corresponding image IDs. If None, assumes images are paths. - img_size (int): Size to which images' longest side will be resized. """ self.images = images self.image_ids = image_ids if image_ids is not None else images self.img_size = img_size self.transform = transforms.Compose([ ResizeLongestSide(self.img_size), transforms.ToTensor() ]) def __len__(self): """ Return the size of the dataset. Returns: int: Number of images in the dataset. """ return len(self.images) def __getitem__(self, idx): """ Get an image and its corresponding ID by index. Args: - idx (int): Index of the image to retrieve. Returns: tuple: Transformed image tensor and corresponding image ID. """ image = self.images[idx] image_id = self.image_ids[idx] # Check if the image is a path or a PIL.Image object if isinstance(image, str): image = Image.open(image).convert('RGB') elif isinstance(image, Image.Image): image = image.convert('RGB') else: raise ValueError("Image must be a file path or a PIL.Image object") # Apply transformations image = self.transform(image) return image, image_id class MathDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # if not pil image, then convert to pil image if isinstance(self.image_paths[idx], str): raw_image = Image.open(self.image_paths[idx]) else: raw_image = self.image_paths[idx] if self.transform: image = self.transform(raw_image) return image