PromptNet / modules /tokenizers.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
import json
import re
from collections import Counter
class Tokenizer(object):
def __init__(self, args):
self.ann_path = args.ann_path
self.threshold = args.threshold
self.dataset_name = args.dataset_name
if self.dataset_name == 'iu_xray':
self.clean_report = self.clean_report_iu_xray
else:
self.clean_report = self.clean_report_mimic_cxr
self.ann = json.loads(open(self.ann_path, 'r').read())
self.token2idx, self.idx2token = self.create_vocabulary()
def create_vocabulary(self):
total_tokens = []
for example in self.ann['train']:
tokens = self.clean_report(example['report']).split()
for token in tokens:
total_tokens.append(token)
counter = Counter(total_tokens)
vocab = [k for k, v in counter.items() if v >= self.threshold] + ['<unk>']
vocab.sort()
token2idx, idx2token = {}, {}
for idx, token in enumerate(vocab):
token2idx[token] = idx + 1
idx2token[idx + 1] = token
return token2idx, idx2token
def clean_report_iu_xray(self, report):
report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
.replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
.replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
return report
def clean_report_mimic_cxr(self, report):
report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
.replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
.replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
.replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
.replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
return report
def get_token_by_id(self, id):
return self.idx2token[id]
def get_id_by_token(self, token):
if token not in self.token2idx:
return self.token2idx['<unk>']
return self.token2idx[token]
def get_vocab_size(self):
return len(self.token2idx)
def __call__(self, report):
tokens = self.clean_report(report).split()
ids = []
for token in tokens:
ids.append(self.get_id_by_token(token))
ids = [0] + ids + [0]
return ids
def decode(self, ids):
txt = ''
for i, idx in enumerate(ids):
if idx > 0:
if i >= 1:
txt += ' '
txt += self.idx2token[idx]
else:
break
return txt
def decode_batch(self, ids_batch):
out = []
for ids in ids_batch:
out.append(self.decode(ids))
return out