Spaces:
Running
Running
File size: 39,621 Bytes
447c811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 |
"""
AI Avatar Chat - HF Spaces Optimized Version
BUILD: 2025-01-08_00-44-FORCE-REBUILD - With Model Download Controls
FEATURES: Real video generation, model download UI, storage optimization
"""
import os
# STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads
IS_HF_SPACE = any([
os.getenv("SPACE_ID"),
os.getenv("SPACE_AUTHOR_NAME"),
os.getenv("SPACES_BUILDKIT_VERSION"),
"/home/user/app" in os.getcwd()
])
if IS_HF_SPACE:
# Force TTS-only mode to prevent storage limit exceeded
# os.environ[\"DISABLE_MODEL_DOWNLOAD\"] = \"1\" # ENABLED FOR VIDEO GENERATION
# os.environ[\"TTS_ONLY_MODE\"] = \"1\" # ENABLED FOR VIDEO GENERATION
os.environ["HF_SPACE_STORAGE_OPTIMIZED"] = "1"
print("?? STORAGE OPTIMIZATION: Detected HF Space environment")
print("?? Video generation ENABLED (models need manual download)")
print("?? WARNING: Use /download-models endpoint to download ~30GB models first")
import os
import torch
import tempfile
import gradio as gr
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, HttpUrl
import subprocess
import json
from pathlib import Path
import logging
import requests
from urllib.parse import urlparse
from PIL import Image
import io
from typing import Optional
import aiohttp
import asyncio
# Safe dotenv import
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
print("Warning: python-dotenv not found, continuing without .env support")
def load_dotenv():
pass
# CRITICAL: HF Spaces compatibility fix
try:
from hf_spaces_fix import setup_hf_spaces_environment, HFSpacesCompatible
setup_hf_spaces_environment()
except ImportError:
print('Warning: HF Spaces fix not available')
# Load environment variables
load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set environment variables for matplotlib, gradio, and huggingface cache
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
os.environ['HF_HOME'] = '/tmp/huggingface'
# Use HF_HOME instead of deprecated TRANSFORMERS_CACHE
os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub'
# FastAPI app will be created after lifespan is defined
# Create directories with proper permissions
os.makedirs("outputs", exist_ok=True)
os.makedirs("/tmp/matplotlib", exist_ok=True)
os.makedirs("/tmp/huggingface", exist_ok=True)
os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
os.makedirs("/tmp/huggingface/datasets", exist_ok=True)
os.makedirs("/tmp/huggingface/hub", exist_ok=True)
# Mount static files for serving generated videos
def get_video_url(output_path: str) -> str:
"""Convert local file path to accessible URL"""
try:
from pathlib import Path
filename = Path(output_path).name
# For HuggingFace Spaces, construct the URL
base_url = "https://bravedims-ai-avatar-chat.hf.space"
video_url = f"{base_url}/outputs/{filename}"
logger.info(f"Generated video URL: {video_url}")
return video_url
except Exception as e:
logger.error(f"Error creating video URL: {e}")
return output_path # Fallback to original path
# Pydantic models for request/response
class GenerateRequest(BaseModel):
prompt: str
text_to_speech: Optional[str] = None # Text to convert to speech
audio_url: Optional[HttpUrl] = None # Direct audio URL
voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID
image_url: Optional[HttpUrl] = None
guidance_scale: float = 5.0
audio_scale: float = 3.0
num_steps: int = 30
sp_size: int = 1
tea_cache_l1_thresh: Optional[float] = None
class GenerateResponse(BaseModel):
message: str
output_path: str
processing_time: float
audio_generated: bool = False
tts_method: Optional[str] = None
# Try to import TTS clients, but make them optional
try:
from advanced_tts_client import AdvancedTTSClient
ADVANCED_TTS_AVAILABLE = True
logger.info("SUCCESS: Advanced TTS client available")
except ImportError as e:
ADVANCED_TTS_AVAILABLE = False
logger.warning(f"WARNING: Advanced TTS client not available: {e}")
# Always import the robust fallback
try:
from robust_tts_client import RobustTTSClient
ROBUST_TTS_AVAILABLE = True
logger.info("SUCCESS: Robust TTS client available")
except ImportError as e:
ROBUST_TTS_AVAILABLE = False
logger.error(f"ERROR: Robust TTS client not available: {e}")
class TTSManager:
"""Manages multiple TTS clients with fallback chain"""
def __init__(self):
# Initialize TTS clients based on availability
self.advanced_tts = None
self.robust_tts = None
self.clients_loaded = False
if ADVANCED_TTS_AVAILABLE:
try:
self.advanced_tts = AdvancedTTSClient()
logger.info("SUCCESS: Advanced TTS client initialized")
except Exception as e:
logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}")
if ROBUST_TTS_AVAILABLE:
try:
self.robust_tts = RobustTTSClient()
logger.info("SUCCESS: Robust TTS client initialized")
except Exception as e:
logger.error(f"ERROR: Robust TTS client initialization failed: {e}")
if not self.advanced_tts and not self.robust_tts:
logger.error("ERROR: No TTS clients available!")
async def load_models(self):
"""Load TTS models"""
try:
logger.info("Loading TTS models...")
# Try to load advanced TTS first
if self.advanced_tts:
try:
logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...")
success = await self.advanced_tts.load_models()
if success:
logger.info("SUCCESS: Advanced TTS models loaded successfully")
else:
logger.warning("WARNING: Advanced TTS models failed to load")
except Exception as e:
logger.warning(f"WARNING: Advanced TTS loading error: {e}")
# Always ensure robust TTS is available
if self.robust_tts:
try:
await self.robust_tts.load_model()
logger.info("SUCCESS: Robust TTS fallback ready")
except Exception as e:
logger.error(f"ERROR: Robust TTS loading failed: {e}")
self.clients_loaded = True
return True
except Exception as e:
logger.error(f"ERROR: TTS manager initialization failed: {e}")
return False
async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]:
"""
Convert text to speech with fallback chain
Returns: (audio_file_path, method_used)
"""
if not self.clients_loaded:
logger.info("TTS models not loaded, loading now...")
await self.load_models()
logger.info(f"Generating speech: {text[:50]}...")
logger.info(f"Voice ID: {voice_id}")
# Try Advanced TTS first (Facebook VITS / SpeechT5)
if self.advanced_tts:
try:
audio_path = await self.advanced_tts.text_to_speech(text, voice_id)
return audio_path, "Facebook VITS/SpeechT5"
except Exception as advanced_error:
logger.warning(f"Advanced TTS failed: {advanced_error}")
# Fall back to robust TTS
if self.robust_tts:
try:
logger.info("Falling back to robust TTS...")
audio_path = await self.robust_tts.text_to_speech(text, voice_id)
return audio_path, "Robust TTS (Fallback)"
except Exception as robust_error:
logger.error(f"Robust TTS also failed: {robust_error}")
# If we get here, all methods failed
logger.error("All TTS methods failed!")
raise HTTPException(
status_code=500,
detail="All TTS methods failed. Please check system configuration."
)
async def get_available_voices(self):
"""Get available voice configurations"""
try:
if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'):
return await self.advanced_tts.get_available_voices()
except:
pass
# Return default voices if advanced TTS not available
return {
"21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
"pNInz6obpgDQGcFmaJgB": "Male (Professional)",
"EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
"ErXwobaYiN019PkySvjV": "Male (Professional)",
"TxGEqnHWrfGW9XjX": "Male (Deep)",
"yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
"AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
}
def get_tts_info(self):
"""Get TTS system information"""
info = {
"clients_loaded": self.clients_loaded,
"advanced_tts_available": self.advanced_tts is not None,
"robust_tts_available": self.robust_tts is not None,
"primary_method": "Robust TTS"
}
try:
if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
advanced_info = self.advanced_tts.get_model_info()
info.update({
"advanced_tts_loaded": advanced_info.get("models_loaded", False),
"transformers_available": advanced_info.get("transformers_available", False),
"primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
"device": advanced_info.get("device", "cpu"),
"vits_available": advanced_info.get("vits_available", False),
"speecht5_available": advanced_info.get("speecht5_available", False)
})
except Exception as e:
logger.debug(f"Could not get advanced TTS info: {e}")
return info
# Import the VIDEO-FOCUSED engine
try:
from omniavatar_video_engine import video_engine
VIDEO_ENGINE_AVAILABLE = True
logger.info("SUCCESS: OmniAvatar Video Engine available")
except ImportError as e:
VIDEO_ENGINE_AVAILABLE = False
logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}")
class OmniAvatarAPI:
def __init__(self):
self.model_loaded = False
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tts_manager = TTSManager()
logger.info(f"Using device: {self.device}")
logger.info("Initialized with robust TTS system")
def load_model(self):
"""Load the OmniAvatar model - now more flexible"""
try:
# Check if models are downloaded (but don't require them)
# Check both traditional and downloaded model paths
downloaded_video = "./downloaded_models/video"
downloaded_audio = "./downloaded_models/audio"
# Check downloaded models first
if os.path.exists(downloaded_video) and os.path.exists(downloaded_audio):
video_files = len([f for f in os.listdir(downloaded_video) if os.path.isfile(os.path.join(downloaded_video, f))]) if os.path.isdir(downloaded_video) else 0
audio_files = len([f for f in os.listdir(downloaded_audio) if os.path.isfile(os.path.join(downloaded_audio, f))]) if os.path.isdir(downloaded_audio) else 0
if video_files > 5 and audio_files > 5:
missing_models.append(path)
if missing_models:
logger.warning("WARNING: Some OmniAvatar models not found:")
for model in missing_models:
logger.warning(f" - {model}")
logger.info("TIP: App will run in TTS-only mode (no video generation)")
logger.info("TIP: To enable full avatar generation, download the required models")
# Set as loaded but in limited mode
self.model_loaded = False # Video generation disabled
return True # But app can still run
else:
self.model_loaded = True
logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
return True
except Exception as e:
logger.error(f"Error checking models: {str(e)}")
logger.info("TIP: Continuing in TTS-only mode")
self.model_loaded = False
return True # Continue running
async def download_file(self, url: str, suffix: str = "") -> str:
"""Download file from URL and save to temporary location"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(str(url)) as response:
if response.status != 200:
raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
content = await response.read()
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
temp_file.write(content)
temp_file.close()
return temp_file.name
except aiohttp.ClientError as e:
logger.error(f"Network error downloading {url}: {e}")
raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
except Exception as e:
logger.error(f"Error downloading file from {url}: {e}")
raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
def validate_audio_url(self, url: str) -> bool:
"""Validate if URL is likely an audio file"""
try:
parsed = urlparse(url)
# Check for common audio file extensions
audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
return is_audio_ext or 'audio' in url.lower()
except:
return False
def validate_image_url(self, url: str) -> bool:
"""Validate if URL is likely an image file"""
try:
parsed = urlparse(url)
image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
except:
return False
async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
"""Generate avatar VIDEO - PRIMARY FUNCTIONALITY"""
import time
start_time = time.time()
audio_generated = False
method_used = "Unknown"
logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION")
logger.info(f"[INFO] Prompt: {request.prompt}")
if VIDEO_ENGINE_AVAILABLE:
try:
# PRIORITIZE VIDEO GENERATION
logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation")
# Handle audio source
audio_path = None
if request.text_to_speech:
logger.info("[MIC] Generating audio from text...")
audio_path, method_used = await self.tts_manager.text_to_speech(
request.text_to_speech,
request.voice_id or "21m00Tcm4TlvDq8ikWAM"
)
audio_generated = True
elif request.audio_url:
logger.info("๐ฅ Downloading audio from URL...")
audio_path = await self.download_file(str(request.audio_url), ".mp3")
method_used = "External Audio"
else:
raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation")
# Handle image if provided
image_path = None
if request.image_url:
logger.info("[IMAGE] Downloading reference image...")
parsed = urlparse(str(request.image_url))
ext = os.path.splitext(parsed.path)[1] or ".jpg"
image_path = await self.download_file(str(request.image_url), ext)
# GENERATE VIDEO using OmniAvatar engine
logger.info("[VIDEO] Generating avatar video with adaptive body animation...")
video_path, generation_time = video_engine.generate_avatar_video(
prompt=request.prompt,
audio_path=audio_path,
image_path=image_path,
guidance_scale=request.guidance_scale,
audio_scale=request.audio_scale,
num_steps=request.num_steps
)
processing_time = time.time() - start_time
logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s")
# Cleanup temporary files
if audio_path and os.path.exists(audio_path):
os.unlink(audio_path)
if image_path and os.path.exists(image_path):
os.unlink(image_path)
return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})"
except Exception as e:
logger.error(f"ERROR: Video generation failed: {e}")
# For a VIDEO generation app, we should NOT fall back to audio-only
# Instead, provide clear guidance
if "models" in str(e).lower():
raise HTTPException(
status_code=503,
detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}"
)
else:
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
# If video engine not available, this is a critical error for a VIDEO app
raise HTTPException(
status_code=503,
detail="Video generation engine not available. This application requires OmniAvatar models for video generation."
)
async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
"""OLD TTS-ONLY METHOD - kept as backup reference.
Generate avatar video from prompt and audio/text - now handles missing models"""
import time
start_time = time.time()
audio_generated = False
tts_method = None
try:
# Check if video generation is available
if not self.model_loaded:
logger.info("๐๏ธ Running in TTS-only mode (OmniAvatar models not available)")
# Only generate audio, no video
if request.text_to_speech:
logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
audio_path, tts_method = await self.tts_manager.text_to_speech(
request.text_to_speech,
request.voice_id or "21m00Tcm4TlvDq8ikWAM"
)
# Return the audio file as the "output"
processing_time = time.time() - start_time
logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}")
return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)"
else:
raise HTTPException(
status_code=503,
detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported."
)
# Original video generation logic (when models are available)
# Determine audio source
audio_path = None
if request.text_to_speech:
# Generate speech from text using TTS manager
logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
audio_path, tts_method = await self.tts_manager.text_to_speech(
request.text_to_speech,
request.voice_id or "21m00Tcm4TlvDq8ikWAM"
)
audio_generated = True
elif request.audio_url:
# Download audio from provided URL
logger.info(f"Downloading audio from URL: {request.audio_url}")
if not self.validate_audio_url(str(request.audio_url)):
logger.warning(f"Audio URL may not be valid: {request.audio_url}")
audio_path = await self.download_file(str(request.audio_url), ".mp3")
tts_method = "External Audio URL"
else:
raise HTTPException(
status_code=400,
detail="Either text_to_speech or audio_url must be provided"
)
# Download image if provided
image_path = None
if request.image_url:
logger.info(f"Downloading image from URL: {request.image_url}")
if not self.validate_image_url(str(request.image_url)):
logger.warning(f"Image URL may not be valid: {request.image_url}")
# Determine image extension from URL or default to .jpg
parsed = urlparse(str(request.image_url))
ext = os.path.splitext(parsed.path)[1] or ".jpg"
image_path = await self.download_file(str(request.image_url), ext)
# Create temporary input file for inference
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
if image_path:
input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
else:
input_line = f"{request.prompt}@@@@{audio_path}"
f.write(input_line)
temp_input_file = f.name
# Prepare inference command
cmd = [
"python", "-m", "torch.distributed.run",
"--standalone", f"--nproc_per_node={request.sp_size}",
"scripts/inference.py",
"--config", "configs/inference.yaml",
"--input_file", temp_input_file,
"--guidance_scale", str(request.guidance_scale),
"--audio_scale", str(request.audio_scale),
"--num_steps", str(request.num_steps)
]
if request.tea_cache_l1_thresh:
cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
logger.info(f"Running inference with command: {' '.join(cmd)}")
# Run inference
result = subprocess.run(cmd, capture_output=True, text=True)
# Clean up temporary files
os.unlink(temp_input_file)
os.unlink(audio_path)
if image_path:
os.unlink(image_path)
if result.returncode != 0:
logger.error(f"Inference failed: {result.stderr}")
raise Exception(f"Inference failed: {result.stderr}")
# Find output video file
output_dir = "./outputs"
if os.path.exists(output_dir):
video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
if video_files:
# Return the most recent video file
video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
output_path = os.path.join(output_dir, video_files[0])
processing_time = time.time() - start_time
return output_path, processing_time, audio_generated, tts_method
raise Exception("No output video generated")
except Exception as e:
# Clean up any temporary files in case of error
try:
if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
os.unlink(audio_path)
if 'image_path' in locals() and image_path and os.path.exists(image_path):
os.unlink(image_path)
if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
os.unlink(temp_input_file)
except:
pass
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Initialize API
omni_api = OmniAvatarAPI()
# Use FastAPI lifespan instead of deprecated on_event
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
success = omni_api.load_model()
if not success:
logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode")
# Load TTS models
try:
await omni_api.tts_manager.load_models()
logger.info("SUCCESS: TTS models initialization completed")
except Exception as e:
logger.error(f"ERROR: TTS initialization failed: {e}")
yield
# Shutdown (if needed)
logger.info("Application shutting down...")
# Create FastAPI app WITH lifespan parameter
app = FastAPI(
title="OmniAvatar-14B API with Advanced TTS",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files for serving generated videos
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
tts_info = omni_api.tts_manager.get_tts_info()
return {
"status": "healthy",
"model_loaded": omni_api.model_loaded,
"video_generation_available": omni_api.model_loaded,
"tts_only_mode": not omni_api.model_loaded,
"device": omni_api.device,
"supports_text_to_speech": True,
"supports_image_urls": omni_api.model_loaded,
"supports_audio_urls": omni_api.model_loaded,
"tts_system": "Advanced TTS with Robust Fallback",
"advanced_tts_available": ADVANCED_TTS_AVAILABLE,
"robust_tts_available": ROBUST_TTS_AVAILABLE,
**tts_info
}
@app.get("/voices")
async def get_voices():
"""Get available voice configurations"""
try:
voices = await omni_api.tts_manager.get_available_voices()
return {"voices": voices}
except Exception as e:
logger.error(f"Error getting voices: {e}")
return {"error": str(e)}
@app.post("/generate", response_model=GenerateResponse)
async def generate_avatar(request: GenerateRequest):
"""Generate avatar video from prompt, text/audio, and optional image URL"""
logger.info(f"Generating avatar with prompt: {request.prompt}")
if request.text_to_speech:
logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
logger.info(f"Voice ID: {request.voice_id}")
if request.audio_url:
logger.info(f"Audio URL: {request.audio_url}")
if request.image_url:
logger.info(f"Image URL: {request.image_url}")
try:
output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
return GenerateResponse(
message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""),
output_path=get_video_url(output_path) if omni_api.model_loaded else output_path,
processing_time=processing_time,
audio_generated=audio_generated,
tts_method=tts_method
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
@app.post("/download-models")
async def download_video_models():
"""Manually trigger video model downloads"""
logger.info("?? Manual model download requested...")
try:
from huggingface_hub import snapshot_download
import shutil
# Check storage first
_, _, free_bytes = shutil.disk_usage(".")
free_gb = free_bytes / (1024**3)
logger.info(f"?? Available storage: {free_gb:.1f}GB")
if free_gb < 10: # Need at least 10GB free
return {
"success": False,
"message": f"Insufficient storage: {free_gb:.1f}GB available, 10GB+ required",
"storage_gb": free_gb
}
# Download small video generation model
logger.info("?? Downloading text-to-video model...")
model_path = snapshot_download(
repo_id="ali-vilab/text-to-video-ms-1.7b",
cache_dir="./downloaded_models/video",
local_files_only=False
)
logger.info(f"? Video model downloaded: {model_path}")
# Download audio model
audio_model_path = snapshot_download(
repo_id="facebook/wav2vec2-base-960h",
cache_dir="./downloaded_models/audio",
local_files_only=False
)
logger.info(f"? Audio model downloaded: {audio_model_path}")
# Check final storage usage
_, _, free_bytes_after = shutil.disk_usage(".")
free_gb_after = free_bytes_after / (1024**3)
used_gb = free_gb - free_gb_after
return {
"success": True,
"message": "? Video generation models downloaded successfully!",
"models_downloaded": [
"ali-vilab/text-to-video-ms-1.7b",
"facebook/wav2vec2-base-960h"
],
"storage_used_gb": round(used_gb, 2),
"storage_remaining_gb": round(free_gb_after, 2),
"video_model_path": model_path,
"audio_model_path": audio_model_path,
"status": "READY FOR VIDEO GENERATION"
}
except Exception as e:
logger.error(f"? Model download failed: {e}")
return {
"success": False,
"message": f"Model download failed: {str(e)}",
"error": str(e)
}
@app.get("/model-status")
async def get_model_status():
"""Check status of downloaded models"""
try:
models_dir = Path("./downloaded_models")
status = {
"models_downloaded": models_dir.exists(),
"available_models": [],
"storage_info": {}
}
if models_dir.exists():
for model_dir in models_dir.iterdir():
if model_dir.is_dir():
status["available_models"].append({
"name": model_dir.name,
"path": str(model_dir),
"files": len(list(model_dir.rglob("*")))
})
# Storage info
import shutil
_, _, free_bytes = shutil.disk_usage(".")
status["storage_info"] = {
"free_gb": round(free_bytes / (1024**3), 2),
"models_dir_exists": models_dir.exists()
}
return status
except Exception as e:
return {"error": str(e)}
# Enhanced Gradio interface
def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
"""Gradio interface wrapper with robust TTS support"""
try:
# Create request object
request_data = {
"prompt": prompt,
"guidance_scale": guidance_scale,
"audio_scale": audio_scale,
"num_steps": int(num_steps)
}
# Add audio source
if text_to_speech and text_to_speech.strip():
request_data["text_to_speech"] = text_to_speech
request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
elif audio_url and audio_url.strip():
if omni_api.model_loaded:
request_data["audio_url"] = audio_url
else:
return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead."
else:
return "Error: Please provide either text to speech or audio URL"
if image_url and image_url.strip():
if omni_api.model_loaded:
request_data["image_url"] = image_url
else:
return "Error: Image URL input requires full OmniAvatar models for video generation."
request = GenerateRequest(**request_data)
# Run async function in sync context
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
loop.close()
success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}"
print(success_message)
if omni_api.model_loaded:
return output_path
else:
return f"๐๏ธ TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)"
except Exception as e:
logger.error(f"Gradio generation error: {e}")
return f"Error: {str(e)}"
# Create Gradio interface
mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else ""
description_extra = """
WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available.
To enable full video generation, the required model files need to be downloaded.
""" if not omni_api.model_loaded else ""
iface = gr.Interface(
fn=gradio_generate,
inputs=[
gr.Textbox(
label="Prompt",
placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
lines=2
),
gr.Textbox(
label="Text to Speech",
placeholder="Enter text to convert to speech",
lines=3,
info="Will use best available TTS system (Advanced or Fallback)"
),
gr.Textbox(
label="OR Audio URL",
placeholder="https://example.com/audio.mp3",
info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file"
),
gr.Textbox(
label="Image URL (Optional)",
placeholder="https://example.com/image.jpg",
info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image"
),
gr.Dropdown(
choices=[
"21m00Tcm4TlvDq8ikWAM",
"pNInz6obpgDQGcFmaJgB",
"EXAVITQu4vr4xnSDxMaL",
"ErXwobaYiN019PkySvjV",
"TxGEqnHWrfGW9XjX",
"yoZ06aMxZJJ28mfd3POQ",
"AZnzlk1XvdvUeBnXmlld"
],
value="21m00Tcm4TlvDq8ikWAM",
label="Voice Profile",
info="Choose voice characteristics for TTS generation"
),
gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
],
outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"),
title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation",
description=f"""
Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
{description_extra}
**Robust TTS Architecture**
- **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
- **Fallback**: Robust tone generation for 100% reliability
- **Automatic**: Seamless switching between methods
**Features:**
- **Guaranteed Generation**: Always produces audio output
- **No Dependencies**: Works even without advanced models
- **High Availability**: Multiple fallback layers
- **Voice Profiles**: Multiple voice characteristics
- **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""}
- **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""}
**Usage:**
1. Enter a character description in the prompt
2. **Enter text for speech generation** (recommended in current mode)
3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"}
4. Choose voice profile and adjust parameters
5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}!
""",
examples=[
[
"A professional teacher explaining a mathematical concept with clear gestures",
"Hello students! Today we're going to learn about calculus and derivatives.",
"",
"",
"21m00Tcm4TlvDq8ikWAM",
5.0,
3.5,
30
],
[
"A friendly presenter speaking confidently to an audience",
"Welcome everyone to our presentation on artificial intelligence!",
"",
"",
"pNInz6obpgDQGcFmaJgB",
5.5,
4.0,
35
]
],
allow_flagging="never",
flagging_dir="/tmp/gradio_flagged"
)
# Mount Gradio app
app = gr.mount_gradio_app(app, iface, path="/gradio")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|