Spaces:
Running
Running
import spaces | |
from pydub import AudioSegment | |
import os | |
import torchaudio | |
import torch | |
import re | |
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig | |
from pyannote.audio import Pipeline as DiarizationPipeline | |
import whisperx | |
import whisper_timestamped as whisper_ts | |
from typing import Dict | |
device = 0 if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float32 | |
MODEL_PATH_1 = "projecte-aina/whisper-large-v3-tiny-caesar" | |
MODEL_PATH_2 = "langtech-veu/whisper-timestamped-cs" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
def clean_text(input_text): | |
remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@', | |
'*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…'] | |
output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text) | |
return ' '.join(output_text.split()).lower() | |
def split_stereo_channels(audio_path): | |
ext = os.path.splitext(audio_path)[1].lower() | |
if ext == ".wav": | |
audio = AudioSegment.from_wav(audio_path) | |
elif ext == ".mp3": | |
audio = AudioSegment.from_file(audio_path, format="mp3") | |
else: | |
raise ValueError(f"Unsupported file format: {audio_path}") | |
channels = audio.split_to_mono() | |
if len(channels) != 2: | |
raise ValueError(f"Audio {audio_path} does not have 2 channels.") | |
channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right | |
channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left | |
def convert_to_mono(input_path): | |
audio = AudioSegment.from_file(input_path) | |
base, ext = os.path.splitext(input_path) | |
output_path = f"{base}_merged.wav" | |
print('output_path',output_path) | |
mono = audio.set_channels(1) | |
mono.export(output_path, format="wav") | |
return output_path | |
def save_temp_audio(waveform, sample_rate, path): | |
waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform | |
torchaudio.save(path, waveform, sample_rate) | |
def format_audio(audio_path): | |
input_audio, sample_rate = torchaudio.load(audio_path) | |
if input_audio.shape[0] == 2: | |
input_audio = torch.mean(input_audio, dim=0, keepdim=True) | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
input_audio = resampler(input_audio) | |
print('resampled') | |
return input_audio.squeeze(), 16000 | |
def assign_timestamps(asr_segments, audio_path): | |
waveform, sr = format_audio(audio_path) | |
total_duration = waveform.shape[-1] / sr | |
total_words = sum(len(seg["text"].split()) for seg in asr_segments) | |
if total_words == 0: | |
raise ValueError("Total number of words in ASR segments is zero. Cannot assign timestamps.") | |
avg_word_duration = total_duration / total_words | |
current_time = 0.0 | |
for segment in asr_segments: | |
word_count = len(segment["text"].split()) | |
segment_duration = word_count * avg_word_duration | |
segment["start"] = round(current_time, 3) | |
segment["end"] = round(current_time + segment_duration, 3) | |
current_time += segment_duration | |
return asr_segments | |
def hf_chunks_to_whisperx_segments(chunks): | |
return [ | |
{ | |
"text": chunk["text"], | |
"start": chunk["timestamp"][0], | |
"end": chunk["timestamp"][1], | |
} | |
for chunk in chunks | |
if chunk["timestamp"] and isinstance(chunk["timestamp"], (list, tuple)) | |
] | |
def align_words_to_segments(words, segments, window=5.0): | |
aligned = [] | |
seg_idx = 0 | |
for word in words: | |
while seg_idx < len(segments) and segments[seg_idx]["end"] < word["start"] - window: | |
seg_idx += 1 | |
for j in range(seg_idx, len(segments)): | |
seg = segments[j] | |
if seg["start"] > word["end"] + window: | |
break | |
if seg["start"] <= word["start"] < seg["end"]: | |
aligned.append((word, seg)) | |
break | |
return aligned | |
def post_process_transcription(transcription, max_repeats=2): | |
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) | |
cleaned_tokens = [] | |
repetition_count = 0 | |
previous_token = None | |
for token in tokens: | |
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) | |
if reduced_token == previous_token: | |
repetition_count += 1 | |
if repetition_count <= max_repeats: | |
cleaned_tokens.append(reduced_token) | |
else: | |
repetition_count = 1 | |
cleaned_tokens.append(reduced_token) | |
previous_token = reduced_token | |
cleaned_transcription = " ".join(cleaned_tokens) | |
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() | |
return cleaned_transcription | |
def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: | |
segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) | |
merged_transcription = '' | |
current_speaker = None | |
current_segment = [] | |
for i in range(1, len(segments) - 1, 2): | |
speaker_tag = segments[i] | |
text = segments[i + 1].strip() | |
speaker = re.search(r'\d{2}', speaker_tag).group() | |
if speaker == current_speaker: | |
current_segment.append(text) | |
else: | |
if current_speaker is not None: | |
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' | |
current_speaker = speaker | |
current_segment = [text] | |
if current_speaker is not None: | |
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' | |
return merged_transcription.strip() | |
def cleanup_temp_files(*file_paths): | |
for path in file_paths: | |
if path and os.path.exists(path): | |
os.remove(path) | |
def load_whisper_model(model_path: str): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = whisper_ts.load_model(model_path, device=device) | |
return model | |
def transcribe_audio(model, audio_path: str) -> Dict: | |
try: | |
result = whisper_ts.transcribe( | |
model, | |
audio_path, | |
beam_size=5, | |
best_of=5, | |
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), | |
vad=False, | |
detect_disfluencies=True, | |
) | |
words = [] | |
for segment in result.get('segments', []): | |
for word in segment.get('words', []): | |
word_text = word.get('word', '').strip() | |
if word_text.startswith(' '): | |
word_text = word_text[1:] | |
words.append({ | |
'word': word_text, | |
'start': word.get('start', 0), | |
'end': word.get('end', 0), | |
'confidence': word.get('confidence', 0) | |
}) | |
return { | |
'audio_path': audio_path, | |
'text': result['text'].strip(), | |
'segments': result.get('segments', []), | |
'words': words, | |
'duration': result.get('duration', 0), | |
'success': True | |
} | |
except Exception as e: | |
return { | |
'audio_path': audio_path, | |
'error': str(e), | |
'success': False | |
} | |
diarization_pipeline = DiarizationPipeline.from_pretrained("./pyannote/config.yaml") | |
align_model, metadata = whisperx.load_align_model(language_code="en", device=DEVICE) | |
asr_pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_PATH_1, | |
chunk_length_s=30, | |
device=DEVICE, | |
return_timestamps=True) | |
def diarization(audio_path): | |
diarization_result = diarization_pipeline(audio_path) | |
diarized_segments = list(diarization_result.itertracks(yield_label=True)) | |
print('diarized_segments',diarized_segments) | |
return diarized_segments | |
def asr(audio_path): | |
print(f"[DEBUG] Starting ASR on audio: {audio_path}") | |
asr_result = asr_pipe(audio_path, return_timestamps=True) | |
print(f"[DEBUG] Raw ASR result: {asr_result}") | |
asr_segments = hf_chunks_to_whisperx_segments(asr_result['chunks']) | |
asr_segments = assign_timestamps(asr_segments, audio_path) | |
return asr_segments | |
def align_asr_to_diarization(asr_segments, diarized_segments, audio_path): | |
waveform, sample_rate = format_audio(audio_path) | |
word_segments = whisperx.align(asr_segments, align_model, metadata, waveform, DEVICE) | |
words = word_segments['word_segments'] | |
diarized = [{"start": segment.start,"end": segment.end,"speaker": speaker} for segment, _, speaker in diarized_segments] | |
aligned_pairs = align_words_to_segments(words, diarized) | |
output = [] | |
segment_map = {} | |
for word, segment in aligned_pairs: | |
key = (segment["start"], segment["end"], segment["speaker"]) | |
if key not in segment_map: | |
segment_map[key] = [] | |
segment_map[key].append(word["word"]) | |
for (start, end, speaker), words in sorted(segment_map.items()): | |
output.append(f"[{speaker}] {' '.join(words)}") | |
return output | |
def generate(audio_path, use_v2): | |
if use_v2: | |
model = load_whisper_model(MODEL_PATH_2) | |
split_stereo_channels(audio_path) | |
left_channel_path = "temp_mono_speaker2.wav" | |
right_channel_path = "temp_mono_speaker1.wav" | |
left_waveform, left_sr = format_audio(left_channel_path) | |
right_waveform, right_sr = format_audio(right_channel_path) | |
left_result = transcribe_audio(model, left_waveform) | |
right_result = transcribe_audio(model, right_waveform) | |
def get_segments(result, speaker_label): | |
segments = result.get("segments", []) | |
if not segments: | |
return [] | |
return [ | |
(seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip())) | |
for seg in segments if seg.get("text") | |
] | |
left_segs = get_segments(left_result, "Speaker 1") | |
right_segs = get_segments(right_result, "Speaker 2") | |
merged_transcript = sorted( | |
left_segs + right_segs, | |
key=lambda x: float(x[0]) if x[0] is not None else float("inf") | |
) | |
output = "" | |
for start, end, speaker, text in merged_transcript: | |
output += f"[{speaker}]: {text}\n" | |
clean_output = output.strip() | |
else: | |
mono_audio_path = convert_to_mono(audio_path) | |
waveform, sr = format_audio(mono_audio_path) | |
tmp_full_path = "tmp_full.wav" | |
save_temp_audio(waveform, sr, tmp_full_path) | |
diarized_segments = diarization(tmp_full_path) | |
asr_segments = asr(tmp_full_path) | |
for segment in asr_segments: | |
segment["text"] = post_process_transcription(segment["text"]) | |
aligned_text = align_asr_to_diarization(asr_segments, diarized_segments, tmp_full_path) | |
clean_output = "" | |
for line in aligned_text: | |
clean_output += f"{line}\n" | |
clean_output = post_merge_consecutive_segments_from_text(clean_output) | |
cleanup_temp_files(mono_audio_path,tmp_full_path) | |
cleanup_temp_files( | |
"temp_mono_speaker1.wav", | |
"temp_mono_speaker2.wav" | |
) | |
return clean_output |