File size: 6,254 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json

from tqdm import tqdm


"""
Dataset preprocessing script for the SQuAD dataset: https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
Converts the dataset into a jsonl format that can be used for p-tuning/prompt tuning in NeMo. 

Inputs:
    data-dir: (str) The directory where the squad dataset was downloaded, files will be saved here
    train-file: (str) Name of train set file, either train-v1.1.json or train-v2.0.json
    dev-file: (str) Name of dev set file, either dev-v1.1.json or dev-v2.0.json
    save-name-base: (str) The base name for each of the train, val, and test files. If save-name-base were 'squad' for
                    example, the files would be saved as squad_train.jsonl, squad_val.jsonl, and squad_test.jsonl
    include-topic-name: Whether to include the topic name for the paragraph in the data json. See the squad explaination
                        below for more context on what is ment by 'topic name'.
    random-seed: (int) Random seed for repeatable shuffling of train/val/test splits. 

Saves train, val, and test files for the SQuAD dataset. The val and test splits are the same data, because the given test
split lacks ground truth answers. 

An example of the processed output written to file:
    
    {
        "taskname": "squad", 
        "context": "Red is the traditional color of warning and danger. In the Middle Ages, a red flag announced that the defenders of a town or castle would fight to defend it, and a red flag hoisted by a warship meant they would show no mercy to their enemy. In Britain, in the early days of motoring, motor cars had to follow a man with a red flag who would warn horse-drawn vehicles, before the Locomotives on Highways Act 1896 abolished this law. In automobile races, the red flag is raised if there is danger to the drivers. In international football, a player who has made a serious violation of the rules is shown a red penalty card and ejected from the game.", 
        "question": "What did a red flag signal in the Middle Ages?", 
        "answer": " defenders of a town or castle would fight to defend it"
    },


"""


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-dir", type=str, default=".")
    parser.add_argument("--train-file", type=str, default="train-v1.1.json")
    parser.add_argument("--dev-file", type=str, default="dev-v1.1.json")
    parser.add_argument("--save-name-base", type=str, default="squad")
    parser.add_argument("--include-topic-name", action='store_true')
    parser.add_argument("--random-seed", type=int, default=1234)
    args = parser.parse_args()

    train_data_dict = json.load(open(f"{args.data_dir}/{args.train_file}"))
    dev_data_dict = json.load(open(f"{args.data_dir}/{args.dev_file}"))
    train_data = train_data_dict['data']
    val_data = dev_data_dict['data']

    save_name_base = f"{args.data_dir}/{args.save_name_base}"

    process_data(train_data, val_data, save_name_base, args.include_topic_name)


def process_data(train_data, val_data, save_name_base, include_topic):
    train_set = extract_questions(train_data, include_topic, split="train")
    val_set = extract_questions(val_data, include_topic, split="val")
    test_set = extract_questions(val_data, include_topic, split="test")

    gen_file(train_set, save_name_base, 'train')
    gen_file(val_set, save_name_base, 'val')
    gen_file(test_set, save_name_base, 'test', make_ground_truth=True)
    gen_file(test_set, save_name_base, 'test', make_ground_truth=False)


def extract_questions(data, include_topic, split):
    processed_data = []

    # Iterate over topics, want to keep them seprate in train/val/test splits
    for question_group in data:
        processed_topic_data = []
        topic = question_group['title']
        questions = question_group['paragraphs']

        # Iterate over paragraphs related to topics
        for qa_group in questions:
            context = qa_group['context']
            qas = qa_group['qas']

            # Iterate over questions about paragraph
            for qa in qas:
                question = qa['question']

                try:
                    # Dev set has multiple right answers. Want all possible answers in test split ground truth
                    if split == "test":
                        answers = [qa['answers'][i]['text'] for i in range(len(qa['answers']))]

                    # Choose one anser from dev set if making validation split, train set only has one answer
                    else:
                        answers = qa['answers'][0]["text"]

                except IndexError:
                    continue

                example_json = {"taskname": "squad", "context": context, "question": question, "answer": answers}

                if include_topic:
                    example_json["topic"] = topic

                processed_topic_data.append(example_json)
        processed_data.extend(processed_topic_data)

    return processed_data


def gen_file(data, save_name_base, split_type, make_ground_truth=False):
    save_path = f"{save_name_base}_{split_type}.jsonl"

    if make_ground_truth:
        save_path = f"{save_name_base}_{split_type}_ground_truth.jsonl"

    print(f"Saving {split_type} split to {save_path}")

    with open(save_path, 'w') as save_file:
        for example_json in tqdm(data):

            # Dont want labels in the test set
            if split_type == "test" and not make_ground_truth:
                del example_json["answer"]

            save_file.write(json.dumps(example_json) + '\n')


if __name__ == "__main__":
    main()