Spaces:
Paused
Paused
推論キャッシュと並列処理の機能を追加し、`process_talking_head_optimized`関数をキャッシュと並列処理に対応させました。また、Gradioインターフェースにキャッシュ管理機能を追加しました。
Browse files- app_optimized.py +175 -20
- core/optimization/__init__.py +10 -1
- core/optimization/inference_cache.py +386 -0
- core/optimization/parallel_inference.py +268 -0
- core/optimization/parallel_processing.py +400 -0
app_optimized.py
CHANGED
|
@@ -18,7 +18,12 @@ from core.optimization import (
|
|
| 18 |
GPUOptimizer,
|
| 19 |
AvatarCache,
|
| 20 |
AvatarTokenManager,
|
| 21 |
-
ColdStartOptimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
# サンプルファイルのディレクトリを定義
|
|
@@ -44,6 +49,18 @@ avatar_cache = AvatarCache(cache_dir="/tmp/avatar_cache", ttl_days=14)
|
|
| 44 |
token_manager = AvatarTokenManager(avatar_cache)
|
| 45 |
print(f"✅ アバターキャッシュ初期化: {avatar_cache.get_cache_info()}")
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# モデルの初期化(最適化版)
|
| 48 |
USE_PYTORCH = True
|
| 49 |
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
|
|
@@ -92,6 +109,17 @@ except Exception as e:
|
|
| 92 |
traceback.print_exc()
|
| 93 |
raise
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def prepare_avatar(image_file) -> Dict[str, Any]:
|
| 96 |
"""
|
| 97 |
画像を事前処理してアバタートークンを生成
|
|
@@ -150,16 +178,19 @@ def process_talking_head_optimized(
|
|
| 150 |
audio_file,
|
| 151 |
source_image,
|
| 152 |
avatar_token: Optional[str] = None,
|
| 153 |
-
use_resolution_optimization: bool = True
|
|
|
|
|
|
|
| 154 |
):
|
| 155 |
"""
|
| 156 |
-
最適化されたTalking Head
|
| 157 |
|
| 158 |
Args:
|
| 159 |
audio_file: 音声ファイル
|
| 160 |
source_image: ソース画像(avatar_tokenがない場合に使用)
|
| 161 |
avatar_token: 事前生成されたアバタートークン
|
| 162 |
use_resolution_optimization: 解像度最適化を使用するか
|
|
|
|
| 163 |
"""
|
| 164 |
|
| 165 |
if audio_file is None:
|
|
@@ -184,7 +215,6 @@ def process_talking_head_optimized(
|
|
| 184 |
|
| 185 |
# 解像度最適化設定を適用
|
| 186 |
if use_resolution_optimization:
|
| 187 |
-
# SDKに解像度設定を適用
|
| 188 |
setup_kwargs = {
|
| 189 |
"max_size": FIXED_RESOLUTION, # 320固定
|
| 190 |
"sampling_timesteps": resolution_optimizer.get_diffusion_steps() # 25
|
|
@@ -193,15 +223,68 @@ def process_talking_head_optimized(
|
|
| 193 |
else:
|
| 194 |
setup_kwargs = {}
|
| 195 |
|
| 196 |
-
#
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
# 結果の確認
|
| 207 |
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
|
@@ -210,8 +293,12 @@ def process_talking_head_optimized(
|
|
| 210 |
✅ 処理完了!
|
| 211 |
処理時間: {process_time:.2f}秒
|
| 212 |
解像度: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
"""
|
| 216 |
return output_path, perf_info
|
| 217 |
else:
|
|
@@ -233,6 +320,8 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
| 233 |
- 🎯 画像事前アップロード&キャッシュ機能
|
| 234 |
- ⚡ GPU最適化(Mixed Precision, torch.compile)
|
| 235 |
- 💾 Cold Start最適化
|
|
|
|
|
|
|
| 236 |
|
| 237 |
## 使い方
|
| 238 |
### 方法1: 通常の使用
|
|
@@ -271,6 +360,16 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
| 271 |
value=True
|
| 272 |
)
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
generate_btn = gr.Button("🎬 生成", variant="primary")
|
| 275 |
|
| 276 |
with gr.Column():
|
|
@@ -305,16 +404,29 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
| 305 |
|
| 306 |
# タブ3: 最適化情報
|
| 307 |
with gr.TabItem("📊 最適化情報"):
|
| 308 |
-
gr.
|
|
|
|
|
|
|
|
|
|
| 309 |
### 現在の最適化設定
|
| 310 |
|
| 311 |
{resolution_optimizer.get_optimization_summary()}
|
| 312 |
|
| 313 |
{gpu_optimizer.get_optimization_summary()}
|
| 314 |
|
| 315 |
-
###
|
| 316 |
{avatar_cache.get_cache_info()}
|
|
|
|
|
|
|
|
|
|
| 317 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
# サンプル
|
| 320 |
example_audio = EXAMPLES_DIR / "audio.wav"
|
|
@@ -323,9 +435,9 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
| 323 |
if example_audio.exists() and example_image.exists():
|
| 324 |
gr.Examples(
|
| 325 |
examples=[
|
| 326 |
-
[str(example_audio), str(example_image), None, True]
|
| 327 |
],
|
| 328 |
-
inputs=[audio_input, image_input, token_input, use_optimization],
|
| 329 |
outputs=[video_output, status_output],
|
| 330 |
fn=process_talking_head_optimized
|
| 331 |
)
|
|
@@ -333,7 +445,7 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
| 333 |
# イベントハンドラ
|
| 334 |
generate_btn.click(
|
| 335 |
fn=process_talking_head_optimized,
|
| 336 |
-
inputs=[audio_input, image_input, token_input, use_optimization],
|
| 337 |
outputs=[video_output, status_output]
|
| 338 |
)
|
| 339 |
|
|
@@ -342,6 +454,49 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
| 342 |
inputs=[avatar_image_input],
|
| 343 |
outputs=[prepare_output]
|
| 344 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
if __name__ == "__main__":
|
| 347 |
# Cold Start最適化設定でGradioを起動
|
|
|
|
| 18 |
GPUOptimizer,
|
| 19 |
AvatarCache,
|
| 20 |
AvatarTokenManager,
|
| 21 |
+
ColdStartOptimizer,
|
| 22 |
+
InferenceCache,
|
| 23 |
+
CachedInference,
|
| 24 |
+
ParallelProcessor,
|
| 25 |
+
ParallelInference,
|
| 26 |
+
OptimizedInferenceWrapper
|
| 27 |
)
|
| 28 |
|
| 29 |
# サンプルファイルのディレクトリを定義
|
|
|
|
| 49 |
token_manager = AvatarTokenManager(avatar_cache)
|
| 50 |
print(f"✅ アバターキャッシュ初期化: {avatar_cache.get_cache_info()}")
|
| 51 |
|
| 52 |
+
# 5. 推論キャッシュの初期化
|
| 53 |
+
inference_cache = InferenceCache(
|
| 54 |
+
cache_dir="/tmp/inference_cache",
|
| 55 |
+
memory_cache_size=50,
|
| 56 |
+
file_cache_size_gb=5.0,
|
| 57 |
+
ttl_hours=24
|
| 58 |
+
)
|
| 59 |
+
cached_inference = CachedInference(inference_cache)
|
| 60 |
+
print(f"✅ 推論キャッシュ初期化: {inference_cache.get_cache_stats()}")
|
| 61 |
+
|
| 62 |
+
# 6. 並列処理の初期化(SDK初期化後に移動)
|
| 63 |
+
|
| 64 |
# モデルの初期化(最適化版)
|
| 65 |
USE_PYTORCH = True
|
| 66 |
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
|
|
|
|
| 109 |
traceback.print_exc()
|
| 110 |
raise
|
| 111 |
|
| 112 |
+
# 並列処理の初期化(SDK初期化成功後)
|
| 113 |
+
parallel_processor = ParallelProcessor(num_threads=4, num_processes=2)
|
| 114 |
+
parallel_inference = ParallelInference(SDK, parallel_processor)
|
| 115 |
+
optimized_wrapper = OptimizedInferenceWrapper(
|
| 116 |
+
SDK,
|
| 117 |
+
use_parallel=True,
|
| 118 |
+
use_cache=True,
|
| 119 |
+
use_gpu_opt=True
|
| 120 |
+
)
|
| 121 |
+
print(f"✅ 並列処理初期化: {parallel_inference.get_performance_stats()}")
|
| 122 |
+
|
| 123 |
def prepare_avatar(image_file) -> Dict[str, Any]:
|
| 124 |
"""
|
| 125 |
画像を事前処理してアバタートークンを生成
|
|
|
|
| 178 |
audio_file,
|
| 179 |
source_image,
|
| 180 |
avatar_token: Optional[str] = None,
|
| 181 |
+
use_resolution_optimization: bool = True,
|
| 182 |
+
use_inference_cache: bool = True,
|
| 183 |
+
use_parallel_processing: bool = True
|
| 184 |
):
|
| 185 |
"""
|
| 186 |
+
最適化されたTalking Head生成処理(キャッシュ対応)
|
| 187 |
|
| 188 |
Args:
|
| 189 |
audio_file: 音声ファイル
|
| 190 |
source_image: ソース画像(avatar_tokenがない場合に使用)
|
| 191 |
avatar_token: 事前生成されたアバタートークン
|
| 192 |
use_resolution_optimization: 解像度最適化を使用するか
|
| 193 |
+
use_inference_cache: 推論キャッシュを使用するか
|
| 194 |
"""
|
| 195 |
|
| 196 |
if audio_file is None:
|
|
|
|
| 215 |
|
| 216 |
# 解像度最適化設定を適用
|
| 217 |
if use_resolution_optimization:
|
|
|
|
| 218 |
setup_kwargs = {
|
| 219 |
"max_size": FIXED_RESOLUTION, # 320固定
|
| 220 |
"sampling_timesteps": resolution_optimizer.get_diffusion_steps() # 25
|
|
|
|
| 223 |
else:
|
| 224 |
setup_kwargs = {}
|
| 225 |
|
| 226 |
+
# 処理方法の選択
|
| 227 |
+
if use_parallel_processing and source_image:
|
| 228 |
+
# 並列処理を使用
|
| 229 |
+
print("🔄 並列処理モードで実行...")
|
| 230 |
+
|
| 231 |
+
if use_inference_cache:
|
| 232 |
+
# キャッシュ + 並列処理
|
| 233 |
+
def inference_func(audio_path, image_path, out_path, **kwargs):
|
| 234 |
+
# 並列処理ラッパーを使用
|
| 235 |
+
optimized_wrapper.process(
|
| 236 |
+
audio_path, image_path, out_path,
|
| 237 |
+
seed=1024,
|
| 238 |
+
more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})}
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# キャッシュシステムを通じて処理
|
| 242 |
+
result_path, cache_hit, process_time = cached_inference.process_with_cache(
|
| 243 |
+
inference_func,
|
| 244 |
+
audio_file,
|
| 245 |
+
source_image,
|
| 246 |
+
output_path,
|
| 247 |
+
resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default",
|
| 248 |
+
steps=setup_kwargs.get('sampling_timesteps', 50),
|
| 249 |
+
setup_kwargs=setup_kwargs
|
| 250 |
+
)
|
| 251 |
+
cache_status = "キャッシュヒット(並列)" if cache_hit else "新規生成(並列)"
|
| 252 |
+
else:
|
| 253 |
+
# 並列処理のみ
|
| 254 |
+
_, process_time, stats = optimized_wrapper.process(
|
| 255 |
+
audio_file, source_image, output_path,
|
| 256 |
+
seed=1024,
|
| 257 |
+
more_kwargs={"setup_kwargs": setup_kwargs}
|
| 258 |
+
)
|
| 259 |
+
cache_hit = False
|
| 260 |
+
cache_status = "並列処理(キャッシュ未使用)"
|
| 261 |
+
|
| 262 |
+
elif use_inference_cache and source_image:
|
| 263 |
+
# キャッシュのみ(並列処理なし)
|
| 264 |
+
def inference_func(audio_path, image_path, out_path, **kwargs):
|
| 265 |
+
seed_everything(1024)
|
| 266 |
+
run(SDK, audio_path, image_path, out_path,
|
| 267 |
+
more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})})
|
| 268 |
+
|
| 269 |
+
# キャッシュシステムを通じて処理
|
| 270 |
+
result_path, cache_hit, process_time = cached_inference.process_with_cache(
|
| 271 |
+
inference_func,
|
| 272 |
+
audio_file,
|
| 273 |
+
source_image,
|
| 274 |
+
output_path,
|
| 275 |
+
resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default",
|
| 276 |
+
steps=setup_kwargs.get('sampling_timesteps', 50),
|
| 277 |
+
setup_kwargs=setup_kwargs
|
| 278 |
+
)
|
| 279 |
+
cache_status = "キャッシュヒット" if cache_hit else "新規生成"
|
| 280 |
+
else:
|
| 281 |
+
# 通常処理(並列処理もキャッシュもなし)
|
| 282 |
+
print(f"処理開始: audio={audio_file}, image={source_image}, token={avatar_token is not None}")
|
| 283 |
+
seed_everything(1024)
|
| 284 |
+
run(SDK, audio_file, source_image, output_path, more_kwargs={"setup_kwargs": setup_kwargs})
|
| 285 |
+
process_time = time.time() - start_time
|
| 286 |
+
cache_hit = False
|
| 287 |
+
cache_status = "通常処理"
|
| 288 |
|
| 289 |
# 結果の確認
|
| 290 |
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
|
|
|
| 293 |
✅ 処理完了!
|
| 294 |
処理時間: {process_time:.2f}秒
|
| 295 |
解像度: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}
|
| 296 |
+
最適化設定:
|
| 297 |
+
- 解像度最適化: {'有効' if use_resolution_optimization else '無効'}
|
| 298 |
+
- 並列処理: {'有効' if use_parallel_processing else '無効'}
|
| 299 |
+
- アバターキャッシュ: {'使用' if avatar_token else '未使用'}
|
| 300 |
+
- 推論キャッシュ: {cache_status}
|
| 301 |
+
キャッシュ統計: {inference_cache.get_cache_stats()['memory_cache_entries']}件(メモリ), {inference_cache.get_cache_stats()['file_cache_entries']}件(ファイル)
|
| 302 |
"""
|
| 303 |
return output_path, perf_info
|
| 304 |
else:
|
|
|
|
| 320 |
- 🎯 画像事前アップロード&キャッシュ機能
|
| 321 |
- ⚡ GPU最適化(Mixed Precision, torch.compile)
|
| 322 |
- 💾 Cold Start最適化
|
| 323 |
+
- 🔄 推論キャッシュ(同じ入力で即座に結果を返す)
|
| 324 |
+
- 🚀 並列処理(音声・画像の前処理を並列化)
|
| 325 |
|
| 326 |
## 使い方
|
| 327 |
### 方法1: 通常の使用
|
|
|
|
| 360 |
value=True
|
| 361 |
)
|
| 362 |
|
| 363 |
+
use_cache = gr.Checkbox(
|
| 364 |
+
label="推論キャッシュを使用(同じ入力で高速化)",
|
| 365 |
+
value=True
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
use_parallel = gr.Checkbox(
|
| 369 |
+
label="並列処理を使用(前処理を高速化)",
|
| 370 |
+
value=True
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
generate_btn = gr.Button("🎬 生成", variant="primary")
|
| 374 |
|
| 375 |
with gr.Column():
|
|
|
|
| 404 |
|
| 405 |
# タブ3: 最適化情報
|
| 406 |
with gr.TabItem("📊 最適化情報"):
|
| 407 |
+
with gr.Row():
|
| 408 |
+
refresh_btn = gr.Button("🔄 情報を更新", scale=1)
|
| 409 |
+
|
| 410 |
+
info_display = gr.Markdown(f"""
|
| 411 |
### 現在の最適化設定
|
| 412 |
|
| 413 |
{resolution_optimizer.get_optimization_summary()}
|
| 414 |
|
| 415 |
{gpu_optimizer.get_optimization_summary()}
|
| 416 |
|
| 417 |
+
### アバターキャッシュ情報
|
| 418 |
{avatar_cache.get_cache_info()}
|
| 419 |
+
|
| 420 |
+
### 推論キャッシュ情報
|
| 421 |
+
{inference_cache.get_cache_stats()}
|
| 422 |
""")
|
| 423 |
+
|
| 424 |
+
# キャッシュ管理ボタン
|
| 425 |
+
with gr.Row():
|
| 426 |
+
clear_inference_cache_btn = gr.Button("🗑️ 推論キャッシュをクリア", variant="secondary")
|
| 427 |
+
clear_avatar_cache_btn = gr.Button("🗑️ アバターキャッシュをクリア", variant="secondary")
|
| 428 |
+
|
| 429 |
+
cache_status = gr.Textbox(label="キャッシュ操作ステータス", lines=2)
|
| 430 |
|
| 431 |
# サンプル
|
| 432 |
example_audio = EXAMPLES_DIR / "audio.wav"
|
|
|
|
| 435 |
if example_audio.exists() and example_image.exists():
|
| 436 |
gr.Examples(
|
| 437 |
examples=[
|
| 438 |
+
[str(example_audio), str(example_image), None, True, True, True]
|
| 439 |
],
|
| 440 |
+
inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel],
|
| 441 |
outputs=[video_output, status_output],
|
| 442 |
fn=process_talking_head_optimized
|
| 443 |
)
|
|
|
|
| 445 |
# イベントハンドラ
|
| 446 |
generate_btn.click(
|
| 447 |
fn=process_talking_head_optimized,
|
| 448 |
+
inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel],
|
| 449 |
outputs=[video_output, status_output]
|
| 450 |
)
|
| 451 |
|
|
|
|
| 454 |
inputs=[avatar_image_input],
|
| 455 |
outputs=[prepare_output]
|
| 456 |
)
|
| 457 |
+
|
| 458 |
+
# キャッシュ管理関数
|
| 459 |
+
def refresh_info():
|
| 460 |
+
return f"""
|
| 461 |
+
### 現在の最適化設定
|
| 462 |
+
|
| 463 |
+
{resolution_optimizer.get_optimization_summary()}
|
| 464 |
+
|
| 465 |
+
{gpu_optimizer.get_optimization_summary()}
|
| 466 |
+
|
| 467 |
+
### アバターキャッシュ情報
|
| 468 |
+
{avatar_cache.get_cache_info()}
|
| 469 |
+
|
| 470 |
+
### 推論キャッシュ情報
|
| 471 |
+
{inference_cache.get_cache_stats()}
|
| 472 |
+
|
| 473 |
+
### 並列処理情報
|
| 474 |
+
{parallel_inference.get_performance_stats()}
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
def clear_inference_cache():
|
| 478 |
+
inference_cache.clear_cache()
|
| 479 |
+
return "✅ 推論キャッシュをクリアしました"
|
| 480 |
+
|
| 481 |
+
def clear_avatar_cache():
|
| 482 |
+
avatar_cache.clear_cache()
|
| 483 |
+
return "✅ アバターキャッシュをクリアしました"
|
| 484 |
+
|
| 485 |
+
# キャッシュ管理イベント
|
| 486 |
+
refresh_btn.click(
|
| 487 |
+
fn=refresh_info,
|
| 488 |
+
outputs=[info_display]
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
clear_inference_cache_btn.click(
|
| 492 |
+
fn=clear_inference_cache,
|
| 493 |
+
outputs=[cache_status]
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
clear_avatar_cache_btn.click(
|
| 497 |
+
fn=clear_avatar_cache,
|
| 498 |
+
outputs=[cache_status]
|
| 499 |
+
)
|
| 500 |
|
| 501 |
if __name__ == "__main__":
|
| 502 |
# Cold Start最適化設定でGradioを起動
|
core/optimization/__init__.py
CHANGED
|
@@ -6,6 +6,9 @@ from .resolution_optimization import FixedResolutionProcessor
|
|
| 6 |
from .gpu_optimization import GPUOptimizer, OptimizedInference
|
| 7 |
from .avatar_cache import AvatarCache, AvatarTokenManager
|
| 8 |
from .cold_start_optimization import ColdStartOptimizer
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
__all__ = [
|
| 11 |
'FixedResolutionProcessor',
|
|
@@ -13,5 +16,11 @@ __all__ = [
|
|
| 13 |
'OptimizedInference',
|
| 14 |
'AvatarCache',
|
| 15 |
'AvatarTokenManager',
|
| 16 |
-
'ColdStartOptimizer'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
]
|
|
|
|
| 6 |
from .gpu_optimization import GPUOptimizer, OptimizedInference
|
| 7 |
from .avatar_cache import AvatarCache, AvatarTokenManager
|
| 8 |
from .cold_start_optimization import ColdStartOptimizer
|
| 9 |
+
from .inference_cache import InferenceCache, CachedInference
|
| 10 |
+
from .parallel_processing import ParallelProcessor, PipelineProcessor
|
| 11 |
+
from .parallel_inference import ParallelInference, OptimizedInferenceWrapper
|
| 12 |
|
| 13 |
__all__ = [
|
| 14 |
'FixedResolutionProcessor',
|
|
|
|
| 16 |
'OptimizedInference',
|
| 17 |
'AvatarCache',
|
| 18 |
'AvatarTokenManager',
|
| 19 |
+
'ColdStartOptimizer',
|
| 20 |
+
'InferenceCache',
|
| 21 |
+
'CachedInference',
|
| 22 |
+
'ParallelProcessor',
|
| 23 |
+
'PipelineProcessor',
|
| 24 |
+
'ParallelInference',
|
| 25 |
+
'OptimizedInferenceWrapper'
|
| 26 |
]
|
core/optimization/inference_cache.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Cache System for DittoTalkingHead
|
| 3 |
+
Caches video generation results for faster repeated processing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import hashlib
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import pickle
|
| 10 |
+
import time
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Dict, Any, Tuple, Union
|
| 13 |
+
from functools import lru_cache
|
| 14 |
+
import shutil
|
| 15 |
+
from datetime import datetime, timedelta
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class InferenceCache:
|
| 19 |
+
"""
|
| 20 |
+
Cache system for video generation results
|
| 21 |
+
Supports both memory and file-based caching
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
cache_dir: str = "/tmp/inference_cache",
|
| 27 |
+
memory_cache_size: int = 100,
|
| 28 |
+
file_cache_size_gb: float = 10.0,
|
| 29 |
+
ttl_hours: int = 24
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initialize inference cache
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
cache_dir: Directory for file-based cache
|
| 36 |
+
memory_cache_size: Maximum number of items in memory cache
|
| 37 |
+
file_cache_size_gb: Maximum size of file cache in GB
|
| 38 |
+
ttl_hours: Time to live for cache entries in hours
|
| 39 |
+
"""
|
| 40 |
+
self.cache_dir = Path(cache_dir)
|
| 41 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
self.memory_cache_size = memory_cache_size
|
| 44 |
+
self.file_cache_size_bytes = int(file_cache_size_gb * 1024 * 1024 * 1024)
|
| 45 |
+
self.ttl_seconds = ttl_hours * 3600
|
| 46 |
+
|
| 47 |
+
# Metadata file for managing cache
|
| 48 |
+
self.metadata_file = self.cache_dir / "cache_metadata.json"
|
| 49 |
+
self.metadata = self._load_metadata()
|
| 50 |
+
|
| 51 |
+
# In-memory cache
|
| 52 |
+
self._memory_cache = {}
|
| 53 |
+
self._access_times = {}
|
| 54 |
+
|
| 55 |
+
# Clean up expired entries on initialization
|
| 56 |
+
self._cleanup_expired()
|
| 57 |
+
|
| 58 |
+
def _load_metadata(self) -> Dict[str, Any]:
|
| 59 |
+
"""Load cache metadata"""
|
| 60 |
+
if self.metadata_file.exists():
|
| 61 |
+
try:
|
| 62 |
+
with open(self.metadata_file, 'r') as f:
|
| 63 |
+
return json.load(f)
|
| 64 |
+
except:
|
| 65 |
+
return {}
|
| 66 |
+
return {}
|
| 67 |
+
|
| 68 |
+
def _save_metadata(self):
|
| 69 |
+
"""Save cache metadata"""
|
| 70 |
+
with open(self.metadata_file, 'w') as f:
|
| 71 |
+
json.dump(self.metadata, f, indent=2)
|
| 72 |
+
|
| 73 |
+
def generate_cache_key(
|
| 74 |
+
self,
|
| 75 |
+
audio_path: str,
|
| 76 |
+
image_path: str,
|
| 77 |
+
**kwargs
|
| 78 |
+
) -> str:
|
| 79 |
+
"""
|
| 80 |
+
Generate unique cache key based on input parameters
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
audio_path: Path to audio file
|
| 84 |
+
image_path: Path to image file
|
| 85 |
+
**kwargs: Additional parameters affecting output
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
SHA-256 hash as cache key
|
| 89 |
+
"""
|
| 90 |
+
# Read file contents for hashing
|
| 91 |
+
with open(audio_path, 'rb') as f:
|
| 92 |
+
audio_hash = hashlib.sha256(f.read()).hexdigest()
|
| 93 |
+
|
| 94 |
+
with open(image_path, 'rb') as f:
|
| 95 |
+
image_hash = hashlib.sha256(f.read()).hexdigest()
|
| 96 |
+
|
| 97 |
+
# Include relevant parameters in key
|
| 98 |
+
key_data = {
|
| 99 |
+
'audio': audio_hash,
|
| 100 |
+
'image': image_hash,
|
| 101 |
+
'resolution': kwargs.get('resolution', '320x320'),
|
| 102 |
+
'steps': kwargs.get('steps', 25),
|
| 103 |
+
'seed': kwargs.get('seed', None)
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Generate final key
|
| 107 |
+
key_str = json.dumps(key_data, sort_keys=True)
|
| 108 |
+
return hashlib.sha256(key_str.encode()).hexdigest()
|
| 109 |
+
|
| 110 |
+
def get_from_memory(self, cache_key: str) -> Optional[str]:
|
| 111 |
+
"""
|
| 112 |
+
Get video path from memory cache
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
cache_key: Cache key
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Video file path if found, None otherwise
|
| 119 |
+
"""
|
| 120 |
+
if cache_key in self._memory_cache:
|
| 121 |
+
self._access_times[cache_key] = time.time()
|
| 122 |
+
return self._memory_cache[cache_key]
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
def get_from_file(self, cache_key: str) -> Optional[str]:
|
| 126 |
+
"""
|
| 127 |
+
Get video path from file cache
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
cache_key: Cache key
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Video file path if found, None otherwise
|
| 134 |
+
"""
|
| 135 |
+
if cache_key not in self.metadata:
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
entry = self.metadata[cache_key]
|
| 139 |
+
|
| 140 |
+
# Check expiration
|
| 141 |
+
if time.time() > entry['expires_at']:
|
| 142 |
+
self._remove_cache_entry(cache_key)
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
# Check if file exists
|
| 146 |
+
video_path = self.cache_dir / entry['filename']
|
| 147 |
+
if not video_path.exists():
|
| 148 |
+
self._remove_cache_entry(cache_key)
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
# Update access time
|
| 152 |
+
self.metadata[cache_key]['last_access'] = time.time()
|
| 153 |
+
self._save_metadata()
|
| 154 |
+
|
| 155 |
+
# Add to memory cache
|
| 156 |
+
self._add_to_memory_cache(cache_key, str(video_path))
|
| 157 |
+
|
| 158 |
+
return str(video_path)
|
| 159 |
+
|
| 160 |
+
def get(self, cache_key: str) -> Optional[str]:
|
| 161 |
+
"""
|
| 162 |
+
Get video from cache (memory first, then file)
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
cache_key: Cache key
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Video file path if found, None otherwise
|
| 169 |
+
"""
|
| 170 |
+
# Try memory cache first
|
| 171 |
+
result = self.get_from_memory(cache_key)
|
| 172 |
+
if result:
|
| 173 |
+
return result
|
| 174 |
+
|
| 175 |
+
# Try file cache
|
| 176 |
+
return self.get_from_file(cache_key)
|
| 177 |
+
|
| 178 |
+
def put(
|
| 179 |
+
self,
|
| 180 |
+
cache_key: str,
|
| 181 |
+
video_path: str,
|
| 182 |
+
**metadata
|
| 183 |
+
) -> bool:
|
| 184 |
+
"""
|
| 185 |
+
Store video in cache
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
cache_key: Cache key
|
| 189 |
+
video_path: Path to generated video
|
| 190 |
+
**metadata: Additional metadata to store
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
True if stored successfully
|
| 194 |
+
"""
|
| 195 |
+
try:
|
| 196 |
+
# Copy video to cache directory
|
| 197 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 198 |
+
cache_filename = f"{cache_key[:8]}_{timestamp}.mp4"
|
| 199 |
+
cache_video_path = self.cache_dir / cache_filename
|
| 200 |
+
|
| 201 |
+
shutil.copy2(video_path, cache_video_path)
|
| 202 |
+
|
| 203 |
+
# Store metadata
|
| 204 |
+
self.metadata[cache_key] = {
|
| 205 |
+
'filename': cache_filename,
|
| 206 |
+
'created_at': time.time(),
|
| 207 |
+
'expires_at': time.time() + self.ttl_seconds,
|
| 208 |
+
'last_access': time.time(),
|
| 209 |
+
'size_bytes': os.path.getsize(cache_video_path),
|
| 210 |
+
'metadata': metadata
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
# Check cache size and clean if needed
|
| 214 |
+
self._check_cache_size()
|
| 215 |
+
|
| 216 |
+
# Save metadata
|
| 217 |
+
self._save_metadata()
|
| 218 |
+
|
| 219 |
+
# Add to memory cache
|
| 220 |
+
self._add_to_memory_cache(cache_key, str(cache_video_path))
|
| 221 |
+
|
| 222 |
+
return True
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
print(f"Error storing cache: {e}")
|
| 226 |
+
return False
|
| 227 |
+
|
| 228 |
+
def _add_to_memory_cache(self, cache_key: str, video_path: str):
|
| 229 |
+
"""Add item to memory cache with LRU eviction"""
|
| 230 |
+
# Check if we need to evict
|
| 231 |
+
if len(self._memory_cache) >= self.memory_cache_size:
|
| 232 |
+
# Find least recently used
|
| 233 |
+
lru_key = min(self._access_times, key=self._access_times.get)
|
| 234 |
+
del self._memory_cache[lru_key]
|
| 235 |
+
del self._access_times[lru_key]
|
| 236 |
+
|
| 237 |
+
self._memory_cache[cache_key] = video_path
|
| 238 |
+
self._access_times[cache_key] = time.time()
|
| 239 |
+
|
| 240 |
+
def _check_cache_size(self):
|
| 241 |
+
"""Check and maintain cache size limit"""
|
| 242 |
+
total_size = sum(
|
| 243 |
+
entry['size_bytes']
|
| 244 |
+
for entry in self.metadata.values()
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if total_size > self.file_cache_size_bytes:
|
| 248 |
+
# Remove oldest entries until under limit
|
| 249 |
+
sorted_entries = sorted(
|
| 250 |
+
self.metadata.items(),
|
| 251 |
+
key=lambda x: x[1]['last_access']
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
while total_size > self.file_cache_size_bytes and sorted_entries:
|
| 255 |
+
key_to_remove, entry = sorted_entries.pop(0)
|
| 256 |
+
total_size -= entry['size_bytes']
|
| 257 |
+
self._remove_cache_entry(key_to_remove)
|
| 258 |
+
|
| 259 |
+
def _cleanup_expired(self):
|
| 260 |
+
"""Remove expired cache entries"""
|
| 261 |
+
current_time = time.time()
|
| 262 |
+
expired_keys = [
|
| 263 |
+
key for key, entry in self.metadata.items()
|
| 264 |
+
if current_time > entry['expires_at']
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
for key in expired_keys:
|
| 268 |
+
self._remove_cache_entry(key)
|
| 269 |
+
|
| 270 |
+
if expired_keys:
|
| 271 |
+
print(f"Cleaned up {len(expired_keys)} expired cache entries")
|
| 272 |
+
|
| 273 |
+
def _remove_cache_entry(self, cache_key: str):
|
| 274 |
+
"""Remove a cache entry"""
|
| 275 |
+
if cache_key in self.metadata:
|
| 276 |
+
# Remove file
|
| 277 |
+
video_file = self.cache_dir / self.metadata[cache_key]['filename']
|
| 278 |
+
if video_file.exists():
|
| 279 |
+
video_file.unlink()
|
| 280 |
+
|
| 281 |
+
# Remove from metadata
|
| 282 |
+
del self.metadata[cache_key]
|
| 283 |
+
|
| 284 |
+
# Remove from memory cache
|
| 285 |
+
if cache_key in self._memory_cache:
|
| 286 |
+
del self._memory_cache[cache_key]
|
| 287 |
+
del self._access_times[cache_key]
|
| 288 |
+
|
| 289 |
+
def clear_cache(self):
|
| 290 |
+
"""Clear all cache entries"""
|
| 291 |
+
# Remove all video files
|
| 292 |
+
for file in self.cache_dir.glob("*.mp4"):
|
| 293 |
+
file.unlink()
|
| 294 |
+
|
| 295 |
+
# Clear metadata
|
| 296 |
+
self.metadata = {}
|
| 297 |
+
self._save_metadata()
|
| 298 |
+
|
| 299 |
+
# Clear memory cache
|
| 300 |
+
self._memory_cache.clear()
|
| 301 |
+
self._access_times.clear()
|
| 302 |
+
|
| 303 |
+
print("Inference cache cleared")
|
| 304 |
+
|
| 305 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
| 306 |
+
"""Get cache statistics"""
|
| 307 |
+
total_size = sum(
|
| 308 |
+
entry['size_bytes']
|
| 309 |
+
for entry in self.metadata.values()
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
memory_hits = len(self._memory_cache)
|
| 313 |
+
file_entries = len(self.metadata)
|
| 314 |
+
|
| 315 |
+
return {
|
| 316 |
+
'memory_cache_entries': memory_hits,
|
| 317 |
+
'file_cache_entries': file_entries,
|
| 318 |
+
'total_cache_size_mb': total_size / (1024 * 1024),
|
| 319 |
+
'cache_size_limit_gb': self.file_cache_size_bytes / (1024 * 1024 * 1024),
|
| 320 |
+
'ttl_hours': self.ttl_seconds / 3600,
|
| 321 |
+
'cache_directory': str(self.cache_dir)
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class CachedInference:
|
| 326 |
+
"""
|
| 327 |
+
Wrapper for cached inference execution
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(self, cache: InferenceCache):
|
| 331 |
+
"""
|
| 332 |
+
Initialize cached inference
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
cache: InferenceCache instance
|
| 336 |
+
"""
|
| 337 |
+
self.cache = cache
|
| 338 |
+
|
| 339 |
+
def process_with_cache(
|
| 340 |
+
self,
|
| 341 |
+
inference_func: callable,
|
| 342 |
+
audio_path: str,
|
| 343 |
+
image_path: str,
|
| 344 |
+
output_path: str,
|
| 345 |
+
**kwargs
|
| 346 |
+
) -> Tuple[str, bool, float]:
|
| 347 |
+
"""
|
| 348 |
+
Process with caching
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
inference_func: Function to generate video
|
| 352 |
+
audio_path: Path to audio file
|
| 353 |
+
image_path: Path to image file
|
| 354 |
+
output_path: Desired output path
|
| 355 |
+
**kwargs: Additional parameters
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
Tuple of (output_path, cache_hit, process_time)
|
| 359 |
+
"""
|
| 360 |
+
start_time = time.time()
|
| 361 |
+
|
| 362 |
+
# Generate cache key
|
| 363 |
+
cache_key = self.cache.generate_cache_key(
|
| 364 |
+
audio_path, image_path, **kwargs
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Check cache
|
| 368 |
+
cached_video = self.cache.get(cache_key)
|
| 369 |
+
|
| 370 |
+
if cached_video:
|
| 371 |
+
# Cache hit - copy to output path
|
| 372 |
+
shutil.copy2(cached_video, output_path)
|
| 373 |
+
process_time = time.time() - start_time
|
| 374 |
+
print(f"✅ Cache hit! Retrieved in {process_time:.2f}s")
|
| 375 |
+
return output_path, True, process_time
|
| 376 |
+
|
| 377 |
+
# Cache miss - generate video
|
| 378 |
+
print("Cache miss - generating video...")
|
| 379 |
+
inference_func(audio_path, image_path, output_path, **kwargs)
|
| 380 |
+
|
| 381 |
+
# Store in cache
|
| 382 |
+
if os.path.exists(output_path):
|
| 383 |
+
self.cache.put(cache_key, output_path, **kwargs)
|
| 384 |
+
|
| 385 |
+
process_time = time.time() - start_time
|
| 386 |
+
return output_path, False, process_time
|
core/optimization/parallel_inference.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parallel Inference Integration for DittoTalkingHead
|
| 3 |
+
Integrates parallel processing into the inference pipeline
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
from typing import Dict, Any, Tuple, Optional
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from .parallel_processing import ParallelProcessor, PipelineProcessor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ParallelInference:
|
| 17 |
+
"""
|
| 18 |
+
Parallel inference wrapper for DittoTalkingHead
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, sdk, parallel_processor: Optional[ParallelProcessor] = None):
|
| 22 |
+
"""
|
| 23 |
+
Initialize parallel inference
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
sdk: StreamSDK instance
|
| 27 |
+
parallel_processor: ParallelProcessor instance
|
| 28 |
+
"""
|
| 29 |
+
self.sdk = sdk
|
| 30 |
+
self.parallel_processor = parallel_processor or ParallelProcessor(num_threads=4)
|
| 31 |
+
|
| 32 |
+
# Setup pipeline stages
|
| 33 |
+
self.pipeline_stages = {
|
| 34 |
+
'load': self._load_files,
|
| 35 |
+
'preprocess': self._preprocess,
|
| 36 |
+
'inference': self._inference,
|
| 37 |
+
'postprocess': self._postprocess
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def _load_files(self, paths: Dict[str, str]) -> Dict[str, Any]:
|
| 41 |
+
"""Load audio and image files"""
|
| 42 |
+
audio_path = paths['audio']
|
| 43 |
+
image_path = paths['image']
|
| 44 |
+
|
| 45 |
+
# Parallel loading
|
| 46 |
+
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync(
|
| 47 |
+
audio_path, image_path
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
'audio_data': audio_data,
|
| 52 |
+
'image_data': image_data,
|
| 53 |
+
'paths': paths
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def _preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 57 |
+
"""Preprocess loaded data"""
|
| 58 |
+
# Extract audio features
|
| 59 |
+
audio = data['audio_data']['audio']
|
| 60 |
+
sr = data['audio_data']['sample_rate']
|
| 61 |
+
|
| 62 |
+
# Prepare for SDK
|
| 63 |
+
import librosa
|
| 64 |
+
import math
|
| 65 |
+
|
| 66 |
+
# Calculate number of frames
|
| 67 |
+
num_frames = math.ceil(len(audio) / 16000 * 25)
|
| 68 |
+
|
| 69 |
+
# Prepare image
|
| 70 |
+
image = data['image_data']['image']
|
| 71 |
+
|
| 72 |
+
return {
|
| 73 |
+
'audio': audio,
|
| 74 |
+
'image': image,
|
| 75 |
+
'num_frames': num_frames,
|
| 76 |
+
'paths': data['paths']
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def _inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 80 |
+
"""Run inference"""
|
| 81 |
+
# This would integrate with the actual SDK inference
|
| 82 |
+
# For now, placeholder
|
| 83 |
+
return {
|
| 84 |
+
'result': 'inference_result',
|
| 85 |
+
'paths': data['paths']
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
def _postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 89 |
+
"""Postprocess results"""
|
| 90 |
+
return data
|
| 91 |
+
|
| 92 |
+
async def process_parallel_async(
|
| 93 |
+
self,
|
| 94 |
+
audio_path: str,
|
| 95 |
+
image_path: str,
|
| 96 |
+
output_path: str,
|
| 97 |
+
**kwargs
|
| 98 |
+
) -> Tuple[str, float]:
|
| 99 |
+
"""
|
| 100 |
+
Process with full parallelization (async)
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
audio_path: Path to audio file
|
| 104 |
+
image_path: Path to image file
|
| 105 |
+
output_path: Output video path
|
| 106 |
+
**kwargs: Additional parameters
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Tuple of (output_path, process_time)
|
| 110 |
+
"""
|
| 111 |
+
start_time = time.time()
|
| 112 |
+
|
| 113 |
+
# Parallel preprocessing
|
| 114 |
+
audio_data, image_data = await self.parallel_processor.preprocess_parallel_async(
|
| 115 |
+
audio_path, image_path, kwargs.get('target_size', 320)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Run inference (simplified for integration)
|
| 119 |
+
# In real implementation, this would call SDK methods
|
| 120 |
+
|
| 121 |
+
process_time = time.time() - start_time
|
| 122 |
+
return output_path, process_time
|
| 123 |
+
|
| 124 |
+
def process_parallel_sync(
|
| 125 |
+
self,
|
| 126 |
+
audio_path: str,
|
| 127 |
+
image_path: str,
|
| 128 |
+
output_path: str,
|
| 129 |
+
**kwargs
|
| 130 |
+
) -> Tuple[str, float]:
|
| 131 |
+
"""
|
| 132 |
+
Process with parallelization (sync)
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
audio_path: Path to audio file
|
| 136 |
+
image_path: Path to image file
|
| 137 |
+
output_path: Output video path
|
| 138 |
+
**kwargs: Additional parameters
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Tuple of (output_path, process_time)
|
| 142 |
+
"""
|
| 143 |
+
start_time = time.time()
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
# Parallel preprocessing
|
| 147 |
+
print("🔄 Starting parallel preprocessing...")
|
| 148 |
+
preprocess_start = time.time()
|
| 149 |
+
|
| 150 |
+
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync(
|
| 151 |
+
audio_path, image_path, kwargs.get('target_size', 320)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
preprocess_time = time.time() - preprocess_start
|
| 155 |
+
print(f"✅ Parallel preprocessing completed in {preprocess_time:.2f}s")
|
| 156 |
+
|
| 157 |
+
# Run actual SDK inference
|
| 158 |
+
# This integrates with the existing SDK
|
| 159 |
+
from inference import run, seed_everything
|
| 160 |
+
|
| 161 |
+
seed_everything(kwargs.get('seed', 1024))
|
| 162 |
+
|
| 163 |
+
inference_start = time.time()
|
| 164 |
+
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {}))
|
| 165 |
+
inference_time = time.time() - inference_start
|
| 166 |
+
|
| 167 |
+
print(f"✅ Inference completed in {inference_time:.2f}s")
|
| 168 |
+
|
| 169 |
+
total_time = time.time() - start_time
|
| 170 |
+
|
| 171 |
+
# Performance breakdown
|
| 172 |
+
print(f"""
|
| 173 |
+
🎯 Performance Breakdown:
|
| 174 |
+
- Preprocessing (parallel): {preprocess_time:.2f}s
|
| 175 |
+
- Inference: {inference_time:.2f}s
|
| 176 |
+
- Total: {total_time:.2f}s
|
| 177 |
+
""")
|
| 178 |
+
|
| 179 |
+
return output_path, total_time
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"❌ Error in parallel processing: {e}")
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
def get_performance_stats(self) -> Dict[str, Any]:
|
| 186 |
+
"""Get performance statistics"""
|
| 187 |
+
return {
|
| 188 |
+
'num_threads': self.parallel_processor.num_threads,
|
| 189 |
+
'num_processes': self.parallel_processor.num_processes,
|
| 190 |
+
'cuda_streams_enabled': self.parallel_processor.use_cuda_streams
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class OptimizedInferenceWrapper:
|
| 195 |
+
"""
|
| 196 |
+
Wrapper that combines all optimizations
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
sdk,
|
| 202 |
+
use_parallel: bool = True,
|
| 203 |
+
use_cache: bool = True,
|
| 204 |
+
use_gpu_opt: bool = True
|
| 205 |
+
):
|
| 206 |
+
"""
|
| 207 |
+
Initialize optimized inference wrapper
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
sdk: StreamSDK instance
|
| 211 |
+
use_parallel: Enable parallel processing
|
| 212 |
+
use_cache: Enable caching
|
| 213 |
+
use_gpu_opt: Enable GPU optimizations
|
| 214 |
+
"""
|
| 215 |
+
self.sdk = sdk
|
| 216 |
+
self.use_parallel = use_parallel
|
| 217 |
+
self.use_cache = use_cache
|
| 218 |
+
self.use_gpu_opt = use_gpu_opt
|
| 219 |
+
|
| 220 |
+
# Initialize components
|
| 221 |
+
if use_parallel:
|
| 222 |
+
self.parallel_processor = ParallelProcessor(num_threads=4)
|
| 223 |
+
self.parallel_inference = ParallelInference(sdk, self.parallel_processor)
|
| 224 |
+
else:
|
| 225 |
+
self.parallel_processor = None
|
| 226 |
+
self.parallel_inference = None
|
| 227 |
+
|
| 228 |
+
def process(
|
| 229 |
+
self,
|
| 230 |
+
audio_path: str,
|
| 231 |
+
image_path: str,
|
| 232 |
+
output_path: str,
|
| 233 |
+
**kwargs
|
| 234 |
+
) -> Tuple[str, float, Dict[str, Any]]:
|
| 235 |
+
"""
|
| 236 |
+
Process with all optimizations
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Tuple of (output_path, process_time, stats)
|
| 240 |
+
"""
|
| 241 |
+
stats = {
|
| 242 |
+
'parallel_enabled': self.use_parallel,
|
| 243 |
+
'cache_enabled': self.use_cache,
|
| 244 |
+
'gpu_opt_enabled': self.use_gpu_opt
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
if self.use_parallel and self.parallel_inference:
|
| 248 |
+
output_path, process_time = self.parallel_inference.process_parallel_sync(
|
| 249 |
+
audio_path, image_path, output_path, **kwargs
|
| 250 |
+
)
|
| 251 |
+
stats['preprocessing'] = 'parallel'
|
| 252 |
+
else:
|
| 253 |
+
# Fallback to sequential
|
| 254 |
+
from inference import run, seed_everything
|
| 255 |
+
start_time = time.time()
|
| 256 |
+
seed_everything(kwargs.get('seed', 1024))
|
| 257 |
+
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {}))
|
| 258 |
+
process_time = time.time() - start_time
|
| 259 |
+
stats['preprocessing'] = 'sequential'
|
| 260 |
+
|
| 261 |
+
stats['process_time'] = process_time
|
| 262 |
+
|
| 263 |
+
return output_path, process_time, stats
|
| 264 |
+
|
| 265 |
+
def shutdown(self):
|
| 266 |
+
"""Cleanup resources"""
|
| 267 |
+
if self.parallel_processor:
|
| 268 |
+
self.parallel_processor.shutdown()
|
core/optimization/parallel_processing.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parallel Processing Module for DittoTalkingHead
|
| 3 |
+
Implements concurrent audio and image preprocessing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import concurrent.futures
|
| 8 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
| 9 |
+
import time
|
| 10 |
+
from typing import Tuple, Dict, Any, Optional, Callable
|
| 11 |
+
import numpy as np
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import threading
|
| 14 |
+
import queue
|
| 15 |
+
import torch
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ParallelProcessor:
|
| 20 |
+
"""
|
| 21 |
+
Parallel processing for audio and image preprocessing
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
num_threads: int = 4,
|
| 27 |
+
num_processes: int = 2,
|
| 28 |
+
use_cuda_streams: bool = True
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize parallel processor
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
num_threads: Number of threads for I/O operations
|
| 35 |
+
num_processes: Number of processes for CPU-intensive tasks
|
| 36 |
+
use_cuda_streams: Use CUDA streams for GPU operations
|
| 37 |
+
"""
|
| 38 |
+
self.num_threads = num_threads
|
| 39 |
+
self.num_processes = num_processes
|
| 40 |
+
self.use_cuda_streams = use_cuda_streams and torch.cuda.is_available()
|
| 41 |
+
|
| 42 |
+
# Thread pool for I/O operations
|
| 43 |
+
self.thread_executor = ThreadPoolExecutor(max_workers=num_threads)
|
| 44 |
+
|
| 45 |
+
# Process pool for CPU-intensive operations
|
| 46 |
+
self.process_executor = ProcessPoolExecutor(max_workers=num_processes)
|
| 47 |
+
|
| 48 |
+
# CUDA streams for GPU operations
|
| 49 |
+
if self.use_cuda_streams:
|
| 50 |
+
self.cuda_streams = [torch.cuda.Stream() for _ in range(2)]
|
| 51 |
+
else:
|
| 52 |
+
self.cuda_streams = None
|
| 53 |
+
|
| 54 |
+
print(f"✅ ParallelProcessor initialized: {num_threads} threads, {num_processes} processes")
|
| 55 |
+
if self.use_cuda_streams:
|
| 56 |
+
print("✅ CUDA streams enabled for GPU parallelism")
|
| 57 |
+
|
| 58 |
+
def preprocess_audio_parallel(self, audio_path: str) -> Dict[str, Any]:
|
| 59 |
+
"""
|
| 60 |
+
Preprocess audio file in parallel
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
audio_path: Path to audio file
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Preprocessed audio data
|
| 67 |
+
"""
|
| 68 |
+
import librosa
|
| 69 |
+
|
| 70 |
+
# Define subtasks
|
| 71 |
+
def load_audio():
|
| 72 |
+
return librosa.load(audio_path, sr=16000)
|
| 73 |
+
|
| 74 |
+
def extract_features(audio, sr):
|
| 75 |
+
# Extract various audio features in parallel
|
| 76 |
+
features = {}
|
| 77 |
+
|
| 78 |
+
# MFCC features
|
| 79 |
+
features['mfcc'] = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
|
| 80 |
+
|
| 81 |
+
# Spectral features
|
| 82 |
+
features['spectral_centroid'] = librosa.feature.spectral_centroid(y=audio, sr=sr)
|
| 83 |
+
features['spectral_rolloff'] = librosa.feature.spectral_rolloff(y=audio, sr=sr)
|
| 84 |
+
|
| 85 |
+
return features
|
| 86 |
+
|
| 87 |
+
# Load audio
|
| 88 |
+
audio, sr = load_audio()
|
| 89 |
+
|
| 90 |
+
# Extract features in parallel (if needed)
|
| 91 |
+
features = extract_features(audio, sr)
|
| 92 |
+
|
| 93 |
+
return {
|
| 94 |
+
'audio': audio,
|
| 95 |
+
'sample_rate': sr,
|
| 96 |
+
'features': features,
|
| 97 |
+
'duration': len(audio) / sr
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
def preprocess_image_parallel(self, image_path: str, target_size: int = 320) -> Dict[str, Any]:
|
| 101 |
+
"""
|
| 102 |
+
Preprocess image file in parallel
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
image_path: Path to image file
|
| 106 |
+
target_size: Target resolution
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Preprocessed image data
|
| 110 |
+
"""
|
| 111 |
+
from PIL import Image
|
| 112 |
+
import cv2
|
| 113 |
+
|
| 114 |
+
# Define subtasks
|
| 115 |
+
def load_and_resize():
|
| 116 |
+
# Load image
|
| 117 |
+
img = Image.open(image_path).convert('RGB')
|
| 118 |
+
|
| 119 |
+
# Resize
|
| 120 |
+
img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
| 121 |
+
|
| 122 |
+
return np.array(img)
|
| 123 |
+
|
| 124 |
+
def extract_face_landmarks(img_array):
|
| 125 |
+
# Face detection and landmark extraction
|
| 126 |
+
# Simplified version - in production, use MediaPipe or similar
|
| 127 |
+
return {
|
| 128 |
+
'has_face': True,
|
| 129 |
+
'landmarks': None # Placeholder
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Execute in parallel
|
| 133 |
+
future_img = self.thread_executor.submit(load_and_resize)
|
| 134 |
+
|
| 135 |
+
# Get results
|
| 136 |
+
img_array = future_img.result()
|
| 137 |
+
|
| 138 |
+
# Extract landmarks
|
| 139 |
+
landmarks = extract_face_landmarks(img_array)
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
'image': img_array,
|
| 143 |
+
'shape': img_array.shape,
|
| 144 |
+
'landmarks': landmarks
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
async def preprocess_parallel_async(
|
| 148 |
+
self,
|
| 149 |
+
audio_path: str,
|
| 150 |
+
image_path: str,
|
| 151 |
+
target_size: int = 320
|
| 152 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 153 |
+
"""
|
| 154 |
+
Asynchronously preprocess audio and image in parallel
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
audio_path: Path to audio file
|
| 158 |
+
image_path: Path to image file
|
| 159 |
+
target_size: Target image resolution
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Tuple of (audio_data, image_data)
|
| 163 |
+
"""
|
| 164 |
+
loop = asyncio.get_event_loop()
|
| 165 |
+
|
| 166 |
+
# Create tasks for parallel execution
|
| 167 |
+
audio_task = loop.run_in_executor(
|
| 168 |
+
self.thread_executor,
|
| 169 |
+
self.preprocess_audio_parallel,
|
| 170 |
+
audio_path
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
image_task = loop.run_in_executor(
|
| 174 |
+
self.thread_executor,
|
| 175 |
+
partial(self.preprocess_image_parallel, target_size=target_size),
|
| 176 |
+
image_path
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Wait for both tasks to complete
|
| 180 |
+
audio_data, image_data = await asyncio.gather(audio_task, image_task)
|
| 181 |
+
|
| 182 |
+
return audio_data, image_data
|
| 183 |
+
|
| 184 |
+
def preprocess_parallel_sync(
|
| 185 |
+
self,
|
| 186 |
+
audio_path: str,
|
| 187 |
+
image_path: str,
|
| 188 |
+
target_size: int = 320
|
| 189 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 190 |
+
"""
|
| 191 |
+
Synchronously preprocess audio and image in parallel
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
audio_path: Path to audio file
|
| 195 |
+
image_path: Path to image file
|
| 196 |
+
target_size: Target image resolution
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Tuple of (audio_data, image_data)
|
| 200 |
+
"""
|
| 201 |
+
# Submit tasks to thread pool
|
| 202 |
+
audio_future = self.thread_executor.submit(
|
| 203 |
+
self.preprocess_audio_parallel,
|
| 204 |
+
audio_path
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
image_future = self.thread_executor.submit(
|
| 208 |
+
self.preprocess_image_parallel,
|
| 209 |
+
image_path,
|
| 210 |
+
target_size
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Wait for results
|
| 214 |
+
audio_data = audio_future.result()
|
| 215 |
+
image_data = image_future.result()
|
| 216 |
+
|
| 217 |
+
return audio_data, image_data
|
| 218 |
+
|
| 219 |
+
def process_gpu_parallel(
|
| 220 |
+
self,
|
| 221 |
+
audio_tensor: torch.Tensor,
|
| 222 |
+
image_tensor: torch.Tensor,
|
| 223 |
+
model_audio: torch.nn.Module,
|
| 224 |
+
model_image: torch.nn.Module
|
| 225 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 226 |
+
"""
|
| 227 |
+
Process audio and image through models using CUDA streams
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
audio_tensor: Audio tensor
|
| 231 |
+
image_tensor: Image tensor
|
| 232 |
+
model_audio: Audio processing model
|
| 233 |
+
model_image: Image processing model
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Tuple of processed tensors
|
| 237 |
+
"""
|
| 238 |
+
if not self.use_cuda_streams:
|
| 239 |
+
# Fallback to sequential processing
|
| 240 |
+
audio_out = model_audio(audio_tensor)
|
| 241 |
+
image_out = model_image(image_tensor)
|
| 242 |
+
return audio_out, image_out
|
| 243 |
+
|
| 244 |
+
# Use CUDA streams for parallel GPU processing
|
| 245 |
+
with torch.cuda.stream(self.cuda_streams[0]):
|
| 246 |
+
audio_out = model_audio(audio_tensor)
|
| 247 |
+
|
| 248 |
+
with torch.cuda.stream(self.cuda_streams[1]):
|
| 249 |
+
image_out = model_image(image_tensor)
|
| 250 |
+
|
| 251 |
+
# Synchronize streams
|
| 252 |
+
torch.cuda.synchronize()
|
| 253 |
+
|
| 254 |
+
return audio_out, image_out
|
| 255 |
+
|
| 256 |
+
def shutdown(self):
|
| 257 |
+
"""Shutdown executors"""
|
| 258 |
+
self.thread_executor.shutdown(wait=True)
|
| 259 |
+
self.process_executor.shutdown(wait=True)
|
| 260 |
+
print("✅ ParallelProcessor shutdown complete")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class PipelineProcessor:
|
| 264 |
+
"""
|
| 265 |
+
Pipeline-based processing for continuous operations
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def __init__(self, stages: Dict[str, Callable], buffer_size: int = 10):
|
| 269 |
+
"""
|
| 270 |
+
Initialize pipeline processor
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
stages: Dictionary of stage_name -> processing_function
|
| 274 |
+
buffer_size: Size of queues between stages
|
| 275 |
+
"""
|
| 276 |
+
self.stages = stages
|
| 277 |
+
self.buffer_size = buffer_size
|
| 278 |
+
|
| 279 |
+
# Create queues between stages
|
| 280 |
+
self.queues = {}
|
| 281 |
+
stage_names = list(stages.keys())
|
| 282 |
+
for i in range(len(stage_names) - 1):
|
| 283 |
+
queue_name = f"{stage_names[i]}_to_{stage_names[i+1]}"
|
| 284 |
+
self.queues[queue_name] = queue.Queue(maxsize=buffer_size)
|
| 285 |
+
|
| 286 |
+
# Input and output queues
|
| 287 |
+
self.input_queue = queue.Queue(maxsize=buffer_size)
|
| 288 |
+
self.output_queue = queue.Queue(maxsize=buffer_size)
|
| 289 |
+
|
| 290 |
+
# Worker threads
|
| 291 |
+
self.workers = []
|
| 292 |
+
self.stop_event = threading.Event()
|
| 293 |
+
|
| 294 |
+
def _worker(self, stage_name: str, process_func: Callable, input_q: queue.Queue, output_q: queue.Queue):
|
| 295 |
+
"""Worker thread for a pipeline stage"""
|
| 296 |
+
while not self.stop_event.is_set():
|
| 297 |
+
try:
|
| 298 |
+
# Get input with timeout
|
| 299 |
+
item = input_q.get(timeout=0.1)
|
| 300 |
+
|
| 301 |
+
if item is None: # Poison pill
|
| 302 |
+
output_q.put(None)
|
| 303 |
+
break
|
| 304 |
+
|
| 305 |
+
# Process item
|
| 306 |
+
result = process_func(item)
|
| 307 |
+
|
| 308 |
+
# Put result
|
| 309 |
+
output_q.put(result)
|
| 310 |
+
|
| 311 |
+
except queue.Empty:
|
| 312 |
+
continue
|
| 313 |
+
except Exception as e:
|
| 314 |
+
print(f"Error in stage {stage_name}: {e}")
|
| 315 |
+
output_q.put(None)
|
| 316 |
+
|
| 317 |
+
def start(self):
|
| 318 |
+
"""Start pipeline processing"""
|
| 319 |
+
stage_names = list(self.stages.keys())
|
| 320 |
+
|
| 321 |
+
# Create worker threads
|
| 322 |
+
for i, (stage_name, process_func) in enumerate(self.stages.items()):
|
| 323 |
+
# Determine input and output queues
|
| 324 |
+
if i == 0:
|
| 325 |
+
input_q = self.input_queue
|
| 326 |
+
else:
|
| 327 |
+
queue_name = f"{stage_names[i-1]}_to_{stage_names[i]}"
|
| 328 |
+
input_q = self.queues[queue_name]
|
| 329 |
+
|
| 330 |
+
if i == len(stage_names) - 1:
|
| 331 |
+
output_q = self.output_queue
|
| 332 |
+
else:
|
| 333 |
+
queue_name = f"{stage_names[i]}_to_{stage_names[i+1]}"
|
| 334 |
+
output_q = self.queues[queue_name]
|
| 335 |
+
|
| 336 |
+
# Create and start worker
|
| 337 |
+
worker = threading.Thread(
|
| 338 |
+
target=self._worker,
|
| 339 |
+
args=(stage_name, process_func, input_q, output_q)
|
| 340 |
+
)
|
| 341 |
+
worker.start()
|
| 342 |
+
self.workers.append(worker)
|
| 343 |
+
|
| 344 |
+
print(f"✅ Pipeline started with {len(self.workers)} stages")
|
| 345 |
+
|
| 346 |
+
def process(self, item: Any) -> Any:
|
| 347 |
+
"""Process an item through the pipeline"""
|
| 348 |
+
self.input_queue.put(item)
|
| 349 |
+
return self.output_queue.get()
|
| 350 |
+
|
| 351 |
+
def stop(self):
|
| 352 |
+
"""Stop pipeline processing"""
|
| 353 |
+
self.stop_event.set()
|
| 354 |
+
|
| 355 |
+
# Send poison pills
|
| 356 |
+
self.input_queue.put(None)
|
| 357 |
+
|
| 358 |
+
# Wait for workers
|
| 359 |
+
for worker in self.workers:
|
| 360 |
+
worker.join()
|
| 361 |
+
|
| 362 |
+
print("✅ Pipeline stopped")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def benchmark_parallel_processing():
|
| 366 |
+
"""Benchmark parallel vs sequential processing"""
|
| 367 |
+
import time
|
| 368 |
+
|
| 369 |
+
print("\n=== Parallel Processing Benchmark ===")
|
| 370 |
+
|
| 371 |
+
# Create processor
|
| 372 |
+
processor = ParallelProcessor(num_threads=4)
|
| 373 |
+
|
| 374 |
+
# Test files (using example files)
|
| 375 |
+
audio_path = "example/audio.wav"
|
| 376 |
+
image_path = "example/image.png"
|
| 377 |
+
|
| 378 |
+
# Sequential processing
|
| 379 |
+
start_seq = time.time()
|
| 380 |
+
audio_data_seq = processor.preprocess_audio_parallel(audio_path)
|
| 381 |
+
image_data_seq = processor.preprocess_image_parallel(image_path)
|
| 382 |
+
time_seq = time.time() - start_seq
|
| 383 |
+
|
| 384 |
+
# Parallel processing
|
| 385 |
+
start_par = time.time()
|
| 386 |
+
audio_data_par, image_data_par = processor.preprocess_parallel_sync(audio_path, image_path)
|
| 387 |
+
time_par = time.time() - start_par
|
| 388 |
+
|
| 389 |
+
# Results
|
| 390 |
+
print(f"Sequential processing: {time_seq:.3f}s")
|
| 391 |
+
print(f"Parallel processing: {time_par:.3f}s")
|
| 392 |
+
print(f"Speedup: {time_seq/time_par:.2f}x")
|
| 393 |
+
|
| 394 |
+
processor.shutdown()
|
| 395 |
+
|
| 396 |
+
return {
|
| 397 |
+
'sequential_time': time_seq,
|
| 398 |
+
'parallel_time': time_par,
|
| 399 |
+
'speedup': time_seq / time_par
|
| 400 |
+
}
|