File size: 2,576 Bytes
a3e05e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from abc import abstractmethod, ABC
try:
    from torchdyn.core import NeuralODE
    NEURALODE_INSTALLED = True
except ImportError:
    NEURALODE_INSTALLED = False

class SchedulerBase(ABC):
    def __init__(self) -> None:
        pass
    
    @abstractmethod
    def set_timesteps(self):
        pass
    
    @abstractmethod
    def step(self):
        pass

    @abstractmethod
    def add_noise(self):
        pass


class StreamingFlowMatchingScheduler(SchedulerBase):
    def __init__(self, timesteps=1000, sigma_min=1e-4,
                    ) -> None:
        super().__init__()

        self.sigma_min = sigma_min
        self.timesteps = timesteps
        self.t_min = 0
        self.t_max = 1 - self.sigma_min

        self.neural_ode = None

    
    def set_timesteps(self, timesteps=15):
        self.timesteps = timesteps

    def step(self, xt, predicted_v):

        h = (self.t_max - self.t_min) / self.timesteps
        h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)

        xt = xt + h * predicted_v
        return xt
    
    def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
        h = (self.t_max - self.t_min) / self.timesteps
        h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)

        if verbose:
            gt_v = x0 - xt

        for t in time_steps:
            predicted_v = ode_wrapper(t, xt)
            if verbose:
                dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
                print("Time: {}, Distance: {}".format(t, dist))
            xt = xt + h * predicted_v
        return xt
    
    def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
        if not NEURALODE_INSTALLED:
            raise ImportError("NeuralODE is not installed, please install it first.")
        
        if self.neural_ode is None:
            self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min)

        eval_points, traj = self.neural_ode(xt, time_steps)
        return traj[-1]

 
    def add_noise(self, original_samples: torch.FloatTensor,
                        noise: torch.FloatTensor,
                        timesteps: torch.IntTensor,):
        ut = original_samples - (1 - self.sigma_min) * noise  # 和ut的梯度没关系
        t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
        x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise
        return x_noisy, ut