asr-inference-cs / whisper_cs.py
ssolito's picture
Create whisper_cs.py
9dde820 verified
import spaces
from pydub import AudioSegment
import os
import torchaudio
import torch
import re
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig
from pyannote.audio import Pipeline as DiarizationPipeline
import whisperx
import whisper_timestamped as whisper_ts
from typing import Dict
device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32
MODEL_PATH_1 = "projecte-aina/whisper-large-v3-tiny-caesar"
MODEL_PATH_2 = "langtech-veu/whisper-timestamped-cs"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def clean_text(input_text):
remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
'*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…']
output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text)
return ' '.join(output_text.split()).lower()
def split_stereo_channels(audio_path):
ext = os.path.splitext(audio_path)[1].lower()
if ext == ".wav":
audio = AudioSegment.from_wav(audio_path)
elif ext == ".mp3":
audio = AudioSegment.from_file(audio_path, format="mp3")
else:
raise ValueError(f"Unsupported file format: {audio_path}")
channels = audio.split_to_mono()
if len(channels) != 2:
raise ValueError(f"Audio {audio_path} does not have 2 channels.")
channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
def convert_to_mono(input_path):
audio = AudioSegment.from_file(input_path)
base, ext = os.path.splitext(input_path)
output_path = f"{base}_merged.wav"
print('output_path',output_path)
mono = audio.set_channels(1)
mono.export(output_path, format="wav")
return output_path
def save_temp_audio(waveform, sample_rate, path):
waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform
torchaudio.save(path, waveform, sample_rate)
def format_audio(audio_path):
input_audio, sample_rate = torchaudio.load(audio_path)
if input_audio.shape[0] == 2:
input_audio = torch.mean(input_audio, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
input_audio = resampler(input_audio)
print('resampled')
return input_audio.squeeze(), 16000
def assign_timestamps(asr_segments, audio_path):
waveform, sr = format_audio(audio_path)
total_duration = waveform.shape[-1] / sr
total_words = sum(len(seg["text"].split()) for seg in asr_segments)
if total_words == 0:
raise ValueError("Total number of words in ASR segments is zero. Cannot assign timestamps.")
avg_word_duration = total_duration / total_words
current_time = 0.0
for segment in asr_segments:
word_count = len(segment["text"].split())
segment_duration = word_count * avg_word_duration
segment["start"] = round(current_time, 3)
segment["end"] = round(current_time + segment_duration, 3)
current_time += segment_duration
return asr_segments
def hf_chunks_to_whisperx_segments(chunks):
return [
{
"text": chunk["text"],
"start": chunk["timestamp"][0],
"end": chunk["timestamp"][1],
}
for chunk in chunks
if chunk["timestamp"] and isinstance(chunk["timestamp"], (list, tuple))
]
def align_words_to_segments(words, segments, window=5.0):
aligned = []
seg_idx = 0
for word in words:
while seg_idx < len(segments) and segments[seg_idx]["end"] < word["start"] - window:
seg_idx += 1
for j in range(seg_idx, len(segments)):
seg = segments[j]
if seg["start"] > word["end"] + window:
break
if seg["start"] <= word["start"] < seg["end"]:
aligned.append((word, seg))
break
return aligned
def post_process_transcription(transcription, max_repeats=2):
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
cleaned_tokens = []
repetition_count = 0
previous_token = None
for token in tokens:
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
if reduced_token == previous_token:
repetition_count += 1
if repetition_count <= max_repeats:
cleaned_tokens.append(reduced_token)
else:
repetition_count = 1
cleaned_tokens.append(reduced_token)
previous_token = reduced_token
cleaned_transcription = " ".join(cleaned_tokens)
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
return cleaned_transcription
def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
merged_transcription = ''
current_speaker = None
current_segment = []
for i in range(1, len(segments) - 1, 2):
speaker_tag = segments[i]
text = segments[i + 1].strip()
speaker = re.search(r'\d{2}', speaker_tag).group()
if speaker == current_speaker:
current_segment.append(text)
else:
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
current_speaker = speaker
current_segment = [text]
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
return merged_transcription.strip()
def cleanup_temp_files(*file_paths):
for path in file_paths:
if path and os.path.exists(path):
os.remove(path)
def load_whisper_model(model_path: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper_ts.load_model(model_path, device=device)
return model
def transcribe_audio(model, audio_path: str) -> Dict:
try:
result = whisper_ts.transcribe(
model,
audio_path,
beam_size=5,
best_of=5,
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
vad=False,
detect_disfluencies=True,
)
words = []
for segment in result.get('segments', []):
for word in segment.get('words', []):
word_text = word.get('word', '').strip()
if word_text.startswith(' '):
word_text = word_text[1:]
words.append({
'word': word_text,
'start': word.get('start', 0),
'end': word.get('end', 0),
'confidence': word.get('confidence', 0)
})
return {
'audio_path': audio_path,
'text': result['text'].strip(),
'segments': result.get('segments', []),
'words': words,
'duration': result.get('duration', 0),
'success': True
}
except Exception as e:
return {
'audio_path': audio_path,
'error': str(e),
'success': False
}
diarization_pipeline = DiarizationPipeline.from_pretrained("./pyannote/config.yaml")
align_model, metadata = whisperx.load_align_model(language_code="en", device=DEVICE)
asr_pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_PATH_1,
chunk_length_s=30,
device=DEVICE,
return_timestamps=True)
def diarization(audio_path):
diarization_result = diarization_pipeline(audio_path)
diarized_segments = list(diarization_result.itertracks(yield_label=True))
print('diarized_segments',diarized_segments)
return diarized_segments
def asr(audio_path):
print(f"[DEBUG] Starting ASR on audio: {audio_path}")
asr_result = asr_pipe(audio_path, return_timestamps=True)
print(f"[DEBUG] Raw ASR result: {asr_result}")
asr_segments = hf_chunks_to_whisperx_segments(asr_result['chunks'])
asr_segments = assign_timestamps(asr_segments, audio_path)
return asr_segments
def align_asr_to_diarization(asr_segments, diarized_segments, audio_path):
waveform, sample_rate = format_audio(audio_path)
word_segments = whisperx.align(asr_segments, align_model, metadata, waveform, DEVICE)
words = word_segments['word_segments']
diarized = [{"start": segment.start,"end": segment.end,"speaker": speaker} for segment, _, speaker in diarized_segments]
aligned_pairs = align_words_to_segments(words, diarized)
output = []
segment_map = {}
for word, segment in aligned_pairs:
key = (segment["start"], segment["end"], segment["speaker"])
if key not in segment_map:
segment_map[key] = []
segment_map[key].append(word["word"])
for (start, end, speaker), words in sorted(segment_map.items()):
output.append(f"[{speaker}] {' '.join(words)}")
return output
def generate(audio_path, use_v2):
if use_v2:
model = load_whisper_model(MODEL_PATH_2)
split_stereo_channels(audio_path)
left_channel_path = "temp_mono_speaker2.wav"
right_channel_path = "temp_mono_speaker1.wav"
left_waveform, left_sr = format_audio(left_channel_path)
right_waveform, right_sr = format_audio(right_channel_path)
left_result = transcribe_audio(model, left_waveform)
right_result = transcribe_audio(model, right_waveform)
def get_segments(result, speaker_label):
segments = result.get("segments", [])
if not segments:
return []
return [
(seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip()))
for seg in segments if seg.get("text")
]
left_segs = get_segments(left_result, "Speaker 1")
right_segs = get_segments(right_result, "Speaker 2")
merged_transcript = sorted(
left_segs + right_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
output = ""
for start, end, speaker, text in merged_transcript:
output += f"[{speaker}]: {text}\n"
clean_output = output.strip()
else:
mono_audio_path = convert_to_mono(audio_path)
waveform, sr = format_audio(mono_audio_path)
tmp_full_path = "tmp_full.wav"
save_temp_audio(waveform, sr, tmp_full_path)
diarized_segments = diarization(tmp_full_path)
asr_segments = asr(tmp_full_path)
for segment in asr_segments:
segment["text"] = post_process_transcription(segment["text"])
aligned_text = align_asr_to_diarization(asr_segments, diarized_segments, tmp_full_path)
clean_output = ""
for line in aligned_text:
clean_output += f"{line}\n"
clean_output = post_merge_consecutive_segments_from_text(clean_output)
cleanup_temp_files(mono_audio_path,tmp_full_path)
cleanup_temp_files(
"temp_mono_speaker1.wav",
"temp_mono_speaker2.wav"
)
return clean_output