File size: 714 Bytes
c5f9a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import segmentation_models_pytorch as smp
import torch

paths = [
    "2_Class_CCBY_FTW_Pretrained.ckpt",
    "2_Class_FULL_FTW_Pretrained.ckpt",
    "3_Class_CCBY_FTW_Pretrained.ckpt",
    "3_Class_FULL_FTW_Pretrained.ckpt",
]
classes = [2, 2, 3, 3]
for num_classes, path in zip(classes, paths):
    state_dict = torch.load(path, weights_only=True, map_location="cpu")["state_dict"]
    state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
    del state_dict["criterion.weight"]
    model = smp.Unet(encoder_name="efficientnet-b3", in_channels=8, classes=num_classes, encoder_weights=None)
    model.load_state_dict(state_dict)
    torch.save(model.state_dict(), path.replace(".ckpt", ".pth"))