style-bert-vits2-fastapi / inference.py
buchi-stdesign's picture
Update inference.py
61f8b10 verified
import os
import torch
import numpy as np
import soundfile as sf
from fastapi import FastAPI
from huggingface_hub import hf_hub_download
from src.sbv2 import utils
from src.sbv2.synthesizer_trn import SynthesizerTrn
from src.sbv2.text import text_to_sequence
MODEL_REPO = os.getenv("MODEL_REPO")
HF_TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR = "/tmp/models"
app = FastAPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
global model, hps
# config.json と model.safetensors と style_vectors.npy をダウンロード
config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json", token=HF_TOKEN, cache_dir=CACHE_DIR)
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors", token=HF_TOKEN, cache_dir=CACHE_DIR)
style_path = hf_hub_download(repo_id=MODEL_REPO, filename="style_vectors.npy", token=HF_TOKEN, cache_dir=CACHE_DIR)
# configロード
import json
with open(config_path, "r", encoding="utf-8") as f:
hps = json.load(f)
n_vocab = 77 # 小春音アミ用 symbol数
segment_size = 8192 # 通常固定値、Style-BERT-VITS2推奨
model = SynthesizerTrn(
n_vocab,
hps["model"]["p_dropout"],
segment_size // 2,
hps["model"]["inter_channels"],
hps["model"]["out_channels"],
hps["model"]["hidden_channels"],
hps["model"]["filter_channels"],
hps["model"]["dec_kernel_size"],
hps["model"]["enc_channels"],
hps["model"]["enc_out_channels"],
hps["model"]["enc_kernel_size"],
hps["model"]["enc_dilation_rate"],
hps["model"]["enc_n_layers"],
hps["model"]["flow_hidden_channels"],
hps["model"]["flow_kernel_size"],
hps["model"]["flow_n_layers"],
hps["model"]["flow_n_flows"],
hps["model"]["sdp_hidden_channels"],
hps["model"]["sdp_kernel_size"],
hps["model"]["sdp_n_layers"],
hps["model"]["sdp_dropout"],
hps["audio"]["sampling_rate"],
hps["audio"]["filter_length"],
hps["audio"]["hop_length"],
hps["audio"]["win_length"],
hps["model"]["resblock"],
hps["model"]["resblock_kernel_sizes"],
hps["model"]["resblock_dilation_sizes"],
hps["model"]["upsample_rates"],
hps["model"]["upsample_initial_channel"],
hps["model"]["upsample_kernel_sizes"],
hps["model"].get("gin_channels", 0)
).to(device)
# safetensorsロード
utils.load_checkpoint(model_path, model, strict=True)
model.eval()
@app.get("/voice")
def synthesize(text: str):
# テキストを音素に変換
sequence = np.array(text_to_sequence(text, hps["data"]["text_cleaners"]), dtype=np.int64)
sequence = torch.LongTensor(sequence).unsqueeze(0).to(device)
# 推論
with torch.no_grad():
audio = model.infer(sequence, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0)[0][0, 0].data.cpu().numpy()
# 一時WAVファイル保存
output_path = "/tmp/output.wav"
sf.write(output_path, audio, hps["audio"]["sampling_rate"])
return {"audio_path": output_path}