Group7 / utils /model_inference.py
loge-dot
last_version_upload to huggingface
7a66365
import torch
import json
import os
from transformers import AutoTokenizer, BertModel, Wav2Vec2Model
from utils.audio_processing import AudioProcessor
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoModelForSequenceClassification, AutoConfig, Wav2Vec2ForPreTraining
# 下载模型
# huggingface_hub 仓库下载
# model_path = hf_hub_download(repo_id="liloge/Group7_model_test", filename="model.safetensors")
# 本地下载
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoConfig, Wav2Vec2ForPreTraining
class MultimodalClassifier(nn.Module):
def __init__(self, wav2vec2_config_path):
super().__init__()
# **加载微调后的 BERT**
self.bert = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=7
)
self.bert.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.bert.config.hidden_size, self.bert.config.num_labels)
)
# try:
# self.bert.load_state_dict(torch.load(bert_ckpt_path, map_location=torch.device("cpu")), strict=True)
# except Exception as e:
# print(f"❌ 加载 `{bert_ckpt_path}` 失败: {e}")
# **先加载 Wav2Vec2**
config = AutoConfig.from_pretrained(wav2vec2_config_path, num_labels=7)
self.wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base", config=config)
# **再修改 Wav2Vec2 的分类头**
self.wav2vec2.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.wav2vec2.config.hidden_size, self.wav2vec2.config.num_labels)
)
# # **加载 safetensors 权重**
# from safetensors.torch import load_file
# state_dict = load_file(wav2vec2_safetensors_path)
# try:
# self.wav2vec2.load_state_dict(state_dict, strict=False)
# except Exception as e:
# print(f"❌ 加载 `{wav2vec2_safetensors_path}` 失败: {e}")
# **拼接特征的分类头**
self.classifier = nn.Sequential(
nn.Linear(self.bert.config.hidden_size + self.wav2vec2.config.hidden_size, 256),
nn.ReLU(),
nn.Dropout(0.7),
nn.Linear(256, 7) # 7分类任务
)
def forward(self, text_input, audio_input):
# **文本特征**
text_outputs = self.bert(**text_input, output_hidden_states=True)
text_features = text_outputs.hidden_states[-1][:, 0, :]
# **音频特征**
audio_outputs = self.wav2vec2(audio_input, output_hidden_states=True)
audio_features = audio_outputs.hidden_states[-1][:, 0, :]
# **拼接特征**
combined_features = torch.cat((text_features, audio_features), dim=-1)
# **分类**
logits = self.classifier(combined_features)
return logits
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# **定义路径**
wav2vec2_config_path = r"models/config.json"
model_path = r"models/model.safetensors"
# **加载模型及其参数**
model = MultimodalClassifier(wav2vec2_config_path).to(device)
state_dict = load_file(model_path)
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
print("✅ 微调的 BERT + Wav2Vec2 模型加载成功!")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def preprocess_text(text):
text_inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
return text_inputs.to(device)
def preprocess_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
return waveform.to(device)
labels = ["Neutral", "Joy", "Sad", "Angry", "Surprised", "Fearful", "Disgusted"]
def predict_emotion(text, audio):
text_inputs = preprocess_text(text)
audio_inputs = preprocess_audio(audio)
with torch.no_grad():
output = model(text_input=text_inputs, audio_input=audio_inputs) # (1, 7) logits
probabilities = F.softmax(output, dim=1).squeeze().tolist() # 归一化为概率
return {labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(labels))}
def generate_transcript(audio_file):
"""生成音频的文字转写"""
return audio_file.name # 直接返回音频文件的名称
def save_history(audio_file, transcript, emotions):
"""保存分析历史记录到文件"""
history_file = r"history/history.json"
if not os.path.exists(history_file):
with open(history_file, 'w') as f:
json.dump([], f)
with open(history_file, 'r') as f:
history = json.load(f)
history.append({
"audio_file": audio_file.name,
"transcript": transcript,
"emotions": emotions,
})
with open(history_file, 'w') as f:
json.dump(history, f, indent=4)