Spaces:
Build error
Build error
File size: 3,056 Bytes
230c9a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class ResizeLongestSide:
def __init__(self, size):
self.size = size
def __call__(self, img):
# Get the original dimensions
width, height = img.size
# Determine the scaling factor
if width > height:
new_width = self.size
new_height = int(height * (self.size / float(width)))
else:
new_height = self.size
new_width = int(width * (self.size / float(height)))
# Resize the image
return img.resize((new_width, new_height), Image.BILINEAR)
class ImageDataset(Dataset):
def __init__(self, images, image_ids=None, img_size=1280):
"""
Initialize the ImageDataset class.
Args:
- images (list): List of image paths or PIL.Image.Image objects.
- image_ids (list, optional): List of corresponding image IDs. If None, assumes images are paths.
- img_size (int): Size to which images' longest side will be resized.
"""
self.images = images
self.image_ids = image_ids if image_ids is not None else images
self.img_size = img_size
self.transform = transforms.Compose([
ResizeLongestSide(self.img_size),
transforms.ToTensor()
])
def __len__(self):
"""
Return the size of the dataset.
Returns:
int: Number of images in the dataset.
"""
return len(self.images)
def __getitem__(self, idx):
"""
Get an image and its corresponding ID by index.
Args:
- idx (int): Index of the image to retrieve.
Returns:
tuple: Transformed image tensor and corresponding image ID.
"""
image = self.images[idx]
image_id = self.image_ids[idx]
# Check if the image is a path or a PIL.Image object
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, Image.Image):
image = image.convert('RGB')
else:
raise ValueError("Image must be a file path or a PIL.Image object")
# Apply transformations
image = self.transform(image)
return image, image_id
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
|