|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
from ast import literal_eval |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
def load_json(filepath): |
|
with open(filepath, "r") as f: |
|
data = json.load(f) |
|
return data |
|
|
|
|
|
def dump_json(filepath, data): |
|
with open(filepath, "w") as f: |
|
json.dump(data, f) |
|
|
|
|
|
def get_context_from_passages(passages, keep_only_relevant_passages): |
|
contexts = [] |
|
if keep_only_relevant_passages: |
|
for passage in passages: |
|
if passage["is_selected"] == 1: |
|
contexts.append(passage["passage_text"]) |
|
else: |
|
contexts = [passage["passage_text"] for passage in passages] |
|
|
|
return " ".join(contexts) |
|
|
|
|
|
def format_answers_into_squad_format(answers): |
|
is_impossible = True if "No Answer Present." in answers else False |
|
if is_impossible: |
|
answers = [] |
|
else: |
|
answers = [{"text": ans, "answer_start": -1} for ans in answers] |
|
|
|
return answers |
|
|
|
|
|
def convert_msmarco_to_squad_format(msmarco_data, args): |
|
ids = list(msmarco_data["query"]) |
|
squad_data = {"data": [{"title": "MSMARCO", "paragraphs": []}], "version": "v2.1"} |
|
for index, _id in enumerate(tqdm(ids)): |
|
|
|
context = get_context_from_passages(msmarco_data["passages"][_id], args.keep_only_relevant_passages) |
|
if not context: |
|
continue |
|
|
|
query = msmarco_data["query"][_id] |
|
|
|
|
|
well_formed_answers = msmarco_data['wellFormedAnswers'][_id] |
|
well_formed_answers = ( |
|
well_formed_answers if isinstance(well_formed_answers, list) else literal_eval(well_formed_answers) |
|
) |
|
answers = well_formed_answers if well_formed_answers else msmarco_data["answers"][_id] |
|
answers = format_answers_into_squad_format(answers) |
|
if args.exclude_negative_samples and (not answers): |
|
continue |
|
|
|
squad_data["data"][0]["paragraphs"].append( |
|
{ |
|
"context": context, |
|
"qas": [ |
|
{"id": index, "question": query, "answers": answers, "is_impossible": False if answers else True,} |
|
], |
|
} |
|
) |
|
|
|
return squad_data |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--msmarco_train_input_filepath", default=None, type=str, required=True) |
|
parser.add_argument("--msmarco_dev_input_filepath", default=None, type=str, required=True) |
|
parser.add_argument("--converted_train_save_path", default=None, type=str, required=True) |
|
parser.add_argument("--converted_dev_save_path", default=None, type=str, required=True) |
|
parser.add_argument( |
|
"--exclude_negative_samples", |
|
default=False, |
|
type=bool, |
|
help="whether to keep No Answer samples in the dataset", |
|
required=False, |
|
) |
|
parser.add_argument( |
|
"--keep_only_relevant_passages", |
|
default=False, |
|
type=bool, |
|
help="if True, will only use passages with is_selected=True for context", |
|
required=False, |
|
) |
|
args = parser.parse_args() |
|
|
|
print("converting MS-MARCO train dataset...") |
|
msmarco_train_data = load_json(args.msmarco_train_input_filepath) |
|
squad_train_data = convert_msmarco_to_squad_format(msmarco_train_data, args) |
|
dump_json(args.converted_train_save_path, squad_train_data) |
|
|
|
print("converting MS-MARCO dev dataset...") |
|
msmarco_dev_data = load_json(args.msmarco_dev_input_filepath) |
|
squad_dev_data = convert_msmarco_to_squad_format(msmarco_dev_data, args) |
|
dump_json(args.converted_dev_save_path, squad_dev_data) |
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
Please agree to the Terms of Use at: |
|
https://microsoft.github.io/msmarco/ |
|
Download data at: |
|
https://msmarco.blob.core.windows.net/msmarco/train_v2.1.json.gz |
|
https://msmarco.blob.core.windows.net/msmarco/dev_v2.1.json.gz |
|
|
|
Example usage: |
|
python convert_msmarco_to_squad_format.py \ |
|
--msmarco_train_input_filepath=/path/to/msmarco_train_v2.1.json \ |
|
--msmarco_dev_input_filepath=/path/to/msmarco_dev_v2.1.json \ |
|
--converted_train_save_path=/path/to/msmarco_squad_format_train.json \ |
|
--converted_dev_save_path=/path/to/msmarco_squad_format_dev.json \ |
|
--exclude_negative_samples=False \ |
|
--keep_only_relevant_passages=False |
|
""" |
|
main() |
|
|