buchi-stdesign commited on
Commit
50ea4b8
·
verified ·
1 Parent(s): 44571ac

Upload utils.py

Browse files
Files changed (1) hide show
  1. src/sbv2/utils.py +12 -0
src/sbv2/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def load_checkpoint(filepath, model, optimizer=None, strict=True):
4
+ print(f"Loading checkpoint: {filepath}")
5
+ checkpoint_dict = torch.load(filepath, map_location="cpu")
6
+ model.load_state_dict(checkpoint_dict["model"], strict=strict)
7
+
8
+ if optimizer is not None and "optimizer" in checkpoint_dict:
9
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
10
+
11
+ print("Checkpoint loaded.")
12
+ return model