DDT / src /utils /model_loader.py
wangshuai6
init space
9e426da
raw
history blame contribute delete
832 Bytes
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
from lightning.fabric.utilities.types import _PATH
import logging
logger = logging.getLogger(__name__)
class ModelLoader:
def __init__(self,):
super().__init__()
def load(self, denoiser, prefix=""):
if denoiser.weight_path:
weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))
if denoiser.load_ema:
prefix = "ema_denoiser." + prefix
else:
prefix = "denoiser." + prefix
for k, v in denoiser.state_dict().items():
try:
v.copy_(weight["state_dict"][prefix+k])
except:
logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
return denoiser