Spaces:
Running
Running
ο»Ώimport os | |
import torch | |
import tempfile | |
import gradio as gr | |
from fastapi import FastAPI, HTTPException | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, HttpUrl | |
import subprocess | |
import json | |
from pathlib import Path | |
import logging | |
import requests | |
from urllib.parse import urlparse | |
from PIL import Image | |
import io | |
from typing import Optional | |
import aiohttp | |
import asyncio | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="OmniAvatar-14B API with ElevenLabs", version="1.0.0") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Mount static files for serving generated videos | |
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") | |
def get_video_url(output_path: str) -> str: | |
"""Convert local file path to accessible URL""" | |
try: | |
from pathlib import Path | |
filename = Path(output_path).name | |
# For HuggingFace Spaces, construct the URL | |
base_url = "https://bravedims-ai-avatar-chat.hf.space" | |
video_url = f"{base_url}/outputs/{filename}" | |
logger.info(f"Generated video URL: {video_url}") | |
return video_url | |
except Exception as e: | |
logger.error(f"Error creating video URL: {e}") | |
return output_path # Fallback to original path | |
# Pydantic models for request/response | |
class GenerateRequest(BaseModel): | |
prompt: str | |
text_to_speech: Optional[str] = None # Text to convert to speech | |
elevenlabs_audio_url: Optional[HttpUrl] = None # Direct audio URL | |
voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Default ElevenLabs voice | |
image_url: Optional[HttpUrl] = None | |
guidance_scale: float = 5.0 | |
audio_scale: float = 3.0 | |
num_steps: int = 30 | |
sp_size: int = 1 | |
tea_cache_l1_thresh: Optional[float] = None | |
class GenerateResponse(BaseModel): | |
message: str | |
output_path: str | |
processing_time: float | |
audio_generated: bool = False | |
class ElevenLabsClient: | |
def __init__(self, api_key: str = None): | |
self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY", "sk_c7a0b115cd48fc026226158c5ac87755b063c802ad892de6") | |
self.base_url = "https://api.elevenlabs.io/v1" | |
async def text_to_speech(self, text: str, voice_id: str = "21m00Tcm4TlvDq8ikWAM") -> str: | |
"""Convert text to speech using ElevenLabs and return temporary file path""" | |
url = f"{self.base_url}/text-to-speech/{voice_id}" | |
headers = { | |
"Accept": "audio/mpeg", | |
"Content-Type": "application/json", | |
"xi-api-key": self.api_key | |
} | |
data = { | |
"text": text, | |
"model_id": "eleven_monolingual_v1", | |
"voice_settings": { | |
"stability": 0.5, | |
"similarity_boost": 0.5 | |
} | |
} | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post(url, headers=headers, json=data) as response: | |
if response.status != 200: | |
error_text = await response.text() | |
raise HTTPException( | |
status_code=400, | |
detail=f"ElevenLabs API error: {response.status} - {error_text}" | |
) | |
audio_content = await response.read() | |
# Save to temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') | |
temp_file.write(audio_content) | |
temp_file.close() | |
logger.info(f"Generated speech audio: {temp_file.name}") | |
return temp_file.name | |
except aiohttp.ClientError as e: | |
logger.error(f"Network error calling ElevenLabs: {e}") | |
raise HTTPException(status_code=400, detail=f"Network error calling ElevenLabs: {e}") | |
except Exception as e: | |
logger.error(f"Error generating speech: {e}") | |
raise HTTPException(status_code=500, detail=f"Error generating speech: {e}") | |
class OmniAvatarAPI: | |
def __init__(self): | |
self.model_loaded = False | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.elevenlabs_client = ElevenLabsClient() | |
logger.info(f"Using device: {self.device}") | |
logger.info(f"ElevenLabs API Key configured: {'Yes' if self.elevenlabs_client.api_key else 'No'}") | |
def load_model(self): | |
"""Load the OmniAvatar model""" | |
try: | |
# Check if models are downloaded | |
model_paths = [ | |
"./pretrained_models/Wan2.1-T2V-14B", | |
"./pretrained_models/OmniAvatar-14B", | |
"./pretrained_models/wav2vec2-base-960h" | |
] | |
for path in model_paths: | |
if not os.path.exists(path): | |
logger.error(f"Model path not found: {path}") | |
return False | |
self.model_loaded = True | |
logger.info("Models loaded successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
return False | |
async def download_file(self, url: str, suffix: str = "") -> str: | |
"""Download file from URL and save to temporary location""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(str(url)) as response: | |
if response.status != 200: | |
raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}") | |
content = await response.read() | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
temp_file.write(content) | |
temp_file.close() | |
return temp_file.name | |
except aiohttp.ClientError as e: | |
logger.error(f"Network error downloading {url}: {e}") | |
raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}") | |
except Exception as e: | |
logger.error(f"Error downloading file from {url}: {e}") | |
raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") | |
def validate_audio_url(self, url: str) -> bool: | |
"""Validate if URL is likely an audio file""" | |
try: | |
parsed = urlparse(url) | |
# Check for common audio file extensions or ElevenLabs patterns | |
audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac'] | |
is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions) | |
is_elevenlabs = 'elevenlabs' in parsed.netloc.lower() | |
return is_audio_ext or is_elevenlabs or 'audio' in url.lower() | |
except: | |
return False | |
def validate_image_url(self, url: str) -> bool: | |
"""Validate if URL is likely an image file""" | |
try: | |
parsed = urlparse(url) | |
image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'] | |
return any(parsed.path.lower().endswith(ext) for ext in image_extensions) | |
except: | |
return False | |
async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool]: | |
"""Generate avatar video from prompt and audio/text""" | |
import time | |
start_time = time.time() | |
audio_generated = False | |
try: | |
# Determine audio source | |
audio_path = None | |
if request.text_to_speech: | |
# Generate speech from text using ElevenLabs | |
logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") | |
audio_path = await self.elevenlabs_client.text_to_speech( | |
request.text_to_speech, | |
request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
) | |
audio_generated = True | |
elif request.elevenlabs_audio_url: | |
# Download audio from provided URL | |
logger.info(f"Downloading audio from URL: {request.elevenlabs_audio_url}") | |
if not self.validate_audio_url(str(request.elevenlabs_audio_url)): | |
logger.warning(f"Audio URL may not be valid: {request.elevenlabs_audio_url}") | |
audio_path = await self.download_file(str(request.elevenlabs_audio_url), ".mp3") | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail="Either text_to_speech or elevenlabs_audio_url must be provided" | |
) | |
# Download image if provided | |
image_path = None | |
if request.image_url: | |
logger.info(f"Downloading image from URL: {request.image_url}") | |
if not self.validate_image_url(str(request.image_url)): | |
logger.warning(f"Image URL may not be valid: {request.image_url}") | |
# Determine image extension from URL or default to .jpg | |
parsed = urlparse(str(request.image_url)) | |
ext = os.path.splitext(parsed.path)[1] or ".jpg" | |
image_path = await self.download_file(str(request.image_url), ext) | |
# Create temporary input file for inference | |
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: | |
if image_path: | |
input_line = f"{request.prompt}@@{image_path}@@{audio_path}" | |
else: | |
input_line = f"{request.prompt}@@@@{audio_path}" | |
f.write(input_line) | |
temp_input_file = f.name | |
# Prepare inference command | |
cmd = [ | |
"python", "-m", "torch.distributed.run", | |
"--standalone", f"--nproc_per_node={request.sp_size}", | |
"scripts/inference.py", | |
"--config", "configs/inference.yaml", | |
"--input_file", temp_input_file, | |
"--guidance_scale", str(request.guidance_scale), | |
"--audio_scale", str(request.audio_scale), | |
"--num_steps", str(request.num_steps) | |
] | |
if request.tea_cache_l1_thresh: | |
cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)]) | |
logger.info(f"Running inference with command: {' '.join(cmd)}") | |
# Run inference | |
result = subprocess.run(cmd, capture_output=True, text=True) | |
# Clean up temporary files | |
os.unlink(temp_input_file) | |
os.unlink(audio_path) | |
if image_path: | |
os.unlink(image_path) | |
if result.returncode != 0: | |
logger.error(f"Inference failed: {result.stderr}") | |
raise Exception(f"Inference failed: {result.stderr}") | |
# Find output video file | |
output_dir = "./outputs" | |
if os.path.exists(output_dir): | |
video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))] | |
if video_files: | |
# Return the most recent video file | |
video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True) | |
output_path = os.path.join(output_dir, video_files[0]) | |
processing_time = time.time() - start_time | |
return output_path, processing_time, audio_generated | |
raise Exception("No output video generated") | |
except Exception as e: | |
# Clean up any temporary files in case of error | |
try: | |
if 'audio_path' in locals() and audio_path and os.path.exists(audio_path): | |
os.unlink(audio_path) | |
if 'image_path' in locals() and image_path and os.path.exists(image_path): | |
os.unlink(image_path) | |
if 'temp_input_file' in locals() and os.path.exists(temp_input_file): | |
os.unlink(temp_input_file) | |
except: | |
pass | |
logger.error(f"Generation error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Initialize API | |
omni_api = OmniAvatarAPI() | |
async def startup_event(): | |
"""Load model on startup""" | |
success = omni_api.load_model() | |
if not success: | |
logger.warning("Model loading failed on startup") | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"model_loaded": omni_api.model_loaded, | |
"device": omni_api.device, | |
"supports_elevenlabs": True, | |
"supports_image_urls": True, | |
"supports_text_to_speech": True, | |
"elevenlabs_api_configured": bool(omni_api.elevenlabs_client.api_key) | |
} | |
async def generate_avatar(request: GenerateRequest): | |
"""Generate avatar video from prompt, text/audio, and optional image URL""" | |
if not omni_api.model_loaded: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
logger.info(f"Generating avatar with prompt: {request.prompt}") | |
if request.text_to_speech: | |
logger.info(f"Text to speech: {request.text_to_speech[:100]}...") | |
logger.info(f"Voice ID: {request.voice_id}") | |
if request.elevenlabs_audio_url: | |
logger.info(f"Audio URL: {request.elevenlabs_audio_url}") | |
if request.image_url: | |
logger.info(f"Image URL: {request.image_url}") | |
try: | |
output_path, processing_time, audio_generated = await omni_api.generate_avatar(request) | |
return GenerateResponse( | |
message="Avatar generation completed successfully", | |
output_path=get_video_url(output_path), | |
processing_time=processing_time, | |
audio_generated=audio_generated | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") | |
# Enhanced Gradio interface with text-to-speech option | |
def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps): | |
"""Gradio interface wrapper with text-to-speech support""" | |
if not omni_api.model_loaded: | |
return "Error: Model not loaded" | |
try: | |
# Create request object | |
request_data = { | |
"prompt": prompt, | |
"guidance_scale": guidance_scale, | |
"audio_scale": audio_scale, | |
"num_steps": int(num_steps) | |
} | |
# Add audio source | |
if text_to_speech and text_to_speech.strip(): | |
request_data["text_to_speech"] = text_to_speech | |
request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM" | |
elif audio_url and audio_url.strip(): | |
request_data["elevenlabs_audio_url"] = audio_url | |
else: | |
return "Error: Please provide either text to speech or audio URL" | |
if image_url and image_url.strip(): | |
request_data["image_url"] = image_url | |
request = GenerateRequest(**request_data) | |
# Run async function in sync context | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
output_path, processing_time, audio_generated = loop.run_until_complete(omni_api.generate_avatar(request)) | |
loop.close() | |
return output_path | |
except Exception as e: | |
logger.error(f"Gradio generation error: {e}") | |
return f"Error: {str(e)}" | |
# Updated Gradio interface with text-to-speech support | |
iface = gr.Interface( | |
fn=gradio_generate, | |
inputs=[ | |
gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')", | |
lines=2 | |
), | |
gr.Textbox( | |
label="Text to Speech", | |
placeholder="Enter text to convert to speech using ElevenLabs", | |
lines=3, | |
info="This will be converted to speech automatically" | |
), | |
gr.Textbox( | |
label="OR Audio URL", | |
placeholder="https://api.elevenlabs.io/v1/text-to-speech/...", | |
info="Direct URL to audio file (alternative to text-to-speech)" | |
), | |
gr.Textbox( | |
label="Image URL (Optional)", | |
placeholder="https://example.com/image.jpg", | |
info="Direct URL to reference image (JPG, PNG, etc.)" | |
), | |
gr.Dropdown( | |
choices=["21m00Tcm4TlvDq8ikWAM", "pNInz6obpgDQGcFmaJgB", "EXAVITQu4vr4xnSDxMaL"], | |
value="21m00Tcm4TlvDq8ikWAM", | |
label="ElevenLabs Voice ID", | |
info="Choose voice for text-to-speech" | |
), | |
gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"), | |
gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"), | |
gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended") | |
], | |
outputs=gr.Video(label="Generated Avatar Video"), | |
title="π OmniAvatar-14B with ElevenLabs TTS", | |
description=""" | |
Generate avatar videos with lip-sync from text prompts and speech. | |
**Features:** | |
- β **Text-to-Speech**: Enter text to generate speech automatically | |
- β **ElevenLabs Integration**: High-quality voice synthesis | |
- β **Audio URL Support**: Use pre-generated audio files | |
- β **Image URL Support**: Reference images for character appearance | |
- β **Customizable Parameters**: Fine-tune generation quality | |
**Usage:** | |
1. Enter a character description in the prompt | |
2. **Either** enter text for speech generation **OR** provide an audio URL | |
3. Optionally add a reference image URL | |
4. Choose voice and adjust parameters | |
5. Generate your avatar video! | |
**Tips:** | |
- Use guidance scale 4-6 for best prompt following | |
- Increase audio scale for better lip-sync | |
- Clear, descriptive prompts work best | |
""", | |
examples=[ | |
[ | |
"A professional teacher explaining a mathematical concept with clear gestures", | |
"Hello students! Today we're going to learn about calculus and how derivatives work in real life.", | |
"", | |
"https://example.com/teacher.jpg", | |
"21m00Tcm4TlvDq8ikWAM", | |
5.0, | |
3.5, | |
30 | |
], | |
[ | |
"A friendly presenter speaking confidently to an audience", | |
"Welcome everyone to our presentation on artificial intelligence and its applications!", | |
"", | |
"", | |
"pNInz6obpgDQGcFmaJgB", | |
5.5, | |
4.0, | |
35 | |
] | |
] | |
) | |
# Mount Gradio app | |
app = gr.mount_gradio_app(app, iface, path="/gradio") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |