loge-dot commited on
Commit
d88ff3b
·
1 Parent(s): 06f70ac
app.py CHANGED
@@ -1,12 +1,21 @@
 
 
1
  import streamlit as st
2
  import os
3
- import sys
4
  from pathlib import Path
 
 
 
 
 
 
 
 
5
 
6
  # 确保能找到项目模块
7
  sys.path.append(str(Path(__file__).parent))
8
 
9
- from pages import emotion_analyzer, chatbot # 导入情绪分析页面和 Chatbot 页面
10
 
11
  def main():
12
  st.set_page_config(
@@ -18,13 +27,8 @@ def main():
18
  st.title("Audio Emotion Recognition System")
19
  st.write("This is a web application for audio emotion recognition.")
20
 
21
- # 选择页面
22
- page = st.sidebar.selectbox("Select a page", ["Emotion Analyzer", "Chatbot"])
23
-
24
- if page == "Emotion Analyzer":
25
- emotion_analyzer.show()
26
- elif page == "Chatbot":
27
- chatbot.show_chatbot()
28
 
29
  if __name__ == "__main__":
30
  main()
 
1
+ import asyncio
2
+ import sys
3
  import streamlit as st
4
  import os
 
5
  from pathlib import Path
6
+ try:
7
+ asyncio.get_running_loop()
8
+ except RuntimeError:
9
+ asyncio.set_event_loop(asyncio.new_event_loop())
10
+
11
+ sys.path.append(str(Path(__file__).parent))
12
+
13
+
14
 
15
  # 确保能找到项目模块
16
  sys.path.append(str(Path(__file__).parent))
17
 
18
+ from pages import emotion_analyzer # 导入情绪分析页面和 Chatbot 页面
19
 
20
  def main():
21
  st.set_page_config(
 
27
  st.title("Audio Emotion Recognition System")
28
  st.write("This is a web application for audio emotion recognition.")
29
 
30
+ # 先只测试情绪分析页面
31
+ emotion_analyzer.show()
 
 
 
 
 
32
 
33
  if __name__ == "__main__":
34
  main()
components/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/components/__pycache__/__init__.cpython-313.pyc and b/components/__pycache__/__init__.cpython-313.pyc differ
 
components/__pycache__/audio_player.cpython-313.pyc CHANGED
Binary files a/components/__pycache__/audio_player.cpython-313.pyc and b/components/__pycache__/audio_player.cpython-313.pyc differ
 
components/__pycache__/debug_tools.cpython-313.pyc CHANGED
Binary files a/components/__pycache__/debug_tools.cpython-313.pyc and b/components/__pycache__/debug_tools.cpython-313.pyc differ
 
components/__pycache__/visualizations.cpython-313.pyc CHANGED
Binary files a/components/__pycache__/visualizations.cpython-313.pyc and b/components/__pycache__/visualizations.cpython-313.pyc differ
 
pages/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/pages/__pycache__/__init__.cpython-313.pyc and b/pages/__pycache__/__init__.cpython-313.pyc differ
 
pages/__pycache__/chatbot.cpython-313.pyc CHANGED
Binary files a/pages/__pycache__/chatbot.cpython-313.pyc and b/pages/__pycache__/chatbot.cpython-313.pyc differ
 
pages/__pycache__/emotion_analyzer.cpython-313.pyc CHANGED
Binary files a/pages/__pycache__/emotion_analyzer.cpython-313.pyc and b/pages/__pycache__/emotion_analyzer.cpython-313.pyc differ
 
pages/chatbot.py CHANGED
@@ -12,7 +12,7 @@ from utils import model_inference
12
  import os
13
 
14
  # 加载环境变量
15
- load_dotenv(".env")
16
  api_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
17
  api_key = os.getenv("AZURE_OPENAI_API_KEY")
18
  api_version = os.getenv("AZURE_OPENAI_API_VERSION")
 
12
  import os
13
 
14
  # 加载环境变量
15
+ load_dotenv("Group7/.env")
16
  api_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
17
  api_key = os.getenv("AZURE_OPENAI_API_KEY")
18
  api_version = os.getenv("AZURE_OPENAI_API_VERSION")
utils/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/utils/__pycache__/__init__.cpython-313.pyc and b/utils/__pycache__/__init__.cpython-313.pyc differ
 
utils/__pycache__/audio_processing.cpython-313.pyc CHANGED
Binary files a/utils/__pycache__/audio_processing.cpython-313.pyc and b/utils/__pycache__/audio_processing.cpython-313.pyc differ
 
utils/__pycache__/model_inference.cpython-313.pyc CHANGED
Binary files a/utils/__pycache__/model_inference.cpython-313.pyc and b/utils/__pycache__/model_inference.cpython-313.pyc differ
 
utils/model_inference.py CHANGED
@@ -4,52 +4,104 @@ import os
4
  from transformers import AutoTokenizer, BertModel, Wav2Vec2Model
5
  from utils.audio_processing import AudioProcessor
6
  import torchaudio
 
7
  import torch.nn.functional as F
8
  from huggingface_hub import hf_hub_download
9
  from safetensors.torch import load_file
 
10
 
11
  # 下载模型
12
- model_path = hf_hub_download(repo_id="liloge/Group7_model_test", filename="model.safetensors")
13
-
14
- class MultimodalClassifier(torch.nn.Module):
15
- def __init__(self):
16
- super(MultimodalClassifier, self).__init__()
17
- self.bert = BertModel.from_pretrained("bert-base-uncased")
18
- self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
19
- self.classifier = torch.nn.Sequential(
20
- torch.nn.Linear(self.bert.config.hidden_size + self.wav2vec2.config.hidden_size, 256),
21
- torch.nn.ReLU(),
22
- torch.nn.Dropout(0.7),
23
- torch.nn.Linear(256, 7) # 7分类任务
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
 
26
  def forward(self, text_input, audio_input):
 
27
  text_outputs = self.bert(**text_input, output_hidden_states=True)
28
- text_features = text_outputs.hidden_states[-1][:, 0, :] # [CLS] token
 
 
29
  audio_outputs = self.wav2vec2(audio_input, output_hidden_states=True)
30
  audio_features = audio_outputs.hidden_states[-1][:, 0, :]
 
 
31
  combined_features = torch.cat((text_features, audio_features), dim=-1)
 
 
32
  logits = self.classifier(combined_features)
33
  return logits
34
 
35
- # 加载模型
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
- model = MultimodalClassifier().to(device)
38
 
39
- # 加载 SafeTensors 权重
40
- state_dict = load_file(model_path)
41
- print("state_dict:", state_dict)
42
- model.load_state_dict(state_dict)
43
- model.eval() # 设置为评估模式
 
 
 
 
 
44
 
45
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
46
 
47
  def preprocess_text(text):
48
- return tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
 
 
49
 
50
  def preprocess_audio(audio_path):
51
  waveform, sample_rate = torchaudio.load(audio_path)
52
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
 
53
  return waveform.to(device)
54
 
55
  labels = ["Neutral", "Happy", "Sad", "Angry", "Fearful", "Disgusted", "Surprised"]
@@ -59,7 +111,7 @@ def predict_emotion(text, audio):
59
  audio_inputs = preprocess_audio(audio)
60
 
61
  with torch.no_grad():
62
- output = model(audio_inputs,text_inputs) # (1, 7) logits
63
  probabilities = F.softmax(output, dim=1).squeeze().tolist() # 归一化为概率
64
 
65
  return {labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(labels))}
@@ -87,4 +139,4 @@ def save_history(audio_file, transcript, emotions, probabilities):
87
  })
88
 
89
  with open(history_file, 'w') as f:
90
- json.dump(history, f, indent=4)
 
4
  from transformers import AutoTokenizer, BertModel, Wav2Vec2Model
5
  from utils.audio_processing import AudioProcessor
6
  import torchaudio
7
+ import torch.nn as nn
8
  import torch.nn.functional as F
9
  from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file
11
+ from transformers import AutoModelForSequenceClassification, AutoConfig, Wav2Vec2ForPreTraining
12
 
13
  # 下载模型
14
+ # huggingface_hub 仓库下载
15
+ # model_path = hf_hub_download(repo_id="liloge/Group7_model_test", filename="model.safetensors")
16
+ # 本地下载
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from transformers import AutoModelForSequenceClassification, AutoConfig, Wav2Vec2ForPreTraining
21
+
22
+ class MultimodalClassifier(nn.Module):
23
+ def __init__(self, bert_ckpt_path, wav2vec2_config_path, wav2vec2_safetensors_path):
24
+ super().__init__()
25
+
26
+ # **加载微调后的 BERT**
27
+ self.bert = AutoModelForSequenceClassification.from_pretrained(
28
+ "bert-base-uncased", num_labels=7
29
+ )
30
+ self.bert.classifier = nn.Sequential(
31
+ nn.Dropout(0.5),
32
+ nn.Linear(self.bert.config.hidden_size, self.bert.config.num_labels)
33
+ )
34
+ try:
35
+ self.bert.load_state_dict(torch.load(bert_ckpt_path, map_location=torch.device("cpu")), strict=True)
36
+ except Exception as e:
37
+ print(f"❌ 加载 `{bert_ckpt_path}` 失败: {e}")
38
+
39
+ # **先加载 Wav2Vec2**
40
+ config = AutoConfig.from_pretrained(wav2vec2_config_path, num_labels=7)
41
+ self.wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base", config=config)
42
+
43
+ # **再修改 Wav2Vec2 的分类头**
44
+ self.wav2vec2.classifier = nn.Sequential(
45
+ nn.Dropout(0.5),
46
+ nn.Linear(self.wav2vec2.config.hidden_size, self.wav2vec2.config.num_labels)
47
+ )
48
+ # **加载 safetensors 权重**
49
+ from safetensors.torch import load_file
50
+ state_dict = load_file(wav2vec2_safetensors_path)
51
+ try:
52
+ self.wav2vec2.load_state_dict(state_dict, strict=False)
53
+ except Exception as e:
54
+ print(f"❌ 加载 `{wav2vec2_safetensors_path}` 失败: {e}")
55
+
56
+ # **拼接特征的分类头**
57
+ self.classifier = nn.Sequential(
58
+ nn.Linear(self.bert.config.hidden_size + self.wav2vec2.config.hidden_size, 256),
59
+ nn.ReLU(),
60
+ nn.Dropout(0.7),
61
+ nn.Linear(256, 7) # 7分类任务
62
  )
63
 
64
  def forward(self, text_input, audio_input):
65
+ # **文本特征**
66
  text_outputs = self.bert(**text_input, output_hidden_states=True)
67
+ text_features = text_outputs.hidden_states[-1][:, 0, :]
68
+
69
+ # **音频特征**
70
  audio_outputs = self.wav2vec2(audio_input, output_hidden_states=True)
71
  audio_features = audio_outputs.hidden_states[-1][:, 0, :]
72
+
73
+ # **拼接特征**
74
  combined_features = torch.cat((text_features, audio_features), dim=-1)
75
+
76
+ # **分类**
77
  logits = self.classifier(combined_features)
78
  return logits
79
 
80
+
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
82
 
83
+ # **定义路径**
84
+ bert_ckpt_path = "bert_meld_finetune_model.pth"
85
+ wav2vec2_config_path = "config.json"
86
+ wav2vec2_safetensors_path = "wav2vec2.safetensors"
87
+
88
+ # **加载模型**
89
+ model = MultimodalClassifier(bert_ckpt_path, wav2vec2_config_path, wav2vec2_safetensors_path).to(device)
90
+ model.eval()
91
+
92
+ print("✅ 微调的 BERT + Wav2Vec2 模型加载成功!")
93
 
94
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
95
 
96
  def preprocess_text(text):
97
+ text_inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
98
+ print(text_inputs)
99
+ return text_inputs.to(device)
100
 
101
  def preprocess_audio(audio_path):
102
  waveform, sample_rate = torchaudio.load(audio_path)
103
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
104
+ print(waveform)
105
  return waveform.to(device)
106
 
107
  labels = ["Neutral", "Happy", "Sad", "Angry", "Fearful", "Disgusted", "Surprised"]
 
111
  audio_inputs = preprocess_audio(audio)
112
 
113
  with torch.no_grad():
114
+ output = model(text_input=text_inputs, audio_input=audio_inputs) # (1, 7) logits
115
  probabilities = F.softmax(output, dim=1).squeeze().tolist() # 归一化为概率
116
 
117
  return {labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(labels))}
 
139
  })
140
 
141
  with open(history_file, 'w') as f:
142
+ json.dump(history, f, indent=4)