Group7 / chatbot.py
loge-dot
last_version_upload to huggingface
7a66365
import streamlit as st
import torch
import torchaudio
import json
from openai import AzureOpenAI
from openai.types.beta.threads import Message
from safetensors.torch import load_file
from transformers import AutoTokenizer, Wav2Vec2Processor, BertModel, Wav2Vec2Model
from huggingface_hub import hf_hub_download
from dotenv import load_dotenv
from utils import model_inference
import os
# 加载环境变量
load_dotenv(r"Group7/.env")
api_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
api_version = os.getenv("AZURE_OPENAI_API_VERSION")
api_deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT")
# 初始化 OpenAI 客户端
client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_endpoint)
# 设定 Chatbot 角色
instruction = (
"You are a psychiatrist talking to a patient who may be depressed. "
"You'll receive their emotional state and conversation text. "
"Your goal is to help them open up and guide them to a positive path. "
"Be friendly, professional, empathetic, and supportive."
)
# 设定 Chatbot 线程和助手
if "thread" not in st.session_state:
st.session_state.thread = client.beta.threads.create()
if "assistant" not in st.session_state:
assistant_id = "asst_Sb1W9jVTeL1iyzu6N5MilgA1"
try:
st.session_state.assistant = client.beta.assistants.retrieve(assistant_id=assistant_id)
except:
st.session_state.assistant = client.beta.assistants.create(
name="Depression Chatbot",
instructions=instruction,
model=api_deployment_name,
)
# 发送消息到 Azure Chatbot
def send_message_to_chatbot(user_input, emotion):
chat_history = client.beta.threads.messages.list(thread_id=st.session_state.thread.id)
messages = [{"role": msg.role, "content": msg.content} for msg in chat_history]
messages.append({"role": "user", "content": f"Emotion: {emotion}. {user_input}"})
client.beta.threads.messages.create(
thread_id=st.session_state.thread.id,
role="user",
content=f"Emotion: {emotion}. {user_input}",
)
run = client.beta.threads.runs.create(
thread_id=st.session_state.thread.id,
assistant_id=st.session_state.assistant.id,
)
while run.status in ["queued", "in_progress"]:
run = client.beta.threads.runs.retrieve(run.id)
response_messages = client.beta.threads.messages.list(thread_id=st.session_state.thread.id)
return response_messages[-1].content if response_messages else "No response."
# Streamlit 界面
st.title("🧠 AI Depression Chatbot")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# 用户输入
user_input = st.text_input("Enter your message:")
audio_file = st.file_uploader("Upload audio file", type=["wav", "mp3"])
if st.button("Send"):
if user_input or audio_file:
emotion_probabilities = model_inference.predict_emotion(user_input, audio_file)
dominant_emotion = max(emotion_probabilities, key=emotion_probabilities.get)
chatbot_response = send_message_to_chatbot(user_input, dominant_emotion)
# 保存聊天记录
st.session_state.chat_history.append({"role": "user", "content": user_input})
st.session_state.chat_history.append({"role": "assistant", "content": chatbot_response})
# 显示聊天记录
for chat in st.session_state.chat_history:
st.write(f"**{chat['role'].capitalize()}**: {chat['content']}")
else:
st.warning("Please enter a message or upload an audio file.")