Spaces:
Build error
Build error
| import os | |
| import torch | |
| from huggingface_hub import HfApi, upload_file | |
| from pathlib import Path | |
| import shutil | |
| import json | |
| def prepare_model_for_upload( | |
| checkpoint_path: str, | |
| output_dir: str, | |
| model_name: str = "voice-cloning-model", | |
| organization: str = None | |
| ): | |
| """准备模型文件用于上传到Hugging Face Hub""" | |
| # 创建临时目录 | |
| output_dir = Path(output_dir) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # 加载检查点 | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| # 保存模型状态 | |
| model_path = output_dir / "pytorch_model.bin" | |
| torch.save(checkpoint['model_state_dict'], model_path) | |
| # 创建配置文件 | |
| config = { | |
| "model_type": "speaker_encoder", | |
| "hidden_dim": 256, | |
| "embedding_dim": 512, | |
| "num_layers": 3, | |
| "dropout": 0.1, | |
| "version": "1.0.0" | |
| } | |
| with open(output_dir / "config.json", "w") as f: | |
| json.dump(config, f, indent=2) | |
| # 复制模型卡片 | |
| shutil.copy( | |
| Path(__file__).parent / "model_card.md", | |
| output_dir / "README.md" | |
| ) | |
| return output_dir | |
| def upload_to_hub( | |
| model_dir: str, | |
| model_name: str, | |
| organization: str = None, | |
| token: str = None | |
| ): | |
| """上传模型到Hugging Face Hub""" | |
| # 初始化API | |
| api = HfApi() | |
| # 创建仓库 | |
| repo_id = f"{organization}/{model_name}" if organization else model_name | |
| api.create_repo( | |
| repo_id=repo_id, | |
| exist_ok=True, | |
| token=token | |
| ) | |
| # 上传文件 | |
| model_dir = Path(model_dir) | |
| for file_path in model_dir.glob("*"): | |
| upload_file( | |
| path_or_fileobj=str(file_path), | |
| path_in_repo=file_path.name, | |
| repo_id=repo_id, | |
| token=token | |
| ) | |
| print(f"Uploaded {file_path.name}") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Upload model to Hugging Face Hub") | |
| parser.add_argument("--checkpoint", type=str, required=True, | |
| help="Path to model checkpoint") | |
| parser.add_argument("--model_name", type=str, required=True, | |
| help="Name for the model on HuggingFace Hub") | |
| parser.add_argument("--organization", type=str, | |
| help="Optional organization name") | |
| parser.add_argument("--token", type=str, | |
| help="HuggingFace token (or set via HUGGING_FACE_TOKEN env var)") | |
| args = parser.parse_args() | |
| # 准备模型文件 | |
| output_dir = "tmp_model" | |
| model_dir = prepare_model_for_upload( | |
| args.checkpoint, | |
| output_dir, | |
| args.model_name, | |
| args.organization | |
| ) | |
| # 上传到Hub | |
| token = args.token or os.environ.get("HUGGING_FACE_TOKEN") | |
| if not token: | |
| raise ValueError("Please provide a HuggingFace token") | |
| upload_to_hub( | |
| model_dir, | |
| args.model_name, | |
| args.organization, | |
| token | |
| ) | |
| # 清理临时文件 | |
| shutil.rmtree(output_dir) |