loge-dot
commited on
Commit
·
d88ff3b
1
Parent(s):
06f70ac
change1
Browse files- app.py +13 -9
- components/__pycache__/__init__.cpython-313.pyc +0 -0
- components/__pycache__/audio_player.cpython-313.pyc +0 -0
- components/__pycache__/debug_tools.cpython-313.pyc +0 -0
- components/__pycache__/visualizations.cpython-313.pyc +0 -0
- pages/__pycache__/__init__.cpython-313.pyc +0 -0
- pages/__pycache__/chatbot.cpython-313.pyc +0 -0
- pages/__pycache__/emotion_analyzer.cpython-313.pyc +0 -0
- pages/chatbot.py +1 -1
- utils/__pycache__/__init__.cpython-313.pyc +0 -0
- utils/__pycache__/audio_processing.cpython-313.pyc +0 -0
- utils/__pycache__/model_inference.cpython-313.pyc +0 -0
- utils/model_inference.py +75 -23
app.py
CHANGED
@@ -1,12 +1,21 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import os
|
3 |
-
import sys
|
4 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# 确保能找到项目模块
|
7 |
sys.path.append(str(Path(__file__).parent))
|
8 |
|
9 |
-
from pages import emotion_analyzer
|
10 |
|
11 |
def main():
|
12 |
st.set_page_config(
|
@@ -18,13 +27,8 @@ def main():
|
|
18 |
st.title("Audio Emotion Recognition System")
|
19 |
st.write("This is a web application for audio emotion recognition.")
|
20 |
|
21 |
-
#
|
22 |
-
|
23 |
-
|
24 |
-
if page == "Emotion Analyzer":
|
25 |
-
emotion_analyzer.show()
|
26 |
-
elif page == "Chatbot":
|
27 |
-
chatbot.show_chatbot()
|
28 |
|
29 |
if __name__ == "__main__":
|
30 |
main()
|
|
|
1 |
+
import asyncio
|
2 |
+
import sys
|
3 |
import streamlit as st
|
4 |
import os
|
|
|
5 |
from pathlib import Path
|
6 |
+
try:
|
7 |
+
asyncio.get_running_loop()
|
8 |
+
except RuntimeError:
|
9 |
+
asyncio.set_event_loop(asyncio.new_event_loop())
|
10 |
+
|
11 |
+
sys.path.append(str(Path(__file__).parent))
|
12 |
+
|
13 |
+
|
14 |
|
15 |
# 确保能找到项目模块
|
16 |
sys.path.append(str(Path(__file__).parent))
|
17 |
|
18 |
+
from pages import emotion_analyzer # 导入情绪分析页面和 Chatbot 页面
|
19 |
|
20 |
def main():
|
21 |
st.set_page_config(
|
|
|
27 |
st.title("Audio Emotion Recognition System")
|
28 |
st.write("This is a web application for audio emotion recognition.")
|
29 |
|
30 |
+
# 先只测试情绪分析页面
|
31 |
+
emotion_analyzer.show()
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
if __name__ == "__main__":
|
34 |
main()
|
components/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/components/__pycache__/__init__.cpython-313.pyc and b/components/__pycache__/__init__.cpython-313.pyc differ
|
|
components/__pycache__/audio_player.cpython-313.pyc
CHANGED
Binary files a/components/__pycache__/audio_player.cpython-313.pyc and b/components/__pycache__/audio_player.cpython-313.pyc differ
|
|
components/__pycache__/debug_tools.cpython-313.pyc
CHANGED
Binary files a/components/__pycache__/debug_tools.cpython-313.pyc and b/components/__pycache__/debug_tools.cpython-313.pyc differ
|
|
components/__pycache__/visualizations.cpython-313.pyc
CHANGED
Binary files a/components/__pycache__/visualizations.cpython-313.pyc and b/components/__pycache__/visualizations.cpython-313.pyc differ
|
|
pages/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/pages/__pycache__/__init__.cpython-313.pyc and b/pages/__pycache__/__init__.cpython-313.pyc differ
|
|
pages/__pycache__/chatbot.cpython-313.pyc
CHANGED
Binary files a/pages/__pycache__/chatbot.cpython-313.pyc and b/pages/__pycache__/chatbot.cpython-313.pyc differ
|
|
pages/__pycache__/emotion_analyzer.cpython-313.pyc
CHANGED
Binary files a/pages/__pycache__/emotion_analyzer.cpython-313.pyc and b/pages/__pycache__/emotion_analyzer.cpython-313.pyc differ
|
|
pages/chatbot.py
CHANGED
@@ -12,7 +12,7 @@ from utils import model_inference
|
|
12 |
import os
|
13 |
|
14 |
# 加载环境变量
|
15 |
-
load_dotenv("
|
16 |
api_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
17 |
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
18 |
api_version = os.getenv("AZURE_OPENAI_API_VERSION")
|
|
|
12 |
import os
|
13 |
|
14 |
# 加载环境变量
|
15 |
+
load_dotenv("Group7/.env")
|
16 |
api_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
17 |
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
18 |
api_version = os.getenv("AZURE_OPENAI_API_VERSION")
|
utils/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/utils/__pycache__/__init__.cpython-313.pyc and b/utils/__pycache__/__init__.cpython-313.pyc differ
|
|
utils/__pycache__/audio_processing.cpython-313.pyc
CHANGED
Binary files a/utils/__pycache__/audio_processing.cpython-313.pyc and b/utils/__pycache__/audio_processing.cpython-313.pyc differ
|
|
utils/__pycache__/model_inference.cpython-313.pyc
CHANGED
Binary files a/utils/__pycache__/model_inference.cpython-313.pyc and b/utils/__pycache__/model_inference.cpython-313.pyc differ
|
|
utils/model_inference.py
CHANGED
@@ -4,52 +4,104 @@ import os
|
|
4 |
from transformers import AutoTokenizer, BertModel, Wav2Vec2Model
|
5 |
from utils.audio_processing import AudioProcessor
|
6 |
import torchaudio
|
|
|
7 |
import torch.nn.functional as F
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
from safetensors.torch import load_file
|
|
|
10 |
|
11 |
# 下载模型
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
)
|
25 |
|
26 |
def forward(self, text_input, audio_input):
|
|
|
27 |
text_outputs = self.bert(**text_input, output_hidden_states=True)
|
28 |
-
text_features = text_outputs.hidden_states[-1][:, 0, :]
|
|
|
|
|
29 |
audio_outputs = self.wav2vec2(audio_input, output_hidden_states=True)
|
30 |
audio_features = audio_outputs.hidden_states[-1][:, 0, :]
|
|
|
|
|
31 |
combined_features = torch.cat((text_features, audio_features), dim=-1)
|
|
|
|
|
32 |
logits = self.classifier(combined_features)
|
33 |
return logits
|
34 |
|
35 |
-
|
36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
-
model = MultimodalClassifier().to(device)
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
46 |
|
47 |
def preprocess_text(text):
|
48 |
-
|
|
|
|
|
49 |
|
50 |
def preprocess_audio(audio_path):
|
51 |
waveform, sample_rate = torchaudio.load(audio_path)
|
52 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
|
|
53 |
return waveform.to(device)
|
54 |
|
55 |
labels = ["Neutral", "Happy", "Sad", "Angry", "Fearful", "Disgusted", "Surprised"]
|
@@ -59,7 +111,7 @@ def predict_emotion(text, audio):
|
|
59 |
audio_inputs = preprocess_audio(audio)
|
60 |
|
61 |
with torch.no_grad():
|
62 |
-
output = model(audio_inputs
|
63 |
probabilities = F.softmax(output, dim=1).squeeze().tolist() # 归一化为概率
|
64 |
|
65 |
return {labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(labels))}
|
@@ -87,4 +139,4 @@ def save_history(audio_file, transcript, emotions, probabilities):
|
|
87 |
})
|
88 |
|
89 |
with open(history_file, 'w') as f:
|
90 |
-
json.dump(history, f, indent=4)
|
|
|
4 |
from transformers import AutoTokenizer, BertModel, Wav2Vec2Model
|
5 |
from utils.audio_processing import AudioProcessor
|
6 |
import torchaudio
|
7 |
+
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from safetensors.torch import load_file
|
11 |
+
from transformers import AutoModelForSequenceClassification, AutoConfig, Wav2Vec2ForPreTraining
|
12 |
|
13 |
# 下载模型
|
14 |
+
# huggingface_hub 仓库下载
|
15 |
+
# model_path = hf_hub_download(repo_id="liloge/Group7_model_test", filename="model.safetensors")
|
16 |
+
# 本地下载
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from transformers import AutoModelForSequenceClassification, AutoConfig, Wav2Vec2ForPreTraining
|
21 |
+
|
22 |
+
class MultimodalClassifier(nn.Module):
|
23 |
+
def __init__(self, bert_ckpt_path, wav2vec2_config_path, wav2vec2_safetensors_path):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
# **加载微调后的 BERT**
|
27 |
+
self.bert = AutoModelForSequenceClassification.from_pretrained(
|
28 |
+
"bert-base-uncased", num_labels=7
|
29 |
+
)
|
30 |
+
self.bert.classifier = nn.Sequential(
|
31 |
+
nn.Dropout(0.5),
|
32 |
+
nn.Linear(self.bert.config.hidden_size, self.bert.config.num_labels)
|
33 |
+
)
|
34 |
+
try:
|
35 |
+
self.bert.load_state_dict(torch.load(bert_ckpt_path, map_location=torch.device("cpu")), strict=True)
|
36 |
+
except Exception as e:
|
37 |
+
print(f"❌ 加载 `{bert_ckpt_path}` 失败: {e}")
|
38 |
+
|
39 |
+
# **先加载 Wav2Vec2**
|
40 |
+
config = AutoConfig.from_pretrained(wav2vec2_config_path, num_labels=7)
|
41 |
+
self.wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base", config=config)
|
42 |
+
|
43 |
+
# **再修改 Wav2Vec2 的分类头**
|
44 |
+
self.wav2vec2.classifier = nn.Sequential(
|
45 |
+
nn.Dropout(0.5),
|
46 |
+
nn.Linear(self.wav2vec2.config.hidden_size, self.wav2vec2.config.num_labels)
|
47 |
+
)
|
48 |
+
# **加载 safetensors 权重**
|
49 |
+
from safetensors.torch import load_file
|
50 |
+
state_dict = load_file(wav2vec2_safetensors_path)
|
51 |
+
try:
|
52 |
+
self.wav2vec2.load_state_dict(state_dict, strict=False)
|
53 |
+
except Exception as e:
|
54 |
+
print(f"❌ 加载 `{wav2vec2_safetensors_path}` 失败: {e}")
|
55 |
+
|
56 |
+
# **拼接特征的分类头**
|
57 |
+
self.classifier = nn.Sequential(
|
58 |
+
nn.Linear(self.bert.config.hidden_size + self.wav2vec2.config.hidden_size, 256),
|
59 |
+
nn.ReLU(),
|
60 |
+
nn.Dropout(0.7),
|
61 |
+
nn.Linear(256, 7) # 7分类任务
|
62 |
)
|
63 |
|
64 |
def forward(self, text_input, audio_input):
|
65 |
+
# **文本特征**
|
66 |
text_outputs = self.bert(**text_input, output_hidden_states=True)
|
67 |
+
text_features = text_outputs.hidden_states[-1][:, 0, :]
|
68 |
+
|
69 |
+
# **音频特征**
|
70 |
audio_outputs = self.wav2vec2(audio_input, output_hidden_states=True)
|
71 |
audio_features = audio_outputs.hidden_states[-1][:, 0, :]
|
72 |
+
|
73 |
+
# **拼接特征**
|
74 |
combined_features = torch.cat((text_features, audio_features), dim=-1)
|
75 |
+
|
76 |
+
# **分类**
|
77 |
logits = self.classifier(combined_features)
|
78 |
return logits
|
79 |
|
80 |
+
|
81 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
82 |
|
83 |
+
# **定义路径**
|
84 |
+
bert_ckpt_path = "bert_meld_finetune_model.pth"
|
85 |
+
wav2vec2_config_path = "config.json"
|
86 |
+
wav2vec2_safetensors_path = "wav2vec2.safetensors"
|
87 |
+
|
88 |
+
# **加载模型**
|
89 |
+
model = MultimodalClassifier(bert_ckpt_path, wav2vec2_config_path, wav2vec2_safetensors_path).to(device)
|
90 |
+
model.eval()
|
91 |
+
|
92 |
+
print("✅ 微调的 BERT + Wav2Vec2 模型加载成功!")
|
93 |
|
94 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
95 |
|
96 |
def preprocess_text(text):
|
97 |
+
text_inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
98 |
+
print(text_inputs)
|
99 |
+
return text_inputs.to(device)
|
100 |
|
101 |
def preprocess_audio(audio_path):
|
102 |
waveform, sample_rate = torchaudio.load(audio_path)
|
103 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
104 |
+
print(waveform)
|
105 |
return waveform.to(device)
|
106 |
|
107 |
labels = ["Neutral", "Happy", "Sad", "Angry", "Fearful", "Disgusted", "Surprised"]
|
|
|
111 |
audio_inputs = preprocess_audio(audio)
|
112 |
|
113 |
with torch.no_grad():
|
114 |
+
output = model(text_input=text_inputs, audio_input=audio_inputs) # (1, 7) logits
|
115 |
probabilities = F.softmax(output, dim=1).squeeze().tolist() # 归一化为概率
|
116 |
|
117 |
return {labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(labels))}
|
|
|
139 |
})
|
140 |
|
141 |
with open(history_file, 'w') as f:
|
142 |
+
json.dump(history, f, indent=4)
|