Spaces:
Sleeping
Sleeping
File size: 4,042 Bytes
6e32a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import json
import re
from collections import Counter
class Tokenizer(object):
def __init__(self, args):
self.ann_path = args.ann_path
self.threshold = args.threshold
self.dataset_name = args.dataset_name
if self.dataset_name == 'iu_xray':
self.clean_report = self.clean_report_iu_xray
else:
self.clean_report = self.clean_report_mimic_cxr
self.ann = json.loads(open(self.ann_path, 'r').read())
self.token2idx, self.idx2token = self.create_vocabulary()
def create_vocabulary(self):
total_tokens = []
for example in self.ann['train']:
tokens = self.clean_report(example['report']).split()
for token in tokens:
total_tokens.append(token)
counter = Counter(total_tokens)
vocab = [k for k, v in counter.items() if v >= self.threshold] + ['<unk>']
vocab.sort()
token2idx, idx2token = {}, {}
for idx, token in enumerate(vocab):
token2idx[token] = idx + 1
idx2token[idx + 1] = token
return token2idx, idx2token
def clean_report_iu_xray(self, report):
report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
.replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
.replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
return report
def clean_report_mimic_cxr(self, report):
report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
.replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
.replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
.replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
.replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
return report
def get_token_by_id(self, id):
return self.idx2token[id]
def get_id_by_token(self, token):
if token not in self.token2idx:
return self.token2idx['<unk>']
return self.token2idx[token]
def get_vocab_size(self):
return len(self.token2idx)
def __call__(self, report):
tokens = self.clean_report(report).split()
ids = []
for token in tokens:
ids.append(self.get_id_by_token(token))
ids = [0] + ids + [0]
return ids
def decode(self, ids):
txt = ''
for i, idx in enumerate(ids):
if idx > 0:
if i >= 1:
txt += ' '
txt += self.idx2token[idx]
else:
break
return txt
def decode_batch(self, ids_batch):
out = []
for ids in ids_batch:
out.append(self.decode(ids))
return out
|