PromptNet / modules /dataset.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
raw
history blame
2.82 kB
import os
from PIL import Image
import json
from torch.utils.data import Dataset
import numpy as np
import torch
class BaseDataset(Dataset):
def __init__(self, args, tokenizer, split, processor):
self.image_dir = args.image_dir
self.ann_path = args.json_path
self.max_seq_length = args.max_seq_length
self.split = split
self.tokenizer = tokenizer
self.ann = json.loads(open(self.ann_path, 'r').read())
self.examples = self.ann[self.split]
self.processor = processor
def preprocess_text(self, text):
ids = self.tokenizer(text)[:self.max_seq_length]
mask = [1] * len(ids)
text_inputs = self.processor(text=text, return_tensors="pt",truncation=True, padding=False, max_length=self.max_seq_length)
processor_ids = text_inputs['input_ids'].squeeze(0).tolist()
processor_mask = text_inputs['attention_mask'].squeeze(0).tolist()
return ids, mask, processor_ids, processor_mask
def __len__(self):
return len(self.examples)
class IuxrayMultiImageDataset(BaseDataset):
def __getitem__(self, idx):
example = self.examples[idx]
report = example['report']
report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)
image_id = example['id']
image_path = example['image_path']
image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
# MedCLIP processing
image_inputs_1 = self.processor(images=image_1, return_tensors="pt")
image_inputs_2 = self.processor(images=image_2, return_tensors="pt")
image = torch.stack((image_inputs_1.pixel_values[0], image_inputs_2.pixel_values[0]), 0)
seq_length = len(report_ids)
processor_length = len(processor_ids)
sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
return sample
class MimiccxrSingleImageDataset(BaseDataset):
def __getitem__(self, idx):
example = self.examples[idx]
report = example['report']
report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)
image_id = example['id']
image_path = example['image_path']
image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
image_inputs = self.processor(images=image, return_tensors="pt")
image = image_inputs.pixel_values[0]
seq_length = len(report_ids)
processor_length = len(processor_ids)
sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
return sample