import json import re from collections import Counter class Tokenizer(object): def __init__(self, args): self.ann_path = args.ann_path self.threshold = args.threshold self.dataset_name = args.dataset_name if self.dataset_name == 'iu_xray': self.clean_report = self.clean_report_iu_xray else: self.clean_report = self.clean_report_mimic_cxr self.ann = json.loads(open(self.ann_path, 'r').read()) self.token2idx, self.idx2token = self.create_vocabulary() def create_vocabulary(self): total_tokens = [] for example in self.ann['train']: tokens = self.clean_report(example['report']).split() for token in tokens: total_tokens.append(token) counter = Counter(total_tokens) vocab = [k for k, v in counter.items() if v >= self.threshold] + [''] vocab.sort() token2idx, idx2token = {}, {} for idx, token in enumerate(vocab): token2idx[token] = idx + 1 idx2token[idx + 1] = token return token2idx, idx2token def clean_report_iu_xray(self, report): report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \ .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \ .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ .strip().lower().split('. ') sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', ''). replace('\\', '').replace("'", '').strip().lower()) tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] report = ' . '.join(tokens) + ' .' return report def clean_report_mimic_cxr(self, report): report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ .strip().lower().split('. ') sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') .replace('\\', '').replace("'", '').strip().lower()) tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] report = ' . '.join(tokens) + ' .' return report def get_token_by_id(self, id): return self.idx2token[id] def get_id_by_token(self, token): if token not in self.token2idx: return self.token2idx[''] return self.token2idx[token] def get_vocab_size(self): return len(self.token2idx) def __call__(self, report): tokens = self.clean_report(report).split() ids = [] for token in tokens: ids.append(self.get_id_by_token(token)) ids = [0] + ids + [0] return ids def decode(self, ids): txt = '' for i, idx in enumerate(ids): if idx > 0: if i >= 1: txt += ' ' txt += self.idx2token[idx] else: break return txt def decode_batch(self, ids_batch): out = [] for ids in ids_batch: out.append(self.decode(ids)) return out