Spaces:
Running
Running
""" | |
AI Avatar Chat - HF Spaces Optimized Version | |
BUILD: 2025-01-08_00-44-FORCE-REBUILD - With Model Download Controls | |
FEATURES: Real video generation, model download UI, storage optimization | |
""" | |
import os | |
# STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads | |
IS_HF_SPACE = any([ | |
os.getenv("SPACE_ID"), | |
os.getenv("SPACE_AUTHOR_NAME"), | |
os.getenv("SPACES_BUILDKIT_VERSION"), | |
"/home/user/app" in os.getcwd() | |
]) | |
if IS_HF_SPACE: | |
# Force TTS-only mode to prevent storage limit exceeded | |
# os.environ[\"DISABLE_MODEL_DOWNLOAD\"] = \"1\" # ENABLED FOR VIDEO GENERATION | |
# os.environ[\"TTS_ONLY_MODE\"] = \"1\" # ENABLED FOR VIDEO GENERATION | |
os.environ["HF_SPACE_STORAGE_OPTIMIZED"] = "1" | |
print("?? STORAGE OPTIMIZATION: Detected HF Space environment") | |
print("?? Video generation ENABLED (models need manual download)") | |
print("?? WARNING: Use /download-models endpoint to download ~30GB models first") | |
import os | |
import torch | |
import tempfile | |
import gradio as gr | |
from fastapi import FastAPI, HTTPException | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, HttpUrl | |
import subprocess | |
import json | |
from pathlib import Path | |
import logging | |
import requests | |
from urllib.parse import urlparse | |
from PIL import Image | |
import io | |
from typing import Optional | |
import aiohttp | |
import asyncio | |
# Safe dotenv import | |
try: | |
from dotenv import load_dotenv | |
load_dotenv() | |
except ImportError: | |
print("Warning: python-dotenv not found, continuing without .env support") | |
def load_dotenv(): | |
pass | |
# CRITICAL: HF Spaces compatibility fix | |
try: | |
from hf_spaces_fix import setup_hf_spaces_environment, HFSpacesCompatible | |
setup_hf_spaces_environment() | |
except ImportError: | |
print('Warning: HF Spaces fix not available') | |
# Load environment variables | |
load_dotenv() | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set environment variables for matplotlib, gradio, and huggingface cache | |
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
os.environ['GRADIO_ALLOW_FLAGGING'] = 'never' | |
os.environ['HF_HOME'] = '/tmp/huggingface' | |
# Use HF_HOME instead of deprecated TRANSFORMERS_CACHE | |
os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets' | |
os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub' | |
# FastAPI app will be created after lifespan is defined | |
# Create directories with proper permissions | |
os.makedirs("outputs", exist_ok=True) | |
os.makedirs("/tmp/matplotlib", exist_ok=True) | |
os.makedirs("/tmp/huggingface", exist_ok=True) | |
os.makedirs("/tmp/huggingface/transformers", exist_ok=True) | |
os.makedirs("/tmp/huggingface/datasets", exist_ok=True) | |
os.makedirs("/tmp/huggingface/hub", exist_ok=True) | |
# Mount static files for serving generated videos | |
def get_video_url(output_path: str) -> str: | |
"""Convert local file path to accessible URL""" | |
try: | |
from pathlib import Path | |
filename = Path(output_path).name | |
# For HuggingFace Spaces, construct the URL | |
base_url = "https://bravedims-ai-avatar-chat.hf.space" | |
video_url = f"{base_url}/outputs/{filename}" | |
logger.info(f"Generated video URL: {video_url}") | |
return video_url | |
except Exception as e: | |
logger.error(f"Error creating video URL: {e}") | |
return output_path # Fallback to original path | |
# Pydantic models for request/response | |
class GenerateRequest(BaseModel): | |
prompt: str | |
text_to_speech: Optional[str] = None # Text to convert to speech | |
audio_url: Optional[HttpUrl] = None # Direct audio URL | |
voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID | |
image_url: Optional[HttpUrl] = None | |
guidance_scale: float = 5.0 | |
audio_scale: float = 3.0 | |
num_steps: int = 30 | |
sp_size: int = 1 | |
tea_cache_l1_thresh: Optional[float] = None | |
class GenerateResponse(BaseModel): | |
message: str | |
output_path: str | |
processing_time: float | |
audio_generated: bool = False | |
tts_method: Optional[str] = None | |
# Try to import TTS clients, but make them optional | |
try: | |
from advanced_tts_client import AdvancedTTSClient | |
ADVANCED_TTS_AVAILABLE = True | |
logger.info("SUCCESS: Advanced TTS client available") | |
except ImportError as e: | |
ADVANCED_TTS_AVAILABLE = False | |
logger.warning(f"WARNING: Advanced TTS client not available: {e}") | |
# Always import the robust fallback | |
try: | |
from robust_tts_client import RobustTTSClient | |
ROBUST_TTS_AVAILABLE = True | |
logger.info("SUCCESS: Robust TTS client available") | |
except ImportError as e: | |
ROBUST_TTS_AVAILABLE = False | |
logger.error(f"ERROR: Robust TTS client not available: {e}") | |
class TTSManager: | |
"""Manages multiple TTS clients with fallback chain""" | |
def __init__(self): | |
# Initialize TTS clients based on availability | |
self.advanced_tts = None | |
self.robust_tts = None | |
self.clients_loaded = False | |
if ADVANCED_TTS_AVAILABLE: | |
try: | |
self.advanced_tts = AdvancedTTSClient() | |
logger.info("SUCCESS: Advanced TTS client initialized") | |
except Exception as e: | |
logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}") | |
if ROBUST_TTS_AVAILABLE: | |
try: | |
self.robust_tts = RobustTTSClient() | |
logger.info("SUCCESS: Robust TTS client initialized") | |
except Exception as e: | |
logger.error(f"ERROR: Robust TTS client initialization failed: {e}") | |
if not self.advanced_tts and not self.robust_tts: | |
logger.error("ERROR: No TTS clients available!") | |
async def load_models(self): | |
"""Load TTS models""" | |
try: | |
logger.info("Loading TTS models...") | |
# Try to load advanced TTS first | |
if self.advanced_tts: | |
try: | |
logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...") | |
success = await self.advanced_tts.load_models() | |
if success: | |
logger.info("SUCCESS: Advanced TTS models loaded successfully") | |
else: | |
logger.warning("WARNING: Advanced TTS models failed to load") | |
except Exception as e: | |
logger.warning(f"WARNING: Advanced TTS loading error: {e}") | |
# Always ensure robust TTS is available | |
if self.robust_tts: | |
try: | |
await self.robust_tts.load_model() | |
logger.info("SUCCESS: Robust TTS fallback ready") | |
except Exception as e: | |
logger.error(f"ERROR: Robust TTS loading failed: {e}") | |
self.clients_loaded = True | |
return True | |
except Exception as e: | |
logger.error(f"ERROR: TTS manager initialization failed: {e}") | |
return False | |
async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]: | |
""" | |
Convert text to speech with fallback chain | |
Returns: (audio_file_path, method_used) | |
""" | |
if not self.clients_loaded: | |
logger.info("TTS models not loaded, loading now...") | |
await self.load_models() | |
logger.info(f"Generating speech: {text[:50]}...") | |
logger.info(f"Voice ID: {voice_id}") | |
# Try Advanced TTS first (Facebook VITS / SpeechT5) | |
if self.advanced_tts: | |
try: | |
audio_path = await self.advanced_tts.text_to_speech(text, voice_id) | |
return audio_path, "Facebook VITS/SpeechT5" | |
except Exception as advanced_error: | |
logger.warning(f"Advanced TTS failed: {advanced_error}") | |
# Fall back to robust TTS | |
if self.robust_tts: | |
try: | |
logger.info("Falling back to robust TTS...") | |
audio_path = await self.robust_tts.text_to_speech(text, voice_id) | |
return audio_path, "Robust TTS (Fallback)" | |
except Exception as robust_error: | |
logger.error(f"Robust TTS also failed: {robust_error}") | |
# If we get here, all methods failed | |
logger.error("All TTS methods failed!") | |
raise HTTPException( | |
status_code=500, | |
detail="All TTS methods failed. Please check system configuration." | |
) | |
async def get_available_voices(self): | |
"""Get available voice configurations""" | |
try: | |
if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'): | |
return await self.advanced_tts.get_available_voices() | |
except: | |
pass | |
# Return default voices if advanced TTS not available | |
return { | |
"21m00Tcm4TlvDq8ikWAM": "Female (Neutral)", | |
"pNInz6obpgDQGcFmaJgB": "Male (Professional)", | |
"EXAVITQu4vr4xnSDxMaL": "Female (Sweet)", | |
"ErXwobaYiN019PkySvjV": "Male (Professional)", | |
"TxGEqnHWrfGW9XjX": "Male (Deep)", | |
"yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)", | |
"AZnzlk1XvdvUeBnXmlld": "Female (Strong)" | |
} | |
def get_tts_info(self): | |
"""Get TTS system information""" | |
info = { | |
"clients_loaded": self.clients_loaded, | |
"advanced_tts_available": self.advanced_tts is not None, | |
"robust_tts_available": self.robust_tts is not None, | |
"primary_method": "Robust TTS" | |
} | |
try: | |
if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'): | |
advanced_info = self.advanced_tts.get_model_info() | |
info.update({ | |
"advanced_tts_loaded": advanced_info.get("models_loaded", False), | |
"transformers_available": advanced_info.get("transformers_available", False), | |
"primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS", | |
"device": advanced_info.get("device", "cpu"), | |
"vits_available": advanced_info.get("vits_available", False), | |
"speecht5_available": advanced_info.get("speecht5_available", False) | |
}) | |
except Exception as e: | |
logger.debug(f"Could not get advanced TTS info: {e}") | |
return info | |
# Import the VIDEO-FOCUSED engine | |
try: | |
from omniavatar_video_engine import video_engine | |
VIDEO_ENGINE_AVAILABLE = True | |
logger.info("SUCCESS: OmniAvatar Video Engine available") | |
except ImportError as e: | |
VIDEO_ENGINE_AVAILABLE = False | |
logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}") | |
class OmniAvatarAPI: | |
def __init__(self): | |
self.model_loaded = False | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.tts_manager = TTSManager() | |
logger.info(f"Using device: {self.device}") | |
logger.info("Initialized with robust TTS system") | |
def load_model(self): | |
"""Load the OmniAvatar model - now more flexible""" | |
try: | |
# Check if models are downloaded (but don't require them) | |
# Check both traditional and downloaded model paths | |
downloaded_video = "./downloaded_models/video" | |
downloaded_audio = "./downloaded_models/audio" | |
# Check downloaded models first | |
if os.path.exists(downloaded_video) and os.path.exists(downloaded_audio): | |
video_files = len([f for f in os.listdir(downloaded_video) if os.path.isfile(os.path.join(downloaded_video, f))]) if os.path.isdir(downloaded_video) else 0 | |
audio_files = len([f for f in os.listdir(downloaded_audio) if os.path.isfile(os.path.join(downloaded_audio, f))]) if os.path.isdir(downloaded_audio) else 0 | |
if video_files > 5 and audio_files > 5: | |
missing_models.append(path) | |
if missing_models: | |
logger.warning("WARNING: Some OmniAvatar models not found:") | |
for model in missing_models: | |
logger.warning(f" - {model}") | |
logger.info("TIP: App will run in TTS-only mode (no video generation)") | |
logger.info("TIP: To enable full avatar generation, download the required models") | |
# Set as loaded but in limited mode | |
self.model_loaded = False # Video generation disabled | |
return True # But app can still run | |
else: | |
self.model_loaded = True | |
logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled") | |
return True | |
except Exception as e: | |
logger.error(f"Error checking models: {str(e)}") | |
logger.info("TIP: Continuing in TTS-only mode") | |
self.model_loaded = False | |
return True # Continue running | |
async def download_file(self, url: str, suffix: str = "") -> str: | |
"""Download file from URL and save to temporary location""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(str(url)) as response: | |
if response.status != 200: | |
raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}") | |
content = await response.read() | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
temp_file.write(content) | |
temp_file.close() | |
return temp_file.name | |
except aiohttp.ClientError as e: | |
logger.error(f"Network error downloading {url}: {e}") | |
raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}") | |
except Exception as e: | |
logger.error(f"Error downloading file from {url}: {e}") | |
raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") | |
def validate_audio_url(self, url: str) -> bool: | |
"""Validate if URL is likely an audio file""" | |
try: | |
parsed = urlparse(url) | |
# Check for common audio file extensions | |
audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac'] | |
is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions) | |
return is_audio_ext or 'audio' in url.lower() | |
except: | |
return False | |
def validate_image_url(self, url: str) -> bool: | |
"""Validate if URL is likely an image file""" | |
try: | |
parsed = urlparse(url) | |
image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'] | |
return any(parsed.path.lower().endswith(ext) for ext in image_extensions) | |
except: | |
return False | |
async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]: | |
"""Generate avatar VIDEO - PRIMARY FUNCTIONALITY""" | |
import time | |
start_time = time.time() | |
audio_generated = False | |
method_used = "Unknown" | |
logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION") | |
logger.info(f"[INFO] Prompt: {request.prompt}") | |
if VIDEO_ENGINE_AVAILABLE: | |
try: | |
# PRIORITIZE VIDEO GENERATION | |
logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation") | |
# Handle audio source | |
audio_path = None | |
if request.text_to_speech: | |
logger.info("[MIC] Generating audio from text...") | |
audio_path, method_used = await self.tts_manager.text_to_speech( | |
request.text_to_speech, | |
request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
) | |
audio_generated = True | |
elif request.audio_url: | |
logger.info("π₯ Downloading audio from URL...") | |
audio_path = await self.download_file(str(request.audio_url), ".mp3") | |
method_used = "External Audio" | |
else: | |
raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation") | |
# Handle image if provided | |
image_path = None | |
if request.image_url: | |
logger.info("[IMAGE] Downloading reference image...") | |
parsed = urlparse(str(request.image_url)) | |
ext = os.path.splitext(parsed.path)[1] or ".jpg" | |
image_path = await self.download_file(str(request.image_url), ext) | |
# GENERATE VIDEO using OmniAvatar engine | |
logger.info("[VIDEO] Generating avatar video with adaptive body animation...") | |
video_path, generation_time = video_engine.generate_avatar_video( | |
prompt=request.prompt, | |
audio_path=audio_path, | |
image_path=image_path, | |
guidance_scale=request.guidance_scale, | |
audio_scale=request.audio_scale, | |
num_steps=request.num_steps | |
) | |
processing_time = time.time() - start_time | |
logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s") | |
# Cleanup temporary files | |
if audio_path and os.path.exists(audio_path): | |
os.unlink(audio_path) | |
if image_path and os.path.exists(image_path): | |
os.unlink(image_path) | |
return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})" | |
except Exception as e: | |
logger.error(f"ERROR: Video generation failed: {e}") | |
# For a VIDEO generation app, we should NOT fall back to audio-only | |
# Instead, provide clear guidance | |
if "models" in str(e).lower(): | |
raise HTTPException( | |
status_code=503, | |
detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}" | |
) | |
else: | |
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}") | |
# If video engine not available, this is a critical error for a VIDEO app | |
raise HTTPException( | |
status_code=503, | |
detail="Video generation engine not available. This application requires OmniAvatar models for video generation." | |
) | |
async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]: | |
"""OLD TTS-ONLY METHOD - kept as backup reference. | |
Generate avatar video from prompt and audio/text - now handles missing models""" | |
import time | |
start_time = time.time() | |
audio_generated = False | |
tts_method = None | |
try: | |
# Check if video generation is available | |
if not self.model_loaded: | |
logger.info("ποΈ Running in TTS-only mode (OmniAvatar models not available)") | |
# Only generate audio, no video | |
if request.text_to_speech: | |
logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") | |
audio_path, tts_method = await self.tts_manager.text_to_speech( | |
request.text_to_speech, | |
request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
) | |
# Return the audio file as the "output" | |
processing_time = time.time() - start_time | |
logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}") | |
return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)" | |
else: | |
raise HTTPException( | |
status_code=503, | |
detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported." | |
) | |
# Original video generation logic (when models are available) | |
# Determine audio source | |
audio_path = None | |
if request.text_to_speech: | |
# Generate speech from text using TTS manager | |
logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") | |
audio_path, tts_method = await self.tts_manager.text_to_speech( | |
request.text_to_speech, | |
request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
) | |
audio_generated = True | |
elif request.audio_url: | |
# Download audio from provided URL | |
logger.info(f"Downloading audio from URL: {request.audio_url}") | |
if not self.validate_audio_url(str(request.audio_url)): | |
logger.warning(f"Audio URL may not be valid: {request.audio_url}") | |
audio_path = await self.download_file(str(request.audio_url), ".mp3") | |
tts_method = "External Audio URL" | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail="Either text_to_speech or audio_url must be provided" | |
) | |
# Download image if provided | |
image_path = None | |
if request.image_url: | |
logger.info(f"Downloading image from URL: {request.image_url}") | |
if not self.validate_image_url(str(request.image_url)): | |
logger.warning(f"Image URL may not be valid: {request.image_url}") | |
# Determine image extension from URL or default to .jpg | |
parsed = urlparse(str(request.image_url)) | |
ext = os.path.splitext(parsed.path)[1] or ".jpg" | |
image_path = await self.download_file(str(request.image_url), ext) | |
# Create temporary input file for inference | |
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: | |
if image_path: | |
input_line = f"{request.prompt}@@{image_path}@@{audio_path}" | |
else: | |
input_line = f"{request.prompt}@@@@{audio_path}" | |
f.write(input_line) | |
temp_input_file = f.name | |
# Prepare inference command | |
cmd = [ | |
"python", "-m", "torch.distributed.run", | |
"--standalone", f"--nproc_per_node={request.sp_size}", | |
"scripts/inference.py", | |
"--config", "configs/inference.yaml", | |
"--input_file", temp_input_file, | |
"--guidance_scale", str(request.guidance_scale), | |
"--audio_scale", str(request.audio_scale), | |
"--num_steps", str(request.num_steps) | |
] | |
if request.tea_cache_l1_thresh: | |
cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)]) | |
logger.info(f"Running inference with command: {' '.join(cmd)}") | |
# Run inference | |
result = subprocess.run(cmd, capture_output=True, text=True) | |
# Clean up temporary files | |
os.unlink(temp_input_file) | |
os.unlink(audio_path) | |
if image_path: | |
os.unlink(image_path) | |
if result.returncode != 0: | |
logger.error(f"Inference failed: {result.stderr}") | |
raise Exception(f"Inference failed: {result.stderr}") | |
# Find output video file | |
output_dir = "./outputs" | |
if os.path.exists(output_dir): | |
video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))] | |
if video_files: | |
# Return the most recent video file | |
video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True) | |
output_path = os.path.join(output_dir, video_files[0]) | |
processing_time = time.time() - start_time | |
return output_path, processing_time, audio_generated, tts_method | |
raise Exception("No output video generated") | |
except Exception as e: | |
# Clean up any temporary files in case of error | |
try: | |
if 'audio_path' in locals() and audio_path and os.path.exists(audio_path): | |
os.unlink(audio_path) | |
if 'image_path' in locals() and image_path and os.path.exists(image_path): | |
os.unlink(image_path) | |
if 'temp_input_file' in locals() and os.path.exists(temp_input_file): | |
os.unlink(temp_input_file) | |
except: | |
pass | |
logger.error(f"Generation error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Initialize API | |
omni_api = OmniAvatarAPI() | |
# Use FastAPI lifespan instead of deprecated on_event | |
from contextlib import asynccontextmanager | |
async def lifespan(app: FastAPI): | |
# Startup | |
success = omni_api.load_model() | |
if not success: | |
logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode") | |
# Load TTS models | |
try: | |
await omni_api.tts_manager.load_models() | |
logger.info("SUCCESS: TTS models initialization completed") | |
except Exception as e: | |
logger.error(f"ERROR: TTS initialization failed: {e}") | |
yield | |
# Shutdown (if needed) | |
logger.info("Application shutting down...") | |
# Create FastAPI app WITH lifespan parameter | |
app = FastAPI( | |
title="OmniAvatar-14B API with Advanced TTS", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Mount static files for serving generated videos | |
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") | |
async def health_check(): | |
"""Health check endpoint""" | |
tts_info = omni_api.tts_manager.get_tts_info() | |
return { | |
"status": "healthy", | |
"model_loaded": omni_api.model_loaded, | |
"video_generation_available": omni_api.model_loaded, | |
"tts_only_mode": not omni_api.model_loaded, | |
"device": omni_api.device, | |
"supports_text_to_speech": True, | |
"supports_image_urls": omni_api.model_loaded, | |
"supports_audio_urls": omni_api.model_loaded, | |
"tts_system": "Advanced TTS with Robust Fallback", | |
"advanced_tts_available": ADVANCED_TTS_AVAILABLE, | |
"robust_tts_available": ROBUST_TTS_AVAILABLE, | |
**tts_info | |
} | |
async def get_voices(): | |
"""Get available voice configurations""" | |
try: | |
voices = await omni_api.tts_manager.get_available_voices() | |
return {"voices": voices} | |
except Exception as e: | |
logger.error(f"Error getting voices: {e}") | |
return {"error": str(e)} | |
async def generate_avatar(request: GenerateRequest): | |
"""Generate avatar video from prompt, text/audio, and optional image URL""" | |
logger.info(f"Generating avatar with prompt: {request.prompt}") | |
if request.text_to_speech: | |
logger.info(f"Text to speech: {request.text_to_speech[:100]}...") | |
logger.info(f"Voice ID: {request.voice_id}") | |
if request.audio_url: | |
logger.info(f"Audio URL: {request.audio_url}") | |
if request.image_url: | |
logger.info(f"Image URL: {request.image_url}") | |
try: | |
output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request) | |
return GenerateResponse( | |
message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""), | |
output_path=get_video_url(output_path) if omni_api.model_loaded else output_path, | |
processing_time=processing_time, | |
audio_generated=audio_generated, | |
tts_method=tts_method | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") | |
async def download_video_models(): | |
"""Manually trigger video model downloads""" | |
logger.info("?? Manual model download requested...") | |
try: | |
from huggingface_hub import snapshot_download | |
import shutil | |
# Check storage first | |
_, _, free_bytes = shutil.disk_usage(".") | |
free_gb = free_bytes / (1024**3) | |
logger.info(f"?? Available storage: {free_gb:.1f}GB") | |
if free_gb < 10: # Need at least 10GB free | |
return { | |
"success": False, | |
"message": f"Insufficient storage: {free_gb:.1f}GB available, 10GB+ required", | |
"storage_gb": free_gb | |
} | |
# Download small video generation model | |
logger.info("?? Downloading text-to-video model...") | |
model_path = snapshot_download( | |
repo_id="ali-vilab/text-to-video-ms-1.7b", | |
cache_dir="./downloaded_models/video", | |
local_files_only=False | |
) | |
logger.info(f"? Video model downloaded: {model_path}") | |
# Download audio model | |
audio_model_path = snapshot_download( | |
repo_id="facebook/wav2vec2-base-960h", | |
cache_dir="./downloaded_models/audio", | |
local_files_only=False | |
) | |
logger.info(f"? Audio model downloaded: {audio_model_path}") | |
# Check final storage usage | |
_, _, free_bytes_after = shutil.disk_usage(".") | |
free_gb_after = free_bytes_after / (1024**3) | |
used_gb = free_gb - free_gb_after | |
return { | |
"success": True, | |
"message": "? Video generation models downloaded successfully!", | |
"models_downloaded": [ | |
"ali-vilab/text-to-video-ms-1.7b", | |
"facebook/wav2vec2-base-960h" | |
], | |
"storage_used_gb": round(used_gb, 2), | |
"storage_remaining_gb": round(free_gb_after, 2), | |
"video_model_path": model_path, | |
"audio_model_path": audio_model_path, | |
"status": "READY FOR VIDEO GENERATION" | |
} | |
except Exception as e: | |
logger.error(f"? Model download failed: {e}") | |
return { | |
"success": False, | |
"message": f"Model download failed: {str(e)}", | |
"error": str(e) | |
} | |
async def get_model_status(): | |
"""Check status of downloaded models""" | |
try: | |
models_dir = Path("./downloaded_models") | |
status = { | |
"models_downloaded": models_dir.exists(), | |
"available_models": [], | |
"storage_info": {} | |
} | |
if models_dir.exists(): | |
for model_dir in models_dir.iterdir(): | |
if model_dir.is_dir(): | |
status["available_models"].append({ | |
"name": model_dir.name, | |
"path": str(model_dir), | |
"files": len(list(model_dir.rglob("*"))) | |
}) | |
# Storage info | |
import shutil | |
_, _, free_bytes = shutil.disk_usage(".") | |
status["storage_info"] = { | |
"free_gb": round(free_bytes / (1024**3), 2), | |
"models_dir_exists": models_dir.exists() | |
} | |
return status | |
except Exception as e: | |
return {"error": str(e)} | |
# Enhanced Gradio interface | |
def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps): | |
"""Gradio interface wrapper with robust TTS support""" | |
try: | |
# Create request object | |
request_data = { | |
"prompt": prompt, | |
"guidance_scale": guidance_scale, | |
"audio_scale": audio_scale, | |
"num_steps": int(num_steps) | |
} | |
# Add audio source | |
if text_to_speech and text_to_speech.strip(): | |
request_data["text_to_speech"] = text_to_speech | |
request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM" | |
elif audio_url and audio_url.strip(): | |
if omni_api.model_loaded: | |
request_data["audio_url"] = audio_url | |
else: | |
return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead." | |
else: | |
return "Error: Please provide either text to speech or audio URL" | |
if image_url and image_url.strip(): | |
if omni_api.model_loaded: | |
request_data["image_url"] = image_url | |
else: | |
return "Error: Image URL input requires full OmniAvatar models for video generation." | |
request = GenerateRequest(**request_data) | |
# Run async function in sync context | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request)) | |
loop.close() | |
success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}" | |
print(success_message) | |
if omni_api.model_loaded: | |
return output_path | |
else: | |
return f"ποΈ TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)" | |
except Exception as e: | |
logger.error(f"Gradio generation error: {e}") | |
return f"Error: {str(e)}" | |
# Create Gradio interface | |
mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else "" | |
description_extra = """ | |
WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available. | |
To enable full video generation, the required model files need to be downloaded. | |
""" if not omni_api.model_loaded else "" | |
iface = gr.Interface( | |
fn=gradio_generate, | |
inputs=[ | |
gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')", | |
lines=2 | |
), | |
gr.Textbox( | |
label="Text to Speech", | |
placeholder="Enter text to convert to speech", | |
lines=3, | |
info="Will use best available TTS system (Advanced or Fallback)" | |
), | |
gr.Textbox( | |
label="OR Audio URL", | |
placeholder="https://example.com/audio.mp3", | |
info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file" | |
), | |
gr.Textbox( | |
label="Image URL (Optional)", | |
placeholder="https://example.com/image.jpg", | |
info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image" | |
), | |
gr.Dropdown( | |
choices=[ | |
"21m00Tcm4TlvDq8ikWAM", | |
"pNInz6obpgDQGcFmaJgB", | |
"EXAVITQu4vr4xnSDxMaL", | |
"ErXwobaYiN019PkySvjV", | |
"TxGEqnHWrfGW9XjX", | |
"yoZ06aMxZJJ28mfd3POQ", | |
"AZnzlk1XvdvUeBnXmlld" | |
], | |
value="21m00Tcm4TlvDq8ikWAM", | |
label="Voice Profile", | |
info="Choose voice characteristics for TTS generation" | |
), | |
gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"), | |
gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"), | |
gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended") | |
], | |
outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"), | |
title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation", | |
description=f""" | |
Generate avatar videos with lip-sync from text prompts and speech using robust TTS system. | |
{description_extra} | |
**Robust TTS Architecture** | |
- **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available | |
- **Fallback**: Robust tone generation for 100% reliability | |
- **Automatic**: Seamless switching between methods | |
**Features:** | |
- **Guaranteed Generation**: Always produces audio output | |
- **No Dependencies**: Works even without advanced models | |
- **High Availability**: Multiple fallback layers | |
- **Voice Profiles**: Multiple voice characteristics | |
- **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""} | |
- **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""} | |
**Usage:** | |
1. Enter a character description in the prompt | |
2. **Enter text for speech generation** (recommended in current mode) | |
3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"} | |
4. Choose voice profile and adjust parameters | |
5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}! | |
""", | |
examples=[ | |
[ | |
"A professional teacher explaining a mathematical concept with clear gestures", | |
"Hello students! Today we're going to learn about calculus and derivatives.", | |
"", | |
"", | |
"21m00Tcm4TlvDq8ikWAM", | |
5.0, | |
3.5, | |
30 | |
], | |
[ | |
"A friendly presenter speaking confidently to an audience", | |
"Welcome everyone to our presentation on artificial intelligence!", | |
"", | |
"", | |
"pNInz6obpgDQGcFmaJgB", | |
5.5, | |
4.0, | |
35 | |
] | |
], | |
allow_flagging="never", | |
flagging_dir="/tmp/gradio_flagged" | |
) | |
# Mount Gradio app | |
app = gr.mount_gradio_app(app, iface, path="/gradio") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |