Spaces:
Paused
Paused
| import os | |
| import shutil | |
| import requests | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| import hashlib | |
| import json | |
| import time | |
| class ModelManager: | |
| def __init__(self, cache_dir="/tmp/models", use_pytorch=False): | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.use_pytorch = use_pytorch | |
| # Hugging Face公式リポジトリからモデルを取得 | |
| base_url = "https://huggingface.co/digital-avatar/ditto-talkinghead/resolve/main" | |
| if use_pytorch: | |
| # PyTorchモデルの設定 | |
| self.model_configs = [ | |
| { | |
| "name": "appearance_extractor.pth", | |
| "url": f"{base_url}/checkpoints/ditto_pytorch/models/appearance_extractor.pth", | |
| "dest_dir": "checkpoints/ditto_pytorch/models", | |
| "dest_file": "appearance_extractor.pth", | |
| "type": "file" | |
| }, | |
| { | |
| "name": "decoder.pth", | |
| "url": f"{base_url}/checkpoints/ditto_pytorch/models/decoder.pth", | |
| "dest_dir": "checkpoints/ditto_pytorch/models", | |
| "dest_file": "decoder.pth", | |
| "type": "file" | |
| }, | |
| { | |
| "name": "lmdm_v0.4_hubert.pth", | |
| "url": f"{base_url}/checkpoints/ditto_pytorch/models/lmdm_v0.4_hubert.pth", | |
| "dest_dir": "checkpoints/ditto_pytorch/models", | |
| "dest_file": "lmdm_v0.4_hubert.pth", | |
| "type": "file" | |
| }, | |
| { | |
| "name": "motion_extractor.pth", | |
| "url": f"{base_url}/checkpoints/ditto_pytorch/models/motion_extractor.pth", | |
| "dest_dir": "checkpoints/ditto_pytorch/models", | |
| "dest_file": "motion_extractor.pth", | |
| "type": "file" | |
| }, | |
| { | |
| "name": "stitch_network.pth", | |
| "url": f"{base_url}/checkpoints/ditto_pytorch/models/stitch_network.pth", | |
| "dest_dir": "checkpoints/ditto_pytorch/models", | |
| "dest_file": "stitch_network.pth", | |
| "type": "file" | |
| }, | |
| { | |
| "name": "warp_network.pth", | |
| "url": f"{base_url}/checkpoints/ditto_pytorch/models/warp_network.pth", | |
| "dest_dir": "checkpoints/ditto_pytorch/models", | |
| "dest_file": "warp_network.pth", | |
| "type": "file" | |
| }, | |
| { | |
| "name": "v0.4_hubert_cfg.pkl", | |
| "url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg.pkl", | |
| "dest_dir": "checkpoints/ditto_cfg", | |
| "dest_file": "v0.4_hubert_cfg.pkl", | |
| "type": "file" | |
| } | |
| ] | |
| else: | |
| # TensorRTモデルの設定 | |
| self.model_configs = [ | |
| { | |
| "name": "ditto_trt_models", | |
| "url": os.environ.get("DITTO_TRT_URL", f"{base_url}/checkpoints/ditto_trt_Ampere_Plus.tar.gz"), | |
| "dest_dir": "checkpoints", | |
| "type": "archive", | |
| "extract_subdir": "ditto_trt_Ampere_Plus" | |
| }, | |
| { | |
| "name": "v0.4_hubert_cfg_trt.pkl", | |
| "url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl", | |
| "dest_dir": "checkpoints/ditto_cfg", | |
| "dest_file": "v0.4_hubert_cfg_trt.pkl", | |
| "type": "file" | |
| } | |
| ] | |
| self.progress_file = self.cache_dir / "download_progress.json" | |
| self.download_progress = self.load_progress() | |
| def load_progress(self): | |
| """ダウンロード進捗の読み込み""" | |
| if self.progress_file.exists(): | |
| with open(self.progress_file, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| def save_progress(self): | |
| """ダウンロード進捗の保存""" | |
| with open(self.progress_file, 'w') as f: | |
| json.dump(self.download_progress, f) | |
| def get_file_hash(self, filepath): | |
| """ファイルのハッシュ値を計算""" | |
| sha256_hash = hashlib.sha256() | |
| with open(filepath, "rb") as f: | |
| for byte_block in iter(lambda: f.read(4096), b""): | |
| sha256_hash.update(byte_block) | |
| return sha256_hash.hexdigest() | |
| def download_file(self, url, dest_path, retries=3): | |
| """ファイルのダウンロード(レジューム対応)""" | |
| dest_path = Path(dest_path) | |
| dest_path.parent.mkdir(parents=True, exist_ok=True) | |
| headers = {} | |
| mode = 'wb' | |
| resume_pos = 0 | |
| # レジューム処理 | |
| if dest_path.exists(): | |
| resume_pos = dest_path.stat().st_size | |
| headers['Range'] = f'bytes={resume_pos}-' | |
| mode = 'ab' | |
| for attempt in range(retries): | |
| try: | |
| response = requests.get(url, headers=headers, stream=True, timeout=30) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| if resume_pos > 0: | |
| total_size += resume_pos | |
| with open(dest_path, mode) as f: | |
| with tqdm(total=total_size, initial=resume_pos, unit='B', unit_scale=True, desc=dest_path.name) as pbar: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| pbar.update(len(chunk)) | |
| return True | |
| except Exception as e: | |
| print(f"ダウンロードエラー (試行 {attempt + 1}/{retries}): {e}") | |
| if attempt < retries - 1: | |
| time.sleep(5) # 再試行前に待機 | |
| else: | |
| raise | |
| return False | |
| def extract_archive(self, archive_path, dest_dir, extract_subdir=None): | |
| """アーカイブの展開""" | |
| import tarfile | |
| import zipfile | |
| archive_path = Path(archive_path) | |
| dest_dir = Path(dest_dir) | |
| temp_dir = dest_dir / "temp_extract" | |
| try: | |
| if archive_path.suffix == '.gz' or archive_path.suffix == '.tar' or str(archive_path).endswith('.tar.gz'): | |
| with tarfile.open(archive_path, 'r:*') as tar: | |
| if extract_subdir: | |
| # 一時ディレクトリに展開してから移動 | |
| temp_dir.mkdir(exist_ok=True) | |
| tar.extractall(temp_dir) | |
| # 特定のサブディレクトリを移動 | |
| src_dir = temp_dir / extract_subdir | |
| if src_dir.exists(): | |
| shutil.move(str(src_dir), str(dest_dir / extract_subdir)) | |
| shutil.rmtree(temp_dir) | |
| else: | |
| tar.extractall(dest_dir) | |
| elif archive_path.suffix == '.zip': | |
| with zipfile.ZipFile(archive_path, 'r') as zip_ref: | |
| zip_ref.extractall(dest_dir) | |
| else: | |
| raise ValueError(f"Unsupported archive format: {archive_path.suffix}") | |
| except Exception as e: | |
| if temp_dir.exists(): | |
| shutil.rmtree(temp_dir) | |
| raise e | |
| def check_models_exist(self): | |
| """必要なモデルが存在するかチェック""" | |
| missing_models = [] | |
| for config in self.model_configs: | |
| if config['type'] == 'file': | |
| dest_path = Path(config['dest_dir']) / config['dest_file'] | |
| if not dest_path.exists(): | |
| missing_models.append(config) | |
| else: # archive | |
| dest_dir = Path(config['dest_dir']) | |
| if not dest_dir.exists() or not any(dest_dir.iterdir()): | |
| missing_models.append(config) | |
| return missing_models | |
| def download_models(self): | |
| """必要なモデルをダウンロード""" | |
| missing_models = self.check_models_exist() | |
| if not missing_models: | |
| print("すべてのモデルが既に存在します。") | |
| return True | |
| print(f"{len(missing_models)}個のモデルをダウンロードします...") | |
| for config in missing_models: | |
| size_info = config.get('size', '不明') | |
| print(f"\n{config['name']} をダウンロード中... (サイズ: {size_info})") | |
| # キャッシュパスの設定 | |
| cache_filename = f"{config['name']}.download" | |
| cache_path = self.cache_dir / cache_filename | |
| try: | |
| # ダウンロード | |
| if not cache_path.exists() or self.download_progress.get(config['name'], {}).get('status') != 'completed': | |
| self.download_file(config['url'], cache_path) | |
| self.download_progress[config['name']] = {'status': 'completed'} | |
| self.save_progress() | |
| # 展開またはコピー | |
| if config['type'] == 'file': | |
| dest_dir = Path(config['dest_dir']) | |
| dest_dir.mkdir(parents=True, exist_ok=True) | |
| dest_path = dest_dir / config['dest_file'] | |
| shutil.copy2(cache_path, dest_path) | |
| else: # archive | |
| dest_dir = Path(config['dest_dir']) | |
| dest_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"{config['name']} を展開中...") | |
| extract_subdir = config.get('extract_subdir') | |
| self.extract_archive(cache_path, dest_dir, extract_subdir) | |
| print(f"{config['name']} のセットアップ完了") | |
| except Exception as e: | |
| print(f"エラー: {config['name']} のダウンロード中にエラーが発生しました: {e}") | |
| return False | |
| return True | |
| def setup_models(self): | |
| """モデルのセットアップ(メイン処理)""" | |
| print("=== DittoTalkingHead モデルセットアップ ===") | |
| print(f"キャッシュディレクトリ: {self.cache_dir}") | |
| success = self.download_models() | |
| if success: | |
| print("\n✅ すべてのモデルのセットアップが完了しました!") | |
| else: | |
| print("\n❌ モデルのセットアップ中にエラーが発生しました。") | |
| return success | |
| if __name__ == "__main__": | |
| # テスト実行 | |
| manager = ModelManager() | |
| manager.setup_models() |