AudioX / app.py
Zeyue7's picture
AudioX
8ab1cf8
import spaces
import gradio as gr
import torch
import torchaudio
import os
from einops import rearrange
import gc
import spaces
import gradio as gr
import torch
import torchaudio
import os
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio
import stat
import platform
import logging
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.ERROR)
model, model_config = get_pretrained_model('HKUSTAudio/AudioX')
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
TEMP_DIR = "tmp/gradio"
os.makedirs(TEMP_DIR, exist_ok=True)
os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos")
os.makedirs(VIDEO_TEMP_DIR, exist_ok=True)
os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
@spaces.GPU(duration=10)
def generate_cond(
prompt,
negative_prompt=None,
video_file=None,
audio_prompt_file=None,
audio_prompt_path=None,
seconds_start=0,
seconds_total=10,
cfg_scale=7.0,
steps=100,
preview_every=0,
seed=-1,
sampler_type="dpmpp-3m-sde",
sigma_min=0.03,
sigma_max=500,
cfg_rescale=0.0,
use_init=False,
init_audio=None,
init_noise_level=0.1,
mask_cropfrom=None,
mask_pastefrom=None,
mask_pasteto=None,
mask_maskstart=None,
mask_maskend=None,
mask_softnessL=None,
mask_softnessR=None,
mask_marination=None,
batch_size=1
):
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print(f"Prompt: {prompt}")
preview_images = []
if preview_every == 0:
preview_every = None
try:
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
except Exception:
has_mps = False
if has_mps:
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
global model
model = model.to(device)
target_fps = model_config.get("video_fps", 5)
model_type = model_config.get("model_type", "diffusion_cond")
if video_file is not None:
actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name
else:
actual_video_path = None
if audio_prompt_file is not None:
audio_path = audio_prompt_file.name
elif audio_prompt_path:
audio_path = audio_prompt_path.strip()
else:
audio_path = None
Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps)
audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
audio_tensor = audio_tensor.to(device)
seconds_input = sample_size / sample_rate
if not prompt:
prompt = ""
conditioning = [{
"video_prompt": [Video_tensors.unsqueeze(0)],
"text_prompt": prompt,
"audio_prompt": audio_tensor.unsqueeze(0),
"seconds_start": seconds_start,
"seconds_total": seconds_input
}]
if negative_prompt:
negative_conditioning = [{
"video_prompt": [Video_tensors.unsqueeze(0)],
"text_prompt": negative_prompt,
"audio_prompt": audio_tensor.unsqueeze(0),
"seconds_start": seconds_start,
"seconds_total": seconds_total
}] * 1
else:
negative_conditioning = None
seed = int(seed)
if not use_init:
init_audio = None
input_sample_size = sample_size
def progress_callback(callback_info):
nonlocal preview_images
denoised = callback_info["denoised"]
current_step = callback_info["i"]
sigma = callback_info["sigma"]
if (current_step - 1) % preview_every == 0:
if model.pretransform is not None:
denoised = model.pretransform.decode(denoised)
denoised = rearrange(denoised, "b d n -> d (b n)")
denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
if model_type == "diffusion_cond":
audio = generate_diffusion_cond(
model,
conditioning=conditioning,
negative_conditioning=negative_conditioning,
steps=steps,
cfg_scale=cfg_scale,
batch_size=batch_size,
sample_size=input_sample_size,
sample_rate=sample_rate,
seed=seed,
device=device,
sampler_type=sampler_type,
sigma_min=sigma_min,
sigma_max=sigma_max,
init_audio=init_audio,
init_noise_level=init_noise_level,
mask_args=None,
callback=progress_callback if preview_every is not None else None,
scale_phi=cfg_rescale
)
audio = rearrange(audio, "b d n -> d (b n)")
samples_10s = 10 * sample_rate
audio = audio[:, :samples_10s]
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
output_dir = "demo_result"
os.makedirs(output_dir, exist_ok=True)
output_audio_path = f"{output_dir}/output.wav"
torchaudio.save(output_audio_path, audio, sample_rate)
if actual_video_path:
output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}"
target_width = 1280
target_height = 720
merge_video_audio(
actual_video_path,
output_audio_path,
output_video_path,
seconds_start,
seconds_total
)
else:
output_video_path = None
del actual_video_path
torch.cuda.empty_cache()
gc.collect()
return output_video_path, output_audio_path
with gr.Blocks() as interface:
gr.Markdown(
"""
# 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation
**[Paper](https://arxiv.org/abs/2503.10522) Β· [Project Page](https://zeyuet.github.io/AudioX/) Β· [Huggingface](https://huggingface.co/HKUSTAudio/AudioX) Β· [GitHub](https://github.com/ZeyueT/AudioX)**
"""
)
with gr.Tab("Generation"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
show_label=False,
placeholder="Enter your prompt"
)
negative_prompt = gr.Textbox(
show_label=False,
placeholder="Negative prompt",
visible=False
)
video_file = gr.File(label="Upload Video File")
audio_prompt_file = gr.File(
label="Upload Audio Prompt File",
visible=False
)
audio_prompt_path = gr.Textbox(
label="Audio Prompt Path",
placeholder="Enter audio file path",
visible=False
)
with gr.Row():
with gr.Column(scale=6):
with gr.Accordion("Video Params", open=False):
seconds_start = gr.Slider(
minimum=0,
maximum=512,
step=1,
value=0,
label="Video Seconds Start"
)
seconds_total = gr.Slider(
minimum=0,
maximum=10,
step=1,
value=10,
label="Seconds Total",
interactive=False
)
with gr.Row():
with gr.Column(scale=4):
with gr.Accordion("Sampler Params", open=False):
steps = gr.Slider(
minimum=1,
maximum=500,
step=1,
value=100,
label="Steps"
)
preview_every = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=0,
label="Preview Every"
)
cfg_scale = gr.Slider(
minimum=0.0,
maximum=25.0,
step=0.1,
value=7.0,
label="CFG Scale"
)
seed = gr.Textbox(
label="Seed (set to -1 for random seed)",
value="-1"
)
sampler_type = gr.Dropdown(
choices=[
"dpmpp-2m-sde",
"dpmpp-3m-sde",
"k-heun",
"k-lms",
"k-dpmpp-2s-ancestral",
"k-dpm-2",
"k-dpm-fast"
],
label="Sampler Type",
value="dpmpp-3m-sde"
)
sigma_min = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.01,
value=0.03,
label="Sigma Min"
)
sigma_max = gr.Slider(
minimum=0.0,
maximum=1000.0,
step=0.1,
value=500,
label="Sigma Max"
)
cfg_rescale = gr.Slider(
minimum=0.0,
maximum=1,
step=0.01,
value=0.0,
label="CFG Rescale Amount"
)
with gr.Row():
with gr.Column(scale=4):
with gr.Accordion("Init Audio", open=False, visible=False):
init_audio_checkbox = gr.Checkbox(label="Use Init Audio")
init_audio_input = gr.Audio(label="Init Audio")
init_noise_level = gr.Slider(
minimum=0.1,
maximum=100.0,
step=0.01,
value=0.1,
label="Init Noise Level"
)
with gr.Row():
generate_button = gr.Button("Generate", variant="primary")
with gr.Row():
with gr.Column(scale=6):
video_output = gr.Video(label="Output Video", interactive=False)
audio_output = gr.Audio(label="Output Audio", interactive=False)
inputs = [
prompt,
negative_prompt,
video_file,
audio_prompt_file,
audio_prompt_path,
seconds_start,
seconds_total,
cfg_scale,
steps,
preview_every,
seed,
sampler_type,
sigma_min,
sigma_max,
cfg_rescale,
init_audio_checkbox,
init_audio_input,
init_noise_level
]
generate_button.click(
fn=generate_cond,
inputs=inputs,
outputs=[video_output, audio_output]
)
gr.Markdown("## Examples")
with gr.Accordion("Click to show examples", open=False):
with gr.Row():
gr.Markdown("**πŸ“ Task: Text-to-Audio**")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Typing on a keyboard*")
ex1 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Ocean waves crashing*")
ex2 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Footsteps in snow*")
ex3 = gr.Button("Load Example")
with gr.Row():
gr.Markdown("**🎢 Task: Text-to-Music**")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*")
ex4 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*")
ex5 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*")
ex6 = gr.Button("Load Example")
ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
interface.queue(5).launch(server_name="0.0.0.0", server_port=7860, share=True)