maiko-parakeet / app.py
bluenevus's picture
Update app.py
8c0d5a5 verified
import os
import time
import gradio as gr
import numpy as np
import librosa
import soundfile as sf
from twilio.rest import Client
from twilio.twiml.voice_response import VoiceResponse, Dial
import requests
from datetime import datetime
import tempfile
from nemo.collections.asr.models import ASRModel
import torch
import gradio.themes as gr_themes
import csv
from pathlib import Path
import shutil
import gc
import re
import threading
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from docx import Document
from pydub import AudioSegment
# ========== Twilio Functions ==========
def get_twilio_credentials():
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
twilio_number = os.environ.get("TWILIO_PHONE_NUMBER")
return account_sid, auth_token, twilio_number
def make_conference_call(phone_number, conference_code, wait_time=30):
try:
account_sid, auth_token, twilio_number = get_twilio_credentials()
if not all([account_sid, auth_token, twilio_number]):
return None, "Twilio credentials not found. Please set environment variables."
client = Client(account_sid, auth_token)
response = VoiceResponse()
response.say("Joining conference call. This call will be recorded for diarization.")
response.pause(length=2)
if conference_code:
for digit in conference_code:
if digit.isdigit() or digit in ['*', '#']:
response.play(digits=digit)
response.pause(length=1)
response.record(timeout=0, transcribe=False, recording_status_callback="/recording-status")
dial = Dial()
dial.conference('ConferenceRoom', record='record-from-start', recording_status_callback="/recording-status")
response.append(dial)
call = client.calls.create(
to=phone_number,
from_=twilio_number,
twiml=str(response),
record=True
)
return call.sid, f"Call initiated with SID: {call.sid}. Wait for the call to complete before retrieving the recording."
except Exception as e:
return None, f"Error initiating call: {str(e)}"
def check_call_status(call_sid):
try:
account_sid, auth_token, _ = get_twilio_credentials()
if not all([account_sid, auth_token]):
return None, "Twilio credentials not found. Please set environment variables."
client = Client(account_sid, auth_token)
call = client.calls(call_sid).fetch()
if call.status in ['in-progress', 'queued', 'ringing']:
return None, f"Call is still {call.status}. Please check again later."
recordings = client.recordings.list(call_sid=call_sid)
if not recordings:
return None, "No recordings found for this call yet. Please check again later."
recording = recordings[0]
recording_url = f"https://api.twilio.com/2010-04-01/Accounts/{account_sid}/Recordings/{recording.sid}.wav"
response = requests.get(recording_url, auth=(account_sid, auth_token))
if response.status_code != 200:
return None, f"Failed to download recording: {response.status_code}"
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
temp_file.write(response.content)
temp_file.close()
return temp_file.name, f"Recording downloaded successfully: {temp_file.name}"
except Exception as e:
return None, f"Error checking call status: {str(e)}"
# ========== Audio Processing ==========
def upsample_to_16k(input_wav):
try:
y, sr = librosa.load(input_wav, sr=None)
if sr != 16000:
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
output_file = f"16k_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
sf.write(output_file, y, 16000)
return output_file, f"Audio upsampled to 16kHz: {output_file}"
except Exception as e:
return None, f"Error upsampling audio: {str(e)}"
# ========== ASR and Meeting Minutes Setup ==========
QWEN_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL)
qwen_model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
if torch.cuda.is_available():
qwen_model = qwen_model.cuda()
qwen_pipe = pipeline(
"text-generation",
model=qwen_model,
tokenizer=qwen_tokenizer,
device=0 if torch.cuda.is_available() else -1,
max_new_tokens=1024,
do_sample=True,
temperature=0.3,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
model = ASRModel.from_pretrained(model_name=MODEL_NAME)
model.eval()
model_lock = threading.Lock()
def start_session(request: gr.Request):
session_hash = request.session_hash
session_dir = Path(f'/tmp/{session_hash}')
session_dir.mkdir(parents=True, exist_ok=True)
print(f"Session with hash {session_hash} started.")
return session_dir.as_posix()
def end_session(request: gr.Request):
session_hash = request.session_hash
session_dir = Path(f'/tmp/{session_hash}')
if session_dir.exists():
shutil.rmtree(session_dir)
print(f"Session with hash {session_hash} ended.")
def get_audio_segment(audio_path, start_second, end_second):
if not audio_path or not Path(audio_path).exists():
print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.")
return None
try:
start_ms = int(start_second * 1000)
end_ms = int(end_second * 1000)
start_ms = max(0, start_ms)
if end_ms <= start_ms:
end_ms = start_ms + 100
audio = AudioSegment.from_file(audio_path)
clipped_audio = audio[start_ms:end_ms]
samples = np.array(clipped_audio.get_array_of_samples())
if clipped_audio.channels == 2:
samples = samples.reshape((-1, 2)).mean(axis=1).astype(samples.dtype)
frame_rate = clipped_audio.frame_rate
if frame_rate <= 0:
frame_rate = audio.frame_rate
if samples.size == 0:
return None
return (frame_rate, samples)
except Exception as e:
print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
return None
def format_srt_time(seconds: float) -> str:
import datetime
sanitized_total_seconds = max(0.0, seconds)
delta = datetime.timedelta(seconds=sanitized_total_seconds)
total_int_seconds = int(delta.total_seconds())
hours = total_int_seconds // 3600
remainder_seconds_after_hours = total_int_seconds % 3600
minutes = remainder_seconds_after_hours // 60
seconds_part = remainder_seconds_after_hours % 60
milliseconds = delta.microseconds // 1000
return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"
def generate_srt_content(segment_timestamps: list) -> str:
srt_content = []
for i, ts in enumerate(segment_timestamps):
start_time = format_srt_time(ts['start'])
end_time = format_srt_time(ts['end'])
text = ts['segment']
srt_content.append(str(i + 1))
srt_content.append(f"{start_time} --> {end_time}")
srt_content.append(text)
srt_content.append("")
return "\n".join(srt_content)
def get_transcripts_and_raw_times(audio_path, session_dir):
import gradio as gr
if not audio_path:
gr.Error("No audio file path provided for transcription.", duration=None)
return [], [], None, gr.DownloadButton(label="Download Transcript (CSV)", visible=False), gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
vis_data = [["N/A", "N/A", "Processing failed"]]
raw_times_data = [[0.0, 0.0]]
processed_audio_path = None
csv_file_path = None
srt_file_path = None
original_path_name = Path(audio_path).name
audio_name = Path(audio_path).stem
csv_button_update = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
srt_button_update = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
try:
gr.Info(f"Upsampling and loading audio: {original_path_name}", duration=2)
upsampled_path, upsample_msg = upsample_to_16k(audio_path)
if not upsampled_path:
gr.Error(upsample_msg, duration=None)
return [["Error", "Error", upsample_msg]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
audio = AudioSegment.from_file(upsampled_path)
duration_sec = audio.duration_seconds
info_path_name = Path(upsampled_path).name
long_audio_settings_applied = False
try:
with model_lock:
model.to(device)
model.to(torch.float32)
gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
if duration_sec > 480: # 8 minutes
try:
gr.Info("Audio longer than 8 minutes. Applying optimized settings for long transcription.", duration=3)
print("Applying long audio settings: Local Attention and Chunking.")
model.change_attention_model("rel_pos_local_attn", [256, 256])
model.change_subsampling_conv_chunking_factor(1) # 1 = auto select
long_audio_settings_applied = True
except Exception as setting_e:
gr.Warning(f"Could not apply long audio settings: {setting_e}", duration=5)
print(f"Warning: Failed to apply long audio settings: {setting_e}")
model.to(torch.bfloat16)
output = model.transcribe([upsampled_path], timestamps=True)
if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
gr.Error("Transcription failed or produced unexpected output format.", duration=None)
return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
segment_timestamps = output[0].timestamp['segment']
csv_headers = ["Start (s)", "End (s)", "Segment"]
vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
try:
csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
with open(csv_file_path, 'w', encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(csv_headers)
writer.writerows(vis_data)
print(f"CSV transcript saved to temporary file: {csv_file_path}")
csv_button_update = gr.DownloadButton(value=csv_file_path, visible=True, label="Download Transcript (CSV)")
except Exception as csv_e:
gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
print(f"Error writing CSV: {csv_e}")
if segment_timestamps:
try:
srt_content = generate_srt_content(segment_timestamps)
srt_file_path = Path(session_dir, f"transcription_{audio_name}.srt")
with open(srt_file_path, 'w', encoding='utf-8') as f:
f.write(srt_content)
print(f"SRT transcript saved to temporary file: {srt_file_path}")
srt_button_update = gr.DownloadButton(value=srt_file_path, visible=True, label="Download Transcript (SRT)")
except Exception as srt_e:
gr.Warning(f"Failed to create transcript SRT file: {srt_e}", duration=5)
print(f"Error writing SRT: {srt_e}")
gr.Info("Transcription complete.", duration=2)
return vis_data, raw_times_data, upsampled_path, csv_button_update, srt_button_update
except torch.cuda.OutOfMemoryError as e:
error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
print(f"CUDA OutOfMemoryError: {e}")
gr.Error(error_msg, duration=None)
return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
except FileNotFoundError:
error_msg = f"Audio file for transcription not found: {Path(upsampled_path).name}."
print(f"Error: Transcribe audio file not found at path: {upsampled_path}")
gr.Error(error_msg, duration=None)
return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
except Exception as e:
error_msg = f"Transcription failed: {e}"
print(f"Error during transcription processing: {e}")
gr.Error(error_msg, duration=None)
vis_data = [["Error", "Error", error_msg]]
raw_times_data = [[0.0, 0.0]]
return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
finally:
with model_lock:
try:
if long_audio_settings_applied:
try:
print("Reverting long audio settings.")
model.change_attention_model("rel_pos")
model.change_subsampling_conv_chunking_factor(-1)
long_audio_settings_applied = False
except Exception as revert_e:
print(f"Warning: Failed to revert long audio settings: {revert_e}")
gr.Warning(f"Issue reverting model settings after long transcription: {revert_e}", duration=5)
if 'model' in locals() and hasattr(model, 'cpu'):
if device == 'cuda':
model.cpu()
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
except Exception as cleanup_e:
print(f"Error during model cleanup: {cleanup_e}")
gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
finally:
if processed_audio_path and os.path.exists(processed_audio_path):
try:
os.remove(processed_audio_path)
print(f"Temporary audio file {processed_audio_path} removed.")
except Exception as e:
print(f"Error removing temporary audio file {processed_audio_path}: {e}")
def strip_markdown(text):
text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text)
text = re.sub(r'(\*|_)(.*?)\1', r'\2', text)
text = re.sub(r'`(.+?)`', r'\1', text)
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text)
text = re.sub(r'^#+\s*', '', text, flags=re.MULTILINE)
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)
text = re.sub(r'^-\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\d+\.\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'---', '', text)
return text.strip()
def generate_meeting_minutes(session_dir):
try:
csv_files = sorted(Path(session_dir).glob("transcription_*.csv"), key=os.path.getmtime, reverse=True)
if not csv_files:
return "No transcript CSV found. Please transcribe first.", None, gr.update(visible=True)
csv_path = csv_files[0]
with open(csv_path, "r", encoding="utf-8") as f:
transcript = f.read()
prompt = (
"You are an expert meeting minutes assistant. "
"Given the following transcript CSV (with start and end times and segments), "
"summarize the meeting into structured minutes. "
"provide the minutes only and nothing else, no intro, no outro, no comments, just the minutes. "
"Include: Attendees (if mentioned), Topics, Discussion Points, Decisions, Action Items, and Next Steps. "
"Be concise and use bullet points where possible.\n\n"
"Transcript CSV:\n"
f"{transcript}\n"
"Structured Meeting Minutes:"
)
print("Sending prompt to Qwen2.5-1.5B-Instruct...")
out = qwen_pipe(prompt)
minutes = out[0]["generated_text"][len(prompt):].strip()
clean_minutes = strip_markdown(minutes)
docx_file = Path(session_dir) / "meeting_minutes.docx"
doc = Document()
for line in clean_minutes.splitlines():
doc.add_paragraph(line)
doc.save(docx_file)
print("Minutes generated and saved to:", docx_file)
return minutes, str(docx_file), gr.update(visible=True)
except Exception as e:
print("Error in generate_meeting_minutes:", e)
return f"Error generating minutes: {e}", None, gr.update(visible=True)
def hangup_call(call_sid):
try:
account_sid, auth_token, _ = get_twilio_credentials()
if not all([account_sid, auth_token]):
return "Twilio credentials not found. Please set environment variables."
client = Client(account_sid, auth_token)
call = client.calls(call_sid).update(status="completed")
return f"Call {call_sid} has been hung up."
except Exception as e:
return f"Error hanging up call: {str(e)}"
# ========== Gradio UI ==========
nvidia_theme = gr_themes.Default(
primary_hue=gr_themes.Color(
c50="#E6F1D9", c100="#CEE3B3", c200="#B5D58C", c300="#9CC766", c400="#84B940",
c500="#76B900", c600="#68A600", c700="#5A9200", c800="#4C7E00", c900="#3E6A00", c950="#2F5600"
),
neutral_hue="gray",
font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
).set()
with gr.Blocks(theme=nvidia_theme) as demo:
current_audio_path_state = gr.State(None)
raw_timestamps_list_state = gr.State([])
session_dir = gr.State()
demo.load(start_session, outputs=[session_dir])
# ====== Twilio Tab ======
with gr.Tab("Twilio Call & Recording"):
gr.Markdown("### 1. Make Twilio Call and Record")
phone_number = gr.Textbox(label="Phone Number (E.164)", placeholder="+15551234567")
conference_code = gr.Textbox(label="Conference Code (optional)", placeholder="123456#")
call_btn = gr.Button("Make Call")
call_sid = gr.Textbox(label="Call SID", interactive=False)
call_status = gr.Textbox(label="Call Status", interactive=False)
call_btn.click(
make_conference_call,
inputs=[phone_number, conference_code],
outputs=[call_sid, call_status]
)
hangup_btn = gr.Button("Hangup Call")
hangup_status = gr.Textbox(label="Hangup Status", interactive=False)
hangup_btn.click(
hangup_call,
inputs=[call_sid],
outputs=[hangup_status]
)
gr.Markdown("### 2. Retrieve Recording")
sid_input = gr.Textbox(label="Call SID")
get_recording_btn = gr.Button("Get Recording")
recording_path = gr.Textbox(label="Recording File Path", interactive=False)
recording_status = gr.Textbox(label="Recording Status", interactive=False)
get_recording_btn.click(
check_call_status,
inputs=[sid_input],
outputs=[recording_path, recording_status]
)
gr.Markdown("### 3. Transcribe and Analyze Processed Audio")
transcribe_btn = gr.Button("Transcribe Processed Recording")
vis_timestamps_df = gr.DataFrame(
headers=["Start (s)", "End (s)", "Segment"],
datatype=["number", "number", "str"],
wrap=True,
label="Transcription Segments"
)
download_btn_csv = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
download_btn_srt = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
transcribe_btn.click(
get_transcripts_and_raw_times,
inputs=[recording_path, session_dir],
outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
)
# ====== Your Existing UI ======
gr.Markdown("---")
gr.Markdown("<p><strong style='color: #FF0000; font-size: 1.2em;'>Transcription Results (Click row to play segment)</strong></p>")
with gr.Row():
gen_minutes_btn = gr.Button("Generate Meeting Minutes", variant="primary")
minutes_output = gr.Textbox(label="Structured Meeting Minutes", visible=False, lines=15)
minutes_download = gr.DownloadButton(label="Download Meeting Minutes (.docx)", visible=False)
with gr.Row():
download_btn_csv = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
download_btn_srt = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
vis_timestamps_df = gr.DataFrame(
headers=["Start (s)", "End (s)", "Segment"],
datatype=["number", "number", "str"],
wrap=True,
label="Transcription Segments"
)
selected_segment_player = gr.Audio(label="Selected Segment", interactive=False)
mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio")
mic_transcribe_btn = gr.Button("Transcribe Microphone Input", variant="primary")
file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio File")
file_transcribe_btn = gr.Button("Transcribe Uploaded File", variant="primary")
mic_transcribe_btn.click(
get_transcripts_and_raw_times,
inputs=[mic_input, session_dir],
outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
api_name="transcribe_mic"
)
file_transcribe_btn.click(
get_transcripts_and_raw_times,
inputs=[file_input, session_dir],
outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
api_name="transcribe_file"
)
gen_minutes_btn.click(
generate_meeting_minutes,
inputs=[session_dir],
outputs=[minutes_output, minutes_download, minutes_download],
)
def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
if not isinstance(raw_ts_list, list):
return gr.Audio(value=None, label="Selected Segment")
if not current_audio_path:
return gr.Audio(value=None, label="Selected Segment")
selected_index = evt.index[0]
if selected_index < 0 or selected_index >= len(raw_ts_list):
return gr.Audio(value=None, label="Selected Segment")
if not isinstance(raw_ts_list[selected_index], (list, tuple)) or len(raw_ts_list[selected_index]) != 2:
return gr.Audio(value=None, label="Selected Segment")
start_time_s, end_time_s = raw_ts_list[selected_index]
segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
if segment_data:
return gr.Audio(value=segment_data, autoplay=True, label=f"Segment: {start_time_s:.2f}s - {end_time_s:.2f}s", interactive=False)
else:
return gr.Audio(value=None, label="Selected Segment")
vis_timestamps_df.select(
play_segment,
inputs=[raw_timestamps_list_state, current_audio_path_state],
outputs=[selected_segment_player],
)
demo.unload(end_session)
if __name__ == "__main__":
print("Launching Gradio Demo...")
demo.queue()
demo.launch()