File size: 3,515 Bytes
dab5199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import importlib

import omegaconf

from .models import ContrastiveModel, DiffuserSTDiT, ResNet18, SegDiTTransformer2DModel


def parse_klass_arg(value, full_config):
    """
    Parse an argument value that might represent a class, enum, or basic data type.
    This function tries to dynamically import and resolve nested attributes.
    It also resolves OmegaConf interpolations if found.
    """
    if isinstance(value, str) and "." in value:
        # Check if the value is an interpolation and try to resolve it
        if value.startswith("${") and value.endswith("}"):
            try:
                # Attempt to resolve the interpolation directly using OmegaConf
                value = omegaconf.OmegaConf.resolve(full_config)[value[2:-1]]
            except Exception as e:
                print(f"Error resolving OmegaConf interpolation {value}: {e}")
                return None

        parts = value.split(".")
        for i in range(len(parts) - 1, 0, -1):
            module_name = ".".join(parts[:i])
            attr_name = parts[i]
            try:
                module = importlib.import_module(module_name)
                result = module
                for j in range(i, len(parts)):
                    result = getattr(result, parts[j])
                return result
            except ImportError as e:
                continue
            except AttributeError as e:
                print(
                    f"Warning: Could not resolve attribute {parts[j]} from {module_name}, error: {e}"
                )
                continue
        # print(f"Warning: Failed to import or resolve {value}. Falling back to string.")
        return (
            value  # Return the original string if no valid import and resolution occurs
        )
    return value


def instantiate_class_from_config(config, *args, **kwargs):
    """
    Dynamically instantiate a class based on a configuration object.
    Supports passing additional positional and keyword arguments.
    """
    module_name, class_name = config.target.rsplit(".", 1)
    klass = globals().get(class_name)
    # module = importlib.import_module(module_name)
    # klass = getattr(module, class_name)

    # Assuming config might be a part of a larger OmegaConf structure:
    # if not isinstance(config, omegaconf.DictConfig):
    #     config = omegaconf.OmegaConf.create(config)
    config = omegaconf.OmegaConf.to_container(config, resolve=True)
    # Resolve args and kwargs from the configuration
    # conf_args = [parse_klass_arg(arg, config) for arg in config.get('args', [])]
    # conf_kwargs = {key: parse_klass_arg(value, config) for key, value in config.get('kwargs', {}).items()}
    conf_kwargs = {
        key: parse_klass_arg(value, config) for key, value in config["args"].items()
    }
    # Combine conf_args with explicitly passed *args
    all_args = list(args)  # + conf_args

    # Combine conf_kwargs with explicitly passed **kwargs
    all_kwargs = {**conf_kwargs, **kwargs}

    # Instantiate the class with the processed arguments
    instance = klass(*all_args, **all_kwargs)
    return instance


def unscale_latents(latents, vae_scaling=None):
    if vae_scaling is not None:
        if latents.ndim == 4:
            v = (1, -1, 1, 1)
        elif latents.ndim == 5:
            v = (1, -1, 1, 1, 1)
        else:
            raise ValueError("Latents should be 4D or 5D")
        latents *= vae_scaling["std"].view(*v)
        latents += vae_scaling["mean"].view(*v)

    return latents