TorchGeo
isaaccorley commited on
Commit
c5f9a84
·
verified ·
1 Parent(s): d2fdab6

Upload convert.py

Browse files
Files changed (1) hide show
  1. convert.py +17 -0
convert.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import segmentation_models_pytorch as smp
2
+ import torch
3
+
4
+ paths = [
5
+ "2_Class_CCBY_FTW_Pretrained.ckpt",
6
+ "2_Class_FULL_FTW_Pretrained.ckpt",
7
+ "3_Class_CCBY_FTW_Pretrained.ckpt",
8
+ "3_Class_FULL_FTW_Pretrained.ckpt",
9
+ ]
10
+ classes = [2, 2, 3, 3]
11
+ for num_classes, path in zip(classes, paths):
12
+ state_dict = torch.load(path, weights_only=True, map_location="cpu")["state_dict"]
13
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
14
+ del state_dict["criterion.weight"]
15
+ model = smp.Unet(encoder_name="efficientnet-b3", in_channels=8, classes=num_classes, encoder_weights=None)
16
+ model.load_state_dict(state_dict)
17
+ torch.save(model.state_dict(), path.replace(".ckpt", ".pth"))