FQiao's picture
Upload 70 files
3324de2 verified
raw
history blame contribute delete
9.7 kB
import os
import logging
import json
import random
import torch
import torchaudio
import re
from diffusers import AutoencoderOobleck, FluxTransformer2DModel
from huggingface_hub import snapshot_download
from comfy.utils import load_torch_file, ProgressBar
import folder_paths
from tangoflux.model import TangoFlux
from .teacache import teacache_forward
log = logging.getLogger("TangoFlux")
TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
if "tangoflux" not in folder_paths.folder_names_and_paths:
current_paths = [TANGOFLUX_DIR]
else:
current_paths, _ = folder_paths.folder_names_and_paths["tangoflux"]
folder_paths.folder_names_and_paths["tangoflux"] = (
current_paths,
folder_paths.supported_pt_extensions,
)
TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
class TangoFluxLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"enable_teacache": ("BOOLEAN", {"default": False}),
"rel_l1_thresh": (
"FLOAT",
{"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01},
),
},
}
RETURN_TYPES = ("TANGOFLUX_MODEL", "TANGOFLUX_VAE")
RETURN_NAMES = ("model", "vae")
OUTPUT_TOOLTIPS = ("TangoFlux Model", "TangoFlux Vae")
CATEGORY = "TangoFlux"
FUNCTION = "load_tangoflux"
DESCRIPTION = "Load TangoFlux model"
def __init__(self):
self.model = None
self.vae = None
self.enable_teacache = False
self.rel_l1_thresh = 0.25
self.original_forward = FluxTransformer2DModel.forward
def load_tangoflux(
self,
enable_teacache=False,
rel_l1_thresh=0.25,
tangoflux_path=TANGOFLUX_DIR,
text_encoder_path=TEXT_ENCODER_DIR,
device="cuda",
):
if self.model is None or self.enable_teacache != enable_teacache:
pbar = ProgressBar(6)
snapshot_download(
repo_id="declare-lab/TangoFlux",
allow_patterns=["*.json", "*.safetensors"],
local_dir=tangoflux_path,
local_dir_use_symlinks=False,
)
pbar.update(1)
log.info("Loading config")
with open(os.path.join(tangoflux_path, "config.json"), "r") as f:
config = json.load(f)
pbar.update(1)
text_encoder = re.sub(
r'[<>:"/\\|?*]',
"-",
config.get("text_encoder_name", "google/flan-t5-large"),
)
text_encoder_path = os.path.join(text_encoder_path, text_encoder)
snapshot_download(
repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
allow_patterns=["*.json", "*.safetensors", "*.model"],
local_dir=text_encoder_path,
local_dir_use_symlinks=False,
)
pbar.update(1)
log.info("Loading TangoFlux models")
del self.model
self.model = None
model_weights = load_torch_file(
os.path.join(tangoflux_path, "tangoflux.safetensors"),
device=torch.device(device),
)
pbar.update(1)
if enable_teacache:
log.info("Enabling TeaCache")
FluxTransformer2DModel.forward = teacache_forward
else:
log.info("Disabling TeaCache")
FluxTransformer2DModel.forward = self.original_forward
model = TangoFlux(config=config, text_encoder_dir=text_encoder_path)
model.load_state_dict(model_weights, strict=False)
model.to(device)
if enable_teacache:
model.transformer.__class__.enable_teacache = True
model.transformer.__class__.cnt = 0
model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
model.transformer.__class__.accumulated_rel_l1_distance = 0
model.transformer.__class__.previous_modulated_input = None
model.transformer.__class__.previous_residual = None
pbar.update(1)
self.model = model
del model
self.enable_teacache = enable_teacache
self.rel_l1_thresh = rel_l1_thresh
if self.vae is None:
log.info("Loading TangoFlux VAE")
vae_weights = load_torch_file(
os.path.join(tangoflux_path, "vae.safetensors")
)
self.vae = AutoencoderOobleck()
self.vae.load_state_dict(vae_weights)
self.vae.to(device)
pbar.update(1)
if self.enable_teacache == True and self.rel_l1_thresh != rel_l1_thresh:
self.model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
self.rel_l1_thresh = rel_l1_thresh
return (self.model, self.vae)
class TangoFluxSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("TANGOFLUX_MODEL",),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"steps": ("INT", {"default": 50, "min": 1, "max": 10000, "step": 1}),
"guidance_scale": (
"FLOAT",
{"default": 3, "min": 1, "max": 100, "step": 1},
),
"duration": ("INT", {"default": 10, "min": 1, "max": 30, "step": 1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
}
RETURN_TYPES = ("TANGOFLUX_LATENTS",)
RETURN_NAMES = ("latents",)
OUTPUT_TOOLTIPS = "TangoFlux Sample"
CATEGORY = "TangoFlux"
FUNCTION = "sample"
DESCRIPTION = "Sampler for TangoFlux"
def sample(
self,
model,
prompt,
steps=50,
guidance_scale=3,
duration=10,
seed=0,
batch_size=1,
device="cuda",
):
pbar = ProgressBar(steps)
with torch.no_grad():
model.to(device)
try:
if model.transformer.__class__.enable_teacache:
model.transformer.__class__.num_steps = steps
except:
pass
log.info("Generating latents with TangoFlux")
latents = model.inference_flow(
prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale,
seed=seed,
num_samples_per_prompt=batch_size,
callback_on_step_end=lambda: pbar.update(1),
)
return ({"latents": latents, "duration": duration},)
class TangoFluxVAEDecodeAndPlay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"vae": ("TANGOFLUX_VAE",),
"latents": ("TANGOFLUX_LATENTS",),
"filename_prefix": ("STRING", {"default": "TangoFlux"}),
"format": (
["wav", "mp3", "flac", "aac", "wma"],
{"default": "wav"},
),
"save_output": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ()
OUTPUT_NODE = True
CATEGORY = "TangoFlux"
FUNCTION = "play"
DESCRIPTION = "Decoder and Player for TangoFlux"
def decode(self, vae, latents):
results = []
for latent in latents:
decoded = vae.decode(latent.unsqueeze(0).transpose(2, 1)).sample.cpu()
results.append(decoded)
results = torch.cat(results, dim=0)
return results
def play(
self,
vae,
latents,
filename_prefix="TangoFlux",
format="wav",
save_output=True,
device="cuda",
):
audios = []
pbar = ProgressBar(len(latents) + 2)
if save_output:
output_dir = folder_paths.get_output_directory()
prefix_append = ""
type = "output"
else:
output_dir = folder_paths.get_temp_directory()
prefix_append = "_temp_" + "".join(
random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)
)
type = "temp"
filename_prefix += prefix_append
full_output_folder, filename, counter, subfolder, _ = (
folder_paths.get_save_image_path(filename_prefix, output_dir)
)
os.makedirs(full_output_folder, exist_ok=True)
pbar.update(1)
duration = latents["duration"]
latents = latents["latents"]
vae.to(device)
log.info("Decoding Tangoflux latents")
waves = self.decode(vae, latents)
pbar.update(1)
for wave in waves:
waveform_end = int(duration * vae.config.sampling_rate)
wave = wave[:, :waveform_end]
file = f"{filename}_{counter:05}_.{format}"
torchaudio.save(
os.path.join(full_output_folder, file), wave, sample_rate=44100
)
counter += 1
audios.append({"filename": file, "subfolder": subfolder, "type": type})
pbar.update(1)
return {
"ui": {"audios": audios},
}
NODE_CLASS_MAPPINGS = {
"TangoFluxLoader": TangoFluxLoader,
"TangoFluxSampler": TangoFluxSampler,
"TangoFluxVAEDecodeAndPlay": TangoFluxVAEDecodeAndPlay,
}