#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Split a large file into a train and valid set while respecting document boundaries. Documents should be separated by a single empty line. """ import argparse import random import sys def main(): parser = argparse.ArgumentParser() parser.add_argument('input') parser.add_argument('sample_output', help='train output file') parser.add_argument('remainder_output', help='valid output file') parser.add_argument('-k', type=int, help="remainder size") parser.add_argument('--lines', action='store_true', help='split lines instead of docs') args = parser.parse_args() assert args.k is not None sample = [] remainder = [] num_docs = [0] def update_sample(doc): if len(sample) < args.k: sample.append(doc.copy()) else: i = num_docs[0] j = random.randrange(i + 1) if j < args.k: remainder.append(sample[j]) sample[j] = doc.copy() else: remainder.append(doc.copy()) num_docs[0] += 1 doc.clear() with open(args.input, 'r', encoding='utf-8') as h: doc = [] for i, line in enumerate(h): if line.strip() == "": # empty line indicates new document update_sample(doc) else: doc.append(line) if args.lines: update_sample(doc) if i % 1000000 == 0: print(i, file=sys.stderr, end="", flush=True) elif i % 100000 == 0: print(".", file=sys.stderr, end="", flush=True) if len(doc) > 0: update_sample(doc) print(file=sys.stderr, flush=True) assert len(sample) == args.k with open(args.sample_output, 'w', encoding='utf-8') as out: first = True for doc in sample: if not first and not args.lines: out.write("\n") first = False for line in doc: out.write(line) with open(args.remainder_output, 'w', encoding='utf-8') as out: first = True for doc in remainder: if not first and not args.lines: out.write("\n") first = False for line in doc: out.write(line) if __name__ == '__main__': main()