#!/usr/bin/env python3 """ OmniAvatar-14B Inference Script Enhanced implementation for avatar video generation with adaptive body animation """ import os import sys import argparse import yaml import torch import logging import time from pathlib import Path from typing import Dict, Any # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def load_config(config_path: str) -> Dict[str, Any]: """Load configuration from YAML file""" try: with open(config_path, 'r') as f: config = yaml.safe_load(f) logger.info(f"✅ Configuration loaded from {config_path}") return config except Exception as e: logger.error(f"❌ Failed to load config: {e}") raise def parse_input_file(input_file: str) -> list: """ Parse the input file with format: [prompt]@@[img_path]@@[audio_path] """ try: with open(input_file, 'r') as f: lines = f.readlines() samples = [] for line_num, line in enumerate(lines, 1): line = line.strip() if not line or line.startswith('#'): continue parts = line.split('@@') if len(parts) != 3: logger.warning(f"⚠️ Line {line_num} has invalid format, skipping: {line}") continue prompt, img_path, audio_path = parts # Validate paths if img_path and not os.path.exists(img_path): logger.warning(f"⚠️ Image not found: {img_path}") img_path = None if not os.path.exists(audio_path): logger.error(f"❌ Audio file not found: {audio_path}") continue samples.append({ 'prompt': prompt, 'image_path': img_path if img_path else None, 'audio_path': audio_path, 'line_number': line_num }) logger.info(f"📝 Parsed {len(samples)} valid samples from {input_file}") return samples except Exception as e: logger.error(f"❌ Failed to parse input file: {e}") raise def validate_models(config: Dict[str, Any]) -> bool: """Validate that all required models are available""" model_paths = [ config['model']['base_model_path'], config['model']['omni_model_path'], config['model']['wav2vec_path'] ] missing_models = [] for path in model_paths: if not os.path.exists(path): missing_models.append(path) elif not any(Path(path).iterdir()): missing_models.append(f"{path} (empty directory)") if missing_models: logger.error("❌ Missing required models:") for model in missing_models: logger.error(f" - {model}") logger.info("💡 Run 'python setup_omniavatar.py' to download models") return False logger.info("✅ All required models found") return True def setup_output_directory(output_dir: str) -> str: """Setup output directory and return path""" os.makedirs(output_dir, exist_ok=True) # Create unique subdirectory for this run timestamp = time.strftime("%Y%m%d_%H%M%S") run_dir = os.path.join(output_dir, f"run_{timestamp}") os.makedirs(run_dir, exist_ok=True) logger.info(f"📁 Output directory: {run_dir}") return run_dir def mock_inference(sample: Dict[str, Any], config: Dict[str, Any], output_dir: str, args: argparse.Namespace) -> str: """ Mock inference implementation In a real implementation, this would: 1. Load the OmniAvatar models 2. Process the audio with wav2vec2 3. Generate video frames using the text-to-video model 4. Apply audio-driven animation 5. Render final video """ logger.info(f"🎬 Processing sample {sample['line_number']}") logger.info(f"📝 Prompt: {sample['prompt']}") logger.info(f"🎵 Audio: {sample['audio_path']}") if sample['image_path']: logger.info(f"🖼️ Image: {sample['image_path']}") # Configuration logger.info("⚙️ Configuration:") logger.info(f" - Guidance Scale: {args.guidance_scale}") logger.info(f" - Audio Scale: {args.audio_scale}") logger.info(f" - Steps: {args.num_steps}") logger.info(f" - Max Tokens: {config.get('inference', {}).get('max_tokens', 30000)}") if args.tea_cache_l1_thresh: logger.info(f" - TeaCache Threshold: {args.tea_cache_l1_thresh}") # Simulate processing time logger.info("🔄 Generating avatar video...") time.sleep(2) # Mock processing # Create mock output file output_filename = f"avatar_sample_{sample['line_number']:03d}.mp4" output_path = os.path.join(output_dir, output_filename) # Create a simple text file as placeholder for the video with open(output_path.replace('.mp4', '_info.txt'), 'w') as f: f.write(f"OmniAvatar-14B Output Information\n") f.write(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Prompt: {sample['prompt']}\n") f.write(f"Audio: {sample['audio_path']}\n") f.write(f"Image: {sample['image_path'] or 'None'}\n") f.write(f"Configuration: {args.__dict__}\n") logger.info(f"✅ Mock output created: {output_path}") return output_path def main(): parser = argparse.ArgumentParser( description="OmniAvatar-14B Inference - Avatar Video Generation with Adaptive Body Animation" ) parser.add_argument("--config", type=str, required=True, help="Configuration file path") parser.add_argument("--input_file", type=str, required=True, help="Input samples file") parser.add_argument("--guidance_scale", type=float, default=4.5, help="Guidance scale (4-6 recommended)") parser.add_argument("--audio_scale", type=float, default=3.0, help="Audio scale for lip-sync consistency") parser.add_argument("--num_steps", type=int, default=25, help="Number of inference steps (20-50 recommended)") parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, help="TeaCache L1 threshold (0.05-0.15 recommended)") parser.add_argument("--sp_size", type=int, default=1, help="Sequence parallel size (number of GPUs)") parser.add_argument("--hp", type=str, default="", help="Additional hyperparameters (comma-separated)") args = parser.parse_args() logger.info("🚀 OmniAvatar-14B Inference Starting") logger.info(f"📄 Config: {args.config}") logger.info(f"📝 Input: {args.input_file}") logger.info(f"🎯 Parameters: guidance_scale={args.guidance_scale}, audio_scale={args.audio_scale}, steps={args.num_steps}") try: # Load configuration config = load_config(args.config) # Validate models if not validate_models(config): return 1 # Parse input samples samples = parse_input_file(args.input_file) if not samples: logger.error("❌ No valid samples found in input file") return 1 # Setup output directory output_dir = setup_output_directory(config.get('inference', {}).get('output_dir', './outputs')) # Process each sample total_samples = len(samples) successful_outputs = [] for i, sample in enumerate(samples, 1): logger.info(f"📊 Processing sample {i}/{total_samples}") try: output_path = mock_inference(sample, config, output_dir, args) successful_outputs.append(output_path) except Exception as e: logger.error(f"❌ Failed to process sample {sample['line_number']}: {e}") continue # Summary logger.info("🎉 Inference completed!") logger.info(f"✅ Successfully processed: {len(successful_outputs)}/{total_samples} samples") logger.info(f"📁 Output directory: {output_dir}") if successful_outputs: logger.info("📹 Generated videos:") for output in successful_outputs: logger.info(f" - {output}") # Implementation note logger.info("💡 NOTE: This is a mock implementation.") logger.info("🔗 For full OmniAvatar functionality, integrate with:") logger.info(" https://github.com/Omni-Avatar/OmniAvatar") return 0 except Exception as e: logger.error(f"❌ Inference failed: {e}") return 1 if __name__ == "__main__": sys.exit(main())