Spaces:
Paused
Paused
| from gradio_client import Client, handle_file | |
| from datetime import datetime | |
| import os | |
| import shutil | |
| import logging | |
| import time | |
| from typing import Tuple, Optional | |
| class TalkingHeadAPIClient: | |
| """DittoTalkingHead API クライアント""" | |
| def __init__(self, space_name: str = "O-ken5481/talkingAvater_bgk", max_retries: int = 3, retry_delay: int = 5): | |
| """ | |
| Args: | |
| space_name: Hugging Face SpaceのID(デフォルト: O-ken5481/talkingAvater_bgk) | |
| max_retries: 最大リトライ回数 | |
| retry_delay: リトライ間隔(秒) | |
| """ | |
| self.space_name = space_name | |
| self.max_retries = max_retries | |
| self.retry_delay = retry_delay | |
| self.logger = self._setup_logger() | |
| self.client = None | |
| self._connect() | |
| def _setup_logger(self) -> logging.Logger: | |
| """ロガーの設定""" | |
| logger = logging.getLogger('TalkingHeadAPIClient') | |
| logger.setLevel(logging.INFO) | |
| if not logger.handlers: | |
| handler = logging.StreamHandler() | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S') | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| return logger | |
| def _connect(self) -> None: | |
| """APIへの接続""" | |
| for attempt in range(self.max_retries): | |
| try: | |
| self.logger.info(f"接続開始: {self.space_name} (試行 {attempt + 1}/{self.max_retries})") | |
| self.client = Client(self.space_name) | |
| self.logger.info("接続成功") | |
| return | |
| except Exception as e: | |
| self.logger.error(f"接続失敗: {e}") | |
| if attempt < self.max_retries - 1: | |
| self.logger.info(f"{self.retry_delay}秒後にリトライします...") | |
| time.sleep(self.retry_delay) | |
| else: | |
| raise ConnectionError(f"APIへの接続に失敗しました: {e}") | |
| def generate_video(self, audio_path: str, image_path: str) -> Tuple[Optional[dict], str]: | |
| """ | |
| API経由で動画生成 | |
| Args: | |
| audio_path: 音声ファイルのパス | |
| image_path: 画像ファイルのパス | |
| Returns: | |
| tuple: (video_data, status_message) | |
| """ | |
| # ファイルの存在確認 | |
| if not os.path.exists(audio_path): | |
| error_msg = f"音声ファイルが見つかりません: {audio_path}" | |
| self.logger.error(error_msg) | |
| return None, error_msg | |
| if not os.path.exists(image_path): | |
| error_msg = f"画像ファイルが見つかりません: {image_path}" | |
| self.logger.error(error_msg) | |
| return None, error_msg | |
| # API呼び出し | |
| for attempt in range(self.max_retries): | |
| try: | |
| self.logger.info(f"ファイルアップロード: {audio_path}, {image_path}") | |
| self.logger.info("処理開始...") | |
| result = self.client.predict( | |
| audio_file=handle_file(audio_path), | |
| source_image=handle_file(image_path), | |
| api_name="/process_talking_head" | |
| ) | |
| self.logger.info("動画生成完了") | |
| return result | |
| except Exception as e: | |
| self.logger.error(f"処理エラー (試行 {attempt + 1}/{self.max_retries}): {e}") | |
| if attempt < self.max_retries - 1: | |
| self.logger.info(f"{self.retry_delay}秒後にリトライします...") | |
| time.sleep(self.retry_delay) | |
| else: | |
| error_msg = f"動画生成に失敗しました: {e}" | |
| return None, error_msg | |
| def save_with_timestamp(self, video_path: str, output_dir: str = "example") -> Optional[str]: | |
| """ | |
| 動画をタイムスタンプ付きで保存 | |
| Args: | |
| video_path: 生成された動画のパス | |
| output_dir: 保存先ディレクトリ | |
| Returns: | |
| str: 保存されたファイルパス(エラー時はNone) | |
| """ | |
| try: | |
| # 動画パスの確認 | |
| if not video_path or not os.path.exists(video_path): | |
| self.logger.error(f"動画ファイルが見つかりません: {video_path}") | |
| return None | |
| # 出力ディレクトリの作成 | |
| os.makedirs(output_dir, exist_ok=True) | |
| # YYYY-MM-DD_HH-MM-SS.mp4 形式で保存 | |
| timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| output_path = os.path.join(output_dir, f"{timestamp}.mp4") | |
| # ファイルをコピー | |
| shutil.copy2(video_path, output_path) | |
| # ファイルサイズの確認 | |
| file_size = os.path.getsize(output_path) | |
| self.logger.info(f"保存完了: {output_path} (サイズ: {file_size:,} bytes)") | |
| return output_path | |
| except Exception as e: | |
| self.logger.error(f"保存エラー: {e}") | |
| return None | |
| def process_with_save(self, audio_path: str, image_path: str, output_dir: str = "example") -> Tuple[Optional[str], str]: | |
| """ | |
| 動画生成と保存を一括実行 | |
| Args: | |
| audio_path: 音声ファイルのパス | |
| image_path: 画像ファイルのパス | |
| output_dir: 保存先ディレクトリ | |
| Returns: | |
| tuple: (saved_path, status_message) | |
| """ | |
| # 動画生成 | |
| result = self.generate_video(audio_path, image_path) | |
| if result[0] is None: | |
| return None, result[1] | |
| video_data, status = result | |
| # 動画の保存 | |
| if isinstance(video_data, dict) and 'video' in video_data: | |
| saved_path = self.save_with_timestamp(video_data['video'], output_dir) | |
| if saved_path: | |
| return saved_path, f"{status}\n保存先: {saved_path}" | |
| else: | |
| return None, f"{status}\n保存に失敗しました" | |
| else: | |
| return None, f"予期しないレスポンス形式: {video_data}" | |
| def main(): | |
| """テストスクリプトのメイン関数""" | |
| # ロギング設定 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| # クライアント初期化 | |
| try: | |
| client = TalkingHeadAPIClient() | |
| except Exception as e: | |
| logging.error(f"クライアント初期化失敗: {e}") | |
| return | |
| # サンプルファイルを使用 | |
| audio_path = "example/audio.wav" | |
| image_path = "example/image.png" | |
| # ファイルの存在確認 | |
| if not os.path.exists(audio_path): | |
| logging.error(f"音声ファイルが見つかりません: {audio_path}") | |
| return | |
| if not os.path.exists(image_path): | |
| logging.error(f"画像ファイルが見つかりません: {image_path}") | |
| return | |
| try: | |
| # 動画生成と保存 | |
| saved_path, status = client.process_with_save(audio_path, image_path) | |
| if saved_path: | |
| print(f"\n✅ 成功!") | |
| print(f"ステータス: {status}") | |
| print(f"動画を確認してください: {saved_path}") | |
| else: | |
| print(f"\n❌ 失敗") | |
| print(f"ステータス: {status}") | |
| except KeyboardInterrupt: | |
| logging.info("処理を中断しました") | |
| except Exception as e: | |
| logging.error(f"予期しないエラー: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() |