bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import os
from typing import Tuple
import torch
from nodes.impl.pytorch.model_loading import load_state_dict
from nodes.impl.pytorch.types import PyTorchModel
from nodes.utils.unpickler import RestrictedUnpickle
def parse_ckpt_state_dict(checkpoint: dict):
state_dict = {}
for i, j in checkpoint.items():
if "netG." in i:
key = i.replace("netG.", "")
state_dict[key] = j
elif "module." in i:
key = i.replace("module.", "")
state_dict[key] = j
return state_dict
def load_model(path: str, device, fp16: bool = False) -> Tuple[PyTorchModel, str, str]:
"""Read a pth file from the specified path and return it as a state dict
and loaded model after finding arch config"""
assert os.path.exists(path), f"Model file at location {path} does not exist"
assert os.path.isfile(path), f"Path {path} is not a file"
try:
extension = os.path.splitext(path)[1].lower()
if extension == ".pt":
state_dict = torch.jit.load( # type: ignore
path, map_location=device
).state_dict()
elif extension == ".pth":
state_dict = torch.load(
path,
map_location=device,
pickle_module=RestrictedUnpickle, # type: ignore
)
elif extension == ".ckpt":
checkpoint = torch.load(
path,
map_location=device,
pickle_module=RestrictedUnpickle, # type: ignore
)
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
state_dict = parse_ckpt_state_dict(checkpoint)
else:
raise ValueError(
f"Unsupported model file extension {extension}. Please try a supported model type."
)
model = load_state_dict(state_dict)
for _, v in model.named_parameters():
v.requires_grad = False
model.eval()
model = model.to(device)
should_use_fp16 = fp16
if should_use_fp16:
model = model.half()
else:
model = model.float()
except Exception as e:
raise ValueError(
f"Model {os.path.basename(path)} is unsupported by chaiNNer. Please try"
" another."
) from e
return model