Spaces:
Paused
Paused
`.gitignore`に`docs/`フォルダを追加して、無視するファイルを更新
Browse files- .gitignore +2 -0
- README_hf_space.md +50 -0
- app.py +130 -0
- model_manager.py +267 -0
- requirements.txt +46 -0
- setup.sh +18 -0
.gitignore
CHANGED
|
@@ -37,6 +37,8 @@ log/*
|
|
| 37 |
# Folders to ignore
|
| 38 |
example/
|
| 39 |
ToDo/
|
|
|
|
|
|
|
| 40 |
|
| 41 |
!example/audio.wav
|
| 42 |
!example/image.png
|
|
|
|
| 37 |
# Folders to ignore
|
| 38 |
example/
|
| 39 |
ToDo/
|
| 40 |
+
docs/
|
| 41 |
+
|
| 42 |
|
| 43 |
!example/audio.wav
|
| 44 |
!example/image.png
|
README_hf_space.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DittoTalkingHead
|
| 3 |
+
emoji: 🗣️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.19.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
hardware: a100-large
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# DittoTalkingHead - Talking Head Generation
|
| 15 |
+
|
| 16 |
+
音声とソース画像から、リアルなTalking Headビデオを生成します。
|
| 17 |
+
|
| 18 |
+
## 特徴
|
| 19 |
+
|
| 20 |
+
- 高品質なリップシンク
|
| 21 |
+
- 自然な表情と頭部の動き
|
| 22 |
+
- TensorRTによる高速推論
|
| 23 |
+
- 自動モデルダウンロード機能
|
| 24 |
+
|
| 25 |
+
## 使い方
|
| 26 |
+
|
| 27 |
+
1. **音声ファイル**(WAV形式)をアップロード
|
| 28 |
+
2. **ソース画像**(PNG/JPG形式)をアップロード
|
| 29 |
+
3. **生成**ボタンをクリック
|
| 30 |
+
|
| 31 |
+
## 技術仕様
|
| 32 |
+
|
| 33 |
+
- **GPU**: NVIDIA A100(推奨)
|
| 34 |
+
- **フレームワーク**: PyTorch
|
| 35 |
+
- **モデル**: DittoTalkingHead (PyTorch版)
|
| 36 |
+
- **モデルサイズ**: 約2.5GB
|
| 37 |
+
|
| 38 |
+
## 注意事項
|
| 39 |
+
|
| 40 |
+
- 初回実行時は、モデルの自動ダウンロードのため時間がかかります(約10-15分)
|
| 41 |
+
- GPU(A100)環境での実行を推奨します
|
| 42 |
+
- 音声ファイルは16kHz WAV形式が推奨です
|
| 43 |
+
|
| 44 |
+
## モデルソース
|
| 45 |
+
|
| 46 |
+
モデルは[digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)から自動的にダウンロードされます。
|
| 47 |
+
|
| 48 |
+
## ライセンス
|
| 49 |
+
|
| 50 |
+
Apache License 2.0
|
app.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from model_manager import ModelManager
|
| 7 |
+
from stream_pipeline_offline import StreamSDK
|
| 8 |
+
from inference import run, seed_everything
|
| 9 |
+
|
| 10 |
+
# モデルの初期化
|
| 11 |
+
print("=== モデルの初期化開始 ===")
|
| 12 |
+
|
| 13 |
+
# PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
|
| 14 |
+
USE_PYTORCH = True
|
| 15 |
+
|
| 16 |
+
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
|
| 17 |
+
if not model_manager.setup_models():
|
| 18 |
+
raise RuntimeError("モデルのセットアップに失敗しました。")
|
| 19 |
+
|
| 20 |
+
# SDKの初期化
|
| 21 |
+
if USE_PYTORCH:
|
| 22 |
+
data_root = "./checkpoints/ditto_pytorch"
|
| 23 |
+
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
|
| 24 |
+
else:
|
| 25 |
+
data_root = "./checkpoints/ditto_trt_Ampere_Plus"
|
| 26 |
+
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
SDK = StreamSDK(cfg_pkl, data_root)
|
| 30 |
+
print("✅ SDK初期化成功")
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"❌ SDK初期化エラー: {e}")
|
| 33 |
+
raise
|
| 34 |
+
|
| 35 |
+
def process_talking_head(audio_file, source_image):
|
| 36 |
+
"""音声とソース画像からTalking Headビデオを生成"""
|
| 37 |
+
|
| 38 |
+
if audio_file is None:
|
| 39 |
+
return None, "音声ファイルをアップロードしてください。"
|
| 40 |
+
|
| 41 |
+
if source_image is None:
|
| 42 |
+
return None, "ソース画像をアップロードしてください。"
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
# 一時ファイルの作成
|
| 46 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output:
|
| 47 |
+
output_path = tmp_output.name
|
| 48 |
+
|
| 49 |
+
# 処理実行
|
| 50 |
+
print(f"処理開始: audio={audio_file}, image={source_image}")
|
| 51 |
+
seed_everything(1024)
|
| 52 |
+
run(SDK, audio_file, source_image, output_path)
|
| 53 |
+
|
| 54 |
+
# 結果の確認
|
| 55 |
+
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
| 56 |
+
return output_path, "✅ 処理が完了しました!"
|
| 57 |
+
else:
|
| 58 |
+
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
import traceback
|
| 62 |
+
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
|
| 63 |
+
print(error_msg)
|
| 64 |
+
return None, error_msg
|
| 65 |
+
|
| 66 |
+
# Gradio UI
|
| 67 |
+
with gr.Blocks(title="DittoTalkingHead") as demo:
|
| 68 |
+
gr.Markdown("""
|
| 69 |
+
# DittoTalkingHead - Talking Head Generation
|
| 70 |
+
|
| 71 |
+
音声とソース画像から、リアルなTalking Headビデオを生成します。
|
| 72 |
+
|
| 73 |
+
## 使い方
|
| 74 |
+
1. **音声ファイル**(WAV形式)をアップロード
|
| 75 |
+
2. **ソース画像**(PNG/JPG形式)をアップロード
|
| 76 |
+
3. **生成**ボタンをクリック
|
| 77 |
+
|
| 78 |
+
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
|
| 79 |
+
|
| 80 |
+
### 技術仕様
|
| 81 |
+
- **モデル**: DittoTalkingHead (PyTorch版)
|
| 82 |
+
- **GPU**: NVIDIA A100推奨
|
| 83 |
+
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
|
| 84 |
+
""")
|
| 85 |
+
|
| 86 |
+
with gr.Row():
|
| 87 |
+
with gr.Column():
|
| 88 |
+
audio_input = gr.Audio(
|
| 89 |
+
label="音声ファイル (WAV)",
|
| 90 |
+
type="filepath"
|
| 91 |
+
)
|
| 92 |
+
image_input = gr.Image(
|
| 93 |
+
label="ソース画像",
|
| 94 |
+
type="filepath"
|
| 95 |
+
)
|
| 96 |
+
generate_btn = gr.Button("生成", variant="primary")
|
| 97 |
+
|
| 98 |
+
with gr.Column():
|
| 99 |
+
video_output = gr.Video(
|
| 100 |
+
label="生成されたビデオ"
|
| 101 |
+
)
|
| 102 |
+
status_output = gr.Textbox(
|
| 103 |
+
label="ステータス",
|
| 104 |
+
lines=3
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# サンプル
|
| 108 |
+
gr.Examples(
|
| 109 |
+
examples=[
|
| 110 |
+
["example/audio.wav", "example/image.png"]
|
| 111 |
+
],
|
| 112 |
+
inputs=[audio_input, image_input],
|
| 113 |
+
outputs=[video_output, status_output],
|
| 114 |
+
fn=process_talking_head,
|
| 115 |
+
cache_examples=True
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# イベントハンドラ
|
| 119 |
+
generate_btn.click(
|
| 120 |
+
fn=process_talking_head,
|
| 121 |
+
inputs=[audio_input, image_input],
|
| 122 |
+
outputs=[video_output, status_output]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
demo.launch(
|
| 127 |
+
server_name="0.0.0.0",
|
| 128 |
+
server_port=7860,
|
| 129 |
+
share=False
|
| 130 |
+
)
|
model_manager.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import requests
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import hashlib
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
class ModelManager:
|
| 11 |
+
def __init__(self, cache_dir="/tmp/models", use_pytorch=False):
|
| 12 |
+
self.cache_dir = Path(cache_dir)
|
| 13 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 14 |
+
self.use_pytorch = use_pytorch
|
| 15 |
+
|
| 16 |
+
# Hugging Face公式リポジトリからモデルを取得
|
| 17 |
+
base_url = "https://huggingface.co/digital-avatar/ditto-talkinghead/resolve/main"
|
| 18 |
+
|
| 19 |
+
if use_pytorch:
|
| 20 |
+
# PyTorchモデルの設定
|
| 21 |
+
self.model_configs = [
|
| 22 |
+
{
|
| 23 |
+
"name": "appearance_extractor.pth",
|
| 24 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/appearance_extractor.pth",
|
| 25 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
| 26 |
+
"dest_file": "appearance_extractor.pth",
|
| 27 |
+
"type": "file"
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"name": "decoder.pth",
|
| 31 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/decoder.pth",
|
| 32 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
| 33 |
+
"dest_file": "decoder.pth",
|
| 34 |
+
"type": "file"
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"name": "lmdm_v0.4_hubert.pth",
|
| 38 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/lmdm_v0.4_hubert.pth",
|
| 39 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
| 40 |
+
"dest_file": "lmdm_v0.4_hubert.pth",
|
| 41 |
+
"type": "file"
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"name": "motion_extractor.pth",
|
| 45 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/motion_extractor.pth",
|
| 46 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
| 47 |
+
"dest_file": "motion_extractor.pth",
|
| 48 |
+
"type": "file"
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"name": "stitch_network.pth",
|
| 52 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/stitch_network.pth",
|
| 53 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
| 54 |
+
"dest_file": "stitch_network.pth",
|
| 55 |
+
"type": "file"
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"name": "warp_network.pth",
|
| 59 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/warp_network.pth",
|
| 60 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
| 61 |
+
"dest_file": "warp_network.pth",
|
| 62 |
+
"type": "file"
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"name": "v0.4_hubert_cfg.pkl",
|
| 66 |
+
"url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg.pkl",
|
| 67 |
+
"dest_dir": "checkpoints/ditto_cfg",
|
| 68 |
+
"dest_file": "v0.4_hubert_cfg.pkl",
|
| 69 |
+
"type": "file"
|
| 70 |
+
}
|
| 71 |
+
]
|
| 72 |
+
else:
|
| 73 |
+
# TensorRTモデルの設定
|
| 74 |
+
self.model_configs = [
|
| 75 |
+
{
|
| 76 |
+
"name": "ditto_trt_models",
|
| 77 |
+
"url": os.environ.get("DITTO_TRT_URL", f"{base_url}/checkpoints/ditto_trt_Ampere_Plus.tar.gz"),
|
| 78 |
+
"dest_dir": "checkpoints",
|
| 79 |
+
"type": "archive",
|
| 80 |
+
"extract_subdir": "ditto_trt_Ampere_Plus"
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"name": "v0.4_hubert_cfg_trt.pkl",
|
| 84 |
+
"url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl",
|
| 85 |
+
"dest_dir": "checkpoints/ditto_cfg",
|
| 86 |
+
"dest_file": "v0.4_hubert_cfg_trt.pkl",
|
| 87 |
+
"type": "file"
|
| 88 |
+
}
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
self.progress_file = self.cache_dir / "download_progress.json"
|
| 92 |
+
self.download_progress = self.load_progress()
|
| 93 |
+
|
| 94 |
+
def load_progress(self):
|
| 95 |
+
"""ダウンロード進捗の読み込み"""
|
| 96 |
+
if self.progress_file.exists():
|
| 97 |
+
with open(self.progress_file, 'r') as f:
|
| 98 |
+
return json.load(f)
|
| 99 |
+
return {}
|
| 100 |
+
|
| 101 |
+
def save_progress(self):
|
| 102 |
+
"""ダウンロード進捗の保存"""
|
| 103 |
+
with open(self.progress_file, 'w') as f:
|
| 104 |
+
json.dump(self.download_progress, f)
|
| 105 |
+
|
| 106 |
+
def get_file_hash(self, filepath):
|
| 107 |
+
"""ファイルのハッシュ値を計算"""
|
| 108 |
+
sha256_hash = hashlib.sha256()
|
| 109 |
+
with open(filepath, "rb") as f:
|
| 110 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
| 111 |
+
sha256_hash.update(byte_block)
|
| 112 |
+
return sha256_hash.hexdigest()
|
| 113 |
+
|
| 114 |
+
def download_file(self, url, dest_path, retries=3):
|
| 115 |
+
"""ファイルのダウンロード(レジューム対応)"""
|
| 116 |
+
dest_path = Path(dest_path)
|
| 117 |
+
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 118 |
+
|
| 119 |
+
headers = {}
|
| 120 |
+
mode = 'wb'
|
| 121 |
+
resume_pos = 0
|
| 122 |
+
|
| 123 |
+
# レジューム処理
|
| 124 |
+
if dest_path.exists():
|
| 125 |
+
resume_pos = dest_path.stat().st_size
|
| 126 |
+
headers['Range'] = f'bytes={resume_pos}-'
|
| 127 |
+
mode = 'ab'
|
| 128 |
+
|
| 129 |
+
for attempt in range(retries):
|
| 130 |
+
try:
|
| 131 |
+
response = requests.get(url, headers=headers, stream=True, timeout=30)
|
| 132 |
+
response.raise_for_status()
|
| 133 |
+
|
| 134 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 135 |
+
if resume_pos > 0:
|
| 136 |
+
total_size += resume_pos
|
| 137 |
+
|
| 138 |
+
with open(dest_path, mode) as f:
|
| 139 |
+
with tqdm(total=total_size, initial=resume_pos, unit='B', unit_scale=True, desc=dest_path.name) as pbar:
|
| 140 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 141 |
+
if chunk:
|
| 142 |
+
f.write(chunk)
|
| 143 |
+
pbar.update(len(chunk))
|
| 144 |
+
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"ダウンロードエラー (試行 {attempt + 1}/{retries}): {e}")
|
| 149 |
+
if attempt < retries - 1:
|
| 150 |
+
time.sleep(5) # 再試行前に待機
|
| 151 |
+
else:
|
| 152 |
+
raise
|
| 153 |
+
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
def extract_archive(self, archive_path, dest_dir, extract_subdir=None):
|
| 157 |
+
"""アーカイブの展開"""
|
| 158 |
+
import tarfile
|
| 159 |
+
import zipfile
|
| 160 |
+
|
| 161 |
+
archive_path = Path(archive_path)
|
| 162 |
+
dest_dir = Path(dest_dir)
|
| 163 |
+
temp_dir = dest_dir / "temp_extract"
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
if archive_path.suffix == '.gz' or archive_path.suffix == '.tar' or str(archive_path).endswith('.tar.gz'):
|
| 167 |
+
with tarfile.open(archive_path, 'r:*') as tar:
|
| 168 |
+
if extract_subdir:
|
| 169 |
+
# 一時ディレクトリに展開してから移動
|
| 170 |
+
temp_dir.mkdir(exist_ok=True)
|
| 171 |
+
tar.extractall(temp_dir)
|
| 172 |
+
# 特定のサブディレクトリを移動
|
| 173 |
+
src_dir = temp_dir / extract_subdir
|
| 174 |
+
if src_dir.exists():
|
| 175 |
+
shutil.move(str(src_dir), str(dest_dir / extract_subdir))
|
| 176 |
+
shutil.rmtree(temp_dir)
|
| 177 |
+
else:
|
| 178 |
+
tar.extractall(dest_dir)
|
| 179 |
+
elif archive_path.suffix == '.zip':
|
| 180 |
+
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
|
| 181 |
+
zip_ref.extractall(dest_dir)
|
| 182 |
+
else:
|
| 183 |
+
raise ValueError(f"Unsupported archive format: {archive_path.suffix}")
|
| 184 |
+
except Exception as e:
|
| 185 |
+
if temp_dir.exists():
|
| 186 |
+
shutil.rmtree(temp_dir)
|
| 187 |
+
raise e
|
| 188 |
+
|
| 189 |
+
def check_models_exist(self):
|
| 190 |
+
"""必要なモデルが存在するかチェック"""
|
| 191 |
+
missing_models = []
|
| 192 |
+
for config in self.model_configs:
|
| 193 |
+
if config['type'] == 'file':
|
| 194 |
+
dest_path = Path(config['dest_dir']) / config['dest_file']
|
| 195 |
+
if not dest_path.exists():
|
| 196 |
+
missing_models.append(config)
|
| 197 |
+
else: # archive
|
| 198 |
+
dest_dir = Path(config['dest_dir'])
|
| 199 |
+
if not dest_dir.exists() or not any(dest_dir.iterdir()):
|
| 200 |
+
missing_models.append(config)
|
| 201 |
+
return missing_models
|
| 202 |
+
|
| 203 |
+
def download_models(self):
|
| 204 |
+
"""必要なモデルをダウンロード"""
|
| 205 |
+
missing_models = self.check_models_exist()
|
| 206 |
+
|
| 207 |
+
if not missing_models:
|
| 208 |
+
print("すべてのモデルが既に存在します。")
|
| 209 |
+
return True
|
| 210 |
+
|
| 211 |
+
print(f"{len(missing_models)}個のモデルをダウンロードします...")
|
| 212 |
+
|
| 213 |
+
for config in missing_models:
|
| 214 |
+
size_info = config.get('size', '不明')
|
| 215 |
+
print(f"\n{config['name']} をダウンロード中... (サイズ: {size_info})")
|
| 216 |
+
|
| 217 |
+
# キャッシュパスの設定
|
| 218 |
+
cache_filename = f"{config['name']}.download"
|
| 219 |
+
cache_path = self.cache_dir / cache_filename
|
| 220 |
+
|
| 221 |
+
try:
|
| 222 |
+
# ダウンロード
|
| 223 |
+
if not cache_path.exists() or self.download_progress.get(config['name'], {}).get('status') != 'completed':
|
| 224 |
+
self.download_file(config['url'], cache_path)
|
| 225 |
+
self.download_progress[config['name']] = {'status': 'completed'}
|
| 226 |
+
self.save_progress()
|
| 227 |
+
|
| 228 |
+
# 展開またはコピー
|
| 229 |
+
if config['type'] == 'file':
|
| 230 |
+
dest_dir = Path(config['dest_dir'])
|
| 231 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 232 |
+
dest_path = dest_dir / config['dest_file']
|
| 233 |
+
shutil.copy2(cache_path, dest_path)
|
| 234 |
+
else: # archive
|
| 235 |
+
dest_dir = Path(config['dest_dir'])
|
| 236 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 237 |
+
print(f"{config['name']} を展開中...")
|
| 238 |
+
extract_subdir = config.get('extract_subdir')
|
| 239 |
+
self.extract_archive(cache_path, dest_dir, extract_subdir)
|
| 240 |
+
|
| 241 |
+
print(f"{config['name']} のセットアップ完了")
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print(f"エラー: {config['name']} のダウンロード中にエラーが発生しました: {e}")
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
return True
|
| 248 |
+
|
| 249 |
+
def setup_models(self):
|
| 250 |
+
"""モデルのセットアップ(メイン処理)"""
|
| 251 |
+
print("=== DittoTalkingHead モデルセットアップ ===")
|
| 252 |
+
print(f"キャッシュディレクトリ: {self.cache_dir}")
|
| 253 |
+
|
| 254 |
+
success = self.download_models()
|
| 255 |
+
|
| 256 |
+
if success:
|
| 257 |
+
print("\n✅ すべてのモデルのセットアップが完了しました!")
|
| 258 |
+
else:
|
| 259 |
+
print("\n❌ モデルのセットアップ中にエラーが発生しました。")
|
| 260 |
+
|
| 261 |
+
return success
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
# テスト実行
|
| 266 |
+
manager = ModelManager()
|
| 267 |
+
manager.setup_models()
|
requirements.txt
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch==2.5.1
|
| 3 |
+
torchvision==0.20.1
|
| 4 |
+
torchaudio==2.5.1
|
| 5 |
+
numpy==2.0.1
|
| 6 |
+
pillow==11.0.0
|
| 7 |
+
|
| 8 |
+
# Audio processing
|
| 9 |
+
librosa==0.10.2.post1
|
| 10 |
+
soundfile==0.13.0
|
| 11 |
+
audioread==3.0.1
|
| 12 |
+
soxr==0.5.0.post1
|
| 13 |
+
|
| 14 |
+
# Video/Image processing
|
| 15 |
+
opencv-python-headless==4.10.0.84
|
| 16 |
+
imageio==2.36.1
|
| 17 |
+
imageio-ffmpeg==0.5.1
|
| 18 |
+
scikit-image==0.25.0
|
| 19 |
+
|
| 20 |
+
# Machine learning
|
| 21 |
+
scikit-learn==1.6.0
|
| 22 |
+
scipy==1.15.0
|
| 23 |
+
numba==0.60.0
|
| 24 |
+
|
| 25 |
+
# TensorRT (GPU acceleration)
|
| 26 |
+
tensorrt==8.6.1
|
| 27 |
+
tensorrt-bindings==8.6.1
|
| 28 |
+
tensorrt-libs==8.6.1
|
| 29 |
+
polygraphy
|
| 30 |
+
colored
|
| 31 |
+
|
| 32 |
+
# Web interface
|
| 33 |
+
gradio==4.19.0
|
| 34 |
+
|
| 35 |
+
# Utilities
|
| 36 |
+
tqdm==4.67.1
|
| 37 |
+
requests==2.32.3
|
| 38 |
+
pyyaml==6.0.2
|
| 39 |
+
joblib==1.4.2
|
| 40 |
+
cython==3.0.11
|
| 41 |
+
|
| 42 |
+
# CUDA dependencies
|
| 43 |
+
cuda-python==12.6.2.post1
|
| 44 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 45 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 46 |
+
nvidia-cudnn-cu12==9.6.0.74
|
setup.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Setup script for Hugging Face Space
|
| 4 |
+
echo "=== DittoTalkingHead Setup Script ==="
|
| 5 |
+
|
| 6 |
+
# Create necessary directories
|
| 7 |
+
mkdir -p checkpoints/ditto_cfg
|
| 8 |
+
mkdir -p tmp
|
| 9 |
+
mkdir -p output
|
| 10 |
+
|
| 11 |
+
# Install system dependencies if needed
|
| 12 |
+
# apt-get update && apt-get install -y ffmpeg
|
| 13 |
+
|
| 14 |
+
# Run model download (PyTorch models)
|
| 15 |
+
echo "Starting model download (PyTorch models)..."
|
| 16 |
+
python -c "from model_manager import ModelManager; manager = ModelManager(use_pytorch=True); manager.setup_models()"
|
| 17 |
+
|
| 18 |
+
echo "Setup complete!"
|