#!/usr/bin/env python3 """ HuggingFace Space Demo for TextSyncMimi Speech Editing with Token-Level Embedding Swapping This demo loads the model from HuggingFace Hub and allows: - Generating speech with different voices using OpenAI TTS - Swapping speech embeddings at specific token positions - Real-time speech editing Prerequisites: - Set OPENAI_API_KEY in Space secrets - Model will be loaded from HuggingFace Hub """ import os import json import tempfile import argparse from typing import List, Tuple, Optional from pathlib import Path import numpy as np import torch import torch.nn as nn import soundfile as sf import gradio as gr from openai import OpenAI from transformers import ( AutoModel, AutoFeatureExtractor, AutoTokenizer, MimiModel, ) # Import spaces for GPU support try: import spaces GPU_AVAILABLE = True except ImportError: GPU_AVAILABLE = False # Create dummy decorator if spaces not available class spaces: @staticmethod def GPU(func): return func # Constants SAMPLE_RATE = 24000 FRAME_RATE = 12.5 TTS_VOICES = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer", "verse"] MAX_Z_TOKENS = 50 END_TOKEN_THRESHOLD = 0.5 # Global variables model = None mimi_model = None tokenizer = None feature_extractor = None device = None openai_client = None def load_audio_to_inputs(feature_extractor, audio_path: str, sample_rate: int) -> torch.Tensor: """Load audio file and convert to model inputs.""" import librosa audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True) audio_inputs = feature_extractor(raw_audio=audio, return_tensors="pt", sampling_rate=sample_rate) return audio_inputs.input_values def initialize_models(model_id: str, tokenizer_id: str = "meta-llama/Llama-3.1-8B-Instruct", hf_token: Optional[str] = None): """Initialize all models from HuggingFace Hub.""" global model, mimi_model, tokenizer, feature_extractor, device, openai_client device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print(f"Loading TextSyncMimi model from {model_id}...") model = AutoModel.from_pretrained( model_id, trust_remote_code=True, token=hf_token ) model.to(device) model.eval() # Get mimi_model_id from config mimi_model_id = model.config.mimi_model_id if hasattr(model.config, 'mimi_model_id') else "kyutai/mimi" print("Loading Mimi model...") mimi_model = MimiModel.from_pretrained(mimi_model_id, token=hf_token) mimi_model.to(device) mimi_model.eval() print(f"Loading tokenizer from {tokenizer_id}...") tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, token=hf_token) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Loading feature extractor...") feature_extractor = AutoFeatureExtractor.from_pretrained(mimi_model_id, token=hf_token) print("Initializing OpenAI client...") openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) print("✅ All models loaded successfully!") @torch.no_grad() def compute_cross_attention_s( model, text_embeddings: torch.Tensor, input_values: torch.Tensor, device: str ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute projected text embeddings and cross-attended speech embeddings.""" audio_attention_mask = torch.ones(1, input_values.shape[-1], dtype=torch.bool, device=device) text_attention_mask = torch.ones(1, text_embeddings.shape[1], dtype=torch.bool, device=device) # Encode speech speech_embeddings = model.encode_audio_to_representation( input_values.to(device), audio_attention_mask=audio_attention_mask, ).transpose(1, 2) # Project text text_proj = model.text_proj(text_embeddings.to(device)) # Build attention masks batch_size, text_seq_len = text_proj.shape[:2] causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_proj.dtype)) causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) pad_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) formatted_text_attention_mask = torch.where((causal_mask * pad_mask).bool(), 0.0, float("-inf")) speech_seq_len = speech_embeddings.shape[1] speech_mask = torch.ones(batch_size, speech_seq_len, dtype=torch.bool, device=device) formatted_speech_attention_mask = torch.where( speech_mask.view(batch_size, 1, 1, speech_seq_len), 0.0, float("-inf") ) # Cross attention cross_out = model.cross_attention_transformer( hidden_states=text_proj, encoder_hidden_states=speech_embeddings, attention_mask=formatted_text_attention_mask, encoder_attention_mask=formatted_speech_attention_mask, alignment_chunk_sizes=None, ).last_hidden_state return text_proj, cross_out, text_attention_mask @torch.no_grad() def ar_generate_and_decode( model, mimi_model, text_proj: torch.Tensor, s_tokens: torch.Tensor, text_attention_mask: torch.Tensor, max_z_tokens: int, end_token_threshold: float, device: str ) -> np.ndarray: """Generate audio autoregressively and decode to waveform.""" batch_size, text_seq_len = text_proj.shape[:2] text_speech_latent_emb = model.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) time_speech_start_emb = model.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) time_speech_end_emb = model.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) generated_z_tokens: List[torch.Tensor] = [] for b in range(batch_size): if text_attention_mask is not None: valid_text_len = int(text_attention_mask[b].sum().item()) else: valid_text_len = text_seq_len sequence: List[torch.Tensor] = [text_speech_latent_emb] for i in range(valid_text_len): t_i = text_proj[b, i:i+1] s_i = s_tokens[b, i:i+1] sequence.extend([t_i, s_i]) sequence.append(time_speech_start_emb) z_count = 0 while z_count < max_z_tokens: current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) ar_attention_mask = torch.ones(1, current_sequence.shape[1], dtype=torch.bool, device=device) ar_outputs = model.ar_transformer( hidden_states=current_sequence, attention_mask=ar_attention_mask, ) last_prediction = ar_outputs.last_hidden_state[0, -1:, :] end_token_logit = model.end_token_classifier(last_prediction).squeeze(-1) end_token_prob = torch.sigmoid(end_token_logit).item() if end_token_prob >= end_token_threshold: break sequence.append(last_prediction) generated_z_tokens.append(last_prediction.squeeze(0)) z_count += 1 sequence.append(time_speech_end_emb) # Decode z tokens to audio if len(generated_z_tokens) == 0: audio_tensor = torch.zeros(1, 1, 1000, device=device) else: z_tokens_batch = torch.stack(generated_z_tokens, dim=0).unsqueeze(0) embeddings_bct = z_tokens_batch.transpose(1, 2) embeddings_upsampled = mimi_model.upsample(embeddings_bct) decoder_outputs = mimi_model.decoder_transformer(embeddings_upsampled.transpose(1, 2), return_dict=True) embeddings_after_dec = decoder_outputs.last_hidden_state.transpose(1, 2) audio_tensor = mimi_model.decoder(embeddings_after_dec) audio_numpy = audio_tensor.squeeze().detach().cpu().numpy() if np.isnan(audio_numpy).any() or np.isinf(audio_numpy).any(): audio_numpy = np.nan_to_num(audio_numpy) if audio_numpy.ndim > 1: audio_numpy = audio_numpy.flatten() return audio_numpy def generate_tts_audio(text: str, voice: str, instructions: str = None) -> str: """Generate TTS audio using OpenAI and return the file path.""" if not openai_client: raise RuntimeError("OpenAI client not initialized") if instructions and instructions.strip(): response = openai_client.audio.speech.create( model="gpt-4o-mini-tts", voice=voice, input=text, instructions=instructions.strip() ) else: response = openai_client.audio.speech.create( model="tts-1", voice=voice, input=text ) with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file: response.stream_to_file(temp_file.name) return temp_file.name @spaces.GPU def process_inputs(transcript_text: str, voice1: str, voice2: str, instructions1: str = "", instructions2: str = ""): """Process inputs and generate audio.""" if not all([model, mimi_model, tokenizer, feature_extractor, openai_client]): return "Please initialize models first!", None, None, None, None, None, None, None if not transcript_text.strip(): return "Please provide a transcript!", None, None, None, None, None, None, None if not voice1 or not voice2: return "Please select voices for both audio samples!", None, None, None, None, None, None, None # Tokenize tokens = tokenizer(transcript_text.strip(), return_tensors="pt", add_special_tokens=False) text_token_ids_cpu = tokens.input_ids.squeeze(0).tolist() text_token_strs = tokenizer.convert_ids_to_tokens(text_token_ids_cpu) text_token_ids = tokens.input_ids.to(device) token_display = "" for i, tok in enumerate(text_token_strs): token_display += f"Token {i}: {tok}\n" # Generate TTS audio print(f"Generating TTS audio with voice '{voice1}'...") audio1_path = generate_tts_audio(transcript_text.strip(), voice1, instructions1) print(f"Generating TTS audio with voice '{voice2}'...") audio2_path = generate_tts_audio(transcript_text.strip(), voice2, instructions2) # Load audio input_values_utt1 = load_audio_to_inputs(feature_extractor, audio1_path, SAMPLE_RATE) input_values_utt2 = load_audio_to_inputs(feature_extractor, audio2_path, SAMPLE_RATE) # Get text embeddings using model's built-in text_token_embedding with torch.no_grad(): text_embeddings = model.text_token_embedding(text_token_ids) # Compute cross-attention embeddings t1_proj, s1_cross, text_attention_mask = compute_cross_attention_s( model, text_embeddings, input_values_utt1, device ) _, s2_cross, _ = compute_cross_attention_s( model, text_embeddings, input_values_utt2, device ) # Generate baseline audio baseline_audio = ar_generate_and_decode( model=model, mimi_model=mimi_model, text_proj=t1_proj, s_tokens=s1_cross, text_attention_mask=text_attention_mask, max_z_tokens=MAX_Z_TOKENS, end_token_threshold=END_TOKEN_THRESHOLD, device=device, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, baseline_audio, SAMPLE_RATE) baseline_path = f.name return ( "Processing completed successfully!", token_display, audio1_path, audio2_path, baseline_path, json.dumps({ "t1_proj": t1_proj.cpu().numpy().tolist(), "s1_cross": s1_cross.cpu().numpy().tolist(), "s2_cross": s2_cross.cpu().numpy().tolist(), "text_attention_mask": text_attention_mask.cpu().numpy().tolist(), "num_tokens": len(text_token_strs) }), audio1_path, audio2_path ) @spaces.GPU def swap_embeddings(embeddings_json: str, swap_indices: str): """Perform embedding swap at specified token indices.""" if not embeddings_json: return "Please process inputs first!", None if not swap_indices.strip(): return "Please specify token indices to swap (e.g., 0,2,5)!", None # Parse stored embeddings embeddings_data = json.loads(embeddings_json) t1_proj = torch.tensor(embeddings_data["t1_proj"]).to(device) s1_cross = torch.tensor(embeddings_data["s1_cross"]).to(device) s2_cross = torch.tensor(embeddings_data["s2_cross"]).to(device) text_attention_mask = torch.tensor(embeddings_data["text_attention_mask"]).to(device) num_tokens = embeddings_data["num_tokens"] # Parse indices parts = [p.strip() for p in swap_indices.split(",")] parsed = [int(p) for p in parts if p.isdigit()] if len(parsed) == 0: return "No valid indices provided! Use format: 0,2,5", None valid_indices = [i for i in parsed if 0 <= i < num_tokens] if len(valid_indices) == 0: return f"All indices out of range! Valid range: 0-{num_tokens-1}", None # Perform swap s_swapped = s1_cross.clone() for idx in valid_indices: s_swapped[:, idx:idx+1, :] = s2_cross[:, idx:idx+1, :] # Generate swapped audio swapped_audio = ar_generate_and_decode( model=model, mimi_model=mimi_model, text_proj=t1_proj, s_tokens=s_swapped, text_attention_mask=text_attention_mask, max_z_tokens=MAX_Z_TOKENS, end_token_threshold=END_TOKEN_THRESHOLD, device=device, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, swapped_audio, SAMPLE_RATE) swapped_path = f.name return f"Successfully swapped embeddings at token indices: {valid_indices}", swapped_path def create_gradio_interface(): """Create the Gradio interface.""" with gr.Blocks(title="TextSyncMimi Demo") as interface: gr.Markdown("# TextSyncMimi - Standalone Demo") gr.Markdown("Generate two voice renditions using OpenAI TTS, then swap speech embeddings at token positions.") gr.Markdown("**This demo uses only the self-contained TextSyncMimi-v1 model code.**") with gr.Accordion("Style Instruction Examples", open=False): gr.Markdown(""" **Example Instructions:** - *Emotional:* "Speak with excitement and joy", "Sound sad and melancholy" - *Pace:* "Speak slowly and deliberately", "Talk quickly and energetically" - *Character:* "Sound like a wise professor", "Speak like an excited child" """) with gr.Row(): with gr.Column(): gr.Markdown("## Text-to-Speech Configuration") transcript_text = gr.Textbox( label="Transcript Text", placeholder="Enter text to synthesize...", lines=3 ) with gr.Row(): voice1 = gr.Dropdown( choices=TTS_VOICES, label="Voice 1", value="alloy" ) voice2 = gr.Dropdown( choices=TTS_VOICES, label="Voice 2", value="echo" ) instructions1 = gr.Textbox( label="Style Instructions for Voice 1", placeholder="e.g., Speak slowly and calmly", lines=2 ) instructions2 = gr.Textbox( label="Style Instructions for Voice 2", placeholder="e.g., Speak quickly with excitement", lines=2 ) process_btn = gr.Button("Generate & Process", variant="primary") process_status = gr.Textbox(label="Status", interactive=False) with gr.Column(): gr.Markdown("## Tokenization") tokens_display = gr.Textbox( label="Tokens", lines=16, interactive=False ) with gr.Row(): with gr.Column(): gr.Markdown("## Generated TTS Audio") generated_audio1 = gr.Audio(label="Generated Audio 1") generated_audio2 = gr.Audio(label="Generated Audio 2") with gr.Column(): gr.Markdown("## Model Output") baseline_audio = gr.Audio(label="Baseline Reconstruction") gr.Markdown("### Embedding Swap") swap_indices_input = gr.Textbox( label="Token Indices to Swap", placeholder="e.g., 0,2,5" ) swap_btn = gr.Button("Perform Swap") swap_status = gr.Textbox(label="Swap Status", interactive=False) swapped_audio = gr.Audio(label="Swapped Result") # Hidden states embeddings_state = gr.State() audio1_state = gr.State() audio2_state = gr.State() # Event handlers process_btn.click( fn=process_inputs, inputs=[transcript_text, voice1, voice2, instructions1, instructions2], outputs=[process_status, tokens_display, generated_audio1, generated_audio2, baseline_audio, embeddings_state, audio1_state, audio2_state] ) swap_btn.click( fn=swap_embeddings, inputs=[embeddings_state, swap_indices_input], outputs=[swap_status, swapped_audio] ) return interface def main(): """Main function.""" parser = argparse.ArgumentParser(description="HuggingFace Space Demo for TextSyncMimi") parser.add_argument( "--model_id", type=str, default="potsawee/TextSyncMimi-v1", help="HuggingFace model ID" ) parser.add_argument( "--tokenizer_id", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="HuggingFace tokenizer ID" ) parser.add_argument( "--hf_token", type=str, default=None, help="Hugging Face token (or set HF_TOKEN env var)" ) parser.add_argument( "--port", type=int, default=7860, help="Port for Gradio app" ) parser.add_argument( "--share", action="store_true", help="Create public share link" ) args = parser.parse_args() # Check OpenAI API key if not os.getenv("OPENAI_API_KEY"): print("❌ Error: OPENAI_API_KEY environment variable is required!") print("Set it: export OPENAI_API_KEY=your_key_here") return # Get HF token hf_token = args.hf_token or os.getenv("HF_TOKEN") # Initialize models print(f"🚀 Initializing TextSyncMimi from HuggingFace Hub: {args.model_id}...") initialize_models(args.model_id, args.tokenizer_id, hf_token) print("🌐 Launching Gradio interface...") # Launch interface = create_gradio_interface() interface.launch(server_port=args.port, share=args.share) if __name__ == "__main__": main()