Spaces:
Running
on
Zero
Running
on
Zero
File size: 832 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
from lightning.fabric.utilities.types import _PATH
import logging
logger = logging.getLogger(__name__)
class ModelLoader:
def __init__(self,):
super().__init__()
def load(self, denoiser, prefix=""):
if denoiser.weight_path:
weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))
if denoiser.load_ema:
prefix = "ema_denoiser." + prefix
else:
prefix = "denoiser." + prefix
for k, v in denoiser.state_dict().items():
try:
v.copy_(weight["state_dict"][prefix+k])
except:
logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
return denoiser |