Spaces:
Paused
Paused
| """ | |
| Inference Cache System for DittoTalkingHead | |
| Caches video generation results for faster repeated processing | |
| """ | |
| import hashlib | |
| import json | |
| import os | |
| import pickle | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any, Tuple, Union | |
| from functools import lru_cache | |
| import shutil | |
| from datetime import datetime, timedelta | |
| class InferenceCache: | |
| """ | |
| Cache system for video generation results | |
| Supports both memory and file-based caching | |
| """ | |
| def __init__( | |
| self, | |
| cache_dir: str = "/tmp/inference_cache", | |
| memory_cache_size: int = 100, | |
| file_cache_size_gb: float = 10.0, | |
| ttl_hours: int = 24 | |
| ): | |
| """ | |
| Initialize inference cache | |
| Args: | |
| cache_dir: Directory for file-based cache | |
| memory_cache_size: Maximum number of items in memory cache | |
| file_cache_size_gb: Maximum size of file cache in GB | |
| ttl_hours: Time to live for cache entries in hours | |
| """ | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.memory_cache_size = memory_cache_size | |
| self.file_cache_size_bytes = int(file_cache_size_gb * 1024 * 1024 * 1024) | |
| self.ttl_seconds = ttl_hours * 3600 | |
| # Metadata file for managing cache | |
| self.metadata_file = self.cache_dir / "cache_metadata.json" | |
| self.metadata = self._load_metadata() | |
| # In-memory cache | |
| self._memory_cache = {} | |
| self._access_times = {} | |
| # Clean up expired entries on initialization | |
| self._cleanup_expired() | |
| def _load_metadata(self) -> Dict[str, Any]: | |
| """Load cache metadata""" | |
| if self.metadata_file.exists(): | |
| try: | |
| with open(self.metadata_file, 'r') as f: | |
| return json.load(f) | |
| except: | |
| return {} | |
| return {} | |
| def _save_metadata(self): | |
| """Save cache metadata""" | |
| with open(self.metadata_file, 'w') as f: | |
| json.dump(self.metadata, f, indent=2) | |
| def generate_cache_key( | |
| self, | |
| audio_path: str, | |
| image_path: str, | |
| **kwargs | |
| ) -> str: | |
| """ | |
| Generate unique cache key based on input parameters | |
| Args: | |
| audio_path: Path to audio file | |
| image_path: Path to image file | |
| **kwargs: Additional parameters affecting output | |
| Returns: | |
| SHA-256 hash as cache key | |
| """ | |
| # Read file contents for hashing | |
| with open(audio_path, 'rb') as f: | |
| audio_hash = hashlib.sha256(f.read()).hexdigest() | |
| with open(image_path, 'rb') as f: | |
| image_hash = hashlib.sha256(f.read()).hexdigest() | |
| # Include relevant parameters in key | |
| key_data = { | |
| 'audio': audio_hash, | |
| 'image': image_hash, | |
| 'resolution': kwargs.get('resolution', '320x320'), | |
| 'steps': kwargs.get('steps', 25), | |
| 'seed': kwargs.get('seed', None) | |
| } | |
| # Generate final key | |
| key_str = json.dumps(key_data, sort_keys=True) | |
| return hashlib.sha256(key_str.encode()).hexdigest() | |
| def get_from_memory(self, cache_key: str) -> Optional[str]: | |
| """ | |
| Get video path from memory cache | |
| Args: | |
| cache_key: Cache key | |
| Returns: | |
| Video file path if found, None otherwise | |
| """ | |
| if cache_key in self._memory_cache: | |
| self._access_times[cache_key] = time.time() | |
| return self._memory_cache[cache_key] | |
| return None | |
| def get_from_file(self, cache_key: str) -> Optional[str]: | |
| """ | |
| Get video path from file cache | |
| Args: | |
| cache_key: Cache key | |
| Returns: | |
| Video file path if found, None otherwise | |
| """ | |
| if cache_key not in self.metadata: | |
| return None | |
| entry = self.metadata[cache_key] | |
| # Check expiration | |
| if time.time() > entry['expires_at']: | |
| self._remove_cache_entry(cache_key) | |
| return None | |
| # Check if file exists | |
| video_path = self.cache_dir / entry['filename'] | |
| if not video_path.exists(): | |
| self._remove_cache_entry(cache_key) | |
| return None | |
| # Update access time | |
| self.metadata[cache_key]['last_access'] = time.time() | |
| self._save_metadata() | |
| # Add to memory cache | |
| self._add_to_memory_cache(cache_key, str(video_path)) | |
| return str(video_path) | |
| def get(self, cache_key: str) -> Optional[str]: | |
| """ | |
| Get video from cache (memory first, then file) | |
| Args: | |
| cache_key: Cache key | |
| Returns: | |
| Video file path if found, None otherwise | |
| """ | |
| # Try memory cache first | |
| result = self.get_from_memory(cache_key) | |
| if result: | |
| return result | |
| # Try file cache | |
| return self.get_from_file(cache_key) | |
| def put( | |
| self, | |
| cache_key: str, | |
| video_path: str, | |
| **metadata | |
| ) -> bool: | |
| """ | |
| Store video in cache | |
| Args: | |
| cache_key: Cache key | |
| video_path: Path to generated video | |
| **metadata: Additional metadata to store | |
| Returns: | |
| True if stored successfully | |
| """ | |
| try: | |
| # Copy video to cache directory | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| cache_filename = f"{cache_key[:8]}_{timestamp}.mp4" | |
| cache_video_path = self.cache_dir / cache_filename | |
| shutil.copy2(video_path, cache_video_path) | |
| # Store metadata | |
| self.metadata[cache_key] = { | |
| 'filename': cache_filename, | |
| 'created_at': time.time(), | |
| 'expires_at': time.time() + self.ttl_seconds, | |
| 'last_access': time.time(), | |
| 'size_bytes': os.path.getsize(cache_video_path), | |
| 'metadata': metadata | |
| } | |
| # Check cache size and clean if needed | |
| self._check_cache_size() | |
| # Save metadata | |
| self._save_metadata() | |
| # Add to memory cache | |
| self._add_to_memory_cache(cache_key, str(cache_video_path)) | |
| return True | |
| except Exception as e: | |
| print(f"Error storing cache: {e}") | |
| return False | |
| def _add_to_memory_cache(self, cache_key: str, video_path: str): | |
| """Add item to memory cache with LRU eviction""" | |
| # Check if we need to evict | |
| if len(self._memory_cache) >= self.memory_cache_size: | |
| # Find least recently used | |
| lru_key = min(self._access_times, key=self._access_times.get) | |
| del self._memory_cache[lru_key] | |
| del self._access_times[lru_key] | |
| self._memory_cache[cache_key] = video_path | |
| self._access_times[cache_key] = time.time() | |
| def _check_cache_size(self): | |
| """Check and maintain cache size limit""" | |
| total_size = sum( | |
| entry['size_bytes'] | |
| for entry in self.metadata.values() | |
| ) | |
| if total_size > self.file_cache_size_bytes: | |
| # Remove oldest entries until under limit | |
| sorted_entries = sorted( | |
| self.metadata.items(), | |
| key=lambda x: x[1]['last_access'] | |
| ) | |
| while total_size > self.file_cache_size_bytes and sorted_entries: | |
| key_to_remove, entry = sorted_entries.pop(0) | |
| total_size -= entry['size_bytes'] | |
| self._remove_cache_entry(key_to_remove) | |
| def _cleanup_expired(self): | |
| """Remove expired cache entries""" | |
| current_time = time.time() | |
| expired_keys = [ | |
| key for key, entry in self.metadata.items() | |
| if current_time > entry['expires_at'] | |
| ] | |
| for key in expired_keys: | |
| self._remove_cache_entry(key) | |
| if expired_keys: | |
| print(f"Cleaned up {len(expired_keys)} expired cache entries") | |
| def _remove_cache_entry(self, cache_key: str): | |
| """Remove a cache entry""" | |
| if cache_key in self.metadata: | |
| # Remove file | |
| video_file = self.cache_dir / self.metadata[cache_key]['filename'] | |
| if video_file.exists(): | |
| video_file.unlink() | |
| # Remove from metadata | |
| del self.metadata[cache_key] | |
| # Remove from memory cache | |
| if cache_key in self._memory_cache: | |
| del self._memory_cache[cache_key] | |
| del self._access_times[cache_key] | |
| def clear_cache(self): | |
| """Clear all cache entries""" | |
| # Remove all video files | |
| for file in self.cache_dir.glob("*.mp4"): | |
| file.unlink() | |
| # Clear metadata | |
| self.metadata = {} | |
| self._save_metadata() | |
| # Clear memory cache | |
| self._memory_cache.clear() | |
| self._access_times.clear() | |
| print("Inference cache cleared") | |
| def get_cache_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics""" | |
| total_size = sum( | |
| entry['size_bytes'] | |
| for entry in self.metadata.values() | |
| ) | |
| memory_hits = len(self._memory_cache) | |
| file_entries = len(self.metadata) | |
| return { | |
| 'memory_cache_entries': memory_hits, | |
| 'file_cache_entries': file_entries, | |
| 'total_cache_size_mb': total_size / (1024 * 1024), | |
| 'cache_size_limit_gb': self.file_cache_size_bytes / (1024 * 1024 * 1024), | |
| 'ttl_hours': self.ttl_seconds / 3600, | |
| 'cache_directory': str(self.cache_dir) | |
| } | |
| class CachedInference: | |
| """ | |
| Wrapper for cached inference execution | |
| """ | |
| def __init__(self, cache: InferenceCache): | |
| """ | |
| Initialize cached inference | |
| Args: | |
| cache: InferenceCache instance | |
| """ | |
| self.cache = cache | |
| def process_with_cache( | |
| self, | |
| inference_func: callable, | |
| audio_path: str, | |
| image_path: str, | |
| output_path: str, | |
| **kwargs | |
| ) -> Tuple[str, bool, float]: | |
| """ | |
| Process with caching | |
| Args: | |
| inference_func: Function to generate video | |
| audio_path: Path to audio file | |
| image_path: Path to image file | |
| output_path: Desired output path | |
| **kwargs: Additional parameters | |
| Returns: | |
| Tuple of (output_path, cache_hit, process_time) | |
| """ | |
| start_time = time.time() | |
| # Generate cache key | |
| cache_key = self.cache.generate_cache_key( | |
| audio_path, image_path, **kwargs | |
| ) | |
| # Check cache | |
| cached_video = self.cache.get(cache_key) | |
| if cached_video: | |
| # Cache hit - copy to output path | |
| shutil.copy2(cached_video, output_path) | |
| process_time = time.time() - start_time | |
| print(f"✅ Cache hit! Retrieved in {process_time:.2f}s") | |
| return output_path, True, process_time | |
| # Cache miss - generate video | |
| print("Cache miss - generating video...") | |
| inference_func(audio_path, image_path, output_path, **kwargs) | |
| # Store in cache | |
| if os.path.exists(output_path): | |
| self.cache.put(cache_key, output_path, **kwargs) | |
| process_time = time.time() - start_time | |
| return output_path, False, process_time |