Tzktz's picture
Upload 7664 files
6fc683c verified
# coding=utf-8
# Copyright 2020 Google and DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import argparse
from transformers import BertTokenizer, XLMTokenizer, XLMRobertaTokenizer
import os
from collections import defaultdict
import csv
import random
import os
import shutil
import json
TOKENIZERS = {
'bert': BertTokenizer,
'xlm': XLMTokenizer,
'xlmr': XLMRobertaTokenizer,
}
def panx_tokenize_preprocess(args):
def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len):
if not os.path.exists(infile):
print(f'{infile} not exists')
return 0
special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2
max_seq_len = max_len - special_tokens_count
subword_len_counter = idx = 0
with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx:
for line in fin:
line = line.strip()
if not line:
fout.write('\n')
fidx.write('\n')
idx += 1
subword_len_counter = 0
continue
items = line.split()
token = items[0].strip()
if len(items) == 2:
label = items[1].strip()
else:
label = 'O'
current_subwords_len = len(tokenizer.tokenize(token))
if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0:
token = tokenizer.unk_token
current_subwords_len = 1
if (subword_len_counter + current_subwords_len) > max_seq_len:
fout.write(f"\n{token}\t{label}\n")
fidx.write(f"\n{idx}\n")
subword_len_counter = current_subwords_len
else:
fout.write(f"{token}\t{label}\n")
fidx.write(f"{idx}\n")
subword_len_counter += current_subwords_len
return 1
model_type = args.model_type
tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None)
for lang in args.languages.split(','):
out_dir = os.path.join(args.output_dir, lang)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
if lang == 'en':
files = ['dev', 'test', 'train']
else:
files = ['dev', 'test']
for file in files:
infile = os.path.join(args.data_dir, f'{file}-{lang}.tsv')
outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path))
idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path))
if os.path.exists(outfile) and os.path.exists(idxfile):
print(f'{outfile} and {idxfile} exist')
else:
code = _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len)
if code > 0:
print(f'finish preprocessing {outfile}')
def panx_preprocess(args):
def _process_one_file(infile, outfile):
lines = open(infile, 'r').readlines()
if lines[-1].strip() == '':
lines = lines[:-1]
with open(outfile, 'w') as fout:
for l in lines:
items = l.strip().split('\t')
if len(items) == 2:
label = items[1].strip()
idx = items[0].find(':')
if idx != -1:
token = items[0][idx+1:].strip()
# if 'test' in infile:
# fout.write(f'{token}\n')
# else:
# fout.write(f'{token}\t{label}\n')
fout.write(f'{token}\t{label}\n')
else:
fout.write('\n')
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
langs = 'ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu'.split(' ')
for lg in langs:
for split in ['train', 'test', 'dev']:
infile = os.path.join(args.data_dir, f'{lg}-{split}')
outfile = os.path.join(args.output_dir, f'{split}-{lg}.tsv')
_process_one_file(infile, outfile)
def udpos_tokenize_preprocess(args):
def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len):
if not os.path.exists(infile):
print(f'{infile} does not exist')
return
subword_len_counter = idx = 0
special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2
max_seq_len = max_len - special_tokens_count
with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx:
for line in fin:
line = line.strip()
if len(line) == 0 or line == '':
fout.write('\n')
fidx.write('\n')
idx += 1
subword_len_counter = 0
continue
items = line.split()
if len(items) == 2:
label = items[1].strip()
else:
label = "X"
token = items[0].strip()
current_subwords_len = len(tokenizer.tokenize(token))
if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0:
token = tokenizer.unk_token
current_subwords_len = 1
if (subword_len_counter + current_subwords_len) > max_seq_len:
fout.write(f"\n{token}\t{label}\n")
fidx.write(f"\n{idx}\n")
subword_len_counter = current_subwords_len
else:
fout.write(f"{token}\t{label}\n")
fidx.write(f"{idx}\n")
subword_len_counter += current_subwords_len
model_type = args.model_type
tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None)
for lang in args.languages.split(','):
out_dir = os.path.join(args.output_dir, lang)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
if lang == 'en':
files = ['dev', 'test', 'train']
else:
files = ['dev', 'test']
for file in files:
infile = os.path.join(args.data_dir, "{}-{}.tsv".format(file, lang))
outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path))
idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path))
if os.path.exists(outfile) and os.path.exists(idxfile):
print(f'{outfile} and {idxfile} exist')
else:
_preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len)
print(f'finish preprocessing {outfile}')
def udpos_preprocess(args):
def _read_one_file(file):
data = []
sent, tag, lines = [], [], []
for line in open(file, 'r'):
items = line.strip().split('\t')
if len(items) != 10:
empty = all(w == '_' for w in sent)
num_empty = sum([int(w == '_') for w in sent])
if num_empty == 0 or num_empty < len(sent) - 1:
data.append((sent, tag, lines))
sent, tag, lines = [], [], []
else:
sent.append(items[1].strip())
tag.append(items[3].strip())
lines.append(line.strip())
assert len(sent) == int(items[0]), 'line={}, sent={}, tag={}'.format(line, sent, tag)
return data
def isfloat(value):
try:
float(value)
return True
except ValueError:
return False
def remove_empty_space(data):
new_data = {}
for split in data:
new_data[split] = []
for sent, tag, lines in data[split]:
new_sent = [''.join(w.replace('\u200c', '').split(' ')) for w in sent]
lines = [line.replace('\u200c', '') for line in lines]
assert len(" ".join(new_sent).split(' ')) == len(tag)
new_data[split].append((new_sent, tag, lines))
return new_data
def check_file(file):
for i, l in enumerate(open(file)):
items = l.strip().split('\t')
assert len(items[0].split(' ')) == len(items[1].split(' ')), 'idx={}, line={}'.format(i, l)
def _write_files(data, output_dir, lang, suffix):
for split in data:
if len(data[split]) > 0:
prefix = os.path.join(output_dir, f'{split}-{lang}')
if suffix == 'mt':
with open(prefix + '.mt.tsv', 'w') as fout:
for idx, (sent, tag, _) in enumerate(data[split]):
newline = '\n' if idx != len(data[split]) - 1 else ''
# if split == 'test':
# fout.write('{}{}'.format(' '.join(sent, newline)))
# else:
# fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline))
fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline))
check_file(prefix + '.mt.tsv')
print(' - finish checking ' + prefix + '.mt.tsv')
elif suffix == 'tsv':
with open(prefix + '.tsv', 'w') as fout:
for sidx, (sent, tag, _) in enumerate(data[split]):
for widx, (w, t) in enumerate(zip(sent, tag)):
newline = '' if (sidx == len(data[split]) - 1) and (widx == len(sent) - 1) else '\n'
# if split == 'test':
# fout.write('{}{}'.format(w, newline))
# else:
# fout.write('{}\t{}{}'.format(w, t, newline))
fout.write('{}\t{}{}'.format(w, t, newline))
fout.write('\n')
elif suffix == 'conll':
with open(prefix + '.conll', 'w') as fout:
for _, _, lines in data[split]:
for l in lines:
fout.write(l.strip() + '\n')
fout.write('\n')
print(f'finish writing file to {prefix}.{suffix}')
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
languages = 'af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh'.split(' ')
for root, dirs, files in os.walk(args.data_dir):
lg = root.strip().split('/')[-1]
if root == args.data_dir or lg not in languages:
continue
data = {k: [] for k in ['train', 'dev', 'test']}
for f in sorted(files):
if f.endswith('conll'):
file = os.path.join(root, f)
examples = _read_one_file(file)
if 'train' in f:
data['train'].extend(examples)
elif 'dev' in f:
data['dev'].extend(examples)
elif 'test' in f:
data['test'].extend(examples)
else:
print('split not found: ', file)
print(' - finish reading {}, {}'.format(file, [(k, len(v)) for k,v in data.items()]))
data = remove_empty_space(data)
for sub in ['tsv']:
_write_files(data, args.output_dir, lg, sub)
def pawsx_preprocess(args):
def _preprocess_one_file(infile, outfile, remove_label=False):
data = []
for i, line in enumerate(open(infile, 'r')):
if i == 0:
continue
items = line.strip().split('\t')
sent1 = ' '.join(items[1].strip().split(' '))
sent2 = ' '.join(items[2].strip().split(' '))
label = items[3]
data.append([sent1, sent2, label])
with open(outfile, 'w') as fout:
writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='')
for sent1, sent2, label in data:
# if remove_label:
# writer.writerow([sent1, sent2])
# else:
# writer.writerow([sent1, sent2, label])
writer.writerow([sent1, sent2, label])
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
split2file = {'train': 'train', 'test': 'test_2k', 'dev': 'dev_2k'}
for lang in ['en', 'de', 'es', 'fr', 'ja', 'ko', 'zh']:
for split in ['train', 'test', 'dev']:
if split == 'train' and lang != 'en':
continue
file = split2file[split]
infile = os.path.join(args.data_dir, lang, "{}.tsv".format(file))
outfile = os.path.join(args.output_dir, "{}-{}.tsv".format(split, lang))
_preprocess_one_file(infile, outfile, remove_label=(split == 'test'))
print(f'finish preprocessing {outfile}')
def xnli_preprocess(args):
def _preprocess_file(infile, output_dir, split):
all_langs = defaultdict(list)
for i, line in enumerate(open(infile, 'r')):
if i == 0:
continue
items = line.strip().split('\t')
lang = items[0].strip()
label = "contradiction" if items[1].strip() == "contradictory" else items[1].strip()
sent1 = ' '.join(items[6].strip().split(' '))
sent2 = ' '.join(items[7].strip().split(' '))
all_langs[lang].append((sent1, sent2, label))
print(f'# langs={len(all_langs)}')
for lang, pairs in all_langs.items():
outfile = os.path.join(output_dir, '{}-{}.tsv'.format(split, lang))
with open(outfile, 'w') as fout:
writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='')
for (sent1, sent2, label) in pairs:
# if split == 'test':
# writer.writerow([sent1, sent2])
# else:
# writer.writerow([sent1, sent2, label])
writer.writerow([sent1, sent2, label])
print(f'finish preprocess {outfile}')
def _preprocess_train_file(infile, outfile):
with open(outfile, 'w') as fout:
writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='')
for i, line in enumerate(open(infile, 'r')):
if i == 0:
continue
items = line.strip().split('\t')
sent1 = ' '.join(items[0].strip().split(' '))
sent2 = ' '.join(items[1].strip().split(' '))
label = "contradiction" if items[2].strip() == "contradictory" else items[2].strip()
writer.writerow([sent1, sent2, label])
print(f'finish preprocess {outfile}')
infile = os.path.join(args.data_dir, 'XNLI-MT-1.0/multinli/multinli.train.en.tsv')
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
outfile = os.path.join(args.output_dir, 'train-en.tsv')
_preprocess_train_file(infile, outfile)
for split in ['test', 'dev']:
infile = os.path.join(args.data_dir, 'XNLI-1.0/xnli.{}.tsv'.format(split))
print(f'reading file {infile}')
_preprocess_file(infile, args.output_dir, split)
def tatoeba_preprocess(args):
lang3_dict = {
'afr':'af', 'ara':'ar', 'bul':'bg', 'ben':'bn',
'deu':'de', 'ell':'el', 'spa':'es', 'est':'et',
'eus':'eu', 'pes':'fa', 'fin':'fi', 'fra':'fr',
'heb':'he', 'hin':'hi', 'hun':'hu', 'ind':'id',
'ita':'it', 'jpn':'ja', 'jav':'jv', 'kat':'ka',
'kaz':'kk', 'kor':'ko', 'mal':'ml', 'mar':'mr',
'nld':'nl', 'por':'pt', 'rus':'ru', 'swh':'sw',
'tam':'ta', 'tel':'te', 'tha':'th', 'tgl':'tl',
'tur':'tr', 'urd':'ur', 'vie':'vi', 'cmn':'zh',
'eng':'en',
}
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
for sl3, sl2 in lang3_dict.items():
if sl3 != 'eng':
src_file = f'{args.data_dir}/tatoeba.{sl3}-eng.{sl3}'
tgt_file = f'{args.data_dir}/tatoeba.{sl3}-eng.eng'
src_out = f'{args.output_dir}/{sl2}-en.{sl2}'
tgt_out = f'{args.output_dir}/{sl2}-en.en'
shutil.copy(src_file, src_out)
tgts = [l.strip() for l in open(tgt_file)]
idx = range(len(tgts))
data = zip(tgts, idx)
with open(tgt_out, 'w') as ftgt:
for t, i in sorted(data, key=lambda x: x[0]):
ftgt.write(f'{t}\n')
def xquad_preprocess(args):
# Remove the test annotations to prevent accidental cheating
# remove_qa_test_annotations(args.data_dir)
pass
def mlqa_preprocess(args):
# Remove the test annotations to prevent accidental cheating
# remove_qa_test_annotations(args.data_dir)
pass
def tydiqa_preprocess(args):
LANG2ISO = {'arabic': 'ar', 'bengali': 'bn', 'english': 'en', 'finnish': 'fi',
'indonesian': 'id', 'korean': 'ko', 'russian': 'ru',
'swahili': 'sw', 'telugu': 'te'}
assert os.path.exists(args.data_dir)
train_file = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-train.json')
os.makedirs(args.output_dir, exist_ok=True)
# Split the training file into language-specific files
lang2data = defaultdict(list)
with open(train_file, 'r') as f_in:
data = json.load(f_in)
version = data['version']
for doc in data['data']:
for par in doc['paragraphs']:
context = par['context']
for qa in par['qas']:
question = qa['question']
question_id = qa['id']
example_lang = question_id.split('-')[0]
q_id = question_id.split('-')[-1]
for answer in qa['answers']:
a_start, a_text = answer['answer_start'], answer['text']
a_end = a_start + len(a_text)
assert context[a_start:a_end] == a_text
lang2data[example_lang].append({'paragraphs': [{
'context': context,
'qas': [{'answers': qa['answers'],
'question': question,
'id': q_id}]}]})
for lang, data in lang2data.items():
out_file = os.path.join(
args.output_dir, 'tydiqa.%s.train.json' % LANG2ISO[lang])
with open(out_file, 'w') as f:
json.dump({'data': data, 'version': version}, f)
# Rename the dev files
dev_dir = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-dev')
assert os.path.exists(dev_dir)
for lang, iso in LANG2ISO.items():
src_file = os.path.join(dev_dir, 'tydiqa-goldp-dev-%s.json' % lang)
dst_file = os.path.join(dev_dir, 'tydiqa.%s.dev.json' % iso)
os.rename(src_file, dst_file)
# Remove the test annotations to prevent accidental cheating
# remove_qa_test_annotations(dev_dir)
def remove_qa_test_annotations(test_dir):
assert os.path.exists(test_dir)
for file_name in os.listdir(test_dir):
new_data = []
test_file = os.path.join(test_dir, file_name)
with open(test_file, 'r') as f:
data = json.load(f)
version = data['version']
for doc in data['data']:
for par in doc['paragraphs']:
context = par['context']
for qa in par['qas']:
question = qa['question']
question_id = qa['id']
for answer in qa['answers']:
a_start, a_text = answer['answer_start'], answer['text']
a_end = a_start + len(a_text)
assert context[a_start:a_end] == a_text
new_data.append({'paragraphs': [{
'context': context,
'qas': [{'answers': [{'answer_start': 0, 'text': ''}],
'question': question,
'id': question_id}]}]})
with open(test_file, 'w') as f:
json.dump({'data': new_data, 'version': version}, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output data dir where any processed files will be written to.")
parser.add_argument("--task", default="panx", type=str, required=True,
help="The task name")
parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str,
help="The pre-trained model")
parser.add_argument("--model_type", default="bert", type=str,
help="model type")
parser.add_argument("--max_len", default=512, type=int,
help="the maximum length of sentences")
parser.add_argument("--do_lower_case", action='store_true',
help="whether to do lower case")
parser.add_argument("--cache_dir", default=None, type=str,
help="cache directory")
parser.add_argument("--languages", default="en", type=str,
help="process language")
parser.add_argument("--remove_last_token", action='store_true',
help="whether to remove the last token")
parser.add_argument("--remove_test_label", action='store_true',
help="whether to remove test set label")
args = parser.parse_args()
if args.task == 'panx_tokenize':
panx_tokenize_preprocess(args)
if args.task == 'panx':
panx_preprocess(args)
if args.task == 'udpos_tokenize':
udpos_tokenize_preprocess(args)
if args.task == 'udpos':
udpos_preprocess(args)
if args.task == 'pawsx':
pawsx_preprocess(args)
if args.task == 'xnli':
xnli_preprocess(args)
if args.task == 'tatoeba':
tatoeba_preprocess(args)
if args.task == 'xquad':
xquad_preprocess(args)
if args.task == 'mlqa':
mlqa_preprocess(args)
if args.task == 'tydiqa':
tydiqa_preprocess(args)