Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Split a large file into a train and valid set while respecting document | |
boundaries. Documents should be separated by a single empty line. | |
""" | |
import argparse | |
import random | |
import sys | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('input') | |
parser.add_argument('sample_output', help='train output file') | |
parser.add_argument('remainder_output', help='valid output file') | |
parser.add_argument('-k', type=int, help="remainder size") | |
parser.add_argument('--lines', action='store_true', | |
help='split lines instead of docs') | |
args = parser.parse_args() | |
assert args.k is not None | |
sample = [] | |
remainder = [] | |
num_docs = [0] | |
def update_sample(doc): | |
if len(sample) < args.k: | |
sample.append(doc.copy()) | |
else: | |
i = num_docs[0] | |
j = random.randrange(i + 1) | |
if j < args.k: | |
remainder.append(sample[j]) | |
sample[j] = doc.copy() | |
else: | |
remainder.append(doc.copy()) | |
num_docs[0] += 1 | |
doc.clear() | |
with open(args.input, 'r', encoding='utf-8') as h: | |
doc = [] | |
for i, line in enumerate(h): | |
if line.strip() == "": # empty line indicates new document | |
update_sample(doc) | |
else: | |
doc.append(line) | |
if args.lines: | |
update_sample(doc) | |
if i % 1000000 == 0: | |
print(i, file=sys.stderr, end="", flush=True) | |
elif i % 100000 == 0: | |
print(".", file=sys.stderr, end="", flush=True) | |
if len(doc) > 0: | |
update_sample(doc) | |
print(file=sys.stderr, flush=True) | |
assert len(sample) == args.k | |
with open(args.sample_output, 'w', encoding='utf-8') as out: | |
first = True | |
for doc in sample: | |
if not first and not args.lines: | |
out.write("\n") | |
first = False | |
for line in doc: | |
out.write(line) | |
with open(args.remainder_output, 'w', encoding='utf-8') as out: | |
first = True | |
for doc in remainder: | |
if not first and not args.lines: | |
out.write("\n") | |
first = False | |
for line in doc: | |
out.write(line) | |
if __name__ == '__main__': | |
main() | |