File size: 3,363 Bytes
9a6dac6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import json
import torch
from tqdm import tqdm
from huggingface_hub import snapshot_download
from models import AudioDiffusion, DDPMScheduler, DDIMScheduler
from audioldm.audio.stft import TacotronSTFT
from audioldm.variational_autoencoder import AutoencoderKL


class Tango:
    def __init__(self, name="declare-lab/tango", device="cuda:0"):

        path = snapshot_download(repo_id=name)

        vae_config = json.load(open("{}/vae_config.json".format(path)))
        stft_config = json.load(open("{}/stft_config.json".format(path)))
        main_config = json.load(open("{}/main_config.json".format(path)))

        self.vae = AutoencoderKL(**vae_config).to(device)
        self.stft = TacotronSTFT(**stft_config).to(device)
        self.model = AudioDiffusion(**main_config).to(device)

        vae_weights = torch.load(
            "{}/pytorch_model_vae.bin".format(path), map_location=device
        )
        stft_weights = torch.load(
            "{}/pytorch_model_stft.bin".format(path), map_location=device
        )
        main_weights = torch.load(
            "{}/pytorch_model_main.bin".format(path), map_location=device
        )

        self.vae.load_state_dict(vae_weights)
        self.stft.load_state_dict(stft_weights)
        self.model.load_state_dict(main_weights)

        print("Successfully loaded checkpoint from:", name)

        self.vae.eval()
        self.stft.eval()
        self.model.eval()

        # self.scheduler = DDPMScheduler.from_pretrained(
        #     main_config["scheduler_name"], subfolder="scheduler"
        # )
        self.scheduler = DDIMScheduler.from_pretrained(
            main_config["scheduler_name"], subfolder="scheduler"
        )

    def chunks(self, lst, n):
        """Yield successive n-sized chunks from a list."""
        for i in range(0, len(lst), n):
            yield lst[i : i + n]

    def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
        """Genrate audio for a single prompt string."""
        with torch.no_grad():
            latents = self.model.inference(
                [prompt],
                self.scheduler,
                steps,
                guidance,
                samples,
                disable_progress=disable_progress,
            )
            mel = self.vae.decode_first_stage(latents)
            wave = self.vae.decode_to_waveform(mel)
        return wave[0]

    def generate_for_batch(
        self,
        prompts,
        steps=100,
        guidance=3,
        samples=1,
        batch_size=8,
        disable_progress=True,
    ):
        """Genrate audio for a list of prompt strings."""
        outputs = []
        for k in tqdm(range(0, len(prompts), batch_size)):
            batch = prompts[k : k + batch_size]
            with torch.no_grad():
                latents = self.model.inference(
                    batch,
                    self.scheduler,
                    steps,
                    guidance,
                    samples,
                    disable_progress=disable_progress,
                )
                mel = self.vae.decode_first_stage(latents)
                wave = self.vae.decode_to_waveform(mel)
                outputs += [item for item in wave]
        if samples == 1:
            return outputs
        else:
            return list(self.chunks(outputs, samples))