import math import torch from src.diffusion.base.scheduling import * class DDPMScheduler(BaseScheduler): def __init__( self, beta_min=0.0001, beta_max=0.02, num_steps=1000, ): super().__init__() self.beta_min = beta_min self.beta_max = beta_max self.num_steps = num_steps self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda") self.alphas_table = torch.cumprod(1-self.betas_table, dim=0) self.sigmas_table = 1-self.alphas_table def beta(self, t) -> Tensor: t = t.to(torch.long) return self.betas_table[t].view(-1, 1, 1, 1) def alpha(self, t) -> Tensor: t = t.to(torch.long) return self.alphas_table[t].view(-1, 1, 1, 1)**0.5 def sigma(self, t) -> Tensor: t = t.to(torch.long) return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5 def dsigma(self, t) -> Tensor: raise NotImplementedError("wrong usage") def dalpha_over_alpha(self, t) ->Tensor: raise NotImplementedError("wrong usage") def dsigma_mul_sigma(self, t) ->Tensor: raise NotImplementedError("wrong usage") def dalpha(self, t) -> Tensor: raise NotImplementedError("wrong usage") def drift_coefficient(self, t): raise NotImplementedError("wrong usage") def diffuse_coefficient(self, t): raise NotImplementedError("wrong usage") def w(self, t): raise NotImplementedError("wrong usage") class VPScheduler(BaseScheduler): def __init__( self, beta_min=0.1, beta_max=20, ): super().__init__() self.beta_min = beta_min self.beta_d = beta_max - beta_min def beta(self, t) -> Tensor: t = torch.clamp(t, min=1e-3, max=1) return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1) def sigma(self, t) -> Tensor: t = torch.clamp(t, min=1e-3, max=1) inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1) def dsigma(self, t) -> Tensor: raise NotImplementedError("wrong usage") def dalpha_over_alpha(self, t) ->Tensor: raise NotImplementedError("wrong usage") def dsigma_mul_sigma(self, t) ->Tensor: raise NotImplementedError("wrong usage") def dalpha(self, t) -> Tensor: raise NotImplementedError("wrong usage") def alpha(self, t) -> Tensor: t = torch.clamp(t, min=1e-3, max=1) inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1) def drift_coefficient(self, t): raise NotImplementedError("wrong usage") def diffuse_coefficient(self, t): raise NotImplementedError("wrong usage") def w(self, t): return self.diffuse_coefficient(t)