import torch import math from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc from diffusers_helper.k_diffusion.wrapper import fm_wrapper from diffusers_helper.utils import repeat_to_batch_size def flux_time_shift(t, mu=1.15, sigma=1.0): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): k = (y2 - y1) / (x2 - x1) b = y1 - k * x1 mu = k * context_length + b mu = min(mu, math.log(exp_max)) return mu def get_flux_sigmas_from_mu(n, mu): sigmas = torch.linspace(1, 0, steps=n + 1) sigmas = flux_time_shift(sigmas, mu=mu) return sigmas @torch.inference_mode() def sample_hunyuan( transformer, sampler='unipc', initial_latent=None, concat_latent=None, strength=1.0, width=512, height=512, frames=16, real_guidance_scale=1.0, distilled_guidance_scale=6.0, guidance_rescale=0.0, shift=None, num_inference_steps=25, batch_size=None, generator=None, prompt_embeds=None, prompt_embeds_mask=None, prompt_poolers=None, negative_prompt_embeds=None, negative_prompt_embeds_mask=None, negative_prompt_poolers=None, dtype=torch.bfloat16, device=None, negative_kwargs=None, callback=None, **kwargs, ): device = device or transformer.device if batch_size is None: batch_size = int(prompt_embeds.shape[0]) latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) B, C, T, H, W = latents.shape seq_length = T * H * W // 4 if shift is None: mu = calculate_flux_mu(seq_length, exp_max=7.0) else: mu = math.log(shift) sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) k_model = fm_wrapper(transformer) if initial_latent is not None: sigmas = sigmas * strength first_sigma = sigmas[0].to(device=device, dtype=torch.float32) initial_latent = initial_latent.to(device=device, dtype=torch.float32) latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma if concat_latent is not None: concat_latent = concat_latent.to(latents) distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) concat_latent = repeat_to_batch_size(concat_latent, batch_size) sampler_kwargs = dict( dtype=dtype, cfg_scale=real_guidance_scale, cfg_rescale=guidance_rescale, concat_latent=concat_latent, positive=dict( pooled_projections=prompt_poolers, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_embeds_mask, guidance=distilled_guidance, **kwargs, ), negative=dict( pooled_projections=negative_prompt_poolers, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_embeds_mask, guidance=distilled_guidance, **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), ) ) if sampler == 'unipc': results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback) else: raise NotImplementedError(f'Sampler {sampler} is not supported.') return results