jzq11111's picture
Upload folder using huggingface_hub
a3e05e8 verified
raw
history blame contribute delete
2.58 kB
import torch
from abc import abstractmethod, ABC
try:
from torchdyn.core import NeuralODE
NEURALODE_INSTALLED = True
except ImportError:
NEURALODE_INSTALLED = False
class SchedulerBase(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def set_timesteps(self):
pass
@abstractmethod
def step(self):
pass
@abstractmethod
def add_noise(self):
pass
class StreamingFlowMatchingScheduler(SchedulerBase):
def __init__(self, timesteps=1000, sigma_min=1e-4,
) -> None:
super().__init__()
self.sigma_min = sigma_min
self.timesteps = timesteps
self.t_min = 0
self.t_max = 1 - self.sigma_min
self.neural_ode = None
def set_timesteps(self, timesteps=15):
self.timesteps = timesteps
def step(self, xt, predicted_v):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
xt = xt + h * predicted_v
return xt
def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
if verbose:
gt_v = x0 - xt
for t in time_steps:
predicted_v = ode_wrapper(t, xt)
if verbose:
dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
print("Time: {}, Distance: {}".format(t, dist))
xt = xt + h * predicted_v
return xt
def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
if not NEURALODE_INSTALLED:
raise ImportError("NeuralODE is not installed, please install it first.")
if self.neural_ode is None:
self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min)
eval_points, traj = self.neural_ode(xt, time_steps)
return traj[-1]
def add_noise(self, original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,):
ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系
t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise
return x_noisy, ut