|
import gradio as gr |
|
import torch |
|
from diffusers import AudioLDMPipeline |
|
|
|
from transformers import AutoProcessor, ClapModel |
|
|
|
|
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
|
|
repo_id = "cvssp/audioldm-m-full" |
|
pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device) |
|
pipe.unet = torch.compile(pipe.unet) |
|
|
|
|
|
clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device) |
|
processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full") |
|
|
|
generator = torch.Generator(device) |
|
|
|
|
|
def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates): |
|
if text is None: |
|
raise gr.Error("θ―·ζδΎζζ¬θΎε
₯") |
|
|
|
waveforms = pipe( |
|
text, |
|
audio_length_in_s=duration, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt, |
|
num_waveforms_per_prompt=n_candidates if n_candidates else 1, |
|
generator=generator.manual_seed(int(random_seed)), |
|
)["audios"] |
|
|
|
if waveforms.shape[0] > 1: |
|
waveform = score_waveforms(text, waveforms) |
|
else: |
|
waveform = waveforms[0] |
|
|
|
return gr.make_waveform((16000, waveform)) |
|
|
|
|
|
def score_waveforms(text, waveforms): |
|
inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True) |
|
inputs = {key: inputs[key].to(device) for key in inputs} |
|
with torch.no_grad(): |
|
logits_per_text = clap_model(**inputs).logits_per_text |
|
probs = logits_per_text.softmax(dim=-1) |
|
most_probable = torch.argmax(probs) |
|
waveform = waveforms[most_probable] |
|
return waveform |
|
|
|
|
|
iface = gr.Blocks() |
|
|
|
with iface: |
|
with gr.Group(): |
|
with gr.Box(): |
|
textbox = gr.Textbox( |
|
max_lines=1, |
|
label="θ¦ζ±", |
|
info="θ¦ζ±", |
|
elem_id="prompt-in", |
|
) |
|
negative_textbox = gr.Textbox( |
|
max_lines=1, |
|
label="ζ΄θ―¦η»ηθ¦ζ±", |
|
info="ζ΄θ―¦η»ηθ¦ζ±", |
|
elem_id="prompt-in", |
|
) |
|
|
|
with gr.Accordion("ε±εΌζ΄ε€ιι‘Ή", open=False): |
|
seed = gr.Number( |
|
value=45, |
|
label="η§ε", |
|
info="δΈεη§εζδΈεη»ζ,ηΈεη§εζηΈεη»ζ", |
|
) |
|
duration = gr.Slider(2.5, 10, value=5, step=2.5, label="ζη»ζΆι΄(η§)") |
|
guidance_scale = gr.Slider( |
|
0, |
|
4, |
|
value=2.5, |
|
step=0.5, |
|
label="质ι", |
|
info="ε€§ζζ΄ε₯½η质ιεδΈζζ¬ηηΈε
³ζ§οΌε°ζζ΄ε₯½ηε€ζ ·ζ§", |
|
) |
|
n_candidates = gr.Slider( |
|
1, |
|
3, |
|
value=3, |
|
step=1, |
|
label="ειζ°ι", |
|
info="θΏδΈͺζ°εζ§εΆειζ°ι", |
|
) |
|
|
|
outputs = gr.Video(label="Output", elem_id="output-video") |
|
btn = gr.Button("Submit").style(full_width=True) |
|
|
|
btn.click( |
|
text2audio, |
|
inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates], |
|
outputs=[outputs], |
|
) |
|
|
|
iface.queue(max_size=10).launch(debug=True) |
|
|