Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from ....utils.general_utils import dict_foreach | |
from ....pipelines import samplers | |
class ClassifierFreeGuidanceMixin: | |
def __init__(self, *args, p_uncond: float = 0.1, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.p_uncond = p_uncond | |
def get_cond(self, cond, neg_cond=None, **kwargs): | |
""" | |
Get the conditioning data. | |
""" | |
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" | |
if self.p_uncond > 0: | |
# randomly drop the class label | |
def get_batch_size(cond): | |
if isinstance(cond, torch.Tensor): | |
return cond.shape[0] | |
elif isinstance(cond, list): | |
return len(cond) | |
else: | |
raise ValueError(f"Unsupported type of cond: {type(cond)}") | |
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] | |
B = get_batch_size(ref_cond) | |
def select(cond, neg_cond, mask): | |
if isinstance(cond, torch.Tensor): | |
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) | |
return torch.where(mask, neg_cond, cond) | |
elif isinstance(cond, list): | |
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] | |
else: | |
raise ValueError(f"Unsupported type of cond: {type(cond)}") | |
mask = list(np.random.rand(B) < self.p_uncond) | |
if not isinstance(cond, dict): | |
cond = select(cond, neg_cond, mask) | |
else: | |
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask)) | |
return cond | |
def get_inference_cond(self, cond, neg_cond=None, **kwargs): | |
""" | |
Get the conditioning data for inference. | |
""" | |
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" | |
return {'cond': cond, 'neg_cond': neg_cond, **kwargs} | |
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: | |
""" | |
Get the sampler for the diffusion process. | |
""" | |
return samplers.FlowEulerCfgSampler(self.sigma_min) | |