|
''' |
|
Clean uncessary information in the weight (*.pth) |
|
''' |
|
import torch |
|
|
|
|
|
if __name__ == "__main__": |
|
weight_path = "saved_models/esrgan_best_generator.pth" |
|
store_path = "1x_APISR_RRDB_GAN_generator.pth" |
|
|
|
|
|
checkpoint_g = torch.load(weight_path) |
|
keys = [] |
|
for key in checkpoint_g: |
|
keys.append(key) |
|
print(key) |
|
for key in keys: |
|
if key != "model_state_dict": |
|
del checkpoint_g[key] |
|
|
|
|
|
|
|
old_keys = [key for key in checkpoint_g['model_state_dict']] |
|
for old_key in old_keys: |
|
if old_key[:10] == "_orig_mod.": |
|
new_key = old_key[10:] |
|
checkpoint_g['model_state_dict'][new_key] = checkpoint_g['model_state_dict'][old_key] |
|
del checkpoint_g['model_state_dict'][old_key] |
|
|
|
torch.save(checkpoint_g, store_path) |
|
|
|
|
|
|
|
|
|
|