buchi-stdesign commited on
Commit
a629f10
·
verified ·
1 Parent(s): 37723db

Update src/sbv2/utils.py

Browse files
Files changed (1) hide show
  1. src/sbv2/utils.py +7 -12
src/sbv2/utils.py CHANGED
@@ -1,12 +1,7 @@
1
- import torch
2
-
3
- def load_checkpoint(filepath, model, optimizer=None, strict=True):
4
- print(f"Loading checkpoint: {filepath}")
5
- checkpoint_dict = torch.load(filepath, map_location="cpu")
6
- model.load_state_dict(checkpoint_dict["model"], strict=strict)
7
-
8
- if optimizer is not None and "optimizer" in checkpoint_dict:
9
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
10
-
11
- print("Checkpoint loaded.")
12
- return model
 
1
+ from safetensors.torch import load_file
2
+
3
+ def load_checkpoint(filepath, model, strict=True):
4
+ print(f"Loading checkpoint: {filepath}")
5
+ checkpoint_dict = load_file(filepath)
6
+ model.load_state_dict(checkpoint_dict, strict=strict)
7
+ return model