Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,858 Bytes
9a6dac6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import subprocess
import time
import json
import torch
from tqdm import tqdm
import soundfile as sf
from models import AudioDiffusion, DDPMScheduler
from audioldm.audio.stft import TacotronSTFT
from audioldm.variational_autoencoder import AutoencoderKL
from cog import BasePredictor, Input, Path
MODEL_URL = "https://weights.replicate.delivery/default/declare-lab/tango.tar"
MODEL_CACHE = "tango_weights"
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
self.models = {k: Tango(name=k) for k in ["tango2", "tango2-full"]}
def predict(
self,
prompt: str = Input(
description="Input prompt",
default="Quiet speech and then and airplane flying away",
),
model: str = Input(
description="choose a model",
choices=[
"tango2",
"tango2-full",
],
default="tango2",
),
steps: int = Input(description="inference steps", default=100),
guidance: float = Input(description="guidance scale", default=3),
) -> Path:
"""Run a single prediction on the model"""
tango = self.models[model]
audio = tango.generate(prompt, steps, guidance)
out = "/tmp/output.wav"
sf.write(out, audio, samplerate=16000)
return Path(out)
class Tango:
def __init__(self, name="tango2", path=MODEL_CACHE, device="cuda:0"):
# weights are downloaded from f"https://huggingface.co/declare-lab/{name}/tree/main" and saved to MODEL_CACHE
vae_config = json.load(open(f"{path}/{name}/vae_config.json"))
stft_config = json.load(open(f"{path}/{name}/stft_config.json"))
main_config = json.load(open(f"{path}/{name}/main_config.json"))
self.vae = AutoencoderKL(**vae_config).to(device)
self.stft = TacotronSTFT(**stft_config).to(device)
self.model = AudioDiffusion(**main_config).to(device)
vae_weights = torch.load(
f"{path}/{name}/pytorch_model_vae.bin", map_location=device
)
stft_weights = torch.load(
f"{path}/{name}/pytorch_model_stft.bin", map_location=device
)
main_weights = torch.load(
f"{path}/{name}/pytorch_model_main.bin", map_location=device
)
self.vae.load_state_dict(vae_weights)
self.stft.load_state_dict(stft_weights)
self.model.load_state_dict(main_weights)
self.vae.eval()
self.stft.eval()
self.model.eval()
self.scheduler = DDPMScheduler.from_pretrained(
main_config["scheduler_name"], subfolder="scheduler"
)
def chunks(self, lst, n):
"""Yield successive n-sized chunks from a list."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
"""Generate audio for a single prompt string."""
with torch.no_grad():
latents = self.model.inference(
[prompt],
self.scheduler,
steps,
guidance,
samples,
disable_progress=disable_progress,
)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
return wave[0]
def generate_for_batch(
self,
prompts,
steps=100,
guidance=3,
samples=1,
batch_size=8,
disable_progress=True,
):
"""Generate audio for a list of prompt strings."""
outputs = []
for k in tqdm(range(0, len(prompts), batch_size)):
batch = prompts[k : k + batch_size]
with torch.no_grad():
latents = self.model.inference(
batch,
self.scheduler,
steps,
guidance,
samples,
disable_progress=disable_progress,
)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
outputs += [item for item in wave]
if samples == 1:
return outputs
else:
return list(self.chunks(outputs, samples))
|