File size: 2,374 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from typing import Tuple
import torch
from nodes.impl.pytorch.model_loading import load_state_dict
from nodes.impl.pytorch.types import PyTorchModel
from nodes.utils.unpickler import RestrictedUnpickle


def parse_ckpt_state_dict(checkpoint: dict):
    state_dict = {}
    for i, j in checkpoint.items():
        if "netG." in i:
            key = i.replace("netG.", "")
            state_dict[key] = j
        elif "module." in i:
            key = i.replace("module.", "")
            state_dict[key] = j
    return state_dict


def load_model(path: str, device, fp16: bool = False) -> Tuple[PyTorchModel, str, str]:
    """Read a pth file from the specified path and return it as a state dict
    and loaded model after finding arch config"""
    assert os.path.exists(path), f"Model file at location {path} does not exist"
    assert os.path.isfile(path), f"Path {path} is not a file"
    try:
        extension = os.path.splitext(path)[1].lower()

        if extension == ".pt":
            state_dict = torch.jit.load(  # type: ignore
                path, map_location=device
            ).state_dict()
        elif extension == ".pth":
            state_dict = torch.load(
                path,
                map_location=device,
                pickle_module=RestrictedUnpickle,  # type: ignore
            )
        elif extension == ".ckpt":
            checkpoint = torch.load(
                path,
                map_location=device,
                pickle_module=RestrictedUnpickle,  # type: ignore
            )
            if "state_dict" in checkpoint:
                checkpoint = checkpoint["state_dict"]
            state_dict = parse_ckpt_state_dict(checkpoint)
        else:
            raise ValueError(
                f"Unsupported model file extension {extension}. Please try a supported model type."
            )

        model = load_state_dict(state_dict)

        for _, v in model.named_parameters():
            v.requires_grad = False
        model.eval()
        model = model.to(device)
        should_use_fp16 = fp16
        if should_use_fp16:
            model = model.half()
        else:
            model = model.float()
    except Exception as e:
        raise ValueError(
            f"Model {os.path.basename(path)} is unsupported by chaiNNer. Please try"
            " another."
        ) from e

    return model