|
|
|
|
|
""" |
|
|
HuggingFace Space Demo for TextSyncMimi |
|
|
Speech Editing with Token-Level Embedding Swapping |
|
|
|
|
|
This demo loads the model from HuggingFace Hub and allows: |
|
|
- Generating speech with different voices using OpenAI TTS |
|
|
- Swapping speech embeddings at specific token positions |
|
|
- Real-time speech editing |
|
|
|
|
|
Prerequisites: |
|
|
- Set OPENAI_API_KEY in Space secrets |
|
|
- Model will be loaded from HuggingFace Hub |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import tempfile |
|
|
import argparse |
|
|
from typing import List, Tuple, Optional |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import soundfile as sf |
|
|
import gradio as gr |
|
|
from openai import OpenAI |
|
|
from transformers import ( |
|
|
AutoModel, |
|
|
AutoFeatureExtractor, |
|
|
AutoTokenizer, |
|
|
MimiModel, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
GPU_AVAILABLE = True |
|
|
except ImportError: |
|
|
GPU_AVAILABLE = False |
|
|
|
|
|
class spaces: |
|
|
@staticmethod |
|
|
def GPU(func): |
|
|
return func |
|
|
|
|
|
|
|
|
|
|
|
SAMPLE_RATE = 24000 |
|
|
FRAME_RATE = 12.5 |
|
|
TTS_VOICES = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer", "verse"] |
|
|
MAX_Z_TOKENS = 50 |
|
|
END_TOKEN_THRESHOLD = 0.5 |
|
|
|
|
|
|
|
|
model = None |
|
|
mimi_model = None |
|
|
tokenizer = None |
|
|
feature_extractor = None |
|
|
device = None |
|
|
openai_client = None |
|
|
|
|
|
|
|
|
def load_audio_to_inputs(feature_extractor, audio_path: str, sample_rate: int) -> torch.Tensor: |
|
|
"""Load audio file and convert to model inputs.""" |
|
|
import librosa |
|
|
audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True) |
|
|
audio_inputs = feature_extractor(raw_audio=audio, return_tensors="pt", sampling_rate=sample_rate) |
|
|
return audio_inputs.input_values |
|
|
|
|
|
|
|
|
def initialize_models(model_id: str, tokenizer_id: str = "meta-llama/Llama-3.1-8B-Instruct", hf_token: Optional[str] = None): |
|
|
"""Initialize all models from HuggingFace Hub.""" |
|
|
global model, mimi_model, tokenizer, feature_extractor, device, openai_client |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
print(f"Loading TextSyncMimi model from {model_id}...") |
|
|
model = AutoModel.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True, |
|
|
token=hf_token |
|
|
) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
mimi_model_id = model.config.mimi_model_id if hasattr(model.config, 'mimi_model_id') else "kyutai/mimi" |
|
|
|
|
|
print("Loading Mimi model...") |
|
|
mimi_model = MimiModel.from_pretrained(mimi_model_id, token=hf_token) |
|
|
mimi_model.to(device) |
|
|
mimi_model.eval() |
|
|
|
|
|
print(f"Loading tokenizer from {tokenizer_id}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, token=hf_token) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print("Loading feature extractor...") |
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(mimi_model_id, token=hf_token) |
|
|
|
|
|
print("Initializing OpenAI client...") |
|
|
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
|
|
|
|
print("β
All models loaded successfully!") |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_cross_attention_s( |
|
|
model, |
|
|
text_embeddings: torch.Tensor, |
|
|
input_values: torch.Tensor, |
|
|
device: str |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Compute projected text embeddings and cross-attended speech embeddings.""" |
|
|
audio_attention_mask = torch.ones(1, input_values.shape[-1], dtype=torch.bool, device=device) |
|
|
text_attention_mask = torch.ones(1, text_embeddings.shape[1], dtype=torch.bool, device=device) |
|
|
|
|
|
|
|
|
speech_embeddings = model.encode_audio_to_representation( |
|
|
input_values.to(device), |
|
|
audio_attention_mask=audio_attention_mask, |
|
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
text_proj = model.text_proj(text_embeddings.to(device)) |
|
|
|
|
|
|
|
|
batch_size, text_seq_len = text_proj.shape[:2] |
|
|
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_proj.dtype)) |
|
|
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
|
|
pad_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) |
|
|
formatted_text_attention_mask = torch.where((causal_mask * pad_mask).bool(), 0.0, float("-inf")) |
|
|
|
|
|
speech_seq_len = speech_embeddings.shape[1] |
|
|
speech_mask = torch.ones(batch_size, speech_seq_len, dtype=torch.bool, device=device) |
|
|
formatted_speech_attention_mask = torch.where( |
|
|
speech_mask.view(batch_size, 1, 1, speech_seq_len), 0.0, float("-inf") |
|
|
) |
|
|
|
|
|
|
|
|
cross_out = model.cross_attention_transformer( |
|
|
hidden_states=text_proj, |
|
|
encoder_hidden_states=speech_embeddings, |
|
|
attention_mask=formatted_text_attention_mask, |
|
|
encoder_attention_mask=formatted_speech_attention_mask, |
|
|
alignment_chunk_sizes=None, |
|
|
).last_hidden_state |
|
|
|
|
|
return text_proj, cross_out, text_attention_mask |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def ar_generate_and_decode( |
|
|
model, |
|
|
mimi_model, |
|
|
text_proj: torch.Tensor, |
|
|
s_tokens: torch.Tensor, |
|
|
text_attention_mask: torch.Tensor, |
|
|
max_z_tokens: int, |
|
|
end_token_threshold: float, |
|
|
device: str |
|
|
) -> np.ndarray: |
|
|
"""Generate audio autoregressively and decode to waveform.""" |
|
|
batch_size, text_seq_len = text_proj.shape[:2] |
|
|
|
|
|
text_speech_latent_emb = model.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
time_speech_start_emb = model.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
time_speech_end_emb = model.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) |
|
|
|
|
|
generated_z_tokens: List[torch.Tensor] = [] |
|
|
|
|
|
for b in range(batch_size): |
|
|
if text_attention_mask is not None: |
|
|
valid_text_len = int(text_attention_mask[b].sum().item()) |
|
|
else: |
|
|
valid_text_len = text_seq_len |
|
|
|
|
|
sequence: List[torch.Tensor] = [text_speech_latent_emb] |
|
|
|
|
|
for i in range(valid_text_len): |
|
|
t_i = text_proj[b, i:i+1] |
|
|
s_i = s_tokens[b, i:i+1] |
|
|
|
|
|
sequence.extend([t_i, s_i]) |
|
|
sequence.append(time_speech_start_emb) |
|
|
|
|
|
z_count = 0 |
|
|
while z_count < max_z_tokens: |
|
|
current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) |
|
|
ar_attention_mask = torch.ones(1, current_sequence.shape[1], dtype=torch.bool, device=device) |
|
|
|
|
|
ar_outputs = model.ar_transformer( |
|
|
hidden_states=current_sequence, |
|
|
attention_mask=ar_attention_mask, |
|
|
) |
|
|
last_prediction = ar_outputs.last_hidden_state[0, -1:, :] |
|
|
|
|
|
end_token_logit = model.end_token_classifier(last_prediction).squeeze(-1) |
|
|
end_token_prob = torch.sigmoid(end_token_logit).item() |
|
|
|
|
|
if end_token_prob >= end_token_threshold: |
|
|
break |
|
|
sequence.append(last_prediction) |
|
|
generated_z_tokens.append(last_prediction.squeeze(0)) |
|
|
z_count += 1 |
|
|
|
|
|
sequence.append(time_speech_end_emb) |
|
|
|
|
|
|
|
|
if len(generated_z_tokens) == 0: |
|
|
audio_tensor = torch.zeros(1, 1, 1000, device=device) |
|
|
else: |
|
|
z_tokens_batch = torch.stack(generated_z_tokens, dim=0).unsqueeze(0) |
|
|
embeddings_bct = z_tokens_batch.transpose(1, 2) |
|
|
embeddings_upsampled = mimi_model.upsample(embeddings_bct) |
|
|
decoder_outputs = mimi_model.decoder_transformer(embeddings_upsampled.transpose(1, 2), return_dict=True) |
|
|
embeddings_after_dec = decoder_outputs.last_hidden_state.transpose(1, 2) |
|
|
audio_tensor = mimi_model.decoder(embeddings_after_dec) |
|
|
|
|
|
audio_numpy = audio_tensor.squeeze().detach().cpu().numpy() |
|
|
if np.isnan(audio_numpy).any() or np.isinf(audio_numpy).any(): |
|
|
audio_numpy = np.nan_to_num(audio_numpy) |
|
|
if audio_numpy.ndim > 1: |
|
|
audio_numpy = audio_numpy.flatten() |
|
|
return audio_numpy |
|
|
|
|
|
|
|
|
def generate_tts_audio(text: str, voice: str, instructions: str = None) -> str: |
|
|
"""Generate TTS audio using OpenAI and return the file path.""" |
|
|
if not openai_client: |
|
|
raise RuntimeError("OpenAI client not initialized") |
|
|
|
|
|
if instructions and instructions.strip(): |
|
|
response = openai_client.audio.speech.create( |
|
|
model="gpt-4o-mini-tts", |
|
|
voice=voice, |
|
|
input=text, |
|
|
instructions=instructions.strip() |
|
|
) |
|
|
else: |
|
|
response = openai_client.audio.speech.create( |
|
|
model="tts-1", |
|
|
voice=voice, |
|
|
input=text |
|
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file: |
|
|
response.stream_to_file(temp_file.name) |
|
|
return temp_file.name |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def process_inputs(transcript_text: str, voice1: str, voice2: str, instructions1: str = "", instructions2: str = ""): |
|
|
"""Process inputs and generate audio.""" |
|
|
if not all([model, mimi_model, tokenizer, feature_extractor, openai_client]): |
|
|
return "Please initialize models first!", None, None, None, None, None, None, None |
|
|
|
|
|
if not transcript_text.strip(): |
|
|
return "Please provide a transcript!", None, None, None, None, None, None, None |
|
|
|
|
|
if not voice1 or not voice2: |
|
|
return "Please select voices for both audio samples!", None, None, None, None, None, None, None |
|
|
|
|
|
|
|
|
tokens = tokenizer(transcript_text.strip(), return_tensors="pt", add_special_tokens=False) |
|
|
text_token_ids_cpu = tokens.input_ids.squeeze(0).tolist() |
|
|
text_token_strs = tokenizer.convert_ids_to_tokens(text_token_ids_cpu) |
|
|
text_token_ids = tokens.input_ids.to(device) |
|
|
|
|
|
token_display = "" |
|
|
for i, tok in enumerate(text_token_strs): |
|
|
token_display += f"Token {i}: {tok}\n" |
|
|
|
|
|
|
|
|
print(f"Generating TTS audio with voice '{voice1}'...") |
|
|
audio1_path = generate_tts_audio(transcript_text.strip(), voice1, instructions1) |
|
|
print(f"Generating TTS audio with voice '{voice2}'...") |
|
|
audio2_path = generate_tts_audio(transcript_text.strip(), voice2, instructions2) |
|
|
|
|
|
|
|
|
input_values_utt1 = load_audio_to_inputs(feature_extractor, audio1_path, SAMPLE_RATE) |
|
|
input_values_utt2 = load_audio_to_inputs(feature_extractor, audio2_path, SAMPLE_RATE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
text_embeddings = model.text_token_embedding(text_token_ids) |
|
|
|
|
|
|
|
|
t1_proj, s1_cross, text_attention_mask = compute_cross_attention_s( |
|
|
model, text_embeddings, input_values_utt1, device |
|
|
) |
|
|
_, s2_cross, _ = compute_cross_attention_s( |
|
|
model, text_embeddings, input_values_utt2, device |
|
|
) |
|
|
|
|
|
|
|
|
baseline_audio = ar_generate_and_decode( |
|
|
model=model, |
|
|
mimi_model=mimi_model, |
|
|
text_proj=t1_proj, |
|
|
s_tokens=s1_cross, |
|
|
text_attention_mask=text_attention_mask, |
|
|
max_z_tokens=MAX_Z_TOKENS, |
|
|
end_token_threshold=END_TOKEN_THRESHOLD, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
sf.write(f.name, baseline_audio, SAMPLE_RATE) |
|
|
baseline_path = f.name |
|
|
|
|
|
return ( |
|
|
"Processing completed successfully!", |
|
|
token_display, |
|
|
audio1_path, |
|
|
audio2_path, |
|
|
baseline_path, |
|
|
json.dumps({ |
|
|
"t1_proj": t1_proj.cpu().numpy().tolist(), |
|
|
"s1_cross": s1_cross.cpu().numpy().tolist(), |
|
|
"s2_cross": s2_cross.cpu().numpy().tolist(), |
|
|
"text_attention_mask": text_attention_mask.cpu().numpy().tolist(), |
|
|
"num_tokens": len(text_token_strs) |
|
|
}), |
|
|
audio1_path, |
|
|
audio2_path |
|
|
) |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def swap_embeddings(embeddings_json: str, swap_indices: str): |
|
|
"""Perform embedding swap at specified token indices.""" |
|
|
if not embeddings_json: |
|
|
return "Please process inputs first!", None |
|
|
|
|
|
if not swap_indices.strip(): |
|
|
return "Please specify token indices to swap (e.g., 0,2,5)!", None |
|
|
|
|
|
|
|
|
embeddings_data = json.loads(embeddings_json) |
|
|
t1_proj = torch.tensor(embeddings_data["t1_proj"]).to(device) |
|
|
s1_cross = torch.tensor(embeddings_data["s1_cross"]).to(device) |
|
|
s2_cross = torch.tensor(embeddings_data["s2_cross"]).to(device) |
|
|
text_attention_mask = torch.tensor(embeddings_data["text_attention_mask"]).to(device) |
|
|
num_tokens = embeddings_data["num_tokens"] |
|
|
|
|
|
|
|
|
parts = [p.strip() for p in swap_indices.split(",")] |
|
|
parsed = [int(p) for p in parts if p.isdigit()] |
|
|
|
|
|
if len(parsed) == 0: |
|
|
return "No valid indices provided! Use format: 0,2,5", None |
|
|
|
|
|
valid_indices = [i for i in parsed if 0 <= i < num_tokens] |
|
|
if len(valid_indices) == 0: |
|
|
return f"All indices out of range! Valid range: 0-{num_tokens-1}", None |
|
|
|
|
|
|
|
|
s_swapped = s1_cross.clone() |
|
|
for idx in valid_indices: |
|
|
s_swapped[:, idx:idx+1, :] = s2_cross[:, idx:idx+1, :] |
|
|
|
|
|
|
|
|
swapped_audio = ar_generate_and_decode( |
|
|
model=model, |
|
|
mimi_model=mimi_model, |
|
|
text_proj=t1_proj, |
|
|
s_tokens=s_swapped, |
|
|
text_attention_mask=text_attention_mask, |
|
|
max_z_tokens=MAX_Z_TOKENS, |
|
|
end_token_threshold=END_TOKEN_THRESHOLD, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
sf.write(f.name, swapped_audio, SAMPLE_RATE) |
|
|
swapped_path = f.name |
|
|
|
|
|
return f"Successfully swapped embeddings at token indices: {valid_indices}", swapped_path |
|
|
|
|
|
|
|
|
def create_gradio_interface(): |
|
|
"""Create the Gradio interface.""" |
|
|
with gr.Blocks(title="TextSyncMimi Demo") as interface: |
|
|
gr.Markdown("# TextSyncMimi - Standalone Demo") |
|
|
gr.Markdown("Generate two voice renditions using OpenAI TTS, then swap speech embeddings at token positions.") |
|
|
gr.Markdown("**This demo uses only the self-contained TextSyncMimi-v1 model code.**") |
|
|
|
|
|
with gr.Accordion("Style Instruction Examples", open=False): |
|
|
gr.Markdown(""" |
|
|
**Example Instructions:** |
|
|
- *Emotional:* "Speak with excitement and joy", "Sound sad and melancholy" |
|
|
- *Pace:* "Speak slowly and deliberately", "Talk quickly and energetically" |
|
|
- *Character:* "Sound like a wise professor", "Speak like an excited child" |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("## Text-to-Speech Configuration") |
|
|
transcript_text = gr.Textbox( |
|
|
label="Transcript Text", |
|
|
placeholder="Enter text to synthesize...", |
|
|
lines=3 |
|
|
) |
|
|
with gr.Row(): |
|
|
voice1 = gr.Dropdown( |
|
|
choices=TTS_VOICES, |
|
|
label="Voice 1", |
|
|
value="alloy" |
|
|
) |
|
|
voice2 = gr.Dropdown( |
|
|
choices=TTS_VOICES, |
|
|
label="Voice 2", |
|
|
value="echo" |
|
|
) |
|
|
instructions1 = gr.Textbox( |
|
|
label="Style Instructions for Voice 1", |
|
|
placeholder="e.g., Speak slowly and calmly", |
|
|
lines=2 |
|
|
) |
|
|
instructions2 = gr.Textbox( |
|
|
label="Style Instructions for Voice 2", |
|
|
placeholder="e.g., Speak quickly with excitement", |
|
|
lines=2 |
|
|
) |
|
|
process_btn = gr.Button("Generate & Process", variant="primary") |
|
|
process_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("## Tokenization") |
|
|
tokens_display = gr.Textbox( |
|
|
label="Tokens", |
|
|
lines=16, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("## Generated TTS Audio") |
|
|
generated_audio1 = gr.Audio(label="Generated Audio 1") |
|
|
generated_audio2 = gr.Audio(label="Generated Audio 2") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("## Model Output") |
|
|
baseline_audio = gr.Audio(label="Baseline Reconstruction") |
|
|
|
|
|
gr.Markdown("### Embedding Swap") |
|
|
swap_indices_input = gr.Textbox( |
|
|
label="Token Indices to Swap", |
|
|
placeholder="e.g., 0,2,5" |
|
|
) |
|
|
swap_btn = gr.Button("Perform Swap") |
|
|
swap_status = gr.Textbox(label="Swap Status", interactive=False) |
|
|
swapped_audio = gr.Audio(label="Swapped Result") |
|
|
|
|
|
|
|
|
embeddings_state = gr.State() |
|
|
audio1_state = gr.State() |
|
|
audio2_state = gr.State() |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_inputs, |
|
|
inputs=[transcript_text, voice1, voice2, instructions1, instructions2], |
|
|
outputs=[process_status, tokens_display, generated_audio1, generated_audio2, |
|
|
baseline_audio, embeddings_state, audio1_state, audio2_state] |
|
|
) |
|
|
|
|
|
swap_btn.click( |
|
|
fn=swap_embeddings, |
|
|
inputs=[embeddings_state, swap_indices_input], |
|
|
outputs=[swap_status, swapped_audio] |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function.""" |
|
|
parser = argparse.ArgumentParser(description="HuggingFace Space Demo for TextSyncMimi") |
|
|
parser.add_argument( |
|
|
"--model_id", |
|
|
type=str, |
|
|
default="potsawee/TextSyncMimi-v1", |
|
|
help="HuggingFace model ID" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tokenizer_id", |
|
|
type=str, |
|
|
default="meta-llama/Llama-3.1-8B-Instruct", |
|
|
help="HuggingFace tokenizer ID" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--hf_token", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Hugging Face token (or set HF_TOKEN env var)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--port", |
|
|
type=int, |
|
|
default=7860, |
|
|
help="Port for Gradio app" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--share", |
|
|
action="store_true", |
|
|
help="Create public share link" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not os.getenv("OPENAI_API_KEY"): |
|
|
print("β Error: OPENAI_API_KEY environment variable is required!") |
|
|
print("Set it: export OPENAI_API_KEY=your_key_here") |
|
|
return |
|
|
|
|
|
|
|
|
hf_token = args.hf_token or os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
print(f"π Initializing TextSyncMimi from HuggingFace Hub: {args.model_id}...") |
|
|
initialize_models(args.model_id, args.tokenizer_id, hf_token) |
|
|
print("π Launching Gradio interface...") |
|
|
|
|
|
|
|
|
interface = create_gradio_interface() |
|
|
interface.launch(server_port=args.port, share=args.share) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|