Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
import soundfile as sf | |
from fastapi import FastAPI | |
from huggingface_hub import hf_hub_download | |
from src.sbv2 import utils | |
from src.sbv2.synthesizer_trn import SynthesizerTrn | |
from src.sbv2.text import text_to_sequence | |
MODEL_REPO = os.getenv("MODEL_REPO") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
CACHE_DIR = "/tmp/models" | |
app = FastAPI() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
global model, hps | |
# config.json と model.safetensors と style_vectors.npy をダウンロード | |
config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json", token=HF_TOKEN, cache_dir=CACHE_DIR) | |
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors", token=HF_TOKEN, cache_dir=CACHE_DIR) | |
style_path = hf_hub_download(repo_id=MODEL_REPO, filename="style_vectors.npy", token=HF_TOKEN, cache_dir=CACHE_DIR) | |
# configロード | |
import json | |
with open(config_path, "r", encoding="utf-8") as f: | |
hps = json.load(f) | |
n_vocab = 77 # 小春音アミ用 symbol数 | |
segment_size = 8192 # 通常固定値、Style-BERT-VITS2推奨 | |
model = SynthesizerTrn( | |
n_vocab, | |
hps["model"]["p_dropout"], | |
segment_size // 2, | |
hps["model"]["inter_channels"], | |
hps["model"]["out_channels"], | |
hps["model"]["hidden_channels"], | |
hps["model"]["filter_channels"], | |
hps["model"]["dec_kernel_size"], | |
hps["model"]["enc_channels"], | |
hps["model"]["enc_out_channels"], | |
hps["model"]["enc_kernel_size"], | |
hps["model"]["enc_dilation_rate"], | |
hps["model"]["enc_n_layers"], | |
hps["model"]["flow_hidden_channels"], | |
hps["model"]["flow_kernel_size"], | |
hps["model"]["flow_n_layers"], | |
hps["model"]["flow_n_flows"], | |
hps["model"]["sdp_hidden_channels"], | |
hps["model"]["sdp_kernel_size"], | |
hps["model"]["sdp_n_layers"], | |
hps["model"]["sdp_dropout"], | |
hps["audio"]["sampling_rate"], | |
hps["audio"]["filter_length"], | |
hps["audio"]["hop_length"], | |
hps["audio"]["win_length"], | |
hps["model"]["resblock"], | |
hps["model"]["resblock_kernel_sizes"], | |
hps["model"]["resblock_dilation_sizes"], | |
hps["model"]["upsample_rates"], | |
hps["model"]["upsample_initial_channel"], | |
hps["model"]["upsample_kernel_sizes"], | |
hps["model"].get("gin_channels", 0) | |
).to(device) | |
# safetensorsロード | |
utils.load_checkpoint(model_path, model, strict=True) | |
model.eval() | |
def synthesize(text: str): | |
# テキストを音素に変換 | |
sequence = np.array(text_to_sequence(text, hps["data"]["text_cleaners"]), dtype=np.int64) | |
sequence = torch.LongTensor(sequence).unsqueeze(0).to(device) | |
# 推論 | |
with torch.no_grad(): | |
audio = model.infer(sequence, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0)[0][0, 0].data.cpu().numpy() | |
# 一時WAVファイル保存 | |
output_path = "/tmp/output.wav" | |
sf.write(output_path, audio, hps["audio"]["sampling_rate"]) | |
return {"audio_path": output_path} | |