import torch import numpy as np from ....utils.general_utils import dict_foreach from ....pipelines import samplers class ClassifierFreeGuidanceMixin: def __init__(self, *args, p_uncond: float = 0.1, **kwargs): super().__init__(*args, **kwargs) self.p_uncond = p_uncond def get_cond(self, cond, neg_cond=None, **kwargs): """ Get the conditioning data. """ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" if self.p_uncond > 0: # randomly drop the class label def get_batch_size(cond): if isinstance(cond, torch.Tensor): return cond.shape[0] elif isinstance(cond, list): return len(cond) else: raise ValueError(f"Unsupported type of cond: {type(cond)}") ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] B = get_batch_size(ref_cond) def select(cond, neg_cond, mask): if isinstance(cond, torch.Tensor): mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) return torch.where(mask, neg_cond, cond) elif isinstance(cond, list): return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] else: raise ValueError(f"Unsupported type of cond: {type(cond)}") mask = list(np.random.rand(B) < self.p_uncond) if not isinstance(cond, dict): cond = select(cond, neg_cond, mask) else: cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask)) return cond def get_inference_cond(self, cond, neg_cond=None, **kwargs): """ Get the conditioning data for inference. """ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" return {'cond': cond, 'neg_cond': neg_cond, **kwargs} def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: """ Get the sampler for the diffusion process. """ return samplers.FlowEulerCfgSampler(self.sigma_min)