|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import torch |
|
from typing import Union |
|
from torch.distributions import LogisticNormal |
|
|
|
|
|
class LogitNormalTrainingTimesteps: |
|
def __init__(self, T=1000.0, loc=0.0, scale=1.0): |
|
assert T > 0 |
|
self.T = T |
|
self.dist = LogisticNormal(loc, scale) |
|
|
|
def sample(self, size, device): |
|
t = self.dist.sample(size)[..., 0].to(device) |
|
return t |
|
|
|
|
|
def pad_t_like_x(t, x): |
|
"""Function to reshape the time vector t by the number of dimensions of x. |
|
|
|
Parameters |
|
---------- |
|
x : Tensor, shape (bs, *dim) |
|
represents the source minibatch |
|
t : FloatTensor, shape (bs) |
|
|
|
Returns |
|
------- |
|
t : Tensor, shape (bs, number of x dimensions) |
|
|
|
Example |
|
------- |
|
x: Tensor (bs, C, W, H) |
|
t: Vector (bs) |
|
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1) |
|
""" |
|
if isinstance(t, (float, int)): |
|
return t |
|
return t.reshape(-1, *([1] * (x.dim() - 1))) |
|
|
|
|
|
class ConditionalFlowMatcher: |
|
"""Base class for conditional flow matching methods. This class implements the independent |
|
conditional flow matching methods from [1] and serves as a parent class for all other flow |
|
matching methods. |
|
|
|
It implements: |
|
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function |
|
- conditional flow matching ut(x1|x0) = x1 - x0 |
|
- score function $\nabla log p_t(x|x0, x1)$ |
|
""" |
|
|
|
def __init__(self, sigma: Union[float, int] = 0.0): |
|
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. |
|
|
|
Parameters |
|
---------- |
|
sigma : Union[float, int] |
|
""" |
|
self.sigma = sigma |
|
self.time_sampler = LogitNormalTrainingTimesteps() |
|
|
|
def compute_mu_t(self, x0, x1, t): |
|
""" |
|
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. |
|
|
|
Parameters |
|
---------- |
|
x0 : Tensor, shape (bs, *dim) |
|
represents the source minibatch |
|
x1 : Tensor, shape (bs, *dim) |
|
represents the target minibatch |
|
t : FloatTensor, shape (bs) |
|
|
|
Returns |
|
------- |
|
mean mu_t: t * x1 + (1 - t) * x0 |
|
|
|
References |
|
---------- |
|
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. |
|
""" |
|
t = pad_t_like_x(t, x0) |
|
return t * x1 + (1 - t) * x0 |
|
|
|
def compute_sigma_t(self, t): |
|
""" |
|
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. |
|
|
|
Parameters |
|
---------- |
|
t : FloatTensor, shape (bs) |
|
|
|
Returns |
|
------- |
|
standard deviation sigma |
|
|
|
References |
|
---------- |
|
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. |
|
""" |
|
del t |
|
return self.sigma |
|
|
|
def sample_xt(self, x0, x1, t, epsilon): |
|
""" |
|
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. |
|
|
|
Parameters |
|
---------- |
|
x0 : Tensor, shape (bs, *dim) |
|
represents the source minibatch |
|
x1 : Tensor, shape (bs, *dim) |
|
represents the target minibatch |
|
t : FloatTensor, shape (bs) |
|
epsilon : Tensor, shape (bs, *dim) |
|
noise sample from N(0, 1) |
|
|
|
Returns |
|
------- |
|
xt : Tensor, shape (bs, *dim) |
|
|
|
References |
|
---------- |
|
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. |
|
""" |
|
mu_t = self.compute_mu_t(x0, x1, t) |
|
sigma_t = self.compute_sigma_t(t) |
|
sigma_t = pad_t_like_x(sigma_t, x0) |
|
return mu_t + sigma_t * epsilon |
|
|
|
def compute_conditional_flow(self, x0, x1, t, xt): |
|
""" |
|
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. |
|
|
|
Parameters |
|
---------- |
|
x0 : Tensor, shape (bs, *dim) |
|
represents the source minibatch |
|
x1 : Tensor, shape (bs, *dim) |
|
represents the target minibatch |
|
t : FloatTensor, shape (bs) |
|
xt : Tensor, shape (bs, *dim) |
|
represents the samples drawn from probability path pt |
|
|
|
Returns |
|
------- |
|
ut : conditional vector field ut(x1|x0) = x1 - x0 |
|
|
|
References |
|
---------- |
|
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. |
|
""" |
|
del t, xt |
|
return x1 - x0 |
|
|
|
def sample_noise_like(self, x): |
|
return torch.randn_like(x) |
|
|
|
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): |
|
""" |
|
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) |
|
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. |
|
|
|
Parameters |
|
---------- |
|
x0 : Tensor, shape (bs, *dim) |
|
represents the source minibatch |
|
x1 : Tensor, shape (bs, *dim) |
|
represents the target minibatch |
|
(optionally) t : Tensor, shape (bs) |
|
represents the time levels |
|
if None, drawn from uniform [0,1] |
|
return_noise : bool |
|
return the noise sample epsilon |
|
|
|
|
|
Returns |
|
------- |
|
t : FloatTensor, shape (bs) |
|
xt : Tensor, shape (bs, *dim) |
|
represents the samples drawn from probability path pt |
|
ut : conditional vector field ut(x1|x0) = x1 - x0 |
|
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon |
|
|
|
References |
|
---------- |
|
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. |
|
""" |
|
if t is None: |
|
|
|
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0) |
|
|
|
assert len(t) == x0.shape[0], "t has to have batch size dimension" |
|
|
|
eps = self.sample_noise_like(x0) |
|
xt = self.sample_xt(x0, x1, t, eps) |
|
ut = self.compute_conditional_flow(x0, x1, t, xt) |
|
if return_noise: |
|
return t, xt, ut, eps |
|
else: |
|
return t, xt, ut |
|
|
|
def compute_lambda(self, t): |
|
"""Compute the lambda function, see Eq.(23) [3]. |
|
|
|
Parameters |
|
---------- |
|
t : FloatTensor, shape (bs) |
|
|
|
Returns |
|
------- |
|
lambda : score weighting function |
|
|
|
References |
|
---------- |
|
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al. |
|
""" |
|
sigma_t = self.compute_sigma_t(t) |
|
return 2 * sigma_t / (self.sigma**2 + 1e-8) |
|
|
|
|
|
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher): |
|
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the |
|
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in |
|
order to compute [3]'s trigonometric interpolants. |
|
|
|
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. |
|
""" |
|
|
|
def compute_mu_t(self, x0, x1, t): |
|
r"""Compute the mean of the probability path (Eq.5) from [3]. |
|
|
|
Parameters |
|
---------- |
|
x0 : Tensor, shape (bs, *dim) |
|
represents the source minibatch |
|
x1 : Tensor, shape (bs, *dim) |
|
represents the target minibatch |
|
t : FloatTensor, shape (bs) |
|
|
|
Returns |
|
------- |
|
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1 |
|
|
|
References |
|
---------- |
|
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. |
|
""" |
|
t = pad_t_like_x(t, x0) |
|
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1 |
|
|
|
def compute_conditional_flow(self, x0, x1, t, xt): |
|
r"""Compute the conditional vector field similar to [3]. |
|
|
|
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0), |
|
see Eq.(21) [3]. |
|
|
|
Parameters |
|
---------- |
|
x0 : Tensor, shape (bs, *dim) |
|
represents the source minibatch |
|
x1 : Tensor, shape (bs, *dim) |
|
represents the target minibatch |
|
t : FloatTensor, shape (bs) |
|
xt : Tensor, shape (bs, *dim) |
|
represents the samples drawn from probability path pt |
|
|
|
Returns |
|
------- |
|
ut : conditional vector field |
|
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0) |
|
|
|
References |
|
---------- |
|
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. |
|
""" |
|
del xt |
|
t = pad_t_like_x(t, x0) |
|
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0) |
|
|