Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import logging | |
import json | |
import random | |
import torch | |
import torchaudio | |
import re | |
from diffusers import AutoencoderOobleck, FluxTransformer2DModel | |
from huggingface_hub import snapshot_download | |
from comfy.utils import load_torch_file, ProgressBar | |
import folder_paths | |
from tangoflux.model import TangoFlux | |
from .teacache import teacache_forward | |
log = logging.getLogger("TangoFlux") | |
TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux") | |
if "tangoflux" not in folder_paths.folder_names_and_paths: | |
current_paths = [TANGOFLUX_DIR] | |
else: | |
current_paths, _ = folder_paths.folder_names_and_paths["tangoflux"] | |
folder_paths.folder_names_and_paths["tangoflux"] = ( | |
current_paths, | |
folder_paths.supported_pt_extensions, | |
) | |
TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders") | |
class TangoFluxLoader: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"enable_teacache": ("BOOLEAN", {"default": False}), | |
"rel_l1_thresh": ( | |
"FLOAT", | |
{"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01}, | |
), | |
}, | |
} | |
RETURN_TYPES = ("TANGOFLUX_MODEL", "TANGOFLUX_VAE") | |
RETURN_NAMES = ("model", "vae") | |
OUTPUT_TOOLTIPS = ("TangoFlux Model", "TangoFlux Vae") | |
CATEGORY = "TangoFlux" | |
FUNCTION = "load_tangoflux" | |
DESCRIPTION = "Load TangoFlux model" | |
def __init__(self): | |
self.model = None | |
self.vae = None | |
self.enable_teacache = False | |
self.rel_l1_thresh = 0.25 | |
self.original_forward = FluxTransformer2DModel.forward | |
def load_tangoflux( | |
self, | |
enable_teacache=False, | |
rel_l1_thresh=0.25, | |
tangoflux_path=TANGOFLUX_DIR, | |
text_encoder_path=TEXT_ENCODER_DIR, | |
device="cuda", | |
): | |
if self.model is None or self.enable_teacache != enable_teacache: | |
pbar = ProgressBar(6) | |
snapshot_download( | |
repo_id="declare-lab/TangoFlux", | |
allow_patterns=["*.json", "*.safetensors"], | |
local_dir=tangoflux_path, | |
local_dir_use_symlinks=False, | |
) | |
pbar.update(1) | |
log.info("Loading config") | |
with open(os.path.join(tangoflux_path, "config.json"), "r") as f: | |
config = json.load(f) | |
pbar.update(1) | |
text_encoder = re.sub( | |
r'[<>:"/\\|?*]', | |
"-", | |
config.get("text_encoder_name", "google/flan-t5-large"), | |
) | |
text_encoder_path = os.path.join(text_encoder_path, text_encoder) | |
snapshot_download( | |
repo_id=config.get("text_encoder_name", "google/flan-t5-large"), | |
allow_patterns=["*.json", "*.safetensors", "*.model"], | |
local_dir=text_encoder_path, | |
local_dir_use_symlinks=False, | |
) | |
pbar.update(1) | |
log.info("Loading TangoFlux models") | |
del self.model | |
self.model = None | |
model_weights = load_torch_file( | |
os.path.join(tangoflux_path, "tangoflux.safetensors"), | |
device=torch.device(device), | |
) | |
pbar.update(1) | |
if enable_teacache: | |
log.info("Enabling TeaCache") | |
FluxTransformer2DModel.forward = teacache_forward | |
else: | |
log.info("Disabling TeaCache") | |
FluxTransformer2DModel.forward = self.original_forward | |
model = TangoFlux(config=config, text_encoder_dir=text_encoder_path) | |
model.load_state_dict(model_weights, strict=False) | |
model.to(device) | |
if enable_teacache: | |
model.transformer.__class__.enable_teacache = True | |
model.transformer.__class__.cnt = 0 | |
model.transformer.__class__.rel_l1_thresh = rel_l1_thresh | |
model.transformer.__class__.accumulated_rel_l1_distance = 0 | |
model.transformer.__class__.previous_modulated_input = None | |
model.transformer.__class__.previous_residual = None | |
pbar.update(1) | |
self.model = model | |
del model | |
self.enable_teacache = enable_teacache | |
self.rel_l1_thresh = rel_l1_thresh | |
if self.vae is None: | |
log.info("Loading TangoFlux VAE") | |
vae_weights = load_torch_file( | |
os.path.join(tangoflux_path, "vae.safetensors") | |
) | |
self.vae = AutoencoderOobleck() | |
self.vae.load_state_dict(vae_weights) | |
self.vae.to(device) | |
pbar.update(1) | |
if self.enable_teacache == True and self.rel_l1_thresh != rel_l1_thresh: | |
self.model.transformer.__class__.rel_l1_thresh = rel_l1_thresh | |
self.rel_l1_thresh = rel_l1_thresh | |
return (self.model, self.vae) | |
class TangoFluxSampler: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"model": ("TANGOFLUX_MODEL",), | |
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), | |
"steps": ("INT", {"default": 50, "min": 1, "max": 10000, "step": 1}), | |
"guidance_scale": ( | |
"FLOAT", | |
{"default": 3, "min": 1, "max": 100, "step": 1}, | |
), | |
"duration": ("INT", {"default": 10, "min": 1, "max": 30, "step": 1}), | |
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), | |
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), | |
}, | |
} | |
RETURN_TYPES = ("TANGOFLUX_LATENTS",) | |
RETURN_NAMES = ("latents",) | |
OUTPUT_TOOLTIPS = "TangoFlux Sample" | |
CATEGORY = "TangoFlux" | |
FUNCTION = "sample" | |
DESCRIPTION = "Sampler for TangoFlux" | |
def sample( | |
self, | |
model, | |
prompt, | |
steps=50, | |
guidance_scale=3, | |
duration=10, | |
seed=0, | |
batch_size=1, | |
device="cuda", | |
): | |
pbar = ProgressBar(steps) | |
with torch.no_grad(): | |
model.to(device) | |
try: | |
if model.transformer.__class__.enable_teacache: | |
model.transformer.__class__.num_steps = steps | |
except: | |
pass | |
log.info("Generating latents with TangoFlux") | |
latents = model.inference_flow( | |
prompt, | |
duration=duration, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
num_samples_per_prompt=batch_size, | |
callback_on_step_end=lambda: pbar.update(1), | |
) | |
return ({"latents": latents, "duration": duration},) | |
class TangoFluxVAEDecodeAndPlay: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"vae": ("TANGOFLUX_VAE",), | |
"latents": ("TANGOFLUX_LATENTS",), | |
"filename_prefix": ("STRING", {"default": "TangoFlux"}), | |
"format": ( | |
["wav", "mp3", "flac", "aac", "wma"], | |
{"default": "wav"}, | |
), | |
"save_output": ("BOOLEAN", {"default": True}), | |
}, | |
} | |
RETURN_TYPES = () | |
OUTPUT_NODE = True | |
CATEGORY = "TangoFlux" | |
FUNCTION = "play" | |
DESCRIPTION = "Decoder and Player for TangoFlux" | |
def decode(self, vae, latents): | |
results = [] | |
for latent in latents: | |
decoded = vae.decode(latent.unsqueeze(0).transpose(2, 1)).sample.cpu() | |
results.append(decoded) | |
results = torch.cat(results, dim=0) | |
return results | |
def play( | |
self, | |
vae, | |
latents, | |
filename_prefix="TangoFlux", | |
format="wav", | |
save_output=True, | |
device="cuda", | |
): | |
audios = [] | |
pbar = ProgressBar(len(latents) + 2) | |
if save_output: | |
output_dir = folder_paths.get_output_directory() | |
prefix_append = "" | |
type = "output" | |
else: | |
output_dir = folder_paths.get_temp_directory() | |
prefix_append = "_temp_" + "".join( | |
random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5) | |
) | |
type = "temp" | |
filename_prefix += prefix_append | |
full_output_folder, filename, counter, subfolder, _ = ( | |
folder_paths.get_save_image_path(filename_prefix, output_dir) | |
) | |
os.makedirs(full_output_folder, exist_ok=True) | |
pbar.update(1) | |
duration = latents["duration"] | |
latents = latents["latents"] | |
vae.to(device) | |
log.info("Decoding Tangoflux latents") | |
waves = self.decode(vae, latents) | |
pbar.update(1) | |
for wave in waves: | |
waveform_end = int(duration * vae.config.sampling_rate) | |
wave = wave[:, :waveform_end] | |
file = f"{filename}_{counter:05}_.{format}" | |
torchaudio.save( | |
os.path.join(full_output_folder, file), wave, sample_rate=44100 | |
) | |
counter += 1 | |
audios.append({"filename": file, "subfolder": subfolder, "type": type}) | |
pbar.update(1) | |
return { | |
"ui": {"audios": audios}, | |
} | |
NODE_CLASS_MAPPINGS = { | |
"TangoFluxLoader": TangoFluxLoader, | |
"TangoFluxSampler": TangoFluxSampler, | |
"TangoFluxVAEDecodeAndPlay": TangoFluxVAEDecodeAndPlay, | |
} | |