Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os | |
import re | |
import shutil | |
import sys | |
pt_regexp = re.compile(r'checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt') | |
pt_regexp_epoch_based = re.compile(r'checkpoint(\d+)\.pt') | |
pt_regexp_update_based = re.compile(r'checkpoint_\d+_(\d+)\.pt') | |
def parse_checkpoints(files): | |
entries = [] | |
for f in files: | |
m = pt_regexp_epoch_based.fullmatch(f) | |
if m is not None: | |
entries.append((int(m.group(1)), m.group(0))) | |
else: | |
m = pt_regexp_update_based.fullmatch(f) | |
if m is not None: | |
entries.append((int(m.group(1)), m.group(0))) | |
return entries | |
def last_n_checkpoints(files, n): | |
entries = parse_checkpoints(files) | |
return [x[1] for x in sorted(entries, reverse=True)[:n]] | |
def every_n_checkpoints(files, n): | |
entries = parse_checkpoints(files) | |
return [x[1] for x in sorted(sorted(entries)[::-n])] | |
def main(): | |
parser = argparse.ArgumentParser( | |
description=( | |
'Recursively delete checkpoint files from `root_dir`, ' | |
'but preserve checkpoint_best.pt and checkpoint_last.pt' | |
) | |
) | |
parser.add_argument('root_dirs', nargs='*') | |
parser.add_argument('--save-last', type=int, default=0, help='number of last checkpoints to save') | |
parser.add_argument('--save-every', type=int, default=0, help='interval of checkpoints to save') | |
parser.add_argument('--preserve-test', action='store_true', | |
help='preserve checkpoints in dirs that start with test_ prefix (default: delete them)') | |
parser.add_argument('--delete-best', action='store_true', help='delete checkpoint_best.pt') | |
parser.add_argument('--delete-last', action='store_true', help='delete checkpoint_last.pt') | |
parser.add_argument('--no-dereference', action='store_true', help='don\'t dereference symlinks') | |
args = parser.parse_args() | |
files_to_desymlink = [] | |
files_to_preserve = [] | |
files_to_delete = [] | |
for root_dir in args.root_dirs: | |
for root, _subdirs, files in os.walk(root_dir): | |
if args.save_last > 0: | |
to_save = last_n_checkpoints(files, args.save_last) | |
else: | |
to_save = [] | |
if args.save_every > 0: | |
to_save += every_n_checkpoints(files, args.save_every) | |
for file in files: | |
if not pt_regexp.fullmatch(file): | |
continue | |
full_path = os.path.join(root, file) | |
if ( | |
( | |
not os.path.basename(root).startswith('test_') | |
or args.preserve_test | |
) | |
and ( | |
(file == 'checkpoint_last.pt' and not args.delete_last) | |
or (file == 'checkpoint_best.pt' and not args.delete_best) | |
or file in to_save | |
) | |
): | |
if os.path.islink(full_path) and not args.no_dereference: | |
files_to_desymlink.append(full_path) | |
else: | |
files_to_preserve.append(full_path) | |
else: | |
files_to_delete.append(full_path) | |
if len(files_to_desymlink) == 0 and len(files_to_delete) == 0: | |
print('Nothing to do.') | |
sys.exit(0) | |
files_to_desymlink = sorted(files_to_desymlink) | |
files_to_preserve = sorted(files_to_preserve) | |
files_to_delete = sorted(files_to_delete) | |
print('Operations to perform (in order):') | |
if len(files_to_desymlink) > 0: | |
for file in files_to_desymlink: | |
print(' - preserve (and dereference symlink): ' + file) | |
if len(files_to_preserve) > 0: | |
for file in files_to_preserve: | |
print(' - preserve: ' + file) | |
if len(files_to_delete) > 0: | |
for file in files_to_delete: | |
print(' - delete: ' + file) | |
while True: | |
resp = input('Continue? (Y/N): ') | |
if resp.strip().lower() == 'y': | |
break | |
elif resp.strip().lower() == 'n': | |
sys.exit(0) | |
print('Executing...') | |
if len(files_to_desymlink) > 0: | |
for file in files_to_desymlink: | |
realpath = os.path.realpath(file) | |
print('rm ' + file) | |
os.remove(file) | |
print('cp {} {}'.format(realpath, file)) | |
shutil.copyfile(realpath, file) | |
if len(files_to_delete) > 0: | |
for file in files_to_delete: | |
print('rm ' + file) | |
os.remove(file) | |
if __name__ == '__main__': | |
main() | |