Spaces:
Sleeping
Sleeping
File size: 3,074 Bytes
9580089 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import torch
from huggingface_hub import HfApi, upload_file
from pathlib import Path
import shutil
import json
def prepare_model_for_upload(
checkpoint_path: str,
output_dir: str,
model_name: str = "voice-cloning-model",
organization: str = None
):
"""准备模型文件用于上传到Hugging Face Hub"""
# 创建临时目录
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)
# 加载检查点
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# 保存模型状态
model_path = output_dir / "pytorch_model.bin"
torch.save(checkpoint['model_state_dict'], model_path)
# 创建配置文件
config = {
"model_type": "speaker_encoder",
"hidden_dim": 256,
"embedding_dim": 512,
"num_layers": 3,
"dropout": 0.1,
"version": "1.0.0"
}
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)
# 复制模型卡片
shutil.copy(
Path(__file__).parent / "model_card.md",
output_dir / "README.md"
)
return output_dir
def upload_to_hub(
model_dir: str,
model_name: str,
organization: str = None,
token: str = None
):
"""上传模型到Hugging Face Hub"""
# 初始化API
api = HfApi()
# 创建仓库
repo_id = f"{organization}/{model_name}" if organization else model_name
api.create_repo(
repo_id=repo_id,
exist_ok=True,
token=token
)
# 上传文件
model_dir = Path(model_dir)
for file_path in model_dir.glob("*"):
upload_file(
path_or_fileobj=str(file_path),
path_in_repo=file_path.name,
repo_id=repo_id,
token=token
)
print(f"Uploaded {file_path.name}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Upload model to Hugging Face Hub")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to model checkpoint")
parser.add_argument("--model_name", type=str, required=True,
help="Name for the model on HuggingFace Hub")
parser.add_argument("--organization", type=str,
help="Optional organization name")
parser.add_argument("--token", type=str,
help="HuggingFace token (or set via HUGGING_FACE_TOKEN env var)")
args = parser.parse_args()
# 准备模型文件
output_dir = "tmp_model"
model_dir = prepare_model_for_upload(
args.checkpoint,
output_dir,
args.model_name,
args.organization
)
# 上传到Hub
token = args.token or os.environ.get("HUGGING_FACE_TOKEN")
if not token:
raise ValueError("Please provide a HuggingFace token")
upload_to_hub(
model_dir,
args.model_name,
args.organization,
token
)
# 清理临时文件
shutil.rmtree(output_dir) |