#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import os import re import shutil import sys pt_regexp = re.compile(r'checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt') pt_regexp_epoch_based = re.compile(r'checkpoint(\d+)\.pt') pt_regexp_update_based = re.compile(r'checkpoint_\d+_(\d+)\.pt') def parse_checkpoints(files): entries = [] for f in files: m = pt_regexp_epoch_based.fullmatch(f) if m is not None: entries.append((int(m.group(1)), m.group(0))) else: m = pt_regexp_update_based.fullmatch(f) if m is not None: entries.append((int(m.group(1)), m.group(0))) return entries def last_n_checkpoints(files, n): entries = parse_checkpoints(files) return [x[1] for x in sorted(entries, reverse=True)[:n]] def every_n_checkpoints(files, n): entries = parse_checkpoints(files) return [x[1] for x in sorted(sorted(entries)[::-n])] def main(): parser = argparse.ArgumentParser( description=( 'Recursively delete checkpoint files from `root_dir`, ' 'but preserve checkpoint_best.pt and checkpoint_last.pt' ) ) parser.add_argument('root_dirs', nargs='*') parser.add_argument('--save-last', type=int, default=0, help='number of last checkpoints to save') parser.add_argument('--save-every', type=int, default=0, help='interval of checkpoints to save') parser.add_argument('--preserve-test', action='store_true', help='preserve checkpoints in dirs that start with test_ prefix (default: delete them)') parser.add_argument('--delete-best', action='store_true', help='delete checkpoint_best.pt') parser.add_argument('--delete-last', action='store_true', help='delete checkpoint_last.pt') parser.add_argument('--no-dereference', action='store_true', help='don\'t dereference symlinks') args = parser.parse_args() files_to_desymlink = [] files_to_preserve = [] files_to_delete = [] for root_dir in args.root_dirs: for root, _subdirs, files in os.walk(root_dir): if args.save_last > 0: to_save = last_n_checkpoints(files, args.save_last) else: to_save = [] if args.save_every > 0: to_save += every_n_checkpoints(files, args.save_every) for file in files: if not pt_regexp.fullmatch(file): continue full_path = os.path.join(root, file) if ( ( not os.path.basename(root).startswith('test_') or args.preserve_test ) and ( (file == 'checkpoint_last.pt' and not args.delete_last) or (file == 'checkpoint_best.pt' and not args.delete_best) or file in to_save ) ): if os.path.islink(full_path) and not args.no_dereference: files_to_desymlink.append(full_path) else: files_to_preserve.append(full_path) else: files_to_delete.append(full_path) if len(files_to_desymlink) == 0 and len(files_to_delete) == 0: print('Nothing to do.') sys.exit(0) files_to_desymlink = sorted(files_to_desymlink) files_to_preserve = sorted(files_to_preserve) files_to_delete = sorted(files_to_delete) print('Operations to perform (in order):') if len(files_to_desymlink) > 0: for file in files_to_desymlink: print(' - preserve (and dereference symlink): ' + file) if len(files_to_preserve) > 0: for file in files_to_preserve: print(' - preserve: ' + file) if len(files_to_delete) > 0: for file in files_to_delete: print(' - delete: ' + file) while True: resp = input('Continue? (Y/N): ') if resp.strip().lower() == 'y': break elif resp.strip().lower() == 'n': sys.exit(0) print('Executing...') if len(files_to_desymlink) > 0: for file in files_to_desymlink: realpath = os.path.realpath(file) print('rm ' + file) os.remove(file) print('cp {} {}'.format(realpath, file)) shutil.copyfile(realpath, file) if len(files_to_delete) > 0: for file in files_to_delete: print('rm ' + file) os.remove(file) if __name__ == '__main__': main()