|
|
|
|
|
import torch |
|
|
|
from scepter.modules.model.registry import NOISE_SCHEDULERS |
|
from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler |
|
|
|
|
|
@NOISE_SCHEDULERS.register_class() |
|
class LinearScheduler(BaseNoiseScheduler): |
|
para_dict = {} |
|
|
|
def init_params(self): |
|
super().init_params() |
|
self.beta_min = self.cfg.get('BETA_MIN', 0.00085) |
|
self.beta_max = self.cfg.get('BETA_MAX', 0.012) |
|
|
|
def betas_to_sigmas(self, betas): |
|
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) |
|
|
|
def get_schedule(self): |
|
betas = torch.linspace(self.beta_min, |
|
self.beta_max, |
|
self.num_timesteps, |
|
dtype=torch.float32) |
|
sigmas = self.betas_to_sigmas(betas) |
|
self._sigmas = sigmas |
|
self._betas = betas |
|
self._alphas = torch.sqrt(1 - sigmas**2) |
|
self._timesteps = torch.arange(len(sigmas), dtype=torch.float32) |