Spaces:
Running
Running
Commit
·
960b1a0
1
Parent(s):
4702e13
gpu
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- LICENSE +21 -0
- Phi-4-mini-instruct_emotions_union.csv +0 -0
- Qwen3-4B_emotions_meld.csv +0 -0
- Qwen3-4B_emotions_resd.csv +0 -0
- Qwen3-4B_emotions_union.csv +0 -0
- analysis.ipynb +0 -0
- app.py +315 -0
- best_audio_model.pt +3 -0
- best_audio_model_2.pt +3 -0
- best_model_dev_0_5895_epoch_8.pt +3 -0
- best_text_model.pth +3 -0
- check.py +230 -0
- config.toml +133 -0
- data_loading/__pycache__/feature_extractor.cpython-310.pyc +0 -0
- data_loading/__pycache__/pretrained_extractors.cpython-310.pyc +0 -0
- data_loading/dataset_multimodal.py +898 -0
- data_loading/feature_extractor.py +410 -0
- data_loading/pretrained_extractors.py +221 -0
- emotion_templates/anger.json +196 -0
- emotion_templates/disgust.json +174 -0
- emotion_templates/fear.json +178 -0
- emotion_templates/happy.json +187 -0
- emotion_templates/neutral.json +97 -0
- emotion_templates/sad.json +183 -0
- emotion_templates/surprise.json +198 -0
- generate_emotion_texts_dataset.py +137 -0
- generate_synthetic_dataset.py +71 -0
- main.py +119 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/help_layers.cpython-310.pyc +0 -0
- models/__pycache__/models.cpython-310.pyc +0 -0
- models/help_layers.py +528 -0
- models/models.py +1700 -0
- requirements.txt +0 -0
- run_generation.py +32 -0
- search_params.toml +22 -0
- synthetic_utils/__pycache__/dia_tts_wrapper.cpython-310.pyc +0 -0
- synthetic_utils/dia_tts_wrapper.py +77 -0
- synthetic_utils/parler_tts_wrapper.py +60 -0
- synthetic_utils/text_generation.py +91 -0
- test.py +28 -0
- training/train_utils.py +585 -0
- training/train_utils_old.py +379 -0
- utils/__pycache__/config_loader.cpython-310.pyc +0 -0
- utils/config_loader.py +204 -0
- utils/logger_setup.py +47 -0
- utils/losses.py +33 -0
- utils/measures.py +41 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
env/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 LEYA Lab for Natural Language Processing
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Phi-4-mini-instruct_emotions_union.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Qwen3-4B_emotions_meld.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Qwen3-4B_emotions_resd.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Qwen3-4B_emotions_union.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
analysis.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torchaudio
|
3 |
+
import pandas as pd
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import whisper
|
6 |
+
import logging
|
7 |
+
import plotly.express as px
|
8 |
+
from utils.config_loader import ConfigLoader
|
9 |
+
from data_loading.feature_extractor import (
|
10 |
+
PretrainedAudioEmbeddingExtractor,
|
11 |
+
PretrainedTextEmbeddingExtractor
|
12 |
+
)
|
13 |
+
import chardet
|
14 |
+
import torch
|
15 |
+
from models.models import BiFormer
|
16 |
+
|
17 |
+
|
18 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
# DEVICE = torch.device('cpu')
|
20 |
+
|
21 |
+
# Configure logging
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
|
24 |
+
# Constants with emojis and colors
|
25 |
+
LABEL_TO_EMOTION = {
|
26 |
+
0: '😠 Anger',
|
27 |
+
1: '🤢 Disgust',
|
28 |
+
2: '😨 Fear',
|
29 |
+
3: '😄 Joy/Happiness',
|
30 |
+
4: '😐 Neutral',
|
31 |
+
5: '😢 Sadness',
|
32 |
+
6: '😲 Surprise/Enthusiasm'
|
33 |
+
}
|
34 |
+
EMOTION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#FF9999', '#D4A5A5']
|
35 |
+
emotion_color_map = {emotion: color for emotion, color in zip(LABEL_TO_EMOTION.values(), EMOTION_COLORS)}
|
36 |
+
TARGET_SAMPLE_RATE = 16000
|
37 |
+
|
38 |
+
|
39 |
+
def initialize_components(config_path='config.toml'):
|
40 |
+
"""Initialize configuration and models."""
|
41 |
+
config = ConfigLoader(config_path)
|
42 |
+
config.show_config()
|
43 |
+
model = BiFormer(
|
44 |
+
audio_dim=256,
|
45 |
+
text_dim=1024,
|
46 |
+
seg_len=95,
|
47 |
+
hidden_dim=256,
|
48 |
+
hidden_dim_gated=256,
|
49 |
+
num_transformer_heads=8,
|
50 |
+
num_graph_heads=2,
|
51 |
+
positional_encoding=False,
|
52 |
+
dropout=0.15,
|
53 |
+
mode='mean',
|
54 |
+
# device="cuda",
|
55 |
+
tr_layer_number=5,
|
56 |
+
out_features=256,
|
57 |
+
num_classes=7
|
58 |
+
)
|
59 |
+
checkpoint_path = "best_model_dev_0_5895_epoch_8.pt"
|
60 |
+
state = torch.load(checkpoint_path, map_location="cpu")
|
61 |
+
model.load_state_dict(state)
|
62 |
+
model = model.to(DEVICE)
|
63 |
+
model.eval()
|
64 |
+
return (
|
65 |
+
PretrainedAudioEmbeddingExtractor(config),
|
66 |
+
PretrainedTextEmbeddingExtractor(config),
|
67 |
+
whisper.load_model("base"),
|
68 |
+
model
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
audio_extractor, text_extractor, whisper_model, bimodal_model = initialize_components()
|
73 |
+
|
74 |
+
|
75 |
+
def load_and_preprocess_audio(audio_path):
|
76 |
+
"""Load and preprocess audio to mono 16kHz format."""
|
77 |
+
try:
|
78 |
+
waveform, orig_sr = torchaudio.load(audio_path)
|
79 |
+
waveform = waveform.mean(dim=0, keepdim=False)
|
80 |
+
|
81 |
+
if orig_sr != TARGET_SAMPLE_RATE:
|
82 |
+
resampler = torchaudio.transforms.Resample(
|
83 |
+
orig_freq=orig_sr,
|
84 |
+
new_freq=TARGET_SAMPLE_RATE
|
85 |
+
)
|
86 |
+
waveform = resampler(waveform)
|
87 |
+
|
88 |
+
return waveform, TARGET_SAMPLE_RATE
|
89 |
+
except Exception as e:
|
90 |
+
logging.error(f"Audio loading failed: {e}")
|
91 |
+
raise
|
92 |
+
|
93 |
+
|
94 |
+
def transcribe_audio(audio_path):
|
95 |
+
"""Convert speech to text using Whisper."""
|
96 |
+
try:
|
97 |
+
result = whisper_model.transcribe(audio_path, fp16=False)
|
98 |
+
return result.get('text', '')
|
99 |
+
except Exception as e:
|
100 |
+
logging.error(f"Transcription failed: {e}")
|
101 |
+
return ""
|
102 |
+
|
103 |
+
|
104 |
+
def get_predictions(input_data, extractor, is_audio=False):
|
105 |
+
"""Generic prediction function for audio/text."""
|
106 |
+
try:
|
107 |
+
if is_audio:
|
108 |
+
pred, emb = extractor.extract(input_data, TARGET_SAMPLE_RATE)
|
109 |
+
else:
|
110 |
+
pred, emb = extractor.extract(input_data)
|
111 |
+
|
112 |
+
return F.softmax(pred, dim=-1)[0].tolist(), emb
|
113 |
+
except Exception as e:
|
114 |
+
logging.error(f"Prediction failed: {e}")
|
115 |
+
return [0.0] * len(LABEL_TO_EMOTION), None
|
116 |
+
|
117 |
+
|
118 |
+
def create_emotion_df(probabilities):
|
119 |
+
"""Create sorted emotion probability dataframe with percentages."""
|
120 |
+
df = pd.DataFrame({
|
121 |
+
'Emotion': list(LABEL_TO_EMOTION.values()),
|
122 |
+
'Probability': [round(p*100, 2) for p in probabilities]
|
123 |
+
})
|
124 |
+
return df
|
125 |
+
|
126 |
+
|
127 |
+
def create_plot(df, title):
|
128 |
+
"""Create Plotly bar chart with proper formatting."""
|
129 |
+
fig = px.bar(
|
130 |
+
df,
|
131 |
+
x='Emotion',
|
132 |
+
y='Probability',
|
133 |
+
title=title,
|
134 |
+
color='Emotion',
|
135 |
+
color_discrete_map=emotion_color_map
|
136 |
+
)
|
137 |
+
fig.update_layout(
|
138 |
+
xaxis=dict(tickangle=-45, tickfont=dict(size=12)),
|
139 |
+
yaxis=dict(title='Probability (%)'),
|
140 |
+
margin=dict(l=20, r=20, t=60, b=100),
|
141 |
+
height=400,
|
142 |
+
showlegend=False
|
143 |
+
)
|
144 |
+
return fig
|
145 |
+
|
146 |
+
|
147 |
+
def get_top_emotion(probabilities):
|
148 |
+
"""Return formatted top emotion with percentage."""
|
149 |
+
max_idx = probabilities.index(max(probabilities))
|
150 |
+
return f"{LABEL_TO_EMOTION[max_idx]} ({max(probabilities)*100:.1f}%)"
|
151 |
+
|
152 |
+
|
153 |
+
def process_audio(audio_path):
|
154 |
+
"""Main processing pipeline."""
|
155 |
+
try:
|
156 |
+
if not audio_path:
|
157 |
+
empty = create_emotion_df([0]*len(LABEL_TO_EMOTION))
|
158 |
+
return (
|
159 |
+
create_plot(empty, "🎧 Audio Analysis"),
|
160 |
+
"No audio detected",
|
161 |
+
create_plot(empty, "📝 Text Analysis"),
|
162 |
+
create_plot(empty, "🤝 Combined Analysis"),
|
163 |
+
"🔇 Please provide audio input"
|
164 |
+
)
|
165 |
+
|
166 |
+
# Audio processing
|
167 |
+
waveform, sr = load_and_preprocess_audio(audio_path)
|
168 |
+
audio_probs, audio_features = get_predictions(waveform, audio_extractor, is_audio=True)
|
169 |
+
audio_df = create_emotion_df(audio_probs)
|
170 |
+
|
171 |
+
# Text processing
|
172 |
+
text = transcribe_audio(audio_path)
|
173 |
+
text_probs, text_features = get_predictions(text, text_extractor) if text.strip() else [0.0]*len(LABEL_TO_EMOTION)
|
174 |
+
text_df = create_emotion_df(text_probs)
|
175 |
+
|
176 |
+
# Combined results
|
177 |
+
combined_probs = bimodal_model(audio_features, text_features)
|
178 |
+
combined_probs = F.softmax(combined_probs, dim=-1)[0].detach().cpu().numpy().tolist()
|
179 |
+
combined_df = create_emotion_df(combined_probs)
|
180 |
+
top_emotion = get_top_emotion(combined_probs)
|
181 |
+
|
182 |
+
return (
|
183 |
+
create_plot(audio_df, "🎧 Audio Analysis"),
|
184 |
+
f"🗣️ Transcription:\n{text}",
|
185 |
+
create_plot(text_df, "📝 Text Analysis"),
|
186 |
+
create_plot(combined_df, "🤝 Combined Analysis"),
|
187 |
+
f"## 🏆 Dominant Emotion: {top_emotion}"
|
188 |
+
)
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
logging.error(f"Processing failed: {e}")
|
192 |
+
error_df = create_emotion_df([0]*len(LABEL_TO_EMOTION))
|
193 |
+
return (
|
194 |
+
create_plot(error_df, "🎧 Audio Analysis"),
|
195 |
+
"❌ Error processing audio",
|
196 |
+
create_plot(error_df, "📝 Text Analysis"),
|
197 |
+
create_plot(error_df, "🤝 Combined Analysis"),
|
198 |
+
"⚠️ Processing Error"
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
def create_app():
|
203 |
+
"""Build enhanced Gradio interface."""
|
204 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection from Speech") as demo:
|
205 |
+
gr.Markdown("# 🎙️ Bimodal Emotion Recognition")
|
206 |
+
gr.Markdown("Analyze emotions in speech through both audio characteristics and spoken content")
|
207 |
+
|
208 |
+
with gr.Row():
|
209 |
+
audio_input = gr.Audio(
|
210 |
+
sources=["upload", "microphone"],
|
211 |
+
type="filepath",
|
212 |
+
label="Record or Upload Audio",
|
213 |
+
format="wav",
|
214 |
+
interactive=True
|
215 |
+
)
|
216 |
+
|
217 |
+
with gr.Row():
|
218 |
+
top_emotion = gr.Markdown("## 🏆 Dominant Emotion: Waiting for input...",
|
219 |
+
elem_classes="dominant-emotion")
|
220 |
+
|
221 |
+
with gr.Row():
|
222 |
+
with gr.Column():
|
223 |
+
audio_plot = gr.Plot(label="Audio Analysis")
|
224 |
+
with gr.Column():
|
225 |
+
text_plot = gr.Plot(label="Text Analysis")
|
226 |
+
with gr.Column():
|
227 |
+
combined_plot = gr.Plot(label="Combined Analysis")
|
228 |
+
|
229 |
+
transcription = gr.Textbox(
|
230 |
+
label="📜 Transcription Results",
|
231 |
+
placeholder="Transcribed text will appear here...",
|
232 |
+
lines=3,
|
233 |
+
max_lines=6
|
234 |
+
)
|
235 |
+
|
236 |
+
audio_input.change(
|
237 |
+
process_audio,
|
238 |
+
inputs=audio_input,
|
239 |
+
outputs=[audio_plot, transcription, text_plot, combined_plot, top_emotion]
|
240 |
+
)
|
241 |
+
|
242 |
+
return demo
|
243 |
+
|
244 |
+
|
245 |
+
def create_authors():
|
246 |
+
df = pd.DataFrame({
|
247 |
+
"Name": ["Author", "Author"]
|
248 |
+
})
|
249 |
+
with gr.Blocks() as demo:
|
250 |
+
gr.Dataframe(df)
|
251 |
+
return demo
|
252 |
+
|
253 |
+
|
254 |
+
def create_reqs():
|
255 |
+
"""Create requirements tab with formatted data and explanations."""
|
256 |
+
# 1️⃣ Detect file encoding
|
257 |
+
with open('requirements.txt', 'rb') as f:
|
258 |
+
raw_data = f.read()
|
259 |
+
encoding = chardet.detect(raw_data)['encoding']
|
260 |
+
|
261 |
+
# 2️⃣ Parse requirements into library-version pairs
|
262 |
+
def parse_requirements(lines):
|
263 |
+
requirements = []
|
264 |
+
for line in lines:
|
265 |
+
line = line.strip()
|
266 |
+
if not line or line.startswith('#'):
|
267 |
+
continue # Skip empty lines and comments
|
268 |
+
parts = line.split('==')
|
269 |
+
library = parts[0].strip()
|
270 |
+
version = parts[1].strip() if len(parts) > 1 else 'latest'
|
271 |
+
requirements.append((library, version))
|
272 |
+
return requirements
|
273 |
+
|
274 |
+
# 3️⃣ Load and process requirements
|
275 |
+
with open('requirements.txt', 'r', encoding=encoding) as f:
|
276 |
+
requirements = parse_requirements(f.readlines())
|
277 |
+
|
278 |
+
# 4️⃣ Create structured data for display
|
279 |
+
df = pd.DataFrame({
|
280 |
+
"📦 Library": [lib for lib, _ in requirements],
|
281 |
+
"🚀 Recommended Version": [ver for _, ver in requirements]
|
282 |
+
})
|
283 |
+
|
284 |
+
# 5️⃣ Build interactive components
|
285 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
286 |
+
gr.Markdown("# 📦 Dependency Requirements")
|
287 |
+
gr.Markdown("""
|
288 |
+
## Essential Packages for Operation
|
289 |
+
These are the core libraries and versions needed to run the application successfully:
|
290 |
+
""")
|
291 |
+
gr.Dataframe(
|
292 |
+
df,
|
293 |
+
interactive=True,
|
294 |
+
wrap=True,
|
295 |
+
elem_id="requirements-table"
|
296 |
+
)
|
297 |
+
gr.Markdown("_Note: Versions marked 'latest' can use any compatible version_")
|
298 |
+
|
299 |
+
return demo
|
300 |
+
|
301 |
+
|
302 |
+
def create_demo():
|
303 |
+
app = create_app()
|
304 |
+
authors = create_authors()
|
305 |
+
reqs = create_reqs()
|
306 |
+
demo = gr.TabbedInterface(
|
307 |
+
[app, authors, reqs],
|
308 |
+
tab_names=["🎙️ Speech Analysis", "👥 Project Team", "📦 Dependencies"]
|
309 |
+
)
|
310 |
+
return demo
|
311 |
+
|
312 |
+
|
313 |
+
if __name__ == "__main__":
|
314 |
+
demo = create_demo()
|
315 |
+
demo.launch()
|
best_audio_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1cf27208d1d99448dd075ae4b499f97aa5fffd0ef7290ebe82555bc09a350469
|
3 |
+
size 3174946
|
best_audio_model_2.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed1705e0cf0e5bc7bb1755f7fb09bfc440b4c92c739489d107d5e0e7a29707bc
|
3 |
+
size 3174674
|
best_model_dev_0_5895_epoch_8.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7eeba90564d5c301aa59f487aa59f71923b9d3f293327a99b88c34b71d4226f
|
3 |
+
size 17670490
|
best_text_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e25ae864742da9a8f4ba6b85b6161c503cb47dc6197dd0eb82b5250180a171cf
|
3 |
+
size 39340722
|
check.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Проверка синтетического корпуса MELD-S:
|
5 |
+
• существует ли WAV-файл;
|
6 |
+
• правильные ли размеры аудио- и текст-эмбеддингов;
|
7 |
+
• совпадает ли итоговый размер фич-вектора с ожиданием.
|
8 |
+
|
9 |
+
Результат:
|
10 |
+
GOOD / BAD в консоль + CSV bad_synth_meld.csv (если нашли проблемы).
|
11 |
+
"""
|
12 |
+
|
13 |
+
from __future__ import annotations
|
14 |
+
|
15 |
+
import csv
|
16 |
+
import logging
|
17 |
+
import sys
|
18 |
+
import traceback
|
19 |
+
from pathlib import Path
|
20 |
+
from types import SimpleNamespace
|
21 |
+
from typing import Dict, List, Optional
|
22 |
+
|
23 |
+
import pandas as pd
|
24 |
+
import torch
|
25 |
+
import torchaudio
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
# ----------------------------------------------------------------------
|
29 |
+
# >>>>>>>>> НАСТРОЙКИ ПОЛЬЗОВАТЕЛЯ (проверьте пути!) <<<<<<<<<<<
|
30 |
+
# ----------------------------------------------------------------------
|
31 |
+
USER_CONFIG = {
|
32 |
+
# пути к синтетике
|
33 |
+
"synthetic_path": r"E:/MELD_S",
|
34 |
+
"csv_name": "meld_s_train_labels.csv",
|
35 |
+
"wav_subdir": "wavs",
|
36 |
+
|
37 |
+
# модели / чекпойнты такие же, как в вашем config.toml
|
38 |
+
"audio_model_name": "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim",
|
39 |
+
"audio_ckpt": "best_audio_model_2.pt",
|
40 |
+
"text_model_name": "jinaai/jina-embeddings-v3",
|
41 |
+
"text_ckpt": "best_text_model.pth",
|
42 |
+
|
43 |
+
# общие параметры
|
44 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
45 |
+
"sample_rate": 16000,
|
46 |
+
"num_emotions": 7, # anger, disgust, fear, happy, neutral, sad, surprise
|
47 |
+
}
|
48 |
+
|
49 |
+
# ----------------------------------------------------------------------
|
50 |
+
# импорт собственных экстракторов
|
51 |
+
# ----------------------------------------------------------------------
|
52 |
+
try:
|
53 |
+
from feature_extractor import (
|
54 |
+
PretrainedAudioEmbeddingExtractor,
|
55 |
+
PretrainedTextEmbeddingExtractor,
|
56 |
+
)
|
57 |
+
except ModuleNotFoundError:
|
58 |
+
try:
|
59 |
+
# если файл лежит в data_loading/
|
60 |
+
from data_loading.feature_extractor import (
|
61 |
+
PretrainedAudioEmbeddingExtractor,
|
62 |
+
PretrainedTextEmbeddingExtractor,
|
63 |
+
)
|
64 |
+
except ModuleNotFoundError as e:
|
65 |
+
sys.exit(
|
66 |
+
"❌ Не найден feature_extractor.py. "
|
67 |
+
"Убедитесь, что он в PYTHONPATH или лежит рядом со скриптом."
|
68 |
+
)
|
69 |
+
|
70 |
+
# ----------------------------------------------------------------------
|
71 |
+
# вспомогательные функции
|
72 |
+
# ----------------------------------------------------------------------
|
73 |
+
def build_audio_cfg() -> SimpleNamespace:
|
74 |
+
"""Готовим config-объект для PretrainedAudioEmbeddingExtractor."""
|
75 |
+
return SimpleNamespace(
|
76 |
+
audio_model_name=USER_CONFIG["audio_model_name"],
|
77 |
+
emb_device=USER_CONFIG["device"],
|
78 |
+
audio_pooling="mean", # как в тренировке
|
79 |
+
emb_normalize=False,
|
80 |
+
max_audio_frames=0,
|
81 |
+
audio_classifier_checkpoint=USER_CONFIG["audio_ckpt"],
|
82 |
+
sample_rate=USER_CONFIG["sample_rate"],
|
83 |
+
wav_length=4,
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
def build_text_cfg() -> SimpleNamespace:
|
88 |
+
"""Config для PretrainedTextEmbeddingExtractor."""
|
89 |
+
return SimpleNamespace(
|
90 |
+
text_model_name=USER_CONFIG["text_model_name"],
|
91 |
+
emb_device=USER_CONFIG["device"],
|
92 |
+
text_pooling="mean",
|
93 |
+
emb_normalize=False,
|
94 |
+
max_tokens=95,
|
95 |
+
text_classifier_checkpoint=USER_CONFIG["text_ckpt"],
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def get_dims(audio_extractor, text_extractor) -> Dict[str, int]:
|
100 |
+
"""Возвращает фактические размеры эмбеддингов (audio_dim, text_dim)."""
|
101 |
+
sr = USER_CONFIG["sample_rate"]
|
102 |
+
with torch.no_grad():
|
103 |
+
dummy_wav = torch.zeros(1, sr)
|
104 |
+
_, a_emb = audio_extractor.extract(dummy_wav[0], sr)
|
105 |
+
audio_dim = a_emb[0].shape[-1]
|
106 |
+
|
107 |
+
_, t_emb = text_extractor.extract("dummy text")
|
108 |
+
text_dim = t_emb[0].shape[-1]
|
109 |
+
|
110 |
+
return {"audio_dim": audio_dim, "text_dim": text_dim}
|
111 |
+
|
112 |
+
|
113 |
+
def check_row(
|
114 |
+
row: pd.Series,
|
115 |
+
feats: Dict[str, object],
|
116 |
+
dims: Dict[str, int],
|
117 |
+
wav_dir: Path,
|
118 |
+
) -> Optional[str]:
|
119 |
+
"""
|
120 |
+
Возвращает None, если пример корректный, иначе строку-причину.
|
121 |
+
"""
|
122 |
+
video = row["video_name"]
|
123 |
+
wav_path = wav_dir / f"{video}.wav"
|
124 |
+
text = row.get("text", "")
|
125 |
+
|
126 |
+
try:
|
127 |
+
if not wav_path.exists():
|
128 |
+
return "file_missing"
|
129 |
+
|
130 |
+
# ---------- аудио ----------
|
131 |
+
wf, sr = torchaudio.load(str(wav_path))
|
132 |
+
if sr != USER_CONFIG["sample_rate"]:
|
133 |
+
wf = torchaudio.transforms.Resample(sr, USER_CONFIG["sample_rate"])(wf)
|
134 |
+
|
135 |
+
a_pred, a_emb = feats["audio"].extract(wf[0], USER_CONFIG["sample_rate"])
|
136 |
+
a_emb = a_emb[0]
|
137 |
+
if a_emb.shape[-1] != dims["audio_dim"]:
|
138 |
+
return f"audio_dim_{a_emb.shape[-1]}"
|
139 |
+
|
140 |
+
# ---------- текст ----------
|
141 |
+
t_pred, t_emb = feats["text"].extract(text)
|
142 |
+
t_emb = t_emb[0]
|
143 |
+
if t_emb.shape[-1] != dims["text_dim"]:
|
144 |
+
return f"text_dim_{t_emb.shape[-1]}"
|
145 |
+
|
146 |
+
# ---------- конкатенация ----------
|
147 |
+
full_vec = torch.cat(
|
148 |
+
[a_emb, t_emb, a_pred[0], t_pred[0]],
|
149 |
+
dim=-1,
|
150 |
+
)
|
151 |
+
expected_all = (
|
152 |
+
dims["audio_dim"]
|
153 |
+
+ dims["text_dim"]
|
154 |
+
+ 2 * USER_CONFIG["num_emotions"]
|
155 |
+
)
|
156 |
+
if full_vec.shape[-1] != expected_all:
|
157 |
+
return f"concat_dim_{full_vec.shape[-1]}"
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
logging.error(f"{video}: {traceback.format_exc(limit=2)}")
|
161 |
+
return "exception_" + e.__class__.__name__
|
162 |
+
|
163 |
+
return None
|
164 |
+
|
165 |
+
|
166 |
+
# ----------------------------------------------------------------------
|
167 |
+
# основной скрипт
|
168 |
+
# ----------------------------------------------------------------------
|
169 |
+
def main() -> None:
|
170 |
+
syn_root = Path(USER_CONFIG["synthetic_path"])
|
171 |
+
csv_path = syn_root / USER_CONFIG["csv_name"]
|
172 |
+
wav_dir = syn_root / USER_CONFIG["wav_subdir"]
|
173 |
+
|
174 |
+
if not csv_path.exists():
|
175 |
+
sys.exit(f"CSV не найден: {csv_path}")
|
176 |
+
if not wav_dir.exists():
|
177 |
+
sys.exit(f"WAV-директория не найдена: {wav_dir}")
|
178 |
+
|
179 |
+
# 1. экстракторы
|
180 |
+
audio_feat = PretrainedAudioEmbeddingExtractor(build_audio_cfg())
|
181 |
+
text_feat = PretrainedTextEmbeddingExtractor(build_text_cfg())
|
182 |
+
feats = {"audio": audio_feat, "text": text_feat}
|
183 |
+
|
184 |
+
# 2. реальные размерности
|
185 |
+
dims = get_dims(audio_feat, text_feat)
|
186 |
+
expected_total = (
|
187 |
+
dims["audio_dim"] + dims["text_dim"] + 2 * USER_CONFIG["num_emotions"]
|
188 |
+
)
|
189 |
+
print(
|
190 |
+
f"Audio dim = {dims['audio_dim']}, "
|
191 |
+
f"Text dim = {dims['text_dim']}, "
|
192 |
+
f"Expected concat = {expected_total}"
|
193 |
+
)
|
194 |
+
|
195 |
+
# 3. правим CSV
|
196 |
+
df = pd.read_csv(csv_path)
|
197 |
+
bad_rows: List[Dict[str, str]] = []
|
198 |
+
good_cnt = 0
|
199 |
+
|
200 |
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Checking"):
|
201 |
+
reason = check_row(row, feats, dims, wav_dir)
|
202 |
+
if reason:
|
203 |
+
bad_rows.append(
|
204 |
+
{
|
205 |
+
"video_name": row["video_name"],
|
206 |
+
"reason": reason,
|
207 |
+
"wav_path": str(wav_dir / f"{row['video_name']}.wav"),
|
208 |
+
}
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
good_cnt += 1
|
212 |
+
|
213 |
+
# 4. отчёт
|
214 |
+
print("\n========== SUMMARY ==========")
|
215 |
+
print(f"✅ GOOD : {good_cnt}")
|
216 |
+
print(f"❌ BAD : {len(bad_rows)}")
|
217 |
+
|
218 |
+
if bad_rows:
|
219 |
+
out_csv = Path(__file__).with_name("bad_synth_meld.csv")
|
220 |
+
with open(out_csv, "w", newline="", encoding="utf-8") as f:
|
221 |
+
writer = csv.DictWriter(
|
222 |
+
f, fieldnames=["video_name", "reason", "wav_path"]
|
223 |
+
)
|
224 |
+
writer.writeheader()
|
225 |
+
writer.writerows(bad_rows)
|
226 |
+
print(f"\nСписок проблемных примеров сохранён: {out_csv.resolve()}")
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
main()
|
config.toml
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---------------------------
|
2 |
+
# Настройки корпусов данных
|
3 |
+
# ---------------------------
|
4 |
+
|
5 |
+
[datasets.meld]
|
6 |
+
base_dir = "E:/MELD"
|
7 |
+
csv_path = "{base_dir}/meld_{split}_labels.csv"
|
8 |
+
wav_dir = "{base_dir}/wavs/{split}"
|
9 |
+
|
10 |
+
[datasets.resd]
|
11 |
+
base_dir = "E:/RESD"
|
12 |
+
csv_path = "{base_dir}/resd_{split}_labels.csv"
|
13 |
+
wav_dir = "{base_dir}/wavs/{split}"
|
14 |
+
|
15 |
+
[synthetic_data]
|
16 |
+
use_synthetic_data = false
|
17 |
+
synthetic_path = "E:/MELD_S"
|
18 |
+
synthetic_ratio = 0.005
|
19 |
+
|
20 |
+
# ---------------------------
|
21 |
+
# Список модальностей и эмоций
|
22 |
+
# ---------------------------
|
23 |
+
modalities = ["audio"]
|
24 |
+
# emotion_columns = ["neutral", "happy", "sad", "anger", "surprise", "disgust", "fear"]
|
25 |
+
emotion_columns = ["anger", "disgust", "fear", "happy", "neutral", "sad", "surprise"]
|
26 |
+
|
27 |
+
# ---------------------------
|
28 |
+
# DataLoader параметры
|
29 |
+
# ---------------------------
|
30 |
+
[dataloader]
|
31 |
+
num_workers = 0
|
32 |
+
shuffle = true
|
33 |
+
prepare_only = false
|
34 |
+
|
35 |
+
# ---------------------------
|
36 |
+
# Аудио
|
37 |
+
# ---------------------------
|
38 |
+
[audio]
|
39 |
+
sample_rate = 16000 # Целевая частота дискретизации
|
40 |
+
wav_length = 4 # Целевая длина (в секундах) для аудио
|
41 |
+
save_merged_audio = true
|
42 |
+
merged_audio_base_path = "saved_merges"
|
43 |
+
merged_audio_suffix = "_merged"
|
44 |
+
force_remerge = false
|
45 |
+
|
46 |
+
# ---------------------------
|
47 |
+
# Whisper и текст
|
48 |
+
# ---------------------------
|
49 |
+
[text]
|
50 |
+
# Если "csv", то мы стараемся брать текст из CSV, если там есть
|
51 |
+
# (поле text_column). Если нет - тогда Whisper (если нужно).
|
52 |
+
source = "csv"
|
53 |
+
text_column = "text"
|
54 |
+
whisper_model = "base"
|
55 |
+
|
56 |
+
# Указываем, где запускать Whisper: "cuda" (GPU) или "cpu"
|
57 |
+
whisper_device = "cuda"
|
58 |
+
|
59 |
+
# Если для dev/test в CSV нет текста, нужно ли всё же вызывать Whisper?
|
60 |
+
use_whisper_for_nontrain_if_no_text = true
|
61 |
+
|
62 |
+
# ---------------------------
|
63 |
+
# Общие параметры тренировки
|
64 |
+
# ---------------------------
|
65 |
+
[train.general]
|
66 |
+
random_seed = 42 # фиксируем random seed для воспроизводимости (0 = каждый раз разный)
|
67 |
+
subset_size = 100 # ограничение на количество примеров (0 = использовать весь датасет)
|
68 |
+
merge_probability = 0 # процент склеивания коротких файлов
|
69 |
+
batch_size = 8 # размер батча
|
70 |
+
num_epochs = 75 # число эпох тренировки
|
71 |
+
max_patience = 10 # максимальное число эпох без улучшений (для Early Stopping)
|
72 |
+
save_best_model = false
|
73 |
+
save_prepared_data = true # сохранять извлеченные признаки (эмбеддинги)
|
74 |
+
save_feature_path = './features/' # путь для сохранения эмбеддингов
|
75 |
+
search_type = "none" # стратегия поиска: "greedy", "exhaustive" или "none"
|
76 |
+
path_to_df_ls = 'Phi-4-mini-instruct_emotions_union.csv' # путь к датафрейму со смягченными метками - Qwen3-4B_emotions_union или Phi-4-mini-instruct_emotions_union
|
77 |
+
smoothing_probability = 0.0 # процент использования смягченных меток
|
78 |
+
|
79 |
+
# ---------------------------
|
80 |
+
# Параметры модели
|
81 |
+
# ---------------------------
|
82 |
+
[train.model]
|
83 |
+
model_name = "BiFormer" # название модели (BiGraphFormer, BiFormer, BiGatedGraphFormer, BiGatedFormer, BiMamba, PredictionsFusion, BiFormerWithProb, BiMambaWithProb, BiGraphFormerWithProb, BiGatedGraphFormerWithProb)
|
84 |
+
hidden_dim = 256 # размер скрытого состояния
|
85 |
+
hidden_dim_gated = 128 # скрытое состояние для gated механизмов
|
86 |
+
num_transformer_heads = 16 # количество attention голов в трансформере
|
87 |
+
num_graph_heads = 2 # количество голов в граф-механизме
|
88 |
+
tr_layer_number = 5 # количество слоев в трансформере
|
89 |
+
mamba_d_state = 16 # размер состояния в Mamba
|
90 |
+
mamba_ker_size = 6 # размер кернела в Mamba
|
91 |
+
mamba_layer_number = 5 # количество слоев Mamba
|
92 |
+
positional_encoding = false # использовать ли позиционное кодирование
|
93 |
+
dropout = 0.15 # dropout между слоями
|
94 |
+
out_features = 256 # размер финальных признаков перед классификацией
|
95 |
+
mode = 'mean' # способ агрегации признаков (например, "mean", "max", и т.д.)
|
96 |
+
|
97 |
+
# ---------------------------
|
98 |
+
# Параметры оптимизатора
|
99 |
+
# ---------------------------
|
100 |
+
[train.optimizer]
|
101 |
+
optimizer = "adam" # тип оптимизатора: "adam", "adamw", "lion", "sgd", "rmsprop"
|
102 |
+
lr = 1e-4 # начальная ��корость обучения
|
103 |
+
weight_decay = 0.0 # weight decay для регуляризации
|
104 |
+
momentum = 0.9 # momentum (используется только в SGD)
|
105 |
+
|
106 |
+
# ---------------------------
|
107 |
+
# Параметры шедулера
|
108 |
+
# ---------------------------
|
109 |
+
[train.scheduler]
|
110 |
+
scheduler_type = "plateau" # тип шедулера: "none", "plateau", "cosine", "onecycle" ил и HuggingFace-стиль ("huggingface_linear", "huggingface_cosine" "huggingface_cosine_with_restarts" и т.д.)
|
111 |
+
warmup_ratio = 0.1 # отношение количества warmup-итераций к общему числу шагов (0.1 = 10%)
|
112 |
+
|
113 |
+
[embeddings]
|
114 |
+
# audio_model = "amiriparian/ExHuBERT" # Hugging Face имя модели для аудио
|
115 |
+
audio_model = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" # Hugging Face имя модели для аудио
|
116 |
+
audio_classifier_checkpoint = "best_audio_model_2.pt"
|
117 |
+
text_classifier_checkpoint = "best_text_model.pth"
|
118 |
+
text_model = "jinaai/jina-embeddings-v3" # Hugging Face имя модели для текста
|
119 |
+
audio_embedding_dim = 256 # размерность аудио-эмбеддинга
|
120 |
+
text_embedding_dim = 1024 # размерность текст-эмбеддинга
|
121 |
+
emb_normalize = false # нормализовать ли вектор L2-нормой
|
122 |
+
max_tokens = 95 # ограничение на длину текста (токенов) при токенизации
|
123 |
+
device = "cuda" # "cuda" или "cpu", куда грузить модель
|
124 |
+
|
125 |
+
# audio_pooling = "mean" # "mean", "cls", "max", "min", "last", "attention"
|
126 |
+
# text_pooling = "cls" # "mean", "cls", "max", "min", "last", "sum", "attention"
|
127 |
+
|
128 |
+
[textgen]
|
129 |
+
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # deepseek-ai/deepseek-llm-1.3b-base или любая другая модель
|
130 |
+
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # deepseek-ai/deepseek-llm-1.3b-base или любая другая модель
|
131 |
+
max_new_tokens = 50
|
132 |
+
temperature = 1.0
|
133 |
+
top_p = 0.95
|
data_loading/__pycache__/feature_extractor.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
data_loading/__pycache__/pretrained_extractors.cpython-310.pyc
ADDED
Binary file (9.4 kB). View file
|
|
data_loading/dataset_multimodal.py
ADDED
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import logging
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
import whisper
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
import pickle
|
13 |
+
from tqdm import tqdm
|
14 |
+
# from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor
|
15 |
+
|
16 |
+
class DatasetMultiModalWithPretrainedExtractors(Dataset):
|
17 |
+
"""
|
18 |
+
Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия).
|
19 |
+
|
20 |
+
При каждом вызове __getitem__:
|
21 |
+
- Загружает WAV по video_name из CSV.
|
22 |
+
- Для обучающей выборки (split="train"):
|
23 |
+
Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки
|
24 |
+
(по merge_probability). Если да – выполняется "chain merge":
|
25 |
+
выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее,
|
26 |
+
и итоговое аудио затем обрезается до точной длины.
|
27 |
+
- Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями.
|
28 |
+
- Текст выбирается так:
|
29 |
+
• Если аудио было merged (склеено) – вызывается Whisper для получения нового текста.
|
30 |
+
• Если merge не происходило и CSV-текст не пуст – используется CSV-текст.
|
31 |
+
• Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper.
|
32 |
+
- Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
csv_path,
|
38 |
+
wav_dir,
|
39 |
+
emotion_columns,
|
40 |
+
config,
|
41 |
+
split,
|
42 |
+
audio_feature_extractor,
|
43 |
+
text_feature_extractor,
|
44 |
+
whisper_model,
|
45 |
+
dataset_name
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
:param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text).
|
49 |
+
:param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav).
|
50 |
+
:param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...].
|
51 |
+
:param split: "train", "dev" или "test".
|
52 |
+
:param audio_feature_extractor: Экстрактор аудио признаков
|
53 |
+
:param text_feature_extractor: Экстрактор текстовых признаков
|
54 |
+
:param sample_rate: Целевая частота дискретизации (например, 16000).
|
55 |
+
:param wav_length: Целевая длина аудио в секундах.
|
56 |
+
:param whisper_model: Mодель Whisper ("tiny", "base", "small", ...).
|
57 |
+
:param max_text_tokens: (Не используется) – ограничение на число токенов.
|
58 |
+
:param text_column: Название колонки с текстом в CSV.
|
59 |
+
:param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper.
|
60 |
+
:param whisper_device: "cuda" или "cpu" – устройство для модели Whisper.
|
61 |
+
:param subset_size: Если > 0, используется только первые N записей из CSV (для отладки).
|
62 |
+
:param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче.
|
63 |
+
:param dataset_name: Название корпуса
|
64 |
+
"""
|
65 |
+
super().__init__()
|
66 |
+
self.split = split
|
67 |
+
self.sample_rate = config.sample_rate
|
68 |
+
self.target_samples = int(config.wav_length * self.sample_rate)
|
69 |
+
self.emotion_columns = emotion_columns
|
70 |
+
self.whisper_model = whisper_model
|
71 |
+
self.text_column = config.text_column
|
72 |
+
self.use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text
|
73 |
+
self.whisper_device = config.whisper_device
|
74 |
+
self.merge_probability = config.merge_probability
|
75 |
+
self.audio_feature_extractor = audio_feature_extractor
|
76 |
+
self.text_feature_extractor = text_feature_extractor
|
77 |
+
self.subset_size = config.subset_size
|
78 |
+
self.save_prepared_data = config.save_prepared_data
|
79 |
+
self.seed = config.random_seed
|
80 |
+
self.dataset_name = dataset_name
|
81 |
+
self.save_feature_path = config.save_feature_path
|
82 |
+
self.use_synthetic_data = config.use_synthetic_data
|
83 |
+
self.synthetic_path = config.synthetic_path
|
84 |
+
self.synthetic_ratio = config.synthetic_ratio
|
85 |
+
|
86 |
+
# Загружаем CSV
|
87 |
+
if not os.path.exists(csv_path):
|
88 |
+
raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}")
|
89 |
+
df = pd.read_csv(csv_path)
|
90 |
+
if self.subset_size > 0:
|
91 |
+
df = df.head(self.subset_size)
|
92 |
+
logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={self.subset_size}).")
|
93 |
+
|
94 |
+
#копия для сохранения текста Wisper
|
95 |
+
self.original_df = df.copy()
|
96 |
+
self.whisper_csv_update_log = []
|
97 |
+
|
98 |
+
# Проверяем наличие всех колонок эмоций
|
99 |
+
missing = [c for c in emotion_columns if c not in df.columns]
|
100 |
+
if missing:
|
101 |
+
raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}")
|
102 |
+
|
103 |
+
# Проверяем существование папки с аудио
|
104 |
+
if not os.path.exists(wav_dir):
|
105 |
+
raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!")
|
106 |
+
self.wav_dir = wav_dir
|
107 |
+
|
108 |
+
# Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть)
|
109 |
+
self.rows = []
|
110 |
+
for i, rowi in df.iterrows():
|
111 |
+
audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav")
|
112 |
+
if not os.path.exists(audio_path):
|
113 |
+
continue
|
114 |
+
# Определяем доминирующую эмоцию (максимальное значение)
|
115 |
+
# print(self.emotion_columns)
|
116 |
+
emotion_values = rowi[self.emotion_columns].values.astype(float)
|
117 |
+
max_idx = np.argmax(emotion_values)
|
118 |
+
emotion_label = self.emotion_columns[max_idx]
|
119 |
+
|
120 |
+
# Извлекаем текст из CSV (если есть)
|
121 |
+
csv_text = ""
|
122 |
+
if self.text_column in rowi and isinstance(rowi[self.text_column], str):
|
123 |
+
csv_text = rowi[self.text_column]
|
124 |
+
|
125 |
+
self.rows.append({
|
126 |
+
"audio_path": audio_path,
|
127 |
+
"label": emotion_label,
|
128 |
+
"csv_text": csv_text
|
129 |
+
})
|
130 |
+
|
131 |
+
if self.use_synthetic_data and self.split == "train" and self.dataset_name.lower() == "meld":
|
132 |
+
logging.info(f"🧪 Включена синтетика для датасета '{self.dataset_name}' — добавляем примеры из: {self.synthetic_path}")
|
133 |
+
self._add_synthetic_data(self.synthetic_ratio)
|
134 |
+
|
135 |
+
# Создаем карту для поиска файлов по эмоции
|
136 |
+
self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows}
|
137 |
+
|
138 |
+
logging.info("📊 Анализ распределения файлов по эмоциям:")
|
139 |
+
emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())}
|
140 |
+
for path, emotion in self.audio_class_map.items():
|
141 |
+
emotion_counts[emotion] += 1
|
142 |
+
for emotion, count in emotion_counts.items():
|
143 |
+
logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.")
|
144 |
+
|
145 |
+
logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}")
|
146 |
+
|
147 |
+
# === Процентное семплирование ===
|
148 |
+
total_files = len(self.rows)
|
149 |
+
num_to_merge = int(total_files * self.merge_probability)
|
150 |
+
|
151 |
+
# <<< NEW: Кешируем длины (eq_len) для всех файлов >>>
|
152 |
+
self.path_info = {}
|
153 |
+
for row in self.rows:
|
154 |
+
p = row["audio_path"]
|
155 |
+
try:
|
156 |
+
info = torchaudio.info(p)
|
157 |
+
length = info.num_frames
|
158 |
+
sr_ = info.sample_rate
|
159 |
+
# переводим длину в "эквивалент self.sample_rate"
|
160 |
+
if sr_ != self.sample_rate:
|
161 |
+
ratio = sr_ / self.sample_rate
|
162 |
+
eq_len = int(length / ratio)
|
163 |
+
else:
|
164 |
+
eq_len = length
|
165 |
+
self.path_info[p] = eq_len
|
166 |
+
except Exception as e:
|
167 |
+
logging.warning(f"⚠️ Ошибка чтения {p}: {e}")
|
168 |
+
self.path_info[p] = 0 # Если не смогли прочитать, ставим 0
|
169 |
+
|
170 |
+
# Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short
|
171 |
+
self.mergable_files = [
|
172 |
+
row["audio_path"] # вместо целого dict берём строку
|
173 |
+
for row in self.rows
|
174 |
+
if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию
|
175 |
+
]
|
176 |
+
short_count = len(self.mergable_files)
|
177 |
+
|
178 |
+
# Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие.
|
179 |
+
if short_count > num_to_merge:
|
180 |
+
self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge))
|
181 |
+
else:
|
182 |
+
self.files_to_merge = set(self.mergable_files)
|
183 |
+
|
184 |
+
logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)")
|
185 |
+
logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}")
|
186 |
+
|
187 |
+
if self.save_prepared_data:
|
188 |
+
self.meta = []
|
189 |
+
|
190 |
+
if self.use_synthetic_data:
|
191 |
+
meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_synthetic_true_pct_{}_pred.pickle'.format(
|
192 |
+
self.dataset_name,
|
193 |
+
self.split,
|
194 |
+
config.audio_classifier_checkpoint[-4:-3],
|
195 |
+
self.seed,
|
196 |
+
self.subset_size,
|
197 |
+
config.emb_normalize,
|
198 |
+
int(self.synthetic_ratio * 100)
|
199 |
+
)
|
200 |
+
|
201 |
+
else:
|
202 |
+
meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_merge_prob_{}_pred.pickle'.format(
|
203 |
+
self.dataset_name,
|
204 |
+
self.split,
|
205 |
+
config.audio_classifier_checkpoint[-4:-3],
|
206 |
+
self.seed,
|
207 |
+
self.subset_size,
|
208 |
+
config.emb_normalize,
|
209 |
+
self.merge_probability
|
210 |
+
)
|
211 |
+
|
212 |
+
pickle_path = os.path.join(self.save_feature_path, meta_filename)
|
213 |
+
self.load_data(pickle_path)
|
214 |
+
|
215 |
+
if not self.meta:
|
216 |
+
self.prepare_data()
|
217 |
+
os.makedirs(self.save_feature_path, exist_ok=True)
|
218 |
+
self.save_data(pickle_path)
|
219 |
+
|
220 |
+
def save_data(self, filename):
|
221 |
+
with open(filename, 'wb') as handle:
|
222 |
+
pickle.dump(self.meta, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
223 |
+
|
224 |
+
def load_data(self, filename):
|
225 |
+
if os.path.exists(filename):
|
226 |
+
with open(filename, 'rb') as handle:
|
227 |
+
self.meta = pickle.load(handle)
|
228 |
+
else:
|
229 |
+
self.meta = []
|
230 |
+
|
231 |
+
def _is_too_short(self, audio_path):
|
232 |
+
"""
|
233 |
+
(Оригинальная) Проверяем, является ли файл короче target_samples.
|
234 |
+
Использует torchaudio.info(audio_path).
|
235 |
+
Но теперь этот метод не используется, поскольку мы кешируем длины.
|
236 |
+
"""
|
237 |
+
try:
|
238 |
+
info = torchaudio.info(audio_path)
|
239 |
+
length = info.num_frames
|
240 |
+
sr_ = info.sample_rate
|
241 |
+
# переводим длину в "эквивалент self.sample_rate"
|
242 |
+
if sr_ != self.sample_rate:
|
243 |
+
ratio = sr_ / self.sample_rate
|
244 |
+
eq_len = int(length / ratio)
|
245 |
+
else:
|
246 |
+
eq_len = length
|
247 |
+
return eq_len < self.target_samples
|
248 |
+
except Exception as e:
|
249 |
+
logging.warning(f"Ошибка _is_too_short({audio_path}): {e}")
|
250 |
+
return False
|
251 |
+
|
252 |
+
def _is_too_short_cached(self, audio_path):
|
253 |
+
"""
|
254 |
+
(Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info.
|
255 |
+
"""
|
256 |
+
eq_len = self.path_info.get(audio_path, 0)
|
257 |
+
return eq_len < self.target_samples
|
258 |
+
|
259 |
+
def __len__(self):
|
260 |
+
if self.save_prepared_data:
|
261 |
+
return len(self.meta)
|
262 |
+
else:
|
263 |
+
return len(self.rows)
|
264 |
+
|
265 |
+
def get_data(self, row):
|
266 |
+
audio_path = row["audio_path"]
|
267 |
+
label_name = row["label"]
|
268 |
+
csv_text = row["csv_text"]
|
269 |
+
|
270 |
+
# Преобразуем label в one-hot вектор
|
271 |
+
label_vec = self.emotion_to_vector(label_name)
|
272 |
+
|
273 |
+
# Шаг 1. Загружаем аудио
|
274 |
+
waveform, sr = self.load_audio(audio_path)
|
275 |
+
if waveform is None:
|
276 |
+
return None
|
277 |
+
|
278 |
+
orig_len = waveform.shape[1]
|
279 |
+
logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек")
|
280 |
+
|
281 |
+
was_merged = False
|
282 |
+
merged_texts = [csv_text] # Тексты исходного файла + добавленных
|
283 |
+
|
284 |
+
# Шаг 2. Для train, если аудио короче target_samples, проверяем:
|
285 |
+
# попал ли данный row в files_to_merge?
|
286 |
+
if self.split == "train" and row["audio_path"] in self.files_to_merge:
|
287 |
+
# chain merge
|
288 |
+
current_length = orig_len
|
289 |
+
used_candidates = set()
|
290 |
+
|
291 |
+
while current_length < self.target_samples:
|
292 |
+
needed = self.target_samples - current_length
|
293 |
+
candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10)
|
294 |
+
if candidate is None or candidate in used_candidates:
|
295 |
+
break
|
296 |
+
used_candidates.add(candidate)
|
297 |
+
add_wf, add_sr = self.load_audio(candidate)
|
298 |
+
if add_wf is None:
|
299 |
+
break
|
300 |
+
logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})")
|
301 |
+
waveform = torch.cat((waveform, add_wf), dim=1)
|
302 |
+
current_length = waveform.shape[1]
|
303 |
+
was_merged = True
|
304 |
+
|
305 |
+
# Получаем текст второго файла (если есть в CSV)
|
306 |
+
add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "")
|
307 |
+
merged_texts.append(add_csv_text)
|
308 |
+
|
309 |
+
logging.debug(f"📜 Текст первого файла: {csv_text}")
|
310 |
+
logging.debug(f"📜 Текст добавленного файла: {add_csv_text}")
|
311 |
+
else:
|
312 |
+
# Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge
|
313 |
+
logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.")
|
314 |
+
|
315 |
+
if was_merged:
|
316 |
+
logging.debug("📝 Текст: аудио было merged – вызываем Whisper.")
|
317 |
+
text_final = self.run_whisper(waveform)
|
318 |
+
logging.debug(f"🆕 Whisper предсказал: {text_final}")
|
319 |
+
|
320 |
+
merge_components = [os.path.splitext(os.path.basename(audio_path))[0]]
|
321 |
+
merge_components += [os.path.splitext(os.path.basename(p))[0] for p in used_candidates]
|
322 |
+
|
323 |
+
self.whisper_csv_update_log.append({
|
324 |
+
"video_name": os.path.splitext(os.path.basename(audio_path))[0],
|
325 |
+
"text_new": text_final,
|
326 |
+
"text_old": csv_text,
|
327 |
+
"was_merged": True,
|
328 |
+
"merge_components": merge_components
|
329 |
+
})
|
330 |
+
|
331 |
+
else:
|
332 |
+
if csv_text.strip():
|
333 |
+
logging.debug("Текст: используем CSV-текст (не пуст).")
|
334 |
+
text_final = csv_text
|
335 |
+
else:
|
336 |
+
if self.split == "train" or self.use_whisper_for_nontrain_if_no_text:
|
337 |
+
logging.debug("Текст: CSV пустой – вызываем Whisper.")
|
338 |
+
text_final = self.run_whisper(waveform)
|
339 |
+
else:
|
340 |
+
logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.")
|
341 |
+
text_final = ""
|
342 |
+
|
343 |
+
audio_pred, audion_emb = self.audio_feature_extractor.extract(waveform[0], self.sample_rate)
|
344 |
+
text_pred, text_emb = self.text_feature_extractor.extract(text_final)
|
345 |
+
|
346 |
+
return {
|
347 |
+
"audio_path": os.path.basename(audio_path),
|
348 |
+
"audio": audion_emb[0],
|
349 |
+
"label": label_vec,
|
350 |
+
"text": text_emb[0],
|
351 |
+
"audio_pred": audio_pred[0],
|
352 |
+
"text_pred": text_pred[0]
|
353 |
+
}
|
354 |
+
|
355 |
+
def prepare_data(self):
|
356 |
+
"""
|
357 |
+
Загружает и обрабатывает один элемент датасета,
|
358 |
+
сохраняет эмбеддинги и обновлённый текст (если было склеено).
|
359 |
+
"""
|
360 |
+
for idx, row in enumerate(tqdm(self.rows)):
|
361 |
+
curr_dict = self.get_data(row)
|
362 |
+
if curr_dict is not None:
|
363 |
+
self.meta.append(curr_dict)
|
364 |
+
|
365 |
+
# === Сохраняем CSV с обновлёнными текстами (только если был merge) ===
|
366 |
+
if self.whisper_csv_update_log:
|
367 |
+
df_log = pd.DataFrame(self.whisper_csv_update_log)
|
368 |
+
|
369 |
+
# Копия исходного CSV
|
370 |
+
df_out = self.original_df.copy()
|
371 |
+
|
372 |
+
# Мержим по video_name
|
373 |
+
df_out = df_out.merge(df_log, on="video_name", how="left")
|
374 |
+
|
375 |
+
# Обновляем текст: заменяем только если Whisper сгенерировал
|
376 |
+
df_out["text_final"] = df_out["text_new"].combine_first(df_out["text"])
|
377 |
+
df_out["text_old"] = df_out["text"]
|
378 |
+
df_out["text"] = df_out["text_final"]
|
379 |
+
df_out["was_merged"] = df_out["was_merged"].fillna(False).astype(bool)
|
380 |
+
|
381 |
+
# Преобразуем merge_components в строку
|
382 |
+
df_out["merge_components"] = df_out["merge_components"].apply(
|
383 |
+
lambda x: ";".join(x) if isinstance(x, list) else ""
|
384 |
+
)
|
385 |
+
|
386 |
+
# Чистим временные колонки
|
387 |
+
df_out = df_out.drop(columns=["text_new", "text_final"])
|
388 |
+
|
389 |
+
# Сохраняем как CSV
|
390 |
+
output_path = os.path.join(self.save_feature_path, f"{self.dataset_name}_{self.split}_merged_whisper_{self.merge_probability *100}.csv")
|
391 |
+
os.makedirs(self.save_feature_path, exist_ok=True)
|
392 |
+
df_out.to_csv(output_path, index=False, encoding="utf-8")
|
393 |
+
logging.info(f"📄 Обновлённый merged CSV сохранён: {output_path}")
|
394 |
+
|
395 |
+
def __getitem__(self, index):
|
396 |
+
if self.save_prepared_data:
|
397 |
+
return self.meta[index]
|
398 |
+
else:
|
399 |
+
return self.get_data(self.rows[index])
|
400 |
+
|
401 |
+
def load_audio(self, path):
|
402 |
+
"""
|
403 |
+
Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо.
|
404 |
+
"""
|
405 |
+
if not os.path.exists(path):
|
406 |
+
logging.warning(f"Файл отсутствует: {path}")
|
407 |
+
return None, None
|
408 |
+
try:
|
409 |
+
wf, sr = torchaudio.load(path)
|
410 |
+
if sr != self.sample_rate:
|
411 |
+
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
|
412 |
+
wf = resampler(wf)
|
413 |
+
sr = self.sample_rate
|
414 |
+
return wf, sr
|
415 |
+
except Exception as e:
|
416 |
+
logging.error(f"Ошибка загрузки {path}: {e}")
|
417 |
+
return None, None
|
418 |
+
|
419 |
+
def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5):
|
420 |
+
"""
|
421 |
+
Ищет аудиофайл с той же эмоцией.
|
422 |
+
1) Если есть файлы >= min_needed, выбираем случайно из них.
|
423 |
+
2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный.
|
424 |
+
"""
|
425 |
+
|
426 |
+
candidates = [p for p, lbl in self.audio_class_map.items()
|
427 |
+
if lbl == label_name and p != exclude_path]
|
428 |
+
logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'")
|
429 |
+
|
430 |
+
# Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info
|
431 |
+
all_info = []
|
432 |
+
for path in candidates:
|
433 |
+
# <<< NEW: вместо info = torchaudio.info(path) ...
|
434 |
+
eq_len = self.path_info.get(path, 0) # Получаем из кэша
|
435 |
+
all_info.append((eq_len, path))
|
436 |
+
|
437 |
+
valid = [(l, p) for l, p in all_info if l >= min_needed]
|
438 |
+
logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})")
|
439 |
+
|
440 |
+
if valid:
|
441 |
+
# Если есть идеальные — берём случайно из них
|
442 |
+
random.shuffle(valid)
|
443 |
+
chosen = random.choice(valid)[1]
|
444 |
+
return chosen
|
445 |
+
else:
|
446 |
+
# 2) Если идеальных нет — берём топ-K по длине
|
447 |
+
sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True)
|
448 |
+
top_k_list = sorted_by_len[:top_k]
|
449 |
+
if not top_k_list:
|
450 |
+
logging.debug("Нет доступных кандидатов вообще.")
|
451 |
+
return None # вообще нет кандидатов
|
452 |
+
|
453 |
+
random.shuffle(top_k_list)
|
454 |
+
chosen = top_k_list[0][1]
|
455 |
+
logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}")
|
456 |
+
return chosen
|
457 |
+
|
458 |
+
def run_whisper(self, waveform):
|
459 |
+
"""
|
460 |
+
Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов).
|
461 |
+
"""
|
462 |
+
arr = waveform.squeeze().cpu().numpy()
|
463 |
+
try:
|
464 |
+
with torch.no_grad():
|
465 |
+
result = self.whisper_model.transcribe(arr, fp16=False)
|
466 |
+
text = result["text"].strip()
|
467 |
+
return text
|
468 |
+
except Exception as e:
|
469 |
+
logging.error(f"Whisper ошибка: {e}")
|
470 |
+
return ""
|
471 |
+
|
472 |
+
def _add_synthetic_data(self, synthetic_ratio):
|
473 |
+
"""
|
474 |
+
Добавляет synthetic_ratio (0..1) от количества доступных синтетических файлов на каждую эмоцию.
|
475 |
+
"""
|
476 |
+
if not self.synthetic_path:
|
477 |
+
logging.warning("⚠ Путь к синтетическим данным не указан.")
|
478 |
+
return
|
479 |
+
|
480 |
+
random.seed(self.seed)
|
481 |
+
|
482 |
+
synth_csv_path = os.path.join(self.synthetic_path, "meld_s_train_labels.csv")
|
483 |
+
synth_wav_dir = os.path.join(self.synthetic_path, "wavs")
|
484 |
+
|
485 |
+
if not (os.path.exists(synth_csv_path) and os.path.exists(synth_wav_dir)):
|
486 |
+
logging.warning("⚠ Синтетические данные не найдены.")
|
487 |
+
return
|
488 |
+
|
489 |
+
df_synth = pd.read_csv(synth_csv_path)
|
490 |
+
rows_by_label = {emotion: [] for emotion in self.emotion_columns}
|
491 |
+
|
492 |
+
for _, row in df_synth.iterrows():
|
493 |
+
audio_path = os.path.join(synth_wav_dir, f"{row['video_name']}.wav")
|
494 |
+
if not os.path.exists(audio_path):
|
495 |
+
continue
|
496 |
+
emotion_values = row[self.emotion_columns].values.astype(float)
|
497 |
+
max_idx = np.argmax(emotion_values)
|
498 |
+
label = self.emotion_columns[max_idx]
|
499 |
+
csv_text = row[self.text_column] if self.text_column in row and isinstance(row[self.text_column], str) else ""
|
500 |
+
rows_by_label[label].append({
|
501 |
+
"audio_path": audio_path,
|
502 |
+
"label": label,
|
503 |
+
"csv_text": csv_text
|
504 |
+
})
|
505 |
+
|
506 |
+
added = 0
|
507 |
+
for label in self.emotion_columns:
|
508 |
+
candidates = rows_by_label[label]
|
509 |
+
if not candidates:
|
510 |
+
continue
|
511 |
+
count_synth = int(len(candidates) * synthetic_ratio)
|
512 |
+
if count_synth <= 0:
|
513 |
+
continue
|
514 |
+
selected = random.sample(candidates, count_synth)
|
515 |
+
self.rows.extend(selected)
|
516 |
+
added += len(selected)
|
517 |
+
logging.info(f"➕ Добавлено {len(selected)} синтетических примеров для эмоции '{label}'")
|
518 |
+
|
519 |
+
logging.info(f"📦 Всего добавлено {added} синтетических примеров из MELD_S")
|
520 |
+
|
521 |
+
def emotion_to_vector(self, label_name):
|
522 |
+
"""
|
523 |
+
Преобразует название эмоции в one-hot вектор (torch.tensor).
|
524 |
+
"""
|
525 |
+
v = np.zeros(len(self.emotion_columns), dtype=np.float32)
|
526 |
+
if label_name in self.emotion_columns:
|
527 |
+
idx = self.emotion_columns.index(label_name)
|
528 |
+
v[idx] = 1.0
|
529 |
+
return torch.tensor(v, dtype=torch.float32)
|
530 |
+
|
531 |
+
class DatasetMultiModal(Dataset):
|
532 |
+
"""
|
533 |
+
Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия).
|
534 |
+
|
535 |
+
При каждом вызове __getitem__:
|
536 |
+
- Загружает WAV по video_name из CSV.
|
537 |
+
- Для обучающей выборки (split="train"):
|
538 |
+
Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки
|
539 |
+
(по merge_probability). Если да – выполняется "chain merge":
|
540 |
+
выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее,
|
541 |
+
и итоговое аудио затем обрезается до точной длины.
|
542 |
+
- Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями.
|
543 |
+
- Текст выбирается так:
|
544 |
+
• Если аудио было merged (склеено) – вызывается Whisper для получения нового текста.
|
545 |
+
• Если merge не происходило и CSV-текст не пуст – используется CSV-текст.
|
546 |
+
• Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper.
|
547 |
+
- Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }.
|
548 |
+
"""
|
549 |
+
|
550 |
+
def __init__(
|
551 |
+
self,
|
552 |
+
csv_path,
|
553 |
+
wav_dir,
|
554 |
+
emotion_columns,
|
555 |
+
split="train",
|
556 |
+
sample_rate=16000,
|
557 |
+
wav_length=4,
|
558 |
+
whisper_model="tiny",
|
559 |
+
text_column="text",
|
560 |
+
use_whisper_for_nontrain_if_no_text=True,
|
561 |
+
whisper_device="cuda",
|
562 |
+
subset_size=0,
|
563 |
+
merge_probability=1.0 # <-- Новый параметр: доля от ОБЩЕГО числа файлов
|
564 |
+
):
|
565 |
+
"""
|
566 |
+
:param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text).
|
567 |
+
:param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav).
|
568 |
+
:param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...].
|
569 |
+
:param split: "train", "dev" или "test".
|
570 |
+
:param sample_rate: Целевая частота дискретизации (например, 16000).
|
571 |
+
:param wav_length: Целевая длина аудио в секундах.
|
572 |
+
:param whisper_model: Название модели Whisper ("tiny", "base", "small", ...).
|
573 |
+
:param max_text_tokens: (Не используется) – ограничение на число токенов.
|
574 |
+
:param text_column: Название колонки с текстом в CSV.
|
575 |
+
:param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper.
|
576 |
+
:param whisper_device: "cuda" или "cpu" – устройство для модели Whisper.
|
577 |
+
:param subset_size: Если > 0, используется только первые N записей из CSV (для отладки).
|
578 |
+
:param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче.
|
579 |
+
"""
|
580 |
+
super().__init__()
|
581 |
+
self.split = split
|
582 |
+
self.sample_rate = sample_rate
|
583 |
+
self.target_samples = int(wav_length * sample_rate)
|
584 |
+
self.emotion_columns = emotion_columns
|
585 |
+
self.whisper_model_name = whisper_model
|
586 |
+
self.text_column = text_column
|
587 |
+
self.use_whisper_for_nontrain_if_no_text = use_whisper_for_nontrain_if_no_text
|
588 |
+
self.whisper_device = whisper_device
|
589 |
+
self.merge_probability = merge_probability
|
590 |
+
|
591 |
+
# Загружаем CSV
|
592 |
+
if not os.path.exists(csv_path):
|
593 |
+
raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}")
|
594 |
+
df = pd.read_csv(csv_path)
|
595 |
+
if subset_size > 0:
|
596 |
+
df = df.head(subset_size)
|
597 |
+
logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={subset_size}).")
|
598 |
+
|
599 |
+
# Проверяем наличие всех колонок эмоций
|
600 |
+
missing = [c for c in emotion_columns if c not in df.columns]
|
601 |
+
if missing:
|
602 |
+
raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}")
|
603 |
+
|
604 |
+
# Проверяем существование папки с аудио
|
605 |
+
if not os.path.exists(wav_dir):
|
606 |
+
raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!")
|
607 |
+
self.wav_dir = wav_dir
|
608 |
+
|
609 |
+
# Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть)
|
610 |
+
self.rows = []
|
611 |
+
for i, rowi in df.iterrows():
|
612 |
+
audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav")
|
613 |
+
if not os.path.exists(audio_path):
|
614 |
+
continue
|
615 |
+
# Определяем доминирующую эмоцию (максимальное значение)
|
616 |
+
emotion_values = rowi[self.emotion_columns].values.astype(float)
|
617 |
+
max_idx = np.argmax(emotion_values)
|
618 |
+
emotion_label = self.emotion_columns[max_idx]
|
619 |
+
|
620 |
+
# Извлекаем текст из CSV (если есть)
|
621 |
+
csv_text = ""
|
622 |
+
if self.text_column in rowi and isinstance(rowi[self.text_column], str):
|
623 |
+
csv_text = rowi[self.text_column]
|
624 |
+
|
625 |
+
self.rows.append({
|
626 |
+
"audio_path": audio_path,
|
627 |
+
"label": emotion_label,
|
628 |
+
"csv_text": csv_text
|
629 |
+
})
|
630 |
+
|
631 |
+
# Создаем карту для поиска файлов по эмоции
|
632 |
+
self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows}
|
633 |
+
|
634 |
+
logging.info("📊 Анализ распределения файлов по эмоциям:")
|
635 |
+
emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())}
|
636 |
+
for path, emotion in self.audio_class_map.items():
|
637 |
+
emotion_counts[emotion] += 1
|
638 |
+
for emotion, count in emotion_counts.items():
|
639 |
+
logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.")
|
640 |
+
|
641 |
+
logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}")
|
642 |
+
|
643 |
+
# === Процентное семплирование ===
|
644 |
+
total_files = len(self.rows)
|
645 |
+
num_to_merge = int(total_files * self.merge_probability)
|
646 |
+
|
647 |
+
# <<< NEW: Кешируем длины (eq_len) для всех файлов >>>
|
648 |
+
self.path_info = {}
|
649 |
+
for row in self.rows:
|
650 |
+
p = row["audio_path"]
|
651 |
+
try:
|
652 |
+
info = torchaudio.info(p)
|
653 |
+
length = info.num_frames
|
654 |
+
sr_ = info.sample_rate
|
655 |
+
# переводим длину в "эквивалент self.sample_rate"
|
656 |
+
if sr_ != self.sample_rate:
|
657 |
+
ratio = sr_ / self.sample_rate
|
658 |
+
eq_len = int(length / ratio)
|
659 |
+
else:
|
660 |
+
eq_len = length
|
661 |
+
self.path_info[p] = eq_len
|
662 |
+
except Exception as e:
|
663 |
+
logging.warning(f"⚠️ Ошибка чтения {p}: {e}")
|
664 |
+
self.path_info[p] = 0 # Если не смогли прочитать, ставим 0
|
665 |
+
|
666 |
+
# Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short
|
667 |
+
self.mergable_files = [
|
668 |
+
row["audio_path"] # вместо целого dict берём строку
|
669 |
+
for row in self.rows
|
670 |
+
if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию
|
671 |
+
]
|
672 |
+
short_count = len(self.mergable_files)
|
673 |
+
|
674 |
+
# Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие.
|
675 |
+
if short_count > num_to_merge:
|
676 |
+
self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge))
|
677 |
+
else:
|
678 |
+
self.files_to_merge = set(self.mergable_files)
|
679 |
+
|
680 |
+
logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)")
|
681 |
+
logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}")
|
682 |
+
|
683 |
+
# Инициализируем Whisper-модель один раз
|
684 |
+
logging.info(f"Инициализация Whisper: модель={whisper_model}, устройство={whisper_device}")
|
685 |
+
self.whisper_model = whisper.load_model(whisper_model, device=whisper_device).eval()
|
686 |
+
# print(f"📦 Whisper работает на устройстве: {self.whisper_model.device}")
|
687 |
+
|
688 |
+
def _is_too_short(self, audio_path):
|
689 |
+
"""
|
690 |
+
(Оригинальная) Проверяем, является ли файл короче target_samples.
|
691 |
+
Использует torchaudio.info(audio_path).
|
692 |
+
Но теперь этот метод не используется, поскольку мы кешируем длины.
|
693 |
+
"""
|
694 |
+
try:
|
695 |
+
info = torchaudio.info(audio_path)
|
696 |
+
length = info.num_frames
|
697 |
+
sr_ = info.sample_rate
|
698 |
+
# переводим длину в "эквивалент self.sample_rate"
|
699 |
+
if sr_ != self.sample_rate:
|
700 |
+
ratio = sr_ / self.sample_rate
|
701 |
+
eq_len = int(length / ratio)
|
702 |
+
else:
|
703 |
+
eq_len = length
|
704 |
+
return eq_len < self.target_samples
|
705 |
+
except Exception as e:
|
706 |
+
logging.warning(f"Ошибка _is_too_short({audio_path}): {e}")
|
707 |
+
return False
|
708 |
+
|
709 |
+
def _is_too_short_cached(self, audio_path):
|
710 |
+
"""
|
711 |
+
(Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info.
|
712 |
+
"""
|
713 |
+
eq_len = self.path_info.get(audio_path, 0)
|
714 |
+
return eq_len < self.target_samples
|
715 |
+
|
716 |
+
def __len__(self):
|
717 |
+
return len(self.rows)
|
718 |
+
|
719 |
+
def __getitem__(self, index):
|
720 |
+
"""
|
721 |
+
Загружает и обрабатывает один элемент датасета (он‑the‑fly).
|
722 |
+
"""
|
723 |
+
row = self.rows[index]
|
724 |
+
audio_path = row["audio_path"]
|
725 |
+
label_name = row["label"]
|
726 |
+
csv_text = row["csv_text"]
|
727 |
+
|
728 |
+
# Преобразуем label в one-hot вектор
|
729 |
+
label_vec = self.emotion_to_vector(label_name)
|
730 |
+
|
731 |
+
# Шаг 1. Загружаем аудио
|
732 |
+
waveform, sr = self.load_audio(audio_path)
|
733 |
+
if waveform is None:
|
734 |
+
return None
|
735 |
+
|
736 |
+
orig_len = waveform.shape[1]
|
737 |
+
logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек")
|
738 |
+
|
739 |
+
was_merged = False
|
740 |
+
merged_texts = [csv_text] # Тексты исходного файла + добавленных
|
741 |
+
|
742 |
+
# Шаг 2. Для train, если аудио короче target_samples, проверяем:
|
743 |
+
# попал ли данный row в files_to_merge?
|
744 |
+
if self.split == "train" and row["audio_path"] in self.files_to_merge:
|
745 |
+
# chain merge
|
746 |
+
current_length = orig_len
|
747 |
+
used_candidates = set()
|
748 |
+
|
749 |
+
while current_length < self.target_samples:
|
750 |
+
needed = self.target_samples - current_length
|
751 |
+
candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10)
|
752 |
+
if candidate is None or candidate in used_candidates:
|
753 |
+
break
|
754 |
+
used_candidates.add(candidate)
|
755 |
+
add_wf, add_sr = self.load_audio(candidate)
|
756 |
+
if add_wf is None:
|
757 |
+
break
|
758 |
+
logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})")
|
759 |
+
waveform = torch.cat((waveform, add_wf), dim=1)
|
760 |
+
current_length = waveform.shape[1]
|
761 |
+
was_merged = True
|
762 |
+
|
763 |
+
# Получаем текст второго файла (если есть в CSV)
|
764 |
+
add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "")
|
765 |
+
merged_texts.append(add_csv_text)
|
766 |
+
|
767 |
+
logging.debug(f"📜 Текст первого файла: {csv_text}")
|
768 |
+
logging.debug(f"📜 Текст добавленного файла: {add_csv_text}")
|
769 |
+
else:
|
770 |
+
# Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge
|
771 |
+
logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.")
|
772 |
+
|
773 |
+
# Шаг 3. Если итоговая длина меньше target_samples, паддинг нулями
|
774 |
+
curr_len = waveform.shape[1]
|
775 |
+
if curr_len < self.target_samples:
|
776 |
+
pad_size = self.target_samples - curr_len
|
777 |
+
logging.debug(f"Паддинг {os.path.basename(audio_path)}: +{pad_size} сэмплов")
|
778 |
+
waveform = torch.nn.functional.pad(waveform, (0, pad_size))
|
779 |
+
|
780 |
+
# Шаг 4. Обрезаем аудио до target_samples (если вышло больше)
|
781 |
+
waveform = waveform[:, :self.target_samples]
|
782 |
+
logging.debug(f"Финальная длина {os.path.basename(audio_path)}: {waveform.shape[1]/sr:.2f} сек; was_merged={was_merged}")
|
783 |
+
|
784 |
+
# Шаг 5. Получаем текст
|
785 |
+
if was_merged:
|
786 |
+
logging.debug("📝 Текст: аудио было merged – вызываем Whisper.")
|
787 |
+
text_final = self.run_whisper(waveform)
|
788 |
+
logging.debug(f"🆕 Whisper предсказал: {text_final}")
|
789 |
+
else:
|
790 |
+
if csv_text.strip():
|
791 |
+
logging.debug("Текст: используем CSV-текст (не пуст).")
|
792 |
+
text_final = csv_text
|
793 |
+
else:
|
794 |
+
if self.split == "train" or self.use_whisper_for_nontrain_if_no_text:
|
795 |
+
logging.debug("Текст: CSV пустой – вызываем Whisper.")
|
796 |
+
text_final = self.run_whisper(waveform)
|
797 |
+
else:
|
798 |
+
logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.")
|
799 |
+
text_final = ""
|
800 |
+
|
801 |
+
return {
|
802 |
+
"audio_path": os.path.basename(audio_path), # new
|
803 |
+
"audio": waveform,
|
804 |
+
"label": label_vec,
|
805 |
+
"text": text_final
|
806 |
+
}
|
807 |
+
|
808 |
+
def load_audio(self, path):
|
809 |
+
"""
|
810 |
+
Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо.
|
811 |
+
"""
|
812 |
+
if not os.path.exists(path):
|
813 |
+
logging.warning(f"Файл отсутствует: {path}")
|
814 |
+
return None, None
|
815 |
+
try:
|
816 |
+
wf, sr = torchaudio.load(path)
|
817 |
+
if sr != self.sample_rate:
|
818 |
+
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
|
819 |
+
wf = resampler(wf)
|
820 |
+
sr = self.sample_rate
|
821 |
+
return wf, sr
|
822 |
+
except Exception as e:
|
823 |
+
logging.error(f"Ошибка загрузки {path}: {e}")
|
824 |
+
return None, None
|
825 |
+
|
826 |
+
def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5):
|
827 |
+
"""
|
828 |
+
Ищет аудиофайл с той же эмоцией.
|
829 |
+
1) Если есть файлы >= min_needed, выбираем случайно из них.
|
830 |
+
2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный.
|
831 |
+
"""
|
832 |
+
|
833 |
+
candidates = [p for p, lbl in self.audio_class_map.items()
|
834 |
+
if lbl == label_name and p != exclude_path]
|
835 |
+
logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'")
|
836 |
+
|
837 |
+
# Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info
|
838 |
+
all_info = []
|
839 |
+
for path in candidates:
|
840 |
+
# <<< NEW: вместо info = torchaudio.info(path) ...
|
841 |
+
eq_len = self.path_info.get(path, 0) # Получаем из кэша
|
842 |
+
all_info.append((eq_len, path))
|
843 |
+
|
844 |
+
# --- Ниже старый код, который был:
|
845 |
+
# for path in candidates:
|
846 |
+
# try:
|
847 |
+
# info = torchaudio.info(path)
|
848 |
+
# length = info.num_frames
|
849 |
+
# sr_ = info.sample_rate
|
850 |
+
# eq_len = int(length / (sr_ / self.sample_rate)) if sr_ != self.sample_rate else length
|
851 |
+
# all_info.append((eq_len, path))
|
852 |
+
# except Exception as e:
|
853 |
+
# logging.warning(f"⚠ Ошибка чтения {path}: {e}")
|
854 |
+
|
855 |
+
# 1) Фильтруем только >= min_needed
|
856 |
+
valid = [(l, p) for l, p in all_info if l >= min_needed]
|
857 |
+
logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})")
|
858 |
+
|
859 |
+
if valid:
|
860 |
+
# Если есть идеальные — берём случайно из них
|
861 |
+
random.shuffle(valid)
|
862 |
+
chosen = random.choice(valid)[1]
|
863 |
+
return chosen
|
864 |
+
else:
|
865 |
+
# 2) Если идеальных нет — берём топ-K по длине
|
866 |
+
sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True)
|
867 |
+
top_k_list = sorted_by_len[:top_k]
|
868 |
+
if not top_k_list:
|
869 |
+
logging.debug("Нет доступных кандидатов вообще.")
|
870 |
+
return None # вообще нет кандидатов
|
871 |
+
|
872 |
+
random.shuffle(top_k_list)
|
873 |
+
chosen = top_k_list[0][1]
|
874 |
+
logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}")
|
875 |
+
return chosen
|
876 |
+
|
877 |
+
def run_whisper(self, waveform):
|
878 |
+
"""
|
879 |
+
Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов).
|
880 |
+
"""
|
881 |
+
arr = waveform.squeeze().cpu().numpy()
|
882 |
+
try:
|
883 |
+
result = self.whisper_model.transcribe(arr, fp16=False)
|
884 |
+
text = result["text"].strip()
|
885 |
+
return text
|
886 |
+
except Exception as e:
|
887 |
+
logging.error(f"Whisper ошибка: {e}")
|
888 |
+
return ""
|
889 |
+
|
890 |
+
def emotion_to_vector(self, label_name):
|
891 |
+
"""
|
892 |
+
Преобразует название эмоции в one-hot вектор (torch.tensor).
|
893 |
+
"""
|
894 |
+
v = np.zeros(len(self.emotion_columns), dtype=np.float32)
|
895 |
+
if label_name in self.emotion_columns:
|
896 |
+
idx = self.emotion_columns.index(label_name)
|
897 |
+
v[idx] = 1.0
|
898 |
+
return torch.tensor(v, dtype=torch.float32)
|
data_loading/feature_extractor.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# data_loading/feature_extractor.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from transformers import (
|
8 |
+
AutoFeatureExtractor,
|
9 |
+
AutoModel,
|
10 |
+
AutoTokenizer,
|
11 |
+
AutoModelForAudioClassification,
|
12 |
+
Wav2Vec2Processor
|
13 |
+
)
|
14 |
+
from data_loading.pretrained_extractors import EmotionModel, get_model_mamba, Mamba
|
15 |
+
|
16 |
+
|
17 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
# DEVICE = torch.device('cpu')
|
19 |
+
|
20 |
+
|
21 |
+
class PretrainedAudioEmbeddingExtractor:
|
22 |
+
"""
|
23 |
+
Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'),
|
24 |
+
с учётом pooling, нормализации и т.д.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, config):
|
28 |
+
"""
|
29 |
+
Ожидается, что в config есть поля:
|
30 |
+
- audio_model_name (str) : название модели (ExHuBERT и т.п.)
|
31 |
+
- emb_device (str) : "cpu" или "cuda"
|
32 |
+
- audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг)
|
33 |
+
- emb_normalize (bool) : делать ли L2-нормализацию выхода
|
34 |
+
- max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать)
|
35 |
+
"""
|
36 |
+
self.config = config
|
37 |
+
self.device = config.emb_device
|
38 |
+
self.model_name = config.audio_model_name
|
39 |
+
self.pooling = config.audio_pooling # может быть None
|
40 |
+
self.normalize_output = config.emb_normalize
|
41 |
+
self.max_audio_frames = getattr(config, "max_audio_frames", 0)
|
42 |
+
self.audio_classifier_checkpoint = config.audio_classifier_checkpoint
|
43 |
+
|
44 |
+
# Инициализируем processor и audio_embedder
|
45 |
+
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
|
46 |
+
self.audio_embedder = EmotionModel.from_pretrained(self.model_name).to(self.device)
|
47 |
+
|
48 |
+
# Загружаем модель
|
49 |
+
self.classifier_model = self.load_classifier_model_from_checkpoint(self.audio_classifier_checkpoint)
|
50 |
+
|
51 |
+
|
52 |
+
def extract(self, waveform: torch.Tensor, sample_rate=16000):
|
53 |
+
"""
|
54 |
+
Извлекает эмбеддинги из аудиоданных.
|
55 |
+
|
56 |
+
:param waveform: Тензор формы (T).
|
57 |
+
:param sample_rate: Частота дискретизации (int).
|
58 |
+
:return: Тензоры:
|
59 |
+
вернётся (B, classes), (B, sequence_length, hidden_dim).
|
60 |
+
"""
|
61 |
+
|
62 |
+
embeddings = self.process_audio(waveform, sample_rate)
|
63 |
+
tensor_emb = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
|
64 |
+
lengths = [tensor_emb.shape[1]]
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
logits, hidden = self.classifier_model(tensor_emb, lengths, with_features=True)
|
68 |
+
|
69 |
+
# Если pooling=None => вернём (B, seq_len, hidden_dim)
|
70 |
+
if hidden.dim() == 3:
|
71 |
+
if self.pooling is None:
|
72 |
+
emb = hidden
|
73 |
+
else:
|
74 |
+
if self.pooling == "mean":
|
75 |
+
emb = hidden.mean(dim=1)
|
76 |
+
elif self.pooling == "cls":
|
77 |
+
emb = hidden[:, 0, :]
|
78 |
+
elif self.pooling == "max":
|
79 |
+
emb, _ = hidden.max(dim=1)
|
80 |
+
elif self.pooling == "min":
|
81 |
+
emb, _ = hidden.min(dim=1)
|
82 |
+
elif self.pooling == "last":
|
83 |
+
emb = hidden[:, -1, :]
|
84 |
+
elif self.pooling == "sum":
|
85 |
+
emb = hidden.sum(dim=1)
|
86 |
+
else:
|
87 |
+
emb = hidden.mean(dim=1)
|
88 |
+
else:
|
89 |
+
# На всякий случай, если получилось (B, hidden_dim)
|
90 |
+
emb = hidden
|
91 |
+
|
92 |
+
if self.normalize_output and emb.dim() == 2:
|
93 |
+
emb = F.normalize(emb, p=2, dim=1)
|
94 |
+
|
95 |
+
return logits, emb
|
96 |
+
|
97 |
+
def process_audio(self, signal: np.ndarray, sampling_rate: int) -> np.ndarray:
|
98 |
+
inputs = self.processor(signal, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
|
99 |
+
input_values = inputs["input_values"].to(self.device)
|
100 |
+
|
101 |
+
with torch.no_grad():
|
102 |
+
outputs = self.audio_embedder(input_values)
|
103 |
+
embeddings = outputs
|
104 |
+
|
105 |
+
return embeddings.detach().cpu().numpy()
|
106 |
+
|
107 |
+
def load_classifier_model_from_checkpoint(self, checkpoint_path):
|
108 |
+
if checkpoint_path == "best_audio_model.pt":
|
109 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
110 |
+
exp_params = checkpoint['exp_params']
|
111 |
+
classifier_model = get_model_mamba(exp_params).to(self.device)
|
112 |
+
classifier_model.load_state_dict(checkpoint['model_state_dict'])
|
113 |
+
elif checkpoint_path == "best_audio_model_2.pt":
|
114 |
+
model_params = {
|
115 |
+
"input_size": 1024,
|
116 |
+
"d_model": 256,
|
117 |
+
"num_layers": 2,
|
118 |
+
"num_classes": 7,
|
119 |
+
"dropout": 0.2
|
120 |
+
}
|
121 |
+
classifier_model = get_model_mamba(model_params).to(self.device)
|
122 |
+
classifier_model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
|
123 |
+
classifier_model.eval()
|
124 |
+
return classifier_model
|
125 |
+
|
126 |
+
class AudioEmbeddingExtractor:
|
127 |
+
"""
|
128 |
+
Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'),
|
129 |
+
с учётом pooling, нормализации и т.д.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, config):
|
133 |
+
"""
|
134 |
+
Ожидается, что в config есть поля:
|
135 |
+
- audio_model_name (str) : название модели (ExHuBERT и т.п.)
|
136 |
+
- emb_device (str) : "cpu" или "cuda"
|
137 |
+
- audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг)
|
138 |
+
- emb_normalize (bool) : делать ли L2-нормализацию выхода
|
139 |
+
- max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать)
|
140 |
+
"""
|
141 |
+
self.config = config
|
142 |
+
self.device = config.emb_device
|
143 |
+
self.model_name = config.audio_model_name
|
144 |
+
self.pooling = config.audio_pooling # может быть None
|
145 |
+
self.normalize_output = config.emb_normalize
|
146 |
+
# self.max_audio_frames = getattr(config, "max_audio_frames", 0)
|
147 |
+
self.max_audio_frames = config.sample_rate * config.wav_length
|
148 |
+
|
149 |
+
|
150 |
+
# Попробуем загрузить feature_extractor (не у всех моделей доступен)
|
151 |
+
try:
|
152 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name)
|
153 |
+
logging.info(f"[Audio] Using AutoFeatureExtractor for '{self.model_name}'")
|
154 |
+
except Exception as e:
|
155 |
+
self.feature_extractor = None
|
156 |
+
logging.warning(f"[Audio] No built-in FeatureExtractor found. Model={self.model_name}. Error: {e}")
|
157 |
+
|
158 |
+
# Загружаем модель
|
159 |
+
# Если у модели нет head-классификации, бывает достаточно AutoModel
|
160 |
+
try:
|
161 |
+
self.model = AutoModel.from_pretrained(
|
162 |
+
self.model_name,
|
163 |
+
output_hidden_states=True # чтобы точно был last_hidden_state
|
164 |
+
).to(self.device)
|
165 |
+
logging.info(f"[Audio] Loaded AutoModel with output_hidden_states=True: {self.model_name}")
|
166 |
+
except Exception as e:
|
167 |
+
logging.warning(f"[Audio] Fallback to AudioClassification model. Reason: {e}")
|
168 |
+
self.model = AutoModelForAudioClassification.from_pretrained(
|
169 |
+
self.model_name,
|
170 |
+
output_hidden_states=True
|
171 |
+
).to(self.device)
|
172 |
+
|
173 |
+
def extract(self, waveform_batch: torch.Tensor, sample_rate=16000):
|
174 |
+
"""
|
175 |
+
Извлекает эмбеддинги из аудиоданных.
|
176 |
+
|
177 |
+
:param waveform_batch: Тензор формы (B, T) или (B, 1, T).
|
178 |
+
:param sample_rate: Частота дискретизации (int).
|
179 |
+
:return: Тензор:
|
180 |
+
- если pooling != None, будет (B, hidden_dim)
|
181 |
+
- если pooling == None и last_hidden_state имел форму (B, seq_len, hidden_dim),
|
182 |
+
вернётся (B, seq_len, hidden_dim).
|
183 |
+
"""
|
184 |
+
|
185 |
+
# Если пришло (B, 1, T), уберём ось "1"
|
186 |
+
if waveform_batch.dim() == 3 and waveform_batch.shape[1] == 1:
|
187 |
+
waveform_batch = waveform_batch.squeeze(1) # -> (B, T)
|
188 |
+
|
189 |
+
# Усечение по времени, если нужно
|
190 |
+
if self.max_audio_frames > 0 and waveform_batch.shape[1] > self.max_audio_frames:
|
191 |
+
waveform_batch = waveform_batch[:, :self.max_audio_frames]
|
192 |
+
|
193 |
+
# Если есть feature_extractor - используем
|
194 |
+
if self.feature_extractor is not None:
|
195 |
+
inputs = self.feature_extractor(
|
196 |
+
waveform_batch,
|
197 |
+
sampling_rate=sample_rate,
|
198 |
+
return_tensors="pt",
|
199 |
+
truncation=True,
|
200 |
+
max_length=self.max_audio_frames if self.max_audio_frames > 0 else None
|
201 |
+
)
|
202 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
203 |
+
|
204 |
+
outputs = self.model(input_values=inputs["input_values"])
|
205 |
+
else:
|
206 |
+
# Иначе подадим напрямую "input_values" на модель
|
207 |
+
inputs = {"input_values": waveform_batch.to(self.device)}
|
208 |
+
outputs = self.model(**inputs)
|
209 |
+
|
210 |
+
# Теперь outputs может быть BaseModelOutput (с last_hidden_state, hidden_states, etc.)
|
211 |
+
# Или SequenceClassifierOutput (с logits), если это модель-классификатор
|
212 |
+
if hasattr(outputs, "last_hidden_state"):
|
213 |
+
# (B, seq_len, hidden_dim)
|
214 |
+
hidden = outputs.last_hidden_state
|
215 |
+
# logging.debug(f"[Audio] last_hidden_state shape: {hidden.shape}")
|
216 |
+
elif hasattr(outputs, "logits"):
|
217 |
+
# logits: (B, num_labels)
|
218 |
+
# Для пуллинга по "seq_len" притворимся, что seq_len=1
|
219 |
+
hidden = outputs.logits.unsqueeze(1) # (B,1,num_labels)
|
220 |
+
logging.debug(f"[Audio] Found logits shape: {outputs.logits.shape} => hidden={hidden.shape}")
|
221 |
+
else:
|
222 |
+
# Модель может сразу возвращать тензор
|
223 |
+
hidden = outputs
|
224 |
+
|
225 |
+
# Если у нас 2D-тензор (B, hidden_dim), значит всё уже спулено
|
226 |
+
if hidden.dim() == 2:
|
227 |
+
emb = hidden
|
228 |
+
elif hidden.dim() == 3:
|
229 |
+
# (B, seq_len, hidden_dim)
|
230 |
+
if self.pooling is None:
|
231 |
+
# Возвращаем как есть
|
232 |
+
emb = hidden
|
233 |
+
else:
|
234 |
+
# Выполним пуллинг
|
235 |
+
if self.pooling == "mean":
|
236 |
+
emb = hidden.mean(dim=1)
|
237 |
+
elif self.pooling == "cls":
|
238 |
+
emb = hidden[:, 0, :] # [B, hidden_dim]
|
239 |
+
elif self.pooling == "max":
|
240 |
+
emb, _ = hidden.max(dim=1)
|
241 |
+
elif self.pooling == "min":
|
242 |
+
emb, _ = hidden.min(dim=1)
|
243 |
+
elif self.pooling == "last":
|
244 |
+
emb = hidden[:, -1, :]
|
245 |
+
else:
|
246 |
+
emb = hidden.mean(dim=1) # на всякий случай fallback
|
247 |
+
else:
|
248 |
+
# На всякий: если ещё какая-то форма
|
249 |
+
raise ValueError(f"[Audio] Unexpected hidden shape={hidden.shape}, pooling={self.pooling}")
|
250 |
+
|
251 |
+
if self.normalize_output and emb.dim() == 2:
|
252 |
+
emb = F.normalize(emb, p=2, dim=1)
|
253 |
+
|
254 |
+
return emb
|
255 |
+
|
256 |
+
|
257 |
+
class TextEmbeddingExtractor:
|
258 |
+
"""
|
259 |
+
Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'),
|
260 |
+
с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения.
|
261 |
+
"""
|
262 |
+
|
263 |
+
def __init__(self, config):
|
264 |
+
"""
|
265 |
+
Параметры в config:
|
266 |
+
- text_model_name (str)
|
267 |
+
- emb_device (str)
|
268 |
+
- text_pooling (str | None)
|
269 |
+
- emb_normalize (bool)
|
270 |
+
- max_tokens (int)
|
271 |
+
"""
|
272 |
+
self.config = config
|
273 |
+
self.device = config.emb_device
|
274 |
+
self.model_name = config.text_model_name
|
275 |
+
self.pooling = config.text_pooling # может быть None
|
276 |
+
self.normalize_output = config.emb_normalize
|
277 |
+
self.max_tokens = config.max_tokens
|
278 |
+
|
279 |
+
# trust_remote_code=True нужно для моделей вроде jina
|
280 |
+
logging.info(f"[Text] Loading tokenizer for {self.model_name} with trust_remote_code=True")
|
281 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
282 |
+
self.model_name,
|
283 |
+
trust_remote_code=True
|
284 |
+
)
|
285 |
+
|
286 |
+
logging.info(f"[Text] Loading model for {self.model_name} with trust_remote_code=True")
|
287 |
+
self.model = AutoModel.from_pretrained(
|
288 |
+
self.model_name,
|
289 |
+
trust_remote_code=True,
|
290 |
+
output_hidden_states=True, # хотим иметь last_hidden_state
|
291 |
+
force_download=False
|
292 |
+
).to(self.device)
|
293 |
+
|
294 |
+
def extract(self, text_list):
|
295 |
+
"""
|
296 |
+
:param text_list: список строк (или одна строка)
|
297 |
+
:return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None
|
298 |
+
"""
|
299 |
+
|
300 |
+
if isinstance(text_list, str):
|
301 |
+
text_list = [text_list]
|
302 |
+
|
303 |
+
inputs = self.tokenizer(
|
304 |
+
text_list,
|
305 |
+
padding="max_length",
|
306 |
+
truncation=True,
|
307 |
+
max_length=self.max_tokens,
|
308 |
+
return_tensors="pt"
|
309 |
+
)
|
310 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
311 |
+
|
312 |
+
with torch.no_grad():
|
313 |
+
outputs = self.model(**inputs)
|
314 |
+
# Обычно у AutoModel last_hidden_state.shape = (B, seq_len, hidden_dim)
|
315 |
+
hidden = outputs.last_hidden_state
|
316 |
+
# logging.debug(f"[Text] last_hidden_state shape: {hidden.shape}")
|
317 |
+
|
318 |
+
# Если pooling=None => вернём (B, seq_len, hidden_dim)
|
319 |
+
if hidden.dim() == 3:
|
320 |
+
if self.pooling is None:
|
321 |
+
emb = hidden
|
322 |
+
else:
|
323 |
+
if self.pooling == "mean":
|
324 |
+
emb = hidden.mean(dim=1)
|
325 |
+
elif self.pooling == "cls":
|
326 |
+
emb = hidden[:, 0, :]
|
327 |
+
elif self.pooling == "max":
|
328 |
+
emb, _ = hidden.max(dim=1)
|
329 |
+
elif self.pooling == "min":
|
330 |
+
emb, _ = hidden.min(dim=1)
|
331 |
+
elif self.pooling == "last":
|
332 |
+
emb = hidden[:, -1, :]
|
333 |
+
elif self.pooling == "sum":
|
334 |
+
emb = hidden.sum(dim=1)
|
335 |
+
else:
|
336 |
+
emb = hidden.mean(dim=1)
|
337 |
+
else:
|
338 |
+
# На всякий случай, если получилось (B, hidden_dim)
|
339 |
+
emb = hidden
|
340 |
+
|
341 |
+
if self.normalize_output and emb.dim() == 2:
|
342 |
+
emb = F.normalize(emb, p=2, dim=1)
|
343 |
+
|
344 |
+
return emb
|
345 |
+
|
346 |
+
class PretrainedTextEmbeddingExtractor:
|
347 |
+
"""
|
348 |
+
Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'),
|
349 |
+
с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения.
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(self, config):
|
353 |
+
"""
|
354 |
+
Параметры в config:
|
355 |
+
- text_model_name (str)
|
356 |
+
- emb_device (str)
|
357 |
+
- text_pooling (str | None)
|
358 |
+
- emb_normalize (bool)
|
359 |
+
- max_tokens (int)
|
360 |
+
"""
|
361 |
+
self.config = config
|
362 |
+
self.device = config.emb_device
|
363 |
+
self.model_name = config.text_model_name
|
364 |
+
self.pooling = config.text_pooling # может быть None
|
365 |
+
self.normalize_output = config.emb_normalize
|
366 |
+
self.max_tokens = config.max_tokens
|
367 |
+
self.text_classifier_checkpoint = config.text_classifier_checkpoint
|
368 |
+
|
369 |
+
self.model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name=self.model_name, max_tokens=self.max_tokens, pooling=None).to(self.device)
|
370 |
+
checkpoint = torch.load(self.text_classifier_checkpoint, map_location=DEVICE)
|
371 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
372 |
+
self.model.eval()
|
373 |
+
|
374 |
+
def extract(self, text_list):
|
375 |
+
"""
|
376 |
+
:param text_list: список строк (или одна строка)
|
377 |
+
:return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None
|
378 |
+
"""
|
379 |
+
|
380 |
+
if isinstance(text_list, str):
|
381 |
+
text_list = [text_list]
|
382 |
+
|
383 |
+
with torch.no_grad():
|
384 |
+
logits, hidden = self.model(text_list, with_features=True)
|
385 |
+
|
386 |
+
if hidden.dim() == 3:
|
387 |
+
if self.pooling is None:
|
388 |
+
emb = hidden
|
389 |
+
else:
|
390 |
+
if self.pooling == "mean":
|
391 |
+
emb = hidden.mean(dim=1)
|
392 |
+
elif self.pooling == "cls":
|
393 |
+
emb = hidden[:, 0, :]
|
394 |
+
elif self.pooling == "max":
|
395 |
+
emb, _ = hidden.max(dim=1)
|
396 |
+
elif self.pooling == "min":
|
397 |
+
emb, _ = hidden.min(dim=1)
|
398 |
+
elif self.pooling == "last":
|
399 |
+
emb = hidden[:, -1, :]
|
400 |
+
elif self.pooling == "sum":
|
401 |
+
emb = hidden.sum(dim=1)
|
402 |
+
else:
|
403 |
+
emb = hidden.mean(dim=1)
|
404 |
+
else:
|
405 |
+
# На всякий случай, если получилось (B, hidden_dim)
|
406 |
+
emb = hidden
|
407 |
+
|
408 |
+
if self.normalize_output and emb.dim() == 2:
|
409 |
+
emb = F.normalize(emb, p=2, dim=1)
|
410 |
+
return logits, emb
|
data_loading/pretrained_extractors.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from transformers import AutoTokenizer, AutoModel
|
7 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
8 |
+
Wav2Vec2Model,
|
9 |
+
Wav2Vec2PreTrainedModel,
|
10 |
+
)
|
11 |
+
from torch.nn.functional import silu
|
12 |
+
from torch.nn.functional import softplus
|
13 |
+
from einops import rearrange, einsum
|
14 |
+
from torch import Tensor
|
15 |
+
from einops import rearrange
|
16 |
+
|
17 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
# DEVICE = torch.device('cpu')
|
19 |
+
|
20 |
+
## Audio models
|
21 |
+
|
22 |
+
class CustomMambaBlock(nn.Module):
|
23 |
+
def __init__(self, d_input, d_model, dropout=0.1):
|
24 |
+
super().__init__()
|
25 |
+
self.in_proj = nn.Linear(d_input, d_model)
|
26 |
+
self.s_B = nn.Linear(d_model, d_model)
|
27 |
+
self.s_C = nn.Linear(d_model, d_model)
|
28 |
+
self.out_proj = nn.Linear(d_model, d_input)
|
29 |
+
self.norm = nn.LayerNorm(d_input)
|
30 |
+
self.dropout = nn.Dropout(dropout)
|
31 |
+
self.activation = nn.ReLU()
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x_in = x # сохраняем вход
|
35 |
+
x = self.in_proj(x)
|
36 |
+
B = self.s_B(x)
|
37 |
+
C = self.s_C(x)
|
38 |
+
x = x + B + C
|
39 |
+
x = self.activation(x)
|
40 |
+
x = self.out_proj(x)
|
41 |
+
x = self.dropout(x)
|
42 |
+
x = self.norm(x + x_in) # residual + norm
|
43 |
+
return x
|
44 |
+
|
45 |
+
class CustomMambaClassifier(nn.Module):
|
46 |
+
def __init__(self, input_size=1024, d_model=256, num_layers=2, num_classes=7, dropout=0.1):
|
47 |
+
super().__init__()
|
48 |
+
self.input_proj = nn.Linear(input_size, d_model)
|
49 |
+
self.blocks = nn.ModuleList([
|
50 |
+
CustomMambaBlock(d_model, d_model, dropout=dropout)
|
51 |
+
for _ in range(num_layers)
|
52 |
+
])
|
53 |
+
self.fc = nn.Linear(d_model, num_classes)
|
54 |
+
|
55 |
+
def forward(self, x, lengths, with_features=False):
|
56 |
+
# x: (batch, seq_length, input_size)
|
57 |
+
x = self.input_proj(x)
|
58 |
+
for block in self.blocks:
|
59 |
+
x = block(x)
|
60 |
+
pooled = []
|
61 |
+
for i, l in enumerate(lengths):
|
62 |
+
if l > 0:
|
63 |
+
pooled.append(x[i, :l, :].mean(dim=0))
|
64 |
+
else:
|
65 |
+
pooled.append(torch.zeros(x.size(2), device=x.device))
|
66 |
+
pooled = torch.stack(pooled, dim=0)
|
67 |
+
if with_features:
|
68 |
+
return self.fc(pooled), x
|
69 |
+
else:
|
70 |
+
return self.fc(pooled)
|
71 |
+
|
72 |
+
def get_model_mamba(params):
|
73 |
+
return CustomMambaClassifier(
|
74 |
+
input_size=params.get("input_size", 1024),
|
75 |
+
d_model=params.get("d_model", 256),
|
76 |
+
num_layers=params.get("num_layers", 2),
|
77 |
+
num_classes=params.get("num_classes", 7),
|
78 |
+
dropout=params.get("dropout", 0.1)
|
79 |
+
)
|
80 |
+
|
81 |
+
class EmotionModel(Wav2Vec2PreTrainedModel):
|
82 |
+
|
83 |
+
def __init__(self, config):
|
84 |
+
super().__init__(config)
|
85 |
+
self.config = config
|
86 |
+
self.wav2vec2 = Wav2Vec2Model(config)
|
87 |
+
self.init_weights()
|
88 |
+
|
89 |
+
def forward(self, input_values):
|
90 |
+
outputs = self.wav2vec2(input_values)
|
91 |
+
hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
92 |
+
return hidden_states
|
93 |
+
|
94 |
+
## Text models
|
95 |
+
|
96 |
+
class Embedding():
|
97 |
+
def __init__(self, model_name='jinaai/jina-embeddings-v3', pooling=None):
|
98 |
+
self.model_name = model_name
|
99 |
+
self.pooling = pooling
|
100 |
+
self.device = DEVICE
|
101 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True)
|
102 |
+
self.model = AutoModel.from_pretrained(model_name, code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True).to(self.device)
|
103 |
+
self.model.eval()
|
104 |
+
|
105 |
+
def _mean_pooling(self, X):
|
106 |
+
def mean_pooling(model_output, attention_mask):
|
107 |
+
token_embeddings = model_output[0]
|
108 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
109 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
110 |
+
encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
111 |
+
with torch.no_grad():
|
112 |
+
model_output = self.model(**encoded_input)
|
113 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
114 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
115 |
+
return sentence_embeddings.unsqueeze(1)
|
116 |
+
|
117 |
+
def get_embeddings(self, X, max_len):
|
118 |
+
encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
119 |
+
with torch.no_grad():
|
120 |
+
features = self.model(**encoded_input)[0].detach().cpu().float().numpy()
|
121 |
+
res = np.pad(features[:, :max_len, :], ((0, 0), (0, max(0, max_len - features.shape[1])), (0, 0)), "constant")
|
122 |
+
return torch.tensor(res)
|
123 |
+
|
124 |
+
class RMSNorm(nn.Module):
|
125 |
+
def __init__(self, d_model: int, eps: float = 1e-8) -> None:
|
126 |
+
super().__init__()
|
127 |
+
self.eps = eps
|
128 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
129 |
+
|
130 |
+
def forward(self, x: Tensor) -> Tensor:
|
131 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
|
132 |
+
|
133 |
+
class Mamba(nn.Module):
|
134 |
+
def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, max_tokens=95, model_name='jina', pooling=None):
|
135 |
+
super().__init__()
|
136 |
+
mamba_par = {
|
137 |
+
'd_input' : d_input,
|
138 |
+
'd_model' : d_model,
|
139 |
+
'd_state' : d_state,
|
140 |
+
'd_discr' : d_discr,
|
141 |
+
'ker_size': ker_size
|
142 |
+
}
|
143 |
+
self.model_name = model_name
|
144 |
+
self.max_tokens = max_tokens
|
145 |
+
embed = Embedding(model_name, pooling)
|
146 |
+
self.embedding = embed.get_embeddings
|
147 |
+
self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
|
148 |
+
self.fc_out = nn.Linear(d_input, num_classes)
|
149 |
+
self.device = DEVICE
|
150 |
+
|
151 |
+
def forward(self, seq, cache=None, with_features=True):
|
152 |
+
seq = self.embedding(seq, self.max_tokens).to(self.device)
|
153 |
+
for mamba, norm in self.layers:
|
154 |
+
out, cache = mamba(norm(seq), cache)
|
155 |
+
seq = out + seq
|
156 |
+
if with_features:
|
157 |
+
return self.fc_out(seq.mean(dim = 1)), seq
|
158 |
+
else:
|
159 |
+
return self.fc_out(seq.mean(dim = 1))
|
160 |
+
|
161 |
+
class MambaBlock(nn.Module):
|
162 |
+
def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
|
163 |
+
super().__init__()
|
164 |
+
d_discr = d_discr if d_discr is not None else d_model // 16
|
165 |
+
self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
|
166 |
+
self.out_proj = nn.Linear(d_model, d_input, bias=False)
|
167 |
+
self.s_B = nn.Linear(d_model, d_state, bias=False)
|
168 |
+
self.s_C = nn.Linear(d_model, d_state, bias=False)
|
169 |
+
self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
|
170 |
+
self.conv = nn.Conv1d(
|
171 |
+
in_channels=d_model,
|
172 |
+
out_channels=d_model,
|
173 |
+
kernel_size=ker_size,
|
174 |
+
padding=ker_size - 1,
|
175 |
+
groups=d_model,
|
176 |
+
bias=True,
|
177 |
+
)
|
178 |
+
self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
|
179 |
+
self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
|
180 |
+
self.device = DEVICE
|
181 |
+
|
182 |
+
def forward(self, seq, cache=None):
|
183 |
+
b, l, d = seq.shape
|
184 |
+
(prev_hid, prev_inp) = cache if cache is not None else (None, None)
|
185 |
+
a, b = self.in_proj(seq).chunk(2, dim=-1)
|
186 |
+
x = rearrange(a, 'b l d -> b d l')
|
187 |
+
x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
|
188 |
+
a = self.conv(x)[..., :l]
|
189 |
+
a = rearrange(a, 'b d l -> b l d')
|
190 |
+
a = silu(a)
|
191 |
+
a, hid = self.ssm(a, prev_hid=prev_hid)
|
192 |
+
b = silu(b)
|
193 |
+
out = a * b
|
194 |
+
out = self.out_proj(out)
|
195 |
+
if cache:
|
196 |
+
cache = (hid.squeeze(), x[..., 1:])
|
197 |
+
return out, cache
|
198 |
+
|
199 |
+
def ssm(self, seq, prev_hid):
|
200 |
+
A = -self.A
|
201 |
+
D = +self.D
|
202 |
+
B = self.s_B(seq)
|
203 |
+
C = self.s_C(seq)
|
204 |
+
s = softplus(D + self.s_D(seq))
|
205 |
+
A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
|
206 |
+
B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
|
207 |
+
X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
|
208 |
+
hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
|
209 |
+
out = einsum(hid, C, 'b l d s, b l s -> b l d')
|
210 |
+
out = out + D * seq
|
211 |
+
return out, hid
|
212 |
+
|
213 |
+
def _hid_states(self, A, X, prev_hid=None):
|
214 |
+
b, l, d, s = A.shape
|
215 |
+
A = rearrange(A, 'b l d s -> l b d s')
|
216 |
+
X = rearrange(X, 'b l d s -> l b d s')
|
217 |
+
if prev_hid is not None:
|
218 |
+
return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
|
219 |
+
h = torch.zeros(b, d, s, device=self.device)
|
220 |
+
return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)
|
221 |
+
|
emotion_templates/anger.json
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"subjects": [
|
3 |
+
"That lie", "This situation", "His attitude", "Her silence", "The way they ignored me",
|
4 |
+
"That decision", "This whole mess", "His response", "The way they talked to me",
|
5 |
+
"Their tone", "The look in their eyes", "The unfairness", "The betrayal",
|
6 |
+
"This conversation", "The rules they made", "That smug smile", "The injustice",
|
7 |
+
"His arrogance", "Her hypocrisy", "That fake apology", "Their excuse",
|
8 |
+
"What happened", "That accusation", "Their assumption", "His coldness",
|
9 |
+
"The way they dismissed me", "The way she walked away", "Their comment",
|
10 |
+
"The lack of respect", "The silence", "The manipulation", "The fake concern",
|
11 |
+
"This nonsense", "Their reaction", "The way they handled it", "The favoritism",
|
12 |
+
"That email", "The way they gaslit me", "This double standard", "The broken promise",
|
13 |
+
"Her message", "His ignorance", "The constant interruptions", "The yelling",
|
14 |
+
"The lies they spread", "Their decision", "This treatment", "His sarcasm",
|
15 |
+
"The rolled eyes", "This chaos", "The way they brushed it off", "The crossed line",
|
16 |
+
"Their behavior", "The gaslighting", "That moment", "Her denial", "The lack of accountability",
|
17 |
+
"His manipulation", "Their fake smile", "The passive aggression", "The backhanded comment",
|
18 |
+
"Her voice", "His condescension", "The disrespect", "The way I was ignored",
|
19 |
+
"The raised voice", "The rules for them", "The expectations they set", "That stupid joke",
|
20 |
+
"The fake empathy", "The unfair treatment", "That meeting", "That statement",
|
21 |
+
"The constant blame", "The way they twisted my words", "The tone they used",
|
22 |
+
"That decision they made", "The control", "Their agenda", "The things they said",
|
23 |
+
"That smug look", "The whispering", "Their manipulation", "The criticism",
|
24 |
+
"Her smugness", "His blank stare", "The gaslight attempt", "The guilt trip",
|
25 |
+
"The two-faced attitude", "The pressure", "The invalidation", "The dismissal",
|
26 |
+
"This tension", "The accusations", "The disrespectful comment", "The dishonesty",
|
27 |
+
"The fake kindness", "This nonsense policy", "Their entitlement", "The red flags",
|
28 |
+
"The toxic vibe", "The constant judgment", "The unnecessary drama", "His defensiveness"
|
29 |
+
],
|
30 |
+
"verbs": [
|
31 |
+
"set me off", "boiled my blood", "pushed me over the edge", "got under my skin",
|
32 |
+
"made me snap", "triggered me instantly", "infuriated me", "was unbearable",
|
33 |
+
"sparked pure rage", "ticked me off", "made my fists clench", "was too much",
|
34 |
+
"ignited everything", "set a fire in me", "boiled over", "made my blood rush",
|
35 |
+
"shook me with anger", "cut too deep", "threw me into rage", "was beyond frustrating",
|
36 |
+
"hit a nerve", "was infuriating", "drove me mad", "burned inside me", "got to me",
|
37 |
+
"turned me red", "filled me with heat", "was outrageous", "was completely unacceptable",
|
38 |
+
"ripped through me", "kept poking at me", "got louder and louder", "spun me out",
|
39 |
+
"made me want to scream", "felt like an attack", "built up fast", "couldn’t be ignored",
|
40 |
+
"turned everything red", "sent me into a spiral", "clashed with everything in me",
|
41 |
+
"pressed all my buttons", "erased my patience", "had me yelling inside", "spilled over fast",
|
42 |
+
"was a slap in the face", "put me in fight mode", "snapped something in me", "took it too far",
|
43 |
+
"crossed a line", "felt like fuel", "burned through me", "got personal", "hit my limits",
|
44 |
+
"killed my calm", "shot adrenaline through me", "woke the worst in me", "soured the whole day",
|
45 |
+
"made me want to throw something", "burned me out", "was like fire in my gut", "dismantled my tolerance",
|
46 |
+
"overwhelmed me", "shattered my cool", "tore through my brain", "clanged in my chest",
|
47 |
+
"stabbed at my peace", "chased my breath", "spit in my face", "made me want to walk out",
|
48 |
+
"boomed inside", "grew louder", "flared without warning", "grated every nerve",
|
49 |
+
"felt like betrayal", "hammered on my mind", "flipped my mood instantly", "exploded internally",
|
50 |
+
"twisted my stomach", "simmered too long", "kept repeating in my head", "built up in silence",
|
51 |
+
"felt relentless", "crept up slowly", "left me livid", "killed all reason", "tore open my restraint",
|
52 |
+
"blocked out everything else", "just wouldn't stop", "sparked a storm", "dragged me into rage",
|
53 |
+
"stabbed at my patience", "flooded me with fury", "split my brain", "left me shaking",
|
54 |
+
"was like fire to paper", "ripped up my peace", "wrecked my day", "tipped me over",
|
55 |
+
"woke something primal", "restarted the fire", "slammed into my chest", "blew my fuse"
|
56 |
+
],
|
57 |
+
"interjections": [
|
58 |
+
"You’ve got to be kidding me (groans)!",
|
59 |
+
"(inhales) This is unreal!",
|
60 |
+
"Seriously (clears throat)?",
|
61 |
+
"I can’t even begin (inhales)...",
|
62 |
+
"Enough is enough (groans)!",
|
63 |
+
"What the hell (exhales)!",
|
64 |
+
"Are you serious (groans)?",
|
65 |
+
"This has to stop (inhales)!",
|
66 |
+
"I’ve had it (groans)!",
|
67 |
+
"I’m not holding back (clears throat).",
|
68 |
+
"Now I’m done (inhales).",
|
69 |
+
"Unbelievable (groans).",
|
70 |
+
"Not again (exhales)...",
|
71 |
+
"That’s it (groans)!",
|
72 |
+
"I swear (clears throat).",
|
73 |
+
"How dare they (inhales)!",
|
74 |
+
"Absolutely not (groans)!",
|
75 |
+
"Get out of here (screams)!",
|
76 |
+
"I’m done pretending (inhales).",
|
77 |
+
"No way (groans)."
|
78 |
+
],
|
79 |
+
|
80 |
+
"contexts": [
|
81 |
+
"and I couldn’t stay quiet (groans)",
|
82 |
+
"and I snapped (exhales) without warning",
|
83 |
+
"and I couldn’t control my voice (inhales)",
|
84 |
+
"and I felt the heat rise (exhales)",
|
85 |
+
"and I wanted to punch a wall (inhales)",
|
86 |
+
"and I yelled (screams) before I could stop",
|
87 |
+
"and it shattered my calm (groans)",
|
88 |
+
"and I lost it right there (groans)",
|
89 |
+
"and I (inhales) slammed the door",
|
90 |
+
"and I could feel myself shaking (groans)",
|
91 |
+
"and my jaw locked tight (exhales)",
|
92 |
+
"and I couldn’t hold it in (groans)",
|
93 |
+
"and I was done pretending (clears throat)",
|
94 |
+
"and I couldn’t even look at them (inhales)",
|
95 |
+
"and I exploded (groans)",
|
96 |
+
"and I couldn’t take it anymore (exhales)",
|
97 |
+
"and I had to say something (groans)",
|
98 |
+
"and I (inhales) stood my ground",
|
99 |
+
"and I shouted back (screams)",
|
100 |
+
"and I let it all out (exhales)",
|
101 |
+
"and my voice (clears throat) cracked",
|
102 |
+
"and I hit the table (groans)",
|
103 |
+
"and my chest was pounding (inhales)",
|
104 |
+
"and I was ready to walk out (groans)",
|
105 |
+
"and I didn’t care anymore (exhales)",
|
106 |
+
"and I needed to be heard (clears throat)",
|
107 |
+
"and I (inhales) raised my voice",
|
108 |
+
"and I threw the paper (groans)",
|
109 |
+
"and I stormed out (screams)",
|
110 |
+
"and I couldn’t stay still (inhales)",
|
111 |
+
"and I lost all patience (groans)",
|
112 |
+
"and I refused to let it slide (groans)",
|
113 |
+
"and I called them out (clears throat)",
|
114 |
+
"and I (inhales) slammed my hand down",
|
115 |
+
"and I saw red (groans)",
|
116 |
+
"and I swore (exhales) under my breath",
|
117 |
+
"and I (inhales) pushed back",
|
118 |
+
"and I made it clear (clears throat)",
|
119 |
+
"and I couldn’t smile through it (exhales)",
|
120 |
+
"and I wasn’t going to take it (groans)",
|
121 |
+
"and I felt fury (inhales) in my chest",
|
122 |
+
"and I almost screamed (groans)",
|
123 |
+
"and I nearly broke something (exhales)",
|
124 |
+
"and I didn’t care who heard (groans)",
|
125 |
+
"and I just (inhales) let it rip",
|
126 |
+
"and I couldn’t stop (groans)",
|
127 |
+
"and I raised hell (screams)",
|
128 |
+
"and I felt every nerve snap (inhales)",
|
129 |
+
"and I threw my phone (groans)",
|
130 |
+
"and I stood up fast (exhales)",
|
131 |
+
"and I couldn’t think straight (groans)",
|
132 |
+
"and I almost walked out (inhales)",
|
133 |
+
"and I hit the wall (groans)",
|
134 |
+
"and I didn’t hold back (clears throat)",
|
135 |
+
"and I wasn’t sorry (groans)",
|
136 |
+
"and I didn’t fake calm (inhales)",
|
137 |
+
"and I pushed through the rage (groans)",
|
138 |
+
"and I didn’t even blink (exhales)",
|
139 |
+
"and I said (inhales) exactly what I felt",
|
140 |
+
"and I glared back (groans)",
|
141 |
+
"and I lost all control (exhales)",
|
142 |
+
"and I roared inside (groans)",
|
143 |
+
"and I hit my breaking point (inhales)",
|
144 |
+
"and I wasn’t about to be silent (groans)",
|
145 |
+
"and I couldn’t take another second (exhales)",
|
146 |
+
"and I let the fury speak (groans)",
|
147 |
+
"and I (inhales) dropped the act",
|
148 |
+
"and I watched myself boil (groans)",
|
149 |
+
"and I nearly lost it all (exhales)",
|
150 |
+
"and I refused to back down (groans)",
|
151 |
+
"and I didn’t sugarcoat it (clears throat)",
|
152 |
+
"and I let the anger talk (groans)",
|
153 |
+
"and I (inhales) cut them off",
|
154 |
+
"and I burned the bridge (groans)",
|
155 |
+
"and I pulled no punches (inhales)",
|
156 |
+
"and I growled back (groans)",
|
157 |
+
"and I lashed out (exhales)",
|
158 |
+
"and I let it burn (groans)",
|
159 |
+
"and I needed the release (inhales)",
|
160 |
+
"and I was done holding it in (groans)",
|
161 |
+
"and I let the room know (inhales)",
|
162 |
+
"and I banged the table (groans)",
|
163 |
+
"and I spit the words out (inhales)",
|
164 |
+
"and I slammed my fist (groans)",
|
165 |
+
"and I (inhales) snapped hard",
|
166 |
+
"and I said it louder (groans)",
|
167 |
+
"and I faced it head‑on (exhales)",
|
168 |
+
"and I let the silence burn (groans)",
|
169 |
+
"and I felt the power rise (inhales)",
|
170 |
+
"and I walked away fuming (groans)",
|
171 |
+
"and I cracked wide open (exhales)",
|
172 |
+
"and I gave them the truth (groans)",
|
173 |
+
"and I looked them dead in the eyes (inhales)",
|
174 |
+
"and I didn’t regret it (exhales)"
|
175 |
+
],
|
176 |
+
"templates": [
|
177 |
+
"{s} {v}, {c}.",
|
178 |
+
"{i} {s} {v}.",
|
179 |
+
"{s} {v}. {c}.",
|
180 |
+
"{s} {v} — {c}.",
|
181 |
+
"{c}. {s} {v}.",
|
182 |
+
"{s}. It {v}, {c}.",
|
183 |
+
"{s} just... {v}. {c}.",
|
184 |
+
"{i} — {s} {v}, {c}.",
|
185 |
+
"What set me off? {s} {v}. {c}.",
|
186 |
+
"{s} {v}. I had enough! {c}.",
|
187 |
+
"{s} {v}, and I didn’t hold back. {c}.",
|
188 |
+
"I lost it when {s} {v}, {c}.",
|
189 |
+
"You want to know why I snapped? {s} {v}.",
|
190 |
+
"And then it happened: {s} {v}, {c}.",
|
191 |
+
"{s} {v}. I couldn’t take one more second! {c}.",
|
192 |
+
"{s} {v}. That was the last straw!",
|
193 |
+
"{i} I saw red when {s} {v}. {c}.",
|
194 |
+
"{s} {v}. I slammed the damn door. {c}."
|
195 |
+
]
|
196 |
+
}
|
emotion_templates/disgust.json
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"subjects": [
|
3 |
+
"That smell", "This mess", "The way they chew", "The bathroom", "That habit",
|
4 |
+
"The leftover food", "The sink", "The way he talks", "The mold", "Her fake laugh",
|
5 |
+
"The dirty plate", "His comment", "The old sponge", "The floor", "That photo",
|
6 |
+
"The rotten food", "The spoiled milk", "The sweat smell", "That video",
|
7 |
+
"The way she looked at me", "The chewing sound", "That creepy grin", "The fridge",
|
8 |
+
"The slime", "The trash", "His breath", "The sound of slurping", "The bug",
|
9 |
+
"The pus", "The stain", "The toilet", "The crusted towel", "The unwashed dish",
|
10 |
+
"The clump of hair", "The spit", "That burp", "The toenail clipping",
|
11 |
+
"The used tissue", "The crust in the corner", "The food stuck to the table",
|
12 |
+
"The pile of clothes", "The smudge", "That smell in the microwave", "The grime",
|
13 |
+
"The grease", "The way they eat with their mouth open", "The way he scratched",
|
14 |
+
"The loud chewing", "The mucus", "That nasty cough", "The dirty keyboard",
|
15 |
+
"The open sore", "The bathroom floor", "That weird stain", "The gum under the desk",
|
16 |
+
"The way she licked her fingers", "That slimy texture", "The hair in the drain",
|
17 |
+
"The oil smell", "The weird texture", "That flaky skin", "The cough in my face",
|
18 |
+
"The dried ketchup", "The fake smile", "The old food container", "The spit bubble",
|
19 |
+
"The breath in my face", "The noisy swallowing", "The crusty fork", "The plate of mush",
|
20 |
+
"The gurgling sound", "The smelly shoe", "The grease trap", "The used napkin",
|
21 |
+
"The close talker", "The public bathroom", "The noise he made", "The dirty towel",
|
22 |
+
"The slime on the floor", "The gunk", "The mix of smells", "The empty bottle with milk in it",
|
23 |
+
"That gesture", "The broken nail", "The eye gunk", "That noise with the throat",
|
24 |
+
"The mess in the sink", "The weird breathing", "The sticky hand", "The stinky feet",
|
25 |
+
"The trash bag leak", "The way they touched the food", "The finger licking",
|
26 |
+
"The bathroom stall", "That cracked lip", "The greasy spoon", "The way they smacked their lips"
|
27 |
+
],
|
28 |
+
"verbs": [
|
29 |
+
"turned my stomach", "made me gag", "grossed me out", "made me shiver",
|
30 |
+
"sent a wave of nausea", "was disgusting", "made me flinch", "twisted my gut",
|
31 |
+
"churned my stomach", "repulsed me", "triggered my gag reflex",
|
32 |
+
"made my skin crawl", "was revolting", "sickened me", "hit me wrong",
|
33 |
+
"burned my nose", "felt wrong", "made me dry heave", "pushed me back",
|
34 |
+
"curdled my insides", "stung my senses", "brought bile up", "left a taste in my mouth",
|
35 |
+
"smelled rancid", "tasted foul", "looked putrid", "reeked", "felt like vomit",
|
36 |
+
"left me nauseated", "knocked my appetite", "made me lose it",
|
37 |
+
"felt filthy", "stank horribly", "coated my throat", "overwhelmed my nose",
|
38 |
+
"burned into memory", "made me cough", "was rotten", "was unbearable",
|
39 |
+
"left me disgusted", "sent shivers", "twisted my face", "gave me goosebumps",
|
40 |
+
"triggered disgust", "was sickening", "was grotesque", "reeked of old filth",
|
41 |
+
"smelled like death", "was unspeakable", "looked infected", "was slimy and warm",
|
42 |
+
"was moist in the worst way", "left a greasy film", "was sticky and gross",
|
43 |
+
"stuck to everything", "oozed", "dripped in the worst way", "was decaying",
|
44 |
+
"looked chewed", "felt rancid", "smelled spoiled", "squished in my hand",
|
45 |
+
"glistened weirdly", "looked alive", "wiggled slightly", "felt damp and wrong",
|
46 |
+
"left a slime trail", "made me cover my mouth", "caused a dry retch",
|
47 |
+
"spoiled the whole room", "felt like rot", "reeked of sweat", "oozed a bit",
|
48 |
+
"smeared across the surface", "clung to the air", "infected the vibe",
|
49 |
+
"spread filth", "made everything worse", "crawled under my skin",
|
50 |
+
"was so wrong", "ruined my day", "tasted like trash", "bubbled weirdly",
|
51 |
+
"had texture like decay", "looked chewed up", "smelled like feet",
|
52 |
+
"smelled fermented", "smeared all over", "cracked and oozed", "sent a gag reflex",
|
53 |
+
"looked diseased", "felt like warm spit", "reeked of mold", "coated my senses"
|
54 |
+
],
|
55 |
+
"interjections": [
|
56 |
+
"(groans) Ew!",
|
57 |
+
"(coughs) That’s disgusting!",
|
58 |
+
"Ugh! (sniffs)",
|
59 |
+
"(gasps) I can’t believe I saw that",
|
60 |
+
"Seriously? (sighs)",
|
61 |
+
"(groans) What the hell?!",
|
62 |
+
"No no no! (clears throat)",
|
63 |
+
"Gross! (exhales)",
|
64 |
+
"(mumbles) Disgusting!",
|
65 |
+
"That made me gag! (coughs)",
|
66 |
+
"(sniffs) Yuck!",
|
67 |
+
"Oh god (groans)",
|
68 |
+
"That’s so wrong! (exhales)",
|
69 |
+
"Why would you do that?! (clears throat)",
|
70 |
+
"(inhales) Please stop",
|
71 |
+
"That smell though (sniffs)",
|
72 |
+
"No thanks (sighs)",
|
73 |
+
"I’m done (groans)",
|
74 |
+
"Absolutely vile! (coughs)",
|
75 |
+
"I could puke! (exhales)"
|
76 |
+
],
|
77 |
+
|
78 |
+
"contexts": [
|
79 |
+
"(groans) and I couldn’t unsee it",
|
80 |
+
"and I nearly gagged (coughs)",
|
81 |
+
"and I felt my throat close (exhales)",
|
82 |
+
"and I turned away (groans)",
|
83 |
+
"and I had to leave the room (exhales)",
|
84 |
+
"and my face twisted (sniffs)",
|
85 |
+
"and I backed off immediately (groans)",
|
86 |
+
"and I clenched my jaw (sighs)",
|
87 |
+
"and I wanted to scrub my brain (groans)",
|
88 |
+
"and my appetite vanished (clears throat)",
|
89 |
+
"and I wiped my hands (sniffs)",
|
90 |
+
"and I needed mouthwash (groans)",
|
91 |
+
"and I felt queasy (coughs)",
|
92 |
+
"(sniffs) and I covered my nose",
|
93 |
+
"and I said ‘gross!’ (groans)",
|
94 |
+
"and I still smell it (sniffs)",
|
95 |
+
"and I wanted to burn it all (exhales)",
|
96 |
+
"and I nearly lost it (groans)",
|
97 |
+
"and I shook my head (sighs)",
|
98 |
+
"and I whispered ‘what the hell?!’ (groans)",
|
99 |
+
"and I flinched hard (inhales)",
|
100 |
+
"and I squinted like it would go away (sighs)",
|
101 |
+
"(inhales) and I held my breath",
|
102 |
+
"and I gagged a little (coughs)",
|
103 |
+
"and I pulled back fast (groans)",
|
104 |
+
"and I rubbed my eyes (sniffs)",
|
105 |
+
"and I dry heaved (coughs)",
|
106 |
+
"and I bit my tongue to avoid screaming (clears throat)",
|
107 |
+
"and I washed my hands three times (sniffs)",
|
108 |
+
"and I didn’t look again (groans)",
|
109 |
+
"and I covered my face (exhales)",
|
110 |
+
"and I had to breathe through my mouth (inhales)",
|
111 |
+
"and I said ‘absolutely not!’ (groans)",
|
112 |
+
"and I scrubbed everything (sniffs)",
|
113 |
+
"and I needed a shower (groans)",
|
114 |
+
"and I said ‘ew!’ out loud (coughs)",
|
115 |
+
"and I sprayed the whole area (coughs)",
|
116 |
+
"and I glared at them (clears throat)",
|
117 |
+
"and I wiped the table like five times (sniffs)",
|
118 |
+
"and I wanted to bleach it (groans)",
|
119 |
+
"and I swore never again (exhales)",
|
120 |
+
"and I asked them to stop (clears throat)",
|
121 |
+
"and I backed into a corner (inhales)",
|
122 |
+
"and I wanted to scream (groans)",
|
123 |
+
"and I asked them to throw it away (clears throat)",
|
124 |
+
"and I almost dropped it (gasps)",
|
125 |
+
"and I cringed (groans)",
|
126 |
+
"and I winced hard (inhales)",
|
127 |
+
"and I had to spit (coughs)",
|
128 |
+
"and I cleaned everything I touched (sniffs)",
|
129 |
+
"and I looked away fast (groans)",
|
130 |
+
"and I felt like vomiting (coughs)",
|
131 |
+
"and I dry swallowed (sighs)",
|
132 |
+
"and I left the room immediately (exhales)",
|
133 |
+
"and I didn’t know what to do (sighs)",
|
134 |
+
"and I needed to rinse my mouth (clears throat)",
|
135 |
+
"and I still feel it on my hands (sniffs)",
|
136 |
+
"and I wiped my phone off (sniffs)",
|
137 |
+
"and I stood there stunned (gasps)",
|
138 |
+
"and I whispered ‘nope!’ (groans)",
|
139 |
+
"and I tensed up (inhales)",
|
140 |
+
"and I wiped down the table again (sniffs)",
|
141 |
+
"(coughs) and I sprayed disinfectant everywhere",
|
142 |
+
"and I held the air (inhales)",
|
143 |
+
"and I tried not to vomit (coughs)",
|
144 |
+
"and I shook it off quickly (exhales)",
|
145 |
+
"and I walked away fast (groans)",
|
146 |
+
"and I stared in horror (gasps)",
|
147 |
+
"and I flinched again (inhales)",
|
148 |
+
"and I cleaned the whole area (sniffs)",
|
149 |
+
"and I sanitized everything (clears throat)",
|
150 |
+
"and I stepped away (exhales)"
|
151 |
+
],
|
152 |
+
"templates": [
|
153 |
+
"{s} {v}, {c}.",
|
154 |
+
"{i}, {s} {v}.",
|
155 |
+
"{s} {v}. {c}.",
|
156 |
+
"{s} {v} — {c}.",
|
157 |
+
"{c}. {s} {v}.",
|
158 |
+
"{s}. It {v}, {c}.",
|
159 |
+
"{s} just... {v}. {c}.",
|
160 |
+
"{i} — {s} {v}, {c}.",
|
161 |
+
"Honestly, {s} {v}. {c}.",
|
162 |
+
"I couldn’t take it — {s} {v}, {c}.",
|
163 |
+
"It was foul. {s} {v}. {c}.",
|
164 |
+
"Just thinking about it — {s} {v}. {c}.",
|
165 |
+
"‘{s} {v}’ — yeah, I’m out. {c}.",
|
166 |
+
"{s} made me gag, no question. {c}.",
|
167 |
+
"I still feel sick. {s} {v}. {c}.",
|
168 |
+
"It turned my gut — {s} {v}, {c}.",
|
169 |
+
"{s}? Just vile. {v}, {c}.",
|
170 |
+
"Disgust doesn’t even cover it — {s} {v}, {c}.",
|
171 |
+
"{i}. Seriously, {s} {v}. {c}.",
|
172 |
+
"Every time I remember — {s} {v}. {c}."
|
173 |
+
]
|
174 |
+
}
|
emotion_templates/fear.json
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"subjects": [
|
3 |
+
"That noise", "The silence", "This feeling", "His stare", "Her whisper", "That look",
|
4 |
+
"The shadow", "The knock", "That voice", "The unknown", "The hallway", "The message",
|
5 |
+
"The basement", "The flicker", "That sound behind me", "The reflection", "The darkness",
|
6 |
+
"The quiet", "That face", "The empty street", "The open door", "The figure in the distance",
|
7 |
+
"The footsteps", "The breath I heard", "The locked room", "The sudden cold", "The static",
|
8 |
+
"The glitch", "That motion", "The silence after", "That sudden stop", "The blank screen",
|
9 |
+
"The basement light", "The empty chair", "The way it stopped", "The timing", "The phone ringing",
|
10 |
+
"The strange call", "The echo", "The scream", "The eye contact", "The ticking", "The delay",
|
11 |
+
"That message at night", "The wrong number", "The unexpected knock", "The closed curtain",
|
12 |
+
"The way it moved", "The window", "The mirror", "The shape", "The silence on the line",
|
13 |
+
"The dim hallway", "That heavy breath", "The door creaking", "The chill", "That shadow on the wall",
|
14 |
+
"The open browser tab", "The stillness", "The unanswered call", "The flickering light",
|
15 |
+
"The power going out", "That smell", "The voice in my head", "The basement stairs",
|
16 |
+
"The unknown number", "The sudden pause", "The air around me", "The shape behind me",
|
17 |
+
"The clicking sound", "The lack of sound", "That dream", "The phone buzz", "The siren",
|
18 |
+
"That change in tone", "The knock I didn’t expect", "The door left open", "That breath I didn’t take",
|
19 |
+
"The space under the bed", "The open closet", "The camera turning on", "The power surge",
|
20 |
+
"That sudden silence", "The hallway light", "The alarm", "The code not working",
|
21 |
+
"That sudden breeze", "The voice that stopped", "The quiet tap", "The screen glitch",
|
22 |
+
"That blinking cursor", "The slow footsteps", "The eye in the dark", "The frozen time",
|
23 |
+
"The disconnect tone", "The feeling of being watched", "The stop in motion", "The email I didn’t send"
|
24 |
+
],
|
25 |
+
"verbs": [
|
26 |
+
"froze me", "sent a chill down my spine", "made my breath catch", "stopped my heart",
|
27 |
+
"tied my stomach in knots", "paralyzed me", "had me holding my breath", "rushed my pulse",
|
28 |
+
"tensed every muscle", "squeezed my chest", "made me dizzy", "triggered my panic",
|
29 |
+
"gripped me", "shook me", "dried my mouth", "shrank my world", "knocked me still",
|
30 |
+
"pressed on my lungs", "made my skin crawl", "trembled through me", "dug into my mind",
|
31 |
+
"raced through my thoughts", "tightened my throat", "blanked my head", "flooded me with fear",
|
32 |
+
"spiked my heart rate", "stole my breath", "made me freeze in place", "clouded my focus",
|
33 |
+
"hit like dread", "echoed in my chest", "spun my senses", "stung like ice", "gripped my spine",
|
34 |
+
"whitened my face", "stopped everything", "grew louder in my head", "seeped into me",
|
35 |
+
"dragged down my breath", "set my nerves on edge", "crawled over my skin", "burned in my gut",
|
36 |
+
"snapped my focus", "locked my body", "flooded my system", "choked my breath",
|
37 |
+
"tightened my jaw", "made my heartbeat deafening", "pressed me into stillness",
|
38 |
+
"caged my breath", "made my ears ring", "left me lightheaded", "slowed down time",
|
39 |
+
"reduced me to silence", "took over", "made my hands shake", "bent my knees",
|
40 |
+
"tugged at my ribs", "dimmed my sight", "pulled me inward", "buried my voice",
|
41 |
+
"unraveled me", "set alarms off inside", "curled my fingers", "hit my core",
|
42 |
+
"trembled in my jaw", "scrambled my thoughts", "turned me to stone", "squeezed my spine",
|
43 |
+
"widened my eyes", "flushed my skin", "knocked the air out", "rattled my head",
|
44 |
+
"sat heavy in my chest", "clutched at my throat", "sent panic through me",
|
45 |
+
"knotted my back", "held me hostage", "blinded my thinking", "muffled my hearing",
|
46 |
+
"rose like a wave", "climbed through my limbs", "forced my jaw shut", "twisted in my gut",
|
47 |
+
"burned through my nerves", "flooded my eyes", "froze my voice", "slammed my chest",
|
48 |
+
"disoriented me", "hollowed me out", "rushed like a scream", "scraped my bones",
|
49 |
+
"throbbed in my neck", "tightened like a noose", "spun in my head", "felt like doom"
|
50 |
+
],
|
51 |
+
"contexts": [
|
52 |
+
"(gasps) and I couldn’t move",
|
53 |
+
"and I held my breath (inhales)",
|
54 |
+
"and my body locked up (exhales)",
|
55 |
+
"and I didn’t know what to do (sighs)",
|
56 |
+
"and I felt eyes on me (inhales)",
|
57 |
+
"and I just listened (sighs)",
|
58 |
+
"and I froze in place (gasps)",
|
59 |
+
"and everything felt too quiet (exhales)",
|
60 |
+
"and I scanned the room (inhales)",
|
61 |
+
"and I backed away (gasps)",
|
62 |
+
"and I reached for my phone (inhales)",
|
63 |
+
"and I whispered nothing (sighs)",
|
64 |
+
"and I stayed absolutely still (exhales)",
|
65 |
+
"and I clenched my jaw",
|
66 |
+
"and I waited in silence (exhales)",
|
67 |
+
"and I couldn’t see clearly (exhales)",
|
68 |
+
"and my heart was pounding (inhales)",
|
69 |
+
"and the air felt wrong (sighs)",
|
70 |
+
"and I tiptoed (inhales)",
|
71 |
+
"and my hands were shaking (exhales)",
|
72 |
+
"and I hoped it wasn’t real (sighs)",
|
73 |
+
"and I didn’t dare speak (inhales)",
|
74 |
+
"and I tried to stay calm (exhales)",
|
75 |
+
"and the walls felt closer (gasps)",
|
76 |
+
"and the room spun (exhales)",
|
77 |
+
"and I avoided the mirror (sighs)",
|
78 |
+
"and I wished I hadn’t heard that (exhales)",
|
79 |
+
"and the lights flickered (gasps)",
|
80 |
+
"and the sound repeated (inhales)",
|
81 |
+
"and my fingers curled (exhales)",
|
82 |
+
"and I stood frozen (gasps)",
|
83 |
+
"and I looked over my shoulder (inhales)",
|
84 |
+
"and I barely breathed (exhales)",
|
85 |
+
"and I felt like I was watched (gasps)",
|
86 |
+
"and I clutched my chest (inhales)",
|
87 |
+
"and my breath stopped (exhales)",
|
88 |
+
"and I whispered ‘hello?’ (sighs)",
|
89 |
+
"and I stepped back slowly (exhales)",
|
90 |
+
"and I avoided eye contact (sighs)",
|
91 |
+
"and I didn’t turn around (inhales)",
|
92 |
+
"and I closed my eyes tight (exhales)",
|
93 |
+
"and I mouthed ‘please no’ (sighs)",
|
94 |
+
"and I turned the light on (inhales)",
|
95 |
+
"and I checked the door (exhales)",
|
96 |
+
"and I hoped it was nothing (sighs)",
|
97 |
+
"and I bit my lip (inhales)",
|
98 |
+
"and I looked again (gasps)",
|
99 |
+
"and I grabbed the handle (inhales)",
|
100 |
+
"and I watched the hallway (exhales)",
|
101 |
+
"and I held the wall (sighs)",
|
102 |
+
"and I tried not to blink (inhales)",
|
103 |
+
"and I prayed I was wrong (exhales)",
|
104 |
+
"and I counted my breaths (sighs)",
|
105 |
+
"and I waited for a sound (inhales)",
|
106 |
+
"and I pressed my back to the wall (exhales)",
|
107 |
+
"and I closed the tab (sighs)",
|
108 |
+
"and I locked my phone (inhales)",
|
109 |
+
"and I shook (exhales)",
|
110 |
+
"and I tried to steady myself (sighs)",
|
111 |
+
"and I looked away (exhales)",
|
112 |
+
"and I froze mid‑step (gasps)",
|
113 |
+
"and I turned down the volume (sighs)",
|
114 |
+
"and I stood in the dark (inhales)",
|
115 |
+
"and I checked the lock (exhales)",
|
116 |
+
"and I peeked through the curtain (gasps)",
|
117 |
+
"and I watched the shadow move (inhales)",
|
118 |
+
"and I saw it again (exhales)",
|
119 |
+
"and I hoped I imagined it (sighs)",
|
120 |
+
"and I wanted to scream (gasps)",
|
121 |
+
"and I felt something shift (inhales)",
|
122 |
+
"and I barely exhaled (exhales)",
|
123 |
+
"and I pressed mute (sighs)",
|
124 |
+
"and I whispered ‘stop’ (inhales)",
|
125 |
+
"and I stared into the dark (exhales)",
|
126 |
+
"and I held my breath again (inhales)",
|
127 |
+
"and I tried to wake up (sighs)",
|
128 |
+
"and I reached for the light (inhales)",
|
129 |
+
"and I stayed under the blanket (exhales)",
|
130 |
+
"and I whispered to myself (sighs)",
|
131 |
+
"and I looked at the screen again (inhales)",
|
132 |
+
"and I locked the window (exhales)",
|
133 |
+
"and I flinched (gasps)",
|
134 |
+
"and I felt it behind me (inhales)",
|
135 |
+
"and I turned off the TV (exhales)",
|
136 |
+
"and I pulled the covers tight (sighs)"
|
137 |
+
],
|
138 |
+
|
139 |
+
"interjections": [
|
140 |
+
"(gasps) I felt a chill!",
|
141 |
+
"(sighs) Something’s off.",
|
142 |
+
"(exhales) This can’t be good.",
|
143 |
+
"I don’t like this (inhales).",
|
144 |
+
"(inhales) What if it’s real?",
|
145 |
+
"(gasps) No way!",
|
146 |
+
"I swear I saw something (exhales).",
|
147 |
+
"(inhales) Did you hear that?",
|
148 |
+
"(sighs) Why is it so quiet?",
|
149 |
+
"That wasn’t right (exhales).",
|
150 |
+
"(inhales) Is someone there?",
|
151 |
+
"Not again (exhales)...",
|
152 |
+
"(exhales) That gave me chills.",
|
153 |
+
"(gasps) What the...",
|
154 |
+
"It felt wrong (sighs).",
|
155 |
+
"(gasps) I froze!",
|
156 |
+
"(gasps) Something moved!",
|
157 |
+
"(sighs) It’s too quiet.",
|
158 |
+
"(exhales) This isn’t normal.",
|
159 |
+
"(inhales) I’m not alone."
|
160 |
+
],
|
161 |
+
"templates": [
|
162 |
+
"{s} {v}, {c}.",
|
163 |
+
"{i}, {s} {v}.",
|
164 |
+
"{s} {v}. {c}.",
|
165 |
+
"{s} {v} — {c}.",
|
166 |
+
"{c}. {s} {v}.",
|
167 |
+
"{s}. It {v}, {c}.",
|
168 |
+
"{s} just... {v}. {c}.",
|
169 |
+
"{i} — {s} {v}, {c}.",
|
170 |
+
"{s} {v}. I couldn’t blink. {c}.",
|
171 |
+
"And then it happened: {s} {v}, {c}.",
|
172 |
+
"‘{s} {v}’ — and I froze. {c}.",
|
173 |
+
"It was quiet, too quiet. Then {s} {v}. {c}.",
|
174 |
+
"I stopped breathing when {s} {v}. {c}.",
|
175 |
+
"No one else noticed, but {s} {v}, {c}.",
|
176 |
+
"{s} {v}. My pulse spiked. {c}."
|
177 |
+
]
|
178 |
+
}
|
emotion_templates/happy.json
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"subjects": [
|
3 |
+
"That moment", "This message", "Her smile", "His laugh", "The sunshine", "The good news",
|
4 |
+
"That call", "The unexpected text", "The gift", "The surprise", "That compliment",
|
5 |
+
"The way they looked at me", "This day", "The win", "That hug", "The way it happened",
|
6 |
+
"Their reaction", "That unexpected turn", "This feeling", "That morning",
|
7 |
+
"The celebration", "The results", "That gesture", "The quiet joy", "The party",
|
8 |
+
"The sound of laughter", "The light", "That one word", "The surprise visit",
|
9 |
+
"The smell of coffee", "The fresh air", "The freedom", "That one dance",
|
10 |
+
"The sunrise", "The music", "The rhythm", "The warm breeze", "That smile across the room",
|
11 |
+
"The perfect timing", "The peace", "The joke", "The message I didn’t expect",
|
12 |
+
"That moment I opened the door", "The reaction on their face", "That memory",
|
13 |
+
"The sparkle in their eyes", "The unexpected break", "The lucky moment",
|
14 |
+
"The right words", "The song that played", "The cake", "The light in the room",
|
15 |
+
"The kids laughing", "The compliments", "The photo", "The way they cheered",
|
16 |
+
"The shared silence", "That warm feeling", "The burst of joy", "The freedom I felt",
|
17 |
+
"The high five", "The small win", "The huge success", "The laughter we shared",
|
18 |
+
"The trip", "The story", "The goofy face", "The text with all caps",
|
19 |
+
"The silly moment", "The happy tear", "The peace in my heart", "The sparkle",
|
20 |
+
"The good vibes", "The cozy night", "The unexpected yes", "The sudden dance",
|
21 |
+
"The applause", "The open road", "The sunshine on my face", "The joke that landed",
|
22 |
+
"The moment they said it", "The sparkle of hope", "The confetti moment", "That yes",
|
23 |
+
"The bounce in their step", "The rush of relief", "The random kindness",
|
24 |
+
"The shared look", "The clink of glasses", "The moment of pride", "The nod of approval",
|
25 |
+
"The lightness in the air", "The smell of cookies", "The favorite song",
|
26 |
+
"The surprise package", "The right time", "That perfect moment", "The cheers",
|
27 |
+
"The warmth in the room", "The silly dance", "The peace in the pause", "The bright light"
|
28 |
+
],
|
29 |
+
"verbs": [
|
30 |
+
"lit me up", "made me laugh out loud", "warmed my chest", "brought me peace",
|
31 |
+
"lifted my mood", "was pure joy", "felt like flying", "was exactly what I needed",
|
32 |
+
"made me grin", "sparked something bright", "filled me with light",
|
33 |
+
"was everything", "cheered me up", "set the tone", "made the whole day",
|
34 |
+
"was perfect", "was a blessing", "felt magical", "recharged me", "was unforgettable",
|
35 |
+
"made my heart dance", "put a smile on my face", "made my eyes shine",
|
36 |
+
"woke up my soul", "filled the room", "shimmered through me", "set me free",
|
37 |
+
"bubbled up inside", "sparked so much joy", "brought out my laughter",
|
38 |
+
"was heartwarming", "had me laughing instantly", "soothed my spirit",
|
39 |
+
"gave me goosebumps", "sparkled", "changed everything", "was so genuine",
|
40 |
+
"had me beaming", "was sunshine", "shined bright", "was golden",
|
41 |
+
"made the room feel lighter", "brought instant joy", "was beautiful",
|
42 |
+
"echoed with joy", "brightened the moment", "lifted the weight",
|
43 |
+
"refreshed my spirit", "wrapped me in warmth", "was the highlight",
|
44 |
+
"gave me hope", "touched my heart", "made me giggle", "felt light and free",
|
45 |
+
"restored my smile", "made me blush", "sparkled in me", "woke something kind",
|
46 |
+
"was soft and strong", "wrapped the moment", "was like magic", "felt like home",
|
47 |
+
"was full of grace", "was the best surprise", "was the perfect fit", "bloomed inside me",
|
48 |
+
"brought clarity and calm", "was so full of color", "set off a spark",
|
49 |
+
"played like a melody", "made everything brighter", "was like spring inside me",
|
50 |
+
"was sweetness", "glowed in me", "reminded me of good things", "restored my joy",
|
51 |
+
"hugged me from within", "was full of wonder", "sparkled in my chest", "was so soft",
|
52 |
+
"helped me breathe again", "was a true gift", "felt so human", "rang like laughter",
|
53 |
+
"set me dancing", "was pure love", "made me giggle-snort", "sent joy down my spine",
|
54 |
+
"was completely silly", "was just right", "was light as air", "was like a hug",
|
55 |
+
"fizzed like soda", "made my heart sing", "was total delight", "filled the space with light",
|
56 |
+
"was cozy and bright", "carried me", "felt so real", "just made sense", "bounced off the walls"
|
57 |
+
],
|
58 |
+
|
59 |
+
"interjections": [
|
60 |
+
"(laughs) I couldn’t believe it!",
|
61 |
+
"No way! (claps)",
|
62 |
+
"That made my day, (laughs).",
|
63 |
+
"(chuckle)... You have no idea!",
|
64 |
+
"(whistles) Honestly, amazing!",
|
65 |
+
"So good! (laughs)",
|
66 |
+
"It was perfect — (applause).",
|
67 |
+
"(laughs) Pure joy!",
|
68 |
+
"Absolutely loved it! (claps)",
|
69 |
+
"Couldn’t stop grinning! (chuckle)",
|
70 |
+
"Wow! (whistles)",
|
71 |
+
"(laughs) This lit me up!",
|
72 |
+
"(whistles)... Just yes!",
|
73 |
+
"Loved every second! (laughs)",
|
74 |
+
"(chuckle) Total happiness!",
|
75 |
+
"It hit perfectly! (whistles)",
|
76 |
+
"(laughs) I’m smiling again.",
|
77 |
+
"Everything clicked — (claps)!",
|
78 |
+
"I was glowing... (humming)"
|
79 |
+
],
|
80 |
+
|
81 |
+
"contexts": [
|
82 |
+
"(laughs)... and I couldn’t stop smiling!",
|
83 |
+
"and everything felt lighter (exhales).",
|
84 |
+
"and I didn’t want it to end — (sighs).",
|
85 |
+
"(whistles) and I wanted to dance!",
|
86 |
+
"(laughs) and I couldn’t stop laughing!",
|
87 |
+
"and the room lit up — (claps)!",
|
88 |
+
"and I held onto it all day... (humming)",
|
89 |
+
"and it made my week! (laughs)",
|
90 |
+
"and I smiled without even realizing (chuckle).",
|
91 |
+
"and I was so full of energy (humming)!",
|
92 |
+
"and I felt so free... (exhales)",
|
93 |
+
"and nothing else mattered.",
|
94 |
+
"and I just wanted to share it (laughs)!",
|
95 |
+
"(laughs) and I laughed until my sides hurt!",
|
96 |
+
"and I couldn’t stop replaying it (whistles)...",
|
97 |
+
"and I felt ten pounds lighter (exhales)!",
|
98 |
+
"and the moment sparkled (humming)...",
|
99 |
+
"and I forgot all the stress — (sighs)",
|
100 |
+
"and I hugged them tight!",
|
101 |
+
"and I just soaked it in... (exhales)",
|
102 |
+
"and it made everything worth it! (laughs)",
|
103 |
+
"(inhales) and I breathed it in.",
|
104 |
+
"and it stayed with me (whistles)...",
|
105 |
+
"and it was everything I needed (sighs).",
|
106 |
+
"and I danced a little inside (whistles)!",
|
107 |
+
"and I caught myself smiling later (chuckle).",
|
108 |
+
"and I said “finally!” (laughs)",
|
109 |
+
"and my heart felt full (humming).",
|
110 |
+
"and I couldn’t stop giggling (laughs)!",
|
111 |
+
"and I felt alive again (exhales).",
|
112 |
+
"and I whispered “yes!” (whistles).",
|
113 |
+
"and I high‑fived the air (claps)!",
|
114 |
+
"and I had the biggest grin (chuckle)!",
|
115 |
+
"and it lifted the whole moment (laughs).",
|
116 |
+
"and the day turned golden (whistles)...",
|
117 |
+
"and the silence was perfect — (sighs).",
|
118 |
+
"and I was glowing (humming)!",
|
119 |
+
"and I bounced around (laughs)!",
|
120 |
+
"and the sun seemed brighter (whistles)!",
|
121 |
+
"and even coffee tasted better (chuckle)...",
|
122 |
+
"and it made everything okay (sighs).",
|
123 |
+
"and I had happy tears (laughs)...",
|
124 |
+
"and I just twirled (whistles)!",
|
125 |
+
"and I hugged myself (humming).",
|
126 |
+
"(laughs) and I laughed out loud in public!",
|
127 |
+
"and it gave me wings (whistles)...",
|
128 |
+
"and I wanted to bottle that feeling (humming).",
|
129 |
+
"(humming)... and I hummed without thinking!",
|
130 |
+
"and I sent five heart emojis (claps).",
|
131 |
+
"and I didn’t need words (exhales)...",
|
132 |
+
"and I knew it was real (sighs).",
|
133 |
+
"and I just sat in the joy (exhales).",
|
134 |
+
"and it filled the room (laughs)!",
|
135 |
+
"and I felt carried (humming)...",
|
136 |
+
"and I danced in my chair (whistles)!",
|
137 |
+
"and it echoed in my bones (humming).",
|
138 |
+
"and I forgot why I was sad (exhales)...",
|
139 |
+
"and I thanked the moment (laughs).",
|
140 |
+
"and I could finally breathe (exhales)!",
|
141 |
+
"and I turned up the music (singing)!",
|
142 |
+
"and I felt hugged (humming).",
|
143 |
+
"and I danced barefoot (whistles)...",
|
144 |
+
"and I made a goofy face (laughs).",
|
145 |
+
"and I felt so warm (exhales).",
|
146 |
+
"and I cried happy tears (laughs)!",
|
147 |
+
"and I replayed it in my head (chuckle).",
|
148 |
+
"and I couldn’t believe it (gasps)!",
|
149 |
+
"and I wanted to freeze time (sighs)...",
|
150 |
+
"and I smiled at a stranger (laughs).",
|
151 |
+
"and I whispered “thank you.” (sighs)",
|
152 |
+
"and I let myself laugh (laughs)!",
|
153 |
+
"and I felt like a kid again (whistles).",
|
154 |
+
"and I laughed‑snorted (laughs)!",
|
155 |
+
"and I leaned into it (exhales)...",
|
156 |
+
"and I closed my eyes and smiled (humming).",
|
157 |
+
"and I sang out loud (singing)!",
|
158 |
+
"and it gave me joy to spare (laughs)...",
|
159 |
+
"and I did a little spin (whistles)!",
|
160 |
+
"and I raised my hands — (claps)!",
|
161 |
+
"(claps)... and I clapped again!",
|
162 |
+
"and I wanted to shout it (laughs)!",
|
163 |
+
"and I texted five people (chuckle).",
|
164 |
+
"and I smiled like an idiot (laughs)!",
|
165 |
+
"and I started dancing (whistles)!",
|
166 |
+
"and I breathed deep and smiled (exhales).",
|
167 |
+
"and I grinned the rest of the day (laughs)!",
|
168 |
+
"and I fist‑pumped the air (claps)!"
|
169 |
+
],
|
170 |
+
"templates": [
|
171 |
+
"{s} {v}, {c}.",
|
172 |
+
"{i}, {s} {v}.",
|
173 |
+
"{s} {v}. {c}.",
|
174 |
+
"{s} {v} — {c}.",
|
175 |
+
"{c}. {s} {v}.",
|
176 |
+
"{s}. It {v}, {c}.",
|
177 |
+
"{s} just... {v}. {c}.",
|
178 |
+
"{i} — {s} {v}, {c}.",
|
179 |
+
"{s} {v}. I smiled so hard. {c}.",
|
180 |
+
"It was simple: {s} {v}. {c}.",
|
181 |
+
"And there it was: {s} {v}. {c}.",
|
182 |
+
"I couldn’t stop smiling — {s} {v}, {c}.",
|
183 |
+
"{s} {v}, and I laughed like a child. {c}.",
|
184 |
+
"It all made sense when {s} {v}. {c}.",
|
185 |
+
"You wouldn’t believe it: {s} {v}, {c}."
|
186 |
+
]
|
187 |
+
}
|
emotion_templates/neutral.json
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"subjects": [
|
3 |
+
"The meeting", "This email", "The weather", "My schedule", "The report",
|
4 |
+
"The file", "The documents", "The update", "The list", "The package",
|
5 |
+
"That call", "The deadline", "My calendar", "The appointment", "The client",
|
6 |
+
"The system", "The form", "The issue", "The request", "This link",
|
7 |
+
"The procedure", "That number", "The login", "The ticket", "The reminder",
|
8 |
+
"The message", "The time", "The place", "The project", "The plan",
|
9 |
+
"The item", "The summary", "The spreadsheet", "The announcement",
|
10 |
+
"The room", "The schedule", "The response", "The folder", "The draft",
|
11 |
+
"The setup", "The chart", "The template", "This section", "The address",
|
12 |
+
"The status", "The link", "The note", "The user", "The change", "The version",
|
13 |
+
"The page", "The setting", "The system log", "The data", "The record",
|
14 |
+
"The feedback", "The form submission", "The result", "The agenda", "The channel",
|
15 |
+
"The section", "The backend", "The field", "The login screen", "The main page",
|
16 |
+
"The recent edit", "The input", "The draft proposal", "The video", "The request form",
|
17 |
+
"The calendar update", "The hour", "The alert", "The setting change", "The comment",
|
18 |
+
"The text box", "The dashboard", "The timezone", "The notification", "The next step",
|
19 |
+
"The change log", "The ping", "The tool", "The auto-reply", "The file name",
|
20 |
+
"The browser tab", "The screen", "The online form", "The number of users",
|
21 |
+
"The archive", "The post", "The backend tool", "The session", "The queue",
|
22 |
+
"The timestamp", "The log", "The script", "The recent call", "The API endpoint",
|
23 |
+
"The shortcut", "The current state", "The setting value", "The message ID"
|
24 |
+
],
|
25 |
+
"verbs": [
|
26 |
+
"was updated", "was received", "was sent", "is scheduled", "was delivered",
|
27 |
+
"was attached", "is listed", "is included", "is available", "was confirmed",
|
28 |
+
"is submitted", "was opened", "is set", "was approved", "was completed",
|
29 |
+
"was reviewed", "was filed", "is noted", "was mentioned", "was discussed",
|
30 |
+
"was forwarded", "was generated", "was uploaded", "was located", "was exported",
|
31 |
+
"was restarted", "was removed", "was closed", "was fixed", "is resolved",
|
32 |
+
"is visible", "was highlighted", "was added", "was changed", "was renamed",
|
33 |
+
"was stored", "was marked", "was tracked", "was logged", "was selected",
|
34 |
+
"was unchecked", "is prefilled", "is functional", "was muted", "was flagged",
|
35 |
+
"was logged again", "was copied", "was moved", "was parsed", "was captured",
|
36 |
+
"was identified", "was tested", "was analyzed", "was displayed", "was found",
|
37 |
+
"was active", "was paused", "was finalized", "was shared", "was cleared",
|
38 |
+
"was replaced", "was validated", "was reloaded", "was timed out", "was confirmed again",
|
39 |
+
"was repeated", "was archived", "was restarted", "was merged", "was commented on",
|
40 |
+
"was clicked", "was created", "was auto-saved", "was viewed", "was hidden",
|
41 |
+
"was unchecked", "was dismissed", "was shortened", "was synced", "was noted again",
|
42 |
+
"was converted", "was reapproved", "was minimized", "was restored", "was synced again",
|
43 |
+
"was edited", "was translated", "was resized", "was expanded", "was called in",
|
44 |
+
"was reentered", "was backed up", "was counted", "was referenced", "was linked"
|
45 |
+
],
|
46 |
+
"contexts": [
|
47 |
+
"and noted accordingly", "as expected", "without any issues", "for review",
|
48 |
+
"and is ready", "based on the plan", "according to the request",
|
49 |
+
"and marked complete", "for reference", "with no delay",
|
50 |
+
"and synced properly", "as part of the update", "and stored correctly",
|
51 |
+
"and reviewed again", "without error", "as per instructions",
|
52 |
+
"and appears consistent", "with the rest of the files", "in the system",
|
53 |
+
"as mentioned earlier", "without further input", "at the same time",
|
54 |
+
"on the main screen", "in the report summary", "in the log", "as scheduled",
|
55 |
+
"and is linked correctly", "with updated fields", "as noted", "with default settings",
|
56 |
+
"on the dashboard", "in the latest build", "with minor differences",
|
57 |
+
"with unchanged parameters", "based on previous input", "in the final draft",
|
58 |
+
"in the archive", "and visible now", "within limits", "with expected output",
|
59 |
+
"as per the instructions", "without manual entry", "without formatting issues",
|
60 |
+
"with default layout", "with no duplicates", "and updated again",
|
61 |
+
"with no change required", "on the same page", "with standard formatting",
|
62 |
+
"in the initial pass", "in the preview", "on request", "under normal load",
|
63 |
+
"with corrected values", "in current use", "in the test environment",
|
64 |
+
"with current permissions", "after the last restart", "on page load",
|
65 |
+
"in the dropdown list", "after refresh", "with no visible difference",
|
66 |
+
"under the given conditions", "within tolerance", "for later reference",
|
67 |
+
"during the last session", "on confirmation", "as of now", "as reviewed",
|
68 |
+
"without alert", "and is part of the bundle", "without recent changes",
|
69 |
+
"on repeated use", "in line with policy", "and saved successfully",
|
70 |
+
"in its original format", "with logged history", "from default mode",
|
71 |
+
"and loaded correctly", "with stable results", "with current context",
|
72 |
+
"during the trial", "in admin view", "in preview mode", "from last session"
|
73 |
+
],
|
74 |
+
"interjections": [
|
75 |
+
"Okay", "Sure", "All right", "Understood", "Makes sense",
|
76 |
+
"Got it", "Thanks", "No problem", "Noted", "All good",
|
77 |
+
"Sounds fine", "Looks okay", "That works", "Cool", "Done",
|
78 |
+
"Let me check", "Confirmed", "Updated", "Logged", "As expected"
|
79 |
+
],
|
80 |
+
"templates": [
|
81 |
+
"{s} {v}, {c}.",
|
82 |
+
"{i}, {s} {v}.",
|
83 |
+
"{s} {v}. {c}.",
|
84 |
+
"{s} {v} — {c}.",
|
85 |
+
"{c}. {s} {v}.",
|
86 |
+
"{s} {v}.",
|
87 |
+
"{s}. It {v}, {c}.",
|
88 |
+
"{i}. {s} {v}, {c}.",
|
89 |
+
"Just FYI: {s} {v}, {c}.",
|
90 |
+
"Note: {s} {v}, {c}.",
|
91 |
+
"{s} {v}. No further action required.",
|
92 |
+
"Update: {s} {v}, {c}.",
|
93 |
+
"For reference, {s} {v}.",
|
94 |
+
"Also, {s} {v}.",
|
95 |
+
"{s} {v} as planned. {c}."
|
96 |
+
]
|
97 |
+
}
|
emotion_templates/sad.json
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"subjects": [
|
3 |
+
"That memory", "This feeling", "Her silence", "His voice", "The goodbye",
|
4 |
+
"That call", "This day", "My past", "That old photo", "The message",
|
5 |
+
"What she said", "What he didn’t say", "The look in their eyes", "That time of year",
|
6 |
+
"The room we used to share", "That place", "The last conversation", "That empty chair",
|
7 |
+
"That letter", "The moment I knew", "That one decision", "The song we both liked",
|
8 |
+
"My mistake", "Her absence", "The words I never said", "The things I lost",
|
9 |
+
"The silence after", "His last text", "My regret", "That broken promise",
|
10 |
+
"That question", "The goodbye hug", "The final message", "The void", "The shadow of it all",
|
11 |
+
"My own thoughts", "The long nights", "That walk alone", "Her last smile",
|
12 |
+
"That feeling of loss", "The unread message", "The sound of her name", "The forgotten laugh",
|
13 |
+
"The unspoken words", "This moment", "The fading photo", "The conversation that never came",
|
14 |
+
"The look I still remember", "The missed call", "This quiet", "What we never had",
|
15 |
+
"His empty side", "The empty inbox", "The things left unsaid", "The cracked frame",
|
16 |
+
"That empty promise", "The silence in the car", "The distance", "The space between us",
|
17 |
+
"The moment everything changed", "The unmade bed", "This weight", "The last gift",
|
18 |
+
"The things I should’ve done", "This cold air", "That long pause", "The truth I denied",
|
19 |
+
"The tears I hid", "The smile I faked", "The echo in my chest", "The memories flooding back",
|
20 |
+
"The half-empty cup", "That slow fade", "That goodbye text", "The waiting",
|
21 |
+
"That unanswered question", "This morning", "His coat", "The words on the page",
|
22 |
+
"That familiar place", "The night we stopped talking", "The unread letters",
|
23 |
+
"The unsent messages", "That photo on my phone", "That song I can’t skip",
|
24 |
+
"That cold seat", "The quiet kitchen", "The dim light", "That empty room",
|
25 |
+
"The thing I never said", "The end", "That hallway", "The couch we sat on",
|
26 |
+
"The words that hurt", "The moment we broke", "That one night", "The feeling of being left",
|
27 |
+
"What I carry", "That date", "The thing I remember most", "The after"
|
28 |
+
],
|
29 |
+
"verbs": [
|
30 |
+
"still hurts", "breaks me", "keeps me awake", "never fades", "hits me hard",
|
31 |
+
"leaves me numb", "crushed me", "stays with me", "burns inside", "tears me apart",
|
32 |
+
"weighs on me", "makes me cry", "won’t let go", "haunts me", "hurts deeply",
|
33 |
+
"shatters me", "lingers", "cuts deep", "took a part of me", "sits heavy in my chest",
|
34 |
+
"aches quietly", "feels endless", "never really left", "drowns me", "pulls me down",
|
35 |
+
"brings tears", "holds me still", "makes my chest heavy", "chills me",
|
36 |
+
"drains me", "fills me with sorrow", "leaves a mark", "empties me", "blurs everything",
|
37 |
+
"mutes my world", "dims my light", "echoes loudly", "steals my breath",
|
38 |
+
"closes my throat", "breaks the silence", "pushes me down", "holds me back",
|
39 |
+
"reminds me", "sticks like glue", "crawls under my skin", "hangs over me",
|
40 |
+
"wraps around me", "dampens everything", "stings", "feels like drowning",
|
41 |
+
"pulls at my heart", "sinks deep", "won’t leave", "burns behind my eyes",
|
42 |
+
"won’t let me forget", "drags on", "never ends", "feels permanent",
|
43 |
+
"presses on my chest", "washes over me", "never gets easier", "finds me in the quiet",
|
44 |
+
"catches me off guard", "unravels me", "pours in like rain", "sits in my lungs",
|
45 |
+
"breaks the morning", "bleeds into my days", "clouds my thoughts", "chokes me softly",
|
46 |
+
"grips my stomach", "twists my mind", "sinks into my bones", "grays out the world",
|
47 |
+
"softens my voice", "shadows my smile", "brings a wave", "holds my breath",
|
48 |
+
"dims the sun", "keeps repeating", "echoes through", "shuts me down",
|
49 |
+
"darkens the room", "muffles everything", "screams inside", "keeps me from moving",
|
50 |
+
"follows me", "closes in", "blurs my days", "fogs my memory", "empties the joy",
|
51 |
+
"clutches my chest", "keeps showing up", "presses into me", "sits beside me",
|
52 |
+
"floods my eyes", "weighs down my steps", "never softens", "never truly heals"
|
53 |
+
],
|
54 |
+
"interjections": [
|
55 |
+
"(sighs)... I still can't believe it.",
|
56 |
+
"Some things never leave (sniffs).",
|
57 |
+
"(mumbles) Honestly...",
|
58 |
+
"(exhales) It just hurts.",
|
59 |
+
"I wish I could say it (sighs)...",
|
60 |
+
"If only they knew (inhales)...",
|
61 |
+
"(clears throat) No one sees it.",
|
62 |
+
"Deep down (sniffs)...",
|
63 |
+
"Truth is — (sighs) I’m still there.",
|
64 |
+
"Even now... (exhales)",
|
65 |
+
"(mumbles) I never said it.",
|
66 |
+
"(sniffs)... It stays with me.",
|
67 |
+
"Sometimes I wonder (sighs)...",
|
68 |
+
"I still carry it (exhales).",
|
69 |
+
"To this day — (sniffs) it aches.",
|
70 |
+
"(sighs) I remember everything.",
|
71 |
+
"Some pain doesn’t show (mumbles)...",
|
72 |
+
"Even in silence (exhales)...",
|
73 |
+
"That moment still plays (sniffs).",
|
74 |
+
"And yet I smile (sighs)..."
|
75 |
+
],
|
76 |
+
|
77 |
+
"contexts": [
|
78 |
+
"(sighs)... and I can’t move on.",
|
79 |
+
"like nothing will ever be the same (exhales).",
|
80 |
+
"and I keep holding it in (sniffs)...",
|
81 |
+
"though no one sees it (sighs).",
|
82 |
+
"and I hide it so well (inhales)...",
|
83 |
+
"and I still feel it all (exhales).",
|
84 |
+
"and I smile like I’m fine (mumbles)...",
|
85 |
+
"though I’m breaking inside (sighs).",
|
86 |
+
"but I never talk about it (sniffs).",
|
87 |
+
"and it follows me everywhere (exhales)...",
|
88 |
+
"and I still carry it (sighs)...",
|
89 |
+
"like it just happened (inhales).",
|
90 |
+
"and I don’t know how to forget (exhales)...",
|
91 |
+
"and no one knows the weight (sighs).",
|
92 |
+
"even when I laugh (sniffs)...",
|
93 |
+
"and it feels so unfair (exhales).",
|
94 |
+
"and it still ruins me (sighs)...",
|
95 |
+
"and I feel it every day (inhales).",
|
96 |
+
"but I try to act normal (mumbles)...",
|
97 |
+
"like I’m supposed to be okay (sighs).",
|
98 |
+
"and I never got closure (exhales)...",
|
99 |
+
"and I miss them every day (sniffs).",
|
100 |
+
"but the pain stays the same (sighs)...",
|
101 |
+
"and it’s always right there (exhales).",
|
102 |
+
"and it floods me suddenly (sniffs)...",
|
103 |
+
"and it breaks me in pieces (sighs).",
|
104 |
+
"and I hear the echo of it (exhales)...",
|
105 |
+
"and I wish I said more (sighs).",
|
106 |
+
"and I regret staying quiet (mumbles)...",
|
107 |
+
"and I feel so small (sniffs).",
|
108 |
+
"and I can’t stop thinking about it (exhales)...",
|
109 |
+
"and I just collapse (sighs).",
|
110 |
+
"and I keep blaming myself (inhales)...",
|
111 |
+
"but it wasn’t even my fault (exhales).",
|
112 |
+
"and I replay it over and over (sighs)...",
|
113 |
+
"and it plays like a movie (sniffs).",
|
114 |
+
"but I never stop hurting (exhales)...",
|
115 |
+
"and I keep pretending (sighs).",
|
116 |
+
"and no one asks (mumbles)...",
|
117 |
+
"and I hate that it’s still here (exhales).",
|
118 |
+
"but I can’t let it go (sighs)...",
|
119 |
+
"even when I want to (sniffs).",
|
120 |
+
"and I can’t breathe sometimes (exhales)...",
|
121 |
+
"and my chest aches quietly (sighs).",
|
122 |
+
"and it keeps resurfacing (sniffs)...",
|
123 |
+
"and I’m so tired of it (exhales).",
|
124 |
+
"and I don’t know how to heal (sighs)...",
|
125 |
+
"and it knocks the wind out of me (exhales).",
|
126 |
+
"and I feel like a ghost (sighs)...",
|
127 |
+
"and I smile while breaking (mumbles).",
|
128 |
+
"and I don’t talk about it anymore (sniffs)...",
|
129 |
+
"and I wake up with it (exhales).",
|
130 |
+
"and it seeps into everything (sighs)...",
|
131 |
+
"and I lost myself there (exhales).",
|
132 |
+
"but no one notices (sighs)...",
|
133 |
+
"and I keep it locked up (sniffs).",
|
134 |
+
"and I carry it alone (exhales)...",
|
135 |
+
"and it’s always just under the surface (sighs).",
|
136 |
+
"but I’m still hurting (exhales)...",
|
137 |
+
"and it’s louder in the silence (sighs).",
|
138 |
+
"and I fall apart slowly (sniffs)...",
|
139 |
+
"and I can’t fix it (exhales).",
|
140 |
+
"and I still wait for them (sighs)...",
|
141 |
+
"and it all feels empty (exhales).",
|
142 |
+
"and it doesn’t make sense (sighs)...",
|
143 |
+
"but I can’t forget (sniffs).",
|
144 |
+
"and it makes me feel hollow (exhales)...",
|
145 |
+
"but I carry on anyway (sighs).",
|
146 |
+
"and I whisper it to no one (mumbles)...",
|
147 |
+
"and I miss who I was (exhales).",
|
148 |
+
"and I don’t think it’ll ever stop (sighs)...",
|
149 |
+
"and it comes back at night (sniffs).",
|
150 |
+
"and it makes me feel cold (exhales)...",
|
151 |
+
"but I still try to be okay (sighs).",
|
152 |
+
"and I hate talking about it (inhales)...",
|
153 |
+
"and I still cry when I’m alone (sniffs).",
|
154 |
+
"and I lose words (sighs)...",
|
155 |
+
"and I feel fragile (exhales).",
|
156 |
+
"and it’s exhausting (sighs)...",
|
157 |
+
"but I just nod and smile (mumbles).",
|
158 |
+
"and it feels like I’m fading (exhales)...",
|
159 |
+
"and it eats away at me (sighs).",
|
160 |
+
"but I say nothing (sniffs)...",
|
161 |
+
"and I hope they never see (sighs).",
|
162 |
+
"and I fake my way through (exhales)...",
|
163 |
+
"but I’m breaking slowly (sighs)."
|
164 |
+
],
|
165 |
+
|
166 |
+
"templates": [
|
167 |
+
"{s} {v}, {c}.",
|
168 |
+
"{i}, {s} {v}.",
|
169 |
+
"{s} {v}. {c}.",
|
170 |
+
"{s} {v} — {c}.",
|
171 |
+
"{c}. {s} {v}.",
|
172 |
+
"{s}. It {v}, {c}.",
|
173 |
+
"{s} still... {v}. {c}.",
|
174 |
+
"{s} {v}. And it never ends. {c}.",
|
175 |
+
"I try to forget, but {s} {v}, {c}.",
|
176 |
+
"{i} — {s} {v}, {c}.",
|
177 |
+
"{s} {v}. {i}. {c}.",
|
178 |
+
"The truth is, {s} {v}.",
|
179 |
+
"‘{s} {v}’ — and no one knew. {c}.",
|
180 |
+
"{s} just doesn’t go away. {c}.",
|
181 |
+
"Even now, {s} {v}, {c}."
|
182 |
+
]
|
183 |
+
}
|
emotion_templates/surprise.json
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"subjects": [
|
3 |
+
"That message", "This call", "The result", "What she said", "The announcement",
|
4 |
+
"The coincidence", "The gift", "His answer", "Her reaction", "The timing",
|
5 |
+
"That moment", "The photo", "This discovery", "The unexpected reply",
|
6 |
+
"The look on his face", "The email", "That visit", "The headline",
|
7 |
+
"The glitch", "This outcome", "The news", "That thing", "The realization",
|
8 |
+
"Her arrival", "His silence", "That sound", "That noise", "The blink",
|
9 |
+
"This miracle", "The coincidence of dates", "The surprise guest",
|
10 |
+
"The reversed roles", "This twist", "The window", "The echo", "The light",
|
11 |
+
"The turning point", "The expression", "The instant reply", "The notification",
|
12 |
+
"This random thing", "The unexpected gift", "The open door", "That trigger",
|
13 |
+
"The unlocked phone", "The old message", "That letter", "The voicemail",
|
14 |
+
"The opportunity", "The moment of silence", "The call from nowhere",
|
15 |
+
"The found photo", "That chance meeting", "What I saw", "The change in tone",
|
16 |
+
"The open drawer", "The sound in the hallway", "The empty room",
|
17 |
+
"The random response", "Her sudden shift", "The blinking lights",
|
18 |
+
"What he did", "That conversation", "The unplanned moment", "The whisper",
|
19 |
+
"That sudden feeling", "That name", "This vibe", "That expression",
|
20 |
+
"The door creaking", "The unexpected question", "The win", "This strange timing",
|
21 |
+
"The unplugged cable", "The room's silence", "The laugh", "The pause",
|
22 |
+
"The sudden music", "The uninvited guest", "The news headline", "The page turn",
|
23 |
+
"The camera flash", "The unknown voice", "The stray animal", "The forgotten date",
|
24 |
+
"That coincidence", "The twisted sentence", "That eye contact", "This strange sound",
|
25 |
+
"The falling object", "The flicker", "The joke", "That mood shift", "This noise",
|
26 |
+
"The silent phone", "That little thing", "The hidden truth", "The strange smile"
|
27 |
+
],
|
28 |
+
"verbs": [
|
29 |
+
"shocked me", "made me freeze", "caught me off guard", "left me speechless",
|
30 |
+
"blew my mind", "came out of nowhere", "was totally unexpected", "made my jaw drop",
|
31 |
+
"felt surreal", "threw me off", "turned everything around", "changed the whole vibe",
|
32 |
+
"reset the tone", "hit me like lightning", "flipped the moment", "was unreal",
|
33 |
+
"was bizarre", "seemed impossible", "was both funny and scary", "made me pause",
|
34 |
+
"rewired my brain", "was absurd", "challenged everything I thought", "was chaotic",
|
35 |
+
"made no sense", "was random", "came with no warning", "felt like a dream",
|
36 |
+
"sparked confusion", "was like fiction", "shattered expectations", "reversed my mood",
|
37 |
+
"surprised everyone", "felt impossible", "disrupted everything", "popped up suddenly",
|
38 |
+
"turned heads", "was pure surprise", "hit like a flash", "flipped the whole thing",
|
39 |
+
"triggered laughter", "was hilariously off", "felt like a glitch", "reset the energy",
|
40 |
+
"dropped like a bomb", "silenced the room", "was a twist", "stood out completely",
|
41 |
+
"just happened", "was nonsense", "was oddly fitting", "was genius", "twisted my logic",
|
42 |
+
"shook everyone", "sparked wild reactions", "caught all attention", "stunned me",
|
43 |
+
"was beautifully weird", "was madness", "reversed all logic", "was insane",
|
44 |
+
"left us reeling", "had no explanation", "felt electric", "hit mid-sentence",
|
45 |
+
"made me sit up", "just flipped it", "was comedic", "was magical", "reset my brain",
|
46 |
+
"was silly and scary", "snapped everything", "twisted expectations", "felt improvised",
|
47 |
+
"landed unexpectedly", "snuck up on us", "was so weird", "broke the silence",
|
48 |
+
"sparked awe", "felt foreign", "felt glitched", "confused everyone", "was strangely real",
|
49 |
+
"made me react instantly", "was unexpected chaos", "took a wild turn", "made no sense at all",
|
50 |
+
"came suddenly", "hit without logic", "felt reversed", "cut through everything",
|
51 |
+
"changed the pace", "felt timed perfectly", "felt unscripted", "just crashed in",
|
52 |
+
"was the last thing I thought", "bent the moment", "came sharply", "felt like a prank"
|
53 |
+
],
|
54 |
+
"interjections": [
|
55 |
+
"(gasps) Honestly?!",
|
56 |
+
"No way — (laughs)!",
|
57 |
+
"Wait... what? (gasps)",
|
58 |
+
"(inhales) Seriously—",
|
59 |
+
"Oh wow! (whistles)",
|
60 |
+
"You're kidding! (laughs)",
|
61 |
+
"(gasps) I can't believe it!",
|
62 |
+
"What just happened? (chuckle)",
|
63 |
+
"Believe it or not — (gasps)",
|
64 |
+
"Suddenly... (laughs)",
|
65 |
+
"(whistles) Guess what!",
|
66 |
+
"Right then — (gasps) boom!",
|
67 |
+
"Out of nowhere (laughs)!",
|
68 |
+
"Just like that... (gasps)",
|
69 |
+
"Then — bam! (chuckle)",
|
70 |
+
"(gasps) Without warning!",
|
71 |
+
"For real? (laughs)",
|
72 |
+
"You won't believe this! (gasps)",
|
73 |
+
"Hold up— (whistles)!",
|
74 |
+
"(claps) That just happened!"
|
75 |
+
],
|
76 |
+
|
77 |
+
"contexts": [
|
78 |
+
"(gasps)... and I didn’t know what to say!",
|
79 |
+
"and everyone froze — (exhales).",
|
80 |
+
"and I had to laugh (laughs)!",
|
81 |
+
"and I still don’t get it... (chuckle)",
|
82 |
+
"and it felt surreal (whistles).",
|
83 |
+
"and no one moved — (inhales).",
|
84 |
+
"(gasps) and I just... gasped again.",
|
85 |
+
"and I was speechless (exhales)...",
|
86 |
+
"and I blinked like five times (gasps).",
|
87 |
+
"and it echoed in my head — (whistles)",
|
88 |
+
"and everyone just stared (inhales)...",
|
89 |
+
"and I sat there stunned (gasps)!",
|
90 |
+
"and the room went silent — (exhales).",
|
91 |
+
"and we all looked around (chuckle)...",
|
92 |
+
"and I looked twice (gasps)!",
|
93 |
+
"and I was like “what?!” (laughs)",
|
94 |
+
"and I couldn’t believe it (exhales)...",
|
95 |
+
"and it stuck with me (whistles).",
|
96 |
+
"and I still think about it (gasps)...",
|
97 |
+
"and I just shook my head (laughs).",
|
98 |
+
"and it broke the tension — (chuckle)!",
|
99 |
+
"and it threw off everything (exhales)...",
|
100 |
+
"and I said “no way!” (gasps)",
|
101 |
+
"and we all cracked up (laughs)!",
|
102 |
+
"and the vibe was gone — (inhales).",
|
103 |
+
"and it changed the day (whistles)...",
|
104 |
+
"and I had to rewind (gasps)!",
|
105 |
+
"and my jaw dropped — (exhales).",
|
106 |
+
"and I froze (gasps)...",
|
107 |
+
"and we all lost it (laughs)!",
|
108 |
+
"and it messed me up (chuckle)...",
|
109 |
+
"and I just stood still (inhales).",
|
110 |
+
"and I had no words (gasps)...",
|
111 |
+
"and we laughed for a minute (laughs).",
|
112 |
+
"and it felt like a dream — (whistles)...",
|
113 |
+
"and I blinked in shock (gasps)!",
|
114 |
+
"and we all paused (exhales)...",
|
115 |
+
"and I just stared (gasps)...",
|
116 |
+
"and I couldn’t speak (inhales).",
|
117 |
+
"and the shift was instant (gasps)!",
|
118 |
+
"and it was hilarious (laughs).",
|
119 |
+
"and my brain glitched (chuckle)...",
|
120 |
+
"and everyone panicked slightly (gasps).",
|
121 |
+
"and I had chills (inhales)...",
|
122 |
+
"and it just... happened (exhales).",
|
123 |
+
"and it triggered everything (gasps)!",
|
124 |
+
"and I remember the silence (whistles)...",
|
125 |
+
"and someone whispered “what just happened?” (gasps)",
|
126 |
+
"and we all froze (inhales)...",
|
127 |
+
"and we burst out laughing (laughs)!",
|
128 |
+
"and it flipped the room — (whistles)!",
|
129 |
+
"and I had to sit (gasps)...",
|
130 |
+
"and it made no sense (chuckle).",
|
131 |
+
"and I gasped loudly (gasps)!",
|
132 |
+
"and everyone blinked (inhales)...",
|
133 |
+
"and I had to walk out (exhales).",
|
134 |
+
"and the energy just changed (whistles)...",
|
135 |
+
"and it felt like fiction (gasps).",
|
136 |
+
"and the moment stuck — (exhales)...",
|
137 |
+
"and I tried to process it (chuckle).",
|
138 |
+
"and I said nothing (gasps)...",
|
139 |
+
"and the response was just shock (inhales).",
|
140 |
+
"and I heard my own heartbeat (exhales)...",
|
141 |
+
"and we just blinked (gasps).",
|
142 |
+
"and it rewrote the mood (whistles)...",
|
143 |
+
"and it still plays in my head (gasps)...",
|
144 |
+
"and it made the silence louder (exhales).",
|
145 |
+
"and I had to sit down (inhales)...",
|
146 |
+
"and it just hit me (gasps)!",
|
147 |
+
"and the air shifted — (whistles)...",
|
148 |
+
"and it startled me (gasps).",
|
149 |
+
"and everyone had the same face (chuckle)...",
|
150 |
+
"and I was out of words (exhales).",
|
151 |
+
"and someone gasped (gasps)!",
|
152 |
+
"and the tension cracked (laughs)...",
|
153 |
+
"and it made me laugh hard (laughs)!",
|
154 |
+
"and we all paused and stared (inhales)...",
|
155 |
+
"and I just looked down (exhales).",
|
156 |
+
"and we replayed it in our heads (whistles)...",
|
157 |
+
"and no one said a word (gasps).",
|
158 |
+
"and the timing was wild — (laughs)!",
|
159 |
+
"and it was too real (inhales)...",
|
160 |
+
"and we couldn’t stop laughing (laughs)!",
|
161 |
+
"and I shook my head slowly (chuckle)...",
|
162 |
+
"and it crashed everything (gasps).",
|
163 |
+
"and I wanted to rewind it (exhales)...",
|
164 |
+
"and it threw the moment off (inhales).",
|
165 |
+
"and we sat in silence (gasps)...",
|
166 |
+
"and my eyes went wide (exhales).",
|
167 |
+
"and it stunned everyone (gasps)!",
|
168 |
+
"and it landed like a bomb (laughs).",
|
169 |
+
"and I burst out mid‑thought (chuckle)...",
|
170 |
+
"and it shattered logic (gasps).",
|
171 |
+
"and I couldn’t unsee it (exhales)...",
|
172 |
+
"and the shift was visible (whistles).",
|
173 |
+
"and we blinked hard (gasps)..."
|
174 |
+
],
|
175 |
+
|
176 |
+
"templates": [
|
177 |
+
"{s} {v}, {c}.",
|
178 |
+
"{i}! {s} {v}.",
|
179 |
+
"{s} {v}. {c}.",
|
180 |
+
"{s} {v} — {c}.",
|
181 |
+
"{c}. {s} {v}.",
|
182 |
+
"{s}. It {v}, {c}.",
|
183 |
+
"{s} just... {v}. {c}.",
|
184 |
+
"You won’t believe this: {s} {v}, {c}!",
|
185 |
+
"And then, {s} {v}. {c}.",
|
186 |
+
"‘{s} {v}?’ That’s what happened. {c}.",
|
187 |
+
"{i}! — {s} {v}, {c}!",
|
188 |
+
"Everyone was quiet. Then {s} {v}, {c}!",
|
189 |
+
"{s} {v}. {i}! {c}.",
|
190 |
+
"{s} {v}. Can you imagine?! {c}.",
|
191 |
+
"So random: {s} {v}, {c}.",
|
192 |
+
"Boom! {s} {v}, {c}!",
|
193 |
+
"Suddenly, {s} {v}! {c}.",
|
194 |
+
"{s} {v} like a lightning bolt! {c}.",
|
195 |
+
"I blinked and {s} {v}! {c}.",
|
196 |
+
"Out of thin air — {s} {v}! {c}."
|
197 |
+
]
|
198 |
+
}
|
generate_emotion_texts_dataset.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
import re
|
5 |
+
from datetime import timedelta
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# === Загрузка шаблонов ===
|
9 |
+
def load_templates_json(templates_dir, emotion):
|
10 |
+
path = Path(templates_dir) / f"{emotion}.json"
|
11 |
+
if not path.exists():
|
12 |
+
raise FileNotFoundError(f"Шаблон для эмоции '{emotion}' не найден: {path}")
|
13 |
+
with open(path, "r", encoding="utf-8") as f:
|
14 |
+
return json.load(f)
|
15 |
+
|
16 |
+
# === Генерация текстов с учётом seed и антидубликатов ===
|
17 |
+
def generate_emotion_batch(n, template_data, seed=None):
|
18 |
+
if seed is not None:
|
19 |
+
random.seed(seed)
|
20 |
+
|
21 |
+
subjects = template_data["subjects"]
|
22 |
+
verbs = template_data["verbs"]
|
23 |
+
contexts = template_data["contexts"]
|
24 |
+
interjections = template_data.get("interjections", [""])
|
25 |
+
templates = template_data["templates"]
|
26 |
+
|
27 |
+
# Допустимые звуковые метки DIA‑TTS
|
28 |
+
dia_tags = {
|
29 |
+
"(laughs)", "(clears throat)", "(sighs)", "(gasps)", "(coughs)",
|
30 |
+
"(singing)", "(sings)", "(mumbles)", "(beep)", "(groans)", "(sniffs)",
|
31 |
+
"(claps)", "(screams)", "(inhales)", "(exhales)", "(applause)",
|
32 |
+
"(burps)", "(humming)", "(sneezes)", "(chuckle)", "(whistles)"
|
33 |
+
}
|
34 |
+
|
35 |
+
def has_tag(text): return any(tag in text for tag in dia_tags)
|
36 |
+
def remove_tags(text):
|
37 |
+
for tag in dia_tags:
|
38 |
+
text = text.replace(tag, "")
|
39 |
+
return text.strip()
|
40 |
+
|
41 |
+
phrases, attempts = set(), 0
|
42 |
+
max_attempts = n * 50
|
43 |
+
|
44 |
+
while len(phrases) < n and attempts < max_attempts:
|
45 |
+
s, v = random.choice(subjects), random.choice(verbs)
|
46 |
+
c, i = random.choice(contexts), random.choice(interjections)
|
47 |
+
t = random.choice(templates)
|
48 |
+
|
49 |
+
# ▸ Разрешаем максимум одну звуковую метку на фразу
|
50 |
+
if has_tag(i) and has_tag(c):
|
51 |
+
if random.random() < .5:
|
52 |
+
c = remove_tags(c)
|
53 |
+
else:
|
54 |
+
i = remove_tags(i)
|
55 |
+
|
56 |
+
phrase = t.format(s=s, v=v, c=c, i=i)
|
57 |
+
|
58 |
+
# --- Очистка без разрушения многоточий ---------------------------
|
59 |
+
# 1) убрать пробелы перед знаками пунктуации
|
60 |
+
phrase = re.sub(r"\s+([,.!?])", r"\1", phrase)
|
61 |
+
# 2) превратить двойную точку, КОТОРАЯ не часть троеточия, в одну
|
62 |
+
phrase = re.sub(r"(?<!\.)\.\.(?!\.)", ".", phrase)
|
63 |
+
# 3) вставить пробел, если после метки сразу идёт слово
|
64 |
+
phrase = re.sub(r"\)(?=\w)", ") ", phrase)
|
65 |
+
# 4) схлопнуть множественные пробелы и обрезать края
|
66 |
+
phrase = re.sub(r"\s{2,}", " ", phrase).strip()
|
67 |
+
# ------------------------------------------------------------------
|
68 |
+
|
69 |
+
if phrase not in phrases:
|
70 |
+
phrases.add(phrase)
|
71 |
+
attempts += 1
|
72 |
+
|
73 |
+
if len(phrases) < n:
|
74 |
+
print(f"⚠️ Только {len(phrases)} уникальных фраз из {n} запрошенных — возможно, исчерпан пул шаблонов.")
|
75 |
+
|
76 |
+
return list(phrases)
|
77 |
+
|
78 |
+
# === Генерация временных меток ===
|
79 |
+
def generate_dummy_timestamps(n):
|
80 |
+
base_time, result = timedelta(), []
|
81 |
+
for idx in range(n):
|
82 |
+
start = base_time + timedelta(seconds=idx * 6)
|
83 |
+
end = start + timedelta(seconds=5)
|
84 |
+
result.append((
|
85 |
+
str(start).split(".")[0] + ",000",
|
86 |
+
str(end).split(".")[0] + ",000"
|
87 |
+
))
|
88 |
+
return result
|
89 |
+
|
90 |
+
# === Финальная сборка и сохранение CSV ===
|
91 |
+
def create_emotion_csv(template_path, emotion_label, out_file, n=1000, seed=None):
|
92 |
+
data = load_templates_json(template_path, emotion_label)
|
93 |
+
phrases = generate_emotion_batch(n, data, seed)
|
94 |
+
timeline = generate_dummy_timestamps(n)
|
95 |
+
|
96 |
+
emotions = ["neutral", "happy", "sad", "anger", "surprise", "disgust", "fear"]
|
97 |
+
label_mask = {e: float(e == emotion_label) for e in emotions}
|
98 |
+
|
99 |
+
df = pd.DataFrame({
|
100 |
+
"video_name": [f"dia_{emotion_label}_utt{i}_synt" for i in range(n)],
|
101 |
+
"start_time": [s for s, _ in timeline],
|
102 |
+
"end_time" : [e for _, e in timeline],
|
103 |
+
"sentiment" : [0] * n,
|
104 |
+
**{e: [label_mask[e]] * n for e in emotions},
|
105 |
+
"text" : phrases
|
106 |
+
})
|
107 |
+
|
108 |
+
df.to_csv(out_file, index=False)
|
109 |
+
print(f"✅ Сохранено {len(df)} строк → {out_file}")
|
110 |
+
|
111 |
+
# --- Проверка дубликатов ---
|
112 |
+
dupes = df[df.duplicated("text", keep=False)]
|
113 |
+
if not dupes.empty:
|
114 |
+
dupe_file = Path(out_file).with_name(f"duplicates_{emotion_label}.csv")
|
115 |
+
dupes.to_csv(dupe_file, index=False)
|
116 |
+
print(f"⚠️ Найдено {len(dupes)} повторов → {dupe_file}")
|
117 |
+
else:
|
118 |
+
print("✅ Дубликатов нет.")
|
119 |
+
|
120 |
+
# === Точка вход�� ===
|
121 |
+
if __name__ == "__main__":
|
122 |
+
emotion_config = {
|
123 |
+
"anger": 3600,
|
124 |
+
"disgust": 4438,
|
125 |
+
"fear": 4441,
|
126 |
+
"happy": 2966,
|
127 |
+
"sad": 4026,
|
128 |
+
"surprise": 3504
|
129 |
+
}
|
130 |
+
|
131 |
+
seed, template_path, out_dir = 42, "emotion_templates", "synthetic_data"
|
132 |
+
Path(out_dir).mkdir(parents=True, exist_ok=True)
|
133 |
+
|
134 |
+
for emotion, n in emotion_config.items():
|
135 |
+
out_csv = Path(out_dir) / f"meld_synthetic_{emotion}_{n}.csv"
|
136 |
+
print(f"\n🔄 Генерация: {emotion} ({n} фраз)")
|
137 |
+
create_emotion_csv(template_path, emotion, str(out_csv), n, seed)
|
generate_synthetic_dataset.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
import pandas as pd
|
5 |
+
from multiprocessing import Process
|
6 |
+
from synthetic_utils.dia_tts_wrapper import DiaTTSWrapper
|
7 |
+
|
8 |
+
|
9 |
+
def process_chunk(chunk_df, emotion, wav_dir, device, chunk_id):
|
10 |
+
tts = DiaTTSWrapper(device=device)
|
11 |
+
for idx, row in chunk_df.iterrows():
|
12 |
+
text = row["text"]
|
13 |
+
video_name = row.get("video_name", f"{emotion}_{chunk_id}_{idx}")
|
14 |
+
filename_prefix = video_name
|
15 |
+
|
16 |
+
try:
|
17 |
+
result = tts.generate_and_save_audio(
|
18 |
+
text=text,
|
19 |
+
out_dir=wav_dir,
|
20 |
+
filename_prefix=filename_prefix,
|
21 |
+
use_timestamp=False,
|
22 |
+
skip_if_exists=True,
|
23 |
+
max_trim_duration=10.0
|
24 |
+
)
|
25 |
+
if result is None:
|
26 |
+
logging.info(f"[{emotion}] ⏭️ Пропущено: {filename_prefix}.wav")
|
27 |
+
else:
|
28 |
+
logging.info(f"[{emotion}] ✔ {filename_prefix}.wav")
|
29 |
+
except Exception as e:
|
30 |
+
logging.error(f"[{emotion}] ❌ Ошибка: {filename_prefix} — {e}")
|
31 |
+
|
32 |
+
|
33 |
+
def generate_from_emotion_csv(
|
34 |
+
csv_path: str,
|
35 |
+
emotion: str,
|
36 |
+
output_dir: str,
|
37 |
+
device: str = "cuda",
|
38 |
+
max_samples: int = None,
|
39 |
+
num_processes: int = 1
|
40 |
+
):
|
41 |
+
out_dir = os.path.join(output_dir, emotion)
|
42 |
+
wav_dir = os.path.join(out_dir, "wavs")
|
43 |
+
os.makedirs(wav_dir, exist_ok=True)
|
44 |
+
|
45 |
+
logging.info(f"🎙️ Эмоция: '{emotion}' | CSV: {csv_path}")
|
46 |
+
logging.info(f"📥 Сохранение в: {wav_dir}")
|
47 |
+
|
48 |
+
df = pd.read_csv(csv_path)
|
49 |
+
if max_samples is not None:
|
50 |
+
df = df.sample(n=max_samples)
|
51 |
+
|
52 |
+
chunk_size = len(df) // num_processes
|
53 |
+
chunks = [df.iloc[i*chunk_size : (i+1)*chunk_size] for i in range(num_processes)]
|
54 |
+
|
55 |
+
remainder = len(df) % num_processes
|
56 |
+
if remainder > 0:
|
57 |
+
chunks[-1] = pd.concat([chunks[-1], df.iloc[-remainder:]])
|
58 |
+
|
59 |
+
total_start = time.time()
|
60 |
+
|
61 |
+
processes = []
|
62 |
+
for i, chunk in enumerate(chunks):
|
63 |
+
p = Process(target=process_chunk, args=(chunk, emotion, wav_dir, device, i))
|
64 |
+
p.start()
|
65 |
+
processes.append(p)
|
66 |
+
|
67 |
+
for p in processes:
|
68 |
+
p.join()
|
69 |
+
|
70 |
+
total_elapsed = time.time() - total_start
|
71 |
+
logging.info(f"✅ Эмоция '{emotion}' завершена | чанков: {num_processes} | ⏱️ {total_elapsed:.1f} сек\n")
|
main.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train.py
|
2 |
+
# coding: utf-8
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import datetime
|
7 |
+
import whisper
|
8 |
+
import toml
|
9 |
+
# os.environ["HF_HOME"] = "models"
|
10 |
+
|
11 |
+
from utils.config_loader import ConfigLoader
|
12 |
+
from utils.logger_setup import setup_logger
|
13 |
+
from utils.search_utils import greedy_search, exhaustive_search
|
14 |
+
from training.train_utils import (
|
15 |
+
make_dataset_and_loader,
|
16 |
+
train_once
|
17 |
+
)
|
18 |
+
from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor
|
19 |
+
|
20 |
+
def main():
|
21 |
+
|
22 |
+
# Грузим конфиг
|
23 |
+
base_config = ConfigLoader("config.toml")
|
24 |
+
|
25 |
+
model_name = base_config.model_name.replace("/", "_").replace(" ", "_").lower()
|
26 |
+
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
27 |
+
results_dir = f"results_{model_name}_{timestamp}"
|
28 |
+
os.makedirs(results_dir, exist_ok=True)
|
29 |
+
|
30 |
+
epochlog_dir = os.path.join(results_dir, "metrics_by_epoch")
|
31 |
+
os.makedirs(epochlog_dir, exist_ok=True)
|
32 |
+
|
33 |
+
# Настраиваем logging
|
34 |
+
log_file = os.path.join(results_dir, "session_log.txt")
|
35 |
+
setup_logger(logging.DEBUG, log_file=log_file)
|
36 |
+
|
37 |
+
# Грузим конфиг
|
38 |
+
base_config.show_config()
|
39 |
+
|
40 |
+
shutil.copy("config.toml", os.path.join(results_dir, "config_copy.toml"))
|
41 |
+
# Файл, куда будет писать наш жадный поиск
|
42 |
+
overrides_file = os.path.join(results_dir, "overrides.txt")
|
43 |
+
csv_prefix = os.path.join(epochlog_dir, "metrics_epochlog")
|
44 |
+
|
45 |
+
audio_feature_extractor= PretrainedAudioEmbeddingExtractor(base_config)
|
46 |
+
text_feature_extractor = PretrainedTextEmbeddingExtractor(base_config)
|
47 |
+
|
48 |
+
# Инициализируем Whisper-модель один раз
|
49 |
+
logging.info(f"Инициализация Whisper: модель={base_config.whisper_model}, устройство={base_config.whisper_device}")
|
50 |
+
whisper_model = whisper.load_model(base_config.whisper_model, device=base_config.whisper_device)
|
51 |
+
|
52 |
+
# Делаем датасеты/лоадеры
|
53 |
+
# Общий train_loader
|
54 |
+
_, train_loader = make_dataset_and_loader(base_config, "train", audio_feature_extractor, text_feature_extractor, whisper_model)
|
55 |
+
|
56 |
+
# Раздельные dev/test
|
57 |
+
dev_loaders = []
|
58 |
+
test_loaders = []
|
59 |
+
|
60 |
+
for dataset_name in base_config.datasets:
|
61 |
+
_, dev_loader = make_dataset_and_loader(base_config, "dev", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)
|
62 |
+
_, test_loader = make_dataset_and_loader(base_config, "test", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)
|
63 |
+
|
64 |
+
dev_loaders.append((dataset_name, dev_loader))
|
65 |
+
test_loaders.append((dataset_name, test_loader))
|
66 |
+
|
67 |
+
if base_config.prepare_only:
|
68 |
+
logging.info("== Режим prepare_only: только подготовка данных, без обучения ==")
|
69 |
+
return
|
70 |
+
|
71 |
+
search_config = toml.load("search_params.toml")
|
72 |
+
param_grid = dict(search_config["grid"])
|
73 |
+
default_values = dict(search_config["defaults"])
|
74 |
+
|
75 |
+
if base_config.search_type == "greedy":
|
76 |
+
greedy_search(
|
77 |
+
base_config = base_config,
|
78 |
+
train_loader = train_loader,
|
79 |
+
dev_loader = dev_loaders,
|
80 |
+
test_loader = test_loaders,
|
81 |
+
train_fn = train_once,
|
82 |
+
overrides_file = overrides_file,
|
83 |
+
param_grid = param_grid,
|
84 |
+
default_values = default_values,
|
85 |
+
csv_prefix = csv_prefix
|
86 |
+
)
|
87 |
+
|
88 |
+
elif base_config.search_type == "exhaustive":
|
89 |
+
exhaustive_search(
|
90 |
+
base_config = base_config,
|
91 |
+
train_loader = train_loader,
|
92 |
+
dev_loader = dev_loaders,
|
93 |
+
test_loader = test_loaders,
|
94 |
+
train_fn = train_once,
|
95 |
+
overrides_file = overrides_file,
|
96 |
+
param_grid = param_grid,
|
97 |
+
csv_prefix = csv_prefix
|
98 |
+
)
|
99 |
+
|
100 |
+
elif base_config.search_type == "none":
|
101 |
+
logging.info("== Режим одиночной тренировки (без поиска параметров) ==")
|
102 |
+
|
103 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
104 |
+
csv_file_path = f"{csv_prefix}_single_{timestamp}.csv"
|
105 |
+
|
106 |
+
train_once(
|
107 |
+
config = base_config,
|
108 |
+
train_loader = train_loader,
|
109 |
+
dev_loaders = dev_loaders,
|
110 |
+
test_loaders = test_loaders,
|
111 |
+
metrics_csv_path = csv_file_path
|
112 |
+
)
|
113 |
+
|
114 |
+
else:
|
115 |
+
raise ValueError(f"⛔️ Неверное значение search_type в конфиге: '{base_config.search_type}'. Используй 'greedy', 'exhaustive' или 'none'.")
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
main()
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (134 Bytes). View file
|
|
models/__pycache__/help_layers.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
models/__pycache__/models.cpython-310.pyc
ADDED
Binary file (28.9 kB). View file
|
|
models/help_layers.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.init as init
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
from torch.nn.functional import silu
|
9 |
+
from torch.nn.functional import softplus
|
10 |
+
from einops import rearrange, einsum
|
11 |
+
from torch import Tensor
|
12 |
+
from torch_geometric.nn import GATConv, RGCNConv, TransformerConv
|
13 |
+
|
14 |
+
class PositionWiseFeedForward(nn.Module):
|
15 |
+
def __init__(self, input_dim, hidden_dim, dropout=0.1):
|
16 |
+
super().__init__()
|
17 |
+
self.layer_1 = nn.Linear(input_dim, hidden_dim)
|
18 |
+
self.layer_2 = nn.Linear(hidden_dim, input_dim)
|
19 |
+
self.dropout = nn.Dropout(dropout)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.layer_1(x)
|
23 |
+
x = F.gelu(x) # Более плавная активация
|
24 |
+
x = self.dropout(x)
|
25 |
+
return self.layer_2(x)
|
26 |
+
|
27 |
+
|
28 |
+
class AddAndNorm(nn.Module):
|
29 |
+
def __init__(self, input_dim, dropout=0.1):
|
30 |
+
super().__init__()
|
31 |
+
self.norm = nn.LayerNorm(input_dim)
|
32 |
+
self.dropout = nn.Dropout(dropout)
|
33 |
+
|
34 |
+
def forward(self, x, residual):
|
35 |
+
return self.norm(x + self.dropout(residual))
|
36 |
+
|
37 |
+
|
38 |
+
class PositionalEncoding(nn.Module):
|
39 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
40 |
+
super().__init__()
|
41 |
+
self.dropout = nn.Dropout(p=dropout)
|
42 |
+
|
43 |
+
position = torch.arange(max_len).unsqueeze(1)
|
44 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
45 |
+
pe = torch.zeros(max_len, d_model)
|
46 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
47 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
48 |
+
|
49 |
+
self.register_buffer("pe", pe)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x + self.pe[: x.size(1)].detach() # Отключаем градиенты
|
53 |
+
return self.dropout(x)
|
54 |
+
|
55 |
+
|
56 |
+
class TransformerEncoderLayer(nn.Module):
|
57 |
+
def __init__(self, input_dim, num_heads, dropout=0.1, positional_encoding=False):
|
58 |
+
super().__init__()
|
59 |
+
self.input_dim = input_dim
|
60 |
+
self.self_attention = nn.MultiheadAttention(input_dim, num_heads, dropout=dropout, batch_first=True)
|
61 |
+
# self.self_attention = MHA(
|
62 |
+
# embed_dim=input_dim,
|
63 |
+
# num_heads=num_heads,
|
64 |
+
# dropout=dropout,
|
65 |
+
# # bias=True,
|
66 |
+
# use_flash_attn=True
|
67 |
+
# )
|
68 |
+
self.feed_forward = PositionWiseFeedForward(input_dim, input_dim, dropout=dropout)
|
69 |
+
self.add_norm_after_attention = AddAndNorm(input_dim, dropout=dropout)
|
70 |
+
self.add_norm_after_ff = AddAndNorm(input_dim, dropout=dropout)
|
71 |
+
self.positional_encoding = PositionalEncoding(input_dim) if positional_encoding else None
|
72 |
+
|
73 |
+
def forward(self, key, value, query):
|
74 |
+
if self.positional_encoding:
|
75 |
+
key = self.positional_encoding(key)
|
76 |
+
value = self.positional_encoding(value)
|
77 |
+
query = self.positional_encoding(query)
|
78 |
+
|
79 |
+
attn_output, _ = self.self_attention(query, key, value, need_weights=False)
|
80 |
+
# attn_output = self.self_attention(query, key, value)
|
81 |
+
|
82 |
+
x = self.add_norm_after_attention(attn_output, query)
|
83 |
+
|
84 |
+
ff_output = self.feed_forward(x)
|
85 |
+
x = self.add_norm_after_ff(ff_output, x)
|
86 |
+
|
87 |
+
return x
|
88 |
+
|
89 |
+
class GAL(nn.Module):
|
90 |
+
def __init__(self, input_dim_F1, input_dim_F2, gated_dim, dropout_rate):
|
91 |
+
super(GAL, self).__init__()
|
92 |
+
|
93 |
+
self.WF1 = nn.Parameter(torch.Tensor(input_dim_F1, gated_dim))
|
94 |
+
self.WF2 = nn.Parameter(torch.Tensor(input_dim_F2, gated_dim))
|
95 |
+
|
96 |
+
init.xavier_uniform_(self.WF1)
|
97 |
+
init.xavier_uniform_(self.WF2)
|
98 |
+
|
99 |
+
dim_size_f = input_dim_F1 + input_dim_F2
|
100 |
+
|
101 |
+
self.WF = nn.Parameter(torch.Tensor(dim_size_f, gated_dim))
|
102 |
+
|
103 |
+
init.xavier_uniform_(self.WF)
|
104 |
+
|
105 |
+
self.dropout = nn.Dropout(dropout_rate)
|
106 |
+
|
107 |
+
def forward(self, f1, f2):
|
108 |
+
|
109 |
+
h_f1 = self.dropout(torch.tanh(torch.matmul(f1, self.WF1)))
|
110 |
+
h_f2 = self.dropout(torch.tanh(torch.matmul(f2, self.WF2)))
|
111 |
+
# print(h_f1.shape, h_f2.shape, self.WF.shape, torch.cat([f1, f2], dim=1).shape)
|
112 |
+
z_f = torch.softmax(self.dropout(torch.matmul(torch.cat([f1, f2], dim=1), self.WF)), dim=1)
|
113 |
+
h_f = z_f*h_f1 + (1 - z_f)*h_f2
|
114 |
+
return h_f
|
115 |
+
|
116 |
+
class GraphFusionLayer(nn.Module):
|
117 |
+
def __init__(self, hidden_dim, dropout=0.0, heads=2, out_mean=True):
|
118 |
+
super().__init__()
|
119 |
+
self.out_mean = out_mean
|
120 |
+
# # Проекционные слои для признаков
|
121 |
+
self.proj_audio = nn.Sequential(
|
122 |
+
nn.Linear(hidden_dim, hidden_dim),
|
123 |
+
nn.LayerNorm(hidden_dim),
|
124 |
+
nn.Dropout(dropout)
|
125 |
+
)
|
126 |
+
self.proj_text = nn.Sequential(
|
127 |
+
nn.Linear(hidden_dim, hidden_dim),
|
128 |
+
nn.LayerNorm(hidden_dim),
|
129 |
+
nn.Dropout(dropout)
|
130 |
+
)
|
131 |
+
|
132 |
+
# Графовые слои
|
133 |
+
self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads)
|
134 |
+
self.gat2 = GATConv(hidden_dim*heads, hidden_dim)
|
135 |
+
|
136 |
+
# Финальная проекция
|
137 |
+
self.fc = nn.Sequential(
|
138 |
+
nn.Linear(hidden_dim, hidden_dim),
|
139 |
+
nn.LayerNorm(hidden_dim),
|
140 |
+
nn.Dropout(dropout)
|
141 |
+
)
|
142 |
+
|
143 |
+
def build_complete_graph(self, num_nodes):
|
144 |
+
# Создаем полный граф (каждый узел соединен со всеми)
|
145 |
+
edge_index = []
|
146 |
+
for i in range(num_nodes):
|
147 |
+
for j in range(num_nodes):
|
148 |
+
if i != j:
|
149 |
+
edge_index.append([i, j])
|
150 |
+
return torch.tensor(edge_index).t().contiguous()
|
151 |
+
|
152 |
+
def forward(self, audio_stats, text_stats):
|
153 |
+
"""
|
154 |
+
audio_stats: [batch_size, hidden_dim]
|
155 |
+
text_stats: [batch_size, hidden_dim]
|
156 |
+
"""
|
157 |
+
batch_size = audio_stats.size(0)
|
158 |
+
|
159 |
+
# Проекция признаков
|
160 |
+
x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim]
|
161 |
+
x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim]
|
162 |
+
|
163 |
+
# Объединение узлов (аудио и текст попеременно)
|
164 |
+
nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim]
|
165 |
+
nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim]
|
166 |
+
|
167 |
+
# Построение графа (полный граф для каждого элемента батча)
|
168 |
+
edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст
|
169 |
+
edge_index = edge_index.to(audio_stats.device)
|
170 |
+
|
171 |
+
# Применение GAT
|
172 |
+
x = F.relu(self.gat1(nodes, edge_index))
|
173 |
+
x = self.gat2(x, edge_index)
|
174 |
+
|
175 |
+
# Разделяем обратно аудио и текст
|
176 |
+
x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim]
|
177 |
+
|
178 |
+
if self.out_mean:
|
179 |
+
# Усреднение по модальностям
|
180 |
+
fused = torch.mean(x, dim=1) # [batch_size, hidden_dim]
|
181 |
+
|
182 |
+
return self.fc(fused)
|
183 |
+
else:
|
184 |
+
return x
|
185 |
+
|
186 |
+
class GraphFusionLayerAtt(nn.Module):
|
187 |
+
def __init__(self, hidden_dim, heads=2):
|
188 |
+
super().__init__()
|
189 |
+
# Проекционные слои для признаков
|
190 |
+
self.proj_audio = nn.Linear(hidden_dim, hidden_dim)
|
191 |
+
self.proj_text = nn.Linear(hidden_dim, hidden_dim)
|
192 |
+
|
193 |
+
# Графовые слои
|
194 |
+
self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads)
|
195 |
+
self.gat2 = GATConv(hidden_dim*heads, hidden_dim)
|
196 |
+
|
197 |
+
self.attention_fusion = nn.Linear(hidden_dim, 1)
|
198 |
+
|
199 |
+
# Финальная проекция
|
200 |
+
self.fc = nn.Linear(hidden_dim, hidden_dim)
|
201 |
+
|
202 |
+
def build_complete_graph(self, num_nodes):
|
203 |
+
# Создаем полный граф (каждый узел соединен со всеми)
|
204 |
+
edge_index = []
|
205 |
+
for i in range(num_nodes):
|
206 |
+
for j in range(num_nodes):
|
207 |
+
if i != j:
|
208 |
+
edge_index.append([i, j])
|
209 |
+
return torch.tensor(edge_index).t().contiguous()
|
210 |
+
|
211 |
+
def forward(self, audio_stats, text_stats):
|
212 |
+
"""
|
213 |
+
audio_stats: [batch_size, hidden_dim]
|
214 |
+
text_stats: [batch_size, hidden_dim]
|
215 |
+
"""
|
216 |
+
batch_size = audio_stats.size(0)
|
217 |
+
|
218 |
+
# Проекция признаков
|
219 |
+
x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim]
|
220 |
+
x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim]
|
221 |
+
|
222 |
+
# Объединение узлов (аудио и текст попеременно)
|
223 |
+
nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim]
|
224 |
+
nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim]
|
225 |
+
|
226 |
+
# Построение графа (полный граф для каждого элемента батча)
|
227 |
+
edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст
|
228 |
+
edge_index = edge_index.to(audio_stats.device)
|
229 |
+
|
230 |
+
# Применение GAT
|
231 |
+
x = F.relu(self.gat1(nodes, edge_index))
|
232 |
+
x = self.gat2(x, edge_index)
|
233 |
+
|
234 |
+
# Разделяем обратно аудио и текст
|
235 |
+
x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim]
|
236 |
+
|
237 |
+
# Усреднение по модальностям
|
238 |
+
# fused = torch.mean(x, dim=1) # [batch_size, hidden_dim]
|
239 |
+
|
240 |
+
weights = F.softmax(self.attention_fusion(x), dim=1)
|
241 |
+
fused = torch.sum(weights * x, dim=1) # [batch_size, hidden_dim]
|
242 |
+
|
243 |
+
return self.fc(fused)
|
244 |
+
|
245 |
+
# Full code see https://github.com/leson502/CORECT_EMNLP2023/tree/master/corect/model
|
246 |
+
|
247 |
+
class GNN(nn.Module):
|
248 |
+
def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, gcn_conv, use_graph_transformer, graph_transformer_nheads):
|
249 |
+
super(GNN, self).__init__()
|
250 |
+
self.gcn_conv = gcn_conv
|
251 |
+
self.use_graph_transformer=use_graph_transformer
|
252 |
+
|
253 |
+
self.num_modals = num_modals
|
254 |
+
|
255 |
+
if self.gcn_conv == "rgcn":
|
256 |
+
print("GNN --> Use RGCN")
|
257 |
+
self.conv1 = RGCNConv(g_dim, h1_dim, num_relations)
|
258 |
+
|
259 |
+
if self.use_graph_transformer:
|
260 |
+
print("GNN --> Use Graph Transformer")
|
261 |
+
|
262 |
+
in_dim = h1_dim
|
263 |
+
|
264 |
+
self.conv2 = TransformerConv(in_dim, h2_dim, heads=graph_transformer_nheads, concat=True)
|
265 |
+
self.bn = nn.BatchNorm1d(h2_dim * graph_transformer_nheads)
|
266 |
+
|
267 |
+
|
268 |
+
def forward(self, node_features, node_type, edge_index, edge_type):
|
269 |
+
print(node_features.shape, edge_index.shape, edge_type.shape)
|
270 |
+
|
271 |
+
if self.gcn_conv == "rgcn":
|
272 |
+
x = self.conv1(node_features, edge_index, edge_type)
|
273 |
+
|
274 |
+
if self.use_graph_transformer:
|
275 |
+
x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index)))
|
276 |
+
|
277 |
+
return x
|
278 |
+
|
279 |
+
class GraphModel(nn.Module):
|
280 |
+
def __init__(self, g_dim, h1_dim, h2_dim, device, modalities, wp, wf, edge_type, gcn_conv, use_graph_transformer, graph_transformer_nheads):
|
281 |
+
super(GraphModel, self).__init__()
|
282 |
+
|
283 |
+
self.n_modals = len(modalities)
|
284 |
+
self.wp = wp
|
285 |
+
self.wf = wf
|
286 |
+
self.device = device
|
287 |
+
self.gcn_conv=gcn_conv
|
288 |
+
self.use_graph_transformer=use_graph_transformer
|
289 |
+
|
290 |
+
print(f"GraphModel --> Edge type: {edge_type}")
|
291 |
+
print(f"GraphModel --> Window past: {wp}")
|
292 |
+
print(f"GraphModel --> Window future: {wf}")
|
293 |
+
edge_temp = "temp" in edge_type
|
294 |
+
edge_multi = "multi" in edge_type
|
295 |
+
|
296 |
+
edge_type_to_idx = {}
|
297 |
+
|
298 |
+
if edge_temp:
|
299 |
+
temporal = [-1, 1, 0]
|
300 |
+
for j in temporal:
|
301 |
+
for k in range(self.n_modals):
|
302 |
+
edge_type_to_idx[str(j) + str(k) + str(k)] = len(edge_type_to_idx)
|
303 |
+
else:
|
304 |
+
for j in range(self.n_modals):
|
305 |
+
edge_type_to_idx['0' + str(j) + str(j)] = len(edge_type_to_idx)
|
306 |
+
|
307 |
+
if edge_multi:
|
308 |
+
for j in range(self.n_modals):
|
309 |
+
for k in range(self.n_modals):
|
310 |
+
if (j != k):
|
311 |
+
edge_type_to_idx['0' + str(j) + str(k)] = len(edge_type_to_idx)
|
312 |
+
|
313 |
+
self.edge_type_to_idx = edge_type_to_idx
|
314 |
+
self.num_relations = len(edge_type_to_idx)
|
315 |
+
self.edge_multi = edge_multi
|
316 |
+
self.edge_temp = edge_temp
|
317 |
+
|
318 |
+
self.gnn = GNN(g_dim, h1_dim, h2_dim, self.num_relations, self.n_modals, self.gcn_conv, self.use_graph_transformer, graph_transformer_nheads)
|
319 |
+
|
320 |
+
|
321 |
+
def forward(self, x, lengths):
|
322 |
+
# print(f"x shape: {x.shape}, lengths: {lengths}, lengths.shape: {lengths.shape}")
|
323 |
+
|
324 |
+
node_features = feature_packing(x, lengths)
|
325 |
+
|
326 |
+
node_type, edge_index, edge_type, edge_index_lengths = \
|
327 |
+
self.batch_graphify(lengths)
|
328 |
+
|
329 |
+
out_gnn = self.gnn(node_features, node_type, edge_index, edge_type)
|
330 |
+
out_gnn = multi_concat(out_gnn, lengths, self.n_modals)
|
331 |
+
|
332 |
+
return out_gnn
|
333 |
+
|
334 |
+
def batch_graphify(self, lengths):
|
335 |
+
|
336 |
+
node_type, edge_index, edge_type, edge_index_lengths = [], [], [], []
|
337 |
+
edge_type_lengths = [0] * len(self.edge_type_to_idx)
|
338 |
+
|
339 |
+
lengths = lengths.tolist()
|
340 |
+
|
341 |
+
sum_length = 0
|
342 |
+
total_length = sum(lengths)
|
343 |
+
batch_size = len(lengths)
|
344 |
+
|
345 |
+
for k in range(self.n_modals):
|
346 |
+
for j in range(batch_size):
|
347 |
+
cur_len = lengths[j]
|
348 |
+
node_type.extend([k] * cur_len)
|
349 |
+
|
350 |
+
for j in range(batch_size):
|
351 |
+
cur_len = lengths[j]
|
352 |
+
|
353 |
+
perms = self.edge_perms(cur_len, total_length)
|
354 |
+
edge_index_lengths.append(len(perms))
|
355 |
+
|
356 |
+
for item in perms:
|
357 |
+
vertices = item[0]
|
358 |
+
neighbor = item[1]
|
359 |
+
edge_index.append(torch.tensor([vertices + sum_length, neighbor + sum_length]))
|
360 |
+
|
361 |
+
if vertices % total_length > neighbor % total_length:
|
362 |
+
temporal_type = 1
|
363 |
+
elif vertices % total_length < neighbor % total_length:
|
364 |
+
temporal_type = -1
|
365 |
+
else:
|
366 |
+
temporal_type = 0
|
367 |
+
edge_type.append(self.edge_type_to_idx[str(temporal_type)
|
368 |
+
+ str(node_type[vertices + sum_length])
|
369 |
+
+ str(node_type[neighbor + sum_length])])
|
370 |
+
|
371 |
+
sum_length += cur_len
|
372 |
+
|
373 |
+
node_type = torch.tensor(node_type).long().to(self.device)
|
374 |
+
edge_index = torch.stack(edge_index).t().contiguous().to(self.device) # [2, E]
|
375 |
+
edge_type = torch.tensor(edge_type).long().to(self.device) # [E]
|
376 |
+
edge_index_lengths = torch.tensor(edge_index_lengths).long().to(self.device) # [B]
|
377 |
+
|
378 |
+
return node_type, edge_index, edge_type, edge_index_lengths
|
379 |
+
|
380 |
+
def edge_perms(self, length, total_lengths):
|
381 |
+
|
382 |
+
all_perms = set()
|
383 |
+
array = np.arange(length)
|
384 |
+
for j in range(length):
|
385 |
+
if self.wp == -1 and self.wf == -1:
|
386 |
+
eff_array = array
|
387 |
+
elif self.wp == -1: # use all past context
|
388 |
+
eff_array = array[: min(length, j + self.wf)]
|
389 |
+
elif self.wf == -1: # use all future context
|
390 |
+
eff_array = array[max(0, j - self.wp) :]
|
391 |
+
else:
|
392 |
+
eff_array = array[
|
393 |
+
max(0, j - self.wp) : min(length, j + self.wf)
|
394 |
+
]
|
395 |
+
perms = set()
|
396 |
+
|
397 |
+
|
398 |
+
for k in range(self.n_modals):
|
399 |
+
node_index = j + k * total_lengths
|
400 |
+
if self.edge_temp == True:
|
401 |
+
for item in eff_array:
|
402 |
+
perms.add((node_index, item + k * total_lengths))
|
403 |
+
else:
|
404 |
+
perms.add((node_index, node_index))
|
405 |
+
if self.edge_multi == True:
|
406 |
+
for l in range(self.n_modals):
|
407 |
+
if l != k:
|
408 |
+
perms.add((node_index, j + l * total_lengths))
|
409 |
+
|
410 |
+
all_perms = all_perms.union(perms)
|
411 |
+
|
412 |
+
return list(all_perms)
|
413 |
+
|
414 |
+
def feature_packing(multimodal_feature, lengths):
|
415 |
+
batch_size = lengths.size(0)
|
416 |
+
# print(multimodal_feature.shape, batch_size, lengths.shape)
|
417 |
+
node_features = []
|
418 |
+
|
419 |
+
for feature in multimodal_feature:
|
420 |
+
for j in range(batch_size):
|
421 |
+
cur_len = lengths[j].item()
|
422 |
+
# print(f"feature.shape: {feature.shape}, j: {j}, cur_len: {cur_len}")
|
423 |
+
node_features.append(feature[j,:cur_len])
|
424 |
+
|
425 |
+
node_features = torch.cat(node_features, dim=0)
|
426 |
+
|
427 |
+
return node_features
|
428 |
+
|
429 |
+
def multi_concat(nodes_feature, lengths, n_modals):
|
430 |
+
sum_length = lengths.sum().item()
|
431 |
+
feature = []
|
432 |
+
for j in range(n_modals):
|
433 |
+
feature.append(nodes_feature[j * sum_length : (j + 1) * sum_length])
|
434 |
+
|
435 |
+
feature = torch.cat(feature, dim=-1)
|
436 |
+
|
437 |
+
return feature
|
438 |
+
|
439 |
+
class RMSNorm(nn.Module):
|
440 |
+
def __init__(self, d_model: int, eps: float = 1e-8) -> None:
|
441 |
+
super().__init__()
|
442 |
+
self.eps = eps
|
443 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
444 |
+
|
445 |
+
def forward(self, x: Tensor) -> Tensor:
|
446 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
|
447 |
+
|
448 |
+
class Mamba(nn.Module):
|
449 |
+
def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, pooling=None):
|
450 |
+
super().__init__()
|
451 |
+
mamba_par = {
|
452 |
+
'd_input' : d_input,
|
453 |
+
'd_model' : d_model,
|
454 |
+
'd_state' : d_state,
|
455 |
+
'd_discr' : d_discr,
|
456 |
+
'ker_size': ker_size
|
457 |
+
}
|
458 |
+
self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
|
459 |
+
self.fc_out = nn.Linear(d_input, num_classes)
|
460 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
461 |
+
|
462 |
+
def forward(self, seq, cache=None):
|
463 |
+
seq = torch.tensor(self.embedding(seq)).to(self.device)
|
464 |
+
for mamba, norm in self.layers:
|
465 |
+
out, cache = mamba(norm(seq), cache)
|
466 |
+
seq = out + seq
|
467 |
+
return self.fc_out(seq.mean(dim = 1))
|
468 |
+
|
469 |
+
class MambaBlock(nn.Module):
|
470 |
+
def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
|
471 |
+
super().__init__()
|
472 |
+
d_discr = d_discr if d_discr is not None else d_model // 16
|
473 |
+
self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
|
474 |
+
self.out_proj = nn.Linear(d_model, d_input, bias=False)
|
475 |
+
self.s_B = nn.Linear(d_model, d_state, bias=False)
|
476 |
+
self.s_C = nn.Linear(d_model, d_state, bias=False)
|
477 |
+
self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
|
478 |
+
self.conv = nn.Conv1d(
|
479 |
+
in_channels=d_model,
|
480 |
+
out_channels=d_model,
|
481 |
+
kernel_size=ker_size,
|
482 |
+
padding=ker_size - 1,
|
483 |
+
groups=d_model,
|
484 |
+
bias=True,
|
485 |
+
)
|
486 |
+
self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
|
487 |
+
self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
|
488 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
489 |
+
|
490 |
+
def forward(self, seq, cache=None):
|
491 |
+
b, l, d = seq.shape
|
492 |
+
(prev_hid, prev_inp) = cache if cache is not None else (None, None)
|
493 |
+
a, b = self.in_proj(seq).chunk(2, dim=-1)
|
494 |
+
x = rearrange(a, 'b l d -> b d l')
|
495 |
+
x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
|
496 |
+
a = self.conv(x)[..., :l]
|
497 |
+
a = rearrange(a, 'b d l -> b l d')
|
498 |
+
a = silu(a)
|
499 |
+
a, hid = self.ssm(a, prev_hid=prev_hid)
|
500 |
+
b = silu(b)
|
501 |
+
out = a * b
|
502 |
+
out = self.out_proj(out)
|
503 |
+
if cache:
|
504 |
+
cache = (hid.squeeze(), x[..., 1:])
|
505 |
+
return out, cache
|
506 |
+
|
507 |
+
def ssm(self, seq, prev_hid):
|
508 |
+
A = -self.A
|
509 |
+
D = +self.D
|
510 |
+
B = self.s_B(seq)
|
511 |
+
C = self.s_C(seq)
|
512 |
+
s = softplus(D + self.s_D(seq))
|
513 |
+
A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
|
514 |
+
B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
|
515 |
+
X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
|
516 |
+
hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
|
517 |
+
out = einsum(hid, C, 'b l d s, b l s -> b l d')
|
518 |
+
out = out + D * seq
|
519 |
+
return out, hid
|
520 |
+
|
521 |
+
def _hid_states(self, A, X, prev_hid=None):
|
522 |
+
b, l, d, s = A.shape
|
523 |
+
A = rearrange(A, 'b l d s -> l b d s')
|
524 |
+
X = rearrange(X, 'b l d s -> l b d s')
|
525 |
+
if prev_hid is not None:
|
526 |
+
return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
|
527 |
+
h = torch.zeros(b, d, s, device=self.device)
|
528 |
+
return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)
|
models/models.py
ADDED
@@ -0,0 +1,1700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .help_layers import TransformerEncoderLayer, GAL,GraphFusionLayer, GraphFusionLayerAtt, MambaBlock, RMSNorm
|
6 |
+
|
7 |
+
class PredictionsFusion(nn.Module):
|
8 |
+
def __init__(self, num_matrices=2, num_classes=7):
|
9 |
+
super(PredictionsFusion, self).__init__()
|
10 |
+
self.weights = nn.Parameter(torch.rand(num_matrices, num_classes))
|
11 |
+
|
12 |
+
def forward(self, pred):
|
13 |
+
normalized_weights = torch.softmax(self.weights, dim=0)
|
14 |
+
weighted_matrix = sum(mat * normalized_weights[i] for i, mat in enumerate(pred))
|
15 |
+
return weighted_matrix
|
16 |
+
|
17 |
+
class MultiModalTransformer_v3(nn.Module):
|
18 |
+
def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
19 |
+
super(MultiModalTransformer_v3, self).__init__()
|
20 |
+
|
21 |
+
self.mode = mode
|
22 |
+
|
23 |
+
self.hidden_dim = hidden_dim
|
24 |
+
|
25 |
+
# Проекционные слои
|
26 |
+
# self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
|
27 |
+
|
28 |
+
# self.audio_proj = nn.Sequential(
|
29 |
+
# nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
30 |
+
# nn.LayerNorm(hidden_dim),
|
31 |
+
# nn.Dropout(dropout)
|
32 |
+
# )
|
33 |
+
|
34 |
+
self.audio_proj = nn.Sequential(
|
35 |
+
nn.Conv1d(audio_dim, hidden_dim, 1),
|
36 |
+
nn.GELU(),
|
37 |
+
)
|
38 |
+
|
39 |
+
self.text_proj = nn.Sequential(
|
40 |
+
nn.Conv1d(text_dim, hidden_dim, 1),
|
41 |
+
nn.GELU(),
|
42 |
+
)
|
43 |
+
# self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
|
44 |
+
|
45 |
+
# self.text_proj = nn.Sequential(
|
46 |
+
# nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
47 |
+
# nn.LayerNorm(hidden_dim),
|
48 |
+
# nn.Dropout(dropout)
|
49 |
+
# )
|
50 |
+
|
51 |
+
# Механизмы внимания
|
52 |
+
self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
53 |
+
])
|
54 |
+
self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
55 |
+
])
|
56 |
+
|
57 |
+
# Классификатор
|
58 |
+
# self.classifier = nn.Sequential(
|
59 |
+
# nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features),
|
60 |
+
# nn.ReLU(),
|
61 |
+
# nn.Linear(out_features, num_classes)
|
62 |
+
# )
|
63 |
+
|
64 |
+
self.classifier = nn.Sequential(
|
65 |
+
nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features),
|
66 |
+
# nn.LayerNorm(out_features),
|
67 |
+
# nn.GELU(),
|
68 |
+
# nn.Dropout(dropout),
|
69 |
+
nn.Linear(out_features, num_classes)
|
70 |
+
)
|
71 |
+
|
72 |
+
# self._init_weights()
|
73 |
+
|
74 |
+
def forward(self, audio_features, text_features):
|
75 |
+
# Преобразование размерностей
|
76 |
+
audio_features = audio_features.float()
|
77 |
+
text_features = text_features.float()
|
78 |
+
|
79 |
+
# audio_features = self.audio_proj(audio_features)
|
80 |
+
# text_features = self.text_proj(text_features)
|
81 |
+
audio_features = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1)
|
82 |
+
text_features = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1)
|
83 |
+
|
84 |
+
# Адаптивная пуллинг до минимальной длины
|
85 |
+
min_seq_len = min(audio_features.size(1), text_features.size(1))
|
86 |
+
audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1)
|
87 |
+
text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1)
|
88 |
+
|
89 |
+
# Трансформерные блоки
|
90 |
+
for i in range(len(self.audio_to_text_attn)):
|
91 |
+
attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
|
92 |
+
attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
|
93 |
+
audio_features += attn_audio
|
94 |
+
text_features += attn_text
|
95 |
+
|
96 |
+
# Статистики
|
97 |
+
std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
|
98 |
+
std_text, mean_text = torch.std_mean(attn_text, dim=1)
|
99 |
+
|
100 |
+
# Классификация
|
101 |
+
if self.mode == 'mean':
|
102 |
+
return self.classifier(torch.cat([mean_audio, mean_audio], dim=1))
|
103 |
+
else:
|
104 |
+
return self.classifier(torch.cat([mean_audio, std_audio, mean_text, std_text], dim=1))
|
105 |
+
|
106 |
+
def _init_weights(self):
|
107 |
+
for m in self.modules():
|
108 |
+
if isinstance(m, nn.Linear):
|
109 |
+
nn.init.xavier_uniform_(m.weight)
|
110 |
+
if m.bias is not None:
|
111 |
+
nn.init.constant_(m.bias, 0)
|
112 |
+
elif isinstance(m, nn.LayerNorm):
|
113 |
+
nn.init.constant_(m.weight, 1)
|
114 |
+
nn.init.constant_(m.bias, 0)
|
115 |
+
|
116 |
+
class MultiModalTransformer_v4(nn.Module):
|
117 |
+
def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
118 |
+
super(MultiModalTransformer_v4, self).__init__()
|
119 |
+
|
120 |
+
self.mode = mode
|
121 |
+
|
122 |
+
self.hidden_dim = hidden_dim
|
123 |
+
|
124 |
+
# Проекционные слои
|
125 |
+
self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
|
126 |
+
self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
|
127 |
+
|
128 |
+
# Механизмы внимания
|
129 |
+
self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
130 |
+
])
|
131 |
+
self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
132 |
+
])
|
133 |
+
|
134 |
+
# Графовое слияние вместо GAL
|
135 |
+
if self.mode == 'mean':
|
136 |
+
self.graph_fusion = GraphFusionLayer(hidden_dim, heads=num_graph_heads)
|
137 |
+
else:
|
138 |
+
self.graph_fusion = GraphFusionLayer(hidden_dim*2, heads=num_graph_heads)
|
139 |
+
|
140 |
+
# Классификатор
|
141 |
+
self.classifier = nn.Sequential(
|
142 |
+
nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features),
|
143 |
+
nn.ReLU(),
|
144 |
+
nn.Linear(out_features, num_classes)
|
145 |
+
)
|
146 |
+
|
147 |
+
def forward(self, audio_features, text_features):
|
148 |
+
# Преобразование размерностей
|
149 |
+
audio_features = audio_features.float()
|
150 |
+
text_features = text_features.float()
|
151 |
+
|
152 |
+
audio_features = self.audio_proj(audio_features)
|
153 |
+
text_features = self.text_proj(text_features)
|
154 |
+
|
155 |
+
# Адаптивная пуллинг до минимальной длины
|
156 |
+
min_seq_len = min(audio_features.size(1), text_features.size(1))
|
157 |
+
audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1)
|
158 |
+
text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1)
|
159 |
+
|
160 |
+
# Трансформерные блоки
|
161 |
+
for i in range(len(self.audio_to_text_attn)):
|
162 |
+
attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
|
163 |
+
attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
|
164 |
+
audio_features += attn_audio
|
165 |
+
text_features += attn_text
|
166 |
+
|
167 |
+
# Статистики
|
168 |
+
std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
|
169 |
+
std_text, mean_text = torch.std_mean(attn_text, dim=1)
|
170 |
+
|
171 |
+
# Графовое слияние статистик
|
172 |
+
if self.mode == 'mean':
|
173 |
+
h_ta = self.graph_fusion(mean_audio, mean_text)
|
174 |
+
else:
|
175 |
+
h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1))
|
176 |
+
|
177 |
+
# Классификация
|
178 |
+
return self.classifier(h_ta)
|
179 |
+
|
180 |
+
class MultiModalTransformer_v5(nn.Module):
|
181 |
+
def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, tr_layer_number=1, positional_encoding=True, dropout=0, mode='mean', device="cuda", out_features=128, num_classes=7):
|
182 |
+
super(MultiModalTransformer_v5, self).__init__()
|
183 |
+
|
184 |
+
self.hidden_dim = hidden_dim
|
185 |
+
self.mode = mode
|
186 |
+
|
187 |
+
# Приведение к общей размерности (адаптивные проекции)
|
188 |
+
self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
|
189 |
+
self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
|
190 |
+
|
191 |
+
# Механизмы внимания
|
192 |
+
|
193 |
+
self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
194 |
+
])
|
195 |
+
self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
196 |
+
])
|
197 |
+
|
198 |
+
# Гейтед аттеншн
|
199 |
+
if self.mode == 'mean':
|
200 |
+
self.gal = GAL(hidden_dim, hidden_dim, hidden_dim_gated)
|
201 |
+
else:
|
202 |
+
self.gal = GAL(hidden_dim*2, hidden_dim*2, hidden_dim_gated)
|
203 |
+
|
204 |
+
# Классификатор
|
205 |
+
self.classifier = nn.Sequential(
|
206 |
+
nn.Linear(hidden_dim, out_features),
|
207 |
+
nn.ReLU(),
|
208 |
+
nn.Linear(out_features, num_classes)
|
209 |
+
)
|
210 |
+
|
211 |
+
def forward(self, audio_features, text_features):
|
212 |
+
bs, seq_audio, audio_feat_dim = audio_features.shape
|
213 |
+
bs, seq_text, text_feat_dim = text_features.shape
|
214 |
+
|
215 |
+
text_features = text_features.to(torch.float32)
|
216 |
+
audio_features = audio_features.to(torch.float32)
|
217 |
+
|
218 |
+
# Приведение размерности
|
219 |
+
audio_features = self.audio_proj(audio_features) # (bs, seq_audio, hidden_dim)
|
220 |
+
text_features = self.text_proj(text_features) # (bs, seq_text, hidden_dim)
|
221 |
+
|
222 |
+
# Определяем минимальную длину последовательности
|
223 |
+
min_seq_len = min(seq_audio, seq_text)
|
224 |
+
|
225 |
+
# Усреднение до минимальной длины
|
226 |
+
audio_features = F.adaptive_avg_pool2d(audio_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1)
|
227 |
+
text_features = F.adaptive_avg_pool2d(text_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1)
|
228 |
+
|
229 |
+
# Трансформерные блоки
|
230 |
+
for i in range(len(self.audio_to_text_attn)):
|
231 |
+
attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
|
232 |
+
attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
|
233 |
+
audio_features += attn_audio
|
234 |
+
text_features += attn_text
|
235 |
+
|
236 |
+
# Статистики
|
237 |
+
std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
|
238 |
+
std_text, mean_text = torch.std_mean(attn_text, dim=1)
|
239 |
+
|
240 |
+
# # Гейтед аттеншн
|
241 |
+
# h_audio = torch.tanh(self.Wa(torch.cat([min_audio, std_audio], dim=1)))
|
242 |
+
# h_text = torch.tanh(self.Wt(torch.cat([min_text, std_text], dim=1)))
|
243 |
+
# z_ta = torch.sigmoid(self.W_at(torch.cat([min_audio, std_audio, min_text, std_text], dim=1)))
|
244 |
+
# h_ta = z_ta * h_text + (1 - z_ta) * h_audio
|
245 |
+
if self.mode == 'mean':
|
246 |
+
h_ta = self.gal(mean_audio, mean_text)
|
247 |
+
else:
|
248 |
+
h_ta = self.gal(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1))
|
249 |
+
|
250 |
+
# Классификация
|
251 |
+
output = self.classifier(h_ta)
|
252 |
+
return output
|
253 |
+
|
254 |
+
class MultiModalTransformer_v7(nn.Module):
|
255 |
+
def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, num_heads=2, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
256 |
+
super(MultiModalTransformer_v7, self).__init__()
|
257 |
+
|
258 |
+
self.mode = mode
|
259 |
+
|
260 |
+
self.hidden_dim = hidden_dim
|
261 |
+
|
262 |
+
# Проекционные слои
|
263 |
+
self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
|
264 |
+
self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
|
265 |
+
|
266 |
+
# Механизмы внимания
|
267 |
+
self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
268 |
+
])
|
269 |
+
self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
|
270 |
+
])
|
271 |
+
|
272 |
+
# Графовое слияние вместо GAL
|
273 |
+
if self.mode == 'mean':
|
274 |
+
self.graph_fusion = GraphFusionLayerAtt(hidden_dim, heads=num_heads)
|
275 |
+
else:
|
276 |
+
self.graph_fusion = GraphFusionLayerAtt(hidden_dim*2, heads=num_heads)
|
277 |
+
|
278 |
+
# Классификатор
|
279 |
+
self.classifier = nn.Sequential(
|
280 |
+
nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features),
|
281 |
+
nn.ReLU(),
|
282 |
+
nn.Linear(out_features, num_classes)
|
283 |
+
)
|
284 |
+
|
285 |
+
def forward(self, audio_features, text_features):
|
286 |
+
# Преобразование размерностей
|
287 |
+
audio_features = audio_features.float()
|
288 |
+
text_features = text_features.float()
|
289 |
+
|
290 |
+
audio_features = self.audio_proj(audio_features)
|
291 |
+
text_features = self.text_proj(text_features)
|
292 |
+
|
293 |
+
# Адаптивная пуллинг до минимальной длины
|
294 |
+
min_seq_len = min(audio_features.size(1), text_features.size(1))
|
295 |
+
audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1)
|
296 |
+
text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1)
|
297 |
+
|
298 |
+
# Трансформерные блоки
|
299 |
+
for i in range(len(self.audio_to_text_attn)):
|
300 |
+
attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
|
301 |
+
attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
|
302 |
+
audio_features += attn_audio
|
303 |
+
text_features += attn_text
|
304 |
+
|
305 |
+
# Статистики
|
306 |
+
std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
|
307 |
+
std_text, mean_text = torch.std_mean(attn_text, dim=1)
|
308 |
+
|
309 |
+
# Графовое слияние статистик
|
310 |
+
if self.mode == 'mean':
|
311 |
+
h_ta = self.graph_fusion(mean_audio, mean_text)
|
312 |
+
else:
|
313 |
+
h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_audio, std_text], dim=1))
|
314 |
+
|
315 |
+
# Классификация
|
316 |
+
return self.classifier(h_ta)
|
317 |
+
|
318 |
+
class BiFormer(nn.Module):
|
319 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
|
320 |
+
num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean',
|
321 |
+
device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
322 |
+
super(BiFormer, self).__init__()
|
323 |
+
|
324 |
+
self.mode = mode
|
325 |
+
self.hidden_dim = hidden_dim
|
326 |
+
self.seg_len = seg_len
|
327 |
+
self.tr_layer_number = tr_layer_number
|
328 |
+
|
329 |
+
# Проекционные слои с нормализацией
|
330 |
+
self.audio_proj = nn.Sequential(
|
331 |
+
nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
332 |
+
nn.LayerNorm(hidden_dim),
|
333 |
+
nn.Dropout(dropout)
|
334 |
+
)
|
335 |
+
|
336 |
+
self.text_proj = nn.Sequential(
|
337 |
+
nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
|
338 |
+
nn.LayerNorm(hidden_dim),
|
339 |
+
nn.Dropout(dropout)
|
340 |
+
)
|
341 |
+
# self.audio_proj = nn.Sequential(
|
342 |
+
# nn.Conv1d(audio_dim, hidden_dim, 1),
|
343 |
+
# nn.GELU(),
|
344 |
+
# )
|
345 |
+
|
346 |
+
# self.text_proj = nn.Sequential(
|
347 |
+
# nn.Conv1d(text_dim, hidden_dim, 1),
|
348 |
+
# nn.GELU(),
|
349 |
+
# )
|
350 |
+
|
351 |
+
# Трансформерные слои (сохраняем вашу реализацию)
|
352 |
+
self.audio_to_text_attn = nn.ModuleList([
|
353 |
+
TransformerEncoderLayer(
|
354 |
+
input_dim=hidden_dim,
|
355 |
+
num_heads=num_transformer_heads,
|
356 |
+
dropout=dropout,
|
357 |
+
positional_encoding=positional_encoding
|
358 |
+
) for _ in range(tr_layer_number)
|
359 |
+
])
|
360 |
+
|
361 |
+
self.text_to_audio_attn = nn.ModuleList([
|
362 |
+
TransformerEncoderLayer(
|
363 |
+
input_dim=hidden_dim,
|
364 |
+
num_heads=num_transformer_heads,
|
365 |
+
dropout=dropout,
|
366 |
+
positional_encoding=positional_encoding
|
367 |
+
) for _ in range(tr_layer_number)
|
368 |
+
])
|
369 |
+
|
370 |
+
# Автоматический расчёт размерности для классификатора
|
371 |
+
self._calculate_classifier_input_dim()
|
372 |
+
|
373 |
+
# Классификатор
|
374 |
+
self.classifier = nn.Sequential(
|
375 |
+
nn.Linear(self.classifier_input_dim, out_features),
|
376 |
+
nn.LayerNorm(out_features),
|
377 |
+
nn.GELU(),
|
378 |
+
nn.Dropout(dropout),
|
379 |
+
nn.Linear(out_features, num_classes)
|
380 |
+
)
|
381 |
+
|
382 |
+
self._init_weights()
|
383 |
+
|
384 |
+
def _calculate_classifier_input_dim(self):
|
385 |
+
"""Вычисляет размер входных признаков для классификатора"""
|
386 |
+
# Тестовый проход через пулинг с dummy-данными
|
387 |
+
dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
|
388 |
+
dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
|
389 |
+
|
390 |
+
audio_pool = self._pool_features(dummy_audio)
|
391 |
+
text_pool = self._pool_features(dummy_text)
|
392 |
+
|
393 |
+
combined = torch.cat([audio_pool, text_pool], dim=1)
|
394 |
+
self.classifier_input_dim = combined.size(1)
|
395 |
+
|
396 |
+
def _pool_features(self, x):
|
397 |
+
# Статистики по временной оси (seq_len)
|
398 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
399 |
+
|
400 |
+
# Статистики по feature оси (hidden_dim)
|
401 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
402 |
+
|
403 |
+
return torch.cat([mean_temp, mean_feat], dim=1)
|
404 |
+
|
405 |
+
def forward(self, audio_features, text_features):
|
406 |
+
# Проекция признаков
|
407 |
+
# audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1)
|
408 |
+
# text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1)
|
409 |
+
audio = self.audio_proj(audio_features.float())
|
410 |
+
text = self.text_proj(text_features.float())
|
411 |
+
|
412 |
+
# Адаптивный пулинг
|
413 |
+
min_len = min(audio.size(1), text.size(1))
|
414 |
+
audio = self.adaptive_temporal_pool(audio, min_len)
|
415 |
+
text = self.adaptive_temporal_pool(text, min_len)
|
416 |
+
|
417 |
+
# Кросс-модальное взаимодействие
|
418 |
+
for i in range(self.tr_layer_number):
|
419 |
+
attn_audio = self.audio_to_text_attn[i](text, audio, audio)
|
420 |
+
attn_text = self.text_to_audio_attn[i](audio, text, text)
|
421 |
+
audio = audio + attn_audio
|
422 |
+
text = text + attn_text
|
423 |
+
|
424 |
+
# Агрегация признаков
|
425 |
+
audio_pool = self._pool_features(audio)
|
426 |
+
text_pool = self._pool_features(text)
|
427 |
+
|
428 |
+
# Классификация
|
429 |
+
features = torch.cat([audio_pool, text_pool], dim=1)
|
430 |
+
return self.classifier(features)
|
431 |
+
|
432 |
+
def adaptive_temporal_pool(self, x, target_len):
|
433 |
+
"""Адаптивное изменение временной длины"""
|
434 |
+
if x.size(1) == target_len:
|
435 |
+
return x
|
436 |
+
|
437 |
+
return F.interpolate(
|
438 |
+
x.permute(0, 2, 1),
|
439 |
+
size=target_len,
|
440 |
+
mode='linear',
|
441 |
+
align_corners=False
|
442 |
+
).permute(0, 2, 1)
|
443 |
+
|
444 |
+
def _init_weights(self):
|
445 |
+
for m in self.modules():
|
446 |
+
if isinstance(m, nn.Linear):
|
447 |
+
nn.init.xavier_uniform_(m.weight)
|
448 |
+
if m.bias is not None:
|
449 |
+
nn.init.constant_(m.bias, 0)
|
450 |
+
elif isinstance(m, nn.LayerNorm):
|
451 |
+
nn.init.constant_(m.weight, 1)
|
452 |
+
nn.init.constant_(m.bias, 0)
|
453 |
+
|
454 |
+
class BiGraphFormer(nn.Module):
|
455 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
|
456 |
+
num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
|
457 |
+
device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
458 |
+
super(BiGraphFormer, self).__init__()
|
459 |
+
|
460 |
+
self.mode = mode
|
461 |
+
self.hidden_dim = hidden_dim
|
462 |
+
self.seg_len = seg_len
|
463 |
+
self.tr_layer_number = tr_layer_number
|
464 |
+
|
465 |
+
# Проекционные слои с нормализацией
|
466 |
+
self.audio_proj = nn.Sequential(
|
467 |
+
nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
468 |
+
nn.LayerNorm(hidden_dim),
|
469 |
+
nn.Dropout(dropout)
|
470 |
+
)
|
471 |
+
|
472 |
+
self.text_proj = nn.Sequential(
|
473 |
+
nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
|
474 |
+
nn.LayerNorm(hidden_dim),
|
475 |
+
nn.Dropout(dropout)
|
476 |
+
)
|
477 |
+
|
478 |
+
# Трансформерные слои (сохраняем вашу реализацию)
|
479 |
+
self.audio_to_text_attn = nn.ModuleList([
|
480 |
+
TransformerEncoderLayer(
|
481 |
+
input_dim=hidden_dim,
|
482 |
+
num_heads=num_transformer_heads,
|
483 |
+
dropout=dropout,
|
484 |
+
positional_encoding=positional_encoding
|
485 |
+
) for _ in range(tr_layer_number)
|
486 |
+
])
|
487 |
+
|
488 |
+
self.text_to_audio_attn = nn.ModuleList([
|
489 |
+
TransformerEncoderLayer(
|
490 |
+
input_dim=hidden_dim,
|
491 |
+
num_heads=num_transformer_heads,
|
492 |
+
dropout=dropout,
|
493 |
+
positional_encoding=positional_encoding
|
494 |
+
) for _ in range(tr_layer_number)
|
495 |
+
])
|
496 |
+
|
497 |
+
self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads)
|
498 |
+
self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads)
|
499 |
+
|
500 |
+
# Автоматический расчёт размерности для классификатора
|
501 |
+
self._calculate_classifier_input_dim()
|
502 |
+
|
503 |
+
# Классификатор
|
504 |
+
self.classifier = nn.Sequential(
|
505 |
+
nn.Linear(self.classifier_input_dim, out_features),
|
506 |
+
nn.LayerNorm(out_features),
|
507 |
+
nn.GELU(),
|
508 |
+
nn.Dropout(dropout),
|
509 |
+
nn.Linear(out_features, num_classes)
|
510 |
+
)
|
511 |
+
|
512 |
+
# Финальная проекция графов
|
513 |
+
self.fc_feat = nn.Sequential(
|
514 |
+
nn.Linear(self.seg_len, self.seg_len),
|
515 |
+
nn.LayerNorm(self.seg_len),
|
516 |
+
nn.Dropout(dropout)
|
517 |
+
)
|
518 |
+
|
519 |
+
self.fc_temp = nn.Sequential(
|
520 |
+
nn.Linear(hidden_dim, hidden_dim),
|
521 |
+
nn.LayerNorm(hidden_dim),
|
522 |
+
nn.Dropout(dropout)
|
523 |
+
)
|
524 |
+
|
525 |
+
self._init_weights()
|
526 |
+
|
527 |
+
def _calculate_classifier_input_dim(self):
|
528 |
+
"""Вычисляет размер входных признаков для классификатора"""
|
529 |
+
# Тестовый проход через пулинг с dummy-данными
|
530 |
+
dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
|
531 |
+
dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
|
532 |
+
|
533 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
|
534 |
+
# text_pool_temp, _ = self._pool_features(dummy_text)
|
535 |
+
|
536 |
+
combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
|
537 |
+
self.classifier_input_dim = combined.size(1)
|
538 |
+
|
539 |
+
def _pool_features(self, x):
|
540 |
+
# Статистики по временной оси (seq_len)
|
541 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
542 |
+
|
543 |
+
# Статистики по feature оси (hidden_dim)
|
544 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
545 |
+
|
546 |
+
return mean_temp, mean_feat
|
547 |
+
|
548 |
+
def forward(self, audio_features, text_features):
|
549 |
+
# Проекция признаков
|
550 |
+
audio = self.audio_proj(audio_features.float())
|
551 |
+
text = self.text_proj(text_features.float())
|
552 |
+
|
553 |
+
# Адаптивный пулинг
|
554 |
+
min_len = min(audio.size(1), text.size(1))
|
555 |
+
audio = self.adaptive_temporal_pool(audio, min_len)
|
556 |
+
text = self.adaptive_temporal_pool(text, min_len)
|
557 |
+
|
558 |
+
# Кросс-модальное взаимодействие
|
559 |
+
for i in range(self.tr_layer_number):
|
560 |
+
attn_audio = self.audio_to_text_attn[i](text, audio, audio)
|
561 |
+
attn_text = self.text_to_audio_attn[i](audio, text, text)
|
562 |
+
|
563 |
+
audio = audio + attn_audio
|
564 |
+
text = text + attn_text
|
565 |
+
|
566 |
+
# Агрегация признаков
|
567 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(audio)
|
568 |
+
text_pool_temp, text_pool_feat = self._pool_features(text)
|
569 |
+
|
570 |
+
# print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
|
571 |
+
|
572 |
+
graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
|
573 |
+
graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
|
574 |
+
|
575 |
+
# print(graph_feat.shape, graph_temp.shape)
|
576 |
+
# print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
|
577 |
+
|
578 |
+
# graph_feat = self.fc_feat(graph_feat)
|
579 |
+
# graph_temp = self.fc_temp(graph_temp)
|
580 |
+
|
581 |
+
# Классификация
|
582 |
+
features = torch.cat([graph_feat, graph_temp], dim=1)
|
583 |
+
|
584 |
+
# print(graph_feat.shape, graph_temp.shape, features.shape)
|
585 |
+
return self.classifier(features)
|
586 |
+
|
587 |
+
def adaptive_temporal_pool(self, x, target_len):
|
588 |
+
"""Адаптивное изменение временной длины"""
|
589 |
+
if x.size(1) == target_len:
|
590 |
+
return x
|
591 |
+
|
592 |
+
return F.interpolate(
|
593 |
+
x.permute(0, 2, 1),
|
594 |
+
size=target_len,
|
595 |
+
mode='linear',
|
596 |
+
align_corners=False
|
597 |
+
).permute(0, 2, 1)
|
598 |
+
|
599 |
+
def _init_weights(self):
|
600 |
+
for m in self.modules():
|
601 |
+
if isinstance(m, nn.Linear):
|
602 |
+
nn.init.xavier_uniform_(m.weight)
|
603 |
+
if m.bias is not None:
|
604 |
+
nn.init.constant_(m.bias, 0)
|
605 |
+
elif isinstance(m, nn.LayerNorm):
|
606 |
+
nn.init.constant_(m.weight, 1)
|
607 |
+
nn.init.constant_(m.bias, 0)
|
608 |
+
|
609 |
+
|
610 |
+
class BiGatedGraphFormer(nn.Module):
|
611 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
|
612 |
+
num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
|
613 |
+
device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
614 |
+
super(BiGatedGraphFormer, self).__init__()
|
615 |
+
|
616 |
+
self.mode = mode
|
617 |
+
self.hidden_dim = hidden_dim
|
618 |
+
self.seg_len = seg_len
|
619 |
+
self.tr_layer_number = tr_layer_number
|
620 |
+
|
621 |
+
# Проекционные слои с нормализацией
|
622 |
+
self.audio_proj = nn.Sequential(
|
623 |
+
nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
624 |
+
nn.LayerNorm(hidden_dim),
|
625 |
+
nn.Dropout(dropout)
|
626 |
+
)
|
627 |
+
|
628 |
+
self.text_proj = nn.Sequential(
|
629 |
+
nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
|
630 |
+
nn.LayerNorm(hidden_dim),
|
631 |
+
nn.Dropout(dropout)
|
632 |
+
)
|
633 |
+
|
634 |
+
# Трансформерные слои (сохраняем вашу реализацию)
|
635 |
+
self.audio_to_text_attn = nn.ModuleList([
|
636 |
+
TransformerEncoderLayer(
|
637 |
+
input_dim=hidden_dim,
|
638 |
+
num_heads=num_transformer_heads,
|
639 |
+
dropout=dropout,
|
640 |
+
positional_encoding=positional_encoding
|
641 |
+
) for _ in range(tr_layer_number)
|
642 |
+
])
|
643 |
+
|
644 |
+
self.text_to_audio_attn = nn.ModuleList([
|
645 |
+
TransformerEncoderLayer(
|
646 |
+
input_dim=hidden_dim,
|
647 |
+
num_heads=num_transformer_heads,
|
648 |
+
dropout=dropout,
|
649 |
+
positional_encoding=positional_encoding
|
650 |
+
) for _ in range(tr_layer_number)
|
651 |
+
])
|
652 |
+
|
653 |
+
self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False)
|
654 |
+
self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False)
|
655 |
+
|
656 |
+
self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout)
|
657 |
+
self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout)
|
658 |
+
|
659 |
+
# Автоматический расчёт размерности для классификатора
|
660 |
+
self._calculate_classifier_input_dim()
|
661 |
+
|
662 |
+
# Классификатор
|
663 |
+
self.classifier = nn.Sequential(
|
664 |
+
nn.Linear(hidden_dim_gated*2, out_features),
|
665 |
+
nn.LayerNorm(out_features),
|
666 |
+
nn.GELU(),
|
667 |
+
nn.Dropout(dropout),
|
668 |
+
nn.Linear(out_features, num_classes)
|
669 |
+
)
|
670 |
+
|
671 |
+
# Финальная проекция граф��в
|
672 |
+
self.fc_graph_feat = nn.Sequential(
|
673 |
+
nn.Linear(self.seg_len, hidden_dim_gated),
|
674 |
+
nn.LayerNorm(hidden_dim_gated),
|
675 |
+
nn.Dropout(dropout)
|
676 |
+
)
|
677 |
+
|
678 |
+
self.fc_graph_temp = nn.Sequential(
|
679 |
+
nn.Linear(hidden_dim, hidden_dim_gated),
|
680 |
+
nn.LayerNorm(hidden_dim_gated),
|
681 |
+
nn.Dropout(dropout)
|
682 |
+
)
|
683 |
+
|
684 |
+
# Финальная проекция gated
|
685 |
+
self.fc_gated_feat = nn.Sequential(
|
686 |
+
nn.Linear(hidden_dim_gated, hidden_dim_gated),
|
687 |
+
nn.LayerNorm(hidden_dim_gated),
|
688 |
+
nn.Dropout(dropout)
|
689 |
+
)
|
690 |
+
|
691 |
+
self.fc_gated_temp = nn.Sequential(
|
692 |
+
nn.Linear(hidden_dim_gated, hidden_dim_gated),
|
693 |
+
nn.LayerNorm(hidden_dim_gated),
|
694 |
+
nn.Dropout(dropout)
|
695 |
+
)
|
696 |
+
|
697 |
+
self._init_weights()
|
698 |
+
|
699 |
+
def _calculate_classifier_input_dim(self):
|
700 |
+
"""Вычисляет размер входных признаков для классификатора"""
|
701 |
+
# Тестовый проход через пулинг с dummy-данными
|
702 |
+
dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
|
703 |
+
dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
|
704 |
+
|
705 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
|
706 |
+
# text_pool_temp, _ = self._pool_features(dummy_text)
|
707 |
+
|
708 |
+
combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
|
709 |
+
self.classifier_input_dim = combined.size(1)
|
710 |
+
|
711 |
+
def _pool_features(self, x):
|
712 |
+
# Статистики по временной оси (seq_len)
|
713 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
714 |
+
|
715 |
+
# Статистики по feature оси (hidden_dim)
|
716 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
717 |
+
|
718 |
+
return mean_temp, mean_feat
|
719 |
+
|
720 |
+
def forward(self, audio_features, text_features):
|
721 |
+
# Проекция признаков
|
722 |
+
audio = self.audio_proj(audio_features.float())
|
723 |
+
text = self.text_proj(text_features.float())
|
724 |
+
|
725 |
+
# Адаптивный пулинг
|
726 |
+
min_len = min(audio.size(1), text.size(1))
|
727 |
+
audio = self.adaptive_temporal_pool(audio, min_len)
|
728 |
+
text = self.adaptive_temporal_pool(text, min_len)
|
729 |
+
|
730 |
+
# Кросс-модальное взаимодействие
|
731 |
+
for i in range(self.tr_layer_number):
|
732 |
+
attn_audio = self.audio_to_text_attn[i](text, audio, audio)
|
733 |
+
attn_text = self.text_to_audio_attn[i](audio, text, text)
|
734 |
+
|
735 |
+
audio = audio + attn_audio
|
736 |
+
text = text + attn_text
|
737 |
+
|
738 |
+
# Агрегация признаков
|
739 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(audio)
|
740 |
+
text_pool_temp, text_pool_feat = self._pool_features(text)
|
741 |
+
|
742 |
+
# print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
|
743 |
+
|
744 |
+
graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
|
745 |
+
graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
|
746 |
+
|
747 |
+
gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :])
|
748 |
+
gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :])
|
749 |
+
|
750 |
+
fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat)
|
751 |
+
fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp)
|
752 |
+
|
753 |
+
# print(graph_feat.shape, graph_temp.shape)
|
754 |
+
# print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
|
755 |
+
|
756 |
+
# graph_feat = self.fc_feat(graph_feat)
|
757 |
+
# graph_temp = self.fc_temp(graph_temp)
|
758 |
+
|
759 |
+
# Классификация
|
760 |
+
features = torch.cat([fused_feat, fused_temp], dim=1)
|
761 |
+
|
762 |
+
# print(graph_feat.shape, graph_temp.shape, features.shape)
|
763 |
+
return self.classifier(features)
|
764 |
+
|
765 |
+
def adaptive_temporal_pool(self, x, target_len):
|
766 |
+
"""Адаптивное изменение временной длины"""
|
767 |
+
if x.size(1) == target_len:
|
768 |
+
return x
|
769 |
+
|
770 |
+
return F.interpolate(
|
771 |
+
x.permute(0, 2, 1),
|
772 |
+
size=target_len,
|
773 |
+
mode='linear',
|
774 |
+
align_corners=False
|
775 |
+
).permute(0, 2, 1)
|
776 |
+
|
777 |
+
def _init_weights(self):
|
778 |
+
for m in self.modules():
|
779 |
+
if isinstance(m, nn.Linear):
|
780 |
+
nn.init.xavier_uniform_(m.weight)
|
781 |
+
if m.bias is not None:
|
782 |
+
nn.init.constant_(m.bias, 0)
|
783 |
+
elif isinstance(m, nn.LayerNorm):
|
784 |
+
nn.init.constant_(m.weight, 1)
|
785 |
+
nn.init.constant_(m.bias, 0)
|
786 |
+
|
787 |
+
|
788 |
+
class BiFormerWithProb(nn.Module):
|
789 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
|
790 |
+
num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean',
|
791 |
+
device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
792 |
+
super(BiFormerWithProb, self).__init__()
|
793 |
+
|
794 |
+
self.mode = mode
|
795 |
+
self.hidden_dim = hidden_dim
|
796 |
+
self.seg_len = seg_len
|
797 |
+
self.tr_layer_number = tr_layer_number
|
798 |
+
|
799 |
+
# Проекционные слои с нормализацией
|
800 |
+
self.audio_proj = nn.Sequential(
|
801 |
+
nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
802 |
+
nn.LayerNorm(hidden_dim),
|
803 |
+
nn.Dropout(dropout)
|
804 |
+
)
|
805 |
+
|
806 |
+
self.text_proj = nn.Sequential(
|
807 |
+
nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
|
808 |
+
nn.LayerNorm(hidden_dim),
|
809 |
+
nn.Dropout(dropout)
|
810 |
+
)
|
811 |
+
# self.audio_proj = nn.Sequential(
|
812 |
+
# nn.Conv1d(audio_dim, hidden_dim, 1),
|
813 |
+
# nn.GELU(),
|
814 |
+
# )
|
815 |
+
|
816 |
+
# self.text_proj = nn.Sequential(
|
817 |
+
# nn.Conv1d(text_dim, hidden_dim, 1),
|
818 |
+
# nn.GELU(),
|
819 |
+
# )
|
820 |
+
|
821 |
+
# Трансформерные слои (сохраняем вашу реализацию)
|
822 |
+
self.audio_to_text_attn = nn.ModuleList([
|
823 |
+
TransformerEncoderLayer(
|
824 |
+
input_dim=hidden_dim,
|
825 |
+
num_heads=num_transformer_heads,
|
826 |
+
dropout=dropout,
|
827 |
+
positional_encoding=positional_encoding
|
828 |
+
) for _ in range(tr_layer_number)
|
829 |
+
])
|
830 |
+
|
831 |
+
self.text_to_audio_attn = nn.ModuleList([
|
832 |
+
TransformerEncoderLayer(
|
833 |
+
input_dim=hidden_dim,
|
834 |
+
num_heads=num_transformer_heads,
|
835 |
+
dropout=dropout,
|
836 |
+
positional_encoding=positional_encoding
|
837 |
+
) for _ in range(tr_layer_number)
|
838 |
+
])
|
839 |
+
|
840 |
+
# Автоматический расчёт размерности для классификатора
|
841 |
+
self._calculate_classifier_input_dim()
|
842 |
+
|
843 |
+
# Классификатор
|
844 |
+
self.classifier = nn.Sequential(
|
845 |
+
nn.Linear(self.classifier_input_dim, out_features),
|
846 |
+
nn.LayerNorm(out_features),
|
847 |
+
nn.GELU(),
|
848 |
+
nn.Dropout(dropout),
|
849 |
+
nn.Linear(out_features, num_classes)
|
850 |
+
)
|
851 |
+
|
852 |
+
self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
|
853 |
+
|
854 |
+
self._init_weights()
|
855 |
+
|
856 |
+
def _calculate_classifier_input_dim(self):
|
857 |
+
"""Вычисляет размер входных признаков для классификатора"""
|
858 |
+
# Тестовый проход через пулинг с dummy-данными
|
859 |
+
dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
|
860 |
+
dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
|
861 |
+
|
862 |
+
audio_pool = self._pool_features(dummy_audio)
|
863 |
+
text_pool = self._pool_features(dummy_text)
|
864 |
+
|
865 |
+
combined = torch.cat([audio_pool, text_pool], dim=1)
|
866 |
+
self.classifier_input_dim = combined.size(1)
|
867 |
+
|
868 |
+
def _pool_features(self, x):
|
869 |
+
# Статистики по временной оси (seq_len)
|
870 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
871 |
+
|
872 |
+
# Статистики по feature оси (hidden_dim)
|
873 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
874 |
+
|
875 |
+
return torch.cat([mean_temp, mean_feat], dim=1)
|
876 |
+
|
877 |
+
def forward(self, audio_features, text_features, audio_pred, text_pred):
|
878 |
+
# Проекция признаков
|
879 |
+
# audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1)
|
880 |
+
# text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1)
|
881 |
+
audio = self.audio_proj(audio_features.float())
|
882 |
+
text = self.text_proj(text_features.float())
|
883 |
+
|
884 |
+
# Адаптивный пулинг
|
885 |
+
min_len = min(audio.size(1), text.size(1))
|
886 |
+
audio = self.adaptive_temporal_pool(audio, min_len)
|
887 |
+
text = self.adaptive_temporal_pool(text, min_len)
|
888 |
+
|
889 |
+
# Кросс-модальное взаимодействие
|
890 |
+
for i in range(self.tr_layer_number):
|
891 |
+
attn_audio = self.audio_to_text_attn[i](text, audio, audio)
|
892 |
+
attn_text = self.text_to_audio_attn[i](audio, text, text)
|
893 |
+
audio = audio + attn_audio
|
894 |
+
text = text + attn_text
|
895 |
+
|
896 |
+
# Агрегация признаков
|
897 |
+
audio_pool = self._pool_features(audio)
|
898 |
+
text_pool = self._pool_features(text)
|
899 |
+
|
900 |
+
# Классификация
|
901 |
+
features = torch.cat([audio_pool, text_pool], dim=1)
|
902 |
+
|
903 |
+
out = self.classifier(features)
|
904 |
+
|
905 |
+
w_out = self.pred_fusion([audio_pred, text_pred, out])
|
906 |
+
return w_out
|
907 |
+
|
908 |
+
def adaptive_temporal_pool(self, x, target_len):
|
909 |
+
"""Адаптивное изменение временной длины"""
|
910 |
+
if x.size(1) == target_len:
|
911 |
+
return x
|
912 |
+
|
913 |
+
return F.interpolate(
|
914 |
+
x.permute(0, 2, 1),
|
915 |
+
size=target_len,
|
916 |
+
mode='linear',
|
917 |
+
align_corners=False
|
918 |
+
).permute(0, 2, 1)
|
919 |
+
|
920 |
+
def _init_weights(self):
|
921 |
+
for m in self.modules():
|
922 |
+
if isinstance(m, nn.Linear):
|
923 |
+
nn.init.xavier_uniform_(m.weight)
|
924 |
+
if m.bias is not None:
|
925 |
+
nn.init.constant_(m.bias, 0)
|
926 |
+
elif isinstance(m, nn.LayerNorm):
|
927 |
+
nn.init.constant_(m.weight, 1)
|
928 |
+
nn.init.constant_(m.bias, 0)
|
929 |
+
|
930 |
+
class BiGraphFormerWithProb(nn.Module):
|
931 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
|
932 |
+
num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
|
933 |
+
device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
934 |
+
super(BiGraphFormerWithProb, self).__init__()
|
935 |
+
|
936 |
+
self.mode = mode
|
937 |
+
self.hidden_dim = hidden_dim
|
938 |
+
self.seg_len = seg_len
|
939 |
+
self.tr_layer_number = tr_layer_number
|
940 |
+
|
941 |
+
# Проекционные слои с нормализацией
|
942 |
+
self.audio_proj = nn.Sequential(
|
943 |
+
nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
944 |
+
nn.LayerNorm(hidden_dim),
|
945 |
+
nn.Dropout(dropout)
|
946 |
+
)
|
947 |
+
|
948 |
+
self.text_proj = nn.Sequential(
|
949 |
+
nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
|
950 |
+
nn.LayerNorm(hidden_dim),
|
951 |
+
nn.Dropout(dropout)
|
952 |
+
)
|
953 |
+
|
954 |
+
# Трансформерные слои (сохраняем вашу реализацию)
|
955 |
+
self.audio_to_text_attn = nn.ModuleList([
|
956 |
+
TransformerEncoderLayer(
|
957 |
+
input_dim=hidden_dim,
|
958 |
+
num_heads=num_transformer_heads,
|
959 |
+
dropout=dropout,
|
960 |
+
positional_encoding=positional_encoding
|
961 |
+
) for _ in range(tr_layer_number)
|
962 |
+
])
|
963 |
+
|
964 |
+
self.text_to_audio_attn = nn.ModuleList([
|
965 |
+
TransformerEncoderLayer(
|
966 |
+
input_dim=hidden_dim,
|
967 |
+
num_heads=num_transformer_heads,
|
968 |
+
dropout=dropout,
|
969 |
+
positional_encoding=positional_encoding
|
970 |
+
) for _ in range(tr_layer_number)
|
971 |
+
])
|
972 |
+
|
973 |
+
self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads)
|
974 |
+
self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads)
|
975 |
+
|
976 |
+
# Автоматический расчёт размерности для классификатора
|
977 |
+
self._calculate_classifier_input_dim()
|
978 |
+
|
979 |
+
# Классификатор
|
980 |
+
self.classifier = nn.Sequential(
|
981 |
+
nn.Linear(self.classifier_input_dim, out_features),
|
982 |
+
nn.LayerNorm(out_features),
|
983 |
+
nn.GELU(),
|
984 |
+
nn.Dropout(dropout),
|
985 |
+
nn.Linear(out_features, num_classes)
|
986 |
+
)
|
987 |
+
|
988 |
+
# Финальная проекция графов
|
989 |
+
self.fc_feat = nn.Sequential(
|
990 |
+
nn.Linear(self.seg_len, self.seg_len),
|
991 |
+
nn.LayerNorm(self.seg_len),
|
992 |
+
nn.Dropout(dropout)
|
993 |
+
)
|
994 |
+
|
995 |
+
self.fc_temp = nn.Sequential(
|
996 |
+
nn.Linear(hidden_dim, hidden_dim),
|
997 |
+
nn.LayerNorm(hidden_dim),
|
998 |
+
nn.Dropout(dropout)
|
999 |
+
)
|
1000 |
+
|
1001 |
+
self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
|
1002 |
+
|
1003 |
+
self._init_weights()
|
1004 |
+
|
1005 |
+
def _calculate_classifier_input_dim(self):
|
1006 |
+
"""Вычисляет размер входных признаков для классификатора"""
|
1007 |
+
# Тестовый проход через пулинг с dummy-данными
|
1008 |
+
dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
|
1009 |
+
dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
|
1010 |
+
|
1011 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
|
1012 |
+
# text_pool_temp, _ = self._pool_features(dummy_text)
|
1013 |
+
|
1014 |
+
combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
|
1015 |
+
self.classifier_input_dim = combined.size(1)
|
1016 |
+
|
1017 |
+
def _pool_features(self, x):
|
1018 |
+
# Статистики по временной оси (seq_len)
|
1019 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
1020 |
+
|
1021 |
+
# Статистики по feature оси (hidden_dim)
|
1022 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
1023 |
+
|
1024 |
+
return mean_temp, mean_feat
|
1025 |
+
|
1026 |
+
def forward(self, audio_features, text_features, audio_pred, text_pred):
|
1027 |
+
# Проекция признаков
|
1028 |
+
audio = self.audio_proj(audio_features.float())
|
1029 |
+
text = self.text_proj(text_features.float())
|
1030 |
+
|
1031 |
+
# Адаптивный пулинг
|
1032 |
+
min_len = min(audio.size(1), text.size(1))
|
1033 |
+
audio = self.adaptive_temporal_pool(audio, min_len)
|
1034 |
+
text = self.adaptive_temporal_pool(text, min_len)
|
1035 |
+
|
1036 |
+
# Кросс-модальное взаимодействие
|
1037 |
+
for i in range(self.tr_layer_number):
|
1038 |
+
attn_audio = self.audio_to_text_attn[i](text, audio, audio)
|
1039 |
+
attn_text = self.text_to_audio_attn[i](audio, text, text)
|
1040 |
+
|
1041 |
+
audio = audio + attn_audio
|
1042 |
+
text = text + attn_text
|
1043 |
+
|
1044 |
+
# Агрегация признаков
|
1045 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(audio)
|
1046 |
+
text_pool_temp, text_pool_feat = self._pool_features(text)
|
1047 |
+
|
1048 |
+
# print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
|
1049 |
+
|
1050 |
+
graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
|
1051 |
+
graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
|
1052 |
+
|
1053 |
+
# print(graph_feat.shape, graph_temp.shape)
|
1054 |
+
# print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
|
1055 |
+
|
1056 |
+
# graph_feat = self.fc_feat(graph_feat)
|
1057 |
+
# graph_temp = self.fc_temp(graph_temp)
|
1058 |
+
|
1059 |
+
# Классификация
|
1060 |
+
features = torch.cat([graph_feat, graph_temp], dim=1)
|
1061 |
+
|
1062 |
+
# print(graph_feat.shape, graph_temp.shape, features.shape)
|
1063 |
+
out = self.classifier(features)
|
1064 |
+
|
1065 |
+
w_out = self.pred_fusion([audio_pred, text_pred, out])
|
1066 |
+
return w_out
|
1067 |
+
|
1068 |
+
def adaptive_temporal_pool(self, x, target_len):
|
1069 |
+
"""Адаптивное изменение временной длины"""
|
1070 |
+
if x.size(1) == target_len:
|
1071 |
+
return x
|
1072 |
+
|
1073 |
+
return F.interpolate(
|
1074 |
+
x.permute(0, 2, 1),
|
1075 |
+
size=target_len,
|
1076 |
+
mode='linear',
|
1077 |
+
align_corners=False
|
1078 |
+
).permute(0, 2, 1)
|
1079 |
+
|
1080 |
+
def _init_weights(self):
|
1081 |
+
for m in self.modules():
|
1082 |
+
if isinstance(m, nn.Linear):
|
1083 |
+
nn.init.xavier_uniform_(m.weight)
|
1084 |
+
if m.bias is not None:
|
1085 |
+
nn.init.constant_(m.bias, 0)
|
1086 |
+
elif isinstance(m, nn.LayerNorm):
|
1087 |
+
nn.init.constant_(m.weight, 1)
|
1088 |
+
nn.init.constant_(m.bias, 0)
|
1089 |
+
|
1090 |
+
class BiGatedGraphFormerWithProb(nn.Module):
|
1091 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
|
1092 |
+
num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
|
1093 |
+
device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
1094 |
+
super(BiGatedGraphFormerWithProb, self).__init__()
|
1095 |
+
|
1096 |
+
self.mode = mode
|
1097 |
+
self.hidden_dim = hidden_dim
|
1098 |
+
self.seg_len = seg_len
|
1099 |
+
self.tr_layer_number = tr_layer_number
|
1100 |
+
|
1101 |
+
# Проекционные слои с нормализацией
|
1102 |
+
self.audio_proj = nn.Sequential(
|
1103 |
+
nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
1104 |
+
nn.LayerNorm(hidden_dim),
|
1105 |
+
nn.Dropout(dropout)
|
1106 |
+
)
|
1107 |
+
|
1108 |
+
self.text_proj = nn.Sequential(
|
1109 |
+
nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
|
1110 |
+
nn.LayerNorm(hidden_dim),
|
1111 |
+
nn.Dropout(dropout)
|
1112 |
+
)
|
1113 |
+
|
1114 |
+
# Трансформерные слои (сохраняем вашу реализацию)
|
1115 |
+
self.audio_to_text_attn = nn.ModuleList([
|
1116 |
+
TransformerEncoderLayer(
|
1117 |
+
input_dim=hidden_dim,
|
1118 |
+
num_heads=num_transformer_heads,
|
1119 |
+
dropout=dropout,
|
1120 |
+
positional_encoding=positional_encoding
|
1121 |
+
) for _ in range(tr_layer_number)
|
1122 |
+
])
|
1123 |
+
|
1124 |
+
self.text_to_audio_attn = nn.ModuleList([
|
1125 |
+
TransformerEncoderLayer(
|
1126 |
+
input_dim=hidden_dim,
|
1127 |
+
num_heads=num_transformer_heads,
|
1128 |
+
dropout=dropout,
|
1129 |
+
positional_encoding=positional_encoding
|
1130 |
+
) for _ in range(tr_layer_number)
|
1131 |
+
])
|
1132 |
+
|
1133 |
+
self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False)
|
1134 |
+
self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False)
|
1135 |
+
|
1136 |
+
self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout)
|
1137 |
+
self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout)
|
1138 |
+
|
1139 |
+
# Автоматический расчёт размерности для классификатора
|
1140 |
+
self._calculate_classifier_input_dim()
|
1141 |
+
|
1142 |
+
# Классификатор
|
1143 |
+
self.classifier = nn.Sequential(
|
1144 |
+
nn.Linear(hidden_dim_gated*2, out_features),
|
1145 |
+
nn.LayerNorm(out_features),
|
1146 |
+
nn.GELU(),
|
1147 |
+
nn.Dropout(dropout),
|
1148 |
+
nn.Linear(out_features, num_classes)
|
1149 |
+
)
|
1150 |
+
|
1151 |
+
# Финальная проекция графов
|
1152 |
+
self.fc_graph_feat = nn.Sequential(
|
1153 |
+
nn.Linear(self.seg_len, hidden_dim_gated),
|
1154 |
+
nn.LayerNorm(hidden_dim_gated),
|
1155 |
+
nn.Dropout(dropout)
|
1156 |
+
)
|
1157 |
+
|
1158 |
+
self.fc_graph_temp = nn.Sequential(
|
1159 |
+
nn.Linear(hidden_dim, hidden_dim_gated),
|
1160 |
+
nn.LayerNorm(hidden_dim_gated),
|
1161 |
+
nn.Dropout(dropout)
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
# Финальная проекция gated
|
1165 |
+
self.fc_gated_feat = nn.Sequential(
|
1166 |
+
nn.Linear(hidden_dim_gated, hidden_dim_gated),
|
1167 |
+
nn.LayerNorm(hidden_dim_gated),
|
1168 |
+
nn.Dropout(dropout)
|
1169 |
+
)
|
1170 |
+
|
1171 |
+
self.fc_gated_temp = nn.Sequential(
|
1172 |
+
nn.Linear(hidden_dim_gated, hidden_dim_gated),
|
1173 |
+
nn.LayerNorm(hidden_dim_gated),
|
1174 |
+
nn.Dropout(dropout)
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
|
1178 |
+
|
1179 |
+
self._init_weights()
|
1180 |
+
|
1181 |
+
def _calculate_classifier_input_dim(self):
|
1182 |
+
"""Вычисляет размер входных признаков для классификатора"""
|
1183 |
+
# Тестовый проход через пулинг с dummy-данными
|
1184 |
+
dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
|
1185 |
+
dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
|
1186 |
+
|
1187 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
|
1188 |
+
# text_pool_temp, _ = self._pool_features(dummy_text)
|
1189 |
+
|
1190 |
+
combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
|
1191 |
+
self.classifier_input_dim = combined.size(1)
|
1192 |
+
|
1193 |
+
def _pool_features(self, x):
|
1194 |
+
# Статистики по временной оси (seq_len)
|
1195 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
1196 |
+
|
1197 |
+
# Статистики по feature оси (hidden_dim)
|
1198 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
1199 |
+
|
1200 |
+
return mean_temp, mean_feat
|
1201 |
+
|
1202 |
+
def forward(self, audio_features, text_features, audio_pred, text_pred):
|
1203 |
+
# Проекция признаков
|
1204 |
+
audio = self.audio_proj(audio_features.float())
|
1205 |
+
text = self.text_proj(text_features.float())
|
1206 |
+
|
1207 |
+
# Адаптивный пулинг
|
1208 |
+
min_len = min(audio.size(1), text.size(1))
|
1209 |
+
audio = self.adaptive_temporal_pool(audio, min_len)
|
1210 |
+
text = self.adaptive_temporal_pool(text, min_len)
|
1211 |
+
|
1212 |
+
# Кросс-модальное взаимодействие
|
1213 |
+
for i in range(self.tr_layer_number):
|
1214 |
+
attn_audio = self.audio_to_text_attn[i](text, audio, audio)
|
1215 |
+
attn_text = self.text_to_audio_attn[i](audio, text, text)
|
1216 |
+
|
1217 |
+
audio = audio + attn_audio
|
1218 |
+
text = text + attn_text
|
1219 |
+
|
1220 |
+
# Агрегация признаков
|
1221 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(audio)
|
1222 |
+
text_pool_temp, text_pool_feat = self._pool_features(text)
|
1223 |
+
|
1224 |
+
# print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
|
1225 |
+
|
1226 |
+
graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
|
1227 |
+
graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
|
1228 |
+
|
1229 |
+
gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :])
|
1230 |
+
gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :])
|
1231 |
+
|
1232 |
+
fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat)
|
1233 |
+
fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp)
|
1234 |
+
|
1235 |
+
# print(graph_feat.shape, graph_temp.shape)
|
1236 |
+
# print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
|
1237 |
+
|
1238 |
+
# graph_feat = self.fc_feat(graph_feat)
|
1239 |
+
# graph_temp = self.fc_temp(graph_temp)
|
1240 |
+
|
1241 |
+
# Классификация
|
1242 |
+
features = torch.cat([fused_feat, fused_temp], dim=1)
|
1243 |
+
|
1244 |
+
# print(graph_feat.shape, graph_temp.shape, features.shape)
|
1245 |
+
out = self.classifier(features)
|
1246 |
+
|
1247 |
+
w_out = self.pred_fusion([audio_pred, text_pred, out])
|
1248 |
+
return w_out
|
1249 |
+
|
1250 |
+
def adaptive_temporal_pool(self, x, target_len):
|
1251 |
+
"""Адаптивное изменение временной длины"""
|
1252 |
+
if x.size(1) == target_len:
|
1253 |
+
return x
|
1254 |
+
|
1255 |
+
return F.interpolate(
|
1256 |
+
x.permute(0, 2, 1),
|
1257 |
+
size=target_len,
|
1258 |
+
mode='linear',
|
1259 |
+
align_corners=False
|
1260 |
+
).permute(0, 2, 1)
|
1261 |
+
|
1262 |
+
def _init_weights(self):
|
1263 |
+
for m in self.modules():
|
1264 |
+
if isinstance(m, nn.Linear):
|
1265 |
+
nn.init.xavier_uniform_(m.weight)
|
1266 |
+
if m.bias is not None:
|
1267 |
+
nn.init.constant_(m.bias, 0)
|
1268 |
+
elif isinstance(m, nn.LayerNorm):
|
1269 |
+
nn.init.constant_(m.weight, 1)
|
1270 |
+
nn.init.constant_(m.bias, 0)
|
1271 |
+
|
1272 |
+
class BiGatedFormer(nn.Module):
|
1273 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
|
1274 |
+
num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
|
1275 |
+
device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
|
1276 |
+
super(BiGatedFormer, self).__init__()
|
1277 |
+
|
1278 |
+
self.mode = mode
|
1279 |
+
self.hidden_dim = hidden_dim
|
1280 |
+
self.seg_len = seg_len
|
1281 |
+
self.tr_layer_number = tr_layer_number
|
1282 |
+
|
1283 |
+
# Проекционные слои с нормализацией
|
1284 |
+
self.audio_proj = nn.Sequential(
|
1285 |
+
nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
|
1286 |
+
nn.LayerNorm(hidden_dim),
|
1287 |
+
nn.Dropout(dropout)
|
1288 |
+
)
|
1289 |
+
|
1290 |
+
self.text_proj = nn.Sequential(
|
1291 |
+
nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
|
1292 |
+
nn.LayerNorm(hidden_dim),
|
1293 |
+
nn.Dropout(dropout)
|
1294 |
+
)
|
1295 |
+
|
1296 |
+
# Трансформерные слои (сохраняем вашу реализацию)
|
1297 |
+
self.audio_to_text_attn = nn.ModuleList([
|
1298 |
+
TransformerEncoderLayer(
|
1299 |
+
input_dim=hidden_dim,
|
1300 |
+
num_heads=num_transformer_heads,
|
1301 |
+
dropout=dropout,
|
1302 |
+
positional_encoding=positional_encoding
|
1303 |
+
) for _ in range(tr_layer_number)
|
1304 |
+
])
|
1305 |
+
|
1306 |
+
self.text_to_audio_attn = nn.ModuleList([
|
1307 |
+
TransformerEncoderLayer(
|
1308 |
+
input_dim=hidden_dim,
|
1309 |
+
num_heads=num_transformer_heads,
|
1310 |
+
dropout=dropout,
|
1311 |
+
positional_encoding=positional_encoding
|
1312 |
+
) for _ in range(tr_layer_number)
|
1313 |
+
])
|
1314 |
+
|
1315 |
+
# self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False)
|
1316 |
+
# self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False)
|
1317 |
+
|
1318 |
+
self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout)
|
1319 |
+
self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout)
|
1320 |
+
|
1321 |
+
# Автоматический расчёт размерности для классификатора
|
1322 |
+
self._calculate_classifier_input_dim()
|
1323 |
+
|
1324 |
+
# Классификатор
|
1325 |
+
self.classifier = nn.Sequential(
|
1326 |
+
nn.Linear(hidden_dim_gated*2, out_features),
|
1327 |
+
nn.LayerNorm(out_features),
|
1328 |
+
nn.GELU(),
|
1329 |
+
nn.Dropout(dropout),
|
1330 |
+
nn.Linear(out_features, num_classes)
|
1331 |
+
)
|
1332 |
+
|
1333 |
+
# Финальная проекция графов
|
1334 |
+
# self.fc_graph_feat = nn.Sequential(
|
1335 |
+
# nn.Linear(self.seg_len, hidden_dim_gated),
|
1336 |
+
# nn.LayerNorm(hidden_dim_gated),
|
1337 |
+
# nn.Dropout(dropout)
|
1338 |
+
# )
|
1339 |
+
|
1340 |
+
# self.fc_graph_temp = nn.Sequential(
|
1341 |
+
# nn.Linear(hidden_dim, hidden_dim_gated),
|
1342 |
+
# nn.LayerNorm(hidden_dim_gated),
|
1343 |
+
# nn.Dropout(dropout)
|
1344 |
+
# )
|
1345 |
+
|
1346 |
+
# Финальная проекция gated
|
1347 |
+
self.fc_gated_feat = nn.Sequential(
|
1348 |
+
nn.Linear(hidden_dim_gated, hidden_dim_gated),
|
1349 |
+
nn.LayerNorm(hidden_dim_gated),
|
1350 |
+
nn.Dropout(dropout)
|
1351 |
+
)
|
1352 |
+
|
1353 |
+
self.fc_gated_temp = nn.Sequential(
|
1354 |
+
nn.Linear(hidden_dim_gated, hidden_dim_gated),
|
1355 |
+
nn.LayerNorm(hidden_dim_gated),
|
1356 |
+
nn.Dropout(dropout)
|
1357 |
+
)
|
1358 |
+
|
1359 |
+
self._init_weights()
|
1360 |
+
|
1361 |
+
def _calculate_classifier_input_dim(self):
|
1362 |
+
"""Вычисляет размер входных признаков для классификатора"""
|
1363 |
+
# Тестовый проход через пулинг с dummy-данными
|
1364 |
+
dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
|
1365 |
+
dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
|
1366 |
+
|
1367 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
|
1368 |
+
# text_pool_temp, _ = self._pool_features(dummy_text)
|
1369 |
+
|
1370 |
+
combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
|
1371 |
+
self.classifier_input_dim = combined.size(1)
|
1372 |
+
|
1373 |
+
def _pool_features(self, x):
|
1374 |
+
# Статистики по временной оси (seq_len)
|
1375 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
1376 |
+
|
1377 |
+
# Статистики по feature оси (hidden_dim)
|
1378 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
1379 |
+
|
1380 |
+
return mean_temp, mean_feat
|
1381 |
+
|
1382 |
+
def forward(self, audio_features, text_features):
|
1383 |
+
# Проекция признаков
|
1384 |
+
audio = self.audio_proj(audio_features.float())
|
1385 |
+
text = self.text_proj(text_features.float())
|
1386 |
+
|
1387 |
+
# Адаптивный пулинг
|
1388 |
+
min_len = min(audio.size(1), text.size(1))
|
1389 |
+
audio = self.adaptive_temporal_pool(audio, min_len)
|
1390 |
+
text = self.adaptive_temporal_pool(text, min_len)
|
1391 |
+
|
1392 |
+
# Кросс-модальное взаимодействие
|
1393 |
+
for i in range(self.tr_layer_number):
|
1394 |
+
attn_audio = self.audio_to_text_attn[i](text, audio, audio)
|
1395 |
+
attn_text = self.text_to_audio_attn[i](audio, text, text)
|
1396 |
+
|
1397 |
+
audio = audio + attn_audio
|
1398 |
+
text = text + attn_text
|
1399 |
+
|
1400 |
+
# Агрегация признаков
|
1401 |
+
audio_pool_temp, audio_pool_feat = self._pool_features(audio)
|
1402 |
+
text_pool_temp, text_pool_feat = self._pool_features(text)
|
1403 |
+
|
1404 |
+
# print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
|
1405 |
+
|
1406 |
+
# graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
|
1407 |
+
# graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
|
1408 |
+
|
1409 |
+
gated_feat = self.gated_feat(audio_pool_feat, text_pool_feat)
|
1410 |
+
gated_temp = self.gated_temp(audio_pool_temp, text_pool_temp)
|
1411 |
+
|
1412 |
+
# fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat)
|
1413 |
+
# fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp)
|
1414 |
+
|
1415 |
+
# print(graph_feat.shape, graph_temp.shape)
|
1416 |
+
# print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
|
1417 |
+
|
1418 |
+
# graph_feat = self.fc_feat(graph_feat)
|
1419 |
+
# graph_temp = self.fc_temp(graph_temp)
|
1420 |
+
|
1421 |
+
# Классификация
|
1422 |
+
features = torch.cat([gated_feat, gated_temp], dim=1)
|
1423 |
+
|
1424 |
+
# print(graph_feat.shape, graph_temp.shape, features.shape)
|
1425 |
+
return self.classifier(features)
|
1426 |
+
|
1427 |
+
def adaptive_temporal_pool(self, x, target_len):
|
1428 |
+
"""Адаптивное изменение временной длины"""
|
1429 |
+
if x.size(1) == target_len:
|
1430 |
+
return x
|
1431 |
+
|
1432 |
+
return F.interpolate(
|
1433 |
+
x.permute(0, 2, 1),
|
1434 |
+
size=target_len,
|
1435 |
+
mode='linear',
|
1436 |
+
align_corners=False
|
1437 |
+
).permute(0, 2, 1)
|
1438 |
+
|
1439 |
+
def _init_weights(self):
|
1440 |
+
for m in self.modules():
|
1441 |
+
if isinstance(m, nn.Linear):
|
1442 |
+
nn.init.xavier_uniform_(m.weight)
|
1443 |
+
if m.bias is not None:
|
1444 |
+
nn.init.constant_(m.bias, 0)
|
1445 |
+
elif isinstance(m, nn.LayerNorm):
|
1446 |
+
nn.init.constant_(m.weight, 1)
|
1447 |
+
nn.init.constant_(m.bias, 0)
|
1448 |
+
|
1449 |
+
class BiMamba(nn.Module):
|
1450 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16,
|
1451 |
+
d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='', positional_encoding=False,
|
1452 |
+
out_features=128, num_classes=7, device="cuda"):
|
1453 |
+
super(BiMamba, self).__init__()
|
1454 |
+
|
1455 |
+
self.hidden_dim = hidden_dim
|
1456 |
+
self.seg_len = seg_len
|
1457 |
+
self.num_mamba_layers = mamba_layer_number
|
1458 |
+
self.device = device
|
1459 |
+
|
1460 |
+
# Проекционные слои для каждой модальности
|
1461 |
+
self.audio_proj = nn.Sequential(
|
1462 |
+
nn.Linear(audio_dim, hidden_dim),
|
1463 |
+
nn.LayerNorm(hidden_dim),
|
1464 |
+
nn.Dropout(dropout)
|
1465 |
+
)
|
1466 |
+
|
1467 |
+
self.text_proj = nn.Sequential(
|
1468 |
+
nn.Linear(text_dim, hidden_dim),
|
1469 |
+
nn.LayerNorm(hidden_dim),
|
1470 |
+
nn.Dropout(dropout)
|
1471 |
+
)
|
1472 |
+
|
1473 |
+
# Слой для объединения модальностей
|
1474 |
+
self.fusion_proj = nn.Sequential(
|
1475 |
+
nn.Linear(2 * hidden_dim, hidden_dim),
|
1476 |
+
nn.LayerNorm(hidden_dim),
|
1477 |
+
nn.Dropout(dropout)
|
1478 |
+
)
|
1479 |
+
|
1480 |
+
# Mamba блоки для обработки объединенных признаков
|
1481 |
+
mamba_params = {
|
1482 |
+
'd_input': hidden_dim,
|
1483 |
+
'd_model': hidden_dim,
|
1484 |
+
'd_state': mamba_d_state,
|
1485 |
+
'd_discr': d_discr,
|
1486 |
+
'ker_size': mamba_ker_size
|
1487 |
+
}
|
1488 |
+
|
1489 |
+
self.mamba_blocks = nn.ModuleList([
|
1490 |
+
nn.Sequential(
|
1491 |
+
MambaBlock(**mamba_params),
|
1492 |
+
RMSNorm(hidden_dim)
|
1493 |
+
)
|
1494 |
+
for _ in range(self.num_mamba_layers)
|
1495 |
+
])
|
1496 |
+
|
1497 |
+
# Автоматический расчет размерности классификатора
|
1498 |
+
# self._calculate_classifier_input_dim()
|
1499 |
+
|
1500 |
+
# Классификатор
|
1501 |
+
self.classifier = nn.Sequential(
|
1502 |
+
nn.Linear(self.seg_len + self.hidden_dim, out_features),
|
1503 |
+
nn.LayerNorm(out_features),
|
1504 |
+
nn.GELU(),
|
1505 |
+
nn.Dropout(dropout),
|
1506 |
+
nn.Linear(out_features, num_classes)
|
1507 |
+
)
|
1508 |
+
|
1509 |
+
self._init_weights()
|
1510 |
+
|
1511 |
+
# def _calculate_classifier_input_dim(self):
|
1512 |
+
# """Вычисляет размер входных признаков для классификатора"""
|
1513 |
+
# dummy = torch.randn(1, self.seg_len, self.hidden_dim)
|
1514 |
+
# pooled = self._pool_features(dummy)
|
1515 |
+
# self.classifier_input_dim = pooled.size(1)
|
1516 |
+
|
1517 |
+
def _pool_features(self, x):
|
1518 |
+
"""Объединение временных и feature статистик"""
|
1519 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
1520 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
1521 |
+
full_feature = torch.cat([mean_temp, mean_feat], dim=1)
|
1522 |
+
if full_feature.shape[-1] == self.seg_len+self.hidden_dim:
|
1523 |
+
return torch.cat([mean_temp, mean_feat], dim=1)
|
1524 |
+
else:
|
1525 |
+
pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1]
|
1526 |
+
return F.pad(full_feature, (0, pad_size), mode="constant", value=0)
|
1527 |
+
|
1528 |
+
def forward(self, audio_features, text_features):
|
1529 |
+
# Проекция признаков
|
1530 |
+
audio = self.audio_proj(audio_features.float()) # [B, T, D]
|
1531 |
+
text = self.text_proj(text_features.float()) # [B, T, D]
|
1532 |
+
|
1533 |
+
# Адаптивный пулинг к минимальной длине
|
1534 |
+
min_len = min(audio.size(1), text.size(1))
|
1535 |
+
audio = self._adaptive_pool(audio, min_len)
|
1536 |
+
text = self._adaptive_pool(text, min_len)
|
1537 |
+
|
1538 |
+
# Объединение модальностей
|
1539 |
+
fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D]
|
1540 |
+
fused = self.fusion_proj(fused) # [B, T, D]
|
1541 |
+
|
1542 |
+
# Обработка объединенных признаков через Mamba
|
1543 |
+
for mamba_block in self.mamba_blocks:
|
1544 |
+
out, _ = mamba_block[0](fused, None)
|
1545 |
+
out = mamba_block[1](out)
|
1546 |
+
fused = fused + out # Residual connection
|
1547 |
+
|
1548 |
+
# Агрегация признаков и классификация
|
1549 |
+
pooled = self._pool_features(fused)
|
1550 |
+
return self.classifier(pooled)
|
1551 |
+
|
1552 |
+
def _adaptive_pool(self, x, target_len):
|
1553 |
+
"""Адаптивное изменение временной длины"""
|
1554 |
+
if x.size(1) == target_len:
|
1555 |
+
return x
|
1556 |
+
return F.interpolate(
|
1557 |
+
x.permute(0, 2, 1),
|
1558 |
+
size=target_len,
|
1559 |
+
mode='linear',
|
1560 |
+
align_corners=False
|
1561 |
+
).permute(0, 2, 1)
|
1562 |
+
|
1563 |
+
def _init_weights(self):
|
1564 |
+
for m in self.modules():
|
1565 |
+
if isinstance(m, nn.Linear):
|
1566 |
+
nn.init.xavier_uniform_(m.weight)
|
1567 |
+
if m.bias is not None:
|
1568 |
+
nn.init.constant_(m.bias, 0)
|
1569 |
+
elif isinstance(m, nn.LayerNorm):
|
1570 |
+
nn.init.constant_(m.weight, 1)
|
1571 |
+
nn.init.constant_(m.bias, 0)
|
1572 |
+
|
1573 |
+
class BiMambaWithProb(nn.Module):
|
1574 |
+
def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16,
|
1575 |
+
d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='',positional_encoding=False,
|
1576 |
+
out_features=128, num_classes=7, device="cuda"):
|
1577 |
+
super(BiMambaWithProb, self).__init__()
|
1578 |
+
|
1579 |
+
self.hidden_dim = hidden_dim
|
1580 |
+
self.seg_len = seg_len
|
1581 |
+
self.num_mamba_layers = mamba_layer_number
|
1582 |
+
self.device = device
|
1583 |
+
|
1584 |
+
# Проекционные слои для каждой модальности
|
1585 |
+
self.audio_proj = nn.Sequential(
|
1586 |
+
nn.Linear(audio_dim, hidden_dim),
|
1587 |
+
nn.LayerNorm(hidden_dim),
|
1588 |
+
nn.Dropout(dropout)
|
1589 |
+
)
|
1590 |
+
|
1591 |
+
self.text_proj = nn.Sequential(
|
1592 |
+
nn.Linear(text_dim, hidden_dim),
|
1593 |
+
nn.LayerNorm(hidden_dim),
|
1594 |
+
nn.Dropout(dropout)
|
1595 |
+
)
|
1596 |
+
|
1597 |
+
# Слой для объединения модальностей
|
1598 |
+
self.fusion_proj = nn.Sequential(
|
1599 |
+
nn.Linear(2 * hidden_dim, hidden_dim),
|
1600 |
+
nn.LayerNorm(hidden_dim),
|
1601 |
+
nn.Dropout(dropout)
|
1602 |
+
)
|
1603 |
+
|
1604 |
+
# Mamba блоки для обработки объединенных признаков
|
1605 |
+
mamba_params = {
|
1606 |
+
'd_input': hidden_dim,
|
1607 |
+
'd_model': hidden_dim,
|
1608 |
+
'd_state': mamba_d_state,
|
1609 |
+
'd_discr': d_discr,
|
1610 |
+
'ker_size': mamba_ker_size
|
1611 |
+
}
|
1612 |
+
|
1613 |
+
self.mamba_blocks = nn.ModuleList([
|
1614 |
+
nn.Sequential(
|
1615 |
+
MambaBlock(**mamba_params),
|
1616 |
+
RMSNorm(hidden_dim)
|
1617 |
+
)
|
1618 |
+
for _ in range(self.num_mamba_layers)
|
1619 |
+
])
|
1620 |
+
|
1621 |
+
# Автоматический расчет размерности классификатора
|
1622 |
+
# self._calculate_classifier_input_dim()
|
1623 |
+
|
1624 |
+
# Классификатор
|
1625 |
+
self.classifier = nn.Sequential(
|
1626 |
+
nn.Linear(self.seg_len + self.hidden_dim, out_features),
|
1627 |
+
nn.LayerNorm(out_features),
|
1628 |
+
nn.GELU(),
|
1629 |
+
nn.Dropout(dropout),
|
1630 |
+
nn.Linear(out_features, num_classes)
|
1631 |
+
)
|
1632 |
+
|
1633 |
+
self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
|
1634 |
+
|
1635 |
+
self._init_weights()
|
1636 |
+
|
1637 |
+
# def _calculate_classifier_input_dim(self):
|
1638 |
+
# """Вычисляет размер входных признаков для классификатора"""
|
1639 |
+
# dummy = torch.randn(1, self.seg_len, self.hidden_dim)
|
1640 |
+
# pooled = self._pool_features(dummy)
|
1641 |
+
# self.classifier_input_dim = pooled.size(1)
|
1642 |
+
|
1643 |
+
def _pool_features(self, x):
|
1644 |
+
"""Объединение временных и feature статистик"""
|
1645 |
+
mean_temp = x.mean(dim=1) # [batch, hidden_dim]
|
1646 |
+
mean_feat = x.mean(dim=-1) # [batch, seq_len]
|
1647 |
+
full_feature = torch.cat([mean_temp, mean_feat], dim=1)
|
1648 |
+
if full_feature.shape[-1] == self.seg_len+self.hidden_dim:
|
1649 |
+
return torch.cat([mean_temp, mean_feat], dim=1)
|
1650 |
+
else:
|
1651 |
+
pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1]
|
1652 |
+
return F.pad(full_feature, (0, pad_size), mode="constant", value=0)
|
1653 |
+
|
1654 |
+
def forward(self, audio_features, text_features, audio_pred, text_pred):
|
1655 |
+
# Проекция признаков
|
1656 |
+
audio = self.audio_proj(audio_features.float()) # [B, T, D]
|
1657 |
+
text = self.text_proj(text_features.float()) # [B, T, D]
|
1658 |
+
|
1659 |
+
# Адаптивный пулинг к минимальной длине
|
1660 |
+
min_len = min(audio.size(1), text.size(1))
|
1661 |
+
audio = self._adaptive_pool(audio, min_len)
|
1662 |
+
text = self._adaptive_pool(text, min_len)
|
1663 |
+
|
1664 |
+
# Объединение модальностей
|
1665 |
+
fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D]
|
1666 |
+
fused = self.fusion_proj(fused) # [B, T, D]
|
1667 |
+
|
1668 |
+
# Обработка объединенных признаков через Mamba
|
1669 |
+
for mamba_block in self.mamba_blocks:
|
1670 |
+
out, _ = mamba_block[0](fused, None)
|
1671 |
+
out = mamba_block[1](out)
|
1672 |
+
fused = fused + out # Residual connection
|
1673 |
+
|
1674 |
+
# Агрегация признаков и классификация
|
1675 |
+
pooled = self._pool_features(fused)
|
1676 |
+
out = self.classifier(pooled)
|
1677 |
+
|
1678 |
+
w_out = self.pred_fusion([audio_pred, text_pred, out])
|
1679 |
+
return w_out
|
1680 |
+
|
1681 |
+
def _adaptive_pool(self, x, target_len):
|
1682 |
+
"""Адаптивное изменение временной длины"""
|
1683 |
+
if x.size(1) == target_len:
|
1684 |
+
return x
|
1685 |
+
return F.interpolate(
|
1686 |
+
x.permute(0, 2, 1),
|
1687 |
+
size=target_len,
|
1688 |
+
mode='linear',
|
1689 |
+
align_corners=False
|
1690 |
+
).permute(0, 2, 1)
|
1691 |
+
|
1692 |
+
def _init_weights(self):
|
1693 |
+
for m in self.modules():
|
1694 |
+
if isinstance(m, nn.Linear):
|
1695 |
+
nn.init.xavier_uniform_(m.weight)
|
1696 |
+
if m.bias is not None:
|
1697 |
+
nn.init.constant_(m.bias, 0)
|
1698 |
+
elif isinstance(m, nn.LayerNorm):
|
1699 |
+
nn.init.constant_(m.weight, 1)
|
1700 |
+
nn.init.constant_(m.bias, 0)
|
requirements.txt
ADDED
Binary file (2.25 kB). View file
|
|
run_generation.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from generate_synthetic_dataset import generate_from_emotion_csv
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
if len(sys.argv) < 2:
|
7 |
+
print("❌ Использование: python run_generation.py path/to/file.csv [num_processes] [device]")
|
8 |
+
sys.exit(1)
|
9 |
+
|
10 |
+
csv_path = sys.argv[1]
|
11 |
+
num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else int(os.environ.get("NUM_DIA_PROCESSES", 1))
|
12 |
+
device = sys.argv[3] if len(sys.argv) > 3 else "cuda"
|
13 |
+
|
14 |
+
filename = os.path.basename(csv_path)
|
15 |
+
try:
|
16 |
+
emotion = filename.split("_")[2]
|
17 |
+
except IndexError:
|
18 |
+
emotion = "unknown"
|
19 |
+
|
20 |
+
print(f"🧪 CSV: {csv_path}")
|
21 |
+
print(f"💻 Устройство: {device}")
|
22 |
+
print(f"🔧 Процессов: {num_processes}")
|
23 |
+
print(f"🎭 Эмоция: {emotion}")
|
24 |
+
|
25 |
+
generate_from_emotion_csv(
|
26 |
+
csv_path=csv_path,
|
27 |
+
emotion=emotion,
|
28 |
+
output_dir="tts_synthetic_final",
|
29 |
+
device=device,
|
30 |
+
max_samples=None,
|
31 |
+
num_processes=num_processes
|
32 |
+
)
|
search_params.toml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[grid]
|
2 |
+
scheduler_type = ["huggingface_cosine_with_restarts", "huggingface_linear", "cosine", "onecycle"]
|
3 |
+
smoothing_probability = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
4 |
+
# mamba_ker_size = [3, 4]
|
5 |
+
# mamba_layer_number = [3, 4, 5]
|
6 |
+
# hidden_dim_gated = [128]
|
7 |
+
# num_transformer_heads = [2, 4, 8]
|
8 |
+
# tr_layer_number = [1, 2]
|
9 |
+
# out_features = [128, 256]
|
10 |
+
# num_graph_heads = [2, 4]
|
11 |
+
# dropout = [0.0, 0.1, 0.2]
|
12 |
+
# positional_encoding = [true, false]
|
13 |
+
|
14 |
+
[defaults]
|
15 |
+
hidden_dim = 128
|
16 |
+
# # hidden_dim_gated = 128
|
17 |
+
# num_transformer_heads = 2
|
18 |
+
# tr_layer_number = 1
|
19 |
+
# out_features = 128
|
20 |
+
# # num_graph_heads = 2
|
21 |
+
# dropout = 0
|
22 |
+
# positional_encoding = false
|
synthetic_utils/__pycache__/dia_tts_wrapper.cpython-310.pyc
ADDED
Binary file (2.78 kB). View file
|
|
synthetic_utils/dia_tts_wrapper.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import soundfile as sf
|
6 |
+
import numpy as np
|
7 |
+
from dia.model import Dia
|
8 |
+
|
9 |
+
|
10 |
+
class DiaTTSWrapper:
|
11 |
+
def __init__(self, model_name="nari-labs/Dia-1.6B", device="cuda", dtype="float16"):
|
12 |
+
self.device = device
|
13 |
+
self.sr = 44100
|
14 |
+
logging.info(f"[DiaTTS] Загрузка модели {model_name} на {device} (dtype={dtype})")
|
15 |
+
self.model = Dia.from_pretrained(
|
16 |
+
model_name,
|
17 |
+
device=device,
|
18 |
+
compute_dtype=dtype
|
19 |
+
)
|
20 |
+
|
21 |
+
def generate_audio_from_text(self, text: str, paralinguistic: str = "", max_duration: float = None) -> torch.Tensor:
|
22 |
+
try:
|
23 |
+
if paralinguistic:
|
24 |
+
clean = paralinguistic.strip("()").lower()
|
25 |
+
text = f"{text} ({clean})"
|
26 |
+
audio_np = self.model.generate(
|
27 |
+
text,
|
28 |
+
use_torch_compile=False,
|
29 |
+
verbose=False
|
30 |
+
)
|
31 |
+
wf = torch.from_numpy(audio_np).float().unsqueeze(0)
|
32 |
+
if max_duration:
|
33 |
+
max_samples = int(self.sr * max_duration)
|
34 |
+
wf = wf[:, :max_samples]
|
35 |
+
return wf
|
36 |
+
except Exception as e:
|
37 |
+
logging.error(f"[DiaTTS] Ошибка генерации аудио: {e}")
|
38 |
+
return torch.zeros(1, self.sr)
|
39 |
+
|
40 |
+
def generate_and_save_audio(
|
41 |
+
self,
|
42 |
+
text: str,
|
43 |
+
paralinguistic: str = "",
|
44 |
+
out_dir="tts_outputs",
|
45 |
+
filename_prefix="tts",
|
46 |
+
max_duration: float = None,
|
47 |
+
use_timestamp=True,
|
48 |
+
skip_if_exists=True,
|
49 |
+
max_trim_duration: float = None
|
50 |
+
) -> torch.Tensor:
|
51 |
+
os.makedirs(out_dir, exist_ok=True)
|
52 |
+
if use_timestamp:
|
53 |
+
timestr = time.strftime("%Y%m%d_%H%M%S")
|
54 |
+
filename = f"{filename_prefix}_{timestr}.wav"
|
55 |
+
else:
|
56 |
+
filename = f"{filename_prefix}.wav"
|
57 |
+
out_path = os.path.join(out_dir, filename)
|
58 |
+
|
59 |
+
if skip_if_exists and os.path.exists(out_path):
|
60 |
+
logging.info(f"[DiaTTS] ⏭️ Пропущено — уже существует: {out_path}")
|
61 |
+
return None
|
62 |
+
|
63 |
+
wf = self.generate_audio_from_text(text, paralinguistic, max_duration)
|
64 |
+
np_wf = wf.squeeze().cpu().numpy()
|
65 |
+
|
66 |
+
if max_trim_duration is not None:
|
67 |
+
max_len = int(self.sr * max_trim_duration)
|
68 |
+
if len(np_wf) > max_len:
|
69 |
+
logging.info(f"[DiaTTS] ✂️ Обрезка аудио до {max_trim_duration} сек.")
|
70 |
+
np_wf = np_wf[:max_len]
|
71 |
+
|
72 |
+
sf.write(out_path, np_wf, self.sr)
|
73 |
+
logging.info(f"[DiaTTS] 💾 Сохранено аудио: {out_path}")
|
74 |
+
return wf
|
75 |
+
|
76 |
+
def get_sample_rate(self):
|
77 |
+
return self.sr
|
synthetic_utils/parler_tts_wrapper.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# parler_tts_wrapper.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import soundfile as sf
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
from parler_tts import ParlerTTSForConditionalGeneration
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
|
11 |
+
class ParlerTTS:
|
12 |
+
def __init__(self, model_name="parler-tts/parler-tts-mini-v1", device="cuda"):
|
13 |
+
self.device = device
|
14 |
+
logging.info(f"[ParlerTTS] Загрузка модели {model_name} на {device} ...")
|
15 |
+
|
16 |
+
self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device)
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
+
self.sr = self.model.config.sampling_rate
|
19 |
+
|
20 |
+
def generate_audio_from_text(self, text: str, description: str) -> torch.Tensor:
|
21 |
+
"""
|
22 |
+
Генерирует аудио (без сохранения на диск).
|
23 |
+
Возвращает PyTorch-тензор формы (1, num_samples).
|
24 |
+
"""
|
25 |
+
input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
|
26 |
+
prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
|
27 |
+
|
28 |
+
with torch.no_grad():
|
29 |
+
generation = self.model.generate(
|
30 |
+
input_ids=input_ids,
|
31 |
+
prompt_input_ids=prompt_input_ids
|
32 |
+
)
|
33 |
+
|
34 |
+
audio_arr = generation.cpu().numpy().squeeze() # (samples,)
|
35 |
+
wf = torch.from_numpy(audio_arr).unsqueeze(0) # -> (1, samples)
|
36 |
+
return wf
|
37 |
+
|
38 |
+
def generate_and_save_audio(self, text: str, description: str, out_dir="tts_outputs", filename_prefix="tts") -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
Генерирует аудио И сохраняет результат в WAV-файл (для отладки/проверки).
|
41 |
+
Возвращает PyTorch-тензор (1, num_samples).
|
42 |
+
"""
|
43 |
+
os.makedirs(out_dir, exist_ok=True)
|
44 |
+
|
45 |
+
wf = self.generate_audio_from_text(text, description)
|
46 |
+
np_wf = wf.squeeze().cpu().numpy()
|
47 |
+
|
48 |
+
# Формируем имя файла
|
49 |
+
timestr = time.strftime("%Y%m%d_%H%M%S")
|
50 |
+
filename = f"{filename_prefix}_{timestr}.wav"
|
51 |
+
out_path = os.path.join(out_dir, filename)
|
52 |
+
|
53 |
+
# Сохраняем
|
54 |
+
sf.write(out_path, np_wf, self.sr)
|
55 |
+
logging.info(f"[ParlerTTS] Сохранено аудио: {out_path}")
|
56 |
+
|
57 |
+
return wf
|
58 |
+
|
59 |
+
def get_sample_rate(self):
|
60 |
+
return self.sr
|
synthetic_utils/text_generation.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
5 |
+
|
6 |
+
class TextGenerator:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
model_name="gpt2",
|
10 |
+
device="cuda",
|
11 |
+
max_new_tokens=50,
|
12 |
+
temperature=1.0,
|
13 |
+
top_p=0.95,
|
14 |
+
seed=None
|
15 |
+
):
|
16 |
+
self.model_name = model_name
|
17 |
+
self.device = device
|
18 |
+
self.max_new_tokens = max_new_tokens
|
19 |
+
self.temperature = temperature
|
20 |
+
self.top_p = top_p
|
21 |
+
self.seed = seed
|
22 |
+
|
23 |
+
logging.info(f"[TextGenerator] Загрузка модели {model_name} на {device} ...")
|
24 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
26 |
+
|
27 |
+
if seed is not None:
|
28 |
+
set_seed(seed)
|
29 |
+
logging.info(f"[TextGenerator] Сид генерации установлен через transformers.set_seed({seed})")
|
30 |
+
else:
|
31 |
+
logging.info("[TextGenerator] Сид генерации не установлен (seed=None)")
|
32 |
+
|
33 |
+
# --- Примеры для few-shot обучения ---
|
34 |
+
self.fewshot_examples = [
|
35 |
+
("happy", "We finally made it!", "We finally made it! I’ve never felt so alive and proud of what we accomplished."),
|
36 |
+
("sad", "He didn't come back.", "He didn't come back. I waited all night, hoping to see him again."),
|
37 |
+
("anger", "Why would you do that?", "Why would you do that? You had no right to interfere!"),
|
38 |
+
("fear", "Did you hear that?", "Did you hear that? Something’s moving outside the window..."),
|
39 |
+
("surprise", "Oh wow, really?", "Oh wow, really? I didn’t see that coming at all!"),
|
40 |
+
("disgust", "That smell is awful.", "That smell is awful. I feel like I’m going to be sick."),
|
41 |
+
("neutral", "Let's meet at noon.", "Let's meet at noon. We’ll have plenty of time to talk then.")
|
42 |
+
]
|
43 |
+
|
44 |
+
def build_prompt(self, emotion: str, partial_text: str) -> str:
|
45 |
+
few_shot = random.sample(self.fewshot_examples, 2)
|
46 |
+
examples_str = ""
|
47 |
+
for emo, text, cont in few_shot:
|
48 |
+
examples_str += (
|
49 |
+
f"Example:\n"
|
50 |
+
f"Emotion: {emo}\n"
|
51 |
+
f"Text: {text}\n"
|
52 |
+
f"Continuation: {cont}\n\n"
|
53 |
+
)
|
54 |
+
|
55 |
+
prompt = (
|
56 |
+
"You are a helpful assistant that generates emotionally-aligned sentence continuations.\n"
|
57 |
+
"You must include the original sentence in the output, and then continue it in a fluent and emotionally appropriate way.\n\n"
|
58 |
+
f"{examples_str}"
|
59 |
+
f"Now try:\n"
|
60 |
+
f"Emotion: {emotion}\n"
|
61 |
+
f"Text: {partial_text}\n"
|
62 |
+
f"Continuation:"
|
63 |
+
)
|
64 |
+
return prompt
|
65 |
+
|
66 |
+
def generate_text(self, emotion: str, partial_text: str = "") -> str:
|
67 |
+
prompt = self.build_prompt(emotion, partial_text)
|
68 |
+
logging.debug(f"[TextGenerator] prompt:\n{prompt}")
|
69 |
+
|
70 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
71 |
+
|
72 |
+
output_ids = self.model.generate(
|
73 |
+
**inputs,
|
74 |
+
max_new_tokens=self.max_new_tokens,
|
75 |
+
do_sample=True,
|
76 |
+
top_p=self.top_p,
|
77 |
+
temperature=self.temperature,
|
78 |
+
pad_token_id=self.tokenizer.eos_token_id
|
79 |
+
)
|
80 |
+
|
81 |
+
full_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
82 |
+
logging.debug(f"[TextGenerator] decoded:\n{full_text}")
|
83 |
+
|
84 |
+
# Вытаскиваем то, что идёт после последнего "Continuation:"
|
85 |
+
if "Continuation:" in full_text:
|
86 |
+
result = full_text.split("Continuation:")[-1].strip()
|
87 |
+
else:
|
88 |
+
result = full_text.strip()
|
89 |
+
|
90 |
+
result = result.split("\n")[0].strip()
|
91 |
+
return result
|
test.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tts_test_run.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
from utils.config_loader import ConfigLoader
|
5 |
+
from synthetic_utils.dia_tts_wrapper import DiaTTSWrapper
|
6 |
+
from generate_synthetic_dataset import PARALINGUISTIC_MARKERS
|
7 |
+
|
8 |
+
# Загружаем конфиг
|
9 |
+
config = ConfigLoader("config.toml")
|
10 |
+
|
11 |
+
# Настройка TTS
|
12 |
+
tts = DiaTTSWrapper(device=config.whisper_device)
|
13 |
+
|
14 |
+
# Пример текста и эмоции
|
15 |
+
text = "I'm just testing how this emotional voice sounds."
|
16 |
+
emotion = "neutral" # можно: neutral, happy, sad, anger, fear, surprise, disgust
|
17 |
+
marker = PARALINGUISTIC_MARKERS.get(emotion, "")
|
18 |
+
|
19 |
+
# Генерация и сохранение
|
20 |
+
tts.generate_and_save_audio(
|
21 |
+
text=text,
|
22 |
+
paralinguistic=marker,
|
23 |
+
out_dir="tts_test_outputs",
|
24 |
+
filename_prefix=f"test_{emotion}",
|
25 |
+
max_duration=5.0
|
26 |
+
)
|
27 |
+
|
28 |
+
print(f"✅ Аудио для эмоции '{emotion}' сохранено.")
|
training/train_utils.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# train_utils.py
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
import csv
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
from typing import Type
|
12 |
+
import os
|
13 |
+
import datetime
|
14 |
+
|
15 |
+
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
16 |
+
from torch.nn.utils.rnn import pad_sequence
|
17 |
+
|
18 |
+
from utils.losses import WeightedCrossEntropyLoss
|
19 |
+
from utils.measures import uar, war, mf1, wf1
|
20 |
+
from models.models import (
|
21 |
+
BiFormer, BiGraphFormer, BiGatedGraphFormer,
|
22 |
+
PredictionsFusion, BiFormerWithProb, BiGatedFormer,
|
23 |
+
BiMamba, BiMambaWithProb,BiGraphFormerWithProb, BiGatedGraphFormerWithProb
|
24 |
+
)
|
25 |
+
from utils.schedulers import SmartScheduler
|
26 |
+
from data_loading.dataset_multimodal import DatasetMultiModalWithPretrainedExtractors
|
27 |
+
from sklearn.utils.class_weight import compute_class_weight
|
28 |
+
from lion_pytorch import Lion
|
29 |
+
|
30 |
+
|
31 |
+
def get_smoothed_labels(audio_paths, original_labels, smooth_labels_df, smooth_mask, emotion_columns, device):
|
32 |
+
"""
|
33 |
+
audio_paths: список путей к аудиофайлам
|
34 |
+
smooth_mask: тензор boolean с индексами для сглаживания
|
35 |
+
Возвращает тензор сглаженных меток только для отмеченных примеров
|
36 |
+
"""
|
37 |
+
|
38 |
+
# Получаем индексы для сглаживания
|
39 |
+
smooth_indices = torch.where(smooth_mask)[0]
|
40 |
+
|
41 |
+
# Создаем тензор для результатов (такого же размера как оригинальные метки)
|
42 |
+
smoothed_labels = torch.zeros_like(original_labels)
|
43 |
+
|
44 |
+
# print(smooth_labels_df, audio_paths)
|
45 |
+
|
46 |
+
for idx in smooth_indices:
|
47 |
+
audio_path = audio_paths[idx]
|
48 |
+
# Получаем сглаженную метку из вашего DataFrame или другого источника
|
49 |
+
smoothed_label = smooth_labels_df.loc[
|
50 |
+
smooth_labels_df['video_name'] == audio_path[:-4],
|
51 |
+
emotion_columns
|
52 |
+
].values[0]
|
53 |
+
|
54 |
+
smoothed_labels[idx] = torch.tensor(smoothed_label, device=device)
|
55 |
+
|
56 |
+
return smoothed_labels
|
57 |
+
|
58 |
+
def custom_collate_fn(batch):
|
59 |
+
"""Собирает список образцов в единый батч, отбрасывая None (невалидные)."""
|
60 |
+
batch = [x for x in batch if x is not None]
|
61 |
+
# print(batch[0].keys())
|
62 |
+
if not batch:
|
63 |
+
return None
|
64 |
+
|
65 |
+
audios = [b["audio"] for b in batch]
|
66 |
+
# audio_tensor = torch.stack(audios)
|
67 |
+
audio_tensor = pad_sequence(audios, batch_first=True)
|
68 |
+
|
69 |
+
labels = [b["label"] for b in batch]
|
70 |
+
label_tensor = torch.stack(labels)
|
71 |
+
|
72 |
+
texts = [b["text"] for b in batch]
|
73 |
+
text_tensor = torch.stack(texts)
|
74 |
+
|
75 |
+
audio_pred = [b["audio_pred"] for b in batch]
|
76 |
+
audio_pred = torch.stack(audio_pred)
|
77 |
+
|
78 |
+
text_pred = [b["text_pred"] for b in batch]
|
79 |
+
text_pred = torch.stack(text_pred)
|
80 |
+
|
81 |
+
return {
|
82 |
+
"audio_paths": [b["audio_path"] for b in batch], # new
|
83 |
+
"audio": audio_tensor,
|
84 |
+
"label": label_tensor,
|
85 |
+
"text": text_tensor,
|
86 |
+
"audio_pred": audio_pred,
|
87 |
+
"text_pred": text_pred,
|
88 |
+
}
|
89 |
+
|
90 |
+
def get_class_weights_from_loader(train_loader, num_classes):
|
91 |
+
"""
|
92 |
+
Вычисляет веса классов из train_loader, устойчиво к отсутствующим классам.
|
93 |
+
Если какой-либо класс отсутствует в выборке, ему будет присвоен вес 0.0.
|
94 |
+
|
95 |
+
:param train_loader: DataLoader с one-hot метками
|
96 |
+
:param num_classes: Общее количество классов
|
97 |
+
:return: np.ndarray весов длины num_classes
|
98 |
+
"""
|
99 |
+
all_labels = []
|
100 |
+
for batch in train_loader:
|
101 |
+
if batch is None:
|
102 |
+
continue
|
103 |
+
all_labels.extend(batch["label"].argmax(dim=1).tolist())
|
104 |
+
|
105 |
+
if not all_labels:
|
106 |
+
raise ValueError("Нет ни одной метки в train_loader для вычисления весов классов.")
|
107 |
+
|
108 |
+
present_classes = np.unique(all_labels)
|
109 |
+
|
110 |
+
if len(present_classes) < num_classes:
|
111 |
+
missing = set(range(num_classes)) - set(present_classes)
|
112 |
+
logging.info(f"[!] Отсутствуют метки для классов: {sorted(missing)}")
|
113 |
+
|
114 |
+
# Вычисляем веса только по тем классам, что есть
|
115 |
+
weights_partial = compute_class_weight(
|
116 |
+
class_weight="balanced",
|
117 |
+
classes=present_classes,
|
118 |
+
y=all_labels
|
119 |
+
)
|
120 |
+
|
121 |
+
# Собираем полный вектор весов
|
122 |
+
full_weights = np.zeros(num_classes, dtype=np.float32)
|
123 |
+
for cls, w in zip(present_classes, weights_partial):
|
124 |
+
full_weights[cls] = w
|
125 |
+
|
126 |
+
return full_weights
|
127 |
+
|
128 |
+
def make_dataset_and_loader(config, split: str, audio_feature_extractor: Type = None, text_feature_extractor: Type = None, whisper_model: Type = None, only_dataset: str = None):
|
129 |
+
"""
|
130 |
+
Универсальная функция: объединяет датасеты или возвращает один п��и only_dataset.
|
131 |
+
При объединении train-датасетов — использует WeightedRandomSampler для балансировки.
|
132 |
+
"""
|
133 |
+
datasets = []
|
134 |
+
|
135 |
+
if not hasattr(config, "datasets") or not config.datasets:
|
136 |
+
raise ValueError("⛔ В конфиге не указана секция [datasets].")
|
137 |
+
|
138 |
+
for dataset_name, dataset_cfg in config.datasets.items():
|
139 |
+
if only_dataset and dataset_name != only_dataset:
|
140 |
+
continue
|
141 |
+
|
142 |
+
csv_path = dataset_cfg["csv_path"].format(base_dir=dataset_cfg["base_dir"], split=split)
|
143 |
+
wav_dir = dataset_cfg["wav_dir"].format(base_dir=dataset_cfg["base_dir"], split=split)
|
144 |
+
|
145 |
+
logging.info(f"[{dataset_name.upper()}] Split={split}: CSV={csv_path}, WAV_DIR={wav_dir}")
|
146 |
+
|
147 |
+
dataset = DatasetMultiModalWithPretrainedExtractors(
|
148 |
+
csv_path = csv_path,
|
149 |
+
wav_dir = wav_dir,
|
150 |
+
emotion_columns = config.emotion_columns,
|
151 |
+
split = split,
|
152 |
+
config = config,
|
153 |
+
audio_feature_extractor = audio_feature_extractor,
|
154 |
+
text_feature_extractor = text_feature_extractor,
|
155 |
+
whisper_model = whisper_model,
|
156 |
+
dataset_name = dataset_name
|
157 |
+
)
|
158 |
+
|
159 |
+
datasets.append(dataset)
|
160 |
+
|
161 |
+
if not datasets:
|
162 |
+
raise ValueError(f"⚠️ Для split='{split}' не найдено ни одного подходящего датасета.")
|
163 |
+
|
164 |
+
if len(datasets) == 1:
|
165 |
+
|
166 |
+
full_dataset = datasets[0]
|
167 |
+
loader = DataLoader(
|
168 |
+
full_dataset,
|
169 |
+
batch_size=config.batch_size,
|
170 |
+
shuffle=(split == "train"),
|
171 |
+
num_workers=config.num_workers,
|
172 |
+
collate_fn=custom_collate_fn
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
# Несколько датасетов — собираем веса
|
176 |
+
lengths = [len(d) for d in datasets]
|
177 |
+
total = sum(lengths)
|
178 |
+
|
179 |
+
logging.info(f"[!] Объединяем {len(datasets)} датасетов: {lengths} (total={total})")
|
180 |
+
|
181 |
+
weights = []
|
182 |
+
for d_len in lengths:
|
183 |
+
w = 1.0 / d_len
|
184 |
+
weights += [w] * d_len
|
185 |
+
logging.info(f" ➜ Сэмплы из датасета с {d_len} примерами получают вес {w:.6f}")
|
186 |
+
|
187 |
+
full_dataset = ConcatDataset(datasets)
|
188 |
+
|
189 |
+
if split == "train":
|
190 |
+
sampler = WeightedRandomSampler(weights, num_samples=total, replacement=True)
|
191 |
+
loader = DataLoader(
|
192 |
+
full_dataset,
|
193 |
+
batch_size=config.batch_size,
|
194 |
+
sampler=sampler,
|
195 |
+
num_workers=config.num_workers,
|
196 |
+
collate_fn=custom_collate_fn
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
loader = DataLoader(
|
200 |
+
full_dataset,
|
201 |
+
batch_size=config.batch_size,
|
202 |
+
shuffle=False,
|
203 |
+
num_workers=config.num_workers,
|
204 |
+
collate_fn=custom_collate_fn
|
205 |
+
)
|
206 |
+
|
207 |
+
return full_dataset, loader
|
208 |
+
|
209 |
+
def run_eval(model, loader, criterion, model_name, device="cuda"):
|
210 |
+
"""
|
211 |
+
Оценка модели на loader'е. Возвращает (loss, uar, war, mf1, wf1).
|
212 |
+
"""
|
213 |
+
model.eval()
|
214 |
+
total_loss = 0.0
|
215 |
+
total_preds = []
|
216 |
+
total_targets = []
|
217 |
+
total = 0
|
218 |
+
|
219 |
+
with torch.no_grad():
|
220 |
+
for batch in tqdm(loader):
|
221 |
+
if batch is None:
|
222 |
+
continue
|
223 |
+
|
224 |
+
audio = batch["audio"].to(device)
|
225 |
+
labels = batch["label"].to(device)
|
226 |
+
texts = batch["text"]
|
227 |
+
audio_pred = batch["audio_pred"].to(device)
|
228 |
+
text_pred = batch["text_pred"].to(device)
|
229 |
+
|
230 |
+
if "fusion" in model_name:
|
231 |
+
logits = model((audio_pred, text_pred))
|
232 |
+
elif "withprob" in model_name:
|
233 |
+
logits = model(audio, texts, audio_pred, text_pred)
|
234 |
+
else:
|
235 |
+
logits = model(audio, texts)
|
236 |
+
target = labels.argmax(dim=1)
|
237 |
+
|
238 |
+
loss = criterion(logits, target)
|
239 |
+
bs = audio.shape[0]
|
240 |
+
total_loss += loss.item() * bs
|
241 |
+
total += bs
|
242 |
+
|
243 |
+
preds = logits.argmax(dim=1)
|
244 |
+
total_preds.extend(preds.cpu().numpy().tolist())
|
245 |
+
total_targets.extend(target.cpu().numpy().tolist())
|
246 |
+
|
247 |
+
avg_loss = total_loss / total
|
248 |
+
|
249 |
+
uar_m = uar(total_targets, total_preds)
|
250 |
+
war_m = war(total_targets, total_preds)
|
251 |
+
mf1_m = mf1(total_targets, total_preds)
|
252 |
+
wf1_m = wf1(total_targets, total_preds)
|
253 |
+
|
254 |
+
return avg_loss, uar_m, war_m, mf1_m, wf1_m
|
255 |
+
|
256 |
+
def train_once(config, train_loader, dev_loaders, test_loaders, metrics_csv_path=None):
|
257 |
+
"""
|
258 |
+
Логика обучения (train/dev/test).
|
259 |
+
Возвращает лучшую метрику на dev и словарь метрик.
|
260 |
+
"""
|
261 |
+
|
262 |
+
logging.info("== Запуск тренировки (train/dev/test) ==")
|
263 |
+
|
264 |
+
checkpoint_dir = None
|
265 |
+
if config.save_best_model:
|
266 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
267 |
+
checkpoint_dir = os.path.join("checkpoints", f"{config.model_name}_{timestamp}")
|
268 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
269 |
+
|
270 |
+
csv_writer = None
|
271 |
+
csv_file = None
|
272 |
+
|
273 |
+
if config.path_to_df_ls:
|
274 |
+
df_ls = pd.read_csv(config.path_to_df_ls)
|
275 |
+
|
276 |
+
if metrics_csv_path:
|
277 |
+
csv_file = open(metrics_csv_path, mode="w", newline="", encoding="utf-8")
|
278 |
+
csv_writer = csv.writer(csv_file)
|
279 |
+
csv_writer.writerow(["split", "epoch", "dataset", "loss", "uar", "war", "mf1", "wf1", "mean"])
|
280 |
+
|
281 |
+
|
282 |
+
# Seed
|
283 |
+
if config.random_seed > 0:
|
284 |
+
random.seed(config.random_seed)
|
285 |
+
torch.manual_seed(config.random_seed)
|
286 |
+
torch.cuda.manual_seed_all(config.random_seed)
|
287 |
+
torch.backends.cudnn.deterministic = True
|
288 |
+
torch.backends.cudnn.benchmark = False
|
289 |
+
os.environ['PYTHONHASHSEED'] = str(config.random_seed)
|
290 |
+
logging.info(f"== Фиксируем random seed: {config.random_seed}")
|
291 |
+
else:
|
292 |
+
logging.info("== Random seed не фиксирован (0).")
|
293 |
+
|
294 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
295 |
+
|
296 |
+
# Экстракторы
|
297 |
+
# audio_extractor = AudioEmbeddingExtractor(config)
|
298 |
+
# text_extractor = TextEmbeddingExtractor(config)
|
299 |
+
|
300 |
+
# Параметры
|
301 |
+
hidden_dim = config.hidden_dim
|
302 |
+
num_classes = len(config.emotion_columns)
|
303 |
+
num_transformer_heads = config.num_transformer_heads
|
304 |
+
num_graph_heads = config.num_graph_heads
|
305 |
+
hidden_dim_gated = config.hidden_dim_gated
|
306 |
+
mamba_d_state = config.mamba_d_state
|
307 |
+
mamba_ker_size = config.mamba_ker_size
|
308 |
+
mamba_layer_number = config.mamba_layer_number
|
309 |
+
mode = config.mode
|
310 |
+
weight_decay = config.weight_decay
|
311 |
+
momentum = config.momentum
|
312 |
+
positional_encoding = config.positional_encoding
|
313 |
+
dropout = config.dropout
|
314 |
+
out_features = config.out_features
|
315 |
+
lr = config.lr
|
316 |
+
num_epochs = config.num_epochs
|
317 |
+
tr_layer_number = config.tr_layer_number
|
318 |
+
max_patience = config.max_patience
|
319 |
+
scheduler_type = config.scheduler_type
|
320 |
+
|
321 |
+
dict_models = {
|
322 |
+
'BiFormer': BiFormer, # вход audio, texts
|
323 |
+
'BiGraphFormer': BiGraphFormer, # вход audio, texts
|
324 |
+
'BiGatedGraphFormer': BiGatedGraphFormer, # вход audio, texts
|
325 |
+
"BiGatedFormer": BiGatedFormer, # вход audio, texts
|
326 |
+
"BiMamba": BiMamba, # вход audio, texts
|
327 |
+
"PredictionsFusion": PredictionsFusion, # вход audio_pred, text_pred
|
328 |
+
"BiFormerWithProb": BiFormerWithProb, # вход audio, texts, audio_pred, text_pred
|
329 |
+
"BiMambaWithProb": BiMambaWithProb, # вход audio, texts, audio_pred, text_pred
|
330 |
+
"BiGraphFormerWithProb": BiGraphFormerWithProb, # вход audio, texts, audio_pred, text_pred
|
331 |
+
"BiGatedGraphFormerWithProb": BiGatedGraphFormerWithProb,
|
332 |
+
}
|
333 |
+
|
334 |
+
model_cls = dict_models[config.model_name]
|
335 |
+
model_name = config.model_name.lower()
|
336 |
+
|
337 |
+
if model_name == 'predictionsfusion':
|
338 |
+
model = model_cls().to(device)
|
339 |
+
|
340 |
+
elif 'mamba' in model_name:
|
341 |
+
# Особые параметры для Mamba-семейства
|
342 |
+
model = model_cls(
|
343 |
+
audio_dim = config.audio_embedding_dim,
|
344 |
+
text_dim = config.text_embedding_dim,
|
345 |
+
hidden_dim = hidden_dim,
|
346 |
+
mamba_d_state = mamba_d_state,
|
347 |
+
mamba_ker_size = mamba_ker_size,
|
348 |
+
mamba_layer_number = mamba_layer_number,
|
349 |
+
seg_len = config.max_tokens,
|
350 |
+
mode = mode,
|
351 |
+
dropout = dropout,
|
352 |
+
positional_encoding = positional_encoding,
|
353 |
+
out_features = out_features,
|
354 |
+
device = device,
|
355 |
+
num_classes = num_classes
|
356 |
+
).to(device)
|
357 |
+
|
358 |
+
else:
|
359 |
+
# Обычные модели
|
360 |
+
model = model_cls(
|
361 |
+
audio_dim = config.audio_embedding_dim,
|
362 |
+
text_dim = config.text_embedding_dim,
|
363 |
+
hidden_dim = hidden_dim,
|
364 |
+
hidden_dim_gated = hidden_dim_gated,
|
365 |
+
num_transformer_heads = num_transformer_heads,
|
366 |
+
num_graph_heads = num_graph_heads,
|
367 |
+
seg_len = config.max_tokens,
|
368 |
+
mode = mode,
|
369 |
+
dropout = dropout,
|
370 |
+
positional_encoding = positional_encoding,
|
371 |
+
out_features = out_features,
|
372 |
+
tr_layer_number = tr_layer_number,
|
373 |
+
device = device,
|
374 |
+
num_classes = num_classes
|
375 |
+
).to(device)
|
376 |
+
|
377 |
+
# Оптимизатор и лосс
|
378 |
+
if config.optimizer == "adam":
|
379 |
+
optimizer = torch.optim.Adam(
|
380 |
+
model.parameters(), lr=lr, weight_decay=weight_decay
|
381 |
+
)
|
382 |
+
elif config.optimizer == "adamw":
|
383 |
+
optimizer = torch.optim.AdamW(
|
384 |
+
model.parameters(), lr=lr, weight_decay=weight_decay
|
385 |
+
)
|
386 |
+
elif config.optimizer == "lion":
|
387 |
+
optimizer = Lion(
|
388 |
+
model.parameters(), lr=lr, weight_decay=weight_decay
|
389 |
+
)
|
390 |
+
elif config.optimizer == "sgd":
|
391 |
+
optimizer = torch.optim.SGD(
|
392 |
+
model.parameters(), lr=lr,momentum = momentum
|
393 |
+
)
|
394 |
+
elif config.optimizer == "rmsprop":
|
395 |
+
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
|
396 |
+
else:
|
397 |
+
raise ValueError(f"⛔ Неизвестный оптимизатор: {config.optimizer}")
|
398 |
+
|
399 |
+
logging.info(f"Используем оптимизатор: {config.optimizer}, learning rate: {lr}")
|
400 |
+
|
401 |
+
class_weights = get_class_weights_from_loader(train_loader, num_classes)
|
402 |
+
criterion = WeightedCrossEntropyLoss(class_weights)
|
403 |
+
|
404 |
+
logging.info("Class weights: " + ", ".join(f"{name}={weight:.4f}" for name, weight in zip(config.emotion_columns, class_weights)))
|
405 |
+
|
406 |
+
# LR Scheduler
|
407 |
+
steps_per_epoch = sum(1 for batch in train_loader if batch is not None)
|
408 |
+
scheduler = SmartScheduler(
|
409 |
+
scheduler_type=scheduler_type,
|
410 |
+
optimizer=optimizer,
|
411 |
+
config=config,
|
412 |
+
steps_per_epoch=steps_per_epoch
|
413 |
+
)
|
414 |
+
|
415 |
+
# Early stopping по dev
|
416 |
+
best_dev_mean = float("-inf")
|
417 |
+
best_dev_metrics = {}
|
418 |
+
patience_counter = 0
|
419 |
+
|
420 |
+
for epoch in range(num_epochs):
|
421 |
+
logging.info(f"\n=== Эпоха {epoch} ===")
|
422 |
+
model.train()
|
423 |
+
|
424 |
+
total_loss = 0.0
|
425 |
+
total_samples = 0
|
426 |
+
total_preds = []
|
427 |
+
total_targets = []
|
428 |
+
|
429 |
+
for batch in tqdm(train_loader):
|
430 |
+
if batch is None:
|
431 |
+
continue
|
432 |
+
|
433 |
+
audio_paths = batch["audio_paths"] # new
|
434 |
+
audio = batch["audio"].to(device)
|
435 |
+
|
436 |
+
# Обработка меток с частичным сглаживанием
|
437 |
+
if config.smoothing_probability == 0:
|
438 |
+
labels = batch["label"].to(device)
|
439 |
+
else:
|
440 |
+
# Получаем оригинальные горячие метки
|
441 |
+
original_labels = batch["label"].to(device)
|
442 |
+
|
443 |
+
# Создаем маску для сглаживания (выбираем случайные примеры)
|
444 |
+
batch_size = original_labels.size(0)
|
445 |
+
smooth_mask = torch.rand(batch_size, device=device) < config.smoothing_probability
|
446 |
+
|
447 |
+
# Получаем сглаженные метки для выбранных примеров
|
448 |
+
smoothed_labels = get_smoothed_labels(audio_paths, original_labels, df_ls, smooth_mask, config.emotion_columns, device)
|
449 |
+
|
450 |
+
# Комбинируем метки
|
451 |
+
labels = torch.where(
|
452 |
+
smooth_mask.unsqueeze(1), # Добавляем размерность для broadcast
|
453 |
+
smoothed_labels.to(device),
|
454 |
+
original_labels
|
455 |
+
)
|
456 |
+
# print(labels)
|
457 |
+
texts = batch["text"]
|
458 |
+
audio_pred = batch["audio_pred"].to(device)
|
459 |
+
text_pred = batch["text_pred"].to(device)
|
460 |
+
|
461 |
+
if "fusion" in model_name:
|
462 |
+
logits = model((audio_pred, text_pred))
|
463 |
+
elif "withprob" in model_name:
|
464 |
+
logits = model(audio, texts, audio_pred, text_pred)
|
465 |
+
else:
|
466 |
+
logits = model(audio, texts)
|
467 |
+
|
468 |
+
target = labels.argmax(dim=1)
|
469 |
+
loss = criterion(logits, target)
|
470 |
+
|
471 |
+
optimizer.zero_grad()
|
472 |
+
loss.backward()
|
473 |
+
optimizer.step()
|
474 |
+
|
475 |
+
# Если scheduler - One cycle или с Hugging Face
|
476 |
+
scheduler.step(batch_level=True)
|
477 |
+
|
478 |
+
bs = audio.shape[0]
|
479 |
+
total_loss += loss.item() * bs
|
480 |
+
|
481 |
+
preds = logits.argmax(dim=1)
|
482 |
+
total_preds.extend(preds.cpu().numpy().tolist())
|
483 |
+
total_targets.extend(target.cpu().numpy().tolist())
|
484 |
+
total_samples += bs
|
485 |
+
|
486 |
+
train_loss = total_loss / total_samples
|
487 |
+
uar_m = uar(total_targets, total_preds)
|
488 |
+
war_m = war(total_targets, total_preds)
|
489 |
+
mf1_m = mf1(total_targets, total_preds)
|
490 |
+
wf1_m = wf1(total_targets, total_preds)
|
491 |
+
mean_train = np.mean([uar_m, war_m, mf1_m, wf1_m])
|
492 |
+
|
493 |
+
logging.info(
|
494 |
+
f"[TRAIN] Loss={train_loss:.4f}, UAR={uar_m:.4f}, WAR={war_m:.4f}, "
|
495 |
+
f"MF1={mf1_m:.4f}, WF1={wf1_m:.4f}, MEAN={mean_train:.4f}"
|
496 |
+
)
|
497 |
+
|
498 |
+
# --- DEV ---
|
499 |
+
dev_means = []
|
500 |
+
dev_metrics_by_dataset = []
|
501 |
+
|
502 |
+
for name, loader in dev_loaders:
|
503 |
+
d_loss, d_uar, d_war, d_mf1, d_wf1 = run_eval(
|
504 |
+
model, loader, criterion, model_name, device
|
505 |
+
)
|
506 |
+
d_mean = np.mean([d_uar, d_war, d_mf1, d_wf1])
|
507 |
+
dev_means.append(d_mean)
|
508 |
+
|
509 |
+
if csv_writer:
|
510 |
+
csv_writer.writerow(["dev", epoch, name, d_loss, d_uar, d_war, d_mf1, d_wf1, d_mean])
|
511 |
+
|
512 |
+
logging.info(
|
513 |
+
f"[DEV:{name}] Loss={d_loss:.4f}, UAR={d_uar:.4f}, WAR={d_war:.4f}, "
|
514 |
+
f"MF1={d_mf1:.4f}, WF1={d_wf1:.4f}, MEAN={d_mean:.4f}"
|
515 |
+
)
|
516 |
+
|
517 |
+
dev_metrics_by_dataset.append({
|
518 |
+
"name": name,
|
519 |
+
"loss": d_loss,
|
520 |
+
"uar": d_uar,
|
521 |
+
"war": d_war,
|
522 |
+
"mf1": d_mf1,
|
523 |
+
"wf1": d_wf1,
|
524 |
+
"mean": d_mean,
|
525 |
+
})
|
526 |
+
|
527 |
+
mean_dev = np.mean(dev_means)
|
528 |
+
scheduler.step(mean_dev)
|
529 |
+
|
530 |
+
# --- TEST ---
|
531 |
+
test_metrics_by_dataset = []
|
532 |
+
for name, loader in test_loaders:
|
533 |
+
t_loss, t_uar, t_war, t_mf1, t_wf1 = run_eval(
|
534 |
+
model, loader, criterion, model_name, device
|
535 |
+
)
|
536 |
+
t_mean = np.mean([t_uar, t_war, t_mf1, t_wf1])
|
537 |
+
logging.info(
|
538 |
+
f"[TEST:{name}] Loss={t_loss:.4f}, UAR={t_uar:.4f}, WAR={t_war:.4f}, "
|
539 |
+
f"MF1={t_mf1:.4f}, WF1={t_wf1:.4f}, MEAN={t_mean:.4f}"
|
540 |
+
)
|
541 |
+
|
542 |
+
test_metrics_by_dataset.append({
|
543 |
+
"name": name,
|
544 |
+
"loss": t_loss,
|
545 |
+
"uar": t_uar,
|
546 |
+
"war": t_war,
|
547 |
+
"mf1": t_mf1,
|
548 |
+
"wf1": t_wf1,
|
549 |
+
"mean": t_mean,
|
550 |
+
})
|
551 |
+
|
552 |
+
if csv_writer:
|
553 |
+
csv_writer.writerow(["test", epoch, name, t_loss, t_uar, t_war, t_mf1, t_wf1, t_mean])
|
554 |
+
|
555 |
+
|
556 |
+
if mean_dev > best_dev_mean:
|
557 |
+
best_dev_mean = mean_dev
|
558 |
+
patience_counter = 0
|
559 |
+
best_dev_metrics = {
|
560 |
+
"mean": mean_dev,
|
561 |
+
"by_dataset": dev_metrics_by_dataset
|
562 |
+
}
|
563 |
+
best_test_metrics = {
|
564 |
+
"mean": np.mean([ds["mean"] for ds in test_metrics_by_dataset]),
|
565 |
+
"by_dataset": test_metrics_by_dataset
|
566 |
+
}
|
567 |
+
|
568 |
+
if config.save_best_model:
|
569 |
+
dev_str = f"{mean_dev:.4f}".replace(".", "_")
|
570 |
+
model_path = os.path.join(checkpoint_dir, f"best_model_dev_{dev_str}_epoch_{epoch}.pt")
|
571 |
+
torch.save(model.state_dict(), model_path)
|
572 |
+
logging.info(f"💾 Модель сохранена по лучшему dev (эпоха {epoch}): {model_path}")
|
573 |
+
|
574 |
+
else:
|
575 |
+
patience_counter += 1
|
576 |
+
if patience_counter >= max_patience:
|
577 |
+
logging.info(f"Early stopping: {max_patience} эпох без улучшения.")
|
578 |
+
break
|
579 |
+
|
580 |
+
logging.info("Тренировка завершена. Все split'ы обработаны!")
|
581 |
+
|
582 |
+
if csv_file:
|
583 |
+
csv_file.close()
|
584 |
+
|
585 |
+
return best_dev_mean, best_dev_metrics, best_test_metrics
|
training/train_utils_old.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# train_utils.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import logging
|
7 |
+
import random
|
8 |
+
import datetime
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
import csv
|
12 |
+
|
13 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
14 |
+
from utils.losses import WeightedCrossEntropyLoss
|
15 |
+
from utils.measures import uar, war, mf1, wf1
|
16 |
+
from models.models import BiFormer, BiGraphFormer, BiGatedGraphFormer
|
17 |
+
from data_loading.dataset_multimodal import DatasetMultiModal
|
18 |
+
from data_loading.feature_extractor import AudioEmbeddingExtractor, TextEmbeddingExtractor
|
19 |
+
from sklearn.utils.class_weight import compute_class_weight
|
20 |
+
|
21 |
+
def custom_collate_fn(batch):
|
22 |
+
"""Собирает список образцов в единый батч, отбрасывая None (невалидные)."""
|
23 |
+
batch = [x for x in batch if x is not None]
|
24 |
+
if not batch:
|
25 |
+
return None
|
26 |
+
|
27 |
+
audios = [b["audio"] for b in batch]
|
28 |
+
audio_tensor = torch.stack(audios)
|
29 |
+
|
30 |
+
labels = [b["label"] for b in batch]
|
31 |
+
label_tensor = torch.stack(labels)
|
32 |
+
|
33 |
+
texts = [b["text"] for b in batch]
|
34 |
+
|
35 |
+
return {
|
36 |
+
"audio": audio_tensor,
|
37 |
+
"label": label_tensor,
|
38 |
+
"text": texts
|
39 |
+
}
|
40 |
+
|
41 |
+
def get_class_weights_from_loader(train_loader, num_classes):
|
42 |
+
"""
|
43 |
+
Вычисляет веса классов из train_loader, устойчиво к отсутствующим классам.
|
44 |
+
Если какой-либо класс отсутствует в выборке, ему будет присвоен вес 0.0.
|
45 |
+
|
46 |
+
:param train_loader: DataLoader с one-hot метками
|
47 |
+
:param num_classes: Общее количество классов
|
48 |
+
:return: np.ndarray весов длины num_classes
|
49 |
+
"""
|
50 |
+
all_labels = []
|
51 |
+
for batch in train_loader:
|
52 |
+
if batch is None:
|
53 |
+
continue
|
54 |
+
all_labels.extend(batch["label"].argmax(dim=1).tolist())
|
55 |
+
|
56 |
+
if not all_labels:
|
57 |
+
raise ValueError("Нет ни одной метки в train_loader для вычисления весов классов.")
|
58 |
+
|
59 |
+
present_classes = np.unique(all_labels)
|
60 |
+
|
61 |
+
if len(present_classes) < num_classes:
|
62 |
+
missing = set(range(num_classes)) - set(present_classes)
|
63 |
+
logging.info(f"[!] Отсутствуют метки для классов: {sorted(missing)}")
|
64 |
+
|
65 |
+
# Вычисляем веса только по тем классам, что есть
|
66 |
+
weights_partial = compute_class_weight(
|
67 |
+
class_weight="balanced",
|
68 |
+
classes=present_classes,
|
69 |
+
y=all_labels
|
70 |
+
)
|
71 |
+
|
72 |
+
# Собираем полный вектор весов
|
73 |
+
full_weights = np.zeros(num_classes, dtype=np.float32)
|
74 |
+
for cls, w in zip(present_classes, weights_partial):
|
75 |
+
full_weights[cls] = w
|
76 |
+
|
77 |
+
return full_weights
|
78 |
+
|
79 |
+
def make_dataset_and_loader(config, split: str, only_dataset: str = None):
|
80 |
+
"""
|
81 |
+
Универсальная функция: объединяет датасеты, или возвращает один при only_dataset.
|
82 |
+
"""
|
83 |
+
datasets = []
|
84 |
+
|
85 |
+
if not hasattr(config, "datasets") or not config.datasets:
|
86 |
+
raise ValueError("⛔ В конфиге не указана секция [datasets].")
|
87 |
+
|
88 |
+
for dataset_name, dataset_cfg in config.datasets.items():
|
89 |
+
if only_dataset and dataset_name != only_dataset:
|
90 |
+
continue
|
91 |
+
|
92 |
+
csv_path = dataset_cfg["csv_path"].format(base_dir=dataset_cfg["base_dir"], split=split)
|
93 |
+
wav_dir = dataset_cfg["wav_dir"].format(base_dir=dataset_cfg["base_dir"], split=split)
|
94 |
+
|
95 |
+
logging.info(f"[{dataset_name.upper()}] Split={split}: CSV={csv_path}, WAV_DIR={wav_dir}")
|
96 |
+
|
97 |
+
dataset = DatasetMultiModal(
|
98 |
+
csv_path = csv_path,
|
99 |
+
wav_dir = wav_dir,
|
100 |
+
emotion_columns = config.emotion_columns,
|
101 |
+
split = split,
|
102 |
+
sample_rate = config.sample_rate,
|
103 |
+
wav_length = config.wav_length,
|
104 |
+
whisper_model = config.whisper_model,
|
105 |
+
text_column = config.text_column,
|
106 |
+
use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text,
|
107 |
+
whisper_device = config.whisper_device,
|
108 |
+
subset_size = config.subset_size,
|
109 |
+
merge_probability = config.merge_probability
|
110 |
+
)
|
111 |
+
|
112 |
+
datasets.append(dataset)
|
113 |
+
|
114 |
+
if not datasets:
|
115 |
+
raise ValueError(f"⚠️ Для split='{split}' не найдено ни одного подходящего датасета.")
|
116 |
+
|
117 |
+
# Объединяем только если их несколько
|
118 |
+
full_dataset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
|
119 |
+
|
120 |
+
loader = DataLoader(
|
121 |
+
full_dataset,
|
122 |
+
batch_size=config.batch_size,
|
123 |
+
shuffle=(split == "train"),
|
124 |
+
num_workers=config.num_workers,
|
125 |
+
collate_fn=custom_collate_fn
|
126 |
+
)
|
127 |
+
|
128 |
+
return full_dataset, loader
|
129 |
+
|
130 |
+
def run_eval(model, loader, audio_extractor, text_extractor, criterion, device="cuda"):
|
131 |
+
"""
|
132 |
+
Оценка модели на loader'е. Возвращае�� (loss, uar, war, mf1, wf1).
|
133 |
+
"""
|
134 |
+
model.eval()
|
135 |
+
total_loss = 0.0
|
136 |
+
total_preds = []
|
137 |
+
total_targets = []
|
138 |
+
total = 0
|
139 |
+
|
140 |
+
with torch.no_grad():
|
141 |
+
for batch in tqdm(loader):
|
142 |
+
if batch is None:
|
143 |
+
continue
|
144 |
+
|
145 |
+
audio = batch["audio"].to(device)
|
146 |
+
labels = batch["label"].to(device)
|
147 |
+
texts = batch["text"]
|
148 |
+
|
149 |
+
audio_emb = audio_extractor.extract(audio)
|
150 |
+
text_emb = text_extractor.extract(texts)
|
151 |
+
|
152 |
+
logits = model(audio_emb, text_emb)
|
153 |
+
target = labels.argmax(dim=1)
|
154 |
+
|
155 |
+
loss = criterion(logits, target)
|
156 |
+
bs = audio.shape[0]
|
157 |
+
total_loss += loss.item() * bs
|
158 |
+
total += bs
|
159 |
+
|
160 |
+
preds = logits.argmax(dim=1)
|
161 |
+
total_preds.extend(preds.cpu().numpy().tolist())
|
162 |
+
total_targets.extend(target.cpu().numpy().tolist())
|
163 |
+
|
164 |
+
avg_loss = total_loss / total
|
165 |
+
|
166 |
+
uar_m = uar(total_targets, total_preds)
|
167 |
+
war_m = war(total_targets, total_preds)
|
168 |
+
mf1_m = mf1(total_targets, total_preds)
|
169 |
+
wf1_m = wf1(total_targets, total_preds)
|
170 |
+
|
171 |
+
return avg_loss, uar_m, war_m, mf1_m, wf1_m
|
172 |
+
|
173 |
+
def train_once(config, train_loader, dev_loaders, test_loaders, metrics_csv_path=None):
|
174 |
+
"""
|
175 |
+
Логика обучения (train/dev/test).
|
176 |
+
Возвращает лучшую метрику на dev и словарь метрик.
|
177 |
+
"""
|
178 |
+
|
179 |
+
logging.info("== Запуск тренировки (train/dev/test) ==")
|
180 |
+
|
181 |
+
csv_writer = None
|
182 |
+
csv_file = None
|
183 |
+
|
184 |
+
if metrics_csv_path:
|
185 |
+
csv_file = open(metrics_csv_path, mode="w", newline="", encoding="utf-8")
|
186 |
+
csv_writer = csv.writer(csv_file)
|
187 |
+
csv_writer.writerow(["split", "epoch", "dataset", "loss", "uar", "war", "mf1", "wf1", "mean"])
|
188 |
+
|
189 |
+
|
190 |
+
# Seed
|
191 |
+
if config.random_seed > 0:
|
192 |
+
random.seed(config.random_seed)
|
193 |
+
torch.manual_seed(config.random_seed)
|
194 |
+
logging.info(f"== Фиксируем random seed: {config.random_seed}")
|
195 |
+
else:
|
196 |
+
logging.info("== Random seed не фиксирован (0).")
|
197 |
+
|
198 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
199 |
+
|
200 |
+
# Экстракторы
|
201 |
+
audio_extractor = AudioEmbeddingExtractor(config)
|
202 |
+
text_extractor = TextEmbeddingExtractor(config)
|
203 |
+
|
204 |
+
# Параметры
|
205 |
+
hidden_dim = config.hidden_dim
|
206 |
+
num_classes = len(config.emotion_columns)
|
207 |
+
num_transformer_heads = config.num_transformer_heads
|
208 |
+
num_graph_heads = config.num_graph_heads
|
209 |
+
hidden_dim_gated = config.hidden_dim_gated
|
210 |
+
mode = config.mode
|
211 |
+
positional_encoding = config.positional_encoding
|
212 |
+
dropout = config.dropout
|
213 |
+
out_features = config.out_features
|
214 |
+
lr = config.lr
|
215 |
+
num_epochs = config.num_epochs
|
216 |
+
tr_layer_number = config.tr_layer_number
|
217 |
+
max_patience = config.max_patience
|
218 |
+
|
219 |
+
dict_models = {
|
220 |
+
'BiFormer': BiFormer,
|
221 |
+
'BiGraphFormer': BiGraphFormer,
|
222 |
+
'BiGatedGraphFormer': BiGatedGraphFormer,
|
223 |
+
# 'MultiModalTransformer_v5': MultiModalTransformer_v5,
|
224 |
+
# 'MultiModalTransformer_v4': MultiModalTransformer_v4,
|
225 |
+
# 'MultiModalTransformer_v3': MultiModalTransformer_v3
|
226 |
+
}
|
227 |
+
|
228 |
+
model_cls = dict_models[config.model_name]
|
229 |
+
model = model_cls(
|
230 |
+
audio_dim = config.audio_embedding_dim,
|
231 |
+
text_dim = config.text_embedding_dim,
|
232 |
+
hidden_dim = hidden_dim,
|
233 |
+
hidden_dim_gated = hidden_dim_gated,
|
234 |
+
num_transformer_heads = num_transformer_heads,
|
235 |
+
num_graph_heads = num_graph_heads,
|
236 |
+
seg_len = config.max_tokens,
|
237 |
+
mode = mode,
|
238 |
+
dropout = dropout,
|
239 |
+
positional_encoding = positional_encoding,
|
240 |
+
out_features = out_features,
|
241 |
+
tr_layer_number = tr_layer_number,
|
242 |
+
device = device,
|
243 |
+
num_classes = num_classes
|
244 |
+
).to(device)
|
245 |
+
|
246 |
+
# Оптимизатор и лосс
|
247 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
248 |
+
|
249 |
+
class_weights = get_class_weights_from_loader(train_loader, num_classes)
|
250 |
+
criterion = WeightedCrossEntropyLoss(class_weights)
|
251 |
+
|
252 |
+
logging.info("Class weights: " + ", ".join(f"{name}={weight:.4f}" for name, weight in zip(config.emotion_columns, class_weights)))
|
253 |
+
|
254 |
+
# LR Scheduler
|
255 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
256 |
+
optimizer,
|
257 |
+
mode="max",
|
258 |
+
factor=0.5,
|
259 |
+
patience=2,
|
260 |
+
min_lr=1e-7
|
261 |
+
)
|
262 |
+
|
263 |
+
# Early stopping по dev
|
264 |
+
best_dev_mean = float("-inf")
|
265 |
+
best_dev_metrics = {}
|
266 |
+
patience_counter = 0
|
267 |
+
|
268 |
+
for epoch in range(num_epochs):
|
269 |
+
logging.info(f"\n=== Эпоха {epoch} ===")
|
270 |
+
model.train()
|
271 |
+
|
272 |
+
total_loss = 0.0
|
273 |
+
total_samples = 0
|
274 |
+
total_preds = []
|
275 |
+
total_targets = []
|
276 |
+
|
277 |
+
for batch in tqdm(train_loader):
|
278 |
+
if batch is None:
|
279 |
+
continue
|
280 |
+
|
281 |
+
audio = batch["audio"].to(device)
|
282 |
+
labels = batch["label"].to(device)
|
283 |
+
texts = batch["text"]
|
284 |
+
|
285 |
+
audio_emb = audio_extractor.extract(audio)
|
286 |
+
text_emb = text_extractor.extract(texts)
|
287 |
+
|
288 |
+
logits = model(audio_emb, text_emb)
|
289 |
+
target = labels.argmax(dim=1)
|
290 |
+
loss = criterion(logits, target)
|
291 |
+
|
292 |
+
optimizer.zero_grad()
|
293 |
+
loss.backward()
|
294 |
+
optimizer.step()
|
295 |
+
|
296 |
+
bs = audio.shape[0]
|
297 |
+
total_loss += loss.item() * bs
|
298 |
+
|
299 |
+
preds = logits.argmax(dim=1)
|
300 |
+
total_preds.extend(preds.cpu().numpy().tolist())
|
301 |
+
total_targets.extend(target.cpu().numpy().tolist())
|
302 |
+
total_samples += bs
|
303 |
+
|
304 |
+
train_loss = total_loss / total_samples
|
305 |
+
uar_m = uar(total_targets, total_preds)
|
306 |
+
war_m = war(total_targets, total_preds)
|
307 |
+
mf1_m = mf1(total_targets, total_preds)
|
308 |
+
wf1_m = wf1(total_targets, total_preds)
|
309 |
+
mean_train = np.mean([uar_m, war_m, mf1_m, wf1_m])
|
310 |
+
|
311 |
+
logging.info(
|
312 |
+
f"[TRAIN] Loss={train_loss:.4f}, UAR={uar_m:.4f}, WAR={war_m:.4f}, "
|
313 |
+
f"MF1={mf1_m:.4f}, WF1={wf1_m:.4f}, MEAN={mean_train:.4f}"
|
314 |
+
)
|
315 |
+
|
316 |
+
# --- DEV ---
|
317 |
+
dev_means = []
|
318 |
+
dev_metrics_by_dataset = []
|
319 |
+
|
320 |
+
for name, loader in dev_loaders:
|
321 |
+
d_loss, d_uar, d_war, d_mf1, d_wf1 = run_eval(
|
322 |
+
model, loader, audio_extractor, text_extractor, criterion, device
|
323 |
+
)
|
324 |
+
d_mean = np.mean([d_uar, d_war, d_mf1, d_wf1])
|
325 |
+
dev_means.append(d_mean)
|
326 |
+
|
327 |
+
if csv_writer:
|
328 |
+
csv_writer.writerow(["dev", epoch, name, d_loss, d_uar, d_war, d_mf1, d_wf1, d_mean])
|
329 |
+
|
330 |
+
logging.info(
|
331 |
+
f"[DEV:{name}] Loss={d_loss:.4f}, UAR={d_uar:.4f}, WAR={d_war:.4f}, "
|
332 |
+
f"MF1={d_mf1:.4f}, WF1={d_wf1:.4f}, MEAN={d_mean:.4f}"
|
333 |
+
)
|
334 |
+
|
335 |
+
dev_metrics_by_dataset.append({
|
336 |
+
"name": name,
|
337 |
+
"loss": d_loss,
|
338 |
+
"uar": d_uar,
|
339 |
+
"war": d_war,
|
340 |
+
"mf1": d_mf1,
|
341 |
+
"wf1": d_wf1,
|
342 |
+
"mean": d_mean,
|
343 |
+
})
|
344 |
+
|
345 |
+
mean_dev = np.mean(dev_means)
|
346 |
+
scheduler.step(mean_dev)
|
347 |
+
|
348 |
+
if mean_dev > best_dev_mean:
|
349 |
+
best_dev_mean = mean_dev
|
350 |
+
patience_counter = 0
|
351 |
+
best_dev_metrics = {
|
352 |
+
"mean": mean_dev
|
353 |
+
}
|
354 |
+
best_dev_metrics["by_dataset"] = dev_metrics_by_dataset
|
355 |
+
else:
|
356 |
+
patience_counter += 1
|
357 |
+
if patience_counter >= max_patience:
|
358 |
+
logging.info(f"Early stopping: {max_patience} эпох без улучшения.")
|
359 |
+
break
|
360 |
+
|
361 |
+
# --- TEST ---
|
362 |
+
for name, loader in test_loaders:
|
363 |
+
t_loss, t_uar, t_war, t_mf1, t_wf1 = run_eval(
|
364 |
+
model, loader, audio_extractor, text_extractor, criterion, device
|
365 |
+
)
|
366 |
+
t_mean = np.mean([t_uar, t_war, t_mf1, t_wf1])
|
367 |
+
logging.info(
|
368 |
+
f"[TEST:{name}] Loss={t_loss:.4f}, UAR={t_uar:.4f}, WAR={t_war:.4f}, "
|
369 |
+
f"MF1={t_mf1:.4f}, WF1={t_wf1:.4f}, MEAN={t_mean:.4f}"
|
370 |
+
)
|
371 |
+
|
372 |
+
if csv_writer:
|
373 |
+
csv_writer.writerow(["test", epoch, name, t_loss, t_uar, t_war, t_mf1, t_wf1, t_mean])
|
374 |
+
|
375 |
+
if csv_file:
|
376 |
+
csv_file.close()
|
377 |
+
|
378 |
+
logging.info("Тренировка завершена. Все split'ы обработаны!")
|
379 |
+
return best_dev_mean, best_dev_metrics
|
utils/__pycache__/config_loader.cpython-310.pyc
ADDED
Binary file (6.53 kB). View file
|
|
utils/config_loader.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/config_loader.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import toml
|
5 |
+
import logging
|
6 |
+
|
7 |
+
class ConfigLoader:
|
8 |
+
"""
|
9 |
+
Класс для загрузки и обработки конфигурации из `config.toml`.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, config_path="config.toml"):
|
13 |
+
if not os.path.exists(config_path):
|
14 |
+
raise FileNotFoundError(f"Файл конфигурации `{config_path}` не найден!")
|
15 |
+
|
16 |
+
self.config = toml.load(config_path)
|
17 |
+
|
18 |
+
# ---------------------------
|
19 |
+
# Общие параметры
|
20 |
+
# ---------------------------
|
21 |
+
self.split = self.config.get("split", "train")
|
22 |
+
|
23 |
+
# ---------------------------
|
24 |
+
# Пути к данным
|
25 |
+
# ---------------------------
|
26 |
+
self.datasets = self.config.get("datasets", {})
|
27 |
+
|
28 |
+
# ---------------------------
|
29 |
+
# Пути к синтетическим данным
|
30 |
+
# ---------------------------
|
31 |
+
synthetic_data_cfg = self.config.get("synthetic_data", {})
|
32 |
+
self.use_synthetic_data = synthetic_data_cfg.get("use_synthetic_data", False)
|
33 |
+
self.synthetic_path = synthetic_data_cfg.get("synthetic_path", "E:/MELD_S")
|
34 |
+
self.synthetic_ratio = synthetic_data_cfg.get("synthetic_ratio", 0.0)
|
35 |
+
|
36 |
+
# ---------------------------
|
37 |
+
# Модальности и эмоции
|
38 |
+
# ---------------------------
|
39 |
+
self.modalities = self.config.get("modalities", ["audio"])
|
40 |
+
self.emotion_columns = self.config.get("emotion_columns", ["anger", "disgust", "fear", "happy", "neutral", "sad", "surprise"])
|
41 |
+
|
42 |
+
# ---------------------------
|
43 |
+
# DataLoader
|
44 |
+
# ---------------------------
|
45 |
+
dataloader_cfg = self.config.get("dataloader", {})
|
46 |
+
self.num_workers = dataloader_cfg.get("num_workers", 0)
|
47 |
+
self.shuffle = dataloader_cfg.get("shuffle", True)
|
48 |
+
self.prepare_only = dataloader_cfg.get("prepare_only", False)
|
49 |
+
|
50 |
+
# ---------------------------
|
51 |
+
# Аудио
|
52 |
+
# ---------------------------
|
53 |
+
audio_cfg = self.config.get("audio", {})
|
54 |
+
self.sample_rate = audio_cfg.get("sample_rate", 16000)
|
55 |
+
self.wav_length = audio_cfg.get("wav_length", 2)
|
56 |
+
self.save_merged_audio = audio_cfg.get("save_merged_audio", True)
|
57 |
+
self.merged_audio_base_path = audio_cfg.get("merged_audio_base_path", "saved_merges")
|
58 |
+
self.merged_audio_suffix = audio_cfg.get("merged_audio_suffix", "_merged")
|
59 |
+
self.force_remerge = audio_cfg.get("force_remerge", False)
|
60 |
+
|
61 |
+
# ---------------------------
|
62 |
+
# Whisper / Текст
|
63 |
+
# ---------------------------
|
64 |
+
text_cfg = self.config.get("text", {})
|
65 |
+
self.text_source = text_cfg.get("source", "csv")
|
66 |
+
self.text_column = text_cfg.get("text_column", "text")
|
67 |
+
self.whisper_model = text_cfg.get("whisper_model", "tiny")
|
68 |
+
self.max_text_tokens = text_cfg.get("max_tokens", 15)
|
69 |
+
self.whisper_device = text_cfg.get("whisper_device", "cuda")
|
70 |
+
self.use_whisper_for_nontrain_if_no_text = text_cfg.get("use_whisper_for_nontrain_if_no_text", True)
|
71 |
+
|
72 |
+
# ---------------------------
|
73 |
+
# Тренировка: общие
|
74 |
+
# ---------------------------
|
75 |
+
train_general = self.config.get("train", {}).get("general", {})
|
76 |
+
self.random_seed = train_general.get("random_seed", 42)
|
77 |
+
self.subset_size = train_general.get("subset_size", 0)
|
78 |
+
self.merge_probability = train_general.get("merge_probability", 0)
|
79 |
+
self.batch_size = train_general.get("batch_size", 8)
|
80 |
+
self.num_epochs = train_general.get("num_epochs", 100)
|
81 |
+
self.max_patience = train_general.get("max_patience", 10)
|
82 |
+
self.save_best_model = train_general.get("save_best_model", False)
|
83 |
+
self.save_prepared_data = train_general.get("save_prepared_data", True)
|
84 |
+
self.save_feature_path = train_general.get("save_feature_path", "./features/")
|
85 |
+
self.search_type = train_general.get("search_type", "none")
|
86 |
+
self.smoothing_probability = train_general.get("smoothing_probability", 0)
|
87 |
+
self.path_to_df_ls = train_general.get("path_to_df_ls", None)
|
88 |
+
|
89 |
+
# ---------------------------
|
90 |
+
# Тренировка: параметры модели
|
91 |
+
# ---------------------------
|
92 |
+
train_model = self.config.get("train", {}).get("model", {})
|
93 |
+
self.model_name = train_model.get("model_name", "BiFormer")
|
94 |
+
self.hidden_dim = train_model.get("hidden_dim", 256)
|
95 |
+
self.hidden_dim_gated = train_model.get("hidden_dim_gated", 256)
|
96 |
+
self.num_transformer_heads = train_model.get("num_transformer_heads", 8)
|
97 |
+
self.num_graph_heads = train_model.get("num_graph_heads", 8)
|
98 |
+
self.tr_layer_number = train_model.get("tr_layer_number", 1)
|
99 |
+
self.mamba_d_state = train_model.get("mamba_d_state", 16)
|
100 |
+
self.mamba_ker_size = train_model.get("mamba_ker_size", 4)
|
101 |
+
self.mamba_layer_number = train_model.get("mamba_layer_number", 3)
|
102 |
+
self.positional_encoding = train_model.get("positional_encoding", True)
|
103 |
+
self.dropout = train_model.get("dropout", 0.0)
|
104 |
+
self.out_features = train_model.get("out_features", 128)
|
105 |
+
self.mode = train_model.get("mode", "mean")
|
106 |
+
|
107 |
+
# ---------------------------
|
108 |
+
# Тренировка: оптимизатор
|
109 |
+
# ---------------------------
|
110 |
+
train_optimizer = self.config.get("train", {}).get("optimizer", {})
|
111 |
+
self.optimizer = train_optimizer.get("optimizer", "adam")
|
112 |
+
self.lr = train_optimizer.get("lr", 1e-4)
|
113 |
+
self.weight_decay = train_optimizer.get("weight_decay", 0.0)
|
114 |
+
self.momentum = train_optimizer.get("momentum", 0.9)
|
115 |
+
|
116 |
+
# ---------------------------
|
117 |
+
# Тренировка: шедулер
|
118 |
+
# ---------------------------
|
119 |
+
train_scheduler = self.config.get("train", {}).get("scheduler", {})
|
120 |
+
self.scheduler_type = train_scheduler.get("scheduler_type", "plateau")
|
121 |
+
self.warmup_ratio = train_scheduler.get("warmup_ratio", 0.1)
|
122 |
+
|
123 |
+
# ---------------------------
|
124 |
+
# Эмбеддинги
|
125 |
+
# ---------------------------
|
126 |
+
emb_cfg = self.config.get("embeddings", {})
|
127 |
+
self.audio_model_name = emb_cfg.get("audio_model", "amiriparian/ExHuBERT")
|
128 |
+
self.text_model_name = emb_cfg.get("text_model", "jinaai/jina-embeddings-v3")
|
129 |
+
self.audio_classifier_checkpoint = emb_cfg.get("audio_classifier_checkpoint", "best_audio_model.pt")
|
130 |
+
self.text_classifier_checkpoint = emb_cfg.get("text_classifier_checkpoint", "best_text_model.pth")
|
131 |
+
self.audio_embedding_dim = emb_cfg.get("audio_embedding_dim", 1024)
|
132 |
+
self.text_embedding_dim = emb_cfg.get("text_embedding_dim", 1024)
|
133 |
+
self.emb_normalize = emb_cfg.get("emb_normalize", True)
|
134 |
+
self.audio_pooling = emb_cfg.get("audio_pooling", None)
|
135 |
+
self.text_pooling = emb_cfg.get("text_pooling", None)
|
136 |
+
self.max_tokens = emb_cfg.get("max_tokens", 256)
|
137 |
+
self.emb_device = emb_cfg.get("device", "cuda")
|
138 |
+
|
139 |
+
# ---------------------------
|
140 |
+
# Синтетика
|
141 |
+
# ---------------------------
|
142 |
+
# textgen_cfg = self.config.get("textgen", {})
|
143 |
+
# self.model_name = textgen_cfg.get("model_name", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
|
144 |
+
# self.max_new_tokens = textgen_cfg.get("max_new_tokens", 50)
|
145 |
+
# self.temperature = textgen_cfg.get("temperature", 1.0)
|
146 |
+
# self.top_p = textgen_cfg.get("top_p", 0.95)
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
self.log_config()
|
150 |
+
|
151 |
+
def log_config(self):
|
152 |
+
logging.info("=== CONFIGURATION ===")
|
153 |
+
logging.info(f"Split: {self.split}")
|
154 |
+
logging.info(f"Datasets loaded: {list(self.datasets.keys())}")
|
155 |
+
for name, ds in self.datasets.items():
|
156 |
+
logging.info(f"[Dataset: {name}]")
|
157 |
+
logging.info(f" Base Dir: {ds.get('base_dir', 'N/A')}")
|
158 |
+
logging.info(f" CSV Path: {ds.get('csv_path', '')}")
|
159 |
+
logging.info(f" WAV Dir: {ds.get('wav_dir', '')}")
|
160 |
+
logging.info(f"Emotion columns: {self.emotion_columns}")
|
161 |
+
|
162 |
+
# Логируем обучающие параметры
|
163 |
+
logging.info("--- Training Config ---")
|
164 |
+
logging.info(f"Sample Rate={self.sample_rate}, Wav Length={self.wav_length}s")
|
165 |
+
logging.info(f"Whisper Model={self.whisper_model}, Device={self.whisper_device}, MaxTokens={self.max_text_tokens}")
|
166 |
+
logging.info(f"use_whisper_for_nontrain_if_no_text={self.use_whisper_for_nontrain_if_no_text}")
|
167 |
+
logging.info(f"DataLoader: batch_size={self.batch_size}, num_workers={self.num_workers}, shuffle={self.shuffle}")
|
168 |
+
logging.info(f"Model Name: {self.model_name}")
|
169 |
+
logging.info(f"Random Seed: {self.random_seed}")
|
170 |
+
logging.info(f"Hidden Dim: {self.hidden_dim}")
|
171 |
+
logging.info(f"Hidden Dim in Gated: {self.hidden_dim_gated}")
|
172 |
+
logging.info(f"Num Heads in Transformer: {self.num_transformer_heads}")
|
173 |
+
logging.info(f"Num Heads in Graph: {self.num_graph_heads}")
|
174 |
+
logging.info(f"Mode stat pooling: {self.mode}")
|
175 |
+
logging.info(f"Optimizer: {self.optimizer}")
|
176 |
+
logging.info(f"Scheduler Type: {self.scheduler_type}")
|
177 |
+
logging.info(f"Warmup Ratio: {self.warmup_ratio}")
|
178 |
+
logging.info(f"Weight Decay for Adam: {self.weight_decay}")
|
179 |
+
logging.info(f"Momentum (SGD): {self.momentum}")
|
180 |
+
logging.info(f"Positional Encoding: {self.positional_encoding}")
|
181 |
+
logging.info(f"Number of Transformer Layers: {self.tr_layer_number}")
|
182 |
+
logging.info(f"Mamba D State: {self.mamba_d_state}")
|
183 |
+
logging.info(f"Mamba Kernel Size: {self.mamba_ker_size}")
|
184 |
+
logging.info(f"Mamba Layer Number: {self.mamba_layer_number}")
|
185 |
+
logging.info(f"Dropout: {self.dropout}")
|
186 |
+
logging.info(f"Out Features: {self.out_features}")
|
187 |
+
logging.info(f"LR: {self.lr}")
|
188 |
+
logging.info(f"Num Epochs: {self.num_epochs}")
|
189 |
+
logging.info(f"Merge Probability={self.merge_probability}")
|
190 |
+
logging.info(f"Smoothing Probability={self.smoothing_probability}")
|
191 |
+
logging.info(f"Max Patience={self.max_patience}")
|
192 |
+
logging.info(f"Save Prepared Data={self.save_prepared_data}")
|
193 |
+
logging.info(f"Path to Save Features={self.save_feature_path}")
|
194 |
+
logging.info(f"Search Type={self.search_type}")
|
195 |
+
|
196 |
+
# Логируем embeddings
|
197 |
+
logging.info("--- Embeddings Config ---")
|
198 |
+
logging.info(f"Audio Model: {self.audio_model_name}, Text Model: {self.text_model_name}")
|
199 |
+
logging.info(f"Audio dim={self.audio_embedding_dim}, Text dim={self.text_embedding_dim}")
|
200 |
+
logging.info(f"Audio pooling={self.audio_pooling}, Text pooling={self.text_pooling}")
|
201 |
+
logging.info(f"Emb device={self.emb_device}, Normalize={self.emb_normalize}")
|
202 |
+
|
203 |
+
def show_config(self):
|
204 |
+
self.log_config()
|
utils/logger_setup.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/logger_setup.py
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from colorlog import ColoredFormatter
|
5 |
+
|
6 |
+
def setup_logger(level=logging.INFO, log_file=None):
|
7 |
+
"""
|
8 |
+
Настраивает корневой логгер для вывода цветных логов в консоль и
|
9 |
+
(опционально) записи в файл.
|
10 |
+
|
11 |
+
:param level: Уровень логирования (например, logging.DEBUG)
|
12 |
+
:param log_file: Путь к файлу лога (если не None, логи будут писаться в этот файл)
|
13 |
+
"""
|
14 |
+
logger = logging.getLogger()
|
15 |
+
if logger.hasHandlers():
|
16 |
+
logger.handlers.clear()
|
17 |
+
|
18 |
+
# Консольный хендлер с colorlog
|
19 |
+
console_handler = logging.StreamHandler()
|
20 |
+
log_format = (
|
21 |
+
"%(log_color)s%(asctime)s [%(levelname)s]%(reset)s %(blue)s%(message)s"
|
22 |
+
)
|
23 |
+
console_formatter = ColoredFormatter(
|
24 |
+
log_format,
|
25 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
26 |
+
reset=True,
|
27 |
+
log_colors={
|
28 |
+
"DEBUG": "cyan",
|
29 |
+
"INFO": "green",
|
30 |
+
"WARNING": "yellow",
|
31 |
+
"ERROR": "red",
|
32 |
+
"CRITICAL": "bold_red"
|
33 |
+
}
|
34 |
+
)
|
35 |
+
console_handler.setFormatter(console_formatter)
|
36 |
+
logger.addHandler(console_handler)
|
37 |
+
|
38 |
+
# Если указан log_file, добавляем файловый хендлер
|
39 |
+
if log_file is not None:
|
40 |
+
file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
|
41 |
+
file_format = "%(asctime)s [%(levelname)s] %(message)s"
|
42 |
+
file_formatter = logging.Formatter(file_format, datefmt="%Y-%m-%d %H:%M:%S")
|
43 |
+
file_handler.setFormatter(file_formatter)
|
44 |
+
logger.addHandler(file_handler)
|
45 |
+
|
46 |
+
logger.setLevel(level)
|
47 |
+
return logger
|
utils/losses.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class WeightedCrossEntropyLoss(nn.Module):
|
6 |
+
def __init__(self, class_weights=None):
|
7 |
+
"""
|
8 |
+
Инициализация класса для кросс-энтропийной потери с возможностью взвешивания классов.
|
9 |
+
|
10 |
+
:param class_weights: Вектор весов для классов (опционально)
|
11 |
+
"""
|
12 |
+
super(WeightedCrossEntropyLoss, self).__init__()
|
13 |
+
self.class_weights = class_weights
|
14 |
+
|
15 |
+
def forward(self, y_pred, y_true):
|
16 |
+
"""
|
17 |
+
Вычисление кросс-энтропийной потери с (или без) взвешиванием классов.
|
18 |
+
|
19 |
+
:param y_true: Точные метки классов (вектор или одна метка)
|
20 |
+
:param y_pred: Вероятностный вектор предсказаний
|
21 |
+
:return: Значение потери
|
22 |
+
"""
|
23 |
+
|
24 |
+
y_true = y_true.to(torch.long) # Приводим метки к типу Long
|
25 |
+
y_pred = y_pred.to(torch.float32) # Приводим предсказания к типу Float32
|
26 |
+
|
27 |
+
if self.class_weights is not None:
|
28 |
+
class_weights = torch.tensor(self.class_weights).float().to(y_true.device)
|
29 |
+
loss = F.cross_entropy(y_pred, y_true, weight=class_weights)
|
30 |
+
else:
|
31 |
+
loss = F.cross_entropy(y_pred, y_true)
|
32 |
+
|
33 |
+
return loss
|
utils/measures.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import recall_score, f1_score
|
2 |
+
|
3 |
+
def uar(y_true, y_pred):
|
4 |
+
"""
|
5 |
+
Вычисление метрики UAR (Unweighted Average Recall).
|
6 |
+
|
7 |
+
:param y_true: Истинные метки
|
8 |
+
:param y_pred: Предсказанные метки
|
9 |
+
:return: UAR (Recall по всем классам без учета веса)
|
10 |
+
"""
|
11 |
+
return recall_score(y_true, y_pred, average='macro', zero_division=0)
|
12 |
+
|
13 |
+
def war(y_true, y_pred):
|
14 |
+
"""
|
15 |
+
Вычисление метрики WAR (Weighted Average Recall).
|
16 |
+
|
17 |
+
:param y_true: Истинные метки
|
18 |
+
:param y_pred: Предсказанные метки
|
19 |
+
:return: WAR (Recall с учетом веса классов)
|
20 |
+
"""
|
21 |
+
return recall_score(y_true, y_pred, average='weighted', zero_division=0)
|
22 |
+
|
23 |
+
def mf1(y_true, y_pred):
|
24 |
+
"""
|
25 |
+
Вычисление метрики MF1 (Macro F1 Score).
|
26 |
+
|
27 |
+
:param y_true: Истинные метки
|
28 |
+
:param y_pred: Предсказанные метки
|
29 |
+
:return: MF1 (F1 с усреднением по всем классам)
|
30 |
+
"""
|
31 |
+
return f1_score(y_true, y_pred, average='macro', zero_division=0)
|
32 |
+
|
33 |
+
def wf1(y_true, y_pred):
|
34 |
+
"""
|
35 |
+
Вычисление метрики WFI (Weighted F1 Score).
|
36 |
+
|
37 |
+
:param y_true: Истинные метки
|
38 |
+
:param y_pred: Предсказанные метки
|
39 |
+
:return: WFI (F1 с учетом веса классов)
|
40 |
+
"""
|
41 |
+
return f1_score(y_true, y_pred, average='weighted', zero_division=0)
|