ThanhNguyen1811 commited on
Commit
58445a7
·
verified ·
1 Parent(s): 949c8f8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +31 -73
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,73 +1,31 @@
1
- import gradio as gr
2
- import torch
3
- import whisper
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
-
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
-
8
- # Load PhoBERT model
9
- MODEL_NAME = "vinai/phobert-base-v2"
10
- phobert = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
11
- phobert.load_state_dict(torch.load("best_model_state.bin", map_location=device))
12
- phobert.to(device)
13
- phobert.eval()
14
-
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
- label_map = {0: "An toàn", 1: "Tiêu cực", 2: "Nguy cơ bạo lực"}
17
-
18
- # Load Whisper model
19
- asr_model = whisper.load_model("base")
20
-
21
- def predict_emotion_from_text(text):
22
- inputs = tokenizer.encode_plus(
23
- text.lower(),
24
- return_tensors="pt",
25
- max_length=128,
26
- padding="max_length",
27
- truncation=True
28
- )
29
- input_ids = inputs["input_ids"].to(device)
30
- attention_mask = inputs["attention_mask"].to(device)
31
-
32
- with torch.no_grad():
33
- outputs = phobert(input_ids=input_ids, attention_mask=attention_mask)
34
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
35
- pred = torch.argmax(probs, dim=1).item()
36
-
37
- return f"{label_map[pred]} (độ tin cậy: {probs[0][pred]:.2f})"
38
-
39
- # Hàm xử lý audio
40
- def analyze_from_audio(audio):
41
- if audio is None:
42
- return "Không có âm thanh"
43
-
44
- # Whisper yêu cầu path tới file .wav
45
- import tempfile
46
- import scipy.io.wavfile
47
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
48
- scipy.io.wavfile.write(tmp.name, 16000, audio[1])
49
- result = asr_model.transcribe(tmp.name, language="vi")
50
-
51
- text = result["text"]
52
- if not text.strip():
53
- return "Không nhận diện được giọng nói"
54
-
55
- prediction = predict_emotion_from_text(text)
56
- return f"Văn bản: {text}\n\nDự đoán: {prediction}"
57
-
58
- with gr.Blocks() as demo:
59
- gr.Markdown("## 🎤 PhoBERT: Nhận diện cảm xúc học sinh từ giọng nói và văn bản")
60
-
61
- with gr.Tab("📝 Nhập văn bản"):
62
- text_input = gr.Textbox(label="Nhập câu tiếng Việt")
63
- text_output = gr.Textbox(label="Kết quả")
64
- text_btn = gr.Button("Dự đoán cảm xúc")
65
- text_btn.click(fn=predict_emotion_from_text, inputs=text_input, outputs=text_output)
66
-
67
- with gr.Tab("🎙️ Ghi âm giọng nói"):
68
- audio_input = gr.Audio(source="microphone", type="numpy", label="Ghi âm")
69
- audio_output = gr.Textbox(label="Kết quả cảm xúc")
70
- audio_btn = gr.Button("Phân tích giọng nói")
71
- audio_btn.click(fn=analyze_from_audio, inputs=audio_input, outputs=audio_output)
72
-
73
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+ MODEL_NAME = "vinai/phobert-base-v2"
8
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
9
+ model.load_state_dict(torch.load("best_model_state.bin", map_location=device))
10
+ model.to(device)
11
+ model.eval()
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ label_map = {0: "An toàn", 1: "Tiêu cực", 2: "Nguy cơ bạo lực"}
15
+
16
+ def predict(text):
17
+ inputs = tokenizer.encode_plus(
18
+ text.lower(), return_tensors="pt", max_length=128, padding="max_length", truncation=True
19
+ )
20
+ input_ids = inputs["input_ids"].to(device)
21
+ attention_mask = inputs["attention_mask"].to(device)
22
+
23
+ with torch.no_grad():
24
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
25
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
26
+ pred = torch.argmax(probs, dim=1).item()
27
+
28
+ return f"{label_map[pred]} (độ tin cậy: {probs[0][pred]:.2f})"
29
+
30
+ demo = gr.Interface(fn=predict, inputs="text", outputs="text", title="PhoBERT - Phân tích cảm xúc học sinh")
31
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- torch
2
- transformers
3
  gradio
 
1
+ torch
2
+ transformers
3
  gradio