import argparse import torch parser = argparse.ArgumentParser(description='Hyperparams') parser.add_argument('filename', nargs='?', type=str, default=None) args = parser.parse_args() model = torch.load(args.filename, map_location=torch.device('cpu')) model = model['module'] # new_model = {} # for k, v in model.items(): # if "backbone.blocks" in k: # continue # if "auxiliary_head" in k: # continue # if "pos_embed" in k or "patch_embed" in k or "cls_token" in k: # continue # try: # if "bn" in k: # print("fp32:", k) # new_model[k] = v # else: # new_model[k] = v # except: # new_model[k] = v # print(new_model.keys()) # new_dict = {'state_dict': new_state_dict} torch.save(model, args.filename.replace('.pt', '_release.pt'))