ssolito commited on
Commit
9dde820
·
verified ·
1 Parent(s): 88dfed1

Create whisper_cs.py

Browse files
Files changed (1) hide show
  1. whisper_cs.py +326 -0
whisper_cs.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from pydub import AudioSegment
3
+ import os
4
+ import torchaudio
5
+ import torch
6
+ import re
7
+ from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig
8
+ from pyannote.audio import Pipeline as DiarizationPipeline
9
+ import whisperx
10
+ import whisper_timestamped as whisper_ts
11
+ from typing import Dict
12
+
13
+ device = 0 if torch.cuda.is_available() else "cpu"
14
+ torch_dtype = torch.float32
15
+
16
+ MODEL_PATH_1 = "projecte-aina/whisper-large-v3-tiny-caesar"
17
+ MODEL_PATH_2 = "langtech-veu/whisper-timestamped-cs"
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ def clean_text(input_text):
21
+ remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
22
+ '*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…']
23
+ output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text)
24
+ return ' '.join(output_text.split()).lower()
25
+
26
+
27
+ def split_stereo_channels(audio_path):
28
+ ext = os.path.splitext(audio_path)[1].lower()
29
+
30
+ if ext == ".wav":
31
+ audio = AudioSegment.from_wav(audio_path)
32
+ elif ext == ".mp3":
33
+ audio = AudioSegment.from_file(audio_path, format="mp3")
34
+ else:
35
+ raise ValueError(f"Unsupported file format: {audio_path}")
36
+
37
+ channels = audio.split_to_mono()
38
+ if len(channels) != 2:
39
+ raise ValueError(f"Audio {audio_path} does not have 2 channels.")
40
+
41
+ channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
42
+ channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
43
+
44
+
45
+ def convert_to_mono(input_path):
46
+ audio = AudioSegment.from_file(input_path)
47
+ base, ext = os.path.splitext(input_path)
48
+ output_path = f"{base}_merged.wav"
49
+ print('output_path',output_path)
50
+ mono = audio.set_channels(1)
51
+ mono.export(output_path, format="wav")
52
+ return output_path
53
+
54
+ def save_temp_audio(waveform, sample_rate, path):
55
+ waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform
56
+ torchaudio.save(path, waveform, sample_rate)
57
+
58
+ def format_audio(audio_path):
59
+ input_audio, sample_rate = torchaudio.load(audio_path)
60
+ if input_audio.shape[0] == 2:
61
+ input_audio = torch.mean(input_audio, dim=0, keepdim=True)
62
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
63
+ input_audio = resampler(input_audio)
64
+ print('resampled')
65
+ return input_audio.squeeze(), 16000
66
+
67
+ def assign_timestamps(asr_segments, audio_path):
68
+ waveform, sr = format_audio(audio_path)
69
+ total_duration = waveform.shape[-1] / sr
70
+
71
+ total_words = sum(len(seg["text"].split()) for seg in asr_segments)
72
+ if total_words == 0:
73
+ raise ValueError("Total number of words in ASR segments is zero. Cannot assign timestamps.")
74
+
75
+ avg_word_duration = total_duration / total_words
76
+
77
+ current_time = 0.0
78
+ for segment in asr_segments:
79
+ word_count = len(segment["text"].split())
80
+ segment_duration = word_count * avg_word_duration
81
+ segment["start"] = round(current_time, 3)
82
+ segment["end"] = round(current_time + segment_duration, 3)
83
+ current_time += segment_duration
84
+
85
+ return asr_segments
86
+
87
+ def hf_chunks_to_whisperx_segments(chunks):
88
+ return [
89
+ {
90
+ "text": chunk["text"],
91
+ "start": chunk["timestamp"][0],
92
+ "end": chunk["timestamp"][1],
93
+ }
94
+ for chunk in chunks
95
+ if chunk["timestamp"] and isinstance(chunk["timestamp"], (list, tuple))
96
+ ]
97
+
98
+ def align_words_to_segments(words, segments, window=5.0):
99
+ aligned = []
100
+ seg_idx = 0
101
+ for word in words:
102
+ while seg_idx < len(segments) and segments[seg_idx]["end"] < word["start"] - window:
103
+ seg_idx += 1
104
+ for j in range(seg_idx, len(segments)):
105
+ seg = segments[j]
106
+ if seg["start"] > word["end"] + window:
107
+ break
108
+ if seg["start"] <= word["start"] < seg["end"]:
109
+ aligned.append((word, seg))
110
+ break
111
+ return aligned
112
+
113
+ def post_process_transcription(transcription, max_repeats=2):
114
+ tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
115
+
116
+ cleaned_tokens = []
117
+ repetition_count = 0
118
+ previous_token = None
119
+
120
+ for token in tokens:
121
+ reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
122
+
123
+ if reduced_token == previous_token:
124
+ repetition_count += 1
125
+ if repetition_count <= max_repeats:
126
+ cleaned_tokens.append(reduced_token)
127
+ else:
128
+ repetition_count = 1
129
+ cleaned_tokens.append(reduced_token)
130
+
131
+ previous_token = reduced_token
132
+
133
+ cleaned_transcription = " ".join(cleaned_tokens)
134
+ cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
135
+
136
+ return cleaned_transcription
137
+
138
+
139
+ def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
140
+ segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
141
+ merged_transcription = ''
142
+ current_speaker = None
143
+ current_segment = []
144
+
145
+ for i in range(1, len(segments) - 1, 2):
146
+ speaker_tag = segments[i]
147
+ text = segments[i + 1].strip()
148
+
149
+ speaker = re.search(r'\d{2}', speaker_tag).group()
150
+
151
+ if speaker == current_speaker:
152
+ current_segment.append(text)
153
+ else:
154
+ if current_speaker is not None:
155
+ merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
156
+ current_speaker = speaker
157
+ current_segment = [text]
158
+
159
+ if current_speaker is not None:
160
+ merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
161
+
162
+ return merged_transcription.strip()
163
+
164
+ def cleanup_temp_files(*file_paths):
165
+ for path in file_paths:
166
+ if path and os.path.exists(path):
167
+ os.remove(path)
168
+
169
+
170
+
171
+ def load_whisper_model(model_path: str):
172
+ device = "cuda" if torch.cuda.is_available() else "cpu"
173
+ model = whisper_ts.load_model(model_path, device=device)
174
+ return model
175
+
176
+ def transcribe_audio(model, audio_path: str) -> Dict:
177
+ try:
178
+ result = whisper_ts.transcribe(
179
+ model,
180
+ audio_path,
181
+ beam_size=5,
182
+ best_of=5,
183
+ temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
184
+ vad=False,
185
+ detect_disfluencies=True,
186
+ )
187
+
188
+ words = []
189
+ for segment in result.get('segments', []):
190
+ for word in segment.get('words', []):
191
+ word_text = word.get('word', '').strip()
192
+ if word_text.startswith(' '):
193
+ word_text = word_text[1:]
194
+
195
+ words.append({
196
+ 'word': word_text,
197
+ 'start': word.get('start', 0),
198
+ 'end': word.get('end', 0),
199
+ 'confidence': word.get('confidence', 0)
200
+ })
201
+
202
+ return {
203
+ 'audio_path': audio_path,
204
+ 'text': result['text'].strip(),
205
+ 'segments': result.get('segments', []),
206
+ 'words': words,
207
+ 'duration': result.get('duration', 0),
208
+ 'success': True
209
+ }
210
+
211
+ except Exception as e:
212
+ return {
213
+ 'audio_path': audio_path,
214
+ 'error': str(e),
215
+ 'success': False
216
+ }
217
+
218
+
219
+
220
+ diarization_pipeline = DiarizationPipeline.from_pretrained("./pyannote/config.yaml")
221
+ align_model, metadata = whisperx.load_align_model(language_code="en", device=DEVICE)
222
+
223
+ asr_pipe = pipeline(
224
+ task="automatic-speech-recognition",
225
+ model=MODEL_PATH_1,
226
+ chunk_length_s=30,
227
+ device=DEVICE,
228
+ return_timestamps=True)
229
+
230
+ def diarization(audio_path):
231
+ diarization_result = diarization_pipeline(audio_path)
232
+ diarized_segments = list(diarization_result.itertracks(yield_label=True))
233
+ print('diarized_segments',diarized_segments)
234
+ return diarized_segments
235
+
236
+ def asr(audio_path):
237
+ print(f"[DEBUG] Starting ASR on audio: {audio_path}")
238
+ asr_result = asr_pipe(audio_path, return_timestamps=True)
239
+ print(f"[DEBUG] Raw ASR result: {asr_result}")
240
+ asr_segments = hf_chunks_to_whisperx_segments(asr_result['chunks'])
241
+ asr_segments = assign_timestamps(asr_segments, audio_path)
242
+ return asr_segments
243
+
244
+ def align_asr_to_diarization(asr_segments, diarized_segments, audio_path):
245
+ waveform, sample_rate = format_audio(audio_path)
246
+
247
+ word_segments = whisperx.align(asr_segments, align_model, metadata, waveform, DEVICE)
248
+ words = word_segments['word_segments']
249
+
250
+ diarized = [{"start": segment.start,"end": segment.end,"speaker": speaker} for segment, _, speaker in diarized_segments]
251
+
252
+ aligned_pairs = align_words_to_segments(words, diarized)
253
+
254
+ output = []
255
+ segment_map = {}
256
+ for word, segment in aligned_pairs:
257
+ key = (segment["start"], segment["end"], segment["speaker"])
258
+ if key not in segment_map:
259
+ segment_map[key] = []
260
+ segment_map[key].append(word["word"])
261
+
262
+ for (start, end, speaker), words in sorted(segment_map.items()):
263
+ output.append(f"[{speaker}] {' '.join(words)}")
264
+
265
+ return output
266
+
267
+ def generate(audio_path, use_v2):
268
+
269
+ if use_v2:
270
+ model = load_whisper_model(MODEL_PATH_2)
271
+ split_stereo_channels(audio_path)
272
+
273
+ left_channel_path = "temp_mono_speaker2.wav"
274
+ right_channel_path = "temp_mono_speaker1.wav"
275
+
276
+ left_waveform, left_sr = format_audio(left_channel_path)
277
+ right_waveform, right_sr = format_audio(right_channel_path)
278
+ left_result = transcribe_audio(model, left_waveform)
279
+ right_result = transcribe_audio(model, right_waveform)
280
+
281
+ def get_segments(result, speaker_label):
282
+ segments = result.get("segments", [])
283
+ if not segments:
284
+ return []
285
+ return [
286
+ (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip()))
287
+ for seg in segments if seg.get("text")
288
+ ]
289
+
290
+ left_segs = get_segments(left_result, "Speaker 1")
291
+ right_segs = get_segments(right_result, "Speaker 2")
292
+
293
+ merged_transcript = sorted(
294
+ left_segs + right_segs,
295
+ key=lambda x: float(x[0]) if x[0] is not None else float("inf")
296
+ )
297
+
298
+ output = ""
299
+ for start, end, speaker, text in merged_transcript:
300
+ output += f"[{speaker}]: {text}\n"
301
+
302
+ clean_output = output.strip()
303
+
304
+ else:
305
+ mono_audio_path = convert_to_mono(audio_path)
306
+ waveform, sr = format_audio(mono_audio_path)
307
+ tmp_full_path = "tmp_full.wav"
308
+ save_temp_audio(waveform, sr, tmp_full_path)
309
+ diarized_segments = diarization(tmp_full_path)
310
+ asr_segments = asr(tmp_full_path)
311
+ for segment in asr_segments:
312
+ segment["text"] = post_process_transcription(segment["text"])
313
+ aligned_text = align_asr_to_diarization(asr_segments, diarized_segments, tmp_full_path)
314
+
315
+ clean_output = ""
316
+ for line in aligned_text:
317
+ clean_output += f"{line}\n"
318
+ clean_output = post_merge_consecutive_segments_from_text(clean_output)
319
+ cleanup_temp_files(mono_audio_path,tmp_full_path)
320
+
321
+ cleanup_temp_files(
322
+ "temp_mono_speaker1.wav",
323
+ "temp_mono_speaker2.wav"
324
+ )
325
+
326
+ return clean_output