File size: 21,724 Bytes
c25a325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
ο»Ώimport os
import torch
import tempfile
import gradio as gr
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, HttpUrl
import subprocess
import json
from pathlib import Path
import logging
import requests
from urllib.parse import urlparse
from PIL import Image
import io
from typing import Optional
import aiohttp
import asyncio
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="OmniAvatar-14B API with ElevenLabs", version="1.0.0")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount static files for serving generated videos  
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")

def get_video_url(output_path: str) -> str:
    """Convert local file path to accessible URL"""
    try:
        from pathlib import Path
        filename = Path(output_path).name
        
        # For HuggingFace Spaces, construct the URL
        base_url = "https://bravedims-ai-avatar-chat.hf.space"
        video_url = f"{base_url}/outputs/{filename}"
        logger.info(f"Generated video URL: {video_url}")
        return video_url
    except Exception as e:
        logger.error(f"Error creating video URL: {e}")
        return output_path  # Fallback to original path

# Pydantic models for request/response
class GenerateRequest(BaseModel):
    prompt: str
    text_to_speech: Optional[str] = None  # Text to convert to speech
    elevenlabs_audio_url: Optional[HttpUrl] = None  # Direct audio URL
    voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM"  # Default ElevenLabs voice
    image_url: Optional[HttpUrl] = None
    guidance_scale: float = 5.0
    audio_scale: float = 3.0
    num_steps: int = 30
    sp_size: int = 1
    tea_cache_l1_thresh: Optional[float] = None

class GenerateResponse(BaseModel):
    message: str
    output_path: str
    processing_time: float
    audio_generated: bool = False

# Import the robust TTS client as fallback
from robust_tts_client import RobustTTSClient

class ElevenLabsClient:
    def __init__(self, api_key: str = None):
        self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY", "sk_c7a0b115cd48fc026226158c5ac87755b063c802ad892de6")
        self.base_url = "https://api.elevenlabs.io/v1"
        # Initialize fallback TTS client
        self.fallback_tts = RobustTTSClient()
        
    async def text_to_speech(self, text: str, voice_id: str = "21m00Tcm4TlvDq8ikWAM") -> str:
        """Convert text to speech using ElevenLabs with fallback to robust TTS"""
        logger.info(f"Generating speech from text: {text[:50]}...")
        logger.info(f"Voice ID: {voice_id}")
        
        # Try ElevenLabs first
        try:
            return await self._elevenlabs_tts(text, voice_id)
        except Exception as e:
            logger.warning(f"ElevenLabs TTS failed: {e}")
            logger.info("Falling back to robust TTS client...")
            try:
                return await self.fallback_tts.text_to_speech(text, voice_id)
            except Exception as fallback_error:
                logger.error(f"Fallback TTS also failed: {fallback_error}")
                raise HTTPException(status_code=500, detail=f"All TTS methods failed. ElevenLabs: {e}, Fallback: {fallback_error}")
    
    async def _elevenlabs_tts(self, text: str, voice_id: str) -> str:
        """Internal method for ElevenLabs API call"""
        url = f"{self.base_url}/text-to-speech/{voice_id}"
        
        headers = {
            "Accept": "audio/mpeg",
            "Content-Type": "application/json",
            "xi-api-key": self.api_key
        }
        
        data = {
            "text": text,
            "model_id": "eleven_monolingual_v1",
            "voice_settings": {
                "stability": 0.5,
                "similarity_boost": 0.5
            }
        }
        
        logger.info(f"Calling ElevenLabs API: {url}")
        logger.info(f"API Key configured: {'Yes' if self.api_key else 'No'}")
        
        timeout = aiohttp.ClientTimeout(total=30)  # 30 second timeout
        
        async with aiohttp.ClientSession(timeout=timeout) as session:
            async with session.post(url, headers=headers, json=data) as response:
                logger.info(f"ElevenLabs response status: {response.status}")
                
                if response.status != 200:
                    error_text = await response.text()
                    logger.error(f"ElevenLabs API error: {response.status} - {error_text}")
                    
                    if response.status == 401:
                        raise Exception(f"ElevenLabs authentication failed. Please check API key.")
                    elif response.status == 429:
                        raise Exception(f"ElevenLabs rate limit exceeded. Please try again later.")
                    elif response.status == 422:
                        raise Exception(f"ElevenLabs request validation failed: {error_text}")
                    else:
                        raise Exception(f"ElevenLabs API error: {response.status} - {error_text}")
                
                audio_content = await response.read()
                
                if not audio_content:
                    raise Exception("ElevenLabs returned empty audio content")
                
                logger.info(f"Received {len(audio_content)} bytes of audio from ElevenLabs")
                
                # Save to temporary file
                temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
                temp_file.write(audio_content)
                temp_file.close()
                
                logger.info(f"Generated speech audio: {temp_file.name}")
                return temp_file.name

class OmniAvatarAPI:
    def __init__(self):
        self.model_loaded = False
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.elevenlabs_client = ElevenLabsClient()
        logger.info(f"Using device: {self.device}")
        logger.info(f"ElevenLabs API Key configured: {'Yes' if self.elevenlabs_client.api_key else 'No'}")
        
    def load_model(self):
        """Load the OmniAvatar model"""
        try:
            # Check if models are downloaded
            model_paths = [
                "./pretrained_models/Wan2.1-T2V-14B",
                "./pretrained_models/OmniAvatar-14B", 
                "./pretrained_models/wav2vec2-base-960h"
            ]
            
            for path in model_paths:
                if not os.path.exists(path):
                    logger.error(f"Model path not found: {path}")
                    return False
                    
            self.model_loaded = True
            logger.info("Models loaded successfully")
            return True
            
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            return False
    
    async def download_file(self, url: str, suffix: str = "") -> str:
        """Download file from URL and save to temporary location"""
        try:
            async with aiohttp.ClientSession() as session:
                async with session.get(str(url)) as response:
                    if response.status != 200:
                        raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
                    
                    content = await response.read()
                    
                    # Create temporary file
                    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
                    temp_file.write(content)
                    temp_file.close()
                    
                    return temp_file.name
                    
        except aiohttp.ClientError as e:
            logger.error(f"Network error downloading {url}: {e}")
            raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
        except Exception as e:
            logger.error(f"Error downloading file from {url}: {e}")
            raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
    
    def validate_audio_url(self, url: str) -> bool:
        """Validate if URL is likely an audio file"""
        try:
            parsed = urlparse(url)
            # Check for common audio file extensions or ElevenLabs patterns
            audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac']
            is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
            is_elevenlabs = 'elevenlabs' in parsed.netloc.lower()
            
            return is_audio_ext or is_elevenlabs or 'audio' in url.lower()
        except:
            return False
    
    def validate_image_url(self, url: str) -> bool:
        """Validate if URL is likely an image file"""
        try:
            parsed = urlparse(url)
            image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
            return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
        except:
            return False
    
    async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool]:
        """Generate avatar video from prompt and audio/text"""
        import time
        start_time = time.time()
        audio_generated = False
        
        try:
            # Determine audio source
            audio_path = None
            
            if request.text_to_speech:
                # Generate speech from text using ElevenLabs
                logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
                audio_path = await self.elevenlabs_client.text_to_speech(
                    request.text_to_speech, 
                    request.voice_id or "21m00Tcm4TlvDq8ikWAM"
                )
                audio_generated = True
                
            elif request.elevenlabs_audio_url:
                # Download audio from provided URL
                logger.info(f"Downloading audio from URL: {request.elevenlabs_audio_url}")
                if not self.validate_audio_url(str(request.elevenlabs_audio_url)):
                    logger.warning(f"Audio URL may not be valid: {request.elevenlabs_audio_url}")
                
                audio_path = await self.download_file(str(request.elevenlabs_audio_url), ".mp3")
            
            else:
                raise HTTPException(
                    status_code=400, 
                    detail="Either text_to_speech or elevenlabs_audio_url must be provided"
                )
            
            # Download image if provided
            image_path = None
            if request.image_url:
                logger.info(f"Downloading image from URL: {request.image_url}")
                if not self.validate_image_url(str(request.image_url)):
                    logger.warning(f"Image URL may not be valid: {request.image_url}")
                
                # Determine image extension from URL or default to .jpg
                parsed = urlparse(str(request.image_url))
                ext = os.path.splitext(parsed.path)[1] or ".jpg"
                image_path = await self.download_file(str(request.image_url), ext)
            
            # Create temporary input file for inference
            with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
                if image_path:
                    input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
                else:
                    input_line = f"{request.prompt}@@@@{audio_path}"
                f.write(input_line)
                temp_input_file = f.name
            
            # Prepare inference command
            cmd = [
                "python", "-m", "torch.distributed.run",
                "--standalone", f"--nproc_per_node={request.sp_size}",
                "scripts/inference.py",
                "--config", "configs/inference.yaml",
                "--input_file", temp_input_file,
                "--guidance_scale", str(request.guidance_scale),
                "--audio_scale", str(request.audio_scale),
                "--num_steps", str(request.num_steps)
            ]
            
            if request.tea_cache_l1_thresh:
                cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
            
            logger.info(f"Running inference with command: {' '.join(cmd)}")
            
            # Run inference
            result = subprocess.run(cmd, capture_output=True, text=True)
            
            # Clean up temporary files
            os.unlink(temp_input_file)
            os.unlink(audio_path)
            if image_path:
                os.unlink(image_path)
            
            if result.returncode != 0:
                logger.error(f"Inference failed: {result.stderr}")
                raise Exception(f"Inference failed: {result.stderr}")
            
            # Find output video file
            output_dir = "./outputs"
            if os.path.exists(output_dir):
                video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
                if video_files:
                    # Return the most recent video file
                    video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
                    output_path = os.path.join(output_dir, video_files[0])
                    processing_time = time.time() - start_time
                    return output_path, processing_time, audio_generated
            
            raise Exception("No output video generated")
            
        except Exception as e:
            # Clean up any temporary files in case of error
            try:
                if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
                    os.unlink(audio_path)
                if 'image_path' in locals() and image_path and os.path.exists(image_path):
                    os.unlink(image_path)
                if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
                    os.unlink(temp_input_file)
            except:
                pass
            
            logger.error(f"Generation error: {str(e)}")
            raise HTTPException(status_code=500, detail=str(e))

# Initialize API
omni_api = OmniAvatarAPI()

@app.on_event("startup")
async def startup_event():
    """Load model on startup"""
    success = omni_api.load_model()
    if not success:
        logger.warning("Model loading failed on startup")

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "model_loaded": omni_api.model_loaded,
        "device": omni_api.device,
        "supports_elevenlabs": True,
        "supports_image_urls": True,
        "supports_text_to_speech": True,
        "elevenlabs_api_configured": bool(omni_api.elevenlabs_client.api_key),
        "fallback_tts_available": True
    }

@app.post("/generate", response_model=GenerateResponse)
async def generate_avatar(request: GenerateRequest):
    """Generate avatar video from prompt, text/audio, and optional image URL"""
    
    if not omni_api.model_loaded:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    logger.info(f"Generating avatar with prompt: {request.prompt}")
    if request.text_to_speech:
        logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
        logger.info(f"Voice ID: {request.voice_id}")
    if request.elevenlabs_audio_url:
        logger.info(f"Audio URL: {request.elevenlabs_audio_url}")
    if request.image_url:
        logger.info(f"Image URL: {request.image_url}")
    
    try:
        output_path, processing_time, audio_generated = await omni_api.generate_avatar(request)
        
        return GenerateResponse(
            message="Avatar generation completed successfully",
            output_path=get_video_url(output_path),
            processing_time=processing_time,
            audio_generated=audio_generated
        )
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")

# Enhanced Gradio interface with text-to-speech option
def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
    """Gradio interface wrapper with text-to-speech support"""
    if not omni_api.model_loaded:
        return "Error: Model not loaded"
    
    try:
        # Create request object
        request_data = {
            "prompt": prompt,
            "guidance_scale": guidance_scale,
            "audio_scale": audio_scale,
            "num_steps": int(num_steps)
        }
        
        # Add audio source
        if text_to_speech and text_to_speech.strip():
            request_data["text_to_speech"] = text_to_speech
            request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
        elif audio_url and audio_url.strip():
            request_data["elevenlabs_audio_url"] = audio_url
        else:
            return "Error: Please provide either text to speech or audio URL"
        
        if image_url and image_url.strip():
            request_data["image_url"] = image_url
        
        request = GenerateRequest(**request_data)
        
        # Run async function in sync context
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        output_path, processing_time, audio_generated = loop.run_until_complete(omni_api.generate_avatar(request))
        loop.close()
        
        return output_path
        
    except Exception as e:
        logger.error(f"Gradio generation error: {e}")
        return f"Error: {str(e)}"

# Updated Gradio interface with text-to-speech support
iface = gr.Interface(
    fn=gradio_generate,
    inputs=[
        gr.Textbox(
            label="Prompt", 
            placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
            lines=2
        ),
        gr.Textbox(
            label="Text to Speech", 
            placeholder="Enter text to convert to speech using ElevenLabs",
            lines=3,
            info="This will be converted to speech automatically"
        ),
        gr.Textbox(
            label="OR Audio URL", 
            placeholder="https://api.elevenlabs.io/v1/text-to-speech/...",
            info="Direct URL to audio file (alternative to text-to-speech)"
        ),
        gr.Textbox(
            label="Image URL (Optional)", 
            placeholder="https://example.com/image.jpg",
            info="Direct URL to reference image (JPG, PNG, etc.)"
        ),
        gr.Dropdown(
            choices=["21m00Tcm4TlvDq8ikWAM", "pNInz6obpgDQGcFmaJgB", "EXAVITQu4vr4xnSDxMaL"],
            value="21m00Tcm4TlvDq8ikWAM",
            label="ElevenLabs Voice ID",
            info="Choose voice for text-to-speech"
        ),
        gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
        gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
        gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
    ],
    outputs=gr.Video(label="Generated Avatar Video"),
    title="🎭 OmniAvatar-14B with ElevenLabs TTS (+ Fallback)",
    description="""
    Generate avatar videos with lip-sync from text prompts and speech.
    
    **Features:**
    - βœ… **Text-to-Speech**: Enter text to generate speech automatically
    - βœ… **ElevenLabs Integration**: High-quality voice synthesis  
    - βœ… **Fallback TTS**: Robust backup system if ElevenLabs fails
    - βœ… **Audio URL Support**: Use pre-generated audio files
    - βœ… **Image URL Support**: Reference images for character appearance
    - βœ… **Customizable Parameters**: Fine-tune generation quality
    
    **Usage:**
    1. Enter a character description in the prompt
    2. **Either** enter text for speech generation **OR** provide an audio URL
    3. Optionally add a reference image URL
    4. Choose voice and adjust parameters
    5. Generate your avatar video!
    
    **Tips:**
    - Use guidance scale 4-6 for best prompt following
    - Increase audio scale for better lip-sync
    - Clear, descriptive prompts work best
    - If ElevenLabs fails, fallback TTS will be used automatically
    """,
    examples=[
        [
            "A professional teacher explaining a mathematical concept with clear gestures",
            "Hello students! Today we're going to learn about calculus and how derivatives work in real life.",
            "",
            "",
            "21m00Tcm4TlvDq8ikWAM",
            5.0,
            3.5,
            30
        ],
        [
            "A friendly presenter speaking confidently to an audience",
            "Welcome everyone to our presentation on artificial intelligence and its applications!",
            "",
            "",
            "pNInz6obpgDQGcFmaJgB", 
            5.5,
            4.0,
            35
        ]
    ]
)

# Mount Gradio app
app = gr.mount_gradio_app(app, iface, path="/gradio")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)