Upload 26 files
Browse files- .env +4 -0
- .streamlit/config.toml +5 -0
- app.py +30 -0
- components/__init__.py +1 -0
- components/__pycache__/__init__.cpython-313.pyc +0 -0
- components/__pycache__/audio_player.cpython-313.pyc +0 -0
- components/__pycache__/debug_tools.cpython-313.pyc +0 -0
- components/__pycache__/visualizations.cpython-313.pyc +0 -0
- components/audio_player.py +10 -0
- components/debug_tools.py +56 -0
- components/progress_bar.py +7 -0
- components/visualizations.py +35 -0
- pages/__init__.py +1 -0
- pages/__pycache__/__init__.cpython-313.pyc +0 -0
- pages/__pycache__/chatbot.cpython-313.pyc +0 -0
- pages/__pycache__/emotion_analyzer.cpython-313.pyc +0 -0
- pages/chatbot.py +97 -0
- pages/emotion_analyzer.py +80 -0
- requirements.txt +14 -0
- utils/__init__.py +1 -0
- utils/__pycache__/__init__.cpython-313.pyc +0 -0
- utils/__pycache__/audio_processing.cpython-313.pyc +0 -0
- utils/__pycache__/model_inference.cpython-313.pyc +0 -0
- utils/audio_processing.py +31 -0
- utils/logger.py +37 -0
- utils/model_inference.py +89 -0
.env
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AZURE_OPENAI_ENDPOINT=https://test111222333.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2025-01-01-preview
|
2 |
+
AZURE_OPENAI_API_KEY=8iMbEYWnZI0tYrCLDk1GNUJXPp3VMCMUvl8tdbiVxi1v34vhnI7sJQQJ99AKACfhMk5XJ3w3AAABACOGNIPC
|
3 |
+
AZURE_OPENAI_API_VERSION=2024-05-01-preview
|
4 |
+
AZURE_OPENAI_DEPLOYMENT=gpt-4o
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
fileWatcherType = "none"
|
3 |
+
|
4 |
+
[logger]
|
5 |
+
level = "info"
|
app.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
# 确保能找到项目模块
|
7 |
+
sys.path.append(str(Path(__file__).parent))
|
8 |
+
|
9 |
+
from pages import emotion_analyzer, chatbot # 导入情绪分析页面和 Chatbot 页面
|
10 |
+
|
11 |
+
def main():
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="Audio Emotion Recognition System",
|
14 |
+
page_icon="🎵",
|
15 |
+
layout="wide"
|
16 |
+
)
|
17 |
+
|
18 |
+
st.title("Audio Emotion Recognition System")
|
19 |
+
st.write("This is a web application for audio emotion recognition.")
|
20 |
+
|
21 |
+
# 选择页面
|
22 |
+
page = st.sidebar.selectbox("Select a page", ["Emotion Analyzer", "Chatbot"])
|
23 |
+
|
24 |
+
if page == "Emotion Analyzer":
|
25 |
+
emotion_analyzer.show()
|
26 |
+
elif page == "Chatbot":
|
27 |
+
chatbot.show_chatbot()
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
main()
|
components/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# 空文件,使components成为一个Python包
|
components/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (190 Bytes). View file
|
|
components/__pycache__/audio_player.cpython-313.pyc
ADDED
Binary file (480 Bytes). View file
|
|
components/__pycache__/debug_tools.cpython-313.pyc
ADDED
Binary file (4.19 kB). View file
|
|
components/__pycache__/visualizations.cpython-313.pyc
ADDED
Binary file (1.03 kB). View file
|
|
components/audio_player.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def play_audio(audio_file):
|
4 |
+
"""
|
5 |
+
显示音频播放器组件
|
6 |
+
|
7 |
+
Args:
|
8 |
+
audio_file: 上传的音频文件
|
9 |
+
"""
|
10 |
+
st.audio(audio_file)
|
components/debug_tools.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
class DebugTools:
|
8 |
+
@staticmethod
|
9 |
+
def show_debug_info():
|
10 |
+
"""显示调试信息的可折叠部分"""
|
11 |
+
with st.expander("Debug Information", expanded=False):
|
12 |
+
# 系统信息
|
13 |
+
st.subheader("System Information")
|
14 |
+
st.text(f"System: {platform.system()} {platform.version()}")
|
15 |
+
st.text(f"Python Version: {sys.version}")
|
16 |
+
|
17 |
+
# 内存使用
|
18 |
+
try:
|
19 |
+
import psutil
|
20 |
+
process = psutil.Process(os.getpid())
|
21 |
+
st.text(f"Memory Usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")
|
22 |
+
except ImportError:
|
23 |
+
st.text("Memory Usage: Unable to get (requires psutil)")
|
24 |
+
|
25 |
+
# GPU信息
|
26 |
+
try:
|
27 |
+
import torch
|
28 |
+
if torch.cuda.is_available():
|
29 |
+
st.text(f"GPU: {torch.cuda.get_device_name(0)}")
|
30 |
+
st.text(f"GPU Memory: {torch.cuda.memory_allocated(0)/1024/1024:.2f}MB / "
|
31 |
+
f"{torch.cuda.memory_reserved(0)/1024/1024:.2f}MB")
|
32 |
+
else:
|
33 |
+
st.text("GPU: Not Available")
|
34 |
+
except Exception as e:
|
35 |
+
st.text("GPU Information Retrieval Failed")
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def log_error(error, context=None):
|
39 |
+
"""记录错误信息"""
|
40 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
41 |
+
error_msg = f"[{timestamp}] Error: {str(error)}"
|
42 |
+
if context:
|
43 |
+
error_msg += f"\nContext: {context}"
|
44 |
+
|
45 |
+
st.error(error_msg)
|
46 |
+
# 可以添加日志文件记录
|
47 |
+
print(error_msg, file=sys.stderr)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def show_audio_info(audio_file):
|
51 |
+
"""显示音频文件信息"""
|
52 |
+
if audio_file is not None:
|
53 |
+
st.write("Audio File Information:")
|
54 |
+
st.text(f"File Name: {audio_file.name}")
|
55 |
+
st.text(f"File Size: {audio_file.size/1024:.2f} KB")
|
56 |
+
st.text(f"File Type: {audio_file.type}")
|
components/progress_bar.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def show_progress():
|
4 |
+
"""
|
5 |
+
返回一个streamlit进度条组件
|
6 |
+
"""
|
7 |
+
return st.progress(0)
|
components/visualizations.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
import plotly.express as px
|
4 |
+
|
5 |
+
def plot_emotion_distribution(emotion_dict):
|
6 |
+
"""
|
7 |
+
绘制情绪概率分布图(雷达图)
|
8 |
+
|
9 |
+
Args:
|
10 |
+
emotion_dict: 包含情绪标签和对应概率的字典
|
11 |
+
"""
|
12 |
+
emotions = list(emotion_dict.keys())
|
13 |
+
probabilities = [float(emotion_dict[emotion].strip('%')) / 100 for emotion in emotions] # 转换为浮点数
|
14 |
+
|
15 |
+
# 创建雷达图
|
16 |
+
fig = go.Figure()
|
17 |
+
|
18 |
+
fig.add_trace(go.Scatterpolar(
|
19 |
+
r=probabilities + [probabilities[0]], # 闭合图形
|
20 |
+
theta=emotions + [emotions[0]], # 闭合图形
|
21 |
+
fill='toself',
|
22 |
+
name='Emotion Distribution'
|
23 |
+
))
|
24 |
+
|
25 |
+
fig.update_layout(
|
26 |
+
title="Emotion Distribution",
|
27 |
+
polar=dict(
|
28 |
+
radialaxis=dict(
|
29 |
+
visible=True,
|
30 |
+
range=[0, 1] # 设置范围
|
31 |
+
)),
|
32 |
+
showlegend=False
|
33 |
+
)
|
34 |
+
|
35 |
+
st.plotly_chart(fig, use_container_width=True)
|
pages/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# 空文件,使pages成为一个Python包
|
pages/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (185 Bytes). View file
|
|
pages/__pycache__/chatbot.cpython-313.pyc
ADDED
Binary file (5.55 kB). View file
|
|
pages/__pycache__/emotion_analyzer.cpython-313.pyc
ADDED
Binary file (2.42 kB). View file
|
|
pages/chatbot.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import json
|
5 |
+
from openai import AzureOpenAI
|
6 |
+
from openai.types.beta.threads import Message
|
7 |
+
from safetensors.torch import load_file
|
8 |
+
from transformers import AutoTokenizer, Wav2Vec2Processor, BertModel, Wav2Vec2Model
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
from utils import model_inference
|
12 |
+
import os
|
13 |
+
|
14 |
+
# 加载环境变量
|
15 |
+
load_dotenv(".env")
|
16 |
+
api_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
17 |
+
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
18 |
+
api_version = os.getenv("AZURE_OPENAI_API_VERSION")
|
19 |
+
api_deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
20 |
+
|
21 |
+
# 初始化 OpenAI 客户端
|
22 |
+
client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_endpoint)
|
23 |
+
|
24 |
+
# 设定 Chatbot 角色
|
25 |
+
instruction = (
|
26 |
+
"You are a psychiatrist talking to a patient who may be depressed. "
|
27 |
+
"You'll receive their emotional state and conversation text. "
|
28 |
+
"Your goal is to help them open up and guide them to a positive path. "
|
29 |
+
"Be friendly, professional, empathetic, and supportive."
|
30 |
+
)
|
31 |
+
|
32 |
+
# 设定 Chatbot 线程和助手
|
33 |
+
if "thread" not in st.session_state:
|
34 |
+
st.session_state.thread = client.beta.threads.create()
|
35 |
+
|
36 |
+
if "assistant" not in st.session_state:
|
37 |
+
assistant_id = "asst_Sb1W9jVTeL1iyzu6N5MilgA1"
|
38 |
+
try:
|
39 |
+
st.session_state.assistant = client.beta.assistants.retrieve(assistant_id=assistant_id)
|
40 |
+
except:
|
41 |
+
st.session_state.assistant = client.beta.assistants.create(
|
42 |
+
name="Depression Chatbot",
|
43 |
+
instructions=instruction,
|
44 |
+
model=api_deployment_name,
|
45 |
+
)
|
46 |
+
|
47 |
+
# 发送消息到 Azure Chatbot
|
48 |
+
def send_message_to_chatbot(user_input, emotion):
|
49 |
+
chat_history = client.beta.threads.messages.list(thread_id=st.session_state.thread.id)
|
50 |
+
messages = [{"role": msg.role, "content": msg.content} for msg in chat_history]
|
51 |
+
|
52 |
+
messages.append({"role": "user", "content": f"Emotion: {emotion}. {user_input}"})
|
53 |
+
|
54 |
+
client.beta.threads.messages.create(
|
55 |
+
thread_id=st.session_state.thread.id,
|
56 |
+
role="user",
|
57 |
+
content=f"Emotion: {emotion}. {user_input}",
|
58 |
+
)
|
59 |
+
|
60 |
+
run = client.beta.threads.runs.create(
|
61 |
+
thread_id=st.session_state.thread.id,
|
62 |
+
assistant_id=st.session_state.assistant.id,
|
63 |
+
)
|
64 |
+
|
65 |
+
while run.status in ["queued", "in_progress"]:
|
66 |
+
run = client.beta.threads.runs.retrieve(run.id)
|
67 |
+
|
68 |
+
response_messages = client.beta.threads.messages.list(thread_id=st.session_state.thread.id)
|
69 |
+
return response_messages[-1].content if response_messages else "No response."
|
70 |
+
|
71 |
+
# Streamlit 界面
|
72 |
+
st.title("🧠 AI Depression Chatbot")
|
73 |
+
|
74 |
+
if "chat_history" not in st.session_state:
|
75 |
+
st.session_state.chat_history = []
|
76 |
+
|
77 |
+
# 用户输入
|
78 |
+
user_input = st.text_input("Enter your message:")
|
79 |
+
audio_file = st.file_uploader("Upload audio file", type=["wav", "mp3"])
|
80 |
+
|
81 |
+
if st.button("Send"):
|
82 |
+
if user_input or audio_file:
|
83 |
+
emotion_probabilities = model_inference.predict_emotion(user_input, audio_file)
|
84 |
+
dominant_emotion = max(emotion_probabilities, key=emotion_probabilities.get)
|
85 |
+
|
86 |
+
chatbot_response = send_message_to_chatbot(user_input, dominant_emotion)
|
87 |
+
|
88 |
+
# 保存聊天记录
|
89 |
+
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
90 |
+
st.session_state.chat_history.append({"role": "assistant", "content": chatbot_response})
|
91 |
+
|
92 |
+
# 显示聊天记录
|
93 |
+
for chat in st.session_state.chat_history:
|
94 |
+
st.write(f"**{chat['role'].capitalize()}**: {chat['content']}")
|
95 |
+
|
96 |
+
else:
|
97 |
+
st.warning("Please enter a message or upload an audio file.")
|
pages/emotion_analyzer.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from components.visualizations import plot_emotion_distribution
|
3 |
+
from utils import model_inference
|
4 |
+
from components.audio_player import play_audio
|
5 |
+
from components.debug_tools import DebugTools
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
def show_history():
|
10 |
+
"""显示历史记录"""
|
11 |
+
history_file = "history.json"
|
12 |
+
|
13 |
+
if os.path.exists(history_file):
|
14 |
+
with open(history_file, 'r') as f:
|
15 |
+
history = json.load(f)
|
16 |
+
|
17 |
+
if history:
|
18 |
+
st.subheader("History")
|
19 |
+
for record in history:
|
20 |
+
st.write(f"Audio File: {record['audio_file']}")
|
21 |
+
st.write(f"Transcript: {record['transcript']}")
|
22 |
+
st.write(f"Emotions: {record['emotions']}")
|
23 |
+
st.write(f"Probabilities: {record['probabilities']}")
|
24 |
+
st.write("---")
|
25 |
+
else:
|
26 |
+
st.write("No history records.")
|
27 |
+
else:
|
28 |
+
st.write("No history file.")
|
29 |
+
|
30 |
+
def show():
|
31 |
+
st.header("Emotion Analyzer")
|
32 |
+
|
33 |
+
# 显示历史记录
|
34 |
+
show_history()
|
35 |
+
|
36 |
+
# 初始化调试工具
|
37 |
+
debug = DebugTools()
|
38 |
+
|
39 |
+
# 显示系统调试信息
|
40 |
+
debug.show_debug_info()
|
41 |
+
|
42 |
+
# 文件上传
|
43 |
+
audio_file = st.file_uploader("Upload audio file", type=['wav', 'mp3'])
|
44 |
+
text_input = st.text_input("Enter text input")
|
45 |
+
|
46 |
+
if audio_file is not None and text_input:
|
47 |
+
# 显示音频文件信息
|
48 |
+
debug.show_audio_info(audio_file)
|
49 |
+
|
50 |
+
# 使用audio_player组件
|
51 |
+
play_audio(audio_file)
|
52 |
+
|
53 |
+
if st.button("Analyse Your Emotion!😊"):
|
54 |
+
# 显示进度条
|
55 |
+
progress_bar = st.progress(0)
|
56 |
+
|
57 |
+
try:
|
58 |
+
# 直接使用用户输入的文本作为转写
|
59 |
+
transcript = text_input
|
60 |
+
st.write("Audio transcript:", transcript)
|
61 |
+
|
62 |
+
# 2. 情绪分析
|
63 |
+
with st.spinner("Analysing emotion..."):
|
64 |
+
progress_bar.progress(30)
|
65 |
+
emotions = model_inference.predict_emotion(text_input, audio_file)
|
66 |
+
|
67 |
+
# 3. 显示结果
|
68 |
+
progress_bar.progress(30)
|
69 |
+
|
70 |
+
# 显示预测结果
|
71 |
+
st.success(f"Predict: {emotions}")
|
72 |
+
|
73 |
+
# 显示情绪概率分布图
|
74 |
+
plot_emotion_distribution(emotions)
|
75 |
+
|
76 |
+
# 保存历史记录
|
77 |
+
model_inference.save_history(audio_file, transcript, emotions, None) # 这里可以根据需要调整
|
78 |
+
|
79 |
+
except Exception as e:
|
80 |
+
debug.log_error(e, context=f"Processing file: {audio_file.name}")
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit>=1.0.0
|
2 |
+
plotly>=5.0.0
|
3 |
+
librosa>=0.9.0
|
4 |
+
scipy>=1.7.0
|
5 |
+
numpy>=1.21.0
|
6 |
+
torch>=1.9.0
|
7 |
+
transformers>=4.11.0
|
8 |
+
soundfile>=0.10.3
|
9 |
+
psutil>=5.8.0
|
10 |
+
huggingface-hub>=0.0.12
|
11 |
+
safetensors>=0.0.3
|
12 |
+
torchaudio>=0.9.0
|
13 |
+
openai
|
14 |
+
dotenv
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# 空文件,使utils成为一个Python包
|
utils/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (185 Bytes). View file
|
|
utils/__pycache__/audio_processing.cpython-313.pyc
ADDED
Binary file (1.69 kB). View file
|
|
utils/__pycache__/model_inference.cpython-313.pyc
ADDED
Binary file (6.21 kB). View file
|
|
utils/audio_processing.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
from scipy.signal import butter, filtfilt
|
4 |
+
|
5 |
+
class AudioProcessor:
|
6 |
+
@staticmethod
|
7 |
+
def load_and_process_audio(file_path, target_sr=16000):
|
8 |
+
"""加载并处理音频文件"""
|
9 |
+
# 加载音频文件
|
10 |
+
audio_data, sr = librosa.load(file_path, sr=target_sr)
|
11 |
+
|
12 |
+
# 归一化音频数据
|
13 |
+
audio_data = librosa.util.normalize(audio_data)
|
14 |
+
|
15 |
+
return audio_data
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def resample(audio_data, orig_sr, target_sr=16000):
|
19 |
+
"""重采样音频"""
|
20 |
+
return librosa.resample(audio_data, orig_sr=orig_sr, target_sr=target_sr)
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def denoise(audio_data, sr):
|
24 |
+
"""音频降噪"""
|
25 |
+
# 实现降噪逻辑
|
26 |
+
return audio_data
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def normalize(audio_data):
|
30 |
+
"""音频归一化"""
|
31 |
+
return librosa.util.normalize(audio_data)
|
utils/logger.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
class Logger:
|
6 |
+
def __init__(self):
|
7 |
+
# 创建logs目录
|
8 |
+
log_dir = "logs"
|
9 |
+
if not os.path.exists(log_dir):
|
10 |
+
os.makedirs(log_dir)
|
11 |
+
|
12 |
+
# 设置日志文件名
|
13 |
+
log_file = os.path.join(
|
14 |
+
log_dir,
|
15 |
+
f"app_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
16 |
+
)
|
17 |
+
|
18 |
+
# 配置日志
|
19 |
+
logging.basicConfig(
|
20 |
+
level=logging.INFO,
|
21 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
22 |
+
handlers=[
|
23 |
+
logging.FileHandler(log_file),
|
24 |
+
logging.StreamHandler()
|
25 |
+
]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.logger = logging.getLogger('emotion_recognition')
|
29 |
+
|
30 |
+
def info(self, message):
|
31 |
+
self.logger.info(message)
|
32 |
+
|
33 |
+
def error(self, message, exc_info=True):
|
34 |
+
self.logger.error(message, exc_info=exc_info)
|
35 |
+
|
36 |
+
def debug(self, message):
|
37 |
+
self.logger.debug(message)
|
utils/model_inference.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from transformers import AutoTokenizer, BertModel, Wav2Vec2Model
|
5 |
+
from utils.audio_processing import AudioProcessor
|
6 |
+
import torchaudio
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
from safetensors.torch import load_file
|
10 |
+
|
11 |
+
# 下载模型
|
12 |
+
model_path = hf_hub_download(repo_id="liloge/Group7_model_test", filename="model.safetensors")
|
13 |
+
|
14 |
+
class MultimodalClassifier(torch.nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(MultimodalClassifier, self).__init__()
|
17 |
+
self.bert = BertModel.from_pretrained("bert-base-uncased")
|
18 |
+
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
19 |
+
self.classifier = torch.nn.Sequential(
|
20 |
+
torch.nn.Linear(self.bert.config.hidden_size + self.wav2vec2.config.hidden_size, 256),
|
21 |
+
torch.nn.ReLU(),
|
22 |
+
torch.nn.Dropout(0.7),
|
23 |
+
torch.nn.Linear(256, 7) # 7分类任务
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, text_input, audio_input):
|
27 |
+
text_outputs = self.bert(**text_input, output_hidden_states=True)
|
28 |
+
text_features = text_outputs.hidden_states[-1][:, 0, :] # [CLS] token
|
29 |
+
audio_outputs = self.wav2vec2(audio_input, output_hidden_states=True)
|
30 |
+
audio_features = audio_outputs.hidden_states[-1][:, 0, :]
|
31 |
+
combined_features = torch.cat((text_features, audio_features), dim=-1)
|
32 |
+
logits = self.classifier(combined_features)
|
33 |
+
return logits
|
34 |
+
|
35 |
+
# 加载模型
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
model = MultimodalClassifier().to(device)
|
38 |
+
|
39 |
+
# 加载 SafeTensors 权重
|
40 |
+
state_dict = load_file(model_path)
|
41 |
+
model.load_state_dict(state_dict)
|
42 |
+
model.eval() # 设置为评估模式
|
43 |
+
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
45 |
+
|
46 |
+
def preprocess_text(text):
|
47 |
+
return tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
|
48 |
+
|
49 |
+
def preprocess_audio(audio_path):
|
50 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
51 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
52 |
+
return waveform.to(device)
|
53 |
+
|
54 |
+
labels = ["Neutral", "Happy", "Sad", "Angry", "Fearful", "Disgusted", "Surprised"]
|
55 |
+
|
56 |
+
def predict_emotion(text, audio):
|
57 |
+
text_inputs = preprocess_text(text)
|
58 |
+
audio_inputs = preprocess_audio(audio)
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
output = model(text_inputs, audio_inputs) # (1, 7) logits
|
62 |
+
probabilities = F.softmax(output, dim=1).squeeze().tolist() # 归一化为概率
|
63 |
+
|
64 |
+
return {labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(labels))}
|
65 |
+
|
66 |
+
def generate_transcript(audio_file):
|
67 |
+
"""生成音频的文字转写"""
|
68 |
+
return audio_file.name # 直接返回音频文件的名称
|
69 |
+
|
70 |
+
def save_history(audio_file, transcript, emotions, probabilities):
|
71 |
+
"""保存分析历史记录到文件"""
|
72 |
+
history_file = "history.json"
|
73 |
+
|
74 |
+
if not os.path.exists(history_file):
|
75 |
+
with open(history_file, 'w') as f:
|
76 |
+
json.dump([], f)
|
77 |
+
|
78 |
+
with open(history_file, 'r') as f:
|
79 |
+
history = json.load(f)
|
80 |
+
|
81 |
+
history.append({
|
82 |
+
"audio_file": audio_file.name,
|
83 |
+
"transcript": transcript,
|
84 |
+
"emotions": emotions,
|
85 |
+
"probabilities": probabilities
|
86 |
+
})
|
87 |
+
|
88 |
+
with open(history_file, 'w') as f:
|
89 |
+
json.dump(history, f, indent=4)
|