xujinheng666 commited on
Commit
65628c8
·
verified ·
1 Parent(s): 5a44a9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -26
app.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import torch
2
  import torchaudio
3
- import os
4
  import re
5
  import streamlit as st
6
  from difflib import SequenceMatcher
@@ -15,47 +16,90 @@ language = "zh"
15
  pipe = pipeline(
16
  task="automatic-speech-recognition",
17
  model=MODEL_NAME,
18
- chunk_length_s=60,
19
- device=device
 
 
 
 
 
 
 
 
 
20
  )
21
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
22
 
23
  # Load quality rating model
24
- rating_pipe = pipeline("text-classification", model="tabularisai/multilingual-sentiment-analysis")
 
 
 
25
 
26
- # Sentiment label mapping
27
- label_map = {"Negative": "Very Poor", "Neutral": "Neutral", "Positive": "Very Good"}
 
 
 
 
 
28
 
29
  def remove_punctuation(text):
30
  return re.sub(r'[^\w\s]', '', text)
31
 
32
  def transcribe_audio(audio_path):
33
- transcript = pipe(audio_path)["text"]
34
- return remove_punctuation(transcript)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def rate_quality(text):
37
- result = rating_pipe(text)[0]
38
- return label_map.get(result["label"], "Unknown")
 
 
 
 
 
39
 
40
  # Streamlit UI
41
- st.set_page_config(page_title="Cantonese Audio Transcription & Analysis", layout="centered")
42
- st.title("🗣️ Cantonese Audio Transcriber & Sentiment Analyzer")
43
- st.markdown("Upload your Cantonese audio file, and we will transcribe and analyze its sentiment.")
44
 
45
- uploaded_file = st.file_uploader("Upload an audio file (WAV, MP3, etc.)", type=["wav", "mp3", "m4a"])
46
  if uploaded_file is not None:
47
- with st.spinner("Processing audio..."):
48
- temp_audio_path = "temp_audio.wav"
49
- with open(temp_audio_path, "wb") as f:
50
- f.write(uploaded_file.getbuffer())
51
- transcript = transcribe_audio(temp_audio_path)
52
- sentiment = rate_quality(transcript)
53
- os.remove(temp_audio_path)
54
 
55
- st.subheader("Transcription")
56
- st.text_area("", transcript, height=150)
 
 
57
 
58
- st.subheader("Sentiment Analysis")
59
- st.markdown(f"### 🎭 Sentiment: **{sentiment}**")
 
60
 
61
- st.success("Processing complete! 🎉")
 
1
+ import os
2
  import torch
3
  import torchaudio
4
+ import numpy as np
5
  import re
6
  import streamlit as st
7
  from difflib import SequenceMatcher
 
16
  pipe = pipeline(
17
  task="automatic-speech-recognition",
18
  model=MODEL_NAME,
19
+ chunk_length_s=30,
20
+ device=device,
21
+ generate_kwargs={
22
+ "no_repeat_ngram_size": 4,
23
+ "repetition_penalty": 1.15,
24
+ "temperature": 0.5,
25
+ "top_p": 0.97,
26
+ "top_k": 40,
27
+ "max_new_tokens": 300,
28
+ "do_sample": True
29
+ }
30
  )
31
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
32
 
33
  # Load quality rating model
34
+ rating_pipe = pipeline("text-classification", model="tabularisai/multilingual-sentiment-analysis", device=device)
35
+
36
+ def is_similar(a, b, threshold=0.8):
37
+ return SequenceMatcher(None, a, b).ratio() > threshold
38
 
39
+ def remove_repeated_phrases(text):
40
+ sentences = re.split(r'(?<=[。!?])', text)
41
+ cleaned_sentences = []
42
+ for sentence in sentences:
43
+ if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
44
+ cleaned_sentences.append(sentence.strip())
45
+ return " ".join(cleaned_sentences)
46
 
47
  def remove_punctuation(text):
48
  return re.sub(r'[^\w\s]', '', text)
49
 
50
  def transcribe_audio(audio_path):
51
+ waveform, sample_rate = torchaudio.load(audio_path)
52
+
53
+ if waveform.shape[0] > 1:
54
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
55
+
56
+ waveform = waveform.squeeze(0).numpy()
57
+ duration = waveform.shape[0] / sample_rate
58
+
59
+ if duration > 60:
60
+ chunk_size = sample_rate * 55
61
+ step_size = sample_rate * 50
62
+ results = []
63
+
64
+ for start in range(0, waveform.shape[0], step_size):
65
+ chunk = waveform[start:start + chunk_size]
66
+ if chunk.shape[0] == 0:
67
+ break
68
+ transcript = pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
69
+ results.append(remove_punctuation(transcript))
70
+
71
+ return remove_punctuation(remove_repeated_phrases(" ".join(results)))
72
+
73
+ return remove_punctuation(remove_repeated_phrases(pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]))
74
 
75
  def rate_quality(text):
76
+ chunks = [text[i:i+512] for i in range(0, len(text), 512)]
77
+ results = rating_pipe(chunks, batch_size=4)
78
+
79
+ label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
80
+ processed_results = [label_map.get(res["label"], "Unknown") for res in results]
81
+
82
+ return max(set(processed_results), key=processed_results.count)
83
 
84
  # Streamlit UI
85
+ st.title("Audio Transcription and Quality Rating")
86
+
87
+ uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
88
 
 
89
  if uploaded_file is not None:
90
+ st.audio(uploaded_file, format="audio/wav")
91
+
92
+ temp_audio_path = "temp_audio.wav"
93
+ with open(temp_audio_path, "wb") as f:
94
+ f.write(uploaded_file.read())
 
 
95
 
96
+ st.write("Processing audio...")
97
+ transcript = transcribe_audio(temp_audio_path)
98
+ st.subheader("Transcript")
99
+ st.write(transcript)
100
 
101
+ quality_rating = rate_quality(transcript)
102
+ st.subheader("Quality Rating")
103
+ st.write(quality_rating)
104
 
105
+ os.remove(temp_audio_path)