|
import torch |
|
import math |
|
|
|
from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc |
|
from diffusers_helper.k_diffusion.wrapper import fm_wrapper |
|
from diffusers_helper.utils import repeat_to_batch_size |
|
|
|
|
|
def flux_time_shift(t, mu=1.15, sigma=1.0): |
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
|
|
|
|
|
def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): |
|
k = (y2 - y1) / (x2 - x1) |
|
b = y1 - k * x1 |
|
mu = k * context_length + b |
|
mu = min(mu, math.log(exp_max)) |
|
return mu |
|
|
|
|
|
def get_flux_sigmas_from_mu(n, mu): |
|
sigmas = torch.linspace(1, 0, steps=n + 1) |
|
sigmas = flux_time_shift(sigmas, mu=mu) |
|
return sigmas |
|
|
|
|
|
@torch.inference_mode() |
|
def sample_hunyuan( |
|
transformer, |
|
sampler='unipc', |
|
initial_latent=None, |
|
concat_latent=None, |
|
strength=1.0, |
|
width=512, |
|
height=512, |
|
frames=16, |
|
real_guidance_scale=1.0, |
|
distilled_guidance_scale=6.0, |
|
guidance_rescale=0.0, |
|
shift=None, |
|
num_inference_steps=25, |
|
batch_size=None, |
|
generator=None, |
|
prompt_embeds=None, |
|
prompt_embeds_mask=None, |
|
prompt_poolers=None, |
|
negative_prompt_embeds=None, |
|
negative_prompt_embeds_mask=None, |
|
negative_prompt_poolers=None, |
|
dtype=torch.bfloat16, |
|
device=None, |
|
negative_kwargs=None, |
|
callback=None, |
|
**kwargs, |
|
): |
|
device = device or transformer.device |
|
|
|
if batch_size is None: |
|
batch_size = int(prompt_embeds.shape[0]) |
|
|
|
latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) |
|
|
|
B, C, T, H, W = latents.shape |
|
seq_length = T * H * W // 4 |
|
|
|
if shift is None: |
|
mu = calculate_flux_mu(seq_length, exp_max=7.0) |
|
else: |
|
mu = math.log(shift) |
|
|
|
sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) |
|
|
|
k_model = fm_wrapper(transformer) |
|
|
|
if initial_latent is not None: |
|
sigmas = sigmas * strength |
|
first_sigma = sigmas[0].to(device=device, dtype=torch.float32) |
|
initial_latent = initial_latent.to(device=device, dtype=torch.float32) |
|
latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma |
|
|
|
if concat_latent is not None: |
|
concat_latent = concat_latent.to(latents) |
|
|
|
distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) |
|
|
|
prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) |
|
prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) |
|
prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) |
|
negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) |
|
negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) |
|
negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) |
|
concat_latent = repeat_to_batch_size(concat_latent, batch_size) |
|
|
|
sampler_kwargs = dict( |
|
dtype=dtype, |
|
cfg_scale=real_guidance_scale, |
|
cfg_rescale=guidance_rescale, |
|
concat_latent=concat_latent, |
|
positive=dict( |
|
pooled_projections=prompt_poolers, |
|
encoder_hidden_states=prompt_embeds, |
|
encoder_attention_mask=prompt_embeds_mask, |
|
guidance=distilled_guidance, |
|
**kwargs, |
|
), |
|
negative=dict( |
|
pooled_projections=negative_prompt_poolers, |
|
encoder_hidden_states=negative_prompt_embeds, |
|
encoder_attention_mask=negative_prompt_embeds_mask, |
|
guidance=distilled_guidance, |
|
**(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), |
|
) |
|
) |
|
|
|
if sampler == 'unipc': |
|
results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback) |
|
else: |
|
raise NotImplementedError(f'Sampler {sampler} is not supported.') |
|
|
|
return results |
|
|