File size: 832 Bytes
9e426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Dict, Any, Optional

import torch
import torch.nn as nn
from lightning.fabric.utilities.types import _PATH


import logging
logger = logging.getLogger(__name__)

class ModelLoader:
    def __init__(self,):
        super().__init__()

    def load(self, denoiser, prefix=""):
        if denoiser.weight_path:
            weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))

            if denoiser.load_ema:
                prefix = "ema_denoiser." + prefix
            else:
                prefix = "denoiser." + prefix

            for k, v in denoiser.state_dict().items():
                try:
                    v.copy_(weight["state_dict"][prefix+k])
                except:
                    logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
        return denoiser