cavargas10's picture
Upload 13 files
04fa6ac verified
raw
history blame contribute delete
2.41 kB
import torch
import numpy as np
from ....utils.general_utils import dict_foreach
from ....pipelines import samplers
class ClassifierFreeGuidanceMixin:
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
super().__init__(*args, **kwargs)
self.p_uncond = p_uncond
def get_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
if self.p_uncond > 0:
# randomly drop the class label
def get_batch_size(cond):
if isinstance(cond, torch.Tensor):
return cond.shape[0]
elif isinstance(cond, list):
return len(cond)
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
B = get_batch_size(ref_cond)
def select(cond, neg_cond, mask):
if isinstance(cond, torch.Tensor):
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
return torch.where(mask, neg_cond, cond)
elif isinstance(cond, list):
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
mask = list(np.random.rand(B) < self.p_uncond)
if not isinstance(cond, dict):
cond = select(cond, neg_cond, mask)
else:
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
return cond
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data for inference.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
"""
Get the sampler for the diffusion process.
"""
return samplers.FlowEulerCfgSampler(self.sigma_min)