Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import platform | |
import os | |
import subprocess as sp | |
import gradio as gr | |
import json | |
import torch | |
import torchaudio | |
from aeiou.viz import audio_spectrogram_image | |
from einops import rearrange | |
from safetensors.torch import load_file | |
from torch.nn import functional as F | |
from torchaudio import transforms as T | |
from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond | |
from ..models.factory import create_model_from_config | |
from ..models.pretrained import get_pretrained_model | |
from ..models.utils import load_ckpt_state_dict | |
from ..inference.utils import prepare_audio | |
from ..training.utils import copy_state_dict | |
from ..data.utils import read_video, merge_video_audio | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
device = torch.device("cpu") | |
os.environ['TMPDIR'] = './tmp' | |
current_model_name = None | |
current_model = None | |
current_sample_rate = None | |
current_sample_size = None | |
def load_model(model_name, model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): | |
global model_configurations | |
if pretrained_name is not None: | |
print(f"Loading pretrained model {pretrained_name}") | |
model, model_config = get_pretrained_model(pretrained_name) | |
elif model_config is not None and model_ckpt_path is not None: | |
print(f"Creating model from config") | |
model = create_model_from_config(model_config) | |
print(f"Loading model checkpoint from {model_ckpt_path}") | |
copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) | |
sample_rate = model_config["sample_rate"] | |
sample_size = model_config["sample_size"] | |
if pretransform_ckpt_path is not None: | |
print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") | |
model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) | |
print(f"Done loading pretransform") | |
model.to(device).eval().requires_grad_(False) | |
if model_half: | |
model.to(torch.float16) | |
print(f"Done loading model") | |
return model, model_config, sample_rate, sample_size | |
def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total): | |
if audio_path is None: | |
return torch.zeros((2, int(sample_rate * seconds_total))) | |
audio_tensor, sr = torchaudio.load(audio_path) | |
start_index = int(sample_rate * seconds_start) | |
target_length = int(sample_rate * seconds_total) | |
end_index = start_index + target_length | |
audio_tensor = audio_tensor[:, start_index:end_index] | |
if audio_tensor.shape[1] < target_length: | |
pad_length = target_length - audio_tensor.shape[1] | |
audio_tensor = F.pad(audio_tensor, (pad_length, 0)) | |
return audio_tensor | |
def generate_cond( | |
prompt, | |
negative_prompt=None, | |
video_file=None, | |
video_path=None, | |
audio_prompt_file=None, | |
audio_prompt_path=None, | |
seconds_start=0, | |
seconds_total=10, | |
cfg_scale=6.0, | |
steps=250, | |
preview_every=None, | |
seed=-1, | |
sampler_type="dpmpp-3m-sde", | |
sigma_min=0.03, | |
sigma_max=1000, | |
cfg_rescale=0.0, | |
use_init=False, | |
init_audio=None, | |
init_noise_level=1.0, | |
mask_cropfrom=None, | |
mask_pastefrom=None, | |
mask_pasteto=None, | |
mask_maskstart=None, | |
mask_maskend=None, | |
mask_softnessL=None, | |
mask_softnessR=None, | |
mask_marination=None, | |
batch_size=1 | |
): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
print(f"Prompt: {prompt}") | |
preview_images = [] | |
if preview_every == 0: | |
preview_every = None | |
try: | |
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() | |
except Exception: | |
has_mps = False | |
if has_mps: | |
device = torch.device("mps") | |
elif torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
model_name = 'default' | |
cfg = model_configurations[model_name] | |
model_config_path = cfg.get("model_config") | |
ckpt_path = cfg.get("ckpt_path") | |
pretrained_name = cfg.get("pretrained_name") | |
pretransform_ckpt_path = cfg.get("pretransform_ckpt_path") | |
model_type = cfg.get("model_type", "diffusion_cond") | |
if model_config_path: | |
with open(model_config_path) as f: | |
model_config = json.load(f) | |
else: | |
model_config = None | |
target_fps = model_config.get("video_fps", 5) | |
global current_model_name, current_model, current_sample_rate, current_sample_size | |
if current_model is None or model_name != current_model_name: | |
current_model, model_config, sample_rate, sample_size = load_model( | |
model_name=model_name, | |
model_config=model_config, | |
model_ckpt_path=ckpt_path, | |
pretrained_name=pretrained_name, | |
pretransform_ckpt_path=pretransform_ckpt_path, | |
device=device, | |
model_half=False | |
) | |
current_model_name = model_name | |
model = current_model | |
current_sample_rate = sample_rate | |
current_sample_size = sample_size | |
else: | |
model = current_model | |
sample_rate = current_sample_rate | |
sample_size = current_sample_size | |
if video_file is not None: | |
video_path = video_file.name | |
elif video_path: | |
video_path = video_path.strip() | |
else: | |
video_path = None | |
if audio_prompt_file is not None: | |
print(f'audio_prompt_file: {audio_prompt_file}') | |
audio_path = audio_prompt_file.name | |
elif audio_prompt_path: | |
audio_path = audio_prompt_path.strip() | |
else: | |
audio_path = None | |
Video_tensors = read_video(video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) | |
audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) | |
audio_tensor = audio_tensor.to(device) | |
seconds_input = sample_size / sample_rate | |
print(f'video_path: {video_path}') | |
if not prompt: | |
prompt = "" | |
conditioning = [{ | |
"video_prompt": [Video_tensors.unsqueeze(0)], | |
"text_prompt": prompt, | |
"audio_prompt": audio_tensor.unsqueeze(0), | |
"seconds_start": seconds_start, | |
"seconds_total": seconds_input | |
}] * batch_size | |
if negative_prompt: | |
negative_conditioning = [{ | |
"video_prompt": [Video_tensors.unsqueeze(0)], | |
"text_prompt": negative_prompt, | |
"audio_prompt": audio_tensor.unsqueeze(0), | |
"seconds_start": seconds_start, | |
"seconds_total": seconds_total | |
}] * batch_size | |
else: | |
negative_conditioning = None | |
try: | |
device = next(model.parameters()).device | |
except Exception as e: | |
device = next(current_model.parameters()).device | |
seed = int(seed) | |
if not use_init: | |
init_audio = None | |
input_sample_size = sample_size | |
if init_audio is not None: | |
in_sr, init_audio = init_audio | |
init_audio = torch.from_numpy(init_audio).float().div(32767) | |
if init_audio.dim() == 1: | |
init_audio = init_audio.unsqueeze(0) | |
elif init_audio.dim() == 2: | |
init_audio = init_audio.transpose(0, 1) | |
if in_sr != sample_rate: | |
resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) | |
init_audio = resample_tf(init_audio) | |
audio_length = init_audio.shape[-1] | |
if audio_length > sample_size: | |
input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length | |
init_audio = (sample_rate, init_audio) | |
def progress_callback(callback_info): | |
nonlocal preview_images | |
denoised = callback_info["denoised"] | |
current_step = callback_info["i"] | |
sigma = callback_info["sigma"] | |
if (current_step - 1) % preview_every == 0: | |
if model.pretransform is not None: | |
denoised = model.pretransform.decode(denoised) | |
denoised = rearrange(denoised, "b d n -> d (b n)") | |
denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) | |
preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) | |
if mask_cropfrom is not None: | |
mask_args = { | |
"cropfrom": mask_cropfrom, | |
"pastefrom": mask_pastefrom, | |
"pasteto": mask_pasteto, | |
"maskstart": mask_maskstart, | |
"maskend": mask_maskend, | |
"softnessL": mask_softnessL, | |
"softnessR": mask_softnessR, | |
"marination": mask_marination, | |
} | |
else: | |
mask_args = None | |
if model_type == "diffusion_cond": | |
audio = generate_diffusion_cond( | |
model, | |
conditioning=conditioning, | |
negative_conditioning=negative_conditioning, | |
steps=steps, | |
cfg_scale=cfg_scale, | |
batch_size=batch_size, | |
sample_size=input_sample_size, | |
sample_rate=sample_rate, | |
seed=seed, | |
device=device, | |
sampler_type=sampler_type, | |
sigma_min=sigma_min, | |
sigma_max=sigma_max, | |
init_audio=init_audio, | |
init_noise_level=init_noise_level, | |
mask_args=mask_args, | |
callback=progress_callback if preview_every is not None else None, | |
scale_phi=cfg_rescale | |
) | |
elif model_type == "diffusion_uncond": | |
audio = generate_diffusion_uncond( | |
model, | |
steps=steps, | |
batch_size=batch_size, | |
sample_size=input_sample_size, | |
seed=seed, | |
device=device, | |
sampler_type=sampler_type, | |
sigma_min=sigma_min, | |
sigma_max=sigma_max, | |
init_audio=init_audio, | |
init_noise_level=init_noise_level, | |
callback=progress_callback if preview_every is not None else None | |
) | |
else: | |
raise ValueError(f"Unsupported model type: {model_type}") | |
audio = rearrange(audio, "b d n -> d (b n)") | |
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
file_name = os.path.basename(video_path) if video_path else "output" | |
output_dir = f"demo_result" | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
output_video_path = f"{output_dir}/{file_name}" | |
torchaudio.save(f"{output_dir}/output.wav", audio, sample_rate) | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
if video_path: | |
merge_video_audio(video_path, f"{output_dir}/output.wav", output_video_path, seconds_start, seconds_total) | |
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) | |
del video_path | |
torch.cuda.empty_cache() | |
gc.collect() | |
return (output_video_path, f"{output_dir}/output.wav") | |
def toggle_custom_model(selected_model): | |
return gr.Row.update(visible=(selected_model == "Custom Model")) | |
def create_sampling_ui(model_config_map, inpainting=False): | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# π§AudioX: Diffusion Transformer for Anything-to-Audio Generation | |
**[Project Page](https://zeyuet.github.io/AudioX/) Β· [Huggingface](https://huggingface.co/Zeyue7/AudioX) Β· [GitHub](https://github.com/ZeyueT/AudioX)** | |
""" | |
) | |
with gr.Tab("Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt") | |
negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt", visible=False) | |
video_path = gr.Textbox(label="Video Path", placeholder="Enter video file path") | |
video_file = gr.File(label="Upload Video File") | |
audio_prompt_file = gr.File(label="Upload Audio Prompt File", visible=False) | |
audio_prompt_path = gr.Textbox(label="Audio Prompt Path", placeholder="Enter audio file path", visible=False) | |
with gr.Row(): | |
with gr.Column(scale=6): | |
with gr.Accordion("Video Params", open=False): | |
seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start") | |
seconds_total_slider = gr.Slider(minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Accordion("Sampler Params", open=False): | |
steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") | |
preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") | |
cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale") | |
seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") | |
sampler_type_dropdown = gr.Dropdown( | |
["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], | |
label="Sampler Type", | |
value="dpmpp-3m-sde" | |
) | |
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min") | |
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max") | |
cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Accordion("Init Audio", open=False, visible=False): | |
init_audio_checkbox = gr.Checkbox(label="Use Init Audio") | |
init_audio_input = gr.Audio(label="Init Audio") | |
init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level") | |
gr.Markdown("## Examples") | |
with gr.Accordion("Click to show examples", open=False): | |
with gr.Row(): | |
gr.Markdown("**π Task: Text-to-Audio**") | |
with gr.Column(scale=1.2): | |
gr.Markdown("Prompt: *Typing on a keyboard*") | |
ex1 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Markdown("Prompt: *Ocean waves crashing*") | |
ex2 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Markdown("Prompt: *Footsteps in snow*") | |
ex3 = gr.Button("Load Example") | |
with gr.Row(): | |
gr.Markdown("**πΆ Task: Text-to-Music**") | |
with gr.Column(scale=1.2): | |
gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") | |
ex4 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") | |
ex5 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") | |
ex6 = gr.Button("Load Example") | |
with gr.Row(): | |
gr.Markdown("**π¬ Task: Video-to-Audio**\nPrompt: *Generate general audio for the video*") | |
with gr.Column(scale=1.2): | |
gr.Video("example/V2A_sample-1.mp4") | |
ex7 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Video("example/V2A_sample-2.mp4") | |
ex8 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Video("example/V2A_sample-3.mp4") | |
ex9 = gr.Button("Load Example") | |
with gr.Row(): | |
gr.Markdown("**π΅ Task: Video-to-Music**\nPrompt: *Generate music for the video*") | |
with gr.Column(scale=1.2): | |
gr.Video("example/V2M_sample-1.mp4") | |
ex10 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Video("example/V2M_sample-2.mp4") | |
ex11 = gr.Button("Load Example") | |
with gr.Column(scale=1.2): | |
gr.Video("example/V2M_sample-3.mp4") | |
ex12 = gr.Button("Load Example") | |
with gr.Row(): | |
generate_button = gr.Button("Generate", variant='primary', scale=1) | |
with gr.Row(): | |
with gr.Column(scale=6): | |
video_output = gr.Video(label="Output Video", interactive=False) | |
audio_output = gr.Audio(label="Output Audio", interactive=False) | |
send_to_init_button = gr.Button("Send to Init Audio", scale=1, visible=False) | |
send_to_init_button.click( | |
fn=lambda audio: audio, | |
inputs=[audio_output], | |
outputs=[init_audio_input] | |
) | |
inputs = [ | |
prompt, | |
negative_prompt, | |
video_file, | |
video_path, | |
audio_prompt_file, | |
audio_prompt_path, | |
seconds_start_slider, | |
seconds_total_slider, | |
cfg_scale_slider, | |
steps_slider, | |
preview_every_slider, | |
seed_textbox, | |
sampler_type_dropdown, | |
sigma_min_slider, | |
sigma_max_slider, | |
cfg_rescale_slider, | |
init_audio_checkbox, | |
init_audio_input, | |
init_noise_level_slider | |
] | |
generate_button.click( | |
fn=generate_cond, | |
inputs=inputs, | |
outputs=[ | |
video_output, | |
audio_output | |
], | |
api_name="generate" | |
) | |
ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex3.click(lambda: ["Footsteps in snow", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex7.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3737819478", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex8.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "1900718499", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex9.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "2289822202", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex10.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3498087420", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex11.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "3753837734", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
ex12.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "3510832996", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
return demo | |
def create_txt2audio_ui(model_config_map): | |
with gr.Blocks(css=".gradio-container { max-width: 1120px; margin: auto; }") as ui: | |
with gr.Tab("Generation"): | |
create_sampling_ui(model_config_map) | |
return ui | |
def toggle_custom_model(selected_model): | |
return gr.Row.update(visible=(selected_model == "Custom Model")) | |
def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): | |
global model_configurations | |
global device | |
try: | |
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() | |
except Exception: | |
has_mps = False | |
if has_mps: | |
device = torch.device("mps") | |
elif torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
print("Using device:", device) | |
model_configurations = { | |
"default": { | |
"model_config": "./model/config.json", | |
"ckpt_path": "./model/model.ckpt" | |
} | |
} | |
ui = create_txt2audio_ui(model_configurations) | |
return ui | |
if __name__ == "__main__": | |
ui = create_ui( | |
model_config_path='./model/config.json', | |
share=True | |
) | |
ui.launch() | |