File size: 4,665 Bytes
7cc4b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import List

import torch


def get_betas(name: str, num_steps: int = 1000, shift_snr: float = 1, terminal_pure_noise: bool = False):
    # Get betas
    max_beta = 1 if terminal_pure_noise else 0.999
    if name == "squared_linear":
        betas = torch.linspace(0.00085**0.5, 0.012**0.5, num_steps) ** 2
    elif name == "cosine":
        betas = get_cosine_betas(num_steps, max_beta=max_beta)
    elif name == "alphas_cumprod_linear":
        betas = get_alphas_cumprod_linear_betas(num_steps, max_beta=max_beta)
    elif name == "sigmoid":
        betas = get_sigmoid_betas(num_steps, max_beta=max_beta, square=True, slop=0.7)
    else:
        raise NotImplementedError

    # Shift snr
    betas = shift_betas_by_snr_factor(betas, shift_snr)

    # Ensure terminal pure noise
    # Only non-cosine schedule uses rescale, cosine schedule can directly set max_beta=1 to ensure temrinal pure noise.
    if name == "squared_linear" and terminal_pure_noise:
        betas = rescale_betas_to_ensure_terminal_pure_noise(betas)

    return betas


def validate_betas(betas: List[float]) -> bool:
    """
    Validate betas is monotonic and within 0 to 1 range, i.e. 0 < beta_{t-1} < beta_{t} <= 1

    Args:
        betas (List[float]): betas

    Returns:
        bool: True if betas is correct
    """
    return all(b1 < b2 for b1, b2 in zip(betas, betas[1:])) and betas[0] > 0 and betas[-1] <= 1


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta=0.999):
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
    if not validate_betas(betas):
        import logging
        logging.warning("No feasible betas for given alpha bar")
    return torch.tensor(betas, dtype=torch.float32)


def get_cosine_betas(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
    def alpha_bar_fn(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
    return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta)


def get_sigmoid_betas(num_diffusion_timesteps, max_beta, square=False, slop=1):
    def alpha_bar_fn(t):
        def sigmoid(x):
            return 1 / (1 + math.exp(-x * slop))
        s = 6  # (-6, 6) from geodiff
        vb = sigmoid(-s)
        ve = sigmoid(s)
        return ((sigmoid(s - t * 2 * s) - vb) / (ve - vb))**(1 if not square else 2)
    return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta)


def get_alphas_cumprod_linear_betas(num_diffusion_timesteps, max_beta):
    def alpha_bar_fn(t):
        return 1 - t
    return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta=max_beta)


def shift_betas_by_snr_factor(betas: torch.Tensor, factor: float) -> torch.Tensor:
    if factor == 1.0:
        return betas
    # Convert betas to snr
    alphas = 1 - betas
    alphas_cumprod = alphas.cumprod(dim=0)
    snr = alphas_cumprod / (1 - alphas_cumprod)
    # Shift snr
    snr *= factor
    # Convert snr to betas
    alphas_cumprod = snr / (1 + snr)
    alphas = torch.cat(
        [alphas_cumprod[0:1], alphas_cumprod[1:] / alphas_cumprod[:-1]])
    betas = 1 - alphas
    return betas


def rescale_betas_to_ensure_terminal_pure_noise(betas: torch.Tensor) -> torch.Tensor:
    # Convert betas to alphas_cumprod_sqrt
    alphas = 1 - betas
    alphas_cumprod = alphas.cumprod(0)
    alphas_cumprod_sqrt = alphas_cumprod.sqrt()
    # Rescale alphas_cumprod_sqrt such that alphas_cumprod_sqrt[0] remains unchanged but alphas_cumprod_sqrt[-1] = 0
    alphas_cumprod_sqrt = (alphas_cumprod_sqrt - alphas_cumprod_sqrt[-1]) / (
        alphas_cumprod_sqrt[0] - alphas_cumprod_sqrt[-1]) * alphas_cumprod_sqrt[0]
    # Convert alphas_cumprod_sqrt to betas
    alphas_cumprod = alphas_cumprod_sqrt ** 2
    alphas = torch.cat(
        [alphas_cumprod[0:1], alphas_cumprod[1:] / alphas_cumprod[:-1]])
    betas = 1 - alphas
    return betas