PromptNet / modules /dataloaders.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
raw
history blame
2.44 kB
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from .datasets import IuxrayMultiImageDataset, MimiccxrSingleImageDataset
class R2DataLoader(DataLoader):
def __init__(self, args, tokenizer, split, shuffle):
self.args = args
self.dataset_name = args.dataset_name
self.batch_size = args.batch_size
self.shuffle = shuffle
self.num_workers = args.num_workers
self.tokenizer = tokenizer
self.split = split
if split == 'train':
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
else:
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
if self.dataset_name == 'iu_xray':
self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)
else:
self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)
self.init_kwargs = {
'dataset': self.dataset,
'batch_size': self.batch_size,
'shuffle': self.shuffle,
'collate_fn': self.collate_fn,
'num_workers': self.num_workers
}
super().__init__(**self.init_kwargs)
@staticmethod
def collate_fn(data):
images_id, images, reports_ids, reports_masks, seq_lengths = zip(*data)
images = torch.stack(images, 0)
max_seq_length = max(seq_lengths)
targets = np.zeros((len(reports_ids), max_seq_length), dtype=int)
targets_masks = np.zeros((len(reports_ids), max_seq_length), dtype=int)
for i, report_ids in enumerate(reports_ids):
targets[i, :len(report_ids)] = report_ids
for i, report_masks in enumerate(reports_masks):
targets_masks[i, :len(report_masks)] = report_masks
return images_id, images, torch.LongTensor(targets), torch.FloatTensor(targets_masks)