PromptNet / modules /datasets.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
import os
import json
import torch
from PIL import Image
from torch.utils.data import Dataset
class BaseDataset(Dataset):
def __init__(self, args, tokenizer, split, transform=None):
self.image_dir = args.image_dir
self.ann_path = args.ann_path
self.max_seq_length = args.max_seq_length
self.split = split
self.tokenizer = tokenizer
self.transform = transform
self.ann = json.loads(open(self.ann_path, 'r').read())
self.examples = self.ann[self.split]
for i in range(len(self.examples)):
self.examples[i]['ids'] = tokenizer(self.examples[i]['report'])[:self.max_seq_length]
self.examples[i]['mask'] = [1] * len(self.examples[i]['ids'])
def __len__(self):
return len(self.examples)
class IuxrayMultiImageDataset(BaseDataset):
def __getitem__(self, idx):
example = self.examples[idx]
image_id = example['id']
image_path = example['image_path']
image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
if self.transform is not None:
image_1 = self.transform(image_1)
image_2 = self.transform(image_2)
image = torch.stack((image_1, image_2), 0)
report_ids = example['ids']
report_masks = example['mask']
seq_length = len(report_ids)
sample = (image_id, image, report_ids, report_masks, seq_length)
return sample
class MimiccxrSingleImageDataset(BaseDataset):
def __getitem__(self, idx):
example = self.examples[idx]
image_id = example['id']
image_path = example['image_path']
image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
if self.transform is not None:
image = self.transform(image)
report_ids = example['ids']
report_masks = example['mask']
seq_length = len(report_ids)
sample = (image_id, image, report_ids, report_masks, seq_length)
return sample