File size: 2,817 Bytes
6e32a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from PIL import Image
import json
from torch.utils.data import Dataset
import numpy as np
import torch


class BaseDataset(Dataset):
    def __init__(self, args, tokenizer, split, processor):
        self.image_dir = args.image_dir
        self.ann_path = args.json_path
        self.max_seq_length = args.max_seq_length
        self.split = split
        self.tokenizer = tokenizer
        self.ann = json.loads(open(self.ann_path, 'r').read())
        self.examples = self.ann[self.split]
        self.processor = processor

    def preprocess_text(self, text):
        ids = self.tokenizer(text)[:self.max_seq_length]
        mask = [1] * len(ids)
        text_inputs = self.processor(text=text, return_tensors="pt",truncation=True, padding=False, max_length=self.max_seq_length)
        processor_ids = text_inputs['input_ids'].squeeze(0).tolist()
        processor_mask = text_inputs['attention_mask'].squeeze(0).tolist()
        return ids, mask, processor_ids, processor_mask

    def __len__(self):
        return len(self.examples)


class IuxrayMultiImageDataset(BaseDataset):
    def __getitem__(self, idx):
        example = self.examples[idx]
        report = example['report']
        report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)

        image_id = example['id']
        image_path = example['image_path']
        image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
        image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
        # MedCLIP processing
        image_inputs_1 = self.processor(images=image_1, return_tensors="pt")
        image_inputs_2 = self.processor(images=image_2, return_tensors="pt")
        image = torch.stack((image_inputs_1.pixel_values[0], image_inputs_2.pixel_values[0]), 0)

        seq_length = len(report_ids)
        processor_length = len(processor_ids)
        sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
        return sample


class MimiccxrSingleImageDataset(BaseDataset):
    def __getitem__(self, idx):
        example = self.examples[idx]
        report = example['report']
        report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)

        image_id = example['id']
        image_path = example['image_path']
        image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
        image_inputs = self.processor(images=image, return_tensors="pt")
        image = image_inputs.pixel_values[0]

        seq_length = len(report_ids)
        processor_length = len(processor_ids)
        sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
        return sample