# coding=utf-8 # Copyright 2020 Google and DeepMind. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import, division, print_function import argparse from transformers import BertTokenizer, XLMTokenizer, XLMRobertaTokenizer import os from collections import defaultdict import csv import random import os import shutil import json TOKENIZERS = { 'bert': BertTokenizer, 'xlm': XLMTokenizer, 'xlmr': XLMRobertaTokenizer, } def panx_tokenize_preprocess(args): def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): if not os.path.exists(infile): print(f'{infile} not exists') return 0 special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 max_seq_len = max_len - special_tokens_count subword_len_counter = idx = 0 with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: for line in fin: line = line.strip() if not line: fout.write('\n') fidx.write('\n') idx += 1 subword_len_counter = 0 continue items = line.split() token = items[0].strip() if len(items) == 2: label = items[1].strip() else: label = 'O' current_subwords_len = len(tokenizer.tokenize(token)) if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: token = tokenizer.unk_token current_subwords_len = 1 if (subword_len_counter + current_subwords_len) > max_seq_len: fout.write(f"\n{token}\t{label}\n") fidx.write(f"\n{idx}\n") subword_len_counter = current_subwords_len else: fout.write(f"{token}\t{label}\n") fidx.write(f"{idx}\n") subword_len_counter += current_subwords_len return 1 model_type = args.model_type tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) for lang in args.languages.split(','): out_dir = os.path.join(args.output_dir, lang) if not os.path.exists(out_dir): os.makedirs(out_dir) if lang == 'en': files = ['dev', 'test', 'train'] else: files = ['dev', 'test'] for file in files: infile = os.path.join(args.data_dir, f'{file}-{lang}.tsv') outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) if os.path.exists(outfile) and os.path.exists(idxfile): print(f'{outfile} and {idxfile} exist') else: code = _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) if code > 0: print(f'finish preprocessing {outfile}') def panx_preprocess(args): def _process_one_file(infile, outfile): lines = open(infile, 'r').readlines() if lines[-1].strip() == '': lines = lines[:-1] with open(outfile, 'w') as fout: for l in lines: items = l.strip().split('\t') if len(items) == 2: label = items[1].strip() idx = items[0].find(':') if idx != -1: token = items[0][idx+1:].strip() # if 'test' in infile: # fout.write(f'{token}\n') # else: # fout.write(f'{token}\t{label}\n') fout.write(f'{token}\t{label}\n') else: fout.write('\n') if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) langs = 'ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu'.split(' ') for lg in langs: for split in ['train', 'test', 'dev']: infile = os.path.join(args.data_dir, f'{lg}-{split}') outfile = os.path.join(args.output_dir, f'{split}-{lg}.tsv') _process_one_file(infile, outfile) def udpos_tokenize_preprocess(args): def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): if not os.path.exists(infile): print(f'{infile} does not exist') return subword_len_counter = idx = 0 special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 max_seq_len = max_len - special_tokens_count with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: for line in fin: line = line.strip() if len(line) == 0 or line == '': fout.write('\n') fidx.write('\n') idx += 1 subword_len_counter = 0 continue items = line.split() if len(items) == 2: label = items[1].strip() else: label = "X" token = items[0].strip() current_subwords_len = len(tokenizer.tokenize(token)) if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: token = tokenizer.unk_token current_subwords_len = 1 if (subword_len_counter + current_subwords_len) > max_seq_len: fout.write(f"\n{token}\t{label}\n") fidx.write(f"\n{idx}\n") subword_len_counter = current_subwords_len else: fout.write(f"{token}\t{label}\n") fidx.write(f"{idx}\n") subword_len_counter += current_subwords_len model_type = args.model_type tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) for lang in args.languages.split(','): out_dir = os.path.join(args.output_dir, lang) if not os.path.exists(out_dir): os.makedirs(out_dir) if lang == 'en': files = ['dev', 'test', 'train'] else: files = ['dev', 'test'] for file in files: infile = os.path.join(args.data_dir, "{}-{}.tsv".format(file, lang)) outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) if os.path.exists(outfile) and os.path.exists(idxfile): print(f'{outfile} and {idxfile} exist') else: _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) print(f'finish preprocessing {outfile}') def udpos_preprocess(args): def _read_one_file(file): data = [] sent, tag, lines = [], [], [] for line in open(file, 'r'): items = line.strip().split('\t') if len(items) != 10: empty = all(w == '_' for w in sent) num_empty = sum([int(w == '_') for w in sent]) if num_empty == 0 or num_empty < len(sent) - 1: data.append((sent, tag, lines)) sent, tag, lines = [], [], [] else: sent.append(items[1].strip()) tag.append(items[3].strip()) lines.append(line.strip()) assert len(sent) == int(items[0]), 'line={}, sent={}, tag={}'.format(line, sent, tag) return data def isfloat(value): try: float(value) return True except ValueError: return False def remove_empty_space(data): new_data = {} for split in data: new_data[split] = [] for sent, tag, lines in data[split]: new_sent = [''.join(w.replace('\u200c', '').split(' ')) for w in sent] lines = [line.replace('\u200c', '') for line in lines] assert len(" ".join(new_sent).split(' ')) == len(tag) new_data[split].append((new_sent, tag, lines)) return new_data def check_file(file): for i, l in enumerate(open(file)): items = l.strip().split('\t') assert len(items[0].split(' ')) == len(items[1].split(' ')), 'idx={}, line={}'.format(i, l) def _write_files(data, output_dir, lang, suffix): for split in data: if len(data[split]) > 0: prefix = os.path.join(output_dir, f'{split}-{lang}') if suffix == 'mt': with open(prefix + '.mt.tsv', 'w') as fout: for idx, (sent, tag, _) in enumerate(data[split]): newline = '\n' if idx != len(data[split]) - 1 else '' # if split == 'test': # fout.write('{}{}'.format(' '.join(sent, newline))) # else: # fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline)) fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline)) check_file(prefix + '.mt.tsv') print(' - finish checking ' + prefix + '.mt.tsv') elif suffix == 'tsv': with open(prefix + '.tsv', 'w') as fout: for sidx, (sent, tag, _) in enumerate(data[split]): for widx, (w, t) in enumerate(zip(sent, tag)): newline = '' if (sidx == len(data[split]) - 1) and (widx == len(sent) - 1) else '\n' # if split == 'test': # fout.write('{}{}'.format(w, newline)) # else: # fout.write('{}\t{}{}'.format(w, t, newline)) fout.write('{}\t{}{}'.format(w, t, newline)) fout.write('\n') elif suffix == 'conll': with open(prefix + '.conll', 'w') as fout: for _, _, lines in data[split]: for l in lines: fout.write(l.strip() + '\n') fout.write('\n') print(f'finish writing file to {prefix}.{suffix}') if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) languages = 'af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh'.split(' ') for root, dirs, files in os.walk(args.data_dir): lg = root.strip().split('/')[-1] if root == args.data_dir or lg not in languages: continue data = {k: [] for k in ['train', 'dev', 'test']} for f in sorted(files): if f.endswith('conll'): file = os.path.join(root, f) examples = _read_one_file(file) if 'train' in f: data['train'].extend(examples) elif 'dev' in f: data['dev'].extend(examples) elif 'test' in f: data['test'].extend(examples) else: print('split not found: ', file) print(' - finish reading {}, {}'.format(file, [(k, len(v)) for k,v in data.items()])) data = remove_empty_space(data) for sub in ['tsv']: _write_files(data, args.output_dir, lg, sub) def pawsx_preprocess(args): def _preprocess_one_file(infile, outfile, remove_label=False): data = [] for i, line in enumerate(open(infile, 'r')): if i == 0: continue items = line.strip().split('\t') sent1 = ' '.join(items[1].strip().split(' ')) sent2 = ' '.join(items[2].strip().split(' ')) label = items[3] data.append([sent1, sent2, label]) with open(outfile, 'w') as fout: writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='') for sent1, sent2, label in data: # if remove_label: # writer.writerow([sent1, sent2]) # else: # writer.writerow([sent1, sent2, label]) writer.writerow([sent1, sent2, label]) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) split2file = {'train': 'train', 'test': 'test_2k', 'dev': 'dev_2k'} for lang in ['en', 'de', 'es', 'fr', 'ja', 'ko', 'zh']: for split in ['train', 'test', 'dev']: if split == 'train' and lang != 'en': continue file = split2file[split] infile = os.path.join(args.data_dir, lang, "{}.tsv".format(file)) outfile = os.path.join(args.output_dir, "{}-{}.tsv".format(split, lang)) _preprocess_one_file(infile, outfile, remove_label=(split == 'test')) print(f'finish preprocessing {outfile}') def xnli_preprocess(args): def _preprocess_file(infile, output_dir, split): all_langs = defaultdict(list) for i, line in enumerate(open(infile, 'r')): if i == 0: continue items = line.strip().split('\t') lang = items[0].strip() label = "contradiction" if items[1].strip() == "contradictory" else items[1].strip() sent1 = ' '.join(items[6].strip().split(' ')) sent2 = ' '.join(items[7].strip().split(' ')) all_langs[lang].append((sent1, sent2, label)) print(f'# langs={len(all_langs)}') for lang, pairs in all_langs.items(): outfile = os.path.join(output_dir, '{}-{}.tsv'.format(split, lang)) with open(outfile, 'w') as fout: writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='') for (sent1, sent2, label) in pairs: # if split == 'test': # writer.writerow([sent1, sent2]) # else: # writer.writerow([sent1, sent2, label]) writer.writerow([sent1, sent2, label]) print(f'finish preprocess {outfile}') def _preprocess_train_file(infile, outfile): with open(outfile, 'w') as fout: writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='') for i, line in enumerate(open(infile, 'r')): if i == 0: continue items = line.strip().split('\t') sent1 = ' '.join(items[0].strip().split(' ')) sent2 = ' '.join(items[1].strip().split(' ')) label = "contradiction" if items[2].strip() == "contradictory" else items[2].strip() writer.writerow([sent1, sent2, label]) print(f'finish preprocess {outfile}') infile = os.path.join(args.data_dir, 'XNLI-MT-1.0/multinli/multinli.train.en.tsv') if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) outfile = os.path.join(args.output_dir, 'train-en.tsv') _preprocess_train_file(infile, outfile) for split in ['test', 'dev']: infile = os.path.join(args.data_dir, 'XNLI-1.0/xnli.{}.tsv'.format(split)) print(f'reading file {infile}') _preprocess_file(infile, args.output_dir, split) def tatoeba_preprocess(args): lang3_dict = { 'afr':'af', 'ara':'ar', 'bul':'bg', 'ben':'bn', 'deu':'de', 'ell':'el', 'spa':'es', 'est':'et', 'eus':'eu', 'pes':'fa', 'fin':'fi', 'fra':'fr', 'heb':'he', 'hin':'hi', 'hun':'hu', 'ind':'id', 'ita':'it', 'jpn':'ja', 'jav':'jv', 'kat':'ka', 'kaz':'kk', 'kor':'ko', 'mal':'ml', 'mar':'mr', 'nld':'nl', 'por':'pt', 'rus':'ru', 'swh':'sw', 'tam':'ta', 'tel':'te', 'tha':'th', 'tgl':'tl', 'tur':'tr', 'urd':'ur', 'vie':'vi', 'cmn':'zh', 'eng':'en', } if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) for sl3, sl2 in lang3_dict.items(): if sl3 != 'eng': src_file = f'{args.data_dir}/tatoeba.{sl3}-eng.{sl3}' tgt_file = f'{args.data_dir}/tatoeba.{sl3}-eng.eng' src_out = f'{args.output_dir}/{sl2}-en.{sl2}' tgt_out = f'{args.output_dir}/{sl2}-en.en' shutil.copy(src_file, src_out) tgts = [l.strip() for l in open(tgt_file)] idx = range(len(tgts)) data = zip(tgts, idx) with open(tgt_out, 'w') as ftgt: for t, i in sorted(data, key=lambda x: x[0]): ftgt.write(f'{t}\n') def xquad_preprocess(args): # Remove the test annotations to prevent accidental cheating # remove_qa_test_annotations(args.data_dir) pass def mlqa_preprocess(args): # Remove the test annotations to prevent accidental cheating # remove_qa_test_annotations(args.data_dir) pass def tydiqa_preprocess(args): LANG2ISO = {'arabic': 'ar', 'bengali': 'bn', 'english': 'en', 'finnish': 'fi', 'indonesian': 'id', 'korean': 'ko', 'russian': 'ru', 'swahili': 'sw', 'telugu': 'te'} assert os.path.exists(args.data_dir) train_file = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-train.json') os.makedirs(args.output_dir, exist_ok=True) # Split the training file into language-specific files lang2data = defaultdict(list) with open(train_file, 'r') as f_in: data = json.load(f_in) version = data['version'] for doc in data['data']: for par in doc['paragraphs']: context = par['context'] for qa in par['qas']: question = qa['question'] question_id = qa['id'] example_lang = question_id.split('-')[0] q_id = question_id.split('-')[-1] for answer in qa['answers']: a_start, a_text = answer['answer_start'], answer['text'] a_end = a_start + len(a_text) assert context[a_start:a_end] == a_text lang2data[example_lang].append({'paragraphs': [{ 'context': context, 'qas': [{'answers': qa['answers'], 'question': question, 'id': q_id}]}]}) for lang, data in lang2data.items(): out_file = os.path.join( args.output_dir, 'tydiqa.%s.train.json' % LANG2ISO[lang]) with open(out_file, 'w') as f: json.dump({'data': data, 'version': version}, f) # Rename the dev files dev_dir = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-dev') assert os.path.exists(dev_dir) for lang, iso in LANG2ISO.items(): src_file = os.path.join(dev_dir, 'tydiqa-goldp-dev-%s.json' % lang) dst_file = os.path.join(dev_dir, 'tydiqa.%s.dev.json' % iso) os.rename(src_file, dst_file) # Remove the test annotations to prevent accidental cheating # remove_qa_test_annotations(dev_dir) def remove_qa_test_annotations(test_dir): assert os.path.exists(test_dir) for file_name in os.listdir(test_dir): new_data = [] test_file = os.path.join(test_dir, file_name) with open(test_file, 'r') as f: data = json.load(f) version = data['version'] for doc in data['data']: for par in doc['paragraphs']: context = par['context'] for qa in par['qas']: question = qa['question'] question_id = qa['id'] for answer in qa['answers']: a_start, a_text = answer['answer_start'], answer['text'] a_end = a_start + len(a_text) assert context[a_start:a_end] == a_text new_data.append({'paragraphs': [{ 'context': context, 'qas': [{'answers': [{'answer_start': 0, 'text': ''}], 'question': question, 'id': question_id}]}]}) with open(test_file, 'w') as f: json.dump({'data': new_data, 'version': version}, f) if __name__ == "__main__": parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output data dir where any processed files will be written to.") parser.add_argument("--task", default="panx", type=str, required=True, help="The task name") parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str, help="The pre-trained model") parser.add_argument("--model_type", default="bert", type=str, help="model type") parser.add_argument("--max_len", default=512, type=int, help="the maximum length of sentences") parser.add_argument("--do_lower_case", action='store_true', help="whether to do lower case") parser.add_argument("--cache_dir", default=None, type=str, help="cache directory") parser.add_argument("--languages", default="en", type=str, help="process language") parser.add_argument("--remove_last_token", action='store_true', help="whether to remove the last token") parser.add_argument("--remove_test_label", action='store_true', help="whether to remove test set label") args = parser.parse_args() if args.task == 'panx_tokenize': panx_tokenize_preprocess(args) if args.task == 'panx': panx_preprocess(args) if args.task == 'udpos_tokenize': udpos_tokenize_preprocess(args) if args.task == 'udpos': udpos_preprocess(args) if args.task == 'pawsx': pawsx_preprocess(args) if args.task == 'xnli': xnli_preprocess(args) if args.task == 'tatoeba': tatoeba_preprocess(args) if args.task == 'xquad': xquad_preprocess(args) if args.task == 'mlqa': mlqa_preprocess(args) if args.task == 'tydiqa': tydiqa_preprocess(args)