|
import argparse
|
|
import json
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from style_bert_vits2.logging import logger
|
|
|
|
|
|
def download_bert_models():
|
|
with open("bert/bert_models.json", encoding="utf-8") as fp:
|
|
models = json.load(fp)
|
|
for k, v in models.items():
|
|
local_path = Path("bert").joinpath(k)
|
|
for file in v["files"]:
|
|
if not Path(local_path).joinpath(file).exists():
|
|
logger.info(f"Downloading {k} {file}")
|
|
hf_hub_download(v["repo_id"], file, local_dir=local_path)
|
|
|
|
|
|
def download_slm_model():
|
|
local_path = Path("slm/wavlm-base-plus/")
|
|
file = "pytorch_model.bin"
|
|
if not Path(local_path).joinpath(file).exists():
|
|
logger.info(f"Downloading wavlm-base-plus {file}")
|
|
hf_hub_download("microsoft/wavlm-base-plus", file, local_dir=local_path)
|
|
|
|
|
|
def download_pretrained_models():
|
|
files = ["G_0.safetensors", "D_0.safetensors", "DUR_0.safetensors"]
|
|
local_path = Path("pretrained")
|
|
for file in files:
|
|
if not Path(local_path).joinpath(file).exists():
|
|
logger.info(f"Downloading pretrained {file}")
|
|
hf_hub_download(
|
|
"litagin/Style-Bert-VITS2-1.0-base", file, local_dir=local_path
|
|
)
|
|
|
|
|
|
def download_jp_extra_pretrained_models():
|
|
files = ["G_0.safetensors", "D_0.safetensors", "WD_0.safetensors"]
|
|
local_path = Path("pretrained_jp_extra")
|
|
for file in files:
|
|
if not Path(local_path).joinpath(file).exists():
|
|
logger.info(f"Downloading JP-Extra pretrained {file}")
|
|
hf_hub_download(
|
|
"litagin/Style-Bert-VITS2-2.0-base-JP-Extra", file, local_dir=local_path
|
|
)
|
|
|
|
|
|
def download_default_models():
|
|
files = [
|
|
"jvnv-F1-jp/config.json",
|
|
"jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors",
|
|
"jvnv-F1-jp/style_vectors.npy",
|
|
"jvnv-F2-jp/config.json",
|
|
"jvnv-F2-jp/jvnv-F2_e166_s20000.safetensors",
|
|
"jvnv-F2-jp/style_vectors.npy",
|
|
"jvnv-M1-jp/config.json",
|
|
"jvnv-M1-jp/jvnv-M1-jp_e158_s14000.safetensors",
|
|
"jvnv-M1-jp/style_vectors.npy",
|
|
"jvnv-M2-jp/config.json",
|
|
"jvnv-M2-jp/jvnv-M2-jp_e159_s17000.safetensors",
|
|
"jvnv-M2-jp/style_vectors.npy",
|
|
]
|
|
for file in files:
|
|
if not Path(f"model_assets/{file}").exists():
|
|
logger.info(f"Downloading {file}")
|
|
hf_hub_download(
|
|
"litagin/style_bert_vits2_jvnv",
|
|
file,
|
|
local_dir="model_assets",
|
|
)
|
|
additional_files = {
|
|
"litagin/sbv2_koharune_ami": [
|
|
"koharune-ami/config.json",
|
|
"koharune-ami/style_vectors.npy",
|
|
"koharune-ami/koharune-ami.safetensors",
|
|
],
|
|
"litagin/sbv2_amitaro": [
|
|
"amitaro/config.json",
|
|
"amitaro/style_vectors.npy",
|
|
"amitaro/amitaro.safetensors",
|
|
],
|
|
}
|
|
for repo_id, files in additional_files.items():
|
|
for file in files:
|
|
if not Path(f"model_assets/{file}").exists():
|
|
logger.info(f"Downloading {file}")
|
|
hf_hub_download(
|
|
repo_id,
|
|
file,
|
|
local_dir="model_assets",
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--skip_default_models", action="store_true")
|
|
parser.add_argument("--only_infer", action="store_true")
|
|
parser.add_argument(
|
|
"--dataset_root",
|
|
type=str,
|
|
help="Dataset root path (default: Data)",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--assets_root",
|
|
type=str,
|
|
help="Assets root path (default: model_assets)",
|
|
default=None,
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
download_bert_models()
|
|
|
|
if not args.skip_default_models:
|
|
download_default_models()
|
|
if not args.only_infer:
|
|
download_slm_model()
|
|
download_pretrained_models()
|
|
download_jp_extra_pretrained_models()
|
|
|
|
|
|
default_paths_yml = Path("configs/default_paths.yml")
|
|
paths_yml = Path("configs/paths.yml")
|
|
if not paths_yml.exists():
|
|
shutil.copy(default_paths_yml, paths_yml)
|
|
|
|
if args.dataset_root is None and args.assets_root is None:
|
|
return
|
|
|
|
|
|
with open(paths_yml, encoding="utf-8") as f:
|
|
yml_data = yaml.safe_load(f)
|
|
if args.assets_root is not None:
|
|
yml_data["assets_root"] = args.assets_root
|
|
if args.dataset_root is not None:
|
|
yml_data["dataset_root"] = args.dataset_root
|
|
with open(paths_yml, "w", encoding="utf-8") as f:
|
|
yaml.dump(yml_data, f, allow_unicode=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|