from typing import Dict, Any, Optional import torch import torch.nn as nn from lightning.fabric.utilities.types import _PATH import logging logger = logging.getLogger(__name__) class ModelLoader: def __init__(self,): super().__init__() def load(self, denoiser, prefix=""): if denoiser.weight_path: weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu')) if denoiser.load_ema: prefix = "ema_denoiser." + prefix else: prefix = "denoiser." + prefix for k, v in denoiser.state_dict().items(): try: v.copy_(weight["state_dict"][prefix+k]) except: logger.warning(f"Failed to copy {prefix+k} to denoiser weight") return denoiser