|
import random |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from TTS.vocoder.configs import WavernnConfig |
|
from TTS.vocoder.models.wavernn import Wavernn, WavernnArgs |
|
|
|
|
|
def test_wavernn(): |
|
config = WavernnConfig() |
|
config.model_args = WavernnArgs( |
|
rnn_dims=512, |
|
fc_dims=512, |
|
mode="mold", |
|
mulaw=False, |
|
pad=2, |
|
use_aux_net=True, |
|
use_upsample_net=True, |
|
upsample_factors=[4, 8, 8], |
|
feat_dims=80, |
|
compute_dims=128, |
|
res_out_dims=128, |
|
num_res_blocks=10, |
|
) |
|
config.audio.hop_length = 256 |
|
config.audio.sample_rate = 2048 |
|
|
|
dummy_x = torch.rand((2, 1280)) |
|
dummy_m = torch.rand((2, 80, 9)) |
|
y_size = random.randrange(20, 60) |
|
dummy_y = torch.rand((80, y_size)) |
|
|
|
|
|
model = Wavernn(config) |
|
output = model(dummy_x, dummy_m) |
|
assert np.all(output.shape == (2, 1280, 30)), output.shape |
|
|
|
|
|
config.model_args.mode = "gauss" |
|
model = Wavernn(config) |
|
output = model(dummy_x, dummy_m) |
|
assert np.all(output.shape == (2, 1280, 2)), output.shape |
|
|
|
|
|
config.model_args.mode = 4 |
|
model = Wavernn(config) |
|
output = model(dummy_x, dummy_m) |
|
assert np.all(output.shape == (2, 1280, 2**4)), output.shape |
|
output = model.inference(dummy_y, True, 5500, 550) |
|
assert np.all(output.shape == (256 * (y_size - 1),)) |
|
|