X-RayDemo / dataset /data_helper.py
tousin23's picture
Upload 41 files
6551065 verified
import os
import json
import re
import numpy as np
from PIL import Image
import torch.utils.data as data
from transformers import BertTokenizer, AutoImageProcessor
class FieldParser:
def __init__(
self,
args
):
super().__init__()
self.args = args
self.dataset = args.dataset
self.vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model)
def _parse_image(self, img):
pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
return pixel_values[0]
# from https://github.com/cuhksz-nlp/R2Gen/blob/main/modules/tokenizers.py
def clean_report(self, report):
# clean Iu-xray reports
if self.dataset == "iu_xray":
report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
.replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
.replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
# clean MIMIC-CXR reports
else:
report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
.replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
.replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
.replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ').replace(':', ' :') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+()\[\]{}]', '', t.replace('"', '').replace('/', '')
.replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
# report = ' '.join(report.split()[:self.args.max_txt_len])
return report
def parse(self, features):
to_return = {'id': features['id']}
report = features.get("report", "")
report = self.clean_report(report)
to_return['input_text'] = report
# chest x-ray images
images = []
for image_path in features['image_path']:
with Image.open(os.path.join(self.args.base_dir, image_path)) as pil:
array = np.array(pil, dtype=np.uint8)
if array.shape[-1] != 3 or len(array.shape) != 3:
array = np.array(pil.convert("RGB"), dtype=np.uint8)
image = self._parse_image(array)
images.append(image)
to_return["image"] = images
return to_return
def transform_with_parse(self, inputs):
return self.parse(inputs)
class ParseDataset(data.Dataset):
def __init__(self, args, split='train'):
self.args = args
self.meta = json.load(open(args.annotation, 'r'))
self.meta = self.meta[split]
self.parser = FieldParser(args)
def __len__(self):
return len(self.meta)
def __getitem__(self, index):
return self.parser.transform_with_parse(self.meta[index])
def create_datasets(args):
train_dataset = ParseDataset(args, 'train')
dev_dataset = ParseDataset(args, 'val')
test_dataset = ParseDataset(args, 'test')
return train_dataset, dev_dataset, test_dataset