import spaces from pydub import AudioSegment import os import torchaudio import torch import re from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig from pyannote.audio import Pipeline as DiarizationPipeline import whisperx import whisper_timestamped as whisper_ts from typing import Dict device = 0 if torch.cuda.is_available() else "cpu" torch_dtype = torch.float32 MODEL_PATH_1 = "projecte-aina/whisper-large-v3-tiny-caesar" MODEL_PATH_2 = "langtech-veu/whisper-timestamped-cs" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def clean_text(input_text): remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@', '*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…'] output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text) return ' '.join(output_text.split()).lower() def split_stereo_channels(audio_path): ext = os.path.splitext(audio_path)[1].lower() if ext == ".wav": audio = AudioSegment.from_wav(audio_path) elif ext == ".mp3": audio = AudioSegment.from_file(audio_path, format="mp3") else: raise ValueError(f"Unsupported file format: {audio_path}") channels = audio.split_to_mono() if len(channels) != 2: raise ValueError(f"Audio {audio_path} does not have 2 channels.") channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left def convert_to_mono(input_path): audio = AudioSegment.from_file(input_path) base, ext = os.path.splitext(input_path) output_path = f"{base}_merged.wav" print('output_path',output_path) mono = audio.set_channels(1) mono.export(output_path, format="wav") return output_path def save_temp_audio(waveform, sample_rate, path): waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform torchaudio.save(path, waveform, sample_rate) def format_audio(audio_path): input_audio, sample_rate = torchaudio.load(audio_path) if input_audio.shape[0] == 2: input_audio = torch.mean(input_audio, dim=0, keepdim=True) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) input_audio = resampler(input_audio) print('resampled') return input_audio.squeeze(), 16000 def assign_timestamps(asr_segments, audio_path): waveform, sr = format_audio(audio_path) total_duration = waveform.shape[-1] / sr total_words = sum(len(seg["text"].split()) for seg in asr_segments) if total_words == 0: raise ValueError("Total number of words in ASR segments is zero. Cannot assign timestamps.") avg_word_duration = total_duration / total_words current_time = 0.0 for segment in asr_segments: word_count = len(segment["text"].split()) segment_duration = word_count * avg_word_duration segment["start"] = round(current_time, 3) segment["end"] = round(current_time + segment_duration, 3) current_time += segment_duration return asr_segments def hf_chunks_to_whisperx_segments(chunks): return [ { "text": chunk["text"], "start": chunk["timestamp"][0], "end": chunk["timestamp"][1], } for chunk in chunks if chunk["timestamp"] and isinstance(chunk["timestamp"], (list, tuple)) ] def align_words_to_segments(words, segments, window=5.0): aligned = [] seg_idx = 0 for word in words: while seg_idx < len(segments) and segments[seg_idx]["end"] < word["start"] - window: seg_idx += 1 for j in range(seg_idx, len(segments)): seg = segments[j] if seg["start"] > word["end"] + window: break if seg["start"] <= word["start"] < seg["end"]: aligned.append((word, seg)) break return aligned def post_process_transcription(transcription, max_repeats=2): tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) cleaned_tokens = [] repetition_count = 0 previous_token = None for token in tokens: reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) if reduced_token == previous_token: repetition_count += 1 if repetition_count <= max_repeats: cleaned_tokens.append(reduced_token) else: repetition_count = 1 cleaned_tokens.append(reduced_token) previous_token = reduced_token cleaned_transcription = " ".join(cleaned_tokens) cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() return cleaned_transcription def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) merged_transcription = '' current_speaker = None current_segment = [] for i in range(1, len(segments) - 1, 2): speaker_tag = segments[i] text = segments[i + 1].strip() speaker = re.search(r'\d{2}', speaker_tag).group() if speaker == current_speaker: current_segment.append(text) else: if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' current_speaker = speaker current_segment = [text] if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' return merged_transcription.strip() def cleanup_temp_files(*file_paths): for path in file_paths: if path and os.path.exists(path): os.remove(path) def load_whisper_model(model_path: str): device = "cuda" if torch.cuda.is_available() else "cpu" model = whisper_ts.load_model(model_path, device=device) return model def transcribe_audio(model, audio_path: str) -> Dict: try: result = whisper_ts.transcribe( model, audio_path, beam_size=5, best_of=5, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), vad=False, detect_disfluencies=True, ) words = [] for segment in result.get('segments', []): for word in segment.get('words', []): word_text = word.get('word', '').strip() if word_text.startswith(' '): word_text = word_text[1:] words.append({ 'word': word_text, 'start': word.get('start', 0), 'end': word.get('end', 0), 'confidence': word.get('confidence', 0) }) return { 'audio_path': audio_path, 'text': result['text'].strip(), 'segments': result.get('segments', []), 'words': words, 'duration': result.get('duration', 0), 'success': True } except Exception as e: return { 'audio_path': audio_path, 'error': str(e), 'success': False } diarization_pipeline = DiarizationPipeline.from_pretrained("./pyannote/config.yaml") align_model, metadata = whisperx.load_align_model(language_code="en", device=DEVICE) asr_pipe = pipeline( task="automatic-speech-recognition", model=MODEL_PATH_1, chunk_length_s=30, device=DEVICE, return_timestamps=True) def diarization(audio_path): diarization_result = diarization_pipeline(audio_path) diarized_segments = list(diarization_result.itertracks(yield_label=True)) print('diarized_segments',diarized_segments) return diarized_segments def asr(audio_path): print(f"[DEBUG] Starting ASR on audio: {audio_path}") asr_result = asr_pipe(audio_path, return_timestamps=True) print(f"[DEBUG] Raw ASR result: {asr_result}") asr_segments = hf_chunks_to_whisperx_segments(asr_result['chunks']) asr_segments = assign_timestamps(asr_segments, audio_path) return asr_segments def align_asr_to_diarization(asr_segments, diarized_segments, audio_path): waveform, sample_rate = format_audio(audio_path) word_segments = whisperx.align(asr_segments, align_model, metadata, waveform, DEVICE) words = word_segments['word_segments'] diarized = [{"start": segment.start,"end": segment.end,"speaker": speaker} for segment, _, speaker in diarized_segments] aligned_pairs = align_words_to_segments(words, diarized) output = [] segment_map = {} for word, segment in aligned_pairs: key = (segment["start"], segment["end"], segment["speaker"]) if key not in segment_map: segment_map[key] = [] segment_map[key].append(word["word"]) for (start, end, speaker), words in sorted(segment_map.items()): output.append(f"[{speaker}] {' '.join(words)}") return output def generate(audio_path, use_v2): if use_v2: model = load_whisper_model(MODEL_PATH_2) split_stereo_channels(audio_path) left_channel_path = "temp_mono_speaker2.wav" right_channel_path = "temp_mono_speaker1.wav" left_waveform, left_sr = format_audio(left_channel_path) right_waveform, right_sr = format_audio(right_channel_path) left_result = transcribe_audio(model, left_waveform) right_result = transcribe_audio(model, right_waveform) def get_segments(result, speaker_label): segments = result.get("segments", []) if not segments: return [] return [ (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip())) for seg in segments if seg.get("text") ] left_segs = get_segments(left_result, "Speaker 1") right_segs = get_segments(right_result, "Speaker 2") merged_transcript = sorted( left_segs + right_segs, key=lambda x: float(x[0]) if x[0] is not None else float("inf") ) output = "" for start, end, speaker, text in merged_transcript: output += f"[{speaker}]: {text}\n" clean_output = output.strip() else: mono_audio_path = convert_to_mono(audio_path) waveform, sr = format_audio(mono_audio_path) tmp_full_path = "tmp_full.wav" save_temp_audio(waveform, sr, tmp_full_path) diarized_segments = diarization(tmp_full_path) asr_segments = asr(tmp_full_path) for segment in asr_segments: segment["text"] = post_process_transcription(segment["text"]) aligned_text = align_asr_to_diarization(asr_segments, diarized_segments, tmp_full_path) clean_output = "" for line in aligned_text: clean_output += f"{line}\n" clean_output = post_merge_consecutive_segments_from_text(clean_output) cleanup_temp_files(mono_audio_path,tmp_full_path) cleanup_temp_files( "temp_mono_speaker1.wav", "temp_mono_speaker2.wav" ) return clean_output