Spaces:
Running
Running
# Copyright 2024 LY Corporation | |
# LY Corporation licenses this file to you under the Apache License, | |
# version 2.0 (the "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at: | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
# License for the specific language governing permissions and limitations | |
# under the License. | |
from pathlib import Path | |
import hydra | |
import pandas as pd | |
import torch | |
import torchaudio | |
from hydra.utils import instantiate | |
from omegaconf import OmegaConf | |
from promptttspp.utils import seed_everything | |
from promptttspp.utils.model import remove_weight_norm_ | |
from promptttspp.vocoders import F0AwareBigVGAN | |
from scipy import signal | |
from tqdm import tqdm | |
def lowpass_filter(x, fs=100, cutoff=20, N=5): | |
"""Lowpass filter | |
Args: | |
x (array): input signal | |
fs (int): sampling rate | |
cutoff (int): cutoff frequency | |
Returns: | |
array: filtered signal | |
""" | |
nyquist = fs // 2 | |
norm_cutoff = cutoff / nyquist | |
Wn = [norm_cutoff] | |
x_len = x.shape[-1] | |
b, a = signal.butter(N, Wn, "lowpass") | |
if x_len <= max(len(a), len(b)) * (N // 2 + 1): | |
# NOTE: input signal is too short | |
return x | |
# NOTE: use zero-phase filter | |
if isinstance(x, torch.Tensor): | |
from torchaudio.functional import filtfilt | |
a = torch.from_numpy(a).float().to(x.device) | |
b = torch.from_numpy(b).float().to(x.device) | |
y = filtfilt(x, a, b, clamp=False) | |
else: | |
y = signal.filtfilt(b, a, x) | |
return y | |
def read_prompt_candidate(filepath): | |
df_style_prompt = pd.read_csv( | |
filepath, header=None, sep="|", names=["style_key", "prompt"] | |
) | |
style_prompt_dict = {} | |
for _, row in df_style_prompt.iterrows(): | |
style_key, style_prompt = row.iloc[0], row.iloc[1] | |
assert isinstance(style_prompt, str) | |
style_prompt_dict[style_key] = list( | |
map(lambda s: s.lower().strip(), style_prompt.split(";")) | |
) | |
return style_prompt_dict | |
def read_spk_prompt_candidate(filepath): | |
df = pd.read_csv(filepath, sep="|", header=None, names=["spk", "words"]) | |
df["words"] = df["words"].map(lambda x: x.split(",")) | |
# dict(key: spk_id, value: words) | |
spk_prompt_cand_dict = df.set_index("spk")["words"].to_dict() | |
return spk_prompt_cand_dict | |
def add_spk_prompt(style_prompt, words): | |
spk_prompt = f"The speaker identity can be described as {words}." | |
prompt = f"{style_prompt}. {spk_prompt}" | |
return prompt | |
def main(cfg): | |
data_root = Path(cfg.path.data_root) | |
output_dir = Path(cfg.output_dir) | |
seed_everything(cfg.train.seed) | |
prompt_candidate = read_prompt_candidate(cfg.path.prompt_candidate_file) | |
spk_prompt_candidate = read_spk_prompt_candidate(cfg.path.spk_prompt_candidate_file) | |
mel_stats = OmegaConf.load(f"{cfg.path.mel_dir}/stats.yaml") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = instantiate(cfg.model) | |
model.load_state_dict(torch.load(cfg.ckpt_path, map_location="cpu")["model"]) | |
model = model.to(device).eval() | |
model.apply(remove_weight_norm_) | |
to_mel = instantiate(cfg.transforms).to(device).eval() | |
vocoder = instantiate(cfg.vocoder) | |
vocoder.load_state_dict( | |
torch.load(cfg.vocoder_ckpt_path, map_location="cpu")["generator"] | |
) | |
vocoder = vocoder.to(device).eval() | |
vocoder.apply(remove_weight_norm_) | |
use_col = [ | |
"spk_id", | |
"item_name", | |
"gender", | |
"pitch", | |
"speaking_speed", | |
"energy", | |
"style_prompt", | |
"style_prompt_key", | |
"seq", | |
] | |
df = pd.read_csv(cfg.label_file, usecols=use_col) | |
data = df[use_col].values.tolist() | |
for row in tqdm(data, total=len(data)): | |
spk = row[0] | |
utt_id = row[1] | |
seq = row[-1] | |
style_prompt_key = row[-2] | |
style_prompt = prompt_candidate[style_prompt_key][0] | |
if spk in spk_prompt_candidate: | |
spk_prompt = spk_prompt_candidate[spk] | |
words = ", ".join(spk_prompt) | |
if cfg.use_spk_prompt: | |
prompt = add_spk_prompt(style_prompt, words) | |
else: | |
prompt = style_prompt | |
else: | |
prompt = style_prompt | |
spk_dir = output_dir / str(spk) | |
ref_dir = spk_dir / "ref" | |
ref_mel_dir = ref_dir / "mel" | |
ref_plot_dir = ref_dir / "plot" | |
ref_wav_dir = ref_dir / "wav" | |
prompt_dir = spk_dir / "prompt" | |
prompt_mel_dir = prompt_dir / "mel" | |
prompt_plot_dir = prompt_dir / "plot" | |
prompt_wav_dir = prompt_dir / "wav" | |
dirs = [ | |
ref_mel_dir, | |
ref_plot_dir, | |
ref_wav_dir, | |
prompt_mel_dir, | |
prompt_plot_dir, | |
prompt_wav_dir, | |
] | |
[d.mkdir(parents=True, exist_ok=True) for d in dirs] | |
label = torch.LongTensor([int(s) for s in seq.split()])[None, :] | |
label = label.to(device) | |
wav, _ = torchaudio.load(data_root / f"{spk}/wav24k/{utt_id}.wav") | |
wav = wav.to(device) | |
mel = to_mel(wav) | |
mel = (mel - mel_stats["mean"]) / mel_stats["std"] | |
is_f0_aware_vocoder = isinstance(vocoder, F0AwareBigVGAN) | |
with torch.no_grad(): | |
if is_f0_aware_vocoder: | |
dec, log_cf0, vuv = model.infer( | |
label, reference_mel=mel, return_f0=True | |
) | |
# NOTE: hard code for 10ms frame shift | |
modfs = int(1.0 / (10 * 0.001)) | |
log_cf0 = lowpass_filter(log_cf0, modfs, cutoff=20) | |
f0 = log_cf0.exp() | |
f0[vuv < 0.5] = 0 | |
dec = dec * mel_stats["std"] + mel_stats["mean"] | |
o_ref = vocoder(dec, f0).squeeze(1).cpu() | |
else: | |
dec = model.infer(label, reference_mel=mel) | |
dec = dec * mel_stats["std"] + mel_stats["mean"] | |
o_ref = vocoder(dec).squeeze(1).cpu() | |
torchaudio.save(ref_wav_dir / f"{utt_id}.wav", o_ref, to_mel.sample_rate) | |
with torch.no_grad(): | |
style_prompt = [prompt] | |
if is_f0_aware_vocoder: | |
dec, log_cf0, vuv = model.infer( | |
label, style_prompt=style_prompt, return_f0=True | |
) | |
# NOTE: hard code for 10ms frame shift | |
modfs = int(1.0 / (10 * 0.001)) | |
log_cf0 = lowpass_filter(log_cf0, modfs, cutoff=20) | |
f0 = log_cf0.exp() | |
f0[vuv < 0.5] = 0 | |
dec = dec * mel_stats["std"] + mel_stats["mean"] | |
o_prompt = vocoder(dec, f0).squeeze(1).cpu() | |
else: | |
dec = model.infer(label, style_prompt=style_prompt) | |
dec = dec * mel_stats["std"] + mel_stats["mean"] | |
o_prompt = vocoder(dec).squeeze(1).cpu() | |
torchaudio.save(prompt_wav_dir / f"{utt_id}.wav", o_prompt, to_mel.sample_rate) | |
with open(output_dir / "finish", "w") as f: | |
f.write("finish") | |
if __name__ == "__main__": | |
main() | |