File size: 2,944 Bytes
3324de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Prediction interface for Cog ⚙️
# https://cog.run/python

import os
import subprocess
import time
import json
from cog import BasePredictor, Input, Path
from diffusers import AutoencoderOobleck
import soundfile as sf
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from tangoflux.model import TangoFlux
from tangoflux import TangoFluxInference

MODEL_CACHE = "model_cache"
MODEL_URL = (
    "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
)


class CachedTangoFluxInference(TangoFluxInference):
    ## load the weights from replicate.delivery for faster booting
    def __init__(self, name="declare-lab/TangoFlux", device="cuda", cached_paths=None):
        if cached_paths:
            paths = cached_paths
        else:
            paths = snapshot_download(repo_id=name)

        self.vae = AutoencoderOobleck()
        vae_weights = load_file(f"{paths}/vae.safetensors")
        self.vae.load_state_dict(vae_weights)
        weights = load_file(f"{paths}/tangoflux.safetensors")

        with open(f"{paths}/config.json", "r") as f:
            config = json.load(f)
        self.model = TangoFlux(config)
        self.model.load_state_dict(weights, strict=False)
        self.vae.to(device)
        self.model.to(device)


def download_weights(url, dest):
    start = time.time()
    print("downloading url: ", url)
    print("downloading to: ", dest)
    subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
    print("downloading took: ", time.time() - start)


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""

        if not os.path.exists(MODEL_CACHE):
            print("downloading")
            download_weights(MODEL_URL, MODEL_CACHE)

        self.model = CachedTangoFluxInference(
            cached_paths=f"{MODEL_CACHE}/declare-lab/TangoFlux"
        )

    def predict(
        self,
        prompt: str = Input(
            description="Input prompt", default="Hammer slowly hitting the wooden table"
        ),
        duration: int = Input(
            description="Duration of the output audio in seconds", default=10
        ),
        steps: int = Input(
            description="Number of inference steps", ge=1, le=200, default=25
        ),
        guidance_scale: float = Input(
            description="Scale for classifier-free guidance", ge=1, le=20, default=4.5
        ),
    ) -> Path:
        """Run a single prediction on the model"""

        audio = self.model.generate(
            prompt,
            steps=steps,
            guidance_scale=guidance_scale,
            duration=duration,
        )
        audio_numpy = audio.numpy()
        out_path = "/tmp/out.wav"

        sf.write(
            out_path, audio_numpy.T, samplerate=self.model.vae.config.sampling_rate
        )
        return Path(out_path)