Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,576 Bytes
a3e05e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
from abc import abstractmethod, ABC
try:
from torchdyn.core import NeuralODE
NEURALODE_INSTALLED = True
except ImportError:
NEURALODE_INSTALLED = False
class SchedulerBase(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def set_timesteps(self):
pass
@abstractmethod
def step(self):
pass
@abstractmethod
def add_noise(self):
pass
class StreamingFlowMatchingScheduler(SchedulerBase):
def __init__(self, timesteps=1000, sigma_min=1e-4,
) -> None:
super().__init__()
self.sigma_min = sigma_min
self.timesteps = timesteps
self.t_min = 0
self.t_max = 1 - self.sigma_min
self.neural_ode = None
def set_timesteps(self, timesteps=15):
self.timesteps = timesteps
def step(self, xt, predicted_v):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
xt = xt + h * predicted_v
return xt
def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
if verbose:
gt_v = x0 - xt
for t in time_steps:
predicted_v = ode_wrapper(t, xt)
if verbose:
dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
print("Time: {}, Distance: {}".format(t, dist))
xt = xt + h * predicted_v
return xt
def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
if not NEURALODE_INSTALLED:
raise ImportError("NeuralODE is not installed, please install it first.")
if self.neural_ode is None:
self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min)
eval_points, traj = self.neural_ode(xt, time_steps)
return traj[-1]
def add_noise(self, original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,):
ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系
t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise
return x_noisy, ut
|