|
import gradio as gr |
|
import os |
|
import tempfile |
|
import shutil |
|
from pathlib import Path |
|
from model_manager import ModelManager |
|
|
|
|
|
try: |
|
import filetype |
|
print("✅ filetype module imported successfully") |
|
except ImportError as e: |
|
print(f"⚠️ filetype import failed: {e}") |
|
print("Using fallback file type detection") |
|
|
|
|
|
StreamSDK = None |
|
run = None |
|
seed_everything = None |
|
|
|
|
|
print("=== モデルの初期化開始 ===") |
|
|
|
|
|
USE_PYTORCH = True |
|
|
|
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) |
|
if not model_manager.setup_models(): |
|
raise RuntimeError("モデルのセットアップに失敗しました。") |
|
|
|
|
|
if USE_PYTORCH: |
|
data_root = "./checkpoints/ditto_pytorch" |
|
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" |
|
else: |
|
data_root = "./checkpoints/ditto_trt_Ampere_Plus" |
|
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" |
|
|
|
try: |
|
|
|
global StreamSDK, run, seed_everything |
|
from stream_pipeline_offline import StreamSDK |
|
from inference import run, seed_everything |
|
|
|
SDK = StreamSDK(cfg_pkl, data_root) |
|
print("✅ SDK初期化成功") |
|
except Exception as e: |
|
print(f"❌ SDK初期化エラー: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |
|
|
|
def process_talking_head(audio_file, source_image): |
|
"""音声とソース画像からTalking Headビデオを生成""" |
|
|
|
if audio_file is None: |
|
return None, "音声ファイルをアップロードしてください。" |
|
|
|
if source_image is None: |
|
return None, "ソース画像をアップロードしてください。" |
|
|
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output: |
|
output_path = tmp_output.name |
|
|
|
|
|
print(f"処理開始: audio={audio_file}, image={source_image}") |
|
seed_everything(1024) |
|
run(SDK, audio_file, source_image, output_path) |
|
|
|
|
|
if os.path.exists(output_path) and os.path.getsize(output_path) > 0: |
|
return output_path, "✅ 処理が完了しました!" |
|
else: |
|
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。" |
|
|
|
except Exception as e: |
|
import traceback |
|
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return None, error_msg |
|
|
|
|
|
with gr.Blocks(title="DittoTalkingHead") as demo: |
|
gr.Markdown(""" |
|
# DittoTalkingHead - Talking Head Generation |
|
|
|
音声とソース画像から、リアルなTalking Headビデオを生成します。 |
|
|
|
## 使い方 |
|
1. **音声ファイル**(WAV形式)をアップロード |
|
2. **ソース画像**(PNG/JPG形式)をアップロード |
|
3. **生成**ボタンをクリック |
|
|
|
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。 |
|
|
|
### 技術仕様 |
|
- **モデル**: DittoTalkingHead (PyTorch版) |
|
- **GPU**: NVIDIA A100推奨 |
|
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio( |
|
label="音声ファイル (WAV)", |
|
type="filepath" |
|
) |
|
image_input = gr.Image( |
|
label="ソース画像", |
|
type="filepath" |
|
) |
|
generate_btn = gr.Button("生成", variant="primary") |
|
|
|
with gr.Column(): |
|
video_output = gr.Video( |
|
label="生成されたビデオ" |
|
) |
|
status_output = gr.Textbox( |
|
label="ステータス", |
|
lines=3 |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["example/audio.wav", "example/image.png"] |
|
], |
|
inputs=[audio_input, image_input], |
|
outputs=[video_output, status_output], |
|
fn=process_talking_head, |
|
cache_examples=True |
|
) |
|
|
|
|
|
generate_btn.click( |
|
fn=process_talking_head, |
|
inputs=[audio_input, image_input], |
|
outputs=[video_output, status_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |