Spaces:
Running
Running
ο»Ώ#!/usr/bin/env python3 | |
""" | |
OmniAvatar-14B Inference Script | |
Enhanced implementation for avatar video generation with adaptive body animation | |
""" | |
import os | |
import sys | |
import argparse | |
import yaml | |
import torch | |
import logging | |
import time | |
from pathlib import Path | |
from typing import Dict, Any | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def load_config(config_path: str) -> Dict[str, Any]: | |
"""Load configuration from YAML file""" | |
try: | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
logger.info(f"β Configuration loaded from {config_path}") | |
return config | |
except Exception as e: | |
logger.error(f"β Failed to load config: {e}") | |
raise | |
def parse_input_file(input_file: str) -> list: | |
""" | |
Parse the input file with format: | |
[prompt]@@[img_path]@@[audio_path] | |
""" | |
try: | |
with open(input_file, 'r') as f: | |
lines = f.readlines() | |
samples = [] | |
for line_num, line in enumerate(lines, 1): | |
line = line.strip() | |
if not line or line.startswith('#'): | |
continue | |
parts = line.split('@@') | |
if len(parts) != 3: | |
logger.warning(f"β οΈ Line {line_num} has invalid format, skipping: {line}") | |
continue | |
prompt, img_path, audio_path = parts | |
# Validate paths | |
if img_path and not os.path.exists(img_path): | |
logger.warning(f"β οΈ Image not found: {img_path}") | |
img_path = None | |
if not os.path.exists(audio_path): | |
logger.error(f"β Audio file not found: {audio_path}") | |
continue | |
samples.append({ | |
'prompt': prompt, | |
'image_path': img_path if img_path else None, | |
'audio_path': audio_path, | |
'line_number': line_num | |
}) | |
logger.info(f"π Parsed {len(samples)} valid samples from {input_file}") | |
return samples | |
except Exception as e: | |
logger.error(f"β Failed to parse input file: {e}") | |
raise | |
def validate_models(config: Dict[str, Any]) -> bool: | |
"""Validate that all required models are available""" | |
model_paths = [ | |
config['model']['base_model_path'], | |
config['model']['omni_model_path'], | |
config['model']['wav2vec_path'] | |
] | |
missing_models = [] | |
for path in model_paths: | |
if not os.path.exists(path): | |
missing_models.append(path) | |
elif not any(Path(path).iterdir()): | |
missing_models.append(f"{path} (empty directory)") | |
if missing_models: | |
logger.error("β Missing required models:") | |
for model in missing_models: | |
logger.error(f" - {model}") | |
logger.info("π‘ Run 'python setup_omniavatar.py' to download models") | |
return False | |
logger.info("β All required models found") | |
return True | |
def setup_output_directory(output_dir: str) -> str: | |
"""Setup output directory and return path""" | |
os.makedirs(output_dir, exist_ok=True) | |
# Create unique subdirectory for this run | |
timestamp = time.strftime("%Y%m%d_%H%M%S") | |
run_dir = os.path.join(output_dir, f"run_{timestamp}") | |
os.makedirs(run_dir, exist_ok=True) | |
logger.info(f"π Output directory: {run_dir}") | |
return run_dir | |
def mock_inference(sample: Dict[str, Any], config: Dict[str, Any], | |
output_dir: str, args: argparse.Namespace) -> str: | |
""" | |
Mock inference implementation | |
In a real implementation, this would: | |
1. Load the OmniAvatar models | |
2. Process the audio with wav2vec2 | |
3. Generate video frames using the text-to-video model | |
4. Apply audio-driven animation | |
5. Render final video | |
""" | |
logger.info(f"π¬ Processing sample {sample['line_number']}") | |
logger.info(f"π Prompt: {sample['prompt']}") | |
logger.info(f"π΅ Audio: {sample['audio_path']}") | |
if sample['image_path']: | |
logger.info(f"πΌοΈ Image: {sample['image_path']}") | |
# Configuration | |
logger.info("βοΈ Configuration:") | |
logger.info(f" - Guidance Scale: {args.guidance_scale}") | |
logger.info(f" - Audio Scale: {args.audio_scale}") | |
logger.info(f" - Steps: {args.num_steps}") | |
logger.info(f" - Max Tokens: {config.get('inference', {}).get('max_tokens', 30000)}") | |
if args.tea_cache_l1_thresh: | |
logger.info(f" - TeaCache Threshold: {args.tea_cache_l1_thresh}") | |
# Simulate processing time | |
logger.info("π Generating avatar video...") | |
time.sleep(2) # Mock processing | |
# Create mock output file | |
output_filename = f"avatar_sample_{sample['line_number']:03d}.mp4" | |
output_path = os.path.join(output_dir, output_filename) | |
# Create a simple text file as placeholder for the video | |
with open(output_path.replace('.mp4', '_info.txt'), 'w') as f: | |
f.write(f"OmniAvatar-14B Output Information\n") | |
f.write(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") | |
f.write(f"Prompt: {sample['prompt']}\n") | |
f.write(f"Audio: {sample['audio_path']}\n") | |
f.write(f"Image: {sample['image_path'] or 'None'}\n") | |
f.write(f"Configuration: {args.__dict__}\n") | |
logger.info(f"β Mock output created: {output_path}") | |
return output_path | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="OmniAvatar-14B Inference - Avatar Video Generation with Adaptive Body Animation" | |
) | |
parser.add_argument("--config", type=str, required=True, | |
help="Configuration file path") | |
parser.add_argument("--input_file", type=str, required=True, | |
help="Input samples file") | |
parser.add_argument("--guidance_scale", type=float, default=4.5, | |
help="Guidance scale (4-6 recommended)") | |
parser.add_argument("--audio_scale", type=float, default=3.0, | |
help="Audio scale for lip-sync consistency") | |
parser.add_argument("--num_steps", type=int, default=25, | |
help="Number of inference steps (20-50 recommended)") | |
parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, | |
help="TeaCache L1 threshold (0.05-0.15 recommended)") | |
parser.add_argument("--sp_size", type=int, default=1, | |
help="Sequence parallel size (number of GPUs)") | |
parser.add_argument("--hp", type=str, default="", | |
help="Additional hyperparameters (comma-separated)") | |
args = parser.parse_args() | |
logger.info("π OmniAvatar-14B Inference Starting") | |
logger.info(f"π Config: {args.config}") | |
logger.info(f"π Input: {args.input_file}") | |
logger.info(f"π― Parameters: guidance_scale={args.guidance_scale}, audio_scale={args.audio_scale}, steps={args.num_steps}") | |
try: | |
# Load configuration | |
config = load_config(args.config) | |
# Validate models | |
if not validate_models(config): | |
return 1 | |
# Parse input samples | |
samples = parse_input_file(args.input_file) | |
if not samples: | |
logger.error("β No valid samples found in input file") | |
return 1 | |
# Setup output directory | |
output_dir = setup_output_directory(config.get('inference', {}).get('output_dir', './outputs')) | |
# Process each sample | |
total_samples = len(samples) | |
successful_outputs = [] | |
for i, sample in enumerate(samples, 1): | |
logger.info(f"π Processing sample {i}/{total_samples}") | |
try: | |
output_path = mock_inference(sample, config, output_dir, args) | |
successful_outputs.append(output_path) | |
except Exception as e: | |
logger.error(f"β Failed to process sample {sample['line_number']}: {e}") | |
continue | |
# Summary | |
logger.info("π Inference completed!") | |
logger.info(f"β Successfully processed: {len(successful_outputs)}/{total_samples} samples") | |
logger.info(f"π Output directory: {output_dir}") | |
if successful_outputs: | |
logger.info("πΉ Generated videos:") | |
for output in successful_outputs: | |
logger.info(f" - {output}") | |
# Implementation note | |
logger.info("π‘ NOTE: This is a mock implementation.") | |
logger.info("π For full OmniAvatar functionality, integrate with:") | |
logger.info(" https://github.com/Omni-Avatar/OmniAvatar") | |
return 0 | |
except Exception as e: | |
logger.error(f"β Inference failed: {e}") | |
return 1 | |
if __name__ == "__main__": | |
sys.exit(main()) | |