|
import gradio as gr |
|
import torchaudio |
|
from audiocraft.models import AudioGen |
|
from audiocraft.data.audio import audio_write |
|
|
|
model = AudioGen.get_pretrained('facebook/audiogen-medium') |
|
|
|
def infer(prompt): |
|
|
|
model.set_generation_params(duration=5) |
|
descriptions = [prompt] |
|
wav = model.generate(descriptions) |
|
|
|
for idx, one_wav in enumerate(wav): |
|
|
|
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) |
|
|
|
return "0.wav" |
|
|
|
gr.Interface( |
|
fn = infer, |
|
inputs = gr.Textbox(), |
|
outputs = gr.Audio() |
|
).launch() |