File size: 1,815 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Union, List
import numpy as np
import torch


def extract_generator_seed(generator: Union[torch.Generator, List[torch.Generator]]) -> List[int]:
    if isinstance(generator, list):
        generator = [g.seed() for g in generator]
    else:
        generator = [generator.seed()]
    return generator


def randn_tensor(shape, dtype: np.dtype, generator: Union[torch.Generator, List[torch.Generator], int, List[int]]):
    if hasattr(generator, "seed") or (isinstance(generator, list) and hasattr(generator[0], "seed")):
        generator = extract_generator_seed(generator)
        if len(generator) == 1:
            generator = generator[0]
    return np.random.default_rng(generator).standard_normal(shape).astype(dtype)


def prepare_latents(
    init_noise_sigma: float,
    batch_size: int,
    height: int,
    width: int,
    dtype: np.dtype,
    generator: Union[torch.Generator, List[torch.Generator]],
    latents: Union[np.ndarray, None]=None,
    num_channels_latents=4,
    vae_scale_factor=8,
):
    shape = (batch_size, num_channels_latents, height // vae_scale_factor, width // vae_scale_factor)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    if latents is None:
        latents = randn_tensor(shape, dtype, generator)
    elif latents.shape != shape:
        raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")

    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * np.float64(init_noise_sigma)

    return latents