|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script can be used to process and import IMDB, ChemProt, SST-2, and THUCnews datasets into NeMo's format. |
|
You may run it as the following: |
|
|
|
python import_datasets.py \ |
|
--dataset_name DATASET_NAME \ |
|
--target_data_dir TARGET_PATH \ |
|
--source_data_dir SOURCE_PATH |
|
|
|
The dataset should be specified by "DATASET_NAME" which can be from ["sst-2", "chemprot", "imdb", "thucnews"]. |
|
It reads the data from "SOURCE_PATH" folder, processes and converts the data into NeMo's format. |
|
Then writes the results into "TARGET_PATH" folder. |
|
|
|
""" |
|
|
|
import argparse |
|
import csv |
|
import glob |
|
import os |
|
from os.path import exists |
|
|
|
import tqdm |
|
|
|
from nemo.utils import logging |
|
|
|
|
|
def process_imdb(infold, outfold, uncased, modes=['train', 'test']): |
|
if not os.path.exists(infold): |
|
link = 'https://ai.stanford.edu/~amaas/data/sentiment/' |
|
raise ValueError( |
|
f'Data not found at {infold}. ' |
|
f'Please download IMDB reviews dataset from {link} and ' |
|
f'extract it into the folder specified by source_data_dir argument.' |
|
) |
|
|
|
logging.info(f'Processing IMDB dataset and store at {outfold}') |
|
os.makedirs(outfold, exist_ok=True) |
|
|
|
outfiles = {} |
|
for mode in modes: |
|
outfiles[mode] = open(os.path.join(outfold, mode + '.tsv'), 'w') |
|
for sent in ['neg', 'pos']: |
|
if sent == 'neg': |
|
label = 0 |
|
else: |
|
label = 1 |
|
files = glob.glob(f'{infold}/{mode}/{sent}/*.txt') |
|
for file in files: |
|
with open(file, 'r') as f: |
|
review = f.read().strip() |
|
if uncased: |
|
review = review.lower() |
|
review = review.replace("<br />", "") |
|
outfiles[mode].write(f'{review}\t{label}\n') |
|
|
|
for mode in modes: |
|
outfiles[mode].close() |
|
|
|
class_labels_file = open(os.path.join(outfold, 'label_ids.tsv'), 'w') |
|
class_labels_file.write('negative\npositive\n') |
|
class_labels_file.close() |
|
|
|
|
|
def process_sst2(infold, outfold, uncased, splits=['train', 'dev']): |
|
"""Process sst2 dataset.""" |
|
|
|
if not os.path.exists(infold): |
|
link = 'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip' |
|
raise ValueError( |
|
f'Data not found at {infold}. Please download SST-2 dataset from `{link}` and ' |
|
f'extract it into the folder specified by `source_data_dir` argument.' |
|
) |
|
|
|
logging.info(f'Processing SST-2 dataset') |
|
os.makedirs(outfold, exist_ok=True) |
|
|
|
def _read_tsv(input_file, quotechar=None): |
|
"""Read a tab separated value file.""" |
|
with open(input_file, "r") as f: |
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) |
|
lines = [] |
|
for line in reader: |
|
lines.append(line) |
|
return lines |
|
|
|
for split in splits: |
|
|
|
input_file = os.path.join(infold, split + '.tsv') |
|
lines = _read_tsv(input_file) |
|
|
|
outfile = open(os.path.join(outfold, split + '.tsv'), 'w') |
|
|
|
|
|
for line in lines[1:]: |
|
text = line[0] |
|
label = line[1] |
|
|
|
if uncased: |
|
text = text.lower() |
|
|
|
outfile.write(f'{text}\t{label}\n') |
|
|
|
outfile.close() |
|
|
|
class_labels_file = open(os.path.join(outfold, 'label_ids.tsv'), 'w') |
|
class_labels_file.write('negative\npositive\n') |
|
class_labels_file.close() |
|
|
|
logging.info(f'Result stored at {outfold}') |
|
|
|
|
|
def process_chemprot(source_dir, target_dir, uncased, modes=['train', 'test', 'dev']): |
|
if not os.path.exists(source_dir): |
|
link = 'https://github.com/arwhirang/recursive_chemprot/tree/master/Demo/tree_LSTM/data' |
|
raise ValueError(f'Data not found at {source_dir}. ' f'Please download ChemProt from {link}.') |
|
|
|
logging.info(f'Processing Chemprot dataset and store at {target_dir}') |
|
os.makedirs(target_dir, exist_ok=True) |
|
|
|
naming_map = {'train': 'trainingPosit_chem', 'test': 'testPosit_chem', 'dev': 'developPosit_chem'} |
|
|
|
def _read_tsv(input_file, quotechar=None): |
|
"""Reads a tab separated value file.""" |
|
with open(input_file, "r") as f: |
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) |
|
lines = [] |
|
for line in reader: |
|
lines.append(line) |
|
return lines |
|
|
|
outfiles = {} |
|
label_mapping = {} |
|
out_label_mapping = open(os.path.join(target_dir, 'label_mapping.tsv'), 'w') |
|
for mode in modes: |
|
outfiles[mode] = open(os.path.join(target_dir, mode + '.tsv'), 'w') |
|
input_file = os.path.join(source_dir, naming_map[mode]) |
|
lines = _read_tsv(input_file) |
|
for line in lines: |
|
text = line[1] |
|
label = line[2] |
|
if label == "True": |
|
label = line[3] |
|
if uncased: |
|
text = text.lower() |
|
if label not in label_mapping: |
|
out_label_mapping.write(f'{label}\t{len(label_mapping)}\n') |
|
label_mapping[label] = len(label_mapping) |
|
label = label_mapping[label] |
|
outfiles[mode].write(f'{text}\t{label}\n') |
|
for mode in modes: |
|
outfiles[mode].close() |
|
out_label_mapping.close() |
|
|
|
|
|
def process_thucnews(infold, outfold): |
|
modes = ['train', 'test'] |
|
train_size = 0.8 |
|
if not os.path.exists(infold): |
|
link = 'thuctc.thunlp.org/' |
|
raise ValueError(f'Data not found at {infold}. ' f'Please download THUCNews from {link}.') |
|
|
|
logging.info(f'Processing THUCNews dataset and store at {outfold}') |
|
os.makedirs(outfold, exist_ok=True) |
|
|
|
outfiles = {} |
|
for mode in modes: |
|
outfiles[mode] = open(os.path.join(outfold, mode + '.tsv'), 'a+', encoding='utf-8') |
|
categories = ['体育', '娱乐', '家居', '彩票', '房产', '教育', '时尚', '时政', '星座', '游戏', '社会', '科技', '股票', '财经'] |
|
for category in categories: |
|
label = categories.index(category) |
|
category_files = glob.glob(f'{infold}/{category}/*.txt') |
|
test_num = int(len(category_files) * (1 - train_size)) |
|
test_files = category_files[:test_num] |
|
train_files = category_files[test_num:] |
|
|
|
for mode in modes: |
|
logging.info(f'Processing {mode} data of the category {category}') |
|
if mode == 'test': |
|
files = test_files |
|
else: |
|
files = train_files |
|
|
|
if len(files) == 0: |
|
logging.info(f'Skipping category {category} for {mode} mode') |
|
continue |
|
|
|
for file in tqdm.tqdm(files): |
|
with open(file, 'r', encoding='utf-8') as f: |
|
news = f.read().strip().replace('\r', '') |
|
news = news.replace('\n', '').replace('\t', ' ') |
|
outfiles[mode].write(f'{news}\t{label}\n') |
|
for mode in modes: |
|
outfiles[mode].close() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description="Process and convert datasets into NeMo\'s format.") |
|
parser.add_argument("--dataset_name", required=True, type=str, choices=['imdb', 'thucnews', 'chemprot']) |
|
parser.add_argument( |
|
"--source_data_dir", required=True, type=str, help='The path to the folder containing the dataset files.' |
|
) |
|
parser.add_argument("--target_data_dir", required=True, type=str) |
|
parser.add_argument("--do_lower_case", action='store_true') |
|
args = parser.parse_args() |
|
|
|
dataset_name = args.dataset_name |
|
do_lower_case = args.do_lower_case |
|
source_dir = args.source_data_dir |
|
target_dir = args.target_data_dir |
|
|
|
if not exists(source_dir): |
|
raise FileNotFoundError(f"{source_dir} does not exist.") |
|
|
|
if dataset_name == 'imdb': |
|
process_imdb(source_dir, target_dir, do_lower_case) |
|
elif dataset_name == 'thucnews': |
|
process_thucnews(source_dir, target_dir) |
|
elif dataset_name == "chemprot": |
|
process_chemprot(source_dir, target_dir, do_lower_case) |
|
elif dataset_name == "sst-2": |
|
process_sst2(source_dir, target_dir, do_lower_case) |
|
else: |
|
raise ValueError( |
|
f'Dataset {dataset_name} is not supported.' |
|
+ "Please make sure that you build the preprocessing process for it. " |
|
+ "NeMo's format assumes that a data file has a header and each line of the file follows " |
|
+ "the format: text [TAB] label. Label is assumed to be an integer." |
|
) |
|
|