farbverlauf commited on
Commit
960b1a0
·
1 Parent(s): 4702e13
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. LICENSE +21 -0
  3. Phi-4-mini-instruct_emotions_union.csv +0 -0
  4. Qwen3-4B_emotions_meld.csv +0 -0
  5. Qwen3-4B_emotions_resd.csv +0 -0
  6. Qwen3-4B_emotions_union.csv +0 -0
  7. analysis.ipynb +0 -0
  8. app.py +315 -0
  9. best_audio_model.pt +3 -0
  10. best_audio_model_2.pt +3 -0
  11. best_model_dev_0_5895_epoch_8.pt +3 -0
  12. best_text_model.pth +3 -0
  13. check.py +230 -0
  14. config.toml +133 -0
  15. data_loading/__pycache__/feature_extractor.cpython-310.pyc +0 -0
  16. data_loading/__pycache__/pretrained_extractors.cpython-310.pyc +0 -0
  17. data_loading/dataset_multimodal.py +898 -0
  18. data_loading/feature_extractor.py +410 -0
  19. data_loading/pretrained_extractors.py +221 -0
  20. emotion_templates/anger.json +196 -0
  21. emotion_templates/disgust.json +174 -0
  22. emotion_templates/fear.json +178 -0
  23. emotion_templates/happy.json +187 -0
  24. emotion_templates/neutral.json +97 -0
  25. emotion_templates/sad.json +183 -0
  26. emotion_templates/surprise.json +198 -0
  27. generate_emotion_texts_dataset.py +137 -0
  28. generate_synthetic_dataset.py +71 -0
  29. main.py +119 -0
  30. models/__init__.py +0 -0
  31. models/__pycache__/__init__.cpython-310.pyc +0 -0
  32. models/__pycache__/help_layers.cpython-310.pyc +0 -0
  33. models/__pycache__/models.cpython-310.pyc +0 -0
  34. models/help_layers.py +528 -0
  35. models/models.py +1700 -0
  36. requirements.txt +0 -0
  37. run_generation.py +32 -0
  38. search_params.toml +22 -0
  39. synthetic_utils/__pycache__/dia_tts_wrapper.cpython-310.pyc +0 -0
  40. synthetic_utils/dia_tts_wrapper.py +77 -0
  41. synthetic_utils/parler_tts_wrapper.py +60 -0
  42. synthetic_utils/text_generation.py +91 -0
  43. test.py +28 -0
  44. training/train_utils.py +585 -0
  45. training/train_utils_old.py +379 -0
  46. utils/__pycache__/config_loader.cpython-310.pyc +0 -0
  47. utils/config_loader.py +204 -0
  48. utils/logger_setup.py +47 -0
  49. utils/losses.py +33 -0
  50. utils/measures.py +41 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 LEYA Lab for Natural Language Processing
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Phi-4-mini-instruct_emotions_union.csv ADDED
The diff for this file is too large to render. See raw diff
 
Qwen3-4B_emotions_meld.csv ADDED
The diff for this file is too large to render. See raw diff
 
Qwen3-4B_emotions_resd.csv ADDED
The diff for this file is too large to render. See raw diff
 
Qwen3-4B_emotions_union.csv ADDED
The diff for this file is too large to render. See raw diff
 
analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchaudio
3
+ import pandas as pd
4
+ import torch.nn.functional as F
5
+ import whisper
6
+ import logging
7
+ import plotly.express as px
8
+ from utils.config_loader import ConfigLoader
9
+ from data_loading.feature_extractor import (
10
+ PretrainedAudioEmbeddingExtractor,
11
+ PretrainedTextEmbeddingExtractor
12
+ )
13
+ import chardet
14
+ import torch
15
+ from models.models import BiFormer
16
+
17
+
18
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ # DEVICE = torch.device('cpu')
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+ # Constants with emojis and colors
25
+ LABEL_TO_EMOTION = {
26
+ 0: '😠 Anger',
27
+ 1: '🤢 Disgust',
28
+ 2: '😨 Fear',
29
+ 3: '😄 Joy/Happiness',
30
+ 4: '😐 Neutral',
31
+ 5: '😢 Sadness',
32
+ 6: '😲 Surprise/Enthusiasm'
33
+ }
34
+ EMOTION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#FF9999', '#D4A5A5']
35
+ emotion_color_map = {emotion: color for emotion, color in zip(LABEL_TO_EMOTION.values(), EMOTION_COLORS)}
36
+ TARGET_SAMPLE_RATE = 16000
37
+
38
+
39
+ def initialize_components(config_path='config.toml'):
40
+ """Initialize configuration and models."""
41
+ config = ConfigLoader(config_path)
42
+ config.show_config()
43
+ model = BiFormer(
44
+ audio_dim=256,
45
+ text_dim=1024,
46
+ seg_len=95,
47
+ hidden_dim=256,
48
+ hidden_dim_gated=256,
49
+ num_transformer_heads=8,
50
+ num_graph_heads=2,
51
+ positional_encoding=False,
52
+ dropout=0.15,
53
+ mode='mean',
54
+ # device="cuda",
55
+ tr_layer_number=5,
56
+ out_features=256,
57
+ num_classes=7
58
+ )
59
+ checkpoint_path = "best_model_dev_0_5895_epoch_8.pt"
60
+ state = torch.load(checkpoint_path, map_location="cpu")
61
+ model.load_state_dict(state)
62
+ model = model.to(DEVICE)
63
+ model.eval()
64
+ return (
65
+ PretrainedAudioEmbeddingExtractor(config),
66
+ PretrainedTextEmbeddingExtractor(config),
67
+ whisper.load_model("base"),
68
+ model
69
+ )
70
+
71
+
72
+ audio_extractor, text_extractor, whisper_model, bimodal_model = initialize_components()
73
+
74
+
75
+ def load_and_preprocess_audio(audio_path):
76
+ """Load and preprocess audio to mono 16kHz format."""
77
+ try:
78
+ waveform, orig_sr = torchaudio.load(audio_path)
79
+ waveform = waveform.mean(dim=0, keepdim=False)
80
+
81
+ if orig_sr != TARGET_SAMPLE_RATE:
82
+ resampler = torchaudio.transforms.Resample(
83
+ orig_freq=orig_sr,
84
+ new_freq=TARGET_SAMPLE_RATE
85
+ )
86
+ waveform = resampler(waveform)
87
+
88
+ return waveform, TARGET_SAMPLE_RATE
89
+ except Exception as e:
90
+ logging.error(f"Audio loading failed: {e}")
91
+ raise
92
+
93
+
94
+ def transcribe_audio(audio_path):
95
+ """Convert speech to text using Whisper."""
96
+ try:
97
+ result = whisper_model.transcribe(audio_path, fp16=False)
98
+ return result.get('text', '')
99
+ except Exception as e:
100
+ logging.error(f"Transcription failed: {e}")
101
+ return ""
102
+
103
+
104
+ def get_predictions(input_data, extractor, is_audio=False):
105
+ """Generic prediction function for audio/text."""
106
+ try:
107
+ if is_audio:
108
+ pred, emb = extractor.extract(input_data, TARGET_SAMPLE_RATE)
109
+ else:
110
+ pred, emb = extractor.extract(input_data)
111
+
112
+ return F.softmax(pred, dim=-1)[0].tolist(), emb
113
+ except Exception as e:
114
+ logging.error(f"Prediction failed: {e}")
115
+ return [0.0] * len(LABEL_TO_EMOTION), None
116
+
117
+
118
+ def create_emotion_df(probabilities):
119
+ """Create sorted emotion probability dataframe with percentages."""
120
+ df = pd.DataFrame({
121
+ 'Emotion': list(LABEL_TO_EMOTION.values()),
122
+ 'Probability': [round(p*100, 2) for p in probabilities]
123
+ })
124
+ return df
125
+
126
+
127
+ def create_plot(df, title):
128
+ """Create Plotly bar chart with proper formatting."""
129
+ fig = px.bar(
130
+ df,
131
+ x='Emotion',
132
+ y='Probability',
133
+ title=title,
134
+ color='Emotion',
135
+ color_discrete_map=emotion_color_map
136
+ )
137
+ fig.update_layout(
138
+ xaxis=dict(tickangle=-45, tickfont=dict(size=12)),
139
+ yaxis=dict(title='Probability (%)'),
140
+ margin=dict(l=20, r=20, t=60, b=100),
141
+ height=400,
142
+ showlegend=False
143
+ )
144
+ return fig
145
+
146
+
147
+ def get_top_emotion(probabilities):
148
+ """Return formatted top emotion with percentage."""
149
+ max_idx = probabilities.index(max(probabilities))
150
+ return f"{LABEL_TO_EMOTION[max_idx]} ({max(probabilities)*100:.1f}%)"
151
+
152
+
153
+ def process_audio(audio_path):
154
+ """Main processing pipeline."""
155
+ try:
156
+ if not audio_path:
157
+ empty = create_emotion_df([0]*len(LABEL_TO_EMOTION))
158
+ return (
159
+ create_plot(empty, "🎧 Audio Analysis"),
160
+ "No audio detected",
161
+ create_plot(empty, "📝 Text Analysis"),
162
+ create_plot(empty, "🤝 Combined Analysis"),
163
+ "🔇 Please provide audio input"
164
+ )
165
+
166
+ # Audio processing
167
+ waveform, sr = load_and_preprocess_audio(audio_path)
168
+ audio_probs, audio_features = get_predictions(waveform, audio_extractor, is_audio=True)
169
+ audio_df = create_emotion_df(audio_probs)
170
+
171
+ # Text processing
172
+ text = transcribe_audio(audio_path)
173
+ text_probs, text_features = get_predictions(text, text_extractor) if text.strip() else [0.0]*len(LABEL_TO_EMOTION)
174
+ text_df = create_emotion_df(text_probs)
175
+
176
+ # Combined results
177
+ combined_probs = bimodal_model(audio_features, text_features)
178
+ combined_probs = F.softmax(combined_probs, dim=-1)[0].detach().cpu().numpy().tolist()
179
+ combined_df = create_emotion_df(combined_probs)
180
+ top_emotion = get_top_emotion(combined_probs)
181
+
182
+ return (
183
+ create_plot(audio_df, "🎧 Audio Analysis"),
184
+ f"🗣️ Transcription:\n{text}",
185
+ create_plot(text_df, "📝 Text Analysis"),
186
+ create_plot(combined_df, "🤝 Combined Analysis"),
187
+ f"## 🏆 Dominant Emotion: {top_emotion}"
188
+ )
189
+
190
+ except Exception as e:
191
+ logging.error(f"Processing failed: {e}")
192
+ error_df = create_emotion_df([0]*len(LABEL_TO_EMOTION))
193
+ return (
194
+ create_plot(error_df, "🎧 Audio Analysis"),
195
+ "❌ Error processing audio",
196
+ create_plot(error_df, "📝 Text Analysis"),
197
+ create_plot(error_df, "🤝 Combined Analysis"),
198
+ "⚠️ Processing Error"
199
+ )
200
+
201
+
202
+ def create_app():
203
+ """Build enhanced Gradio interface."""
204
+ with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection from Speech") as demo:
205
+ gr.Markdown("# 🎙️ Bimodal Emotion Recognition")
206
+ gr.Markdown("Analyze emotions in speech through both audio characteristics and spoken content")
207
+
208
+ with gr.Row():
209
+ audio_input = gr.Audio(
210
+ sources=["upload", "microphone"],
211
+ type="filepath",
212
+ label="Record or Upload Audio",
213
+ format="wav",
214
+ interactive=True
215
+ )
216
+
217
+ with gr.Row():
218
+ top_emotion = gr.Markdown("## 🏆 Dominant Emotion: Waiting for input...",
219
+ elem_classes="dominant-emotion")
220
+
221
+ with gr.Row():
222
+ with gr.Column():
223
+ audio_plot = gr.Plot(label="Audio Analysis")
224
+ with gr.Column():
225
+ text_plot = gr.Plot(label="Text Analysis")
226
+ with gr.Column():
227
+ combined_plot = gr.Plot(label="Combined Analysis")
228
+
229
+ transcription = gr.Textbox(
230
+ label="📜 Transcription Results",
231
+ placeholder="Transcribed text will appear here...",
232
+ lines=3,
233
+ max_lines=6
234
+ )
235
+
236
+ audio_input.change(
237
+ process_audio,
238
+ inputs=audio_input,
239
+ outputs=[audio_plot, transcription, text_plot, combined_plot, top_emotion]
240
+ )
241
+
242
+ return demo
243
+
244
+
245
+ def create_authors():
246
+ df = pd.DataFrame({
247
+ "Name": ["Author", "Author"]
248
+ })
249
+ with gr.Blocks() as demo:
250
+ gr.Dataframe(df)
251
+ return demo
252
+
253
+
254
+ def create_reqs():
255
+ """Create requirements tab with formatted data and explanations."""
256
+ # 1️⃣ Detect file encoding
257
+ with open('requirements.txt', 'rb') as f:
258
+ raw_data = f.read()
259
+ encoding = chardet.detect(raw_data)['encoding']
260
+
261
+ # 2️⃣ Parse requirements into library-version pairs
262
+ def parse_requirements(lines):
263
+ requirements = []
264
+ for line in lines:
265
+ line = line.strip()
266
+ if not line or line.startswith('#'):
267
+ continue # Skip empty lines and comments
268
+ parts = line.split('==')
269
+ library = parts[0].strip()
270
+ version = parts[1].strip() if len(parts) > 1 else 'latest'
271
+ requirements.append((library, version))
272
+ return requirements
273
+
274
+ # 3️⃣ Load and process requirements
275
+ with open('requirements.txt', 'r', encoding=encoding) as f:
276
+ requirements = parse_requirements(f.readlines())
277
+
278
+ # 4️⃣ Create structured data for display
279
+ df = pd.DataFrame({
280
+ "📦 Library": [lib for lib, _ in requirements],
281
+ "🚀 Recommended Version": [ver for _, ver in requirements]
282
+ })
283
+
284
+ # 5️⃣ Build interactive components
285
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
286
+ gr.Markdown("# 📦 Dependency Requirements")
287
+ gr.Markdown("""
288
+ ## Essential Packages for Operation
289
+ These are the core libraries and versions needed to run the application successfully:
290
+ """)
291
+ gr.Dataframe(
292
+ df,
293
+ interactive=True,
294
+ wrap=True,
295
+ elem_id="requirements-table"
296
+ )
297
+ gr.Markdown("_Note: Versions marked 'latest' can use any compatible version_")
298
+
299
+ return demo
300
+
301
+
302
+ def create_demo():
303
+ app = create_app()
304
+ authors = create_authors()
305
+ reqs = create_reqs()
306
+ demo = gr.TabbedInterface(
307
+ [app, authors, reqs],
308
+ tab_names=["🎙️ Speech Analysis", "👥 Project Team", "📦 Dependencies"]
309
+ )
310
+ return demo
311
+
312
+
313
+ if __name__ == "__main__":
314
+ demo = create_demo()
315
+ demo.launch()
best_audio_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cf27208d1d99448dd075ae4b499f97aa5fffd0ef7290ebe82555bc09a350469
3
+ size 3174946
best_audio_model_2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed1705e0cf0e5bc7bb1755f7fb09bfc440b4c92c739489d107d5e0e7a29707bc
3
+ size 3174674
best_model_dev_0_5895_epoch_8.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7eeba90564d5c301aa59f487aa59f71923b9d3f293327a99b88c34b71d4226f
3
+ size 17670490
best_text_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e25ae864742da9a8f4ba6b85b6161c503cb47dc6197dd0eb82b5250180a171cf
3
+ size 39340722
check.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Проверка синтетического корпуса MELD-S:
5
+ • существует ли WAV-файл;
6
+ • правильные ли размеры аудио- и текст-эмбеддингов;
7
+ • совпадает ли итоговый размер фич-вектора с ожиданием.
8
+
9
+ Результат:
10
+ GOOD / BAD в консоль + CSV bad_synth_meld.csv (если нашли проблемы).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import csv
16
+ import logging
17
+ import sys
18
+ import traceback
19
+ from pathlib import Path
20
+ from types import SimpleNamespace
21
+ from typing import Dict, List, Optional
22
+
23
+ import pandas as pd
24
+ import torch
25
+ import torchaudio
26
+ from tqdm import tqdm
27
+
28
+ # ----------------------------------------------------------------------
29
+ # >>>>>>>>> НАСТРОЙКИ ПОЛЬЗОВАТЕЛЯ (проверьте пути!) <<<<<<<<<<<
30
+ # ----------------------------------------------------------------------
31
+ USER_CONFIG = {
32
+ # пути к синтетике
33
+ "synthetic_path": r"E:/MELD_S",
34
+ "csv_name": "meld_s_train_labels.csv",
35
+ "wav_subdir": "wavs",
36
+
37
+ # модели / чекпойнты такие же, как в вашем config.toml
38
+ "audio_model_name": "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim",
39
+ "audio_ckpt": "best_audio_model_2.pt",
40
+ "text_model_name": "jinaai/jina-embeddings-v3",
41
+ "text_ckpt": "best_text_model.pth",
42
+
43
+ # общие параметры
44
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
45
+ "sample_rate": 16000,
46
+ "num_emotions": 7, # anger, disgust, fear, happy, neutral, sad, surprise
47
+ }
48
+
49
+ # ----------------------------------------------------------------------
50
+ # импорт собственных экстракторов
51
+ # ----------------------------------------------------------------------
52
+ try:
53
+ from feature_extractor import (
54
+ PretrainedAudioEmbeddingExtractor,
55
+ PretrainedTextEmbeddingExtractor,
56
+ )
57
+ except ModuleNotFoundError:
58
+ try:
59
+ # если файл лежит в data_loading/
60
+ from data_loading.feature_extractor import (
61
+ PretrainedAudioEmbeddingExtractor,
62
+ PretrainedTextEmbeddingExtractor,
63
+ )
64
+ except ModuleNotFoundError as e:
65
+ sys.exit(
66
+ "❌ Не найден feature_extractor.py. "
67
+ "Убедитесь, что он в PYTHONPATH или лежит рядом со скриптом."
68
+ )
69
+
70
+ # ----------------------------------------------------------------------
71
+ # вспомогательные функции
72
+ # ----------------------------------------------------------------------
73
+ def build_audio_cfg() -> SimpleNamespace:
74
+ """Готовим config-объект для PretrainedAudioEmbeddingExtractor."""
75
+ return SimpleNamespace(
76
+ audio_model_name=USER_CONFIG["audio_model_name"],
77
+ emb_device=USER_CONFIG["device"],
78
+ audio_pooling="mean", # как в тренировке
79
+ emb_normalize=False,
80
+ max_audio_frames=0,
81
+ audio_classifier_checkpoint=USER_CONFIG["audio_ckpt"],
82
+ sample_rate=USER_CONFIG["sample_rate"],
83
+ wav_length=4,
84
+ )
85
+
86
+
87
+ def build_text_cfg() -> SimpleNamespace:
88
+ """Config для PretrainedTextEmbeddingExtractor."""
89
+ return SimpleNamespace(
90
+ text_model_name=USER_CONFIG["text_model_name"],
91
+ emb_device=USER_CONFIG["device"],
92
+ text_pooling="mean",
93
+ emb_normalize=False,
94
+ max_tokens=95,
95
+ text_classifier_checkpoint=USER_CONFIG["text_ckpt"],
96
+ )
97
+
98
+
99
+ def get_dims(audio_extractor, text_extractor) -> Dict[str, int]:
100
+ """Возвращает фактические размеры эмбеддингов (audio_dim, text_dim)."""
101
+ sr = USER_CONFIG["sample_rate"]
102
+ with torch.no_grad():
103
+ dummy_wav = torch.zeros(1, sr)
104
+ _, a_emb = audio_extractor.extract(dummy_wav[0], sr)
105
+ audio_dim = a_emb[0].shape[-1]
106
+
107
+ _, t_emb = text_extractor.extract("dummy text")
108
+ text_dim = t_emb[0].shape[-1]
109
+
110
+ return {"audio_dim": audio_dim, "text_dim": text_dim}
111
+
112
+
113
+ def check_row(
114
+ row: pd.Series,
115
+ feats: Dict[str, object],
116
+ dims: Dict[str, int],
117
+ wav_dir: Path,
118
+ ) -> Optional[str]:
119
+ """
120
+ Возвращает None, если пример корректный, иначе строку-причину.
121
+ """
122
+ video = row["video_name"]
123
+ wav_path = wav_dir / f"{video}.wav"
124
+ text = row.get("text", "")
125
+
126
+ try:
127
+ if not wav_path.exists():
128
+ return "file_missing"
129
+
130
+ # ---------- аудио ----------
131
+ wf, sr = torchaudio.load(str(wav_path))
132
+ if sr != USER_CONFIG["sample_rate"]:
133
+ wf = torchaudio.transforms.Resample(sr, USER_CONFIG["sample_rate"])(wf)
134
+
135
+ a_pred, a_emb = feats["audio"].extract(wf[0], USER_CONFIG["sample_rate"])
136
+ a_emb = a_emb[0]
137
+ if a_emb.shape[-1] != dims["audio_dim"]:
138
+ return f"audio_dim_{a_emb.shape[-1]}"
139
+
140
+ # ---------- текст ----------
141
+ t_pred, t_emb = feats["text"].extract(text)
142
+ t_emb = t_emb[0]
143
+ if t_emb.shape[-1] != dims["text_dim"]:
144
+ return f"text_dim_{t_emb.shape[-1]}"
145
+
146
+ # ---------- конкатенация ----------
147
+ full_vec = torch.cat(
148
+ [a_emb, t_emb, a_pred[0], t_pred[0]],
149
+ dim=-1,
150
+ )
151
+ expected_all = (
152
+ dims["audio_dim"]
153
+ + dims["text_dim"]
154
+ + 2 * USER_CONFIG["num_emotions"]
155
+ )
156
+ if full_vec.shape[-1] != expected_all:
157
+ return f"concat_dim_{full_vec.shape[-1]}"
158
+
159
+ except Exception as e:
160
+ logging.error(f"{video}: {traceback.format_exc(limit=2)}")
161
+ return "exception_" + e.__class__.__name__
162
+
163
+ return None
164
+
165
+
166
+ # ----------------------------------------------------------------------
167
+ # основной скрипт
168
+ # ----------------------------------------------------------------------
169
+ def main() -> None:
170
+ syn_root = Path(USER_CONFIG["synthetic_path"])
171
+ csv_path = syn_root / USER_CONFIG["csv_name"]
172
+ wav_dir = syn_root / USER_CONFIG["wav_subdir"]
173
+
174
+ if not csv_path.exists():
175
+ sys.exit(f"CSV не найден: {csv_path}")
176
+ if not wav_dir.exists():
177
+ sys.exit(f"WAV-директория не найдена: {wav_dir}")
178
+
179
+ # 1. экстракторы
180
+ audio_feat = PretrainedAudioEmbeddingExtractor(build_audio_cfg())
181
+ text_feat = PretrainedTextEmbeddingExtractor(build_text_cfg())
182
+ feats = {"audio": audio_feat, "text": text_feat}
183
+
184
+ # 2. реальные размерности
185
+ dims = get_dims(audio_feat, text_feat)
186
+ expected_total = (
187
+ dims["audio_dim"] + dims["text_dim"] + 2 * USER_CONFIG["num_emotions"]
188
+ )
189
+ print(
190
+ f"Audio dim = {dims['audio_dim']}, "
191
+ f"Text dim = {dims['text_dim']}, "
192
+ f"Expected concat = {expected_total}"
193
+ )
194
+
195
+ # 3. правим CSV
196
+ df = pd.read_csv(csv_path)
197
+ bad_rows: List[Dict[str, str]] = []
198
+ good_cnt = 0
199
+
200
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Checking"):
201
+ reason = check_row(row, feats, dims, wav_dir)
202
+ if reason:
203
+ bad_rows.append(
204
+ {
205
+ "video_name": row["video_name"],
206
+ "reason": reason,
207
+ "wav_path": str(wav_dir / f"{row['video_name']}.wav"),
208
+ }
209
+ )
210
+ else:
211
+ good_cnt += 1
212
+
213
+ # 4. отчёт
214
+ print("\n========== SUMMARY ==========")
215
+ print(f"✅ GOOD : {good_cnt}")
216
+ print(f"❌ BAD : {len(bad_rows)}")
217
+
218
+ if bad_rows:
219
+ out_csv = Path(__file__).with_name("bad_synth_meld.csv")
220
+ with open(out_csv, "w", newline="", encoding="utf-8") as f:
221
+ writer = csv.DictWriter(
222
+ f, fieldnames=["video_name", "reason", "wav_path"]
223
+ )
224
+ writer.writeheader()
225
+ writer.writerows(bad_rows)
226
+ print(f"\nСписок проблемных примеров сохранён: {out_csv.resolve()}")
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()
config.toml ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------
2
+ # Настройки корпусов данных
3
+ # ---------------------------
4
+
5
+ [datasets.meld]
6
+ base_dir = "E:/MELD"
7
+ csv_path = "{base_dir}/meld_{split}_labels.csv"
8
+ wav_dir = "{base_dir}/wavs/{split}"
9
+
10
+ [datasets.resd]
11
+ base_dir = "E:/RESD"
12
+ csv_path = "{base_dir}/resd_{split}_labels.csv"
13
+ wav_dir = "{base_dir}/wavs/{split}"
14
+
15
+ [synthetic_data]
16
+ use_synthetic_data = false
17
+ synthetic_path = "E:/MELD_S"
18
+ synthetic_ratio = 0.005
19
+
20
+ # ---------------------------
21
+ # Список модальностей и эмоций
22
+ # ---------------------------
23
+ modalities = ["audio"]
24
+ # emotion_columns = ["neutral", "happy", "sad", "anger", "surprise", "disgust", "fear"]
25
+ emotion_columns = ["anger", "disgust", "fear", "happy", "neutral", "sad", "surprise"]
26
+
27
+ # ---------------------------
28
+ # DataLoader параметры
29
+ # ---------------------------
30
+ [dataloader]
31
+ num_workers = 0
32
+ shuffle = true
33
+ prepare_only = false
34
+
35
+ # ---------------------------
36
+ # Аудио
37
+ # ---------------------------
38
+ [audio]
39
+ sample_rate = 16000 # Целевая частота дискретизации
40
+ wav_length = 4 # Целевая длина (в секундах) для аудио
41
+ save_merged_audio = true
42
+ merged_audio_base_path = "saved_merges"
43
+ merged_audio_suffix = "_merged"
44
+ force_remerge = false
45
+
46
+ # ---------------------------
47
+ # Whisper и текст
48
+ # ---------------------------
49
+ [text]
50
+ # Если "csv", то мы стараемся брать текст из CSV, если там есть
51
+ # (поле text_column). Если нет - тогда Whisper (если нужно).
52
+ source = "csv"
53
+ text_column = "text"
54
+ whisper_model = "base"
55
+
56
+ # Указываем, где запускать Whisper: "cuda" (GPU) или "cpu"
57
+ whisper_device = "cuda"
58
+
59
+ # Если для dev/test в CSV нет текста, нужно ли всё же вызывать Whisper?
60
+ use_whisper_for_nontrain_if_no_text = true
61
+
62
+ # ---------------------------
63
+ # Общие параметры тренировки
64
+ # ---------------------------
65
+ [train.general]
66
+ random_seed = 42 # фиксируем random seed для воспроизводимости (0 = каждый раз разный)
67
+ subset_size = 100 # ограничение на количество примеров (0 = использовать весь датасет)
68
+ merge_probability = 0 # процент склеивания коротких файлов
69
+ batch_size = 8 # размер батча
70
+ num_epochs = 75 # число эпох тренировки
71
+ max_patience = 10 # максимальное число эпох без улучшений (для Early Stopping)
72
+ save_best_model = false
73
+ save_prepared_data = true # сохранять извлеченные признаки (эмбеддинги)
74
+ save_feature_path = './features/' # путь для сохранения эмбеддингов
75
+ search_type = "none" # стратегия поиска: "greedy", "exhaustive" или "none"
76
+ path_to_df_ls = 'Phi-4-mini-instruct_emotions_union.csv' # путь к датафрейму со смягченными метками - Qwen3-4B_emotions_union или Phi-4-mini-instruct_emotions_union
77
+ smoothing_probability = 0.0 # процент использования смягченных меток
78
+
79
+ # ---------------------------
80
+ # Параметры модели
81
+ # ---------------------------
82
+ [train.model]
83
+ model_name = "BiFormer" # название модели (BiGraphFormer, BiFormer, BiGatedGraphFormer, BiGatedFormer, BiMamba, PredictionsFusion, BiFormerWithProb, BiMambaWithProb, BiGraphFormerWithProb, BiGatedGraphFormerWithProb)
84
+ hidden_dim = 256 # размер скрытого состояния
85
+ hidden_dim_gated = 128 # скрытое состояние для gated механизмов
86
+ num_transformer_heads = 16 # количество attention голов в трансформере
87
+ num_graph_heads = 2 # количество голов в граф-механизме
88
+ tr_layer_number = 5 # количество слоев в трансформере
89
+ mamba_d_state = 16 # размер состояния в Mamba
90
+ mamba_ker_size = 6 # размер кернела в Mamba
91
+ mamba_layer_number = 5 # количество слоев Mamba
92
+ positional_encoding = false # использовать ли позиционное кодирование
93
+ dropout = 0.15 # dropout между слоями
94
+ out_features = 256 # размер финальных признаков перед классификацией
95
+ mode = 'mean' # способ агрегации признаков (например, "mean", "max", и т.д.)
96
+
97
+ # ---------------------------
98
+ # Параметры оптимизатора
99
+ # ---------------------------
100
+ [train.optimizer]
101
+ optimizer = "adam" # тип оптимизатора: "adam", "adamw", "lion", "sgd", "rmsprop"
102
+ lr = 1e-4 # начальная ��корость обучения
103
+ weight_decay = 0.0 # weight decay для регуляризации
104
+ momentum = 0.9 # momentum (используется только в SGD)
105
+
106
+ # ---------------------------
107
+ # Параметры шедулера
108
+ # ---------------------------
109
+ [train.scheduler]
110
+ scheduler_type = "plateau" # тип шедулера: "none", "plateau", "cosine", "onecycle" ил и HuggingFace-стиль ("huggingface_linear", "huggingface_cosine" "huggingface_cosine_with_restarts" и т.д.)
111
+ warmup_ratio = 0.1 # отношение количества warmup-итераций к общему числу шагов (0.1 = 10%)
112
+
113
+ [embeddings]
114
+ # audio_model = "amiriparian/ExHuBERT" # Hugging Face имя модели для аудио
115
+ audio_model = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" # Hugging Face имя модели для аудио
116
+ audio_classifier_checkpoint = "best_audio_model_2.pt"
117
+ text_classifier_checkpoint = "best_text_model.pth"
118
+ text_model = "jinaai/jina-embeddings-v3" # Hugging Face имя модели для текста
119
+ audio_embedding_dim = 256 # размерность аудио-эмбеддинга
120
+ text_embedding_dim = 1024 # размерность текст-эмбеддинга
121
+ emb_normalize = false # нормализовать ли вектор L2-нормой
122
+ max_tokens = 95 # ограничение на длину текста (токенов) при токенизации
123
+ device = "cuda" # "cuda" или "cpu", куда грузить модель
124
+
125
+ # audio_pooling = "mean" # "mean", "cls", "max", "min", "last", "attention"
126
+ # text_pooling = "cls" # "mean", "cls", "max", "min", "last", "sum", "attention"
127
+
128
+ [textgen]
129
+ # model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # deepseek-ai/deepseek-llm-1.3b-base или любая другая модель
130
+ # model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # deepseek-ai/deepseek-llm-1.3b-base или любая другая модель
131
+ max_new_tokens = 50
132
+ temperature = 1.0
133
+ top_p = 0.95
data_loading/__pycache__/feature_extractor.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
data_loading/__pycache__/pretrained_extractors.cpython-310.pyc ADDED
Binary file (9.4 kB). View file
 
data_loading/dataset_multimodal.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import random
5
+ import logging
6
+ import torch
7
+ import torchaudio
8
+ import whisper
9
+ import numpy as np
10
+ import pandas as pd
11
+ from torch.utils.data import Dataset
12
+ import pickle
13
+ from tqdm import tqdm
14
+ # from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor
15
+
16
+ class DatasetMultiModalWithPretrainedExtractors(Dataset):
17
+ """
18
+ Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия).
19
+
20
+ При каждом вызове __getitem__:
21
+ - Загружает WAV по video_name из CSV.
22
+ - Для обучающей выборки (split="train"):
23
+ Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки
24
+ (по merge_probability). Если да – выполняется "chain merge":
25
+ выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее,
26
+ и итоговое аудио затем обрезается до точной длины.
27
+ - Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями.
28
+ - Текст выбирается так:
29
+ • Если аудио было merged (склеено) – вызывается Whisper для получения нового текста.
30
+ • Если merge не происходило и CSV-текст не пуст – используется CSV-текст.
31
+ • Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper.
32
+ - Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ csv_path,
38
+ wav_dir,
39
+ emotion_columns,
40
+ config,
41
+ split,
42
+ audio_feature_extractor,
43
+ text_feature_extractor,
44
+ whisper_model,
45
+ dataset_name
46
+ ):
47
+ """
48
+ :param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text).
49
+ :param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav).
50
+ :param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...].
51
+ :param split: "train", "dev" или "test".
52
+ :param audio_feature_extractor: Экстрактор аудио признаков
53
+ :param text_feature_extractor: Экстрактор текстовых признаков
54
+ :param sample_rate: Целевая частота дискретизации (например, 16000).
55
+ :param wav_length: Целевая длина аудио в секундах.
56
+ :param whisper_model: Mодель Whisper ("tiny", "base", "small", ...).
57
+ :param max_text_tokens: (Не используется) – ограничение на число токенов.
58
+ :param text_column: Название колонки с текстом в CSV.
59
+ :param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper.
60
+ :param whisper_device: "cuda" или "cpu" – устройство для модели Whisper.
61
+ :param subset_size: Если > 0, используется только первые N записей из CSV (для отладки).
62
+ :param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче.
63
+ :param dataset_name: Название корпуса
64
+ """
65
+ super().__init__()
66
+ self.split = split
67
+ self.sample_rate = config.sample_rate
68
+ self.target_samples = int(config.wav_length * self.sample_rate)
69
+ self.emotion_columns = emotion_columns
70
+ self.whisper_model = whisper_model
71
+ self.text_column = config.text_column
72
+ self.use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text
73
+ self.whisper_device = config.whisper_device
74
+ self.merge_probability = config.merge_probability
75
+ self.audio_feature_extractor = audio_feature_extractor
76
+ self.text_feature_extractor = text_feature_extractor
77
+ self.subset_size = config.subset_size
78
+ self.save_prepared_data = config.save_prepared_data
79
+ self.seed = config.random_seed
80
+ self.dataset_name = dataset_name
81
+ self.save_feature_path = config.save_feature_path
82
+ self.use_synthetic_data = config.use_synthetic_data
83
+ self.synthetic_path = config.synthetic_path
84
+ self.synthetic_ratio = config.synthetic_ratio
85
+
86
+ # Загружаем CSV
87
+ if not os.path.exists(csv_path):
88
+ raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}")
89
+ df = pd.read_csv(csv_path)
90
+ if self.subset_size > 0:
91
+ df = df.head(self.subset_size)
92
+ logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={self.subset_size}).")
93
+
94
+ #копия для сохранения текста Wisper
95
+ self.original_df = df.copy()
96
+ self.whisper_csv_update_log = []
97
+
98
+ # Проверяем наличие всех колонок эмоций
99
+ missing = [c for c in emotion_columns if c not in df.columns]
100
+ if missing:
101
+ raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}")
102
+
103
+ # Проверяем существование папки с аудио
104
+ if not os.path.exists(wav_dir):
105
+ raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!")
106
+ self.wav_dir = wav_dir
107
+
108
+ # Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть)
109
+ self.rows = []
110
+ for i, rowi in df.iterrows():
111
+ audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav")
112
+ if not os.path.exists(audio_path):
113
+ continue
114
+ # Определяем доминирующую эмоцию (максимальное значение)
115
+ # print(self.emotion_columns)
116
+ emotion_values = rowi[self.emotion_columns].values.astype(float)
117
+ max_idx = np.argmax(emotion_values)
118
+ emotion_label = self.emotion_columns[max_idx]
119
+
120
+ # Извлекаем текст из CSV (если есть)
121
+ csv_text = ""
122
+ if self.text_column in rowi and isinstance(rowi[self.text_column], str):
123
+ csv_text = rowi[self.text_column]
124
+
125
+ self.rows.append({
126
+ "audio_path": audio_path,
127
+ "label": emotion_label,
128
+ "csv_text": csv_text
129
+ })
130
+
131
+ if self.use_synthetic_data and self.split == "train" and self.dataset_name.lower() == "meld":
132
+ logging.info(f"🧪 Включена синтетика для датасета '{self.dataset_name}' — добавляем примеры из: {self.synthetic_path}")
133
+ self._add_synthetic_data(self.synthetic_ratio)
134
+
135
+ # Создаем карту для поиска файлов по эмоции
136
+ self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows}
137
+
138
+ logging.info("📊 Анализ распределения файлов по эмоциям:")
139
+ emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())}
140
+ for path, emotion in self.audio_class_map.items():
141
+ emotion_counts[emotion] += 1
142
+ for emotion, count in emotion_counts.items():
143
+ logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.")
144
+
145
+ logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}")
146
+
147
+ # === Процентное семплирование ===
148
+ total_files = len(self.rows)
149
+ num_to_merge = int(total_files * self.merge_probability)
150
+
151
+ # <<< NEW: Кешируем длины (eq_len) для всех файлов >>>
152
+ self.path_info = {}
153
+ for row in self.rows:
154
+ p = row["audio_path"]
155
+ try:
156
+ info = torchaudio.info(p)
157
+ length = info.num_frames
158
+ sr_ = info.sample_rate
159
+ # переводим длину в "эквивалент self.sample_rate"
160
+ if sr_ != self.sample_rate:
161
+ ratio = sr_ / self.sample_rate
162
+ eq_len = int(length / ratio)
163
+ else:
164
+ eq_len = length
165
+ self.path_info[p] = eq_len
166
+ except Exception as e:
167
+ logging.warning(f"⚠️ Ошибка чтения {p}: {e}")
168
+ self.path_info[p] = 0 # Если не смогли прочитать, ставим 0
169
+
170
+ # Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short
171
+ self.mergable_files = [
172
+ row["audio_path"] # вместо целого dict берём строку
173
+ for row in self.rows
174
+ if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию
175
+ ]
176
+ short_count = len(self.mergable_files)
177
+
178
+ # Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие.
179
+ if short_count > num_to_merge:
180
+ self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge))
181
+ else:
182
+ self.files_to_merge = set(self.mergable_files)
183
+
184
+ logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)")
185
+ logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}")
186
+
187
+ if self.save_prepared_data:
188
+ self.meta = []
189
+
190
+ if self.use_synthetic_data:
191
+ meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_synthetic_true_pct_{}_pred.pickle'.format(
192
+ self.dataset_name,
193
+ self.split,
194
+ config.audio_classifier_checkpoint[-4:-3],
195
+ self.seed,
196
+ self.subset_size,
197
+ config.emb_normalize,
198
+ int(self.synthetic_ratio * 100)
199
+ )
200
+
201
+ else:
202
+ meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_merge_prob_{}_pred.pickle'.format(
203
+ self.dataset_name,
204
+ self.split,
205
+ config.audio_classifier_checkpoint[-4:-3],
206
+ self.seed,
207
+ self.subset_size,
208
+ config.emb_normalize,
209
+ self.merge_probability
210
+ )
211
+
212
+ pickle_path = os.path.join(self.save_feature_path, meta_filename)
213
+ self.load_data(pickle_path)
214
+
215
+ if not self.meta:
216
+ self.prepare_data()
217
+ os.makedirs(self.save_feature_path, exist_ok=True)
218
+ self.save_data(pickle_path)
219
+
220
+ def save_data(self, filename):
221
+ with open(filename, 'wb') as handle:
222
+ pickle.dump(self.meta, handle, protocol=pickle.HIGHEST_PROTOCOL)
223
+
224
+ def load_data(self, filename):
225
+ if os.path.exists(filename):
226
+ with open(filename, 'rb') as handle:
227
+ self.meta = pickle.load(handle)
228
+ else:
229
+ self.meta = []
230
+
231
+ def _is_too_short(self, audio_path):
232
+ """
233
+ (Оригинальная) Проверяем, является ли файл короче target_samples.
234
+ Использует torchaudio.info(audio_path).
235
+ Но теперь этот метод не используется, поскольку мы кешируем длины.
236
+ """
237
+ try:
238
+ info = torchaudio.info(audio_path)
239
+ length = info.num_frames
240
+ sr_ = info.sample_rate
241
+ # переводим длину в "эквивалент self.sample_rate"
242
+ if sr_ != self.sample_rate:
243
+ ratio = sr_ / self.sample_rate
244
+ eq_len = int(length / ratio)
245
+ else:
246
+ eq_len = length
247
+ return eq_len < self.target_samples
248
+ except Exception as e:
249
+ logging.warning(f"Ошибка _is_too_short({audio_path}): {e}")
250
+ return False
251
+
252
+ def _is_too_short_cached(self, audio_path):
253
+ """
254
+ (Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info.
255
+ """
256
+ eq_len = self.path_info.get(audio_path, 0)
257
+ return eq_len < self.target_samples
258
+
259
+ def __len__(self):
260
+ if self.save_prepared_data:
261
+ return len(self.meta)
262
+ else:
263
+ return len(self.rows)
264
+
265
+ def get_data(self, row):
266
+ audio_path = row["audio_path"]
267
+ label_name = row["label"]
268
+ csv_text = row["csv_text"]
269
+
270
+ # Преобразуем label в one-hot вектор
271
+ label_vec = self.emotion_to_vector(label_name)
272
+
273
+ # Шаг 1. Загружаем аудио
274
+ waveform, sr = self.load_audio(audio_path)
275
+ if waveform is None:
276
+ return None
277
+
278
+ orig_len = waveform.shape[1]
279
+ logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек")
280
+
281
+ was_merged = False
282
+ merged_texts = [csv_text] # Тексты исходного файла + добавленных
283
+
284
+ # Шаг 2. Для train, если аудио короче target_samples, проверяем:
285
+ # попал ли данный row в files_to_merge?
286
+ if self.split == "train" and row["audio_path"] in self.files_to_merge:
287
+ # chain merge
288
+ current_length = orig_len
289
+ used_candidates = set()
290
+
291
+ while current_length < self.target_samples:
292
+ needed = self.target_samples - current_length
293
+ candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10)
294
+ if candidate is None or candidate in used_candidates:
295
+ break
296
+ used_candidates.add(candidate)
297
+ add_wf, add_sr = self.load_audio(candidate)
298
+ if add_wf is None:
299
+ break
300
+ logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})")
301
+ waveform = torch.cat((waveform, add_wf), dim=1)
302
+ current_length = waveform.shape[1]
303
+ was_merged = True
304
+
305
+ # Получаем текст второго файла (если есть в CSV)
306
+ add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "")
307
+ merged_texts.append(add_csv_text)
308
+
309
+ logging.debug(f"📜 Текст первого файла: {csv_text}")
310
+ logging.debug(f"📜 Текст добавленного файла: {add_csv_text}")
311
+ else:
312
+ # Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge
313
+ logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.")
314
+
315
+ if was_merged:
316
+ logging.debug("📝 Текст: аудио было merged – вызываем Whisper.")
317
+ text_final = self.run_whisper(waveform)
318
+ logging.debug(f"🆕 Whisper предсказал: {text_final}")
319
+
320
+ merge_components = [os.path.splitext(os.path.basename(audio_path))[0]]
321
+ merge_components += [os.path.splitext(os.path.basename(p))[0] for p in used_candidates]
322
+
323
+ self.whisper_csv_update_log.append({
324
+ "video_name": os.path.splitext(os.path.basename(audio_path))[0],
325
+ "text_new": text_final,
326
+ "text_old": csv_text,
327
+ "was_merged": True,
328
+ "merge_components": merge_components
329
+ })
330
+
331
+ else:
332
+ if csv_text.strip():
333
+ logging.debug("Текст: используем CSV-текст (не пуст).")
334
+ text_final = csv_text
335
+ else:
336
+ if self.split == "train" or self.use_whisper_for_nontrain_if_no_text:
337
+ logging.debug("Текст: CSV пустой – вызываем Whisper.")
338
+ text_final = self.run_whisper(waveform)
339
+ else:
340
+ logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.")
341
+ text_final = ""
342
+
343
+ audio_pred, audion_emb = self.audio_feature_extractor.extract(waveform[0], self.sample_rate)
344
+ text_pred, text_emb = self.text_feature_extractor.extract(text_final)
345
+
346
+ return {
347
+ "audio_path": os.path.basename(audio_path),
348
+ "audio": audion_emb[0],
349
+ "label": label_vec,
350
+ "text": text_emb[0],
351
+ "audio_pred": audio_pred[0],
352
+ "text_pred": text_pred[0]
353
+ }
354
+
355
+ def prepare_data(self):
356
+ """
357
+ Загружает и обрабатывает один элемент датасета,
358
+ сохраняет эмбеддинги и обновлённый текст (если было склеено).
359
+ """
360
+ for idx, row in enumerate(tqdm(self.rows)):
361
+ curr_dict = self.get_data(row)
362
+ if curr_dict is not None:
363
+ self.meta.append(curr_dict)
364
+
365
+ # === Сохраняем CSV с обновлёнными текстами (только если был merge) ===
366
+ if self.whisper_csv_update_log:
367
+ df_log = pd.DataFrame(self.whisper_csv_update_log)
368
+
369
+ # Копия исходного CSV
370
+ df_out = self.original_df.copy()
371
+
372
+ # Мержим по video_name
373
+ df_out = df_out.merge(df_log, on="video_name", how="left")
374
+
375
+ # Обновляем текст: заменяем только если Whisper сгенерировал
376
+ df_out["text_final"] = df_out["text_new"].combine_first(df_out["text"])
377
+ df_out["text_old"] = df_out["text"]
378
+ df_out["text"] = df_out["text_final"]
379
+ df_out["was_merged"] = df_out["was_merged"].fillna(False).astype(bool)
380
+
381
+ # Преобразуем merge_components в строку
382
+ df_out["merge_components"] = df_out["merge_components"].apply(
383
+ lambda x: ";".join(x) if isinstance(x, list) else ""
384
+ )
385
+
386
+ # Чистим временные колонки
387
+ df_out = df_out.drop(columns=["text_new", "text_final"])
388
+
389
+ # Сохраняем как CSV
390
+ output_path = os.path.join(self.save_feature_path, f"{self.dataset_name}_{self.split}_merged_whisper_{self.merge_probability *100}.csv")
391
+ os.makedirs(self.save_feature_path, exist_ok=True)
392
+ df_out.to_csv(output_path, index=False, encoding="utf-8")
393
+ logging.info(f"📄 Обновлённый merged CSV сохранён: {output_path}")
394
+
395
+ def __getitem__(self, index):
396
+ if self.save_prepared_data:
397
+ return self.meta[index]
398
+ else:
399
+ return self.get_data(self.rows[index])
400
+
401
+ def load_audio(self, path):
402
+ """
403
+ Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо.
404
+ """
405
+ if not os.path.exists(path):
406
+ logging.warning(f"Файл отсутствует: {path}")
407
+ return None, None
408
+ try:
409
+ wf, sr = torchaudio.load(path)
410
+ if sr != self.sample_rate:
411
+ resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
412
+ wf = resampler(wf)
413
+ sr = self.sample_rate
414
+ return wf, sr
415
+ except Exception as e:
416
+ logging.error(f"Ошибка загрузки {path}: {e}")
417
+ return None, None
418
+
419
+ def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5):
420
+ """
421
+ Ищет аудиофайл с той же эмоцией.
422
+ 1) Если есть файлы >= min_needed, выбираем случайно из них.
423
+ 2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный.
424
+ """
425
+
426
+ candidates = [p for p, lbl in self.audio_class_map.items()
427
+ if lbl == label_name and p != exclude_path]
428
+ logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'")
429
+
430
+ # Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info
431
+ all_info = []
432
+ for path in candidates:
433
+ # <<< NEW: вместо info = torchaudio.info(path) ...
434
+ eq_len = self.path_info.get(path, 0) # Получаем из кэша
435
+ all_info.append((eq_len, path))
436
+
437
+ valid = [(l, p) for l, p in all_info if l >= min_needed]
438
+ logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})")
439
+
440
+ if valid:
441
+ # Если есть идеальные — берём случайно из них
442
+ random.shuffle(valid)
443
+ chosen = random.choice(valid)[1]
444
+ return chosen
445
+ else:
446
+ # 2) Если идеальных нет — берём топ-K по длине
447
+ sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True)
448
+ top_k_list = sorted_by_len[:top_k]
449
+ if not top_k_list:
450
+ logging.debug("Нет доступных кандидатов вообще.")
451
+ return None # вообще нет кандидатов
452
+
453
+ random.shuffle(top_k_list)
454
+ chosen = top_k_list[0][1]
455
+ logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}")
456
+ return chosen
457
+
458
+ def run_whisper(self, waveform):
459
+ """
460
+ Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов).
461
+ """
462
+ arr = waveform.squeeze().cpu().numpy()
463
+ try:
464
+ with torch.no_grad():
465
+ result = self.whisper_model.transcribe(arr, fp16=False)
466
+ text = result["text"].strip()
467
+ return text
468
+ except Exception as e:
469
+ logging.error(f"Whisper ошибка: {e}")
470
+ return ""
471
+
472
+ def _add_synthetic_data(self, synthetic_ratio):
473
+ """
474
+ Добавляет synthetic_ratio (0..1) от количества доступных синтетических файлов на каждую эмоцию.
475
+ """
476
+ if not self.synthetic_path:
477
+ logging.warning("⚠ Путь к синтетическим данным не указан.")
478
+ return
479
+
480
+ random.seed(self.seed)
481
+
482
+ synth_csv_path = os.path.join(self.synthetic_path, "meld_s_train_labels.csv")
483
+ synth_wav_dir = os.path.join(self.synthetic_path, "wavs")
484
+
485
+ if not (os.path.exists(synth_csv_path) and os.path.exists(synth_wav_dir)):
486
+ logging.warning("⚠ Синтетические данные не найдены.")
487
+ return
488
+
489
+ df_synth = pd.read_csv(synth_csv_path)
490
+ rows_by_label = {emotion: [] for emotion in self.emotion_columns}
491
+
492
+ for _, row in df_synth.iterrows():
493
+ audio_path = os.path.join(synth_wav_dir, f"{row['video_name']}.wav")
494
+ if not os.path.exists(audio_path):
495
+ continue
496
+ emotion_values = row[self.emotion_columns].values.astype(float)
497
+ max_idx = np.argmax(emotion_values)
498
+ label = self.emotion_columns[max_idx]
499
+ csv_text = row[self.text_column] if self.text_column in row and isinstance(row[self.text_column], str) else ""
500
+ rows_by_label[label].append({
501
+ "audio_path": audio_path,
502
+ "label": label,
503
+ "csv_text": csv_text
504
+ })
505
+
506
+ added = 0
507
+ for label in self.emotion_columns:
508
+ candidates = rows_by_label[label]
509
+ if not candidates:
510
+ continue
511
+ count_synth = int(len(candidates) * synthetic_ratio)
512
+ if count_synth <= 0:
513
+ continue
514
+ selected = random.sample(candidates, count_synth)
515
+ self.rows.extend(selected)
516
+ added += len(selected)
517
+ logging.info(f"➕ Добавлено {len(selected)} синтетических примеров для эмоции '{label}'")
518
+
519
+ logging.info(f"📦 Всего добавлено {added} синтетических примеров из MELD_S")
520
+
521
+ def emotion_to_vector(self, label_name):
522
+ """
523
+ Преобразует название эмоции в one-hot вектор (torch.tensor).
524
+ """
525
+ v = np.zeros(len(self.emotion_columns), dtype=np.float32)
526
+ if label_name in self.emotion_columns:
527
+ idx = self.emotion_columns.index(label_name)
528
+ v[idx] = 1.0
529
+ return torch.tensor(v, dtype=torch.float32)
530
+
531
+ class DatasetMultiModal(Dataset):
532
+ """
533
+ Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия).
534
+
535
+ При каждом вызове __getitem__:
536
+ - Загружает WAV по video_name из CSV.
537
+ - Для обучающей выборки (split="train"):
538
+ Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки
539
+ (по merge_probability). Если да – выполняется "chain merge":
540
+ выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее,
541
+ и итоговое аудио затем обрезается до точной длины.
542
+ - Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями.
543
+ - Текст выбирается так:
544
+ • Если аудио было merged (склеено) – вызывается Whisper для получения нового текста.
545
+ • Если merge не происходило и CSV-текст не пуст – используется CSV-текст.
546
+ • Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper.
547
+ - Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }.
548
+ """
549
+
550
+ def __init__(
551
+ self,
552
+ csv_path,
553
+ wav_dir,
554
+ emotion_columns,
555
+ split="train",
556
+ sample_rate=16000,
557
+ wav_length=4,
558
+ whisper_model="tiny",
559
+ text_column="text",
560
+ use_whisper_for_nontrain_if_no_text=True,
561
+ whisper_device="cuda",
562
+ subset_size=0,
563
+ merge_probability=1.0 # <-- Новый параметр: доля от ОБЩЕГО числа файлов
564
+ ):
565
+ """
566
+ :param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text).
567
+ :param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav).
568
+ :param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...].
569
+ :param split: "train", "dev" или "test".
570
+ :param sample_rate: Целевая частота дискретизации (например, 16000).
571
+ :param wav_length: Целевая длина аудио в секундах.
572
+ :param whisper_model: Название модели Whisper ("tiny", "base", "small", ...).
573
+ :param max_text_tokens: (Не используется) – ограничение на число токенов.
574
+ :param text_column: Название колонки с текстом в CSV.
575
+ :param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper.
576
+ :param whisper_device: "cuda" или "cpu" – устройство для модели Whisper.
577
+ :param subset_size: Если > 0, используется только первые N записей из CSV (для отладки).
578
+ :param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче.
579
+ """
580
+ super().__init__()
581
+ self.split = split
582
+ self.sample_rate = sample_rate
583
+ self.target_samples = int(wav_length * sample_rate)
584
+ self.emotion_columns = emotion_columns
585
+ self.whisper_model_name = whisper_model
586
+ self.text_column = text_column
587
+ self.use_whisper_for_nontrain_if_no_text = use_whisper_for_nontrain_if_no_text
588
+ self.whisper_device = whisper_device
589
+ self.merge_probability = merge_probability
590
+
591
+ # Загружаем CSV
592
+ if not os.path.exists(csv_path):
593
+ raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}")
594
+ df = pd.read_csv(csv_path)
595
+ if subset_size > 0:
596
+ df = df.head(subset_size)
597
+ logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={subset_size}).")
598
+
599
+ # Проверяем наличие всех колонок эмоций
600
+ missing = [c for c in emotion_columns if c not in df.columns]
601
+ if missing:
602
+ raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}")
603
+
604
+ # Проверяем существование папки с аудио
605
+ if not os.path.exists(wav_dir):
606
+ raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!")
607
+ self.wav_dir = wav_dir
608
+
609
+ # Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть)
610
+ self.rows = []
611
+ for i, rowi in df.iterrows():
612
+ audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav")
613
+ if not os.path.exists(audio_path):
614
+ continue
615
+ # Определяем доминирующую эмоцию (максимальное значение)
616
+ emotion_values = rowi[self.emotion_columns].values.astype(float)
617
+ max_idx = np.argmax(emotion_values)
618
+ emotion_label = self.emotion_columns[max_idx]
619
+
620
+ # Извлекаем текст из CSV (если есть)
621
+ csv_text = ""
622
+ if self.text_column in rowi and isinstance(rowi[self.text_column], str):
623
+ csv_text = rowi[self.text_column]
624
+
625
+ self.rows.append({
626
+ "audio_path": audio_path,
627
+ "label": emotion_label,
628
+ "csv_text": csv_text
629
+ })
630
+
631
+ # Создаем карту для поиска файлов по эмоции
632
+ self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows}
633
+
634
+ logging.info("📊 Анализ распределения файлов по эмоциям:")
635
+ emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())}
636
+ for path, emotion in self.audio_class_map.items():
637
+ emotion_counts[emotion] += 1
638
+ for emotion, count in emotion_counts.items():
639
+ logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.")
640
+
641
+ logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}")
642
+
643
+ # === Процентное семплирование ===
644
+ total_files = len(self.rows)
645
+ num_to_merge = int(total_files * self.merge_probability)
646
+
647
+ # <<< NEW: Кешируем длины (eq_len) для всех файлов >>>
648
+ self.path_info = {}
649
+ for row in self.rows:
650
+ p = row["audio_path"]
651
+ try:
652
+ info = torchaudio.info(p)
653
+ length = info.num_frames
654
+ sr_ = info.sample_rate
655
+ # переводим длину в "эквивалент self.sample_rate"
656
+ if sr_ != self.sample_rate:
657
+ ratio = sr_ / self.sample_rate
658
+ eq_len = int(length / ratio)
659
+ else:
660
+ eq_len = length
661
+ self.path_info[p] = eq_len
662
+ except Exception as e:
663
+ logging.warning(f"⚠️ Ошибка чтения {p}: {e}")
664
+ self.path_info[p] = 0 # Если не смогли прочитать, ставим 0
665
+
666
+ # Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short
667
+ self.mergable_files = [
668
+ row["audio_path"] # вместо целого dict берём строку
669
+ for row in self.rows
670
+ if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию
671
+ ]
672
+ short_count = len(self.mergable_files)
673
+
674
+ # Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие.
675
+ if short_count > num_to_merge:
676
+ self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge))
677
+ else:
678
+ self.files_to_merge = set(self.mergable_files)
679
+
680
+ logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)")
681
+ logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}")
682
+
683
+ # Инициализируем Whisper-модель один раз
684
+ logging.info(f"Инициализация Whisper: модель={whisper_model}, устройство={whisper_device}")
685
+ self.whisper_model = whisper.load_model(whisper_model, device=whisper_device).eval()
686
+ # print(f"📦 Whisper работает на устройстве: {self.whisper_model.device}")
687
+
688
+ def _is_too_short(self, audio_path):
689
+ """
690
+ (Оригинальная) Проверяем, является ли файл короче target_samples.
691
+ Использует torchaudio.info(audio_path).
692
+ Но теперь этот метод не используется, поскольку мы кешируем длины.
693
+ """
694
+ try:
695
+ info = torchaudio.info(audio_path)
696
+ length = info.num_frames
697
+ sr_ = info.sample_rate
698
+ # переводим длину в "эквивалент self.sample_rate"
699
+ if sr_ != self.sample_rate:
700
+ ratio = sr_ / self.sample_rate
701
+ eq_len = int(length / ratio)
702
+ else:
703
+ eq_len = length
704
+ return eq_len < self.target_samples
705
+ except Exception as e:
706
+ logging.warning(f"Ошибка _is_too_short({audio_path}): {e}")
707
+ return False
708
+
709
+ def _is_too_short_cached(self, audio_path):
710
+ """
711
+ (Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info.
712
+ """
713
+ eq_len = self.path_info.get(audio_path, 0)
714
+ return eq_len < self.target_samples
715
+
716
+ def __len__(self):
717
+ return len(self.rows)
718
+
719
+ def __getitem__(self, index):
720
+ """
721
+ Загружает и обрабатывает один элемент датасета (он‑the‑fly).
722
+ """
723
+ row = self.rows[index]
724
+ audio_path = row["audio_path"]
725
+ label_name = row["label"]
726
+ csv_text = row["csv_text"]
727
+
728
+ # Преобразуем label в one-hot вектор
729
+ label_vec = self.emotion_to_vector(label_name)
730
+
731
+ # Шаг 1. Загружаем аудио
732
+ waveform, sr = self.load_audio(audio_path)
733
+ if waveform is None:
734
+ return None
735
+
736
+ orig_len = waveform.shape[1]
737
+ logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек")
738
+
739
+ was_merged = False
740
+ merged_texts = [csv_text] # Тексты исходного файла + добавленных
741
+
742
+ # Шаг 2. Для train, если аудио короче target_samples, проверяем:
743
+ # попал ли данный row в files_to_merge?
744
+ if self.split == "train" and row["audio_path"] in self.files_to_merge:
745
+ # chain merge
746
+ current_length = orig_len
747
+ used_candidates = set()
748
+
749
+ while current_length < self.target_samples:
750
+ needed = self.target_samples - current_length
751
+ candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10)
752
+ if candidate is None or candidate in used_candidates:
753
+ break
754
+ used_candidates.add(candidate)
755
+ add_wf, add_sr = self.load_audio(candidate)
756
+ if add_wf is None:
757
+ break
758
+ logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})")
759
+ waveform = torch.cat((waveform, add_wf), dim=1)
760
+ current_length = waveform.shape[1]
761
+ was_merged = True
762
+
763
+ # Получаем текст второго файла (если есть в CSV)
764
+ add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "")
765
+ merged_texts.append(add_csv_text)
766
+
767
+ logging.debug(f"📜 Текст первого файла: {csv_text}")
768
+ logging.debug(f"📜 Текст добавленного файла: {add_csv_text}")
769
+ else:
770
+ # Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge
771
+ logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.")
772
+
773
+ # Шаг 3. Если итоговая длина меньше target_samples, паддинг нулями
774
+ curr_len = waveform.shape[1]
775
+ if curr_len < self.target_samples:
776
+ pad_size = self.target_samples - curr_len
777
+ logging.debug(f"Паддинг {os.path.basename(audio_path)}: +{pad_size} сэмплов")
778
+ waveform = torch.nn.functional.pad(waveform, (0, pad_size))
779
+
780
+ # Шаг 4. Обрезаем аудио до target_samples (если вышло больше)
781
+ waveform = waveform[:, :self.target_samples]
782
+ logging.debug(f"Финальная длина {os.path.basename(audio_path)}: {waveform.shape[1]/sr:.2f} сек; was_merged={was_merged}")
783
+
784
+ # Шаг 5. Получаем текст
785
+ if was_merged:
786
+ logging.debug("📝 Текст: аудио было merged – вызываем Whisper.")
787
+ text_final = self.run_whisper(waveform)
788
+ logging.debug(f"🆕 Whisper предсказал: {text_final}")
789
+ else:
790
+ if csv_text.strip():
791
+ logging.debug("Текст: используем CSV-текст (не пуст).")
792
+ text_final = csv_text
793
+ else:
794
+ if self.split == "train" or self.use_whisper_for_nontrain_if_no_text:
795
+ logging.debug("Текст: CSV пустой – вызываем Whisper.")
796
+ text_final = self.run_whisper(waveform)
797
+ else:
798
+ logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.")
799
+ text_final = ""
800
+
801
+ return {
802
+ "audio_path": os.path.basename(audio_path), # new
803
+ "audio": waveform,
804
+ "label": label_vec,
805
+ "text": text_final
806
+ }
807
+
808
+ def load_audio(self, path):
809
+ """
810
+ Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо.
811
+ """
812
+ if not os.path.exists(path):
813
+ logging.warning(f"Файл отсутствует: {path}")
814
+ return None, None
815
+ try:
816
+ wf, sr = torchaudio.load(path)
817
+ if sr != self.sample_rate:
818
+ resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
819
+ wf = resampler(wf)
820
+ sr = self.sample_rate
821
+ return wf, sr
822
+ except Exception as e:
823
+ logging.error(f"Ошибка загрузки {path}: {e}")
824
+ return None, None
825
+
826
+ def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5):
827
+ """
828
+ Ищет аудиофайл с той же эмоцией.
829
+ 1) Если есть файлы >= min_needed, выбираем случайно из них.
830
+ 2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный.
831
+ """
832
+
833
+ candidates = [p for p, lbl in self.audio_class_map.items()
834
+ if lbl == label_name and p != exclude_path]
835
+ logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'")
836
+
837
+ # Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info
838
+ all_info = []
839
+ for path in candidates:
840
+ # <<< NEW: вместо info = torchaudio.info(path) ...
841
+ eq_len = self.path_info.get(path, 0) # Получаем из кэша
842
+ all_info.append((eq_len, path))
843
+
844
+ # --- Ниже старый код, который был:
845
+ # for path in candidates:
846
+ # try:
847
+ # info = torchaudio.info(path)
848
+ # length = info.num_frames
849
+ # sr_ = info.sample_rate
850
+ # eq_len = int(length / (sr_ / self.sample_rate)) if sr_ != self.sample_rate else length
851
+ # all_info.append((eq_len, path))
852
+ # except Exception as e:
853
+ # logging.warning(f"⚠ Ошибка чтения {path}: {e}")
854
+
855
+ # 1) Фильтруем только >= min_needed
856
+ valid = [(l, p) for l, p in all_info if l >= min_needed]
857
+ logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})")
858
+
859
+ if valid:
860
+ # Если есть идеальные — берём случайно из них
861
+ random.shuffle(valid)
862
+ chosen = random.choice(valid)[1]
863
+ return chosen
864
+ else:
865
+ # 2) Если идеальных нет — берём топ-K по длине
866
+ sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True)
867
+ top_k_list = sorted_by_len[:top_k]
868
+ if not top_k_list:
869
+ logging.debug("Нет доступных кандидатов вообще.")
870
+ return None # вообще нет кандидатов
871
+
872
+ random.shuffle(top_k_list)
873
+ chosen = top_k_list[0][1]
874
+ logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}")
875
+ return chosen
876
+
877
+ def run_whisper(self, waveform):
878
+ """
879
+ Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов).
880
+ """
881
+ arr = waveform.squeeze().cpu().numpy()
882
+ try:
883
+ result = self.whisper_model.transcribe(arr, fp16=False)
884
+ text = result["text"].strip()
885
+ return text
886
+ except Exception as e:
887
+ logging.error(f"Whisper ошибка: {e}")
888
+ return ""
889
+
890
+ def emotion_to_vector(self, label_name):
891
+ """
892
+ Преобразует название эмоции в one-hot вектор (torch.tensor).
893
+ """
894
+ v = np.zeros(len(self.emotion_columns), dtype=np.float32)
895
+ if label_name in self.emotion_columns:
896
+ idx = self.emotion_columns.index(label_name)
897
+ v[idx] = 1.0
898
+ return torch.tensor(v, dtype=torch.float32)
data_loading/feature_extractor.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_loading/feature_extractor.py
2
+
3
+ import torch
4
+ import logging
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from transformers import (
8
+ AutoFeatureExtractor,
9
+ AutoModel,
10
+ AutoTokenizer,
11
+ AutoModelForAudioClassification,
12
+ Wav2Vec2Processor
13
+ )
14
+ from data_loading.pretrained_extractors import EmotionModel, get_model_mamba, Mamba
15
+
16
+
17
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ # DEVICE = torch.device('cpu')
19
+
20
+
21
+ class PretrainedAudioEmbeddingExtractor:
22
+ """
23
+ Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'),
24
+ с учётом pooling, нормализации и т.д.
25
+ """
26
+
27
+ def __init__(self, config):
28
+ """
29
+ Ожидается, что в config есть поля:
30
+ - audio_model_name (str) : название модели (ExHuBERT и т.п.)
31
+ - emb_device (str) : "cpu" или "cuda"
32
+ - audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг)
33
+ - emb_normalize (bool) : делать ли L2-нормализацию выхода
34
+ - max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать)
35
+ """
36
+ self.config = config
37
+ self.device = config.emb_device
38
+ self.model_name = config.audio_model_name
39
+ self.pooling = config.audio_pooling # может быть None
40
+ self.normalize_output = config.emb_normalize
41
+ self.max_audio_frames = getattr(config, "max_audio_frames", 0)
42
+ self.audio_classifier_checkpoint = config.audio_classifier_checkpoint
43
+
44
+ # Инициализируем processor и audio_embedder
45
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
46
+ self.audio_embedder = EmotionModel.from_pretrained(self.model_name).to(self.device)
47
+
48
+ # Загружаем модель
49
+ self.classifier_model = self.load_classifier_model_from_checkpoint(self.audio_classifier_checkpoint)
50
+
51
+
52
+ def extract(self, waveform: torch.Tensor, sample_rate=16000):
53
+ """
54
+ Извлекает эмбеддинги из аудиоданных.
55
+
56
+ :param waveform: Тензор формы (T).
57
+ :param sample_rate: Частота дискретизации (int).
58
+ :return: Тензоры:
59
+ вернётся (B, classes), (B, sequence_length, hidden_dim).
60
+ """
61
+
62
+ embeddings = self.process_audio(waveform, sample_rate)
63
+ tensor_emb = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
64
+ lengths = [tensor_emb.shape[1]]
65
+
66
+ with torch.no_grad():
67
+ logits, hidden = self.classifier_model(tensor_emb, lengths, with_features=True)
68
+
69
+ # Если pooling=None => вернём (B, seq_len, hidden_dim)
70
+ if hidden.dim() == 3:
71
+ if self.pooling is None:
72
+ emb = hidden
73
+ else:
74
+ if self.pooling == "mean":
75
+ emb = hidden.mean(dim=1)
76
+ elif self.pooling == "cls":
77
+ emb = hidden[:, 0, :]
78
+ elif self.pooling == "max":
79
+ emb, _ = hidden.max(dim=1)
80
+ elif self.pooling == "min":
81
+ emb, _ = hidden.min(dim=1)
82
+ elif self.pooling == "last":
83
+ emb = hidden[:, -1, :]
84
+ elif self.pooling == "sum":
85
+ emb = hidden.sum(dim=1)
86
+ else:
87
+ emb = hidden.mean(dim=1)
88
+ else:
89
+ # На всякий случай, если получилось (B, hidden_dim)
90
+ emb = hidden
91
+
92
+ if self.normalize_output and emb.dim() == 2:
93
+ emb = F.normalize(emb, p=2, dim=1)
94
+
95
+ return logits, emb
96
+
97
+ def process_audio(self, signal: np.ndarray, sampling_rate: int) -> np.ndarray:
98
+ inputs = self.processor(signal, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
99
+ input_values = inputs["input_values"].to(self.device)
100
+
101
+ with torch.no_grad():
102
+ outputs = self.audio_embedder(input_values)
103
+ embeddings = outputs
104
+
105
+ return embeddings.detach().cpu().numpy()
106
+
107
+ def load_classifier_model_from_checkpoint(self, checkpoint_path):
108
+ if checkpoint_path == "best_audio_model.pt":
109
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
110
+ exp_params = checkpoint['exp_params']
111
+ classifier_model = get_model_mamba(exp_params).to(self.device)
112
+ classifier_model.load_state_dict(checkpoint['model_state_dict'])
113
+ elif checkpoint_path == "best_audio_model_2.pt":
114
+ model_params = {
115
+ "input_size": 1024,
116
+ "d_model": 256,
117
+ "num_layers": 2,
118
+ "num_classes": 7,
119
+ "dropout": 0.2
120
+ }
121
+ classifier_model = get_model_mamba(model_params).to(self.device)
122
+ classifier_model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
123
+ classifier_model.eval()
124
+ return classifier_model
125
+
126
+ class AudioEmbeddingExtractor:
127
+ """
128
+ Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'),
129
+ с учётом pooling, нормализации и т.д.
130
+ """
131
+
132
+ def __init__(self, config):
133
+ """
134
+ Ожидается, что в config есть поля:
135
+ - audio_model_name (str) : название модели (ExHuBERT и т.п.)
136
+ - emb_device (str) : "cpu" или "cuda"
137
+ - audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг)
138
+ - emb_normalize (bool) : делать ли L2-нормализацию выхода
139
+ - max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать)
140
+ """
141
+ self.config = config
142
+ self.device = config.emb_device
143
+ self.model_name = config.audio_model_name
144
+ self.pooling = config.audio_pooling # может быть None
145
+ self.normalize_output = config.emb_normalize
146
+ # self.max_audio_frames = getattr(config, "max_audio_frames", 0)
147
+ self.max_audio_frames = config.sample_rate * config.wav_length
148
+
149
+
150
+ # Попробуем загрузить feature_extractor (не у всех моделей доступен)
151
+ try:
152
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name)
153
+ logging.info(f"[Audio] Using AutoFeatureExtractor for '{self.model_name}'")
154
+ except Exception as e:
155
+ self.feature_extractor = None
156
+ logging.warning(f"[Audio] No built-in FeatureExtractor found. Model={self.model_name}. Error: {e}")
157
+
158
+ # Загружаем модель
159
+ # Если у модели нет head-классификации, бывает достаточно AutoModel
160
+ try:
161
+ self.model = AutoModel.from_pretrained(
162
+ self.model_name,
163
+ output_hidden_states=True # чтобы точно был last_hidden_state
164
+ ).to(self.device)
165
+ logging.info(f"[Audio] Loaded AutoModel with output_hidden_states=True: {self.model_name}")
166
+ except Exception as e:
167
+ logging.warning(f"[Audio] Fallback to AudioClassification model. Reason: {e}")
168
+ self.model = AutoModelForAudioClassification.from_pretrained(
169
+ self.model_name,
170
+ output_hidden_states=True
171
+ ).to(self.device)
172
+
173
+ def extract(self, waveform_batch: torch.Tensor, sample_rate=16000):
174
+ """
175
+ Извлекает эмбеддинги из аудиоданных.
176
+
177
+ :param waveform_batch: Тензор формы (B, T) или (B, 1, T).
178
+ :param sample_rate: Частота дискретизации (int).
179
+ :return: Тензор:
180
+ - если pooling != None, будет (B, hidden_dim)
181
+ - если pooling == None и last_hidden_state имел форму (B, seq_len, hidden_dim),
182
+ вернётся (B, seq_len, hidden_dim).
183
+ """
184
+
185
+ # Если пришло (B, 1, T), уберём ось "1"
186
+ if waveform_batch.dim() == 3 and waveform_batch.shape[1] == 1:
187
+ waveform_batch = waveform_batch.squeeze(1) # -> (B, T)
188
+
189
+ # Усечение по времени, если нужно
190
+ if self.max_audio_frames > 0 and waveform_batch.shape[1] > self.max_audio_frames:
191
+ waveform_batch = waveform_batch[:, :self.max_audio_frames]
192
+
193
+ # Если есть feature_extractor - используем
194
+ if self.feature_extractor is not None:
195
+ inputs = self.feature_extractor(
196
+ waveform_batch,
197
+ sampling_rate=sample_rate,
198
+ return_tensors="pt",
199
+ truncation=True,
200
+ max_length=self.max_audio_frames if self.max_audio_frames > 0 else None
201
+ )
202
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
203
+
204
+ outputs = self.model(input_values=inputs["input_values"])
205
+ else:
206
+ # Иначе подадим напрямую "input_values" на модель
207
+ inputs = {"input_values": waveform_batch.to(self.device)}
208
+ outputs = self.model(**inputs)
209
+
210
+ # Теперь outputs может быть BaseModelOutput (с last_hidden_state, hidden_states, etc.)
211
+ # Или SequenceClassifierOutput (с logits), если это модель-классификатор
212
+ if hasattr(outputs, "last_hidden_state"):
213
+ # (B, seq_len, hidden_dim)
214
+ hidden = outputs.last_hidden_state
215
+ # logging.debug(f"[Audio] last_hidden_state shape: {hidden.shape}")
216
+ elif hasattr(outputs, "logits"):
217
+ # logits: (B, num_labels)
218
+ # Для пуллинга по "seq_len" притворимся, что seq_len=1
219
+ hidden = outputs.logits.unsqueeze(1) # (B,1,num_labels)
220
+ logging.debug(f"[Audio] Found logits shape: {outputs.logits.shape} => hidden={hidden.shape}")
221
+ else:
222
+ # Модель может сразу возвращать тензор
223
+ hidden = outputs
224
+
225
+ # Если у нас 2D-тензор (B, hidden_dim), значит всё уже спулено
226
+ if hidden.dim() == 2:
227
+ emb = hidden
228
+ elif hidden.dim() == 3:
229
+ # (B, seq_len, hidden_dim)
230
+ if self.pooling is None:
231
+ # Возвращаем как есть
232
+ emb = hidden
233
+ else:
234
+ # Выполним пуллинг
235
+ if self.pooling == "mean":
236
+ emb = hidden.mean(dim=1)
237
+ elif self.pooling == "cls":
238
+ emb = hidden[:, 0, :] # [B, hidden_dim]
239
+ elif self.pooling == "max":
240
+ emb, _ = hidden.max(dim=1)
241
+ elif self.pooling == "min":
242
+ emb, _ = hidden.min(dim=1)
243
+ elif self.pooling == "last":
244
+ emb = hidden[:, -1, :]
245
+ else:
246
+ emb = hidden.mean(dim=1) # на всякий случай fallback
247
+ else:
248
+ # На всякий: если ещё какая-то форма
249
+ raise ValueError(f"[Audio] Unexpected hidden shape={hidden.shape}, pooling={self.pooling}")
250
+
251
+ if self.normalize_output and emb.dim() == 2:
252
+ emb = F.normalize(emb, p=2, dim=1)
253
+
254
+ return emb
255
+
256
+
257
+ class TextEmbeddingExtractor:
258
+ """
259
+ Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'),
260
+ с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения.
261
+ """
262
+
263
+ def __init__(self, config):
264
+ """
265
+ Параметры в config:
266
+ - text_model_name (str)
267
+ - emb_device (str)
268
+ - text_pooling (str | None)
269
+ - emb_normalize (bool)
270
+ - max_tokens (int)
271
+ """
272
+ self.config = config
273
+ self.device = config.emb_device
274
+ self.model_name = config.text_model_name
275
+ self.pooling = config.text_pooling # может быть None
276
+ self.normalize_output = config.emb_normalize
277
+ self.max_tokens = config.max_tokens
278
+
279
+ # trust_remote_code=True нужно для моделей вроде jina
280
+ logging.info(f"[Text] Loading tokenizer for {self.model_name} with trust_remote_code=True")
281
+ self.tokenizer = AutoTokenizer.from_pretrained(
282
+ self.model_name,
283
+ trust_remote_code=True
284
+ )
285
+
286
+ logging.info(f"[Text] Loading model for {self.model_name} with trust_remote_code=True")
287
+ self.model = AutoModel.from_pretrained(
288
+ self.model_name,
289
+ trust_remote_code=True,
290
+ output_hidden_states=True, # хотим иметь last_hidden_state
291
+ force_download=False
292
+ ).to(self.device)
293
+
294
+ def extract(self, text_list):
295
+ """
296
+ :param text_list: список строк (или одна строка)
297
+ :return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None
298
+ """
299
+
300
+ if isinstance(text_list, str):
301
+ text_list = [text_list]
302
+
303
+ inputs = self.tokenizer(
304
+ text_list,
305
+ padding="max_length",
306
+ truncation=True,
307
+ max_length=self.max_tokens,
308
+ return_tensors="pt"
309
+ )
310
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
311
+
312
+ with torch.no_grad():
313
+ outputs = self.model(**inputs)
314
+ # Обычно у AutoModel last_hidden_state.shape = (B, seq_len, hidden_dim)
315
+ hidden = outputs.last_hidden_state
316
+ # logging.debug(f"[Text] last_hidden_state shape: {hidden.shape}")
317
+
318
+ # Если pooling=None => вернём (B, seq_len, hidden_dim)
319
+ if hidden.dim() == 3:
320
+ if self.pooling is None:
321
+ emb = hidden
322
+ else:
323
+ if self.pooling == "mean":
324
+ emb = hidden.mean(dim=1)
325
+ elif self.pooling == "cls":
326
+ emb = hidden[:, 0, :]
327
+ elif self.pooling == "max":
328
+ emb, _ = hidden.max(dim=1)
329
+ elif self.pooling == "min":
330
+ emb, _ = hidden.min(dim=1)
331
+ elif self.pooling == "last":
332
+ emb = hidden[:, -1, :]
333
+ elif self.pooling == "sum":
334
+ emb = hidden.sum(dim=1)
335
+ else:
336
+ emb = hidden.mean(dim=1)
337
+ else:
338
+ # На всякий случай, если получилось (B, hidden_dim)
339
+ emb = hidden
340
+
341
+ if self.normalize_output and emb.dim() == 2:
342
+ emb = F.normalize(emb, p=2, dim=1)
343
+
344
+ return emb
345
+
346
+ class PretrainedTextEmbeddingExtractor:
347
+ """
348
+ Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'),
349
+ с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения.
350
+ """
351
+
352
+ def __init__(self, config):
353
+ """
354
+ Параметры в config:
355
+ - text_model_name (str)
356
+ - emb_device (str)
357
+ - text_pooling (str | None)
358
+ - emb_normalize (bool)
359
+ - max_tokens (int)
360
+ """
361
+ self.config = config
362
+ self.device = config.emb_device
363
+ self.model_name = config.text_model_name
364
+ self.pooling = config.text_pooling # может быть None
365
+ self.normalize_output = config.emb_normalize
366
+ self.max_tokens = config.max_tokens
367
+ self.text_classifier_checkpoint = config.text_classifier_checkpoint
368
+
369
+ self.model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name=self.model_name, max_tokens=self.max_tokens, pooling=None).to(self.device)
370
+ checkpoint = torch.load(self.text_classifier_checkpoint, map_location=DEVICE)
371
+ self.model.load_state_dict(checkpoint['model_state_dict'])
372
+ self.model.eval()
373
+
374
+ def extract(self, text_list):
375
+ """
376
+ :param text_list: список строк (или одна строка)
377
+ :return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None
378
+ """
379
+
380
+ if isinstance(text_list, str):
381
+ text_list = [text_list]
382
+
383
+ with torch.no_grad():
384
+ logits, hidden = self.model(text_list, with_features=True)
385
+
386
+ if hidden.dim() == 3:
387
+ if self.pooling is None:
388
+ emb = hidden
389
+ else:
390
+ if self.pooling == "mean":
391
+ emb = hidden.mean(dim=1)
392
+ elif self.pooling == "cls":
393
+ emb = hidden[:, 0, :]
394
+ elif self.pooling == "max":
395
+ emb, _ = hidden.max(dim=1)
396
+ elif self.pooling == "min":
397
+ emb, _ = hidden.min(dim=1)
398
+ elif self.pooling == "last":
399
+ emb = hidden[:, -1, :]
400
+ elif self.pooling == "sum":
401
+ emb = hidden.sum(dim=1)
402
+ else:
403
+ emb = hidden.mean(dim=1)
404
+ else:
405
+ # На всякий случай, если получилось (B, hidden_dim)
406
+ emb = hidden
407
+
408
+ if self.normalize_output and emb.dim() == 2:
409
+ emb = F.normalize(emb, p=2, dim=1)
410
+ return logits, emb
data_loading/pretrained_extractors.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import AutoTokenizer, AutoModel
7
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
8
+ Wav2Vec2Model,
9
+ Wav2Vec2PreTrainedModel,
10
+ )
11
+ from torch.nn.functional import silu
12
+ from torch.nn.functional import softplus
13
+ from einops import rearrange, einsum
14
+ from torch import Tensor
15
+ from einops import rearrange
16
+
17
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ # DEVICE = torch.device('cpu')
19
+
20
+ ## Audio models
21
+
22
+ class CustomMambaBlock(nn.Module):
23
+ def __init__(self, d_input, d_model, dropout=0.1):
24
+ super().__init__()
25
+ self.in_proj = nn.Linear(d_input, d_model)
26
+ self.s_B = nn.Linear(d_model, d_model)
27
+ self.s_C = nn.Linear(d_model, d_model)
28
+ self.out_proj = nn.Linear(d_model, d_input)
29
+ self.norm = nn.LayerNorm(d_input)
30
+ self.dropout = nn.Dropout(dropout)
31
+ self.activation = nn.ReLU()
32
+
33
+ def forward(self, x):
34
+ x_in = x # сохраняем вход
35
+ x = self.in_proj(x)
36
+ B = self.s_B(x)
37
+ C = self.s_C(x)
38
+ x = x + B + C
39
+ x = self.activation(x)
40
+ x = self.out_proj(x)
41
+ x = self.dropout(x)
42
+ x = self.norm(x + x_in) # residual + norm
43
+ return x
44
+
45
+ class CustomMambaClassifier(nn.Module):
46
+ def __init__(self, input_size=1024, d_model=256, num_layers=2, num_classes=7, dropout=0.1):
47
+ super().__init__()
48
+ self.input_proj = nn.Linear(input_size, d_model)
49
+ self.blocks = nn.ModuleList([
50
+ CustomMambaBlock(d_model, d_model, dropout=dropout)
51
+ for _ in range(num_layers)
52
+ ])
53
+ self.fc = nn.Linear(d_model, num_classes)
54
+
55
+ def forward(self, x, lengths, with_features=False):
56
+ # x: (batch, seq_length, input_size)
57
+ x = self.input_proj(x)
58
+ for block in self.blocks:
59
+ x = block(x)
60
+ pooled = []
61
+ for i, l in enumerate(lengths):
62
+ if l > 0:
63
+ pooled.append(x[i, :l, :].mean(dim=0))
64
+ else:
65
+ pooled.append(torch.zeros(x.size(2), device=x.device))
66
+ pooled = torch.stack(pooled, dim=0)
67
+ if with_features:
68
+ return self.fc(pooled), x
69
+ else:
70
+ return self.fc(pooled)
71
+
72
+ def get_model_mamba(params):
73
+ return CustomMambaClassifier(
74
+ input_size=params.get("input_size", 1024),
75
+ d_model=params.get("d_model", 256),
76
+ num_layers=params.get("num_layers", 2),
77
+ num_classes=params.get("num_classes", 7),
78
+ dropout=params.get("dropout", 0.1)
79
+ )
80
+
81
+ class EmotionModel(Wav2Vec2PreTrainedModel):
82
+
83
+ def __init__(self, config):
84
+ super().__init__(config)
85
+ self.config = config
86
+ self.wav2vec2 = Wav2Vec2Model(config)
87
+ self.init_weights()
88
+
89
+ def forward(self, input_values):
90
+ outputs = self.wav2vec2(input_values)
91
+ hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
92
+ return hidden_states
93
+
94
+ ## Text models
95
+
96
+ class Embedding():
97
+ def __init__(self, model_name='jinaai/jina-embeddings-v3', pooling=None):
98
+ self.model_name = model_name
99
+ self.pooling = pooling
100
+ self.device = DEVICE
101
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True)
102
+ self.model = AutoModel.from_pretrained(model_name, code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True).to(self.device)
103
+ self.model.eval()
104
+
105
+ def _mean_pooling(self, X):
106
+ def mean_pooling(model_output, attention_mask):
107
+ token_embeddings = model_output[0]
108
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
109
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
110
+ encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
111
+ with torch.no_grad():
112
+ model_output = self.model(**encoded_input)
113
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
114
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
115
+ return sentence_embeddings.unsqueeze(1)
116
+
117
+ def get_embeddings(self, X, max_len):
118
+ encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
119
+ with torch.no_grad():
120
+ features = self.model(**encoded_input)[0].detach().cpu().float().numpy()
121
+ res = np.pad(features[:, :max_len, :], ((0, 0), (0, max(0, max_len - features.shape[1])), (0, 0)), "constant")
122
+ return torch.tensor(res)
123
+
124
+ class RMSNorm(nn.Module):
125
+ def __init__(self, d_model: int, eps: float = 1e-8) -> None:
126
+ super().__init__()
127
+ self.eps = eps
128
+ self.weight = nn.Parameter(torch.ones(d_model))
129
+
130
+ def forward(self, x: Tensor) -> Tensor:
131
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
132
+
133
+ class Mamba(nn.Module):
134
+ def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, max_tokens=95, model_name='jina', pooling=None):
135
+ super().__init__()
136
+ mamba_par = {
137
+ 'd_input' : d_input,
138
+ 'd_model' : d_model,
139
+ 'd_state' : d_state,
140
+ 'd_discr' : d_discr,
141
+ 'ker_size': ker_size
142
+ }
143
+ self.model_name = model_name
144
+ self.max_tokens = max_tokens
145
+ embed = Embedding(model_name, pooling)
146
+ self.embedding = embed.get_embeddings
147
+ self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
148
+ self.fc_out = nn.Linear(d_input, num_classes)
149
+ self.device = DEVICE
150
+
151
+ def forward(self, seq, cache=None, with_features=True):
152
+ seq = self.embedding(seq, self.max_tokens).to(self.device)
153
+ for mamba, norm in self.layers:
154
+ out, cache = mamba(norm(seq), cache)
155
+ seq = out + seq
156
+ if with_features:
157
+ return self.fc_out(seq.mean(dim = 1)), seq
158
+ else:
159
+ return self.fc_out(seq.mean(dim = 1))
160
+
161
+ class MambaBlock(nn.Module):
162
+ def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
163
+ super().__init__()
164
+ d_discr = d_discr if d_discr is not None else d_model // 16
165
+ self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
166
+ self.out_proj = nn.Linear(d_model, d_input, bias=False)
167
+ self.s_B = nn.Linear(d_model, d_state, bias=False)
168
+ self.s_C = nn.Linear(d_model, d_state, bias=False)
169
+ self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
170
+ self.conv = nn.Conv1d(
171
+ in_channels=d_model,
172
+ out_channels=d_model,
173
+ kernel_size=ker_size,
174
+ padding=ker_size - 1,
175
+ groups=d_model,
176
+ bias=True,
177
+ )
178
+ self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
179
+ self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
180
+ self.device = DEVICE
181
+
182
+ def forward(self, seq, cache=None):
183
+ b, l, d = seq.shape
184
+ (prev_hid, prev_inp) = cache if cache is not None else (None, None)
185
+ a, b = self.in_proj(seq).chunk(2, dim=-1)
186
+ x = rearrange(a, 'b l d -> b d l')
187
+ x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
188
+ a = self.conv(x)[..., :l]
189
+ a = rearrange(a, 'b d l -> b l d')
190
+ a = silu(a)
191
+ a, hid = self.ssm(a, prev_hid=prev_hid)
192
+ b = silu(b)
193
+ out = a * b
194
+ out = self.out_proj(out)
195
+ if cache:
196
+ cache = (hid.squeeze(), x[..., 1:])
197
+ return out, cache
198
+
199
+ def ssm(self, seq, prev_hid):
200
+ A = -self.A
201
+ D = +self.D
202
+ B = self.s_B(seq)
203
+ C = self.s_C(seq)
204
+ s = softplus(D + self.s_D(seq))
205
+ A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
206
+ B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
207
+ X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
208
+ hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
209
+ out = einsum(hid, C, 'b l d s, b l s -> b l d')
210
+ out = out + D * seq
211
+ return out, hid
212
+
213
+ def _hid_states(self, A, X, prev_hid=None):
214
+ b, l, d, s = A.shape
215
+ A = rearrange(A, 'b l d s -> l b d s')
216
+ X = rearrange(X, 'b l d s -> l b d s')
217
+ if prev_hid is not None:
218
+ return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
219
+ h = torch.zeros(b, d, s, device=self.device)
220
+ return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)
221
+
emotion_templates/anger.json ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ "That lie", "This situation", "His attitude", "Her silence", "The way they ignored me",
4
+ "That decision", "This whole mess", "His response", "The way they talked to me",
5
+ "Their tone", "The look in their eyes", "The unfairness", "The betrayal",
6
+ "This conversation", "The rules they made", "That smug smile", "The injustice",
7
+ "His arrogance", "Her hypocrisy", "That fake apology", "Their excuse",
8
+ "What happened", "That accusation", "Their assumption", "His coldness",
9
+ "The way they dismissed me", "The way she walked away", "Their comment",
10
+ "The lack of respect", "The silence", "The manipulation", "The fake concern",
11
+ "This nonsense", "Their reaction", "The way they handled it", "The favoritism",
12
+ "That email", "The way they gaslit me", "This double standard", "The broken promise",
13
+ "Her message", "His ignorance", "The constant interruptions", "The yelling",
14
+ "The lies they spread", "Their decision", "This treatment", "His sarcasm",
15
+ "The rolled eyes", "This chaos", "The way they brushed it off", "The crossed line",
16
+ "Their behavior", "The gaslighting", "That moment", "Her denial", "The lack of accountability",
17
+ "His manipulation", "Their fake smile", "The passive aggression", "The backhanded comment",
18
+ "Her voice", "His condescension", "The disrespect", "The way I was ignored",
19
+ "The raised voice", "The rules for them", "The expectations they set", "That stupid joke",
20
+ "The fake empathy", "The unfair treatment", "That meeting", "That statement",
21
+ "The constant blame", "The way they twisted my words", "The tone they used",
22
+ "That decision they made", "The control", "Their agenda", "The things they said",
23
+ "That smug look", "The whispering", "Their manipulation", "The criticism",
24
+ "Her smugness", "His blank stare", "The gaslight attempt", "The guilt trip",
25
+ "The two-faced attitude", "The pressure", "The invalidation", "The dismissal",
26
+ "This tension", "The accusations", "The disrespectful comment", "The dishonesty",
27
+ "The fake kindness", "This nonsense policy", "Their entitlement", "The red flags",
28
+ "The toxic vibe", "The constant judgment", "The unnecessary drama", "His defensiveness"
29
+ ],
30
+ "verbs": [
31
+ "set me off", "boiled my blood", "pushed me over the edge", "got under my skin",
32
+ "made me snap", "triggered me instantly", "infuriated me", "was unbearable",
33
+ "sparked pure rage", "ticked me off", "made my fists clench", "was too much",
34
+ "ignited everything", "set a fire in me", "boiled over", "made my blood rush",
35
+ "shook me with anger", "cut too deep", "threw me into rage", "was beyond frustrating",
36
+ "hit a nerve", "was infuriating", "drove me mad", "burned inside me", "got to me",
37
+ "turned me red", "filled me with heat", "was outrageous", "was completely unacceptable",
38
+ "ripped through me", "kept poking at me", "got louder and louder", "spun me out",
39
+ "made me want to scream", "felt like an attack", "built up fast", "couldn’t be ignored",
40
+ "turned everything red", "sent me into a spiral", "clashed with everything in me",
41
+ "pressed all my buttons", "erased my patience", "had me yelling inside", "spilled over fast",
42
+ "was a slap in the face", "put me in fight mode", "snapped something in me", "took it too far",
43
+ "crossed a line", "felt like fuel", "burned through me", "got personal", "hit my limits",
44
+ "killed my calm", "shot adrenaline through me", "woke the worst in me", "soured the whole day",
45
+ "made me want to throw something", "burned me out", "was like fire in my gut", "dismantled my tolerance",
46
+ "overwhelmed me", "shattered my cool", "tore through my brain", "clanged in my chest",
47
+ "stabbed at my peace", "chased my breath", "spit in my face", "made me want to walk out",
48
+ "boomed inside", "grew louder", "flared without warning", "grated every nerve",
49
+ "felt like betrayal", "hammered on my mind", "flipped my mood instantly", "exploded internally",
50
+ "twisted my stomach", "simmered too long", "kept repeating in my head", "built up in silence",
51
+ "felt relentless", "crept up slowly", "left me livid", "killed all reason", "tore open my restraint",
52
+ "blocked out everything else", "just wouldn't stop", "sparked a storm", "dragged me into rage",
53
+ "stabbed at my patience", "flooded me with fury", "split my brain", "left me shaking",
54
+ "was like fire to paper", "ripped up my peace", "wrecked my day", "tipped me over",
55
+ "woke something primal", "restarted the fire", "slammed into my chest", "blew my fuse"
56
+ ],
57
+ "interjections": [
58
+ "You’ve got to be kidding me (groans)!",
59
+ "(inhales) This is unreal!",
60
+ "Seriously (clears throat)?",
61
+ "I can’t even begin (inhales)...",
62
+ "Enough is enough (groans)!",
63
+ "What the hell (exhales)!",
64
+ "Are you serious (groans)?",
65
+ "This has to stop (inhales)!",
66
+ "I’ve had it (groans)!",
67
+ "I’m not holding back (clears throat).",
68
+ "Now I’m done (inhales).",
69
+ "Unbelievable (groans).",
70
+ "Not again (exhales)...",
71
+ "That’s it (groans)!",
72
+ "I swear (clears throat).",
73
+ "How dare they (inhales)!",
74
+ "Absolutely not (groans)!",
75
+ "Get out of here (screams)!",
76
+ "I’m done pretending (inhales).",
77
+ "No way (groans)."
78
+ ],
79
+
80
+ "contexts": [
81
+ "and I couldn’t stay quiet (groans)",
82
+ "and I snapped (exhales) without warning",
83
+ "and I couldn’t control my voice (inhales)",
84
+ "and I felt the heat rise (exhales)",
85
+ "and I wanted to punch a wall (inhales)",
86
+ "and I yelled (screams) before I could stop",
87
+ "and it shattered my calm (groans)",
88
+ "and I lost it right there (groans)",
89
+ "and I (inhales) slammed the door",
90
+ "and I could feel myself shaking (groans)",
91
+ "and my jaw locked tight (exhales)",
92
+ "and I couldn’t hold it in (groans)",
93
+ "and I was done pretending (clears throat)",
94
+ "and I couldn’t even look at them (inhales)",
95
+ "and I exploded (groans)",
96
+ "and I couldn’t take it anymore (exhales)",
97
+ "and I had to say something (groans)",
98
+ "and I (inhales) stood my ground",
99
+ "and I shouted back (screams)",
100
+ "and I let it all out (exhales)",
101
+ "and my voice (clears throat) cracked",
102
+ "and I hit the table (groans)",
103
+ "and my chest was pounding (inhales)",
104
+ "and I was ready to walk out (groans)",
105
+ "and I didn’t care anymore (exhales)",
106
+ "and I needed to be heard (clears throat)",
107
+ "and I (inhales) raised my voice",
108
+ "and I threw the paper (groans)",
109
+ "and I stormed out (screams)",
110
+ "and I couldn’t stay still (inhales)",
111
+ "and I lost all patience (groans)",
112
+ "and I refused to let it slide (groans)",
113
+ "and I called them out (clears throat)",
114
+ "and I (inhales) slammed my hand down",
115
+ "and I saw red (groans)",
116
+ "and I swore (exhales) under my breath",
117
+ "and I (inhales) pushed back",
118
+ "and I made it clear (clears throat)",
119
+ "and I couldn’t smile through it (exhales)",
120
+ "and I wasn’t going to take it (groans)",
121
+ "and I felt fury (inhales) in my chest",
122
+ "and I almost screamed (groans)",
123
+ "and I nearly broke something (exhales)",
124
+ "and I didn’t care who heard (groans)",
125
+ "and I just (inhales) let it rip",
126
+ "and I couldn’t stop (groans)",
127
+ "and I raised hell (screams)",
128
+ "and I felt every nerve snap (inhales)",
129
+ "and I threw my phone (groans)",
130
+ "and I stood up fast (exhales)",
131
+ "and I couldn’t think straight (groans)",
132
+ "and I almost walked out (inhales)",
133
+ "and I hit the wall (groans)",
134
+ "and I didn’t hold back (clears throat)",
135
+ "and I wasn’t sorry (groans)",
136
+ "and I didn’t fake calm (inhales)",
137
+ "and I pushed through the rage (groans)",
138
+ "and I didn’t even blink (exhales)",
139
+ "and I said (inhales) exactly what I felt",
140
+ "and I glared back (groans)",
141
+ "and I lost all control (exhales)",
142
+ "and I roared inside (groans)",
143
+ "and I hit my breaking point (inhales)",
144
+ "and I wasn’t about to be silent (groans)",
145
+ "and I couldn’t take another second (exhales)",
146
+ "and I let the fury speak (groans)",
147
+ "and I (inhales) dropped the act",
148
+ "and I watched myself boil (groans)",
149
+ "and I nearly lost it all (exhales)",
150
+ "and I refused to back down (groans)",
151
+ "and I didn’t sugarcoat it (clears throat)",
152
+ "and I let the anger talk (groans)",
153
+ "and I (inhales) cut them off",
154
+ "and I burned the bridge (groans)",
155
+ "and I pulled no punches (inhales)",
156
+ "and I growled back (groans)",
157
+ "and I lashed out (exhales)",
158
+ "and I let it burn (groans)",
159
+ "and I needed the release (inhales)",
160
+ "and I was done holding it in (groans)",
161
+ "and I let the room know (inhales)",
162
+ "and I banged the table (groans)",
163
+ "and I spit the words out (inhales)",
164
+ "and I slammed my fist (groans)",
165
+ "and I (inhales) snapped hard",
166
+ "and I said it louder (groans)",
167
+ "and I faced it head‑on (exhales)",
168
+ "and I let the silence burn (groans)",
169
+ "and I felt the power rise (inhales)",
170
+ "and I walked away fuming (groans)",
171
+ "and I cracked wide open (exhales)",
172
+ "and I gave them the truth (groans)",
173
+ "and I looked them dead in the eyes (inhales)",
174
+ "and I didn’t regret it (exhales)"
175
+ ],
176
+ "templates": [
177
+ "{s} {v}, {c}.",
178
+ "{i} {s} {v}.",
179
+ "{s} {v}. {c}.",
180
+ "{s} {v} — {c}.",
181
+ "{c}. {s} {v}.",
182
+ "{s}. It {v}, {c}.",
183
+ "{s} just... {v}. {c}.",
184
+ "{i} — {s} {v}, {c}.",
185
+ "What set me off? {s} {v}. {c}.",
186
+ "{s} {v}. I had enough! {c}.",
187
+ "{s} {v}, and I didn’t hold back. {c}.",
188
+ "I lost it when {s} {v}, {c}.",
189
+ "You want to know why I snapped? {s} {v}.",
190
+ "And then it happened: {s} {v}, {c}.",
191
+ "{s} {v}. I couldn’t take one more second! {c}.",
192
+ "{s} {v}. That was the last straw!",
193
+ "{i} I saw red when {s} {v}. {c}.",
194
+ "{s} {v}. I slammed the damn door. {c}."
195
+ ]
196
+ }
emotion_templates/disgust.json ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ "That smell", "This mess", "The way they chew", "The bathroom", "That habit",
4
+ "The leftover food", "The sink", "The way he talks", "The mold", "Her fake laugh",
5
+ "The dirty plate", "His comment", "The old sponge", "The floor", "That photo",
6
+ "The rotten food", "The spoiled milk", "The sweat smell", "That video",
7
+ "The way she looked at me", "The chewing sound", "That creepy grin", "The fridge",
8
+ "The slime", "The trash", "His breath", "The sound of slurping", "The bug",
9
+ "The pus", "The stain", "The toilet", "The crusted towel", "The unwashed dish",
10
+ "The clump of hair", "The spit", "That burp", "The toenail clipping",
11
+ "The used tissue", "The crust in the corner", "The food stuck to the table",
12
+ "The pile of clothes", "The smudge", "That smell in the microwave", "The grime",
13
+ "The grease", "The way they eat with their mouth open", "The way he scratched",
14
+ "The loud chewing", "The mucus", "That nasty cough", "The dirty keyboard",
15
+ "The open sore", "The bathroom floor", "That weird stain", "The gum under the desk",
16
+ "The way she licked her fingers", "That slimy texture", "The hair in the drain",
17
+ "The oil smell", "The weird texture", "That flaky skin", "The cough in my face",
18
+ "The dried ketchup", "The fake smile", "The old food container", "The spit bubble",
19
+ "The breath in my face", "The noisy swallowing", "The crusty fork", "The plate of mush",
20
+ "The gurgling sound", "The smelly shoe", "The grease trap", "The used napkin",
21
+ "The close talker", "The public bathroom", "The noise he made", "The dirty towel",
22
+ "The slime on the floor", "The gunk", "The mix of smells", "The empty bottle with milk in it",
23
+ "That gesture", "The broken nail", "The eye gunk", "That noise with the throat",
24
+ "The mess in the sink", "The weird breathing", "The sticky hand", "The stinky feet",
25
+ "The trash bag leak", "The way they touched the food", "The finger licking",
26
+ "The bathroom stall", "That cracked lip", "The greasy spoon", "The way they smacked their lips"
27
+ ],
28
+ "verbs": [
29
+ "turned my stomach", "made me gag", "grossed me out", "made me shiver",
30
+ "sent a wave of nausea", "was disgusting", "made me flinch", "twisted my gut",
31
+ "churned my stomach", "repulsed me", "triggered my gag reflex",
32
+ "made my skin crawl", "was revolting", "sickened me", "hit me wrong",
33
+ "burned my nose", "felt wrong", "made me dry heave", "pushed me back",
34
+ "curdled my insides", "stung my senses", "brought bile up", "left a taste in my mouth",
35
+ "smelled rancid", "tasted foul", "looked putrid", "reeked", "felt like vomit",
36
+ "left me nauseated", "knocked my appetite", "made me lose it",
37
+ "felt filthy", "stank horribly", "coated my throat", "overwhelmed my nose",
38
+ "burned into memory", "made me cough", "was rotten", "was unbearable",
39
+ "left me disgusted", "sent shivers", "twisted my face", "gave me goosebumps",
40
+ "triggered disgust", "was sickening", "was grotesque", "reeked of old filth",
41
+ "smelled like death", "was unspeakable", "looked infected", "was slimy and warm",
42
+ "was moist in the worst way", "left a greasy film", "was sticky and gross",
43
+ "stuck to everything", "oozed", "dripped in the worst way", "was decaying",
44
+ "looked chewed", "felt rancid", "smelled spoiled", "squished in my hand",
45
+ "glistened weirdly", "looked alive", "wiggled slightly", "felt damp and wrong",
46
+ "left a slime trail", "made me cover my mouth", "caused a dry retch",
47
+ "spoiled the whole room", "felt like rot", "reeked of sweat", "oozed a bit",
48
+ "smeared across the surface", "clung to the air", "infected the vibe",
49
+ "spread filth", "made everything worse", "crawled under my skin",
50
+ "was so wrong", "ruined my day", "tasted like trash", "bubbled weirdly",
51
+ "had texture like decay", "looked chewed up", "smelled like feet",
52
+ "smelled fermented", "smeared all over", "cracked and oozed", "sent a gag reflex",
53
+ "looked diseased", "felt like warm spit", "reeked of mold", "coated my senses"
54
+ ],
55
+ "interjections": [
56
+ "(groans) Ew!",
57
+ "(coughs) That’s disgusting!",
58
+ "Ugh! (sniffs)",
59
+ "(gasps) I can’t believe I saw that",
60
+ "Seriously? (sighs)",
61
+ "(groans) What the hell?!",
62
+ "No no no! (clears throat)",
63
+ "Gross! (exhales)",
64
+ "(mumbles) Disgusting!",
65
+ "That made me gag! (coughs)",
66
+ "(sniffs) Yuck!",
67
+ "Oh god (groans)",
68
+ "That’s so wrong! (exhales)",
69
+ "Why would you do that?! (clears throat)",
70
+ "(inhales) Please stop",
71
+ "That smell though (sniffs)",
72
+ "No thanks (sighs)",
73
+ "I’m done (groans)",
74
+ "Absolutely vile! (coughs)",
75
+ "I could puke! (exhales)"
76
+ ],
77
+
78
+ "contexts": [
79
+ "(groans) and I couldn’t unsee it",
80
+ "and I nearly gagged (coughs)",
81
+ "and I felt my throat close (exhales)",
82
+ "and I turned away (groans)",
83
+ "and I had to leave the room (exhales)",
84
+ "and my face twisted (sniffs)",
85
+ "and I backed off immediately (groans)",
86
+ "and I clenched my jaw (sighs)",
87
+ "and I wanted to scrub my brain (groans)",
88
+ "and my appetite vanished (clears throat)",
89
+ "and I wiped my hands (sniffs)",
90
+ "and I needed mouthwash (groans)",
91
+ "and I felt queasy (coughs)",
92
+ "(sniffs) and I covered my nose",
93
+ "and I said ‘gross!’ (groans)",
94
+ "and I still smell it (sniffs)",
95
+ "and I wanted to burn it all (exhales)",
96
+ "and I nearly lost it (groans)",
97
+ "and I shook my head (sighs)",
98
+ "and I whispered ‘what the hell?!’ (groans)",
99
+ "and I flinched hard (inhales)",
100
+ "and I squinted like it would go away (sighs)",
101
+ "(inhales) and I held my breath",
102
+ "and I gagged a little (coughs)",
103
+ "and I pulled back fast (groans)",
104
+ "and I rubbed my eyes (sniffs)",
105
+ "and I dry heaved (coughs)",
106
+ "and I bit my tongue to avoid screaming (clears throat)",
107
+ "and I washed my hands three times (sniffs)",
108
+ "and I didn’t look again (groans)",
109
+ "and I covered my face (exhales)",
110
+ "and I had to breathe through my mouth (inhales)",
111
+ "and I said ‘absolutely not!’ (groans)",
112
+ "and I scrubbed everything (sniffs)",
113
+ "and I needed a shower (groans)",
114
+ "and I said ‘ew!’ out loud (coughs)",
115
+ "and I sprayed the whole area (coughs)",
116
+ "and I glared at them (clears throat)",
117
+ "and I wiped the table like five times (sniffs)",
118
+ "and I wanted to bleach it (groans)",
119
+ "and I swore never again (exhales)",
120
+ "and I asked them to stop (clears throat)",
121
+ "and I backed into a corner (inhales)",
122
+ "and I wanted to scream (groans)",
123
+ "and I asked them to throw it away (clears throat)",
124
+ "and I almost dropped it (gasps)",
125
+ "and I cringed (groans)",
126
+ "and I winced hard (inhales)",
127
+ "and I had to spit (coughs)",
128
+ "and I cleaned everything I touched (sniffs)",
129
+ "and I looked away fast (groans)",
130
+ "and I felt like vomiting (coughs)",
131
+ "and I dry swallowed (sighs)",
132
+ "and I left the room immediately (exhales)",
133
+ "and I didn’t know what to do (sighs)",
134
+ "and I needed to rinse my mouth (clears throat)",
135
+ "and I still feel it on my hands (sniffs)",
136
+ "and I wiped my phone off (sniffs)",
137
+ "and I stood there stunned (gasps)",
138
+ "and I whispered ‘nope!’ (groans)",
139
+ "and I tensed up (inhales)",
140
+ "and I wiped down the table again (sniffs)",
141
+ "(coughs) and I sprayed disinfectant everywhere",
142
+ "and I held the air (inhales)",
143
+ "and I tried not to vomit (coughs)",
144
+ "and I shook it off quickly (exhales)",
145
+ "and I walked away fast (groans)",
146
+ "and I stared in horror (gasps)",
147
+ "and I flinched again (inhales)",
148
+ "and I cleaned the whole area (sniffs)",
149
+ "and I sanitized everything (clears throat)",
150
+ "and I stepped away (exhales)"
151
+ ],
152
+ "templates": [
153
+ "{s} {v}, {c}.",
154
+ "{i}, {s} {v}.",
155
+ "{s} {v}. {c}.",
156
+ "{s} {v} — {c}.",
157
+ "{c}. {s} {v}.",
158
+ "{s}. It {v}, {c}.",
159
+ "{s} just... {v}. {c}.",
160
+ "{i} — {s} {v}, {c}.",
161
+ "Honestly, {s} {v}. {c}.",
162
+ "I couldn’t take it — {s} {v}, {c}.",
163
+ "It was foul. {s} {v}. {c}.",
164
+ "Just thinking about it — {s} {v}. {c}.",
165
+ "‘{s} {v}’ — yeah, I’m out. {c}.",
166
+ "{s} made me gag, no question. {c}.",
167
+ "I still feel sick. {s} {v}. {c}.",
168
+ "It turned my gut — {s} {v}, {c}.",
169
+ "{s}? Just vile. {v}, {c}.",
170
+ "Disgust doesn’t even cover it — {s} {v}, {c}.",
171
+ "{i}. Seriously, {s} {v}. {c}.",
172
+ "Every time I remember — {s} {v}. {c}."
173
+ ]
174
+ }
emotion_templates/fear.json ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ "That noise", "The silence", "This feeling", "His stare", "Her whisper", "That look",
4
+ "The shadow", "The knock", "That voice", "The unknown", "The hallway", "The message",
5
+ "The basement", "The flicker", "That sound behind me", "The reflection", "The darkness",
6
+ "The quiet", "That face", "The empty street", "The open door", "The figure in the distance",
7
+ "The footsteps", "The breath I heard", "The locked room", "The sudden cold", "The static",
8
+ "The glitch", "That motion", "The silence after", "That sudden stop", "The blank screen",
9
+ "The basement light", "The empty chair", "The way it stopped", "The timing", "The phone ringing",
10
+ "The strange call", "The echo", "The scream", "The eye contact", "The ticking", "The delay",
11
+ "That message at night", "The wrong number", "The unexpected knock", "The closed curtain",
12
+ "The way it moved", "The window", "The mirror", "The shape", "The silence on the line",
13
+ "The dim hallway", "That heavy breath", "The door creaking", "The chill", "That shadow on the wall",
14
+ "The open browser tab", "The stillness", "The unanswered call", "The flickering light",
15
+ "The power going out", "That smell", "The voice in my head", "The basement stairs",
16
+ "The unknown number", "The sudden pause", "The air around me", "The shape behind me",
17
+ "The clicking sound", "The lack of sound", "That dream", "The phone buzz", "The siren",
18
+ "That change in tone", "The knock I didn’t expect", "The door left open", "That breath I didn’t take",
19
+ "The space under the bed", "The open closet", "The camera turning on", "The power surge",
20
+ "That sudden silence", "The hallway light", "The alarm", "The code not working",
21
+ "That sudden breeze", "The voice that stopped", "The quiet tap", "The screen glitch",
22
+ "That blinking cursor", "The slow footsteps", "The eye in the dark", "The frozen time",
23
+ "The disconnect tone", "The feeling of being watched", "The stop in motion", "The email I didn’t send"
24
+ ],
25
+ "verbs": [
26
+ "froze me", "sent a chill down my spine", "made my breath catch", "stopped my heart",
27
+ "tied my stomach in knots", "paralyzed me", "had me holding my breath", "rushed my pulse",
28
+ "tensed every muscle", "squeezed my chest", "made me dizzy", "triggered my panic",
29
+ "gripped me", "shook me", "dried my mouth", "shrank my world", "knocked me still",
30
+ "pressed on my lungs", "made my skin crawl", "trembled through me", "dug into my mind",
31
+ "raced through my thoughts", "tightened my throat", "blanked my head", "flooded me with fear",
32
+ "spiked my heart rate", "stole my breath", "made me freeze in place", "clouded my focus",
33
+ "hit like dread", "echoed in my chest", "spun my senses", "stung like ice", "gripped my spine",
34
+ "whitened my face", "stopped everything", "grew louder in my head", "seeped into me",
35
+ "dragged down my breath", "set my nerves on edge", "crawled over my skin", "burned in my gut",
36
+ "snapped my focus", "locked my body", "flooded my system", "choked my breath",
37
+ "tightened my jaw", "made my heartbeat deafening", "pressed me into stillness",
38
+ "caged my breath", "made my ears ring", "left me lightheaded", "slowed down time",
39
+ "reduced me to silence", "took over", "made my hands shake", "bent my knees",
40
+ "tugged at my ribs", "dimmed my sight", "pulled me inward", "buried my voice",
41
+ "unraveled me", "set alarms off inside", "curled my fingers", "hit my core",
42
+ "trembled in my jaw", "scrambled my thoughts", "turned me to stone", "squeezed my spine",
43
+ "widened my eyes", "flushed my skin", "knocked the air out", "rattled my head",
44
+ "sat heavy in my chest", "clutched at my throat", "sent panic through me",
45
+ "knotted my back", "held me hostage", "blinded my thinking", "muffled my hearing",
46
+ "rose like a wave", "climbed through my limbs", "forced my jaw shut", "twisted in my gut",
47
+ "burned through my nerves", "flooded my eyes", "froze my voice", "slammed my chest",
48
+ "disoriented me", "hollowed me out", "rushed like a scream", "scraped my bones",
49
+ "throbbed in my neck", "tightened like a noose", "spun in my head", "felt like doom"
50
+ ],
51
+ "contexts": [
52
+ "(gasps) and I couldn’t move",
53
+ "and I held my breath (inhales)",
54
+ "and my body locked up (exhales)",
55
+ "and I didn’t know what to do (sighs)",
56
+ "and I felt eyes on me (inhales)",
57
+ "and I just listened (sighs)",
58
+ "and I froze in place (gasps)",
59
+ "and everything felt too quiet (exhales)",
60
+ "and I scanned the room (inhales)",
61
+ "and I backed away (gasps)",
62
+ "and I reached for my phone (inhales)",
63
+ "and I whispered nothing (sighs)",
64
+ "and I stayed absolutely still (exhales)",
65
+ "and I clenched my jaw",
66
+ "and I waited in silence (exhales)",
67
+ "and I couldn’t see clearly (exhales)",
68
+ "and my heart was pounding (inhales)",
69
+ "and the air felt wrong (sighs)",
70
+ "and I tiptoed (inhales)",
71
+ "and my hands were shaking (exhales)",
72
+ "and I hoped it wasn’t real (sighs)",
73
+ "and I didn’t dare speak (inhales)",
74
+ "and I tried to stay calm (exhales)",
75
+ "and the walls felt closer (gasps)",
76
+ "and the room spun (exhales)",
77
+ "and I avoided the mirror (sighs)",
78
+ "and I wished I hadn’t heard that (exhales)",
79
+ "and the lights flickered (gasps)",
80
+ "and the sound repeated (inhales)",
81
+ "and my fingers curled (exhales)",
82
+ "and I stood frozen (gasps)",
83
+ "and I looked over my shoulder (inhales)",
84
+ "and I barely breathed (exhales)",
85
+ "and I felt like I was watched (gasps)",
86
+ "and I clutched my chest (inhales)",
87
+ "and my breath stopped (exhales)",
88
+ "and I whispered ‘hello?’ (sighs)",
89
+ "and I stepped back slowly (exhales)",
90
+ "and I avoided eye contact (sighs)",
91
+ "and I didn’t turn around (inhales)",
92
+ "and I closed my eyes tight (exhales)",
93
+ "and I mouthed ‘please no’ (sighs)",
94
+ "and I turned the light on (inhales)",
95
+ "and I checked the door (exhales)",
96
+ "and I hoped it was nothing (sighs)",
97
+ "and I bit my lip (inhales)",
98
+ "and I looked again (gasps)",
99
+ "and I grabbed the handle (inhales)",
100
+ "and I watched the hallway (exhales)",
101
+ "and I held the wall (sighs)",
102
+ "and I tried not to blink (inhales)",
103
+ "and I prayed I was wrong (exhales)",
104
+ "and I counted my breaths (sighs)",
105
+ "and I waited for a sound (inhales)",
106
+ "and I pressed my back to the wall (exhales)",
107
+ "and I closed the tab (sighs)",
108
+ "and I locked my phone (inhales)",
109
+ "and I shook (exhales)",
110
+ "and I tried to steady myself (sighs)",
111
+ "and I looked away (exhales)",
112
+ "and I froze mid‑step (gasps)",
113
+ "and I turned down the volume (sighs)",
114
+ "and I stood in the dark (inhales)",
115
+ "and I checked the lock (exhales)",
116
+ "and I peeked through the curtain (gasps)",
117
+ "and I watched the shadow move (inhales)",
118
+ "and I saw it again (exhales)",
119
+ "and I hoped I imagined it (sighs)",
120
+ "and I wanted to scream (gasps)",
121
+ "and I felt something shift (inhales)",
122
+ "and I barely exhaled (exhales)",
123
+ "and I pressed mute (sighs)",
124
+ "and I whispered ‘stop’ (inhales)",
125
+ "and I stared into the dark (exhales)",
126
+ "and I held my breath again (inhales)",
127
+ "and I tried to wake up (sighs)",
128
+ "and I reached for the light (inhales)",
129
+ "and I stayed under the blanket (exhales)",
130
+ "and I whispered to myself (sighs)",
131
+ "and I looked at the screen again (inhales)",
132
+ "and I locked the window (exhales)",
133
+ "and I flinched (gasps)",
134
+ "and I felt it behind me (inhales)",
135
+ "and I turned off the TV (exhales)",
136
+ "and I pulled the covers tight (sighs)"
137
+ ],
138
+
139
+ "interjections": [
140
+ "(gasps) I felt a chill!",
141
+ "(sighs) Something’s off.",
142
+ "(exhales) This can’t be good.",
143
+ "I don’t like this (inhales).",
144
+ "(inhales) What if it’s real?",
145
+ "(gasps) No way!",
146
+ "I swear I saw something (exhales).",
147
+ "(inhales) Did you hear that?",
148
+ "(sighs) Why is it so quiet?",
149
+ "That wasn’t right (exhales).",
150
+ "(inhales) Is someone there?",
151
+ "Not again (exhales)...",
152
+ "(exhales) That gave me chills.",
153
+ "(gasps) What the...",
154
+ "It felt wrong (sighs).",
155
+ "(gasps) I froze!",
156
+ "(gasps) Something moved!",
157
+ "(sighs) It’s too quiet.",
158
+ "(exhales) This isn’t normal.",
159
+ "(inhales) I’m not alone."
160
+ ],
161
+ "templates": [
162
+ "{s} {v}, {c}.",
163
+ "{i}, {s} {v}.",
164
+ "{s} {v}. {c}.",
165
+ "{s} {v} — {c}.",
166
+ "{c}. {s} {v}.",
167
+ "{s}. It {v}, {c}.",
168
+ "{s} just... {v}. {c}.",
169
+ "{i} — {s} {v}, {c}.",
170
+ "{s} {v}. I couldn’t blink. {c}.",
171
+ "And then it happened: {s} {v}, {c}.",
172
+ "‘{s} {v}’ — and I froze. {c}.",
173
+ "It was quiet, too quiet. Then {s} {v}. {c}.",
174
+ "I stopped breathing when {s} {v}. {c}.",
175
+ "No one else noticed, but {s} {v}, {c}.",
176
+ "{s} {v}. My pulse spiked. {c}."
177
+ ]
178
+ }
emotion_templates/happy.json ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ "That moment", "This message", "Her smile", "His laugh", "The sunshine", "The good news",
4
+ "That call", "The unexpected text", "The gift", "The surprise", "That compliment",
5
+ "The way they looked at me", "This day", "The win", "That hug", "The way it happened",
6
+ "Their reaction", "That unexpected turn", "This feeling", "That morning",
7
+ "The celebration", "The results", "That gesture", "The quiet joy", "The party",
8
+ "The sound of laughter", "The light", "That one word", "The surprise visit",
9
+ "The smell of coffee", "The fresh air", "The freedom", "That one dance",
10
+ "The sunrise", "The music", "The rhythm", "The warm breeze", "That smile across the room",
11
+ "The perfect timing", "The peace", "The joke", "The message I didn’t expect",
12
+ "That moment I opened the door", "The reaction on their face", "That memory",
13
+ "The sparkle in their eyes", "The unexpected break", "The lucky moment",
14
+ "The right words", "The song that played", "The cake", "The light in the room",
15
+ "The kids laughing", "The compliments", "The photo", "The way they cheered",
16
+ "The shared silence", "That warm feeling", "The burst of joy", "The freedom I felt",
17
+ "The high five", "The small win", "The huge success", "The laughter we shared",
18
+ "The trip", "The story", "The goofy face", "The text with all caps",
19
+ "The silly moment", "The happy tear", "The peace in my heart", "The sparkle",
20
+ "The good vibes", "The cozy night", "The unexpected yes", "The sudden dance",
21
+ "The applause", "The open road", "The sunshine on my face", "The joke that landed",
22
+ "The moment they said it", "The sparkle of hope", "The confetti moment", "That yes",
23
+ "The bounce in their step", "The rush of relief", "The random kindness",
24
+ "The shared look", "The clink of glasses", "The moment of pride", "The nod of approval",
25
+ "The lightness in the air", "The smell of cookies", "The favorite song",
26
+ "The surprise package", "The right time", "That perfect moment", "The cheers",
27
+ "The warmth in the room", "The silly dance", "The peace in the pause", "The bright light"
28
+ ],
29
+ "verbs": [
30
+ "lit me up", "made me laugh out loud", "warmed my chest", "brought me peace",
31
+ "lifted my mood", "was pure joy", "felt like flying", "was exactly what I needed",
32
+ "made me grin", "sparked something bright", "filled me with light",
33
+ "was everything", "cheered me up", "set the tone", "made the whole day",
34
+ "was perfect", "was a blessing", "felt magical", "recharged me", "was unforgettable",
35
+ "made my heart dance", "put a smile on my face", "made my eyes shine",
36
+ "woke up my soul", "filled the room", "shimmered through me", "set me free",
37
+ "bubbled up inside", "sparked so much joy", "brought out my laughter",
38
+ "was heartwarming", "had me laughing instantly", "soothed my spirit",
39
+ "gave me goosebumps", "sparkled", "changed everything", "was so genuine",
40
+ "had me beaming", "was sunshine", "shined bright", "was golden",
41
+ "made the room feel lighter", "brought instant joy", "was beautiful",
42
+ "echoed with joy", "brightened the moment", "lifted the weight",
43
+ "refreshed my spirit", "wrapped me in warmth", "was the highlight",
44
+ "gave me hope", "touched my heart", "made me giggle", "felt light and free",
45
+ "restored my smile", "made me blush", "sparkled in me", "woke something kind",
46
+ "was soft and strong", "wrapped the moment", "was like magic", "felt like home",
47
+ "was full of grace", "was the best surprise", "was the perfect fit", "bloomed inside me",
48
+ "brought clarity and calm", "was so full of color", "set off a spark",
49
+ "played like a melody", "made everything brighter", "was like spring inside me",
50
+ "was sweetness", "glowed in me", "reminded me of good things", "restored my joy",
51
+ "hugged me from within", "was full of wonder", "sparkled in my chest", "was so soft",
52
+ "helped me breathe again", "was a true gift", "felt so human", "rang like laughter",
53
+ "set me dancing", "was pure love", "made me giggle-snort", "sent joy down my spine",
54
+ "was completely silly", "was just right", "was light as air", "was like a hug",
55
+ "fizzed like soda", "made my heart sing", "was total delight", "filled the space with light",
56
+ "was cozy and bright", "carried me", "felt so real", "just made sense", "bounced off the walls"
57
+ ],
58
+
59
+ "interjections": [
60
+ "(laughs) I couldn’t believe it!",
61
+ "No way! (claps)",
62
+ "That made my day, (laughs).",
63
+ "(chuckle)... You have no idea!",
64
+ "(whistles) Honestly, amazing!",
65
+ "So good! (laughs)",
66
+ "It was perfect — (applause).",
67
+ "(laughs) Pure joy!",
68
+ "Absolutely loved it! (claps)",
69
+ "Couldn’t stop grinning! (chuckle)",
70
+ "Wow! (whistles)",
71
+ "(laughs) This lit me up!",
72
+ "(whistles)... Just yes!",
73
+ "Loved every second! (laughs)",
74
+ "(chuckle) Total happiness!",
75
+ "It hit perfectly! (whistles)",
76
+ "(laughs) I’m smiling again.",
77
+ "Everything clicked — (claps)!",
78
+ "I was glowing... (humming)"
79
+ ],
80
+
81
+ "contexts": [
82
+ "(laughs)... and I couldn’t stop smiling!",
83
+ "and everything felt lighter (exhales).",
84
+ "and I didn’t want it to end — (sighs).",
85
+ "(whistles) and I wanted to dance!",
86
+ "(laughs) and I couldn’t stop laughing!",
87
+ "and the room lit up — (claps)!",
88
+ "and I held onto it all day... (humming)",
89
+ "and it made my week! (laughs)",
90
+ "and I smiled without even realizing (chuckle).",
91
+ "and I was so full of energy (humming)!",
92
+ "and I felt so free... (exhales)",
93
+ "and nothing else mattered.",
94
+ "and I just wanted to share it (laughs)!",
95
+ "(laughs) and I laughed until my sides hurt!",
96
+ "and I couldn’t stop replaying it (whistles)...",
97
+ "and I felt ten pounds lighter (exhales)!",
98
+ "and the moment sparkled (humming)...",
99
+ "and I forgot all the stress — (sighs)",
100
+ "and I hugged them tight!",
101
+ "and I just soaked it in... (exhales)",
102
+ "and it made everything worth it! (laughs)",
103
+ "(inhales) and I breathed it in.",
104
+ "and it stayed with me (whistles)...",
105
+ "and it was everything I needed (sighs).",
106
+ "and I danced a little inside (whistles)!",
107
+ "and I caught myself smiling later (chuckle).",
108
+ "and I said “finally!” (laughs)",
109
+ "and my heart felt full (humming).",
110
+ "and I couldn’t stop giggling (laughs)!",
111
+ "and I felt alive again (exhales).",
112
+ "and I whispered “yes!” (whistles).",
113
+ "and I high‑fived the air (claps)!",
114
+ "and I had the biggest grin (chuckle)!",
115
+ "and it lifted the whole moment (laughs).",
116
+ "and the day turned golden (whistles)...",
117
+ "and the silence was perfect — (sighs).",
118
+ "and I was glowing (humming)!",
119
+ "and I bounced around (laughs)!",
120
+ "and the sun seemed brighter (whistles)!",
121
+ "and even coffee tasted better (chuckle)...",
122
+ "and it made everything okay (sighs).",
123
+ "and I had happy tears (laughs)...",
124
+ "and I just twirled (whistles)!",
125
+ "and I hugged myself (humming).",
126
+ "(laughs) and I laughed out loud in public!",
127
+ "and it gave me wings (whistles)...",
128
+ "and I wanted to bottle that feeling (humming).",
129
+ "(humming)... and I hummed without thinking!",
130
+ "and I sent five heart emojis (claps).",
131
+ "and I didn’t need words (exhales)...",
132
+ "and I knew it was real (sighs).",
133
+ "and I just sat in the joy (exhales).",
134
+ "and it filled the room (laughs)!",
135
+ "and I felt carried (humming)...",
136
+ "and I danced in my chair (whistles)!",
137
+ "and it echoed in my bones (humming).",
138
+ "and I forgot why I was sad (exhales)...",
139
+ "and I thanked the moment (laughs).",
140
+ "and I could finally breathe (exhales)!",
141
+ "and I turned up the music (singing)!",
142
+ "and I felt hugged (humming).",
143
+ "and I danced barefoot (whistles)...",
144
+ "and I made a goofy face (laughs).",
145
+ "and I felt so warm (exhales).",
146
+ "and I cried happy tears (laughs)!",
147
+ "and I replayed it in my head (chuckle).",
148
+ "and I couldn’t believe it (gasps)!",
149
+ "and I wanted to freeze time (sighs)...",
150
+ "and I smiled at a stranger (laughs).",
151
+ "and I whispered “thank you.” (sighs)",
152
+ "and I let myself laugh (laughs)!",
153
+ "and I felt like a kid again (whistles).",
154
+ "and I laughed‑snorted (laughs)!",
155
+ "and I leaned into it (exhales)...",
156
+ "and I closed my eyes and smiled (humming).",
157
+ "and I sang out loud (singing)!",
158
+ "and it gave me joy to spare (laughs)...",
159
+ "and I did a little spin (whistles)!",
160
+ "and I raised my hands — (claps)!",
161
+ "(claps)... and I clapped again!",
162
+ "and I wanted to shout it (laughs)!",
163
+ "and I texted five people (chuckle).",
164
+ "and I smiled like an idiot (laughs)!",
165
+ "and I started dancing (whistles)!",
166
+ "and I breathed deep and smiled (exhales).",
167
+ "and I grinned the rest of the day (laughs)!",
168
+ "and I fist‑pumped the air (claps)!"
169
+ ],
170
+ "templates": [
171
+ "{s} {v}, {c}.",
172
+ "{i}, {s} {v}.",
173
+ "{s} {v}. {c}.",
174
+ "{s} {v} — {c}.",
175
+ "{c}. {s} {v}.",
176
+ "{s}. It {v}, {c}.",
177
+ "{s} just... {v}. {c}.",
178
+ "{i} — {s} {v}, {c}.",
179
+ "{s} {v}. I smiled so hard. {c}.",
180
+ "It was simple: {s} {v}. {c}.",
181
+ "And there it was: {s} {v}. {c}.",
182
+ "I couldn’t stop smiling — {s} {v}, {c}.",
183
+ "{s} {v}, and I laughed like a child. {c}.",
184
+ "It all made sense when {s} {v}. {c}.",
185
+ "You wouldn’t believe it: {s} {v}, {c}."
186
+ ]
187
+ }
emotion_templates/neutral.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ "The meeting", "This email", "The weather", "My schedule", "The report",
4
+ "The file", "The documents", "The update", "The list", "The package",
5
+ "That call", "The deadline", "My calendar", "The appointment", "The client",
6
+ "The system", "The form", "The issue", "The request", "This link",
7
+ "The procedure", "That number", "The login", "The ticket", "The reminder",
8
+ "The message", "The time", "The place", "The project", "The plan",
9
+ "The item", "The summary", "The spreadsheet", "The announcement",
10
+ "The room", "The schedule", "The response", "The folder", "The draft",
11
+ "The setup", "The chart", "The template", "This section", "The address",
12
+ "The status", "The link", "The note", "The user", "The change", "The version",
13
+ "The page", "The setting", "The system log", "The data", "The record",
14
+ "The feedback", "The form submission", "The result", "The agenda", "The channel",
15
+ "The section", "The backend", "The field", "The login screen", "The main page",
16
+ "The recent edit", "The input", "The draft proposal", "The video", "The request form",
17
+ "The calendar update", "The hour", "The alert", "The setting change", "The comment",
18
+ "The text box", "The dashboard", "The timezone", "The notification", "The next step",
19
+ "The change log", "The ping", "The tool", "The auto-reply", "The file name",
20
+ "The browser tab", "The screen", "The online form", "The number of users",
21
+ "The archive", "The post", "The backend tool", "The session", "The queue",
22
+ "The timestamp", "The log", "The script", "The recent call", "The API endpoint",
23
+ "The shortcut", "The current state", "The setting value", "The message ID"
24
+ ],
25
+ "verbs": [
26
+ "was updated", "was received", "was sent", "is scheduled", "was delivered",
27
+ "was attached", "is listed", "is included", "is available", "was confirmed",
28
+ "is submitted", "was opened", "is set", "was approved", "was completed",
29
+ "was reviewed", "was filed", "is noted", "was mentioned", "was discussed",
30
+ "was forwarded", "was generated", "was uploaded", "was located", "was exported",
31
+ "was restarted", "was removed", "was closed", "was fixed", "is resolved",
32
+ "is visible", "was highlighted", "was added", "was changed", "was renamed",
33
+ "was stored", "was marked", "was tracked", "was logged", "was selected",
34
+ "was unchecked", "is prefilled", "is functional", "was muted", "was flagged",
35
+ "was logged again", "was copied", "was moved", "was parsed", "was captured",
36
+ "was identified", "was tested", "was analyzed", "was displayed", "was found",
37
+ "was active", "was paused", "was finalized", "was shared", "was cleared",
38
+ "was replaced", "was validated", "was reloaded", "was timed out", "was confirmed again",
39
+ "was repeated", "was archived", "was restarted", "was merged", "was commented on",
40
+ "was clicked", "was created", "was auto-saved", "was viewed", "was hidden",
41
+ "was unchecked", "was dismissed", "was shortened", "was synced", "was noted again",
42
+ "was converted", "was reapproved", "was minimized", "was restored", "was synced again",
43
+ "was edited", "was translated", "was resized", "was expanded", "was called in",
44
+ "was reentered", "was backed up", "was counted", "was referenced", "was linked"
45
+ ],
46
+ "contexts": [
47
+ "and noted accordingly", "as expected", "without any issues", "for review",
48
+ "and is ready", "based on the plan", "according to the request",
49
+ "and marked complete", "for reference", "with no delay",
50
+ "and synced properly", "as part of the update", "and stored correctly",
51
+ "and reviewed again", "without error", "as per instructions",
52
+ "and appears consistent", "with the rest of the files", "in the system",
53
+ "as mentioned earlier", "without further input", "at the same time",
54
+ "on the main screen", "in the report summary", "in the log", "as scheduled",
55
+ "and is linked correctly", "with updated fields", "as noted", "with default settings",
56
+ "on the dashboard", "in the latest build", "with minor differences",
57
+ "with unchanged parameters", "based on previous input", "in the final draft",
58
+ "in the archive", "and visible now", "within limits", "with expected output",
59
+ "as per the instructions", "without manual entry", "without formatting issues",
60
+ "with default layout", "with no duplicates", "and updated again",
61
+ "with no change required", "on the same page", "with standard formatting",
62
+ "in the initial pass", "in the preview", "on request", "under normal load",
63
+ "with corrected values", "in current use", "in the test environment",
64
+ "with current permissions", "after the last restart", "on page load",
65
+ "in the dropdown list", "after refresh", "with no visible difference",
66
+ "under the given conditions", "within tolerance", "for later reference",
67
+ "during the last session", "on confirmation", "as of now", "as reviewed",
68
+ "without alert", "and is part of the bundle", "without recent changes",
69
+ "on repeated use", "in line with policy", "and saved successfully",
70
+ "in its original format", "with logged history", "from default mode",
71
+ "and loaded correctly", "with stable results", "with current context",
72
+ "during the trial", "in admin view", "in preview mode", "from last session"
73
+ ],
74
+ "interjections": [
75
+ "Okay", "Sure", "All right", "Understood", "Makes sense",
76
+ "Got it", "Thanks", "No problem", "Noted", "All good",
77
+ "Sounds fine", "Looks okay", "That works", "Cool", "Done",
78
+ "Let me check", "Confirmed", "Updated", "Logged", "As expected"
79
+ ],
80
+ "templates": [
81
+ "{s} {v}, {c}.",
82
+ "{i}, {s} {v}.",
83
+ "{s} {v}. {c}.",
84
+ "{s} {v} — {c}.",
85
+ "{c}. {s} {v}.",
86
+ "{s} {v}.",
87
+ "{s}. It {v}, {c}.",
88
+ "{i}. {s} {v}, {c}.",
89
+ "Just FYI: {s} {v}, {c}.",
90
+ "Note: {s} {v}, {c}.",
91
+ "{s} {v}. No further action required.",
92
+ "Update: {s} {v}, {c}.",
93
+ "For reference, {s} {v}.",
94
+ "Also, {s} {v}.",
95
+ "{s} {v} as planned. {c}."
96
+ ]
97
+ }
emotion_templates/sad.json ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ "That memory", "This feeling", "Her silence", "His voice", "The goodbye",
4
+ "That call", "This day", "My past", "That old photo", "The message",
5
+ "What she said", "What he didn’t say", "The look in their eyes", "That time of year",
6
+ "The room we used to share", "That place", "The last conversation", "That empty chair",
7
+ "That letter", "The moment I knew", "That one decision", "The song we both liked",
8
+ "My mistake", "Her absence", "The words I never said", "The things I lost",
9
+ "The silence after", "His last text", "My regret", "That broken promise",
10
+ "That question", "The goodbye hug", "The final message", "The void", "The shadow of it all",
11
+ "My own thoughts", "The long nights", "That walk alone", "Her last smile",
12
+ "That feeling of loss", "The unread message", "The sound of her name", "The forgotten laugh",
13
+ "The unspoken words", "This moment", "The fading photo", "The conversation that never came",
14
+ "The look I still remember", "The missed call", "This quiet", "What we never had",
15
+ "His empty side", "The empty inbox", "The things left unsaid", "The cracked frame",
16
+ "That empty promise", "The silence in the car", "The distance", "The space between us",
17
+ "The moment everything changed", "The unmade bed", "This weight", "The last gift",
18
+ "The things I should’ve done", "This cold air", "That long pause", "The truth I denied",
19
+ "The tears I hid", "The smile I faked", "The echo in my chest", "The memories flooding back",
20
+ "The half-empty cup", "That slow fade", "That goodbye text", "The waiting",
21
+ "That unanswered question", "This morning", "His coat", "The words on the page",
22
+ "That familiar place", "The night we stopped talking", "The unread letters",
23
+ "The unsent messages", "That photo on my phone", "That song I can’t skip",
24
+ "That cold seat", "The quiet kitchen", "The dim light", "That empty room",
25
+ "The thing I never said", "The end", "That hallway", "The couch we sat on",
26
+ "The words that hurt", "The moment we broke", "That one night", "The feeling of being left",
27
+ "What I carry", "That date", "The thing I remember most", "The after"
28
+ ],
29
+ "verbs": [
30
+ "still hurts", "breaks me", "keeps me awake", "never fades", "hits me hard",
31
+ "leaves me numb", "crushed me", "stays with me", "burns inside", "tears me apart",
32
+ "weighs on me", "makes me cry", "won’t let go", "haunts me", "hurts deeply",
33
+ "shatters me", "lingers", "cuts deep", "took a part of me", "sits heavy in my chest",
34
+ "aches quietly", "feels endless", "never really left", "drowns me", "pulls me down",
35
+ "brings tears", "holds me still", "makes my chest heavy", "chills me",
36
+ "drains me", "fills me with sorrow", "leaves a mark", "empties me", "blurs everything",
37
+ "mutes my world", "dims my light", "echoes loudly", "steals my breath",
38
+ "closes my throat", "breaks the silence", "pushes me down", "holds me back",
39
+ "reminds me", "sticks like glue", "crawls under my skin", "hangs over me",
40
+ "wraps around me", "dampens everything", "stings", "feels like drowning",
41
+ "pulls at my heart", "sinks deep", "won’t leave", "burns behind my eyes",
42
+ "won’t let me forget", "drags on", "never ends", "feels permanent",
43
+ "presses on my chest", "washes over me", "never gets easier", "finds me in the quiet",
44
+ "catches me off guard", "unravels me", "pours in like rain", "sits in my lungs",
45
+ "breaks the morning", "bleeds into my days", "clouds my thoughts", "chokes me softly",
46
+ "grips my stomach", "twists my mind", "sinks into my bones", "grays out the world",
47
+ "softens my voice", "shadows my smile", "brings a wave", "holds my breath",
48
+ "dims the sun", "keeps repeating", "echoes through", "shuts me down",
49
+ "darkens the room", "muffles everything", "screams inside", "keeps me from moving",
50
+ "follows me", "closes in", "blurs my days", "fogs my memory", "empties the joy",
51
+ "clutches my chest", "keeps showing up", "presses into me", "sits beside me",
52
+ "floods my eyes", "weighs down my steps", "never softens", "never truly heals"
53
+ ],
54
+ "interjections": [
55
+ "(sighs)... I still can't believe it.",
56
+ "Some things never leave (sniffs).",
57
+ "(mumbles) Honestly...",
58
+ "(exhales) It just hurts.",
59
+ "I wish I could say it (sighs)...",
60
+ "If only they knew (inhales)...",
61
+ "(clears throat) No one sees it.",
62
+ "Deep down (sniffs)...",
63
+ "Truth is — (sighs) I’m still there.",
64
+ "Even now... (exhales)",
65
+ "(mumbles) I never said it.",
66
+ "(sniffs)... It stays with me.",
67
+ "Sometimes I wonder (sighs)...",
68
+ "I still carry it (exhales).",
69
+ "To this day — (sniffs) it aches.",
70
+ "(sighs) I remember everything.",
71
+ "Some pain doesn’t show (mumbles)...",
72
+ "Even in silence (exhales)...",
73
+ "That moment still plays (sniffs).",
74
+ "And yet I smile (sighs)..."
75
+ ],
76
+
77
+ "contexts": [
78
+ "(sighs)... and I can’t move on.",
79
+ "like nothing will ever be the same (exhales).",
80
+ "and I keep holding it in (sniffs)...",
81
+ "though no one sees it (sighs).",
82
+ "and I hide it so well (inhales)...",
83
+ "and I still feel it all (exhales).",
84
+ "and I smile like I’m fine (mumbles)...",
85
+ "though I’m breaking inside (sighs).",
86
+ "but I never talk about it (sniffs).",
87
+ "and it follows me everywhere (exhales)...",
88
+ "and I still carry it (sighs)...",
89
+ "like it just happened (inhales).",
90
+ "and I don’t know how to forget (exhales)...",
91
+ "and no one knows the weight (sighs).",
92
+ "even when I laugh (sniffs)...",
93
+ "and it feels so unfair (exhales).",
94
+ "and it still ruins me (sighs)...",
95
+ "and I feel it every day (inhales).",
96
+ "but I try to act normal (mumbles)...",
97
+ "like I’m supposed to be okay (sighs).",
98
+ "and I never got closure (exhales)...",
99
+ "and I miss them every day (sniffs).",
100
+ "but the pain stays the same (sighs)...",
101
+ "and it’s always right there (exhales).",
102
+ "and it floods me suddenly (sniffs)...",
103
+ "and it breaks me in pieces (sighs).",
104
+ "and I hear the echo of it (exhales)...",
105
+ "and I wish I said more (sighs).",
106
+ "and I regret staying quiet (mumbles)...",
107
+ "and I feel so small (sniffs).",
108
+ "and I can’t stop thinking about it (exhales)...",
109
+ "and I just collapse (sighs).",
110
+ "and I keep blaming myself (inhales)...",
111
+ "but it wasn’t even my fault (exhales).",
112
+ "and I replay it over and over (sighs)...",
113
+ "and it plays like a movie (sniffs).",
114
+ "but I never stop hurting (exhales)...",
115
+ "and I keep pretending (sighs).",
116
+ "and no one asks (mumbles)...",
117
+ "and I hate that it’s still here (exhales).",
118
+ "but I can’t let it go (sighs)...",
119
+ "even when I want to (sniffs).",
120
+ "and I can’t breathe sometimes (exhales)...",
121
+ "and my chest aches quietly (sighs).",
122
+ "and it keeps resurfacing (sniffs)...",
123
+ "and I’m so tired of it (exhales).",
124
+ "and I don’t know how to heal (sighs)...",
125
+ "and it knocks the wind out of me (exhales).",
126
+ "and I feel like a ghost (sighs)...",
127
+ "and I smile while breaking (mumbles).",
128
+ "and I don’t talk about it anymore (sniffs)...",
129
+ "and I wake up with it (exhales).",
130
+ "and it seeps into everything (sighs)...",
131
+ "and I lost myself there (exhales).",
132
+ "but no one notices (sighs)...",
133
+ "and I keep it locked up (sniffs).",
134
+ "and I carry it alone (exhales)...",
135
+ "and it’s always just under the surface (sighs).",
136
+ "but I’m still hurting (exhales)...",
137
+ "and it’s louder in the silence (sighs).",
138
+ "and I fall apart slowly (sniffs)...",
139
+ "and I can’t fix it (exhales).",
140
+ "and I still wait for them (sighs)...",
141
+ "and it all feels empty (exhales).",
142
+ "and it doesn’t make sense (sighs)...",
143
+ "but I can’t forget (sniffs).",
144
+ "and it makes me feel hollow (exhales)...",
145
+ "but I carry on anyway (sighs).",
146
+ "and I whisper it to no one (mumbles)...",
147
+ "and I miss who I was (exhales).",
148
+ "and I don’t think it’ll ever stop (sighs)...",
149
+ "and it comes back at night (sniffs).",
150
+ "and it makes me feel cold (exhales)...",
151
+ "but I still try to be okay (sighs).",
152
+ "and I hate talking about it (inhales)...",
153
+ "and I still cry when I’m alone (sniffs).",
154
+ "and I lose words (sighs)...",
155
+ "and I feel fragile (exhales).",
156
+ "and it’s exhausting (sighs)...",
157
+ "but I just nod and smile (mumbles).",
158
+ "and it feels like I’m fading (exhales)...",
159
+ "and it eats away at me (sighs).",
160
+ "but I say nothing (sniffs)...",
161
+ "and I hope they never see (sighs).",
162
+ "and I fake my way through (exhales)...",
163
+ "but I’m breaking slowly (sighs)."
164
+ ],
165
+
166
+ "templates": [
167
+ "{s} {v}, {c}.",
168
+ "{i}, {s} {v}.",
169
+ "{s} {v}. {c}.",
170
+ "{s} {v} — {c}.",
171
+ "{c}. {s} {v}.",
172
+ "{s}. It {v}, {c}.",
173
+ "{s} still... {v}. {c}.",
174
+ "{s} {v}. And it never ends. {c}.",
175
+ "I try to forget, but {s} {v}, {c}.",
176
+ "{i} — {s} {v}, {c}.",
177
+ "{s} {v}. {i}. {c}.",
178
+ "The truth is, {s} {v}.",
179
+ "‘{s} {v}’ — and no one knew. {c}.",
180
+ "{s} just doesn’t go away. {c}.",
181
+ "Even now, {s} {v}, {c}."
182
+ ]
183
+ }
emotion_templates/surprise.json ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ "That message", "This call", "The result", "What she said", "The announcement",
4
+ "The coincidence", "The gift", "His answer", "Her reaction", "The timing",
5
+ "That moment", "The photo", "This discovery", "The unexpected reply",
6
+ "The look on his face", "The email", "That visit", "The headline",
7
+ "The glitch", "This outcome", "The news", "That thing", "The realization",
8
+ "Her arrival", "His silence", "That sound", "That noise", "The blink",
9
+ "This miracle", "The coincidence of dates", "The surprise guest",
10
+ "The reversed roles", "This twist", "The window", "The echo", "The light",
11
+ "The turning point", "The expression", "The instant reply", "The notification",
12
+ "This random thing", "The unexpected gift", "The open door", "That trigger",
13
+ "The unlocked phone", "The old message", "That letter", "The voicemail",
14
+ "The opportunity", "The moment of silence", "The call from nowhere",
15
+ "The found photo", "That chance meeting", "What I saw", "The change in tone",
16
+ "The open drawer", "The sound in the hallway", "The empty room",
17
+ "The random response", "Her sudden shift", "The blinking lights",
18
+ "What he did", "That conversation", "The unplanned moment", "The whisper",
19
+ "That sudden feeling", "That name", "This vibe", "That expression",
20
+ "The door creaking", "The unexpected question", "The win", "This strange timing",
21
+ "The unplugged cable", "The room's silence", "The laugh", "The pause",
22
+ "The sudden music", "The uninvited guest", "The news headline", "The page turn",
23
+ "The camera flash", "The unknown voice", "The stray animal", "The forgotten date",
24
+ "That coincidence", "The twisted sentence", "That eye contact", "This strange sound",
25
+ "The falling object", "The flicker", "The joke", "That mood shift", "This noise",
26
+ "The silent phone", "That little thing", "The hidden truth", "The strange smile"
27
+ ],
28
+ "verbs": [
29
+ "shocked me", "made me freeze", "caught me off guard", "left me speechless",
30
+ "blew my mind", "came out of nowhere", "was totally unexpected", "made my jaw drop",
31
+ "felt surreal", "threw me off", "turned everything around", "changed the whole vibe",
32
+ "reset the tone", "hit me like lightning", "flipped the moment", "was unreal",
33
+ "was bizarre", "seemed impossible", "was both funny and scary", "made me pause",
34
+ "rewired my brain", "was absurd", "challenged everything I thought", "was chaotic",
35
+ "made no sense", "was random", "came with no warning", "felt like a dream",
36
+ "sparked confusion", "was like fiction", "shattered expectations", "reversed my mood",
37
+ "surprised everyone", "felt impossible", "disrupted everything", "popped up suddenly",
38
+ "turned heads", "was pure surprise", "hit like a flash", "flipped the whole thing",
39
+ "triggered laughter", "was hilariously off", "felt like a glitch", "reset the energy",
40
+ "dropped like a bomb", "silenced the room", "was a twist", "stood out completely",
41
+ "just happened", "was nonsense", "was oddly fitting", "was genius", "twisted my logic",
42
+ "shook everyone", "sparked wild reactions", "caught all attention", "stunned me",
43
+ "was beautifully weird", "was madness", "reversed all logic", "was insane",
44
+ "left us reeling", "had no explanation", "felt electric", "hit mid-sentence",
45
+ "made me sit up", "just flipped it", "was comedic", "was magical", "reset my brain",
46
+ "was silly and scary", "snapped everything", "twisted expectations", "felt improvised",
47
+ "landed unexpectedly", "snuck up on us", "was so weird", "broke the silence",
48
+ "sparked awe", "felt foreign", "felt glitched", "confused everyone", "was strangely real",
49
+ "made me react instantly", "was unexpected chaos", "took a wild turn", "made no sense at all",
50
+ "came suddenly", "hit without logic", "felt reversed", "cut through everything",
51
+ "changed the pace", "felt timed perfectly", "felt unscripted", "just crashed in",
52
+ "was the last thing I thought", "bent the moment", "came sharply", "felt like a prank"
53
+ ],
54
+ "interjections": [
55
+ "(gasps) Honestly?!",
56
+ "No way — (laughs)!",
57
+ "Wait... what? (gasps)",
58
+ "(inhales) Seriously—",
59
+ "Oh wow! (whistles)",
60
+ "You're kidding! (laughs)",
61
+ "(gasps) I can't believe it!",
62
+ "What just happened? (chuckle)",
63
+ "Believe it or not — (gasps)",
64
+ "Suddenly... (laughs)",
65
+ "(whistles) Guess what!",
66
+ "Right then — (gasps) boom!",
67
+ "Out of nowhere (laughs)!",
68
+ "Just like that... (gasps)",
69
+ "Then — bam! (chuckle)",
70
+ "(gasps) Without warning!",
71
+ "For real? (laughs)",
72
+ "You won't believe this! (gasps)",
73
+ "Hold up— (whistles)!",
74
+ "(claps) That just happened!"
75
+ ],
76
+
77
+ "contexts": [
78
+ "(gasps)... and I didn’t know what to say!",
79
+ "and everyone froze — (exhales).",
80
+ "and I had to laugh (laughs)!",
81
+ "and I still don’t get it... (chuckle)",
82
+ "and it felt surreal (whistles).",
83
+ "and no one moved — (inhales).",
84
+ "(gasps) and I just... gasped again.",
85
+ "and I was speechless (exhales)...",
86
+ "and I blinked like five times (gasps).",
87
+ "and it echoed in my head — (whistles)",
88
+ "and everyone just stared (inhales)...",
89
+ "and I sat there stunned (gasps)!",
90
+ "and the room went silent — (exhales).",
91
+ "and we all looked around (chuckle)...",
92
+ "and I looked twice (gasps)!",
93
+ "and I was like “what?!” (laughs)",
94
+ "and I couldn’t believe it (exhales)...",
95
+ "and it stuck with me (whistles).",
96
+ "and I still think about it (gasps)...",
97
+ "and I just shook my head (laughs).",
98
+ "and it broke the tension — (chuckle)!",
99
+ "and it threw off everything (exhales)...",
100
+ "and I said “no way!” (gasps)",
101
+ "and we all cracked up (laughs)!",
102
+ "and the vibe was gone — (inhales).",
103
+ "and it changed the day (whistles)...",
104
+ "and I had to rewind (gasps)!",
105
+ "and my jaw dropped — (exhales).",
106
+ "and I froze (gasps)...",
107
+ "and we all lost it (laughs)!",
108
+ "and it messed me up (chuckle)...",
109
+ "and I just stood still (inhales).",
110
+ "and I had no words (gasps)...",
111
+ "and we laughed for a minute (laughs).",
112
+ "and it felt like a dream — (whistles)...",
113
+ "and I blinked in shock (gasps)!",
114
+ "and we all paused (exhales)...",
115
+ "and I just stared (gasps)...",
116
+ "and I couldn’t speak (inhales).",
117
+ "and the shift was instant (gasps)!",
118
+ "and it was hilarious (laughs).",
119
+ "and my brain glitched (chuckle)...",
120
+ "and everyone panicked slightly (gasps).",
121
+ "and I had chills (inhales)...",
122
+ "and it just... happened (exhales).",
123
+ "and it triggered everything (gasps)!",
124
+ "and I remember the silence (whistles)...",
125
+ "and someone whispered “what just happened?” (gasps)",
126
+ "and we all froze (inhales)...",
127
+ "and we burst out laughing (laughs)!",
128
+ "and it flipped the room — (whistles)!",
129
+ "and I had to sit (gasps)...",
130
+ "and it made no sense (chuckle).",
131
+ "and I gasped loudly (gasps)!",
132
+ "and everyone blinked (inhales)...",
133
+ "and I had to walk out (exhales).",
134
+ "and the energy just changed (whistles)...",
135
+ "and it felt like fiction (gasps).",
136
+ "and the moment stuck — (exhales)...",
137
+ "and I tried to process it (chuckle).",
138
+ "and I said nothing (gasps)...",
139
+ "and the response was just shock (inhales).",
140
+ "and I heard my own heartbeat (exhales)...",
141
+ "and we just blinked (gasps).",
142
+ "and it rewrote the mood (whistles)...",
143
+ "and it still plays in my head (gasps)...",
144
+ "and it made the silence louder (exhales).",
145
+ "and I had to sit down (inhales)...",
146
+ "and it just hit me (gasps)!",
147
+ "and the air shifted — (whistles)...",
148
+ "and it startled me (gasps).",
149
+ "and everyone had the same face (chuckle)...",
150
+ "and I was out of words (exhales).",
151
+ "and someone gasped (gasps)!",
152
+ "and the tension cracked (laughs)...",
153
+ "and it made me laugh hard (laughs)!",
154
+ "and we all paused and stared (inhales)...",
155
+ "and I just looked down (exhales).",
156
+ "and we replayed it in our heads (whistles)...",
157
+ "and no one said a word (gasps).",
158
+ "and the timing was wild — (laughs)!",
159
+ "and it was too real (inhales)...",
160
+ "and we couldn’t stop laughing (laughs)!",
161
+ "and I shook my head slowly (chuckle)...",
162
+ "and it crashed everything (gasps).",
163
+ "and I wanted to rewind it (exhales)...",
164
+ "and it threw the moment off (inhales).",
165
+ "and we sat in silence (gasps)...",
166
+ "and my eyes went wide (exhales).",
167
+ "and it stunned everyone (gasps)!",
168
+ "and it landed like a bomb (laughs).",
169
+ "and I burst out mid‑thought (chuckle)...",
170
+ "and it shattered logic (gasps).",
171
+ "and I couldn’t unsee it (exhales)...",
172
+ "and the shift was visible (whistles).",
173
+ "and we blinked hard (gasps)..."
174
+ ],
175
+
176
+ "templates": [
177
+ "{s} {v}, {c}.",
178
+ "{i}! {s} {v}.",
179
+ "{s} {v}. {c}.",
180
+ "{s} {v} — {c}.",
181
+ "{c}. {s} {v}.",
182
+ "{s}. It {v}, {c}.",
183
+ "{s} just... {v}. {c}.",
184
+ "You won’t believe this: {s} {v}, {c}!",
185
+ "And then, {s} {v}. {c}.",
186
+ "‘{s} {v}?’ That’s what happened. {c}.",
187
+ "{i}! — {s} {v}, {c}!",
188
+ "Everyone was quiet. Then {s} {v}, {c}!",
189
+ "{s} {v}. {i}! {c}.",
190
+ "{s} {v}. Can you imagine?! {c}.",
191
+ "So random: {s} {v}, {c}.",
192
+ "Boom! {s} {v}, {c}!",
193
+ "Suddenly, {s} {v}! {c}.",
194
+ "{s} {v} like a lightning bolt! {c}.",
195
+ "I blinked and {s} {v}! {c}.",
196
+ "Out of thin air — {s} {v}! {c}."
197
+ ]
198
+ }
generate_emotion_texts_dataset.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import pandas as pd
4
+ import re
5
+ from datetime import timedelta
6
+ from pathlib import Path
7
+
8
+ # === Загрузка шаблонов ===
9
+ def load_templates_json(templates_dir, emotion):
10
+ path = Path(templates_dir) / f"{emotion}.json"
11
+ if not path.exists():
12
+ raise FileNotFoundError(f"Шаблон для эмоции '{emotion}' не найден: {path}")
13
+ with open(path, "r", encoding="utf-8") as f:
14
+ return json.load(f)
15
+
16
+ # === Генерация текстов с учётом seed и антидубликатов ===
17
+ def generate_emotion_batch(n, template_data, seed=None):
18
+ if seed is not None:
19
+ random.seed(seed)
20
+
21
+ subjects = template_data["subjects"]
22
+ verbs = template_data["verbs"]
23
+ contexts = template_data["contexts"]
24
+ interjections = template_data.get("interjections", [""])
25
+ templates = template_data["templates"]
26
+
27
+ # Допустимые звуковые метки DIA‑TTS
28
+ dia_tags = {
29
+ "(laughs)", "(clears throat)", "(sighs)", "(gasps)", "(coughs)",
30
+ "(singing)", "(sings)", "(mumbles)", "(beep)", "(groans)", "(sniffs)",
31
+ "(claps)", "(screams)", "(inhales)", "(exhales)", "(applause)",
32
+ "(burps)", "(humming)", "(sneezes)", "(chuckle)", "(whistles)"
33
+ }
34
+
35
+ def has_tag(text): return any(tag in text for tag in dia_tags)
36
+ def remove_tags(text):
37
+ for tag in dia_tags:
38
+ text = text.replace(tag, "")
39
+ return text.strip()
40
+
41
+ phrases, attempts = set(), 0
42
+ max_attempts = n * 50
43
+
44
+ while len(phrases) < n and attempts < max_attempts:
45
+ s, v = random.choice(subjects), random.choice(verbs)
46
+ c, i = random.choice(contexts), random.choice(interjections)
47
+ t = random.choice(templates)
48
+
49
+ # ▸ Разрешаем максимум одну звуковую метку на фразу
50
+ if has_tag(i) and has_tag(c):
51
+ if random.random() < .5:
52
+ c = remove_tags(c)
53
+ else:
54
+ i = remove_tags(i)
55
+
56
+ phrase = t.format(s=s, v=v, c=c, i=i)
57
+
58
+ # --- Очистка без разрушения многоточий ---------------------------
59
+ # 1) убрать пробелы перед знаками пунктуации
60
+ phrase = re.sub(r"\s+([,.!?])", r"\1", phrase)
61
+ # 2) превратить двойную точку, КОТОРАЯ не часть троеточия, в одну
62
+ phrase = re.sub(r"(?<!\.)\.\.(?!\.)", ".", phrase)
63
+ # 3) вставить пробел, если после метки сразу идёт слово
64
+ phrase = re.sub(r"\)(?=\w)", ") ", phrase)
65
+ # 4) схлопнуть множественные пробелы и обрезать края
66
+ phrase = re.sub(r"\s{2,}", " ", phrase).strip()
67
+ # ------------------------------------------------------------------
68
+
69
+ if phrase not in phrases:
70
+ phrases.add(phrase)
71
+ attempts += 1
72
+
73
+ if len(phrases) < n:
74
+ print(f"⚠️ Только {len(phrases)} уникальных фраз из {n} запрошенных — возможно, исчерпан пул шаблонов.")
75
+
76
+ return list(phrases)
77
+
78
+ # === Генерация временных меток ===
79
+ def generate_dummy_timestamps(n):
80
+ base_time, result = timedelta(), []
81
+ for idx in range(n):
82
+ start = base_time + timedelta(seconds=idx * 6)
83
+ end = start + timedelta(seconds=5)
84
+ result.append((
85
+ str(start).split(".")[0] + ",000",
86
+ str(end).split(".")[0] + ",000"
87
+ ))
88
+ return result
89
+
90
+ # === Финальная сборка и сохранение CSV ===
91
+ def create_emotion_csv(template_path, emotion_label, out_file, n=1000, seed=None):
92
+ data = load_templates_json(template_path, emotion_label)
93
+ phrases = generate_emotion_batch(n, data, seed)
94
+ timeline = generate_dummy_timestamps(n)
95
+
96
+ emotions = ["neutral", "happy", "sad", "anger", "surprise", "disgust", "fear"]
97
+ label_mask = {e: float(e == emotion_label) for e in emotions}
98
+
99
+ df = pd.DataFrame({
100
+ "video_name": [f"dia_{emotion_label}_utt{i}_synt" for i in range(n)],
101
+ "start_time": [s for s, _ in timeline],
102
+ "end_time" : [e for _, e in timeline],
103
+ "sentiment" : [0] * n,
104
+ **{e: [label_mask[e]] * n for e in emotions},
105
+ "text" : phrases
106
+ })
107
+
108
+ df.to_csv(out_file, index=False)
109
+ print(f"✅ Сохранено {len(df)} строк → {out_file}")
110
+
111
+ # --- Проверка дубликатов ---
112
+ dupes = df[df.duplicated("text", keep=False)]
113
+ if not dupes.empty:
114
+ dupe_file = Path(out_file).with_name(f"duplicates_{emotion_label}.csv")
115
+ dupes.to_csv(dupe_file, index=False)
116
+ print(f"⚠️ Найдено {len(dupes)} повторов → {dupe_file}")
117
+ else:
118
+ print("✅ Дубликатов нет.")
119
+
120
+ # === Точка вход�� ===
121
+ if __name__ == "__main__":
122
+ emotion_config = {
123
+ "anger": 3600,
124
+ "disgust": 4438,
125
+ "fear": 4441,
126
+ "happy": 2966,
127
+ "sad": 4026,
128
+ "surprise": 3504
129
+ }
130
+
131
+ seed, template_path, out_dir = 42, "emotion_templates", "synthetic_data"
132
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
133
+
134
+ for emotion, n in emotion_config.items():
135
+ out_csv = Path(out_dir) / f"meld_synthetic_{emotion}_{n}.csv"
136
+ print(f"\n🔄 Генерация: {emotion} ({n} фраз)")
137
+ create_emotion_csv(template_path, emotion, str(out_csv), n, seed)
generate_synthetic_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import time
4
+ import pandas as pd
5
+ from multiprocessing import Process
6
+ from synthetic_utils.dia_tts_wrapper import DiaTTSWrapper
7
+
8
+
9
+ def process_chunk(chunk_df, emotion, wav_dir, device, chunk_id):
10
+ tts = DiaTTSWrapper(device=device)
11
+ for idx, row in chunk_df.iterrows():
12
+ text = row["text"]
13
+ video_name = row.get("video_name", f"{emotion}_{chunk_id}_{idx}")
14
+ filename_prefix = video_name
15
+
16
+ try:
17
+ result = tts.generate_and_save_audio(
18
+ text=text,
19
+ out_dir=wav_dir,
20
+ filename_prefix=filename_prefix,
21
+ use_timestamp=False,
22
+ skip_if_exists=True,
23
+ max_trim_duration=10.0
24
+ )
25
+ if result is None:
26
+ logging.info(f"[{emotion}] ⏭️ Пропущено: {filename_prefix}.wav")
27
+ else:
28
+ logging.info(f"[{emotion}] ✔ {filename_prefix}.wav")
29
+ except Exception as e:
30
+ logging.error(f"[{emotion}] ❌ Ошибка: {filename_prefix} — {e}")
31
+
32
+
33
+ def generate_from_emotion_csv(
34
+ csv_path: str,
35
+ emotion: str,
36
+ output_dir: str,
37
+ device: str = "cuda",
38
+ max_samples: int = None,
39
+ num_processes: int = 1
40
+ ):
41
+ out_dir = os.path.join(output_dir, emotion)
42
+ wav_dir = os.path.join(out_dir, "wavs")
43
+ os.makedirs(wav_dir, exist_ok=True)
44
+
45
+ logging.info(f"🎙️ Эмоция: '{emotion}' | CSV: {csv_path}")
46
+ logging.info(f"📥 Сохранение в: {wav_dir}")
47
+
48
+ df = pd.read_csv(csv_path)
49
+ if max_samples is not None:
50
+ df = df.sample(n=max_samples)
51
+
52
+ chunk_size = len(df) // num_processes
53
+ chunks = [df.iloc[i*chunk_size : (i+1)*chunk_size] for i in range(num_processes)]
54
+
55
+ remainder = len(df) % num_processes
56
+ if remainder > 0:
57
+ chunks[-1] = pd.concat([chunks[-1], df.iloc[-remainder:]])
58
+
59
+ total_start = time.time()
60
+
61
+ processes = []
62
+ for i, chunk in enumerate(chunks):
63
+ p = Process(target=process_chunk, args=(chunk, emotion, wav_dir, device, i))
64
+ p.start()
65
+ processes.append(p)
66
+
67
+ for p in processes:
68
+ p.join()
69
+
70
+ total_elapsed = time.time() - total_start
71
+ logging.info(f"✅ Эмоция '{emotion}' завершена | чанков: {num_processes} | ⏱️ {total_elapsed:.1f} сек\n")
main.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ # coding: utf-8
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import datetime
7
+ import whisper
8
+ import toml
9
+ # os.environ["HF_HOME"] = "models"
10
+
11
+ from utils.config_loader import ConfigLoader
12
+ from utils.logger_setup import setup_logger
13
+ from utils.search_utils import greedy_search, exhaustive_search
14
+ from training.train_utils import (
15
+ make_dataset_and_loader,
16
+ train_once
17
+ )
18
+ from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor
19
+
20
+ def main():
21
+
22
+ # Грузим конфиг
23
+ base_config = ConfigLoader("config.toml")
24
+
25
+ model_name = base_config.model_name.replace("/", "_").replace(" ", "_").lower()
26
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
27
+ results_dir = f"results_{model_name}_{timestamp}"
28
+ os.makedirs(results_dir, exist_ok=True)
29
+
30
+ epochlog_dir = os.path.join(results_dir, "metrics_by_epoch")
31
+ os.makedirs(epochlog_dir, exist_ok=True)
32
+
33
+ # Настраиваем logging
34
+ log_file = os.path.join(results_dir, "session_log.txt")
35
+ setup_logger(logging.DEBUG, log_file=log_file)
36
+
37
+ # Грузим конфиг
38
+ base_config.show_config()
39
+
40
+ shutil.copy("config.toml", os.path.join(results_dir, "config_copy.toml"))
41
+ # Файл, куда будет писать наш жадный поиск
42
+ overrides_file = os.path.join(results_dir, "overrides.txt")
43
+ csv_prefix = os.path.join(epochlog_dir, "metrics_epochlog")
44
+
45
+ audio_feature_extractor= PretrainedAudioEmbeddingExtractor(base_config)
46
+ text_feature_extractor = PretrainedTextEmbeddingExtractor(base_config)
47
+
48
+ # Инициализируем Whisper-модель один раз
49
+ logging.info(f"Инициализация Whisper: модель={base_config.whisper_model}, устройство={base_config.whisper_device}")
50
+ whisper_model = whisper.load_model(base_config.whisper_model, device=base_config.whisper_device)
51
+
52
+ # Делаем датасеты/лоадеры
53
+ # Общий train_loader
54
+ _, train_loader = make_dataset_and_loader(base_config, "train", audio_feature_extractor, text_feature_extractor, whisper_model)
55
+
56
+ # Раздельные dev/test
57
+ dev_loaders = []
58
+ test_loaders = []
59
+
60
+ for dataset_name in base_config.datasets:
61
+ _, dev_loader = make_dataset_and_loader(base_config, "dev", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)
62
+ _, test_loader = make_dataset_and_loader(base_config, "test", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)
63
+
64
+ dev_loaders.append((dataset_name, dev_loader))
65
+ test_loaders.append((dataset_name, test_loader))
66
+
67
+ if base_config.prepare_only:
68
+ logging.info("== Режим prepare_only: только подготовка данных, без обучения ==")
69
+ return
70
+
71
+ search_config = toml.load("search_params.toml")
72
+ param_grid = dict(search_config["grid"])
73
+ default_values = dict(search_config["defaults"])
74
+
75
+ if base_config.search_type == "greedy":
76
+ greedy_search(
77
+ base_config = base_config,
78
+ train_loader = train_loader,
79
+ dev_loader = dev_loaders,
80
+ test_loader = test_loaders,
81
+ train_fn = train_once,
82
+ overrides_file = overrides_file,
83
+ param_grid = param_grid,
84
+ default_values = default_values,
85
+ csv_prefix = csv_prefix
86
+ )
87
+
88
+ elif base_config.search_type == "exhaustive":
89
+ exhaustive_search(
90
+ base_config = base_config,
91
+ train_loader = train_loader,
92
+ dev_loader = dev_loaders,
93
+ test_loader = test_loaders,
94
+ train_fn = train_once,
95
+ overrides_file = overrides_file,
96
+ param_grid = param_grid,
97
+ csv_prefix = csv_prefix
98
+ )
99
+
100
+ elif base_config.search_type == "none":
101
+ logging.info("== Режим одиночной тренировки (без поиска параметров) ==")
102
+
103
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
104
+ csv_file_path = f"{csv_prefix}_single_{timestamp}.csv"
105
+
106
+ train_once(
107
+ config = base_config,
108
+ train_loader = train_loader,
109
+ dev_loaders = dev_loaders,
110
+ test_loaders = test_loaders,
111
+ metrics_csv_path = csv_file_path
112
+ )
113
+
114
+ else:
115
+ raise ValueError(f"⛔️ Неверное значение search_type в конфиге: '{base_config.search_type}'. Используй 'greedy', 'exhaustive' или 'none'.")
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (134 Bytes). View file
 
models/__pycache__/help_layers.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
models/__pycache__/models.cpython-310.pyc ADDED
Binary file (28.9 kB). View file
 
models/help_layers.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ import numpy as np
7
+ import math
8
+ from torch.nn.functional import silu
9
+ from torch.nn.functional import softplus
10
+ from einops import rearrange, einsum
11
+ from torch import Tensor
12
+ from torch_geometric.nn import GATConv, RGCNConv, TransformerConv
13
+
14
+ class PositionWiseFeedForward(nn.Module):
15
+ def __init__(self, input_dim, hidden_dim, dropout=0.1):
16
+ super().__init__()
17
+ self.layer_1 = nn.Linear(input_dim, hidden_dim)
18
+ self.layer_2 = nn.Linear(hidden_dim, input_dim)
19
+ self.dropout = nn.Dropout(dropout)
20
+
21
+ def forward(self, x):
22
+ x = self.layer_1(x)
23
+ x = F.gelu(x) # Более плавная активация
24
+ x = self.dropout(x)
25
+ return self.layer_2(x)
26
+
27
+
28
+ class AddAndNorm(nn.Module):
29
+ def __init__(self, input_dim, dropout=0.1):
30
+ super().__init__()
31
+ self.norm = nn.LayerNorm(input_dim)
32
+ self.dropout = nn.Dropout(dropout)
33
+
34
+ def forward(self, x, residual):
35
+ return self.norm(x + self.dropout(residual))
36
+
37
+
38
+ class PositionalEncoding(nn.Module):
39
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
40
+ super().__init__()
41
+ self.dropout = nn.Dropout(p=dropout)
42
+
43
+ position = torch.arange(max_len).unsqueeze(1)
44
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
45
+ pe = torch.zeros(max_len, d_model)
46
+ pe[:, 0::2] = torch.sin(position * div_term)
47
+ pe[:, 1::2] = torch.cos(position * div_term)
48
+
49
+ self.register_buffer("pe", pe)
50
+
51
+ def forward(self, x):
52
+ x = x + self.pe[: x.size(1)].detach() # Отключаем градиенты
53
+ return self.dropout(x)
54
+
55
+
56
+ class TransformerEncoderLayer(nn.Module):
57
+ def __init__(self, input_dim, num_heads, dropout=0.1, positional_encoding=False):
58
+ super().__init__()
59
+ self.input_dim = input_dim
60
+ self.self_attention = nn.MultiheadAttention(input_dim, num_heads, dropout=dropout, batch_first=True)
61
+ # self.self_attention = MHA(
62
+ # embed_dim=input_dim,
63
+ # num_heads=num_heads,
64
+ # dropout=dropout,
65
+ # # bias=True,
66
+ # use_flash_attn=True
67
+ # )
68
+ self.feed_forward = PositionWiseFeedForward(input_dim, input_dim, dropout=dropout)
69
+ self.add_norm_after_attention = AddAndNorm(input_dim, dropout=dropout)
70
+ self.add_norm_after_ff = AddAndNorm(input_dim, dropout=dropout)
71
+ self.positional_encoding = PositionalEncoding(input_dim) if positional_encoding else None
72
+
73
+ def forward(self, key, value, query):
74
+ if self.positional_encoding:
75
+ key = self.positional_encoding(key)
76
+ value = self.positional_encoding(value)
77
+ query = self.positional_encoding(query)
78
+
79
+ attn_output, _ = self.self_attention(query, key, value, need_weights=False)
80
+ # attn_output = self.self_attention(query, key, value)
81
+
82
+ x = self.add_norm_after_attention(attn_output, query)
83
+
84
+ ff_output = self.feed_forward(x)
85
+ x = self.add_norm_after_ff(ff_output, x)
86
+
87
+ return x
88
+
89
+ class GAL(nn.Module):
90
+ def __init__(self, input_dim_F1, input_dim_F2, gated_dim, dropout_rate):
91
+ super(GAL, self).__init__()
92
+
93
+ self.WF1 = nn.Parameter(torch.Tensor(input_dim_F1, gated_dim))
94
+ self.WF2 = nn.Parameter(torch.Tensor(input_dim_F2, gated_dim))
95
+
96
+ init.xavier_uniform_(self.WF1)
97
+ init.xavier_uniform_(self.WF2)
98
+
99
+ dim_size_f = input_dim_F1 + input_dim_F2
100
+
101
+ self.WF = nn.Parameter(torch.Tensor(dim_size_f, gated_dim))
102
+
103
+ init.xavier_uniform_(self.WF)
104
+
105
+ self.dropout = nn.Dropout(dropout_rate)
106
+
107
+ def forward(self, f1, f2):
108
+
109
+ h_f1 = self.dropout(torch.tanh(torch.matmul(f1, self.WF1)))
110
+ h_f2 = self.dropout(torch.tanh(torch.matmul(f2, self.WF2)))
111
+ # print(h_f1.shape, h_f2.shape, self.WF.shape, torch.cat([f1, f2], dim=1).shape)
112
+ z_f = torch.softmax(self.dropout(torch.matmul(torch.cat([f1, f2], dim=1), self.WF)), dim=1)
113
+ h_f = z_f*h_f1 + (1 - z_f)*h_f2
114
+ return h_f
115
+
116
+ class GraphFusionLayer(nn.Module):
117
+ def __init__(self, hidden_dim, dropout=0.0, heads=2, out_mean=True):
118
+ super().__init__()
119
+ self.out_mean = out_mean
120
+ # # Проекционные слои для признаков
121
+ self.proj_audio = nn.Sequential(
122
+ nn.Linear(hidden_dim, hidden_dim),
123
+ nn.LayerNorm(hidden_dim),
124
+ nn.Dropout(dropout)
125
+ )
126
+ self.proj_text = nn.Sequential(
127
+ nn.Linear(hidden_dim, hidden_dim),
128
+ nn.LayerNorm(hidden_dim),
129
+ nn.Dropout(dropout)
130
+ )
131
+
132
+ # Графовые слои
133
+ self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads)
134
+ self.gat2 = GATConv(hidden_dim*heads, hidden_dim)
135
+
136
+ # Финальная проекция
137
+ self.fc = nn.Sequential(
138
+ nn.Linear(hidden_dim, hidden_dim),
139
+ nn.LayerNorm(hidden_dim),
140
+ nn.Dropout(dropout)
141
+ )
142
+
143
+ def build_complete_graph(self, num_nodes):
144
+ # Создаем полный граф (каждый узел соединен со всеми)
145
+ edge_index = []
146
+ for i in range(num_nodes):
147
+ for j in range(num_nodes):
148
+ if i != j:
149
+ edge_index.append([i, j])
150
+ return torch.tensor(edge_index).t().contiguous()
151
+
152
+ def forward(self, audio_stats, text_stats):
153
+ """
154
+ audio_stats: [batch_size, hidden_dim]
155
+ text_stats: [batch_size, hidden_dim]
156
+ """
157
+ batch_size = audio_stats.size(0)
158
+
159
+ # Проекция признаков
160
+ x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim]
161
+ x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim]
162
+
163
+ # Объединение узлов (аудио и текст попеременно)
164
+ nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim]
165
+ nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim]
166
+
167
+ # Построение графа (полный граф для каждого элемента батча)
168
+ edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст
169
+ edge_index = edge_index.to(audio_stats.device)
170
+
171
+ # Применение GAT
172
+ x = F.relu(self.gat1(nodes, edge_index))
173
+ x = self.gat2(x, edge_index)
174
+
175
+ # Разделяем обратно аудио и текст
176
+ x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim]
177
+
178
+ if self.out_mean:
179
+ # Усреднение по модальностям
180
+ fused = torch.mean(x, dim=1) # [batch_size, hidden_dim]
181
+
182
+ return self.fc(fused)
183
+ else:
184
+ return x
185
+
186
+ class GraphFusionLayerAtt(nn.Module):
187
+ def __init__(self, hidden_dim, heads=2):
188
+ super().__init__()
189
+ # Проекционные слои для признаков
190
+ self.proj_audio = nn.Linear(hidden_dim, hidden_dim)
191
+ self.proj_text = nn.Linear(hidden_dim, hidden_dim)
192
+
193
+ # Графовые слои
194
+ self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads)
195
+ self.gat2 = GATConv(hidden_dim*heads, hidden_dim)
196
+
197
+ self.attention_fusion = nn.Linear(hidden_dim, 1)
198
+
199
+ # Финальная проекция
200
+ self.fc = nn.Linear(hidden_dim, hidden_dim)
201
+
202
+ def build_complete_graph(self, num_nodes):
203
+ # Создаем полный граф (каждый узел соединен со всеми)
204
+ edge_index = []
205
+ for i in range(num_nodes):
206
+ for j in range(num_nodes):
207
+ if i != j:
208
+ edge_index.append([i, j])
209
+ return torch.tensor(edge_index).t().contiguous()
210
+
211
+ def forward(self, audio_stats, text_stats):
212
+ """
213
+ audio_stats: [batch_size, hidden_dim]
214
+ text_stats: [batch_size, hidden_dim]
215
+ """
216
+ batch_size = audio_stats.size(0)
217
+
218
+ # Проекция признаков
219
+ x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim]
220
+ x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim]
221
+
222
+ # Объединение узлов (аудио и текст попеременно)
223
+ nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim]
224
+ nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim]
225
+
226
+ # Построение графа (полный граф для каждого элемента батча)
227
+ edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст
228
+ edge_index = edge_index.to(audio_stats.device)
229
+
230
+ # Применение GAT
231
+ x = F.relu(self.gat1(nodes, edge_index))
232
+ x = self.gat2(x, edge_index)
233
+
234
+ # Разделяем обратно аудио и текст
235
+ x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim]
236
+
237
+ # Усреднение по модальностям
238
+ # fused = torch.mean(x, dim=1) # [batch_size, hidden_dim]
239
+
240
+ weights = F.softmax(self.attention_fusion(x), dim=1)
241
+ fused = torch.sum(weights * x, dim=1) # [batch_size, hidden_dim]
242
+
243
+ return self.fc(fused)
244
+
245
+ # Full code see https://github.com/leson502/CORECT_EMNLP2023/tree/master/corect/model
246
+
247
+ class GNN(nn.Module):
248
+ def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, gcn_conv, use_graph_transformer, graph_transformer_nheads):
249
+ super(GNN, self).__init__()
250
+ self.gcn_conv = gcn_conv
251
+ self.use_graph_transformer=use_graph_transformer
252
+
253
+ self.num_modals = num_modals
254
+
255
+ if self.gcn_conv == "rgcn":
256
+ print("GNN --> Use RGCN")
257
+ self.conv1 = RGCNConv(g_dim, h1_dim, num_relations)
258
+
259
+ if self.use_graph_transformer:
260
+ print("GNN --> Use Graph Transformer")
261
+
262
+ in_dim = h1_dim
263
+
264
+ self.conv2 = TransformerConv(in_dim, h2_dim, heads=graph_transformer_nheads, concat=True)
265
+ self.bn = nn.BatchNorm1d(h2_dim * graph_transformer_nheads)
266
+
267
+
268
+ def forward(self, node_features, node_type, edge_index, edge_type):
269
+ print(node_features.shape, edge_index.shape, edge_type.shape)
270
+
271
+ if self.gcn_conv == "rgcn":
272
+ x = self.conv1(node_features, edge_index, edge_type)
273
+
274
+ if self.use_graph_transformer:
275
+ x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index)))
276
+
277
+ return x
278
+
279
+ class GraphModel(nn.Module):
280
+ def __init__(self, g_dim, h1_dim, h2_dim, device, modalities, wp, wf, edge_type, gcn_conv, use_graph_transformer, graph_transformer_nheads):
281
+ super(GraphModel, self).__init__()
282
+
283
+ self.n_modals = len(modalities)
284
+ self.wp = wp
285
+ self.wf = wf
286
+ self.device = device
287
+ self.gcn_conv=gcn_conv
288
+ self.use_graph_transformer=use_graph_transformer
289
+
290
+ print(f"GraphModel --> Edge type: {edge_type}")
291
+ print(f"GraphModel --> Window past: {wp}")
292
+ print(f"GraphModel --> Window future: {wf}")
293
+ edge_temp = "temp" in edge_type
294
+ edge_multi = "multi" in edge_type
295
+
296
+ edge_type_to_idx = {}
297
+
298
+ if edge_temp:
299
+ temporal = [-1, 1, 0]
300
+ for j in temporal:
301
+ for k in range(self.n_modals):
302
+ edge_type_to_idx[str(j) + str(k) + str(k)] = len(edge_type_to_idx)
303
+ else:
304
+ for j in range(self.n_modals):
305
+ edge_type_to_idx['0' + str(j) + str(j)] = len(edge_type_to_idx)
306
+
307
+ if edge_multi:
308
+ for j in range(self.n_modals):
309
+ for k in range(self.n_modals):
310
+ if (j != k):
311
+ edge_type_to_idx['0' + str(j) + str(k)] = len(edge_type_to_idx)
312
+
313
+ self.edge_type_to_idx = edge_type_to_idx
314
+ self.num_relations = len(edge_type_to_idx)
315
+ self.edge_multi = edge_multi
316
+ self.edge_temp = edge_temp
317
+
318
+ self.gnn = GNN(g_dim, h1_dim, h2_dim, self.num_relations, self.n_modals, self.gcn_conv, self.use_graph_transformer, graph_transformer_nheads)
319
+
320
+
321
+ def forward(self, x, lengths):
322
+ # print(f"x shape: {x.shape}, lengths: {lengths}, lengths.shape: {lengths.shape}")
323
+
324
+ node_features = feature_packing(x, lengths)
325
+
326
+ node_type, edge_index, edge_type, edge_index_lengths = \
327
+ self.batch_graphify(lengths)
328
+
329
+ out_gnn = self.gnn(node_features, node_type, edge_index, edge_type)
330
+ out_gnn = multi_concat(out_gnn, lengths, self.n_modals)
331
+
332
+ return out_gnn
333
+
334
+ def batch_graphify(self, lengths):
335
+
336
+ node_type, edge_index, edge_type, edge_index_lengths = [], [], [], []
337
+ edge_type_lengths = [0] * len(self.edge_type_to_idx)
338
+
339
+ lengths = lengths.tolist()
340
+
341
+ sum_length = 0
342
+ total_length = sum(lengths)
343
+ batch_size = len(lengths)
344
+
345
+ for k in range(self.n_modals):
346
+ for j in range(batch_size):
347
+ cur_len = lengths[j]
348
+ node_type.extend([k] * cur_len)
349
+
350
+ for j in range(batch_size):
351
+ cur_len = lengths[j]
352
+
353
+ perms = self.edge_perms(cur_len, total_length)
354
+ edge_index_lengths.append(len(perms))
355
+
356
+ for item in perms:
357
+ vertices = item[0]
358
+ neighbor = item[1]
359
+ edge_index.append(torch.tensor([vertices + sum_length, neighbor + sum_length]))
360
+
361
+ if vertices % total_length > neighbor % total_length:
362
+ temporal_type = 1
363
+ elif vertices % total_length < neighbor % total_length:
364
+ temporal_type = -1
365
+ else:
366
+ temporal_type = 0
367
+ edge_type.append(self.edge_type_to_idx[str(temporal_type)
368
+ + str(node_type[vertices + sum_length])
369
+ + str(node_type[neighbor + sum_length])])
370
+
371
+ sum_length += cur_len
372
+
373
+ node_type = torch.tensor(node_type).long().to(self.device)
374
+ edge_index = torch.stack(edge_index).t().contiguous().to(self.device) # [2, E]
375
+ edge_type = torch.tensor(edge_type).long().to(self.device) # [E]
376
+ edge_index_lengths = torch.tensor(edge_index_lengths).long().to(self.device) # [B]
377
+
378
+ return node_type, edge_index, edge_type, edge_index_lengths
379
+
380
+ def edge_perms(self, length, total_lengths):
381
+
382
+ all_perms = set()
383
+ array = np.arange(length)
384
+ for j in range(length):
385
+ if self.wp == -1 and self.wf == -1:
386
+ eff_array = array
387
+ elif self.wp == -1: # use all past context
388
+ eff_array = array[: min(length, j + self.wf)]
389
+ elif self.wf == -1: # use all future context
390
+ eff_array = array[max(0, j - self.wp) :]
391
+ else:
392
+ eff_array = array[
393
+ max(0, j - self.wp) : min(length, j + self.wf)
394
+ ]
395
+ perms = set()
396
+
397
+
398
+ for k in range(self.n_modals):
399
+ node_index = j + k * total_lengths
400
+ if self.edge_temp == True:
401
+ for item in eff_array:
402
+ perms.add((node_index, item + k * total_lengths))
403
+ else:
404
+ perms.add((node_index, node_index))
405
+ if self.edge_multi == True:
406
+ for l in range(self.n_modals):
407
+ if l != k:
408
+ perms.add((node_index, j + l * total_lengths))
409
+
410
+ all_perms = all_perms.union(perms)
411
+
412
+ return list(all_perms)
413
+
414
+ def feature_packing(multimodal_feature, lengths):
415
+ batch_size = lengths.size(0)
416
+ # print(multimodal_feature.shape, batch_size, lengths.shape)
417
+ node_features = []
418
+
419
+ for feature in multimodal_feature:
420
+ for j in range(batch_size):
421
+ cur_len = lengths[j].item()
422
+ # print(f"feature.shape: {feature.shape}, j: {j}, cur_len: {cur_len}")
423
+ node_features.append(feature[j,:cur_len])
424
+
425
+ node_features = torch.cat(node_features, dim=0)
426
+
427
+ return node_features
428
+
429
+ def multi_concat(nodes_feature, lengths, n_modals):
430
+ sum_length = lengths.sum().item()
431
+ feature = []
432
+ for j in range(n_modals):
433
+ feature.append(nodes_feature[j * sum_length : (j + 1) * sum_length])
434
+
435
+ feature = torch.cat(feature, dim=-1)
436
+
437
+ return feature
438
+
439
+ class RMSNorm(nn.Module):
440
+ def __init__(self, d_model: int, eps: float = 1e-8) -> None:
441
+ super().__init__()
442
+ self.eps = eps
443
+ self.weight = nn.Parameter(torch.ones(d_model))
444
+
445
+ def forward(self, x: Tensor) -> Tensor:
446
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
447
+
448
+ class Mamba(nn.Module):
449
+ def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, pooling=None):
450
+ super().__init__()
451
+ mamba_par = {
452
+ 'd_input' : d_input,
453
+ 'd_model' : d_model,
454
+ 'd_state' : d_state,
455
+ 'd_discr' : d_discr,
456
+ 'ker_size': ker_size
457
+ }
458
+ self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
459
+ self.fc_out = nn.Linear(d_input, num_classes)
460
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
461
+
462
+ def forward(self, seq, cache=None):
463
+ seq = torch.tensor(self.embedding(seq)).to(self.device)
464
+ for mamba, norm in self.layers:
465
+ out, cache = mamba(norm(seq), cache)
466
+ seq = out + seq
467
+ return self.fc_out(seq.mean(dim = 1))
468
+
469
+ class MambaBlock(nn.Module):
470
+ def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
471
+ super().__init__()
472
+ d_discr = d_discr if d_discr is not None else d_model // 16
473
+ self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
474
+ self.out_proj = nn.Linear(d_model, d_input, bias=False)
475
+ self.s_B = nn.Linear(d_model, d_state, bias=False)
476
+ self.s_C = nn.Linear(d_model, d_state, bias=False)
477
+ self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
478
+ self.conv = nn.Conv1d(
479
+ in_channels=d_model,
480
+ out_channels=d_model,
481
+ kernel_size=ker_size,
482
+ padding=ker_size - 1,
483
+ groups=d_model,
484
+ bias=True,
485
+ )
486
+ self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
487
+ self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
488
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
489
+
490
+ def forward(self, seq, cache=None):
491
+ b, l, d = seq.shape
492
+ (prev_hid, prev_inp) = cache if cache is not None else (None, None)
493
+ a, b = self.in_proj(seq).chunk(2, dim=-1)
494
+ x = rearrange(a, 'b l d -> b d l')
495
+ x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
496
+ a = self.conv(x)[..., :l]
497
+ a = rearrange(a, 'b d l -> b l d')
498
+ a = silu(a)
499
+ a, hid = self.ssm(a, prev_hid=prev_hid)
500
+ b = silu(b)
501
+ out = a * b
502
+ out = self.out_proj(out)
503
+ if cache:
504
+ cache = (hid.squeeze(), x[..., 1:])
505
+ return out, cache
506
+
507
+ def ssm(self, seq, prev_hid):
508
+ A = -self.A
509
+ D = +self.D
510
+ B = self.s_B(seq)
511
+ C = self.s_C(seq)
512
+ s = softplus(D + self.s_D(seq))
513
+ A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
514
+ B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
515
+ X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
516
+ hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
517
+ out = einsum(hid, C, 'b l d s, b l s -> b l d')
518
+ out = out + D * seq
519
+ return out, hid
520
+
521
+ def _hid_states(self, A, X, prev_hid=None):
522
+ b, l, d, s = A.shape
523
+ A = rearrange(A, 'b l d s -> l b d s')
524
+ X = rearrange(X, 'b l d s -> l b d s')
525
+ if prev_hid is not None:
526
+ return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
527
+ h = torch.zeros(b, d, s, device=self.device)
528
+ return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)
models/models.py ADDED
@@ -0,0 +1,1700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .help_layers import TransformerEncoderLayer, GAL,GraphFusionLayer, GraphFusionLayerAtt, MambaBlock, RMSNorm
6
+
7
+ class PredictionsFusion(nn.Module):
8
+ def __init__(self, num_matrices=2, num_classes=7):
9
+ super(PredictionsFusion, self).__init__()
10
+ self.weights = nn.Parameter(torch.rand(num_matrices, num_classes))
11
+
12
+ def forward(self, pred):
13
+ normalized_weights = torch.softmax(self.weights, dim=0)
14
+ weighted_matrix = sum(mat * normalized_weights[i] for i, mat in enumerate(pred))
15
+ return weighted_matrix
16
+
17
+ class MultiModalTransformer_v3(nn.Module):
18
+ def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
19
+ super(MultiModalTransformer_v3, self).__init__()
20
+
21
+ self.mode = mode
22
+
23
+ self.hidden_dim = hidden_dim
24
+
25
+ # Проекционные слои
26
+ # self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
27
+
28
+ # self.audio_proj = nn.Sequential(
29
+ # nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
30
+ # nn.LayerNorm(hidden_dim),
31
+ # nn.Dropout(dropout)
32
+ # )
33
+
34
+ self.audio_proj = nn.Sequential(
35
+ nn.Conv1d(audio_dim, hidden_dim, 1),
36
+ nn.GELU(),
37
+ )
38
+
39
+ self.text_proj = nn.Sequential(
40
+ nn.Conv1d(text_dim, hidden_dim, 1),
41
+ nn.GELU(),
42
+ )
43
+ # self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
44
+
45
+ # self.text_proj = nn.Sequential(
46
+ # nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
47
+ # nn.LayerNorm(hidden_dim),
48
+ # nn.Dropout(dropout)
49
+ # )
50
+
51
+ # Механизмы внимания
52
+ self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
53
+ ])
54
+ self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
55
+ ])
56
+
57
+ # Классификатор
58
+ # self.classifier = nn.Sequential(
59
+ # nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features),
60
+ # nn.ReLU(),
61
+ # nn.Linear(out_features, num_classes)
62
+ # )
63
+
64
+ self.classifier = nn.Sequential(
65
+ nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features),
66
+ # nn.LayerNorm(out_features),
67
+ # nn.GELU(),
68
+ # nn.Dropout(dropout),
69
+ nn.Linear(out_features, num_classes)
70
+ )
71
+
72
+ # self._init_weights()
73
+
74
+ def forward(self, audio_features, text_features):
75
+ # Преобразование размерностей
76
+ audio_features = audio_features.float()
77
+ text_features = text_features.float()
78
+
79
+ # audio_features = self.audio_proj(audio_features)
80
+ # text_features = self.text_proj(text_features)
81
+ audio_features = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1)
82
+ text_features = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1)
83
+
84
+ # Адаптивная пуллинг до минимальной длины
85
+ min_seq_len = min(audio_features.size(1), text_features.size(1))
86
+ audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1)
87
+ text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1)
88
+
89
+ # Трансформерные блоки
90
+ for i in range(len(self.audio_to_text_attn)):
91
+ attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
92
+ attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
93
+ audio_features += attn_audio
94
+ text_features += attn_text
95
+
96
+ # Статистики
97
+ std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
98
+ std_text, mean_text = torch.std_mean(attn_text, dim=1)
99
+
100
+ # Классификация
101
+ if self.mode == 'mean':
102
+ return self.classifier(torch.cat([mean_audio, mean_audio], dim=1))
103
+ else:
104
+ return self.classifier(torch.cat([mean_audio, std_audio, mean_text, std_text], dim=1))
105
+
106
+ def _init_weights(self):
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Linear):
109
+ nn.init.xavier_uniform_(m.weight)
110
+ if m.bias is not None:
111
+ nn.init.constant_(m.bias, 0)
112
+ elif isinstance(m, nn.LayerNorm):
113
+ nn.init.constant_(m.weight, 1)
114
+ nn.init.constant_(m.bias, 0)
115
+
116
+ class MultiModalTransformer_v4(nn.Module):
117
+ def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
118
+ super(MultiModalTransformer_v4, self).__init__()
119
+
120
+ self.mode = mode
121
+
122
+ self.hidden_dim = hidden_dim
123
+
124
+ # Проекционные слои
125
+ self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
126
+ self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
127
+
128
+ # Механизмы внимания
129
+ self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
130
+ ])
131
+ self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
132
+ ])
133
+
134
+ # Графовое слияние вместо GAL
135
+ if self.mode == 'mean':
136
+ self.graph_fusion = GraphFusionLayer(hidden_dim, heads=num_graph_heads)
137
+ else:
138
+ self.graph_fusion = GraphFusionLayer(hidden_dim*2, heads=num_graph_heads)
139
+
140
+ # Классификатор
141
+ self.classifier = nn.Sequential(
142
+ nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features),
143
+ nn.ReLU(),
144
+ nn.Linear(out_features, num_classes)
145
+ )
146
+
147
+ def forward(self, audio_features, text_features):
148
+ # Преобразование размерностей
149
+ audio_features = audio_features.float()
150
+ text_features = text_features.float()
151
+
152
+ audio_features = self.audio_proj(audio_features)
153
+ text_features = self.text_proj(text_features)
154
+
155
+ # Адаптивная пуллинг до минимальной длины
156
+ min_seq_len = min(audio_features.size(1), text_features.size(1))
157
+ audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1)
158
+ text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1)
159
+
160
+ # Трансформерные блоки
161
+ for i in range(len(self.audio_to_text_attn)):
162
+ attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
163
+ attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
164
+ audio_features += attn_audio
165
+ text_features += attn_text
166
+
167
+ # Статистики
168
+ std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
169
+ std_text, mean_text = torch.std_mean(attn_text, dim=1)
170
+
171
+ # Графовое слияние статистик
172
+ if self.mode == 'mean':
173
+ h_ta = self.graph_fusion(mean_audio, mean_text)
174
+ else:
175
+ h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1))
176
+
177
+ # Классификация
178
+ return self.classifier(h_ta)
179
+
180
+ class MultiModalTransformer_v5(nn.Module):
181
+ def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, tr_layer_number=1, positional_encoding=True, dropout=0, mode='mean', device="cuda", out_features=128, num_classes=7):
182
+ super(MultiModalTransformer_v5, self).__init__()
183
+
184
+ self.hidden_dim = hidden_dim
185
+ self.mode = mode
186
+
187
+ # Приведение к общей размерности (адаптивные проекции)
188
+ self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
189
+ self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
190
+
191
+ # Механизмы внимания
192
+
193
+ self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
194
+ ])
195
+ self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
196
+ ])
197
+
198
+ # Гейтед аттеншн
199
+ if self.mode == 'mean':
200
+ self.gal = GAL(hidden_dim, hidden_dim, hidden_dim_gated)
201
+ else:
202
+ self.gal = GAL(hidden_dim*2, hidden_dim*2, hidden_dim_gated)
203
+
204
+ # Классификатор
205
+ self.classifier = nn.Sequential(
206
+ nn.Linear(hidden_dim, out_features),
207
+ nn.ReLU(),
208
+ nn.Linear(out_features, num_classes)
209
+ )
210
+
211
+ def forward(self, audio_features, text_features):
212
+ bs, seq_audio, audio_feat_dim = audio_features.shape
213
+ bs, seq_text, text_feat_dim = text_features.shape
214
+
215
+ text_features = text_features.to(torch.float32)
216
+ audio_features = audio_features.to(torch.float32)
217
+
218
+ # Приведение размерности
219
+ audio_features = self.audio_proj(audio_features) # (bs, seq_audio, hidden_dim)
220
+ text_features = self.text_proj(text_features) # (bs, seq_text, hidden_dim)
221
+
222
+ # Определяем минимальную длину последовательности
223
+ min_seq_len = min(seq_audio, seq_text)
224
+
225
+ # Усреднение до минимальной длины
226
+ audio_features = F.adaptive_avg_pool2d(audio_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1)
227
+ text_features = F.adaptive_avg_pool2d(text_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1)
228
+
229
+ # Трансформерные блоки
230
+ for i in range(len(self.audio_to_text_attn)):
231
+ attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
232
+ attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
233
+ audio_features += attn_audio
234
+ text_features += attn_text
235
+
236
+ # Статистики
237
+ std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
238
+ std_text, mean_text = torch.std_mean(attn_text, dim=1)
239
+
240
+ # # Гейтед аттеншн
241
+ # h_audio = torch.tanh(self.Wa(torch.cat([min_audio, std_audio], dim=1)))
242
+ # h_text = torch.tanh(self.Wt(torch.cat([min_text, std_text], dim=1)))
243
+ # z_ta = torch.sigmoid(self.W_at(torch.cat([min_audio, std_audio, min_text, std_text], dim=1)))
244
+ # h_ta = z_ta * h_text + (1 - z_ta) * h_audio
245
+ if self.mode == 'mean':
246
+ h_ta = self.gal(mean_audio, mean_text)
247
+ else:
248
+ h_ta = self.gal(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1))
249
+
250
+ # Классификация
251
+ output = self.classifier(h_ta)
252
+ return output
253
+
254
+ class MultiModalTransformer_v7(nn.Module):
255
+ def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, num_heads=2, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
256
+ super(MultiModalTransformer_v7, self).__init__()
257
+
258
+ self.mode = mode
259
+
260
+ self.hidden_dim = hidden_dim
261
+
262
+ # Проекционные слои
263
+ self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity()
264
+ self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity()
265
+
266
+ # Механизмы внимания
267
+ self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
268
+ ])
269
+ self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number)
270
+ ])
271
+
272
+ # Графовое слияние вместо GAL
273
+ if self.mode == 'mean':
274
+ self.graph_fusion = GraphFusionLayerAtt(hidden_dim, heads=num_heads)
275
+ else:
276
+ self.graph_fusion = GraphFusionLayerAtt(hidden_dim*2, heads=num_heads)
277
+
278
+ # Классификатор
279
+ self.classifier = nn.Sequential(
280
+ nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features),
281
+ nn.ReLU(),
282
+ nn.Linear(out_features, num_classes)
283
+ )
284
+
285
+ def forward(self, audio_features, text_features):
286
+ # Преобразование размерностей
287
+ audio_features = audio_features.float()
288
+ text_features = text_features.float()
289
+
290
+ audio_features = self.audio_proj(audio_features)
291
+ text_features = self.text_proj(text_features)
292
+
293
+ # Адаптивная пуллинг до минимальной длины
294
+ min_seq_len = min(audio_features.size(1), text_features.size(1))
295
+ audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1)
296
+ text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1)
297
+
298
+ # Трансформерные блоки
299
+ for i in range(len(self.audio_to_text_attn)):
300
+ attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features)
301
+ attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features)
302
+ audio_features += attn_audio
303
+ text_features += attn_text
304
+
305
+ # Статистики
306
+ std_audio, mean_audio = torch.std_mean(attn_audio, dim=1)
307
+ std_text, mean_text = torch.std_mean(attn_text, dim=1)
308
+
309
+ # Графовое слияние статистик
310
+ if self.mode == 'mean':
311
+ h_ta = self.graph_fusion(mean_audio, mean_text)
312
+ else:
313
+ h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_audio, std_text], dim=1))
314
+
315
+ # Классификация
316
+ return self.classifier(h_ta)
317
+
318
+ class BiFormer(nn.Module):
319
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
320
+ num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean',
321
+ device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
322
+ super(BiFormer, self).__init__()
323
+
324
+ self.mode = mode
325
+ self.hidden_dim = hidden_dim
326
+ self.seg_len = seg_len
327
+ self.tr_layer_number = tr_layer_number
328
+
329
+ # Проекционные слои с нормализацией
330
+ self.audio_proj = nn.Sequential(
331
+ nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
332
+ nn.LayerNorm(hidden_dim),
333
+ nn.Dropout(dropout)
334
+ )
335
+
336
+ self.text_proj = nn.Sequential(
337
+ nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
338
+ nn.LayerNorm(hidden_dim),
339
+ nn.Dropout(dropout)
340
+ )
341
+ # self.audio_proj = nn.Sequential(
342
+ # nn.Conv1d(audio_dim, hidden_dim, 1),
343
+ # nn.GELU(),
344
+ # )
345
+
346
+ # self.text_proj = nn.Sequential(
347
+ # nn.Conv1d(text_dim, hidden_dim, 1),
348
+ # nn.GELU(),
349
+ # )
350
+
351
+ # Трансформерные слои (сохраняем вашу реализацию)
352
+ self.audio_to_text_attn = nn.ModuleList([
353
+ TransformerEncoderLayer(
354
+ input_dim=hidden_dim,
355
+ num_heads=num_transformer_heads,
356
+ dropout=dropout,
357
+ positional_encoding=positional_encoding
358
+ ) for _ in range(tr_layer_number)
359
+ ])
360
+
361
+ self.text_to_audio_attn = nn.ModuleList([
362
+ TransformerEncoderLayer(
363
+ input_dim=hidden_dim,
364
+ num_heads=num_transformer_heads,
365
+ dropout=dropout,
366
+ positional_encoding=positional_encoding
367
+ ) for _ in range(tr_layer_number)
368
+ ])
369
+
370
+ # Автоматический расчёт размерности для классификатора
371
+ self._calculate_classifier_input_dim()
372
+
373
+ # Классификатор
374
+ self.classifier = nn.Sequential(
375
+ nn.Linear(self.classifier_input_dim, out_features),
376
+ nn.LayerNorm(out_features),
377
+ nn.GELU(),
378
+ nn.Dropout(dropout),
379
+ nn.Linear(out_features, num_classes)
380
+ )
381
+
382
+ self._init_weights()
383
+
384
+ def _calculate_classifier_input_dim(self):
385
+ """Вычисляет размер входных признаков для классификатора"""
386
+ # Тестовый проход через пулинг с dummy-данными
387
+ dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
388
+ dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
389
+
390
+ audio_pool = self._pool_features(dummy_audio)
391
+ text_pool = self._pool_features(dummy_text)
392
+
393
+ combined = torch.cat([audio_pool, text_pool], dim=1)
394
+ self.classifier_input_dim = combined.size(1)
395
+
396
+ def _pool_features(self, x):
397
+ # Статистики по временной оси (seq_len)
398
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
399
+
400
+ # Статистики по feature оси (hidden_dim)
401
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
402
+
403
+ return torch.cat([mean_temp, mean_feat], dim=1)
404
+
405
+ def forward(self, audio_features, text_features):
406
+ # Проекция признаков
407
+ # audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1)
408
+ # text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1)
409
+ audio = self.audio_proj(audio_features.float())
410
+ text = self.text_proj(text_features.float())
411
+
412
+ # Адаптивный пулинг
413
+ min_len = min(audio.size(1), text.size(1))
414
+ audio = self.adaptive_temporal_pool(audio, min_len)
415
+ text = self.adaptive_temporal_pool(text, min_len)
416
+
417
+ # Кросс-модальное взаимодействие
418
+ for i in range(self.tr_layer_number):
419
+ attn_audio = self.audio_to_text_attn[i](text, audio, audio)
420
+ attn_text = self.text_to_audio_attn[i](audio, text, text)
421
+ audio = audio + attn_audio
422
+ text = text + attn_text
423
+
424
+ # Агрегация признаков
425
+ audio_pool = self._pool_features(audio)
426
+ text_pool = self._pool_features(text)
427
+
428
+ # Классификация
429
+ features = torch.cat([audio_pool, text_pool], dim=1)
430
+ return self.classifier(features)
431
+
432
+ def adaptive_temporal_pool(self, x, target_len):
433
+ """Адаптивное изменение временной длины"""
434
+ if x.size(1) == target_len:
435
+ return x
436
+
437
+ return F.interpolate(
438
+ x.permute(0, 2, 1),
439
+ size=target_len,
440
+ mode='linear',
441
+ align_corners=False
442
+ ).permute(0, 2, 1)
443
+
444
+ def _init_weights(self):
445
+ for m in self.modules():
446
+ if isinstance(m, nn.Linear):
447
+ nn.init.xavier_uniform_(m.weight)
448
+ if m.bias is not None:
449
+ nn.init.constant_(m.bias, 0)
450
+ elif isinstance(m, nn.LayerNorm):
451
+ nn.init.constant_(m.weight, 1)
452
+ nn.init.constant_(m.bias, 0)
453
+
454
+ class BiGraphFormer(nn.Module):
455
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
456
+ num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
457
+ device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
458
+ super(BiGraphFormer, self).__init__()
459
+
460
+ self.mode = mode
461
+ self.hidden_dim = hidden_dim
462
+ self.seg_len = seg_len
463
+ self.tr_layer_number = tr_layer_number
464
+
465
+ # Проекционные слои с нормализацией
466
+ self.audio_proj = nn.Sequential(
467
+ nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
468
+ nn.LayerNorm(hidden_dim),
469
+ nn.Dropout(dropout)
470
+ )
471
+
472
+ self.text_proj = nn.Sequential(
473
+ nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
474
+ nn.LayerNorm(hidden_dim),
475
+ nn.Dropout(dropout)
476
+ )
477
+
478
+ # Трансформерные слои (сохраняем вашу реализацию)
479
+ self.audio_to_text_attn = nn.ModuleList([
480
+ TransformerEncoderLayer(
481
+ input_dim=hidden_dim,
482
+ num_heads=num_transformer_heads,
483
+ dropout=dropout,
484
+ positional_encoding=positional_encoding
485
+ ) for _ in range(tr_layer_number)
486
+ ])
487
+
488
+ self.text_to_audio_attn = nn.ModuleList([
489
+ TransformerEncoderLayer(
490
+ input_dim=hidden_dim,
491
+ num_heads=num_transformer_heads,
492
+ dropout=dropout,
493
+ positional_encoding=positional_encoding
494
+ ) for _ in range(tr_layer_number)
495
+ ])
496
+
497
+ self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads)
498
+ self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads)
499
+
500
+ # Автоматический расчёт размерности для классификатора
501
+ self._calculate_classifier_input_dim()
502
+
503
+ # Классификатор
504
+ self.classifier = nn.Sequential(
505
+ nn.Linear(self.classifier_input_dim, out_features),
506
+ nn.LayerNorm(out_features),
507
+ nn.GELU(),
508
+ nn.Dropout(dropout),
509
+ nn.Linear(out_features, num_classes)
510
+ )
511
+
512
+ # Финальная проекция графов
513
+ self.fc_feat = nn.Sequential(
514
+ nn.Linear(self.seg_len, self.seg_len),
515
+ nn.LayerNorm(self.seg_len),
516
+ nn.Dropout(dropout)
517
+ )
518
+
519
+ self.fc_temp = nn.Sequential(
520
+ nn.Linear(hidden_dim, hidden_dim),
521
+ nn.LayerNorm(hidden_dim),
522
+ nn.Dropout(dropout)
523
+ )
524
+
525
+ self._init_weights()
526
+
527
+ def _calculate_classifier_input_dim(self):
528
+ """Вычисляет размер входных признаков для классификатора"""
529
+ # Тестовый проход через пулинг с dummy-данными
530
+ dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
531
+ dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
532
+
533
+ audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
534
+ # text_pool_temp, _ = self._pool_features(dummy_text)
535
+
536
+ combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
537
+ self.classifier_input_dim = combined.size(1)
538
+
539
+ def _pool_features(self, x):
540
+ # Статистики по временной оси (seq_len)
541
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
542
+
543
+ # Статистики по feature оси (hidden_dim)
544
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
545
+
546
+ return mean_temp, mean_feat
547
+
548
+ def forward(self, audio_features, text_features):
549
+ # Проекция признаков
550
+ audio = self.audio_proj(audio_features.float())
551
+ text = self.text_proj(text_features.float())
552
+
553
+ # Адаптивный пулинг
554
+ min_len = min(audio.size(1), text.size(1))
555
+ audio = self.adaptive_temporal_pool(audio, min_len)
556
+ text = self.adaptive_temporal_pool(text, min_len)
557
+
558
+ # Кросс-модальное взаимодействие
559
+ for i in range(self.tr_layer_number):
560
+ attn_audio = self.audio_to_text_attn[i](text, audio, audio)
561
+ attn_text = self.text_to_audio_attn[i](audio, text, text)
562
+
563
+ audio = audio + attn_audio
564
+ text = text + attn_text
565
+
566
+ # Агрегация признаков
567
+ audio_pool_temp, audio_pool_feat = self._pool_features(audio)
568
+ text_pool_temp, text_pool_feat = self._pool_features(text)
569
+
570
+ # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
571
+
572
+ graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
573
+ graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
574
+
575
+ # print(graph_feat.shape, graph_temp.shape)
576
+ # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
577
+
578
+ # graph_feat = self.fc_feat(graph_feat)
579
+ # graph_temp = self.fc_temp(graph_temp)
580
+
581
+ # Классификация
582
+ features = torch.cat([graph_feat, graph_temp], dim=1)
583
+
584
+ # print(graph_feat.shape, graph_temp.shape, features.shape)
585
+ return self.classifier(features)
586
+
587
+ def adaptive_temporal_pool(self, x, target_len):
588
+ """Адаптивное изменение временной длины"""
589
+ if x.size(1) == target_len:
590
+ return x
591
+
592
+ return F.interpolate(
593
+ x.permute(0, 2, 1),
594
+ size=target_len,
595
+ mode='linear',
596
+ align_corners=False
597
+ ).permute(0, 2, 1)
598
+
599
+ def _init_weights(self):
600
+ for m in self.modules():
601
+ if isinstance(m, nn.Linear):
602
+ nn.init.xavier_uniform_(m.weight)
603
+ if m.bias is not None:
604
+ nn.init.constant_(m.bias, 0)
605
+ elif isinstance(m, nn.LayerNorm):
606
+ nn.init.constant_(m.weight, 1)
607
+ nn.init.constant_(m.bias, 0)
608
+
609
+
610
+ class BiGatedGraphFormer(nn.Module):
611
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
612
+ num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
613
+ device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
614
+ super(BiGatedGraphFormer, self).__init__()
615
+
616
+ self.mode = mode
617
+ self.hidden_dim = hidden_dim
618
+ self.seg_len = seg_len
619
+ self.tr_layer_number = tr_layer_number
620
+
621
+ # Проекционные слои с нормализацией
622
+ self.audio_proj = nn.Sequential(
623
+ nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
624
+ nn.LayerNorm(hidden_dim),
625
+ nn.Dropout(dropout)
626
+ )
627
+
628
+ self.text_proj = nn.Sequential(
629
+ nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
630
+ nn.LayerNorm(hidden_dim),
631
+ nn.Dropout(dropout)
632
+ )
633
+
634
+ # Трансформерные слои (сохраняем вашу реализацию)
635
+ self.audio_to_text_attn = nn.ModuleList([
636
+ TransformerEncoderLayer(
637
+ input_dim=hidden_dim,
638
+ num_heads=num_transformer_heads,
639
+ dropout=dropout,
640
+ positional_encoding=positional_encoding
641
+ ) for _ in range(tr_layer_number)
642
+ ])
643
+
644
+ self.text_to_audio_attn = nn.ModuleList([
645
+ TransformerEncoderLayer(
646
+ input_dim=hidden_dim,
647
+ num_heads=num_transformer_heads,
648
+ dropout=dropout,
649
+ positional_encoding=positional_encoding
650
+ ) for _ in range(tr_layer_number)
651
+ ])
652
+
653
+ self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False)
654
+ self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False)
655
+
656
+ self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout)
657
+ self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout)
658
+
659
+ # Автоматический расчёт размерности для классификатора
660
+ self._calculate_classifier_input_dim()
661
+
662
+ # Классификатор
663
+ self.classifier = nn.Sequential(
664
+ nn.Linear(hidden_dim_gated*2, out_features),
665
+ nn.LayerNorm(out_features),
666
+ nn.GELU(),
667
+ nn.Dropout(dropout),
668
+ nn.Linear(out_features, num_classes)
669
+ )
670
+
671
+ # Финальная проекция граф��в
672
+ self.fc_graph_feat = nn.Sequential(
673
+ nn.Linear(self.seg_len, hidden_dim_gated),
674
+ nn.LayerNorm(hidden_dim_gated),
675
+ nn.Dropout(dropout)
676
+ )
677
+
678
+ self.fc_graph_temp = nn.Sequential(
679
+ nn.Linear(hidden_dim, hidden_dim_gated),
680
+ nn.LayerNorm(hidden_dim_gated),
681
+ nn.Dropout(dropout)
682
+ )
683
+
684
+ # Финальная проекция gated
685
+ self.fc_gated_feat = nn.Sequential(
686
+ nn.Linear(hidden_dim_gated, hidden_dim_gated),
687
+ nn.LayerNorm(hidden_dim_gated),
688
+ nn.Dropout(dropout)
689
+ )
690
+
691
+ self.fc_gated_temp = nn.Sequential(
692
+ nn.Linear(hidden_dim_gated, hidden_dim_gated),
693
+ nn.LayerNorm(hidden_dim_gated),
694
+ nn.Dropout(dropout)
695
+ )
696
+
697
+ self._init_weights()
698
+
699
+ def _calculate_classifier_input_dim(self):
700
+ """Вычисляет размер входных признаков для классификатора"""
701
+ # Тестовый проход через пулинг с dummy-данными
702
+ dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
703
+ dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
704
+
705
+ audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
706
+ # text_pool_temp, _ = self._pool_features(dummy_text)
707
+
708
+ combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
709
+ self.classifier_input_dim = combined.size(1)
710
+
711
+ def _pool_features(self, x):
712
+ # Статистики по временной оси (seq_len)
713
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
714
+
715
+ # Статистики по feature оси (hidden_dim)
716
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
717
+
718
+ return mean_temp, mean_feat
719
+
720
+ def forward(self, audio_features, text_features):
721
+ # Проекция признаков
722
+ audio = self.audio_proj(audio_features.float())
723
+ text = self.text_proj(text_features.float())
724
+
725
+ # Адаптивный пулинг
726
+ min_len = min(audio.size(1), text.size(1))
727
+ audio = self.adaptive_temporal_pool(audio, min_len)
728
+ text = self.adaptive_temporal_pool(text, min_len)
729
+
730
+ # Кросс-модальное взаимодействие
731
+ for i in range(self.tr_layer_number):
732
+ attn_audio = self.audio_to_text_attn[i](text, audio, audio)
733
+ attn_text = self.text_to_audio_attn[i](audio, text, text)
734
+
735
+ audio = audio + attn_audio
736
+ text = text + attn_text
737
+
738
+ # Агрегация признаков
739
+ audio_pool_temp, audio_pool_feat = self._pool_features(audio)
740
+ text_pool_temp, text_pool_feat = self._pool_features(text)
741
+
742
+ # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
743
+
744
+ graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
745
+ graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
746
+
747
+ gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :])
748
+ gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :])
749
+
750
+ fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat)
751
+ fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp)
752
+
753
+ # print(graph_feat.shape, graph_temp.shape)
754
+ # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
755
+
756
+ # graph_feat = self.fc_feat(graph_feat)
757
+ # graph_temp = self.fc_temp(graph_temp)
758
+
759
+ # Классификация
760
+ features = torch.cat([fused_feat, fused_temp], dim=1)
761
+
762
+ # print(graph_feat.shape, graph_temp.shape, features.shape)
763
+ return self.classifier(features)
764
+
765
+ def adaptive_temporal_pool(self, x, target_len):
766
+ """Адаптивное изменение временной длины"""
767
+ if x.size(1) == target_len:
768
+ return x
769
+
770
+ return F.interpolate(
771
+ x.permute(0, 2, 1),
772
+ size=target_len,
773
+ mode='linear',
774
+ align_corners=False
775
+ ).permute(0, 2, 1)
776
+
777
+ def _init_weights(self):
778
+ for m in self.modules():
779
+ if isinstance(m, nn.Linear):
780
+ nn.init.xavier_uniform_(m.weight)
781
+ if m.bias is not None:
782
+ nn.init.constant_(m.bias, 0)
783
+ elif isinstance(m, nn.LayerNorm):
784
+ nn.init.constant_(m.weight, 1)
785
+ nn.init.constant_(m.bias, 0)
786
+
787
+
788
+ class BiFormerWithProb(nn.Module):
789
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
790
+ num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean',
791
+ device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
792
+ super(BiFormerWithProb, self).__init__()
793
+
794
+ self.mode = mode
795
+ self.hidden_dim = hidden_dim
796
+ self.seg_len = seg_len
797
+ self.tr_layer_number = tr_layer_number
798
+
799
+ # Проекционные слои с нормализацией
800
+ self.audio_proj = nn.Sequential(
801
+ nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
802
+ nn.LayerNorm(hidden_dim),
803
+ nn.Dropout(dropout)
804
+ )
805
+
806
+ self.text_proj = nn.Sequential(
807
+ nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
808
+ nn.LayerNorm(hidden_dim),
809
+ nn.Dropout(dropout)
810
+ )
811
+ # self.audio_proj = nn.Sequential(
812
+ # nn.Conv1d(audio_dim, hidden_dim, 1),
813
+ # nn.GELU(),
814
+ # )
815
+
816
+ # self.text_proj = nn.Sequential(
817
+ # nn.Conv1d(text_dim, hidden_dim, 1),
818
+ # nn.GELU(),
819
+ # )
820
+
821
+ # Трансформерные слои (сохраняем вашу реализацию)
822
+ self.audio_to_text_attn = nn.ModuleList([
823
+ TransformerEncoderLayer(
824
+ input_dim=hidden_dim,
825
+ num_heads=num_transformer_heads,
826
+ dropout=dropout,
827
+ positional_encoding=positional_encoding
828
+ ) for _ in range(tr_layer_number)
829
+ ])
830
+
831
+ self.text_to_audio_attn = nn.ModuleList([
832
+ TransformerEncoderLayer(
833
+ input_dim=hidden_dim,
834
+ num_heads=num_transformer_heads,
835
+ dropout=dropout,
836
+ positional_encoding=positional_encoding
837
+ ) for _ in range(tr_layer_number)
838
+ ])
839
+
840
+ # Автоматический расчёт размерности для классификатора
841
+ self._calculate_classifier_input_dim()
842
+
843
+ # Классификатор
844
+ self.classifier = nn.Sequential(
845
+ nn.Linear(self.classifier_input_dim, out_features),
846
+ nn.LayerNorm(out_features),
847
+ nn.GELU(),
848
+ nn.Dropout(dropout),
849
+ nn.Linear(out_features, num_classes)
850
+ )
851
+
852
+ self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
853
+
854
+ self._init_weights()
855
+
856
+ def _calculate_classifier_input_dim(self):
857
+ """Вычисляет размер входных признаков для классификатора"""
858
+ # Тестовый проход через пулинг с dummy-данными
859
+ dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
860
+ dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
861
+
862
+ audio_pool = self._pool_features(dummy_audio)
863
+ text_pool = self._pool_features(dummy_text)
864
+
865
+ combined = torch.cat([audio_pool, text_pool], dim=1)
866
+ self.classifier_input_dim = combined.size(1)
867
+
868
+ def _pool_features(self, x):
869
+ # Статистики по временной оси (seq_len)
870
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
871
+
872
+ # Статистики по feature оси (hidden_dim)
873
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
874
+
875
+ return torch.cat([mean_temp, mean_feat], dim=1)
876
+
877
+ def forward(self, audio_features, text_features, audio_pred, text_pred):
878
+ # Проекция признаков
879
+ # audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1)
880
+ # text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1)
881
+ audio = self.audio_proj(audio_features.float())
882
+ text = self.text_proj(text_features.float())
883
+
884
+ # Адаптивный пулинг
885
+ min_len = min(audio.size(1), text.size(1))
886
+ audio = self.adaptive_temporal_pool(audio, min_len)
887
+ text = self.adaptive_temporal_pool(text, min_len)
888
+
889
+ # Кросс-модальное взаимодействие
890
+ for i in range(self.tr_layer_number):
891
+ attn_audio = self.audio_to_text_attn[i](text, audio, audio)
892
+ attn_text = self.text_to_audio_attn[i](audio, text, text)
893
+ audio = audio + attn_audio
894
+ text = text + attn_text
895
+
896
+ # Агрегация признаков
897
+ audio_pool = self._pool_features(audio)
898
+ text_pool = self._pool_features(text)
899
+
900
+ # Классификация
901
+ features = torch.cat([audio_pool, text_pool], dim=1)
902
+
903
+ out = self.classifier(features)
904
+
905
+ w_out = self.pred_fusion([audio_pred, text_pred, out])
906
+ return w_out
907
+
908
+ def adaptive_temporal_pool(self, x, target_len):
909
+ """Адаптивное изменение временной длины"""
910
+ if x.size(1) == target_len:
911
+ return x
912
+
913
+ return F.interpolate(
914
+ x.permute(0, 2, 1),
915
+ size=target_len,
916
+ mode='linear',
917
+ align_corners=False
918
+ ).permute(0, 2, 1)
919
+
920
+ def _init_weights(self):
921
+ for m in self.modules():
922
+ if isinstance(m, nn.Linear):
923
+ nn.init.xavier_uniform_(m.weight)
924
+ if m.bias is not None:
925
+ nn.init.constant_(m.bias, 0)
926
+ elif isinstance(m, nn.LayerNorm):
927
+ nn.init.constant_(m.weight, 1)
928
+ nn.init.constant_(m.bias, 0)
929
+
930
+ class BiGraphFormerWithProb(nn.Module):
931
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
932
+ num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
933
+ device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
934
+ super(BiGraphFormerWithProb, self).__init__()
935
+
936
+ self.mode = mode
937
+ self.hidden_dim = hidden_dim
938
+ self.seg_len = seg_len
939
+ self.tr_layer_number = tr_layer_number
940
+
941
+ # Проекционные слои с нормализацией
942
+ self.audio_proj = nn.Sequential(
943
+ nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
944
+ nn.LayerNorm(hidden_dim),
945
+ nn.Dropout(dropout)
946
+ )
947
+
948
+ self.text_proj = nn.Sequential(
949
+ nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
950
+ nn.LayerNorm(hidden_dim),
951
+ nn.Dropout(dropout)
952
+ )
953
+
954
+ # Трансформерные слои (сохраняем вашу реализацию)
955
+ self.audio_to_text_attn = nn.ModuleList([
956
+ TransformerEncoderLayer(
957
+ input_dim=hidden_dim,
958
+ num_heads=num_transformer_heads,
959
+ dropout=dropout,
960
+ positional_encoding=positional_encoding
961
+ ) for _ in range(tr_layer_number)
962
+ ])
963
+
964
+ self.text_to_audio_attn = nn.ModuleList([
965
+ TransformerEncoderLayer(
966
+ input_dim=hidden_dim,
967
+ num_heads=num_transformer_heads,
968
+ dropout=dropout,
969
+ positional_encoding=positional_encoding
970
+ ) for _ in range(tr_layer_number)
971
+ ])
972
+
973
+ self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads)
974
+ self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads)
975
+
976
+ # Автоматический расчёт размерности для классификатора
977
+ self._calculate_classifier_input_dim()
978
+
979
+ # Классификатор
980
+ self.classifier = nn.Sequential(
981
+ nn.Linear(self.classifier_input_dim, out_features),
982
+ nn.LayerNorm(out_features),
983
+ nn.GELU(),
984
+ nn.Dropout(dropout),
985
+ nn.Linear(out_features, num_classes)
986
+ )
987
+
988
+ # Финальная проекция графов
989
+ self.fc_feat = nn.Sequential(
990
+ nn.Linear(self.seg_len, self.seg_len),
991
+ nn.LayerNorm(self.seg_len),
992
+ nn.Dropout(dropout)
993
+ )
994
+
995
+ self.fc_temp = nn.Sequential(
996
+ nn.Linear(hidden_dim, hidden_dim),
997
+ nn.LayerNorm(hidden_dim),
998
+ nn.Dropout(dropout)
999
+ )
1000
+
1001
+ self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
1002
+
1003
+ self._init_weights()
1004
+
1005
+ def _calculate_classifier_input_dim(self):
1006
+ """Вычисляет размер входных признаков для классификатора"""
1007
+ # Тестовый проход через пулинг с dummy-данными
1008
+ dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
1009
+ dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
1010
+
1011
+ audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
1012
+ # text_pool_temp, _ = self._pool_features(dummy_text)
1013
+
1014
+ combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
1015
+ self.classifier_input_dim = combined.size(1)
1016
+
1017
+ def _pool_features(self, x):
1018
+ # Статистики по временной оси (seq_len)
1019
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
1020
+
1021
+ # Статистики по feature оси (hidden_dim)
1022
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
1023
+
1024
+ return mean_temp, mean_feat
1025
+
1026
+ def forward(self, audio_features, text_features, audio_pred, text_pred):
1027
+ # Проекция признаков
1028
+ audio = self.audio_proj(audio_features.float())
1029
+ text = self.text_proj(text_features.float())
1030
+
1031
+ # Адаптивный пулинг
1032
+ min_len = min(audio.size(1), text.size(1))
1033
+ audio = self.adaptive_temporal_pool(audio, min_len)
1034
+ text = self.adaptive_temporal_pool(text, min_len)
1035
+
1036
+ # Кросс-модальное взаимодействие
1037
+ for i in range(self.tr_layer_number):
1038
+ attn_audio = self.audio_to_text_attn[i](text, audio, audio)
1039
+ attn_text = self.text_to_audio_attn[i](audio, text, text)
1040
+
1041
+ audio = audio + attn_audio
1042
+ text = text + attn_text
1043
+
1044
+ # Агрегация признаков
1045
+ audio_pool_temp, audio_pool_feat = self._pool_features(audio)
1046
+ text_pool_temp, text_pool_feat = self._pool_features(text)
1047
+
1048
+ # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
1049
+
1050
+ graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
1051
+ graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
1052
+
1053
+ # print(graph_feat.shape, graph_temp.shape)
1054
+ # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
1055
+
1056
+ # graph_feat = self.fc_feat(graph_feat)
1057
+ # graph_temp = self.fc_temp(graph_temp)
1058
+
1059
+ # Классификация
1060
+ features = torch.cat([graph_feat, graph_temp], dim=1)
1061
+
1062
+ # print(graph_feat.shape, graph_temp.shape, features.shape)
1063
+ out = self.classifier(features)
1064
+
1065
+ w_out = self.pred_fusion([audio_pred, text_pred, out])
1066
+ return w_out
1067
+
1068
+ def adaptive_temporal_pool(self, x, target_len):
1069
+ """Адаптивное изменение временной длины"""
1070
+ if x.size(1) == target_len:
1071
+ return x
1072
+
1073
+ return F.interpolate(
1074
+ x.permute(0, 2, 1),
1075
+ size=target_len,
1076
+ mode='linear',
1077
+ align_corners=False
1078
+ ).permute(0, 2, 1)
1079
+
1080
+ def _init_weights(self):
1081
+ for m in self.modules():
1082
+ if isinstance(m, nn.Linear):
1083
+ nn.init.xavier_uniform_(m.weight)
1084
+ if m.bias is not None:
1085
+ nn.init.constant_(m.bias, 0)
1086
+ elif isinstance(m, nn.LayerNorm):
1087
+ nn.init.constant_(m.weight, 1)
1088
+ nn.init.constant_(m.bias, 0)
1089
+
1090
+ class BiGatedGraphFormerWithProb(nn.Module):
1091
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
1092
+ num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
1093
+ device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
1094
+ super(BiGatedGraphFormerWithProb, self).__init__()
1095
+
1096
+ self.mode = mode
1097
+ self.hidden_dim = hidden_dim
1098
+ self.seg_len = seg_len
1099
+ self.tr_layer_number = tr_layer_number
1100
+
1101
+ # Проекционные слои с нормализацией
1102
+ self.audio_proj = nn.Sequential(
1103
+ nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
1104
+ nn.LayerNorm(hidden_dim),
1105
+ nn.Dropout(dropout)
1106
+ )
1107
+
1108
+ self.text_proj = nn.Sequential(
1109
+ nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
1110
+ nn.LayerNorm(hidden_dim),
1111
+ nn.Dropout(dropout)
1112
+ )
1113
+
1114
+ # Трансформерные слои (сохраняем вашу реализацию)
1115
+ self.audio_to_text_attn = nn.ModuleList([
1116
+ TransformerEncoderLayer(
1117
+ input_dim=hidden_dim,
1118
+ num_heads=num_transformer_heads,
1119
+ dropout=dropout,
1120
+ positional_encoding=positional_encoding
1121
+ ) for _ in range(tr_layer_number)
1122
+ ])
1123
+
1124
+ self.text_to_audio_attn = nn.ModuleList([
1125
+ TransformerEncoderLayer(
1126
+ input_dim=hidden_dim,
1127
+ num_heads=num_transformer_heads,
1128
+ dropout=dropout,
1129
+ positional_encoding=positional_encoding
1130
+ ) for _ in range(tr_layer_number)
1131
+ ])
1132
+
1133
+ self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False)
1134
+ self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False)
1135
+
1136
+ self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout)
1137
+ self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout)
1138
+
1139
+ # Автоматический расчёт размерности для классификатора
1140
+ self._calculate_classifier_input_dim()
1141
+
1142
+ # Классификатор
1143
+ self.classifier = nn.Sequential(
1144
+ nn.Linear(hidden_dim_gated*2, out_features),
1145
+ nn.LayerNorm(out_features),
1146
+ nn.GELU(),
1147
+ nn.Dropout(dropout),
1148
+ nn.Linear(out_features, num_classes)
1149
+ )
1150
+
1151
+ # Финальная проекция графов
1152
+ self.fc_graph_feat = nn.Sequential(
1153
+ nn.Linear(self.seg_len, hidden_dim_gated),
1154
+ nn.LayerNorm(hidden_dim_gated),
1155
+ nn.Dropout(dropout)
1156
+ )
1157
+
1158
+ self.fc_graph_temp = nn.Sequential(
1159
+ nn.Linear(hidden_dim, hidden_dim_gated),
1160
+ nn.LayerNorm(hidden_dim_gated),
1161
+ nn.Dropout(dropout)
1162
+ )
1163
+
1164
+ # Финальная проекция gated
1165
+ self.fc_gated_feat = nn.Sequential(
1166
+ nn.Linear(hidden_dim_gated, hidden_dim_gated),
1167
+ nn.LayerNorm(hidden_dim_gated),
1168
+ nn.Dropout(dropout)
1169
+ )
1170
+
1171
+ self.fc_gated_temp = nn.Sequential(
1172
+ nn.Linear(hidden_dim_gated, hidden_dim_gated),
1173
+ nn.LayerNorm(hidden_dim_gated),
1174
+ nn.Dropout(dropout)
1175
+ )
1176
+
1177
+ self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
1178
+
1179
+ self._init_weights()
1180
+
1181
+ def _calculate_classifier_input_dim(self):
1182
+ """Вычисляет размер входных признаков для классификатора"""
1183
+ # Тестовый проход через пулинг с dummy-данными
1184
+ dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
1185
+ dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
1186
+
1187
+ audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
1188
+ # text_pool_temp, _ = self._pool_features(dummy_text)
1189
+
1190
+ combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
1191
+ self.classifier_input_dim = combined.size(1)
1192
+
1193
+ def _pool_features(self, x):
1194
+ # Статистики по временной оси (seq_len)
1195
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
1196
+
1197
+ # Статистики по feature оси (hidden_dim)
1198
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
1199
+
1200
+ return mean_temp, mean_feat
1201
+
1202
+ def forward(self, audio_features, text_features, audio_pred, text_pred):
1203
+ # Проекция признаков
1204
+ audio = self.audio_proj(audio_features.float())
1205
+ text = self.text_proj(text_features.float())
1206
+
1207
+ # Адаптивный пулинг
1208
+ min_len = min(audio.size(1), text.size(1))
1209
+ audio = self.adaptive_temporal_pool(audio, min_len)
1210
+ text = self.adaptive_temporal_pool(text, min_len)
1211
+
1212
+ # Кросс-модальное взаимодействие
1213
+ for i in range(self.tr_layer_number):
1214
+ attn_audio = self.audio_to_text_attn[i](text, audio, audio)
1215
+ attn_text = self.text_to_audio_attn[i](audio, text, text)
1216
+
1217
+ audio = audio + attn_audio
1218
+ text = text + attn_text
1219
+
1220
+ # Агрегация признаков
1221
+ audio_pool_temp, audio_pool_feat = self._pool_features(audio)
1222
+ text_pool_temp, text_pool_feat = self._pool_features(text)
1223
+
1224
+ # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
1225
+
1226
+ graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
1227
+ graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
1228
+
1229
+ gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :])
1230
+ gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :])
1231
+
1232
+ fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat)
1233
+ fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp)
1234
+
1235
+ # print(graph_feat.shape, graph_temp.shape)
1236
+ # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
1237
+
1238
+ # graph_feat = self.fc_feat(graph_feat)
1239
+ # graph_temp = self.fc_temp(graph_temp)
1240
+
1241
+ # Классификация
1242
+ features = torch.cat([fused_feat, fused_temp], dim=1)
1243
+
1244
+ # print(graph_feat.shape, graph_temp.shape, features.shape)
1245
+ out = self.classifier(features)
1246
+
1247
+ w_out = self.pred_fusion([audio_pred, text_pred, out])
1248
+ return w_out
1249
+
1250
+ def adaptive_temporal_pool(self, x, target_len):
1251
+ """Адаптивное изменение временной длины"""
1252
+ if x.size(1) == target_len:
1253
+ return x
1254
+
1255
+ return F.interpolate(
1256
+ x.permute(0, 2, 1),
1257
+ size=target_len,
1258
+ mode='linear',
1259
+ align_corners=False
1260
+ ).permute(0, 2, 1)
1261
+
1262
+ def _init_weights(self):
1263
+ for m in self.modules():
1264
+ if isinstance(m, nn.Linear):
1265
+ nn.init.xavier_uniform_(m.weight)
1266
+ if m.bias is not None:
1267
+ nn.init.constant_(m.bias, 0)
1268
+ elif isinstance(m, nn.LayerNorm):
1269
+ nn.init.constant_(m.weight, 1)
1270
+ nn.init.constant_(m.bias, 0)
1271
+
1272
+ class BiGatedFormer(nn.Module):
1273
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128,
1274
+ num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean',
1275
+ device="cuda", tr_layer_number=1, out_features=128, num_classes=7):
1276
+ super(BiGatedFormer, self).__init__()
1277
+
1278
+ self.mode = mode
1279
+ self.hidden_dim = hidden_dim
1280
+ self.seg_len = seg_len
1281
+ self.tr_layer_number = tr_layer_number
1282
+
1283
+ # Проекционные слои с нормализацией
1284
+ self.audio_proj = nn.Sequential(
1285
+ nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(),
1286
+ nn.LayerNorm(hidden_dim),
1287
+ nn.Dropout(dropout)
1288
+ )
1289
+
1290
+ self.text_proj = nn.Sequential(
1291
+ nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(),
1292
+ nn.LayerNorm(hidden_dim),
1293
+ nn.Dropout(dropout)
1294
+ )
1295
+
1296
+ # Трансформерные слои (сохраняем вашу реализацию)
1297
+ self.audio_to_text_attn = nn.ModuleList([
1298
+ TransformerEncoderLayer(
1299
+ input_dim=hidden_dim,
1300
+ num_heads=num_transformer_heads,
1301
+ dropout=dropout,
1302
+ positional_encoding=positional_encoding
1303
+ ) for _ in range(tr_layer_number)
1304
+ ])
1305
+
1306
+ self.text_to_audio_attn = nn.ModuleList([
1307
+ TransformerEncoderLayer(
1308
+ input_dim=hidden_dim,
1309
+ num_heads=num_transformer_heads,
1310
+ dropout=dropout,
1311
+ positional_encoding=positional_encoding
1312
+ ) for _ in range(tr_layer_number)
1313
+ ])
1314
+
1315
+ # self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False)
1316
+ # self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False)
1317
+
1318
+ self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout)
1319
+ self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout)
1320
+
1321
+ # Автоматический расчёт размерности для классификатора
1322
+ self._calculate_classifier_input_dim()
1323
+
1324
+ # Классификатор
1325
+ self.classifier = nn.Sequential(
1326
+ nn.Linear(hidden_dim_gated*2, out_features),
1327
+ nn.LayerNorm(out_features),
1328
+ nn.GELU(),
1329
+ nn.Dropout(dropout),
1330
+ nn.Linear(out_features, num_classes)
1331
+ )
1332
+
1333
+ # Финальная проекция графов
1334
+ # self.fc_graph_feat = nn.Sequential(
1335
+ # nn.Linear(self.seg_len, hidden_dim_gated),
1336
+ # nn.LayerNorm(hidden_dim_gated),
1337
+ # nn.Dropout(dropout)
1338
+ # )
1339
+
1340
+ # self.fc_graph_temp = nn.Sequential(
1341
+ # nn.Linear(hidden_dim, hidden_dim_gated),
1342
+ # nn.LayerNorm(hidden_dim_gated),
1343
+ # nn.Dropout(dropout)
1344
+ # )
1345
+
1346
+ # Финальная проекция gated
1347
+ self.fc_gated_feat = nn.Sequential(
1348
+ nn.Linear(hidden_dim_gated, hidden_dim_gated),
1349
+ nn.LayerNorm(hidden_dim_gated),
1350
+ nn.Dropout(dropout)
1351
+ )
1352
+
1353
+ self.fc_gated_temp = nn.Sequential(
1354
+ nn.Linear(hidden_dim_gated, hidden_dim_gated),
1355
+ nn.LayerNorm(hidden_dim_gated),
1356
+ nn.Dropout(dropout)
1357
+ )
1358
+
1359
+ self._init_weights()
1360
+
1361
+ def _calculate_classifier_input_dim(self):
1362
+ """Вычисляет размер входных признаков для классификатора"""
1363
+ # Тестовый проход через пулинг с dummy-данными
1364
+ dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim)
1365
+ dummy_text = torch.randn(1, self.seg_len, self.hidden_dim)
1366
+
1367
+ audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio)
1368
+ # text_pool_temp, _ = self._pool_features(dummy_text)
1369
+
1370
+ combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1)
1371
+ self.classifier_input_dim = combined.size(1)
1372
+
1373
+ def _pool_features(self, x):
1374
+ # Статистики по временной оси (seq_len)
1375
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
1376
+
1377
+ # Статистики по feature оси (hidden_dim)
1378
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
1379
+
1380
+ return mean_temp, mean_feat
1381
+
1382
+ def forward(self, audio_features, text_features):
1383
+ # Проекция признаков
1384
+ audio = self.audio_proj(audio_features.float())
1385
+ text = self.text_proj(text_features.float())
1386
+
1387
+ # Адаптивный пулинг
1388
+ min_len = min(audio.size(1), text.size(1))
1389
+ audio = self.adaptive_temporal_pool(audio, min_len)
1390
+ text = self.adaptive_temporal_pool(text, min_len)
1391
+
1392
+ # Кросс-модальное взаимодействие
1393
+ for i in range(self.tr_layer_number):
1394
+ attn_audio = self.audio_to_text_attn[i](text, audio, audio)
1395
+ attn_text = self.text_to_audio_attn[i](audio, text, text)
1396
+
1397
+ audio = audio + attn_audio
1398
+ text = text + attn_text
1399
+
1400
+ # Агрегация признаков
1401
+ audio_pool_temp, audio_pool_feat = self._pool_features(audio)
1402
+ text_pool_temp, text_pool_feat = self._pool_features(text)
1403
+
1404
+ # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape)
1405
+
1406
+ # graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat)
1407
+ # graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp)
1408
+
1409
+ gated_feat = self.gated_feat(audio_pool_feat, text_pool_feat)
1410
+ gated_temp = self.gated_temp(audio_pool_temp, text_pool_temp)
1411
+
1412
+ # fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat)
1413
+ # fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp)
1414
+
1415
+ # print(graph_feat.shape, graph_temp.shape)
1416
+ # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape)
1417
+
1418
+ # graph_feat = self.fc_feat(graph_feat)
1419
+ # graph_temp = self.fc_temp(graph_temp)
1420
+
1421
+ # Классификация
1422
+ features = torch.cat([gated_feat, gated_temp], dim=1)
1423
+
1424
+ # print(graph_feat.shape, graph_temp.shape, features.shape)
1425
+ return self.classifier(features)
1426
+
1427
+ def adaptive_temporal_pool(self, x, target_len):
1428
+ """Адаптивное изменение временной длины"""
1429
+ if x.size(1) == target_len:
1430
+ return x
1431
+
1432
+ return F.interpolate(
1433
+ x.permute(0, 2, 1),
1434
+ size=target_len,
1435
+ mode='linear',
1436
+ align_corners=False
1437
+ ).permute(0, 2, 1)
1438
+
1439
+ def _init_weights(self):
1440
+ for m in self.modules():
1441
+ if isinstance(m, nn.Linear):
1442
+ nn.init.xavier_uniform_(m.weight)
1443
+ if m.bias is not None:
1444
+ nn.init.constant_(m.bias, 0)
1445
+ elif isinstance(m, nn.LayerNorm):
1446
+ nn.init.constant_(m.weight, 1)
1447
+ nn.init.constant_(m.bias, 0)
1448
+
1449
+ class BiMamba(nn.Module):
1450
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16,
1451
+ d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='', positional_encoding=False,
1452
+ out_features=128, num_classes=7, device="cuda"):
1453
+ super(BiMamba, self).__init__()
1454
+
1455
+ self.hidden_dim = hidden_dim
1456
+ self.seg_len = seg_len
1457
+ self.num_mamba_layers = mamba_layer_number
1458
+ self.device = device
1459
+
1460
+ # Проекционные слои для каждой модальности
1461
+ self.audio_proj = nn.Sequential(
1462
+ nn.Linear(audio_dim, hidden_dim),
1463
+ nn.LayerNorm(hidden_dim),
1464
+ nn.Dropout(dropout)
1465
+ )
1466
+
1467
+ self.text_proj = nn.Sequential(
1468
+ nn.Linear(text_dim, hidden_dim),
1469
+ nn.LayerNorm(hidden_dim),
1470
+ nn.Dropout(dropout)
1471
+ )
1472
+
1473
+ # Слой для объединения модальностей
1474
+ self.fusion_proj = nn.Sequential(
1475
+ nn.Linear(2 * hidden_dim, hidden_dim),
1476
+ nn.LayerNorm(hidden_dim),
1477
+ nn.Dropout(dropout)
1478
+ )
1479
+
1480
+ # Mamba блоки для обработки объединенных признаков
1481
+ mamba_params = {
1482
+ 'd_input': hidden_dim,
1483
+ 'd_model': hidden_dim,
1484
+ 'd_state': mamba_d_state,
1485
+ 'd_discr': d_discr,
1486
+ 'ker_size': mamba_ker_size
1487
+ }
1488
+
1489
+ self.mamba_blocks = nn.ModuleList([
1490
+ nn.Sequential(
1491
+ MambaBlock(**mamba_params),
1492
+ RMSNorm(hidden_dim)
1493
+ )
1494
+ for _ in range(self.num_mamba_layers)
1495
+ ])
1496
+
1497
+ # Автоматический расчет размерности классификатора
1498
+ # self._calculate_classifier_input_dim()
1499
+
1500
+ # Классификатор
1501
+ self.classifier = nn.Sequential(
1502
+ nn.Linear(self.seg_len + self.hidden_dim, out_features),
1503
+ nn.LayerNorm(out_features),
1504
+ nn.GELU(),
1505
+ nn.Dropout(dropout),
1506
+ nn.Linear(out_features, num_classes)
1507
+ )
1508
+
1509
+ self._init_weights()
1510
+
1511
+ # def _calculate_classifier_input_dim(self):
1512
+ # """Вычисляет размер входных признаков для классификатора"""
1513
+ # dummy = torch.randn(1, self.seg_len, self.hidden_dim)
1514
+ # pooled = self._pool_features(dummy)
1515
+ # self.classifier_input_dim = pooled.size(1)
1516
+
1517
+ def _pool_features(self, x):
1518
+ """Объединение временных и feature статистик"""
1519
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
1520
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
1521
+ full_feature = torch.cat([mean_temp, mean_feat], dim=1)
1522
+ if full_feature.shape[-1] == self.seg_len+self.hidden_dim:
1523
+ return torch.cat([mean_temp, mean_feat], dim=1)
1524
+ else:
1525
+ pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1]
1526
+ return F.pad(full_feature, (0, pad_size), mode="constant", value=0)
1527
+
1528
+ def forward(self, audio_features, text_features):
1529
+ # Проекция признаков
1530
+ audio = self.audio_proj(audio_features.float()) # [B, T, D]
1531
+ text = self.text_proj(text_features.float()) # [B, T, D]
1532
+
1533
+ # Адаптивный пулинг к минимальной длине
1534
+ min_len = min(audio.size(1), text.size(1))
1535
+ audio = self._adaptive_pool(audio, min_len)
1536
+ text = self._adaptive_pool(text, min_len)
1537
+
1538
+ # Объединение модальностей
1539
+ fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D]
1540
+ fused = self.fusion_proj(fused) # [B, T, D]
1541
+
1542
+ # Обработка объединенных признаков через Mamba
1543
+ for mamba_block in self.mamba_blocks:
1544
+ out, _ = mamba_block[0](fused, None)
1545
+ out = mamba_block[1](out)
1546
+ fused = fused + out # Residual connection
1547
+
1548
+ # Агрегация признаков и классификация
1549
+ pooled = self._pool_features(fused)
1550
+ return self.classifier(pooled)
1551
+
1552
+ def _adaptive_pool(self, x, target_len):
1553
+ """Адаптивное изменение временной длины"""
1554
+ if x.size(1) == target_len:
1555
+ return x
1556
+ return F.interpolate(
1557
+ x.permute(0, 2, 1),
1558
+ size=target_len,
1559
+ mode='linear',
1560
+ align_corners=False
1561
+ ).permute(0, 2, 1)
1562
+
1563
+ def _init_weights(self):
1564
+ for m in self.modules():
1565
+ if isinstance(m, nn.Linear):
1566
+ nn.init.xavier_uniform_(m.weight)
1567
+ if m.bias is not None:
1568
+ nn.init.constant_(m.bias, 0)
1569
+ elif isinstance(m, nn.LayerNorm):
1570
+ nn.init.constant_(m.weight, 1)
1571
+ nn.init.constant_(m.bias, 0)
1572
+
1573
+ class BiMambaWithProb(nn.Module):
1574
+ def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16,
1575
+ d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='',positional_encoding=False,
1576
+ out_features=128, num_classes=7, device="cuda"):
1577
+ super(BiMambaWithProb, self).__init__()
1578
+
1579
+ self.hidden_dim = hidden_dim
1580
+ self.seg_len = seg_len
1581
+ self.num_mamba_layers = mamba_layer_number
1582
+ self.device = device
1583
+
1584
+ # Проекционные слои для каждой модальности
1585
+ self.audio_proj = nn.Sequential(
1586
+ nn.Linear(audio_dim, hidden_dim),
1587
+ nn.LayerNorm(hidden_dim),
1588
+ nn.Dropout(dropout)
1589
+ )
1590
+
1591
+ self.text_proj = nn.Sequential(
1592
+ nn.Linear(text_dim, hidden_dim),
1593
+ nn.LayerNorm(hidden_dim),
1594
+ nn.Dropout(dropout)
1595
+ )
1596
+
1597
+ # Слой для объединения модальностей
1598
+ self.fusion_proj = nn.Sequential(
1599
+ nn.Linear(2 * hidden_dim, hidden_dim),
1600
+ nn.LayerNorm(hidden_dim),
1601
+ nn.Dropout(dropout)
1602
+ )
1603
+
1604
+ # Mamba блоки для обработки объединенных признаков
1605
+ mamba_params = {
1606
+ 'd_input': hidden_dim,
1607
+ 'd_model': hidden_dim,
1608
+ 'd_state': mamba_d_state,
1609
+ 'd_discr': d_discr,
1610
+ 'ker_size': mamba_ker_size
1611
+ }
1612
+
1613
+ self.mamba_blocks = nn.ModuleList([
1614
+ nn.Sequential(
1615
+ MambaBlock(**mamba_params),
1616
+ RMSNorm(hidden_dim)
1617
+ )
1618
+ for _ in range(self.num_mamba_layers)
1619
+ ])
1620
+
1621
+ # Автоматический расчет размерности классификатора
1622
+ # self._calculate_classifier_input_dim()
1623
+
1624
+ # Классификатор
1625
+ self.classifier = nn.Sequential(
1626
+ nn.Linear(self.seg_len + self.hidden_dim, out_features),
1627
+ nn.LayerNorm(out_features),
1628
+ nn.GELU(),
1629
+ nn.Dropout(dropout),
1630
+ nn.Linear(out_features, num_classes)
1631
+ )
1632
+
1633
+ self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes)
1634
+
1635
+ self._init_weights()
1636
+
1637
+ # def _calculate_classifier_input_dim(self):
1638
+ # """Вычисляет размер входных признаков для классификатора"""
1639
+ # dummy = torch.randn(1, self.seg_len, self.hidden_dim)
1640
+ # pooled = self._pool_features(dummy)
1641
+ # self.classifier_input_dim = pooled.size(1)
1642
+
1643
+ def _pool_features(self, x):
1644
+ """Объединение временных и feature статистик"""
1645
+ mean_temp = x.mean(dim=1) # [batch, hidden_dim]
1646
+ mean_feat = x.mean(dim=-1) # [batch, seq_len]
1647
+ full_feature = torch.cat([mean_temp, mean_feat], dim=1)
1648
+ if full_feature.shape[-1] == self.seg_len+self.hidden_dim:
1649
+ return torch.cat([mean_temp, mean_feat], dim=1)
1650
+ else:
1651
+ pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1]
1652
+ return F.pad(full_feature, (0, pad_size), mode="constant", value=0)
1653
+
1654
+ def forward(self, audio_features, text_features, audio_pred, text_pred):
1655
+ # Проекция признаков
1656
+ audio = self.audio_proj(audio_features.float()) # [B, T, D]
1657
+ text = self.text_proj(text_features.float()) # [B, T, D]
1658
+
1659
+ # Адаптивный пулинг к минимальной длине
1660
+ min_len = min(audio.size(1), text.size(1))
1661
+ audio = self._adaptive_pool(audio, min_len)
1662
+ text = self._adaptive_pool(text, min_len)
1663
+
1664
+ # Объединение модальностей
1665
+ fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D]
1666
+ fused = self.fusion_proj(fused) # [B, T, D]
1667
+
1668
+ # Обработка объединенных признаков через Mamba
1669
+ for mamba_block in self.mamba_blocks:
1670
+ out, _ = mamba_block[0](fused, None)
1671
+ out = mamba_block[1](out)
1672
+ fused = fused + out # Residual connection
1673
+
1674
+ # Агрегация признаков и классификация
1675
+ pooled = self._pool_features(fused)
1676
+ out = self.classifier(pooled)
1677
+
1678
+ w_out = self.pred_fusion([audio_pred, text_pred, out])
1679
+ return w_out
1680
+
1681
+ def _adaptive_pool(self, x, target_len):
1682
+ """Адаптивное изменение временной длины"""
1683
+ if x.size(1) == target_len:
1684
+ return x
1685
+ return F.interpolate(
1686
+ x.permute(0, 2, 1),
1687
+ size=target_len,
1688
+ mode='linear',
1689
+ align_corners=False
1690
+ ).permute(0, 2, 1)
1691
+
1692
+ def _init_weights(self):
1693
+ for m in self.modules():
1694
+ if isinstance(m, nn.Linear):
1695
+ nn.init.xavier_uniform_(m.weight)
1696
+ if m.bias is not None:
1697
+ nn.init.constant_(m.bias, 0)
1698
+ elif isinstance(m, nn.LayerNorm):
1699
+ nn.init.constant_(m.weight, 1)
1700
+ nn.init.constant_(m.bias, 0)
requirements.txt ADDED
Binary file (2.25 kB). View file
 
run_generation.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from generate_synthetic_dataset import generate_from_emotion_csv
4
+
5
+ if __name__ == "__main__":
6
+ if len(sys.argv) < 2:
7
+ print("❌ Использование: python run_generation.py path/to/file.csv [num_processes] [device]")
8
+ sys.exit(1)
9
+
10
+ csv_path = sys.argv[1]
11
+ num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else int(os.environ.get("NUM_DIA_PROCESSES", 1))
12
+ device = sys.argv[3] if len(sys.argv) > 3 else "cuda"
13
+
14
+ filename = os.path.basename(csv_path)
15
+ try:
16
+ emotion = filename.split("_")[2]
17
+ except IndexError:
18
+ emotion = "unknown"
19
+
20
+ print(f"🧪 CSV: {csv_path}")
21
+ print(f"💻 Устройство: {device}")
22
+ print(f"🔧 Процессов: {num_processes}")
23
+ print(f"🎭 Эмоция: {emotion}")
24
+
25
+ generate_from_emotion_csv(
26
+ csv_path=csv_path,
27
+ emotion=emotion,
28
+ output_dir="tts_synthetic_final",
29
+ device=device,
30
+ max_samples=None,
31
+ num_processes=num_processes
32
+ )
search_params.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [grid]
2
+ scheduler_type = ["huggingface_cosine_with_restarts", "huggingface_linear", "cosine", "onecycle"]
3
+ smoothing_probability = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
4
+ # mamba_ker_size = [3, 4]
5
+ # mamba_layer_number = [3, 4, 5]
6
+ # hidden_dim_gated = [128]
7
+ # num_transformer_heads = [2, 4, 8]
8
+ # tr_layer_number = [1, 2]
9
+ # out_features = [128, 256]
10
+ # num_graph_heads = [2, 4]
11
+ # dropout = [0.0, 0.1, 0.2]
12
+ # positional_encoding = [true, false]
13
+
14
+ [defaults]
15
+ hidden_dim = 128
16
+ # # hidden_dim_gated = 128
17
+ # num_transformer_heads = 2
18
+ # tr_layer_number = 1
19
+ # out_features = 128
20
+ # # num_graph_heads = 2
21
+ # dropout = 0
22
+ # positional_encoding = false
synthetic_utils/__pycache__/dia_tts_wrapper.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
synthetic_utils/dia_tts_wrapper.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import torch
5
+ import soundfile as sf
6
+ import numpy as np
7
+ from dia.model import Dia
8
+
9
+
10
+ class DiaTTSWrapper:
11
+ def __init__(self, model_name="nari-labs/Dia-1.6B", device="cuda", dtype="float16"):
12
+ self.device = device
13
+ self.sr = 44100
14
+ logging.info(f"[DiaTTS] Загрузка модели {model_name} на {device} (dtype={dtype})")
15
+ self.model = Dia.from_pretrained(
16
+ model_name,
17
+ device=device,
18
+ compute_dtype=dtype
19
+ )
20
+
21
+ def generate_audio_from_text(self, text: str, paralinguistic: str = "", max_duration: float = None) -> torch.Tensor:
22
+ try:
23
+ if paralinguistic:
24
+ clean = paralinguistic.strip("()").lower()
25
+ text = f"{text} ({clean})"
26
+ audio_np = self.model.generate(
27
+ text,
28
+ use_torch_compile=False,
29
+ verbose=False
30
+ )
31
+ wf = torch.from_numpy(audio_np).float().unsqueeze(0)
32
+ if max_duration:
33
+ max_samples = int(self.sr * max_duration)
34
+ wf = wf[:, :max_samples]
35
+ return wf
36
+ except Exception as e:
37
+ logging.error(f"[DiaTTS] Ошибка генерации аудио: {e}")
38
+ return torch.zeros(1, self.sr)
39
+
40
+ def generate_and_save_audio(
41
+ self,
42
+ text: str,
43
+ paralinguistic: str = "",
44
+ out_dir="tts_outputs",
45
+ filename_prefix="tts",
46
+ max_duration: float = None,
47
+ use_timestamp=True,
48
+ skip_if_exists=True,
49
+ max_trim_duration: float = None
50
+ ) -> torch.Tensor:
51
+ os.makedirs(out_dir, exist_ok=True)
52
+ if use_timestamp:
53
+ timestr = time.strftime("%Y%m%d_%H%M%S")
54
+ filename = f"{filename_prefix}_{timestr}.wav"
55
+ else:
56
+ filename = f"{filename_prefix}.wav"
57
+ out_path = os.path.join(out_dir, filename)
58
+
59
+ if skip_if_exists and os.path.exists(out_path):
60
+ logging.info(f"[DiaTTS] ⏭️ Пропущено — уже существует: {out_path}")
61
+ return None
62
+
63
+ wf = self.generate_audio_from_text(text, paralinguistic, max_duration)
64
+ np_wf = wf.squeeze().cpu().numpy()
65
+
66
+ if max_trim_duration is not None:
67
+ max_len = int(self.sr * max_trim_duration)
68
+ if len(np_wf) > max_len:
69
+ logging.info(f"[DiaTTS] ✂️ Обрезка аудио до {max_trim_duration} сек.")
70
+ np_wf = np_wf[:max_len]
71
+
72
+ sf.write(out_path, np_wf, self.sr)
73
+ logging.info(f"[DiaTTS] 💾 Сохранено аудио: {out_path}")
74
+ return wf
75
+
76
+ def get_sample_rate(self):
77
+ return self.sr
synthetic_utils/parler_tts_wrapper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # parler_tts_wrapper.py
2
+
3
+ import torch
4
+ import soundfile as sf
5
+ import time
6
+ import os
7
+ import logging
8
+ from parler_tts import ParlerTTSForConditionalGeneration
9
+ from transformers import AutoTokenizer
10
+
11
+ class ParlerTTS:
12
+ def __init__(self, model_name="parler-tts/parler-tts-mini-v1", device="cuda"):
13
+ self.device = device
14
+ logging.info(f"[ParlerTTS] Загрузка модели {model_name} на {device} ...")
15
+
16
+ self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device)
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ self.sr = self.model.config.sampling_rate
19
+
20
+ def generate_audio_from_text(self, text: str, description: str) -> torch.Tensor:
21
+ """
22
+ Генерирует аудио (без сохранения на диск).
23
+ Возвращает PyTorch-тензор формы (1, num_samples).
24
+ """
25
+ input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
26
+ prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
27
+
28
+ with torch.no_grad():
29
+ generation = self.model.generate(
30
+ input_ids=input_ids,
31
+ prompt_input_ids=prompt_input_ids
32
+ )
33
+
34
+ audio_arr = generation.cpu().numpy().squeeze() # (samples,)
35
+ wf = torch.from_numpy(audio_arr).unsqueeze(0) # -> (1, samples)
36
+ return wf
37
+
38
+ def generate_and_save_audio(self, text: str, description: str, out_dir="tts_outputs", filename_prefix="tts") -> torch.Tensor:
39
+ """
40
+ Генерирует аудио И сохраняет результат в WAV-файл (для отладки/проверки).
41
+ Возвращает PyTorch-тензор (1, num_samples).
42
+ """
43
+ os.makedirs(out_dir, exist_ok=True)
44
+
45
+ wf = self.generate_audio_from_text(text, description)
46
+ np_wf = wf.squeeze().cpu().numpy()
47
+
48
+ # Формируем имя файла
49
+ timestr = time.strftime("%Y%m%d_%H%M%S")
50
+ filename = f"{filename_prefix}_{timestr}.wav"
51
+ out_path = os.path.join(out_dir, filename)
52
+
53
+ # Сохраняем
54
+ sf.write(out_path, np_wf, self.sr)
55
+ logging.info(f"[ParlerTTS] Сохранено аудио: {out_path}")
56
+
57
+ return wf
58
+
59
+ def get_sample_rate(self):
60
+ return self.sr
synthetic_utils/text_generation.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import random
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
5
+
6
+ class TextGenerator:
7
+ def __init__(
8
+ self,
9
+ model_name="gpt2",
10
+ device="cuda",
11
+ max_new_tokens=50,
12
+ temperature=1.0,
13
+ top_p=0.95,
14
+ seed=None
15
+ ):
16
+ self.model_name = model_name
17
+ self.device = device
18
+ self.max_new_tokens = max_new_tokens
19
+ self.temperature = temperature
20
+ self.top_p = top_p
21
+ self.seed = seed
22
+
23
+ logging.info(f"[TextGenerator] Загрузка модели {model_name} на {device} ...")
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
26
+
27
+ if seed is not None:
28
+ set_seed(seed)
29
+ logging.info(f"[TextGenerator] Сид генерации установлен через transformers.set_seed({seed})")
30
+ else:
31
+ logging.info("[TextGenerator] Сид генерации не установлен (seed=None)")
32
+
33
+ # --- Примеры для few-shot обучения ---
34
+ self.fewshot_examples = [
35
+ ("happy", "We finally made it!", "We finally made it! I’ve never felt so alive and proud of what we accomplished."),
36
+ ("sad", "He didn't come back.", "He didn't come back. I waited all night, hoping to see him again."),
37
+ ("anger", "Why would you do that?", "Why would you do that? You had no right to interfere!"),
38
+ ("fear", "Did you hear that?", "Did you hear that? Something’s moving outside the window..."),
39
+ ("surprise", "Oh wow, really?", "Oh wow, really? I didn’t see that coming at all!"),
40
+ ("disgust", "That smell is awful.", "That smell is awful. I feel like I’m going to be sick."),
41
+ ("neutral", "Let's meet at noon.", "Let's meet at noon. We’ll have plenty of time to talk then.")
42
+ ]
43
+
44
+ def build_prompt(self, emotion: str, partial_text: str) -> str:
45
+ few_shot = random.sample(self.fewshot_examples, 2)
46
+ examples_str = ""
47
+ for emo, text, cont in few_shot:
48
+ examples_str += (
49
+ f"Example:\n"
50
+ f"Emotion: {emo}\n"
51
+ f"Text: {text}\n"
52
+ f"Continuation: {cont}\n\n"
53
+ )
54
+
55
+ prompt = (
56
+ "You are a helpful assistant that generates emotionally-aligned sentence continuations.\n"
57
+ "You must include the original sentence in the output, and then continue it in a fluent and emotionally appropriate way.\n\n"
58
+ f"{examples_str}"
59
+ f"Now try:\n"
60
+ f"Emotion: {emotion}\n"
61
+ f"Text: {partial_text}\n"
62
+ f"Continuation:"
63
+ )
64
+ return prompt
65
+
66
+ def generate_text(self, emotion: str, partial_text: str = "") -> str:
67
+ prompt = self.build_prompt(emotion, partial_text)
68
+ logging.debug(f"[TextGenerator] prompt:\n{prompt}")
69
+
70
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
71
+
72
+ output_ids = self.model.generate(
73
+ **inputs,
74
+ max_new_tokens=self.max_new_tokens,
75
+ do_sample=True,
76
+ top_p=self.top_p,
77
+ temperature=self.temperature,
78
+ pad_token_id=self.tokenizer.eos_token_id
79
+ )
80
+
81
+ full_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
82
+ logging.debug(f"[TextGenerator] decoded:\n{full_text}")
83
+
84
+ # Вытаскиваем то, что идёт после последнего "Continuation:"
85
+ if "Continuation:" in full_text:
86
+ result = full_text.split("Continuation:")[-1].strip()
87
+ else:
88
+ result = full_text.strip()
89
+
90
+ result = result.split("\n")[0].strip()
91
+ return result
test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tts_test_run.py
2
+
3
+ import os
4
+ from utils.config_loader import ConfigLoader
5
+ from synthetic_utils.dia_tts_wrapper import DiaTTSWrapper
6
+ from generate_synthetic_dataset import PARALINGUISTIC_MARKERS
7
+
8
+ # Загружаем конфиг
9
+ config = ConfigLoader("config.toml")
10
+
11
+ # Настройка TTS
12
+ tts = DiaTTSWrapper(device=config.whisper_device)
13
+
14
+ # Пример текста и эмоции
15
+ text = "I'm just testing how this emotional voice sounds."
16
+ emotion = "neutral" # можно: neutral, happy, sad, anger, fear, surprise, disgust
17
+ marker = PARALINGUISTIC_MARKERS.get(emotion, "")
18
+
19
+ # Генерация и сохранение
20
+ tts.generate_and_save_audio(
21
+ text=text,
22
+ paralinguistic=marker,
23
+ out_dir="tts_test_outputs",
24
+ filename_prefix=f"test_{emotion}",
25
+ max_duration=5.0
26
+ )
27
+
28
+ print(f"✅ Аудио для эмоции '{emotion}' сохранено.")
training/train_utils.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # train_utils.py
3
+
4
+ import torch
5
+ import logging
6
+ import random
7
+ import numpy as np
8
+ import csv
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ from typing import Type
12
+ import os
13
+ import datetime
14
+
15
+ from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
16
+ from torch.nn.utils.rnn import pad_sequence
17
+
18
+ from utils.losses import WeightedCrossEntropyLoss
19
+ from utils.measures import uar, war, mf1, wf1
20
+ from models.models import (
21
+ BiFormer, BiGraphFormer, BiGatedGraphFormer,
22
+ PredictionsFusion, BiFormerWithProb, BiGatedFormer,
23
+ BiMamba, BiMambaWithProb,BiGraphFormerWithProb, BiGatedGraphFormerWithProb
24
+ )
25
+ from utils.schedulers import SmartScheduler
26
+ from data_loading.dataset_multimodal import DatasetMultiModalWithPretrainedExtractors
27
+ from sklearn.utils.class_weight import compute_class_weight
28
+ from lion_pytorch import Lion
29
+
30
+
31
+ def get_smoothed_labels(audio_paths, original_labels, smooth_labels_df, smooth_mask, emotion_columns, device):
32
+ """
33
+ audio_paths: список путей к аудиофайлам
34
+ smooth_mask: тензор boolean с индексами для сглаживания
35
+ Возвращает тензор сглаженных меток только для отмеченных примеров
36
+ """
37
+
38
+ # Получаем индексы для сглаживания
39
+ smooth_indices = torch.where(smooth_mask)[0]
40
+
41
+ # Создаем тензор для результатов (такого же размера как оригинальные метки)
42
+ smoothed_labels = torch.zeros_like(original_labels)
43
+
44
+ # print(smooth_labels_df, audio_paths)
45
+
46
+ for idx in smooth_indices:
47
+ audio_path = audio_paths[idx]
48
+ # Получаем сглаженную метку из вашего DataFrame или другого источника
49
+ smoothed_label = smooth_labels_df.loc[
50
+ smooth_labels_df['video_name'] == audio_path[:-4],
51
+ emotion_columns
52
+ ].values[0]
53
+
54
+ smoothed_labels[idx] = torch.tensor(smoothed_label, device=device)
55
+
56
+ return smoothed_labels
57
+
58
+ def custom_collate_fn(batch):
59
+ """Собирает список образцов в единый батч, отбрасывая None (невалидные)."""
60
+ batch = [x for x in batch if x is not None]
61
+ # print(batch[0].keys())
62
+ if not batch:
63
+ return None
64
+
65
+ audios = [b["audio"] for b in batch]
66
+ # audio_tensor = torch.stack(audios)
67
+ audio_tensor = pad_sequence(audios, batch_first=True)
68
+
69
+ labels = [b["label"] for b in batch]
70
+ label_tensor = torch.stack(labels)
71
+
72
+ texts = [b["text"] for b in batch]
73
+ text_tensor = torch.stack(texts)
74
+
75
+ audio_pred = [b["audio_pred"] for b in batch]
76
+ audio_pred = torch.stack(audio_pred)
77
+
78
+ text_pred = [b["text_pred"] for b in batch]
79
+ text_pred = torch.stack(text_pred)
80
+
81
+ return {
82
+ "audio_paths": [b["audio_path"] for b in batch], # new
83
+ "audio": audio_tensor,
84
+ "label": label_tensor,
85
+ "text": text_tensor,
86
+ "audio_pred": audio_pred,
87
+ "text_pred": text_pred,
88
+ }
89
+
90
+ def get_class_weights_from_loader(train_loader, num_classes):
91
+ """
92
+ Вычисляет веса классов из train_loader, устойчиво к отсутствующим классам.
93
+ Если какой-либо класс отсутствует в выборке, ему будет присвоен вес 0.0.
94
+
95
+ :param train_loader: DataLoader с one-hot метками
96
+ :param num_classes: Общее количество классов
97
+ :return: np.ndarray весов длины num_classes
98
+ """
99
+ all_labels = []
100
+ for batch in train_loader:
101
+ if batch is None:
102
+ continue
103
+ all_labels.extend(batch["label"].argmax(dim=1).tolist())
104
+
105
+ if not all_labels:
106
+ raise ValueError("Нет ни одной метки в train_loader для вычисления весов классов.")
107
+
108
+ present_classes = np.unique(all_labels)
109
+
110
+ if len(present_classes) < num_classes:
111
+ missing = set(range(num_classes)) - set(present_classes)
112
+ logging.info(f"[!] Отсутствуют метки для классов: {sorted(missing)}")
113
+
114
+ # Вычисляем веса только по тем классам, что есть
115
+ weights_partial = compute_class_weight(
116
+ class_weight="balanced",
117
+ classes=present_classes,
118
+ y=all_labels
119
+ )
120
+
121
+ # Собираем полный вектор весов
122
+ full_weights = np.zeros(num_classes, dtype=np.float32)
123
+ for cls, w in zip(present_classes, weights_partial):
124
+ full_weights[cls] = w
125
+
126
+ return full_weights
127
+
128
+ def make_dataset_and_loader(config, split: str, audio_feature_extractor: Type = None, text_feature_extractor: Type = None, whisper_model: Type = None, only_dataset: str = None):
129
+ """
130
+ Универсальная функция: объединяет датасеты или возвращает один п��и only_dataset.
131
+ При объединении train-датасетов — использует WeightedRandomSampler для балансировки.
132
+ """
133
+ datasets = []
134
+
135
+ if not hasattr(config, "datasets") or not config.datasets:
136
+ raise ValueError("⛔ В конфиге не указана секция [datasets].")
137
+
138
+ for dataset_name, dataset_cfg in config.datasets.items():
139
+ if only_dataset and dataset_name != only_dataset:
140
+ continue
141
+
142
+ csv_path = dataset_cfg["csv_path"].format(base_dir=dataset_cfg["base_dir"], split=split)
143
+ wav_dir = dataset_cfg["wav_dir"].format(base_dir=dataset_cfg["base_dir"], split=split)
144
+
145
+ logging.info(f"[{dataset_name.upper()}] Split={split}: CSV={csv_path}, WAV_DIR={wav_dir}")
146
+
147
+ dataset = DatasetMultiModalWithPretrainedExtractors(
148
+ csv_path = csv_path,
149
+ wav_dir = wav_dir,
150
+ emotion_columns = config.emotion_columns,
151
+ split = split,
152
+ config = config,
153
+ audio_feature_extractor = audio_feature_extractor,
154
+ text_feature_extractor = text_feature_extractor,
155
+ whisper_model = whisper_model,
156
+ dataset_name = dataset_name
157
+ )
158
+
159
+ datasets.append(dataset)
160
+
161
+ if not datasets:
162
+ raise ValueError(f"⚠️ Для split='{split}' не найдено ни одного подходящего датасета.")
163
+
164
+ if len(datasets) == 1:
165
+
166
+ full_dataset = datasets[0]
167
+ loader = DataLoader(
168
+ full_dataset,
169
+ batch_size=config.batch_size,
170
+ shuffle=(split == "train"),
171
+ num_workers=config.num_workers,
172
+ collate_fn=custom_collate_fn
173
+ )
174
+ else:
175
+ # Несколько датасетов — собираем веса
176
+ lengths = [len(d) for d in datasets]
177
+ total = sum(lengths)
178
+
179
+ logging.info(f"[!] Объединяем {len(datasets)} датасетов: {lengths} (total={total})")
180
+
181
+ weights = []
182
+ for d_len in lengths:
183
+ w = 1.0 / d_len
184
+ weights += [w] * d_len
185
+ logging.info(f" ➜ Сэмплы из датасета с {d_len} примерами получают вес {w:.6f}")
186
+
187
+ full_dataset = ConcatDataset(datasets)
188
+
189
+ if split == "train":
190
+ sampler = WeightedRandomSampler(weights, num_samples=total, replacement=True)
191
+ loader = DataLoader(
192
+ full_dataset,
193
+ batch_size=config.batch_size,
194
+ sampler=sampler,
195
+ num_workers=config.num_workers,
196
+ collate_fn=custom_collate_fn
197
+ )
198
+ else:
199
+ loader = DataLoader(
200
+ full_dataset,
201
+ batch_size=config.batch_size,
202
+ shuffle=False,
203
+ num_workers=config.num_workers,
204
+ collate_fn=custom_collate_fn
205
+ )
206
+
207
+ return full_dataset, loader
208
+
209
+ def run_eval(model, loader, criterion, model_name, device="cuda"):
210
+ """
211
+ Оценка модели на loader'е. Возвращает (loss, uar, war, mf1, wf1).
212
+ """
213
+ model.eval()
214
+ total_loss = 0.0
215
+ total_preds = []
216
+ total_targets = []
217
+ total = 0
218
+
219
+ with torch.no_grad():
220
+ for batch in tqdm(loader):
221
+ if batch is None:
222
+ continue
223
+
224
+ audio = batch["audio"].to(device)
225
+ labels = batch["label"].to(device)
226
+ texts = batch["text"]
227
+ audio_pred = batch["audio_pred"].to(device)
228
+ text_pred = batch["text_pred"].to(device)
229
+
230
+ if "fusion" in model_name:
231
+ logits = model((audio_pred, text_pred))
232
+ elif "withprob" in model_name:
233
+ logits = model(audio, texts, audio_pred, text_pred)
234
+ else:
235
+ logits = model(audio, texts)
236
+ target = labels.argmax(dim=1)
237
+
238
+ loss = criterion(logits, target)
239
+ bs = audio.shape[0]
240
+ total_loss += loss.item() * bs
241
+ total += bs
242
+
243
+ preds = logits.argmax(dim=1)
244
+ total_preds.extend(preds.cpu().numpy().tolist())
245
+ total_targets.extend(target.cpu().numpy().tolist())
246
+
247
+ avg_loss = total_loss / total
248
+
249
+ uar_m = uar(total_targets, total_preds)
250
+ war_m = war(total_targets, total_preds)
251
+ mf1_m = mf1(total_targets, total_preds)
252
+ wf1_m = wf1(total_targets, total_preds)
253
+
254
+ return avg_loss, uar_m, war_m, mf1_m, wf1_m
255
+
256
+ def train_once(config, train_loader, dev_loaders, test_loaders, metrics_csv_path=None):
257
+ """
258
+ Логика обучения (train/dev/test).
259
+ Возвращает лучшую метрику на dev и словарь метрик.
260
+ """
261
+
262
+ logging.info("== Запуск тренировки (train/dev/test) ==")
263
+
264
+ checkpoint_dir = None
265
+ if config.save_best_model:
266
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
267
+ checkpoint_dir = os.path.join("checkpoints", f"{config.model_name}_{timestamp}")
268
+ os.makedirs(checkpoint_dir, exist_ok=True)
269
+
270
+ csv_writer = None
271
+ csv_file = None
272
+
273
+ if config.path_to_df_ls:
274
+ df_ls = pd.read_csv(config.path_to_df_ls)
275
+
276
+ if metrics_csv_path:
277
+ csv_file = open(metrics_csv_path, mode="w", newline="", encoding="utf-8")
278
+ csv_writer = csv.writer(csv_file)
279
+ csv_writer.writerow(["split", "epoch", "dataset", "loss", "uar", "war", "mf1", "wf1", "mean"])
280
+
281
+
282
+ # Seed
283
+ if config.random_seed > 0:
284
+ random.seed(config.random_seed)
285
+ torch.manual_seed(config.random_seed)
286
+ torch.cuda.manual_seed_all(config.random_seed)
287
+ torch.backends.cudnn.deterministic = True
288
+ torch.backends.cudnn.benchmark = False
289
+ os.environ['PYTHONHASHSEED'] = str(config.random_seed)
290
+ logging.info(f"== Фиксируем random seed: {config.random_seed}")
291
+ else:
292
+ logging.info("== Random seed не фиксирован (0).")
293
+
294
+ device = "cuda" if torch.cuda.is_available() else "cpu"
295
+
296
+ # Экстракторы
297
+ # audio_extractor = AudioEmbeddingExtractor(config)
298
+ # text_extractor = TextEmbeddingExtractor(config)
299
+
300
+ # Параметры
301
+ hidden_dim = config.hidden_dim
302
+ num_classes = len(config.emotion_columns)
303
+ num_transformer_heads = config.num_transformer_heads
304
+ num_graph_heads = config.num_graph_heads
305
+ hidden_dim_gated = config.hidden_dim_gated
306
+ mamba_d_state = config.mamba_d_state
307
+ mamba_ker_size = config.mamba_ker_size
308
+ mamba_layer_number = config.mamba_layer_number
309
+ mode = config.mode
310
+ weight_decay = config.weight_decay
311
+ momentum = config.momentum
312
+ positional_encoding = config.positional_encoding
313
+ dropout = config.dropout
314
+ out_features = config.out_features
315
+ lr = config.lr
316
+ num_epochs = config.num_epochs
317
+ tr_layer_number = config.tr_layer_number
318
+ max_patience = config.max_patience
319
+ scheduler_type = config.scheduler_type
320
+
321
+ dict_models = {
322
+ 'BiFormer': BiFormer, # вход audio, texts
323
+ 'BiGraphFormer': BiGraphFormer, # вход audio, texts
324
+ 'BiGatedGraphFormer': BiGatedGraphFormer, # вход audio, texts
325
+ "BiGatedFormer": BiGatedFormer, # вход audio, texts
326
+ "BiMamba": BiMamba, # вход audio, texts
327
+ "PredictionsFusion": PredictionsFusion, # вход audio_pred, text_pred
328
+ "BiFormerWithProb": BiFormerWithProb, # вход audio, texts, audio_pred, text_pred
329
+ "BiMambaWithProb": BiMambaWithProb, # вход audio, texts, audio_pred, text_pred
330
+ "BiGraphFormerWithProb": BiGraphFormerWithProb, # вход audio, texts, audio_pred, text_pred
331
+ "BiGatedGraphFormerWithProb": BiGatedGraphFormerWithProb,
332
+ }
333
+
334
+ model_cls = dict_models[config.model_name]
335
+ model_name = config.model_name.lower()
336
+
337
+ if model_name == 'predictionsfusion':
338
+ model = model_cls().to(device)
339
+
340
+ elif 'mamba' in model_name:
341
+ # Особые параметры для Mamba-семейства
342
+ model = model_cls(
343
+ audio_dim = config.audio_embedding_dim,
344
+ text_dim = config.text_embedding_dim,
345
+ hidden_dim = hidden_dim,
346
+ mamba_d_state = mamba_d_state,
347
+ mamba_ker_size = mamba_ker_size,
348
+ mamba_layer_number = mamba_layer_number,
349
+ seg_len = config.max_tokens,
350
+ mode = mode,
351
+ dropout = dropout,
352
+ positional_encoding = positional_encoding,
353
+ out_features = out_features,
354
+ device = device,
355
+ num_classes = num_classes
356
+ ).to(device)
357
+
358
+ else:
359
+ # Обычные модели
360
+ model = model_cls(
361
+ audio_dim = config.audio_embedding_dim,
362
+ text_dim = config.text_embedding_dim,
363
+ hidden_dim = hidden_dim,
364
+ hidden_dim_gated = hidden_dim_gated,
365
+ num_transformer_heads = num_transformer_heads,
366
+ num_graph_heads = num_graph_heads,
367
+ seg_len = config.max_tokens,
368
+ mode = mode,
369
+ dropout = dropout,
370
+ positional_encoding = positional_encoding,
371
+ out_features = out_features,
372
+ tr_layer_number = tr_layer_number,
373
+ device = device,
374
+ num_classes = num_classes
375
+ ).to(device)
376
+
377
+ # Оптимизатор и лосс
378
+ if config.optimizer == "adam":
379
+ optimizer = torch.optim.Adam(
380
+ model.parameters(), lr=lr, weight_decay=weight_decay
381
+ )
382
+ elif config.optimizer == "adamw":
383
+ optimizer = torch.optim.AdamW(
384
+ model.parameters(), lr=lr, weight_decay=weight_decay
385
+ )
386
+ elif config.optimizer == "lion":
387
+ optimizer = Lion(
388
+ model.parameters(), lr=lr, weight_decay=weight_decay
389
+ )
390
+ elif config.optimizer == "sgd":
391
+ optimizer = torch.optim.SGD(
392
+ model.parameters(), lr=lr,momentum = momentum
393
+ )
394
+ elif config.optimizer == "rmsprop":
395
+ optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
396
+ else:
397
+ raise ValueError(f"⛔ Неизвестный оптимизатор: {config.optimizer}")
398
+
399
+ logging.info(f"Используем оптимизатор: {config.optimizer}, learning rate: {lr}")
400
+
401
+ class_weights = get_class_weights_from_loader(train_loader, num_classes)
402
+ criterion = WeightedCrossEntropyLoss(class_weights)
403
+
404
+ logging.info("Class weights: " + ", ".join(f"{name}={weight:.4f}" for name, weight in zip(config.emotion_columns, class_weights)))
405
+
406
+ # LR Scheduler
407
+ steps_per_epoch = sum(1 for batch in train_loader if batch is not None)
408
+ scheduler = SmartScheduler(
409
+ scheduler_type=scheduler_type,
410
+ optimizer=optimizer,
411
+ config=config,
412
+ steps_per_epoch=steps_per_epoch
413
+ )
414
+
415
+ # Early stopping по dev
416
+ best_dev_mean = float("-inf")
417
+ best_dev_metrics = {}
418
+ patience_counter = 0
419
+
420
+ for epoch in range(num_epochs):
421
+ logging.info(f"\n=== Эпоха {epoch} ===")
422
+ model.train()
423
+
424
+ total_loss = 0.0
425
+ total_samples = 0
426
+ total_preds = []
427
+ total_targets = []
428
+
429
+ for batch in tqdm(train_loader):
430
+ if batch is None:
431
+ continue
432
+
433
+ audio_paths = batch["audio_paths"] # new
434
+ audio = batch["audio"].to(device)
435
+
436
+ # Обработка меток с частичным сглаживанием
437
+ if config.smoothing_probability == 0:
438
+ labels = batch["label"].to(device)
439
+ else:
440
+ # Получаем оригинальные горячие метки
441
+ original_labels = batch["label"].to(device)
442
+
443
+ # Создаем маску для сглаживания (выбираем случайные примеры)
444
+ batch_size = original_labels.size(0)
445
+ smooth_mask = torch.rand(batch_size, device=device) < config.smoothing_probability
446
+
447
+ # Получаем сглаженные метки для выбранных примеров
448
+ smoothed_labels = get_smoothed_labels(audio_paths, original_labels, df_ls, smooth_mask, config.emotion_columns, device)
449
+
450
+ # Комбинируем метки
451
+ labels = torch.where(
452
+ smooth_mask.unsqueeze(1), # Добавляем размерность для broadcast
453
+ smoothed_labels.to(device),
454
+ original_labels
455
+ )
456
+ # print(labels)
457
+ texts = batch["text"]
458
+ audio_pred = batch["audio_pred"].to(device)
459
+ text_pred = batch["text_pred"].to(device)
460
+
461
+ if "fusion" in model_name:
462
+ logits = model((audio_pred, text_pred))
463
+ elif "withprob" in model_name:
464
+ logits = model(audio, texts, audio_pred, text_pred)
465
+ else:
466
+ logits = model(audio, texts)
467
+
468
+ target = labels.argmax(dim=1)
469
+ loss = criterion(logits, target)
470
+
471
+ optimizer.zero_grad()
472
+ loss.backward()
473
+ optimizer.step()
474
+
475
+ # Если scheduler - One cycle или с Hugging Face
476
+ scheduler.step(batch_level=True)
477
+
478
+ bs = audio.shape[0]
479
+ total_loss += loss.item() * bs
480
+
481
+ preds = logits.argmax(dim=1)
482
+ total_preds.extend(preds.cpu().numpy().tolist())
483
+ total_targets.extend(target.cpu().numpy().tolist())
484
+ total_samples += bs
485
+
486
+ train_loss = total_loss / total_samples
487
+ uar_m = uar(total_targets, total_preds)
488
+ war_m = war(total_targets, total_preds)
489
+ mf1_m = mf1(total_targets, total_preds)
490
+ wf1_m = wf1(total_targets, total_preds)
491
+ mean_train = np.mean([uar_m, war_m, mf1_m, wf1_m])
492
+
493
+ logging.info(
494
+ f"[TRAIN] Loss={train_loss:.4f}, UAR={uar_m:.4f}, WAR={war_m:.4f}, "
495
+ f"MF1={mf1_m:.4f}, WF1={wf1_m:.4f}, MEAN={mean_train:.4f}"
496
+ )
497
+
498
+ # --- DEV ---
499
+ dev_means = []
500
+ dev_metrics_by_dataset = []
501
+
502
+ for name, loader in dev_loaders:
503
+ d_loss, d_uar, d_war, d_mf1, d_wf1 = run_eval(
504
+ model, loader, criterion, model_name, device
505
+ )
506
+ d_mean = np.mean([d_uar, d_war, d_mf1, d_wf1])
507
+ dev_means.append(d_mean)
508
+
509
+ if csv_writer:
510
+ csv_writer.writerow(["dev", epoch, name, d_loss, d_uar, d_war, d_mf1, d_wf1, d_mean])
511
+
512
+ logging.info(
513
+ f"[DEV:{name}] Loss={d_loss:.4f}, UAR={d_uar:.4f}, WAR={d_war:.4f}, "
514
+ f"MF1={d_mf1:.4f}, WF1={d_wf1:.4f}, MEAN={d_mean:.4f}"
515
+ )
516
+
517
+ dev_metrics_by_dataset.append({
518
+ "name": name,
519
+ "loss": d_loss,
520
+ "uar": d_uar,
521
+ "war": d_war,
522
+ "mf1": d_mf1,
523
+ "wf1": d_wf1,
524
+ "mean": d_mean,
525
+ })
526
+
527
+ mean_dev = np.mean(dev_means)
528
+ scheduler.step(mean_dev)
529
+
530
+ # --- TEST ---
531
+ test_metrics_by_dataset = []
532
+ for name, loader in test_loaders:
533
+ t_loss, t_uar, t_war, t_mf1, t_wf1 = run_eval(
534
+ model, loader, criterion, model_name, device
535
+ )
536
+ t_mean = np.mean([t_uar, t_war, t_mf1, t_wf1])
537
+ logging.info(
538
+ f"[TEST:{name}] Loss={t_loss:.4f}, UAR={t_uar:.4f}, WAR={t_war:.4f}, "
539
+ f"MF1={t_mf1:.4f}, WF1={t_wf1:.4f}, MEAN={t_mean:.4f}"
540
+ )
541
+
542
+ test_metrics_by_dataset.append({
543
+ "name": name,
544
+ "loss": t_loss,
545
+ "uar": t_uar,
546
+ "war": t_war,
547
+ "mf1": t_mf1,
548
+ "wf1": t_wf1,
549
+ "mean": t_mean,
550
+ })
551
+
552
+ if csv_writer:
553
+ csv_writer.writerow(["test", epoch, name, t_loss, t_uar, t_war, t_mf1, t_wf1, t_mean])
554
+
555
+
556
+ if mean_dev > best_dev_mean:
557
+ best_dev_mean = mean_dev
558
+ patience_counter = 0
559
+ best_dev_metrics = {
560
+ "mean": mean_dev,
561
+ "by_dataset": dev_metrics_by_dataset
562
+ }
563
+ best_test_metrics = {
564
+ "mean": np.mean([ds["mean"] for ds in test_metrics_by_dataset]),
565
+ "by_dataset": test_metrics_by_dataset
566
+ }
567
+
568
+ if config.save_best_model:
569
+ dev_str = f"{mean_dev:.4f}".replace(".", "_")
570
+ model_path = os.path.join(checkpoint_dir, f"best_model_dev_{dev_str}_epoch_{epoch}.pt")
571
+ torch.save(model.state_dict(), model_path)
572
+ logging.info(f"💾 Модель сохранена по лучшему dev (эпоха {epoch}): {model_path}")
573
+
574
+ else:
575
+ patience_counter += 1
576
+ if patience_counter >= max_patience:
577
+ logging.info(f"Early stopping: {max_patience} эпох без улучшения.")
578
+ break
579
+
580
+ logging.info("Тренировка завершена. Все split'ы обработаны!")
581
+
582
+ if csv_file:
583
+ csv_file.close()
584
+
585
+ return best_dev_mean, best_dev_metrics, best_test_metrics
training/train_utils_old.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # train_utils.py
3
+
4
+ import os
5
+ import torch
6
+ import logging
7
+ import random
8
+ import datetime
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import csv
12
+
13
+ from torch.utils.data import DataLoader, ConcatDataset
14
+ from utils.losses import WeightedCrossEntropyLoss
15
+ from utils.measures import uar, war, mf1, wf1
16
+ from models.models import BiFormer, BiGraphFormer, BiGatedGraphFormer
17
+ from data_loading.dataset_multimodal import DatasetMultiModal
18
+ from data_loading.feature_extractor import AudioEmbeddingExtractor, TextEmbeddingExtractor
19
+ from sklearn.utils.class_weight import compute_class_weight
20
+
21
+ def custom_collate_fn(batch):
22
+ """Собирает список образцов в единый батч, отбрасывая None (невалидные)."""
23
+ batch = [x for x in batch if x is not None]
24
+ if not batch:
25
+ return None
26
+
27
+ audios = [b["audio"] for b in batch]
28
+ audio_tensor = torch.stack(audios)
29
+
30
+ labels = [b["label"] for b in batch]
31
+ label_tensor = torch.stack(labels)
32
+
33
+ texts = [b["text"] for b in batch]
34
+
35
+ return {
36
+ "audio": audio_tensor,
37
+ "label": label_tensor,
38
+ "text": texts
39
+ }
40
+
41
+ def get_class_weights_from_loader(train_loader, num_classes):
42
+ """
43
+ Вычисляет веса классов из train_loader, устойчиво к отсутствующим классам.
44
+ Если какой-либо класс отсутствует в выборке, ему будет присвоен вес 0.0.
45
+
46
+ :param train_loader: DataLoader с one-hot метками
47
+ :param num_classes: Общее количество классов
48
+ :return: np.ndarray весов длины num_classes
49
+ """
50
+ all_labels = []
51
+ for batch in train_loader:
52
+ if batch is None:
53
+ continue
54
+ all_labels.extend(batch["label"].argmax(dim=1).tolist())
55
+
56
+ if not all_labels:
57
+ raise ValueError("Нет ни одной метки в train_loader для вычисления весов классов.")
58
+
59
+ present_classes = np.unique(all_labels)
60
+
61
+ if len(present_classes) < num_classes:
62
+ missing = set(range(num_classes)) - set(present_classes)
63
+ logging.info(f"[!] Отсутствуют метки для классов: {sorted(missing)}")
64
+
65
+ # Вычисляем веса только по тем классам, что есть
66
+ weights_partial = compute_class_weight(
67
+ class_weight="balanced",
68
+ classes=present_classes,
69
+ y=all_labels
70
+ )
71
+
72
+ # Собираем полный вектор весов
73
+ full_weights = np.zeros(num_classes, dtype=np.float32)
74
+ for cls, w in zip(present_classes, weights_partial):
75
+ full_weights[cls] = w
76
+
77
+ return full_weights
78
+
79
+ def make_dataset_and_loader(config, split: str, only_dataset: str = None):
80
+ """
81
+ Универсальная функция: объединяет датасеты, или возвращает один при only_dataset.
82
+ """
83
+ datasets = []
84
+
85
+ if not hasattr(config, "datasets") or not config.datasets:
86
+ raise ValueError("⛔ В конфиге не указана секция [datasets].")
87
+
88
+ for dataset_name, dataset_cfg in config.datasets.items():
89
+ if only_dataset and dataset_name != only_dataset:
90
+ continue
91
+
92
+ csv_path = dataset_cfg["csv_path"].format(base_dir=dataset_cfg["base_dir"], split=split)
93
+ wav_dir = dataset_cfg["wav_dir"].format(base_dir=dataset_cfg["base_dir"], split=split)
94
+
95
+ logging.info(f"[{dataset_name.upper()}] Split={split}: CSV={csv_path}, WAV_DIR={wav_dir}")
96
+
97
+ dataset = DatasetMultiModal(
98
+ csv_path = csv_path,
99
+ wav_dir = wav_dir,
100
+ emotion_columns = config.emotion_columns,
101
+ split = split,
102
+ sample_rate = config.sample_rate,
103
+ wav_length = config.wav_length,
104
+ whisper_model = config.whisper_model,
105
+ text_column = config.text_column,
106
+ use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text,
107
+ whisper_device = config.whisper_device,
108
+ subset_size = config.subset_size,
109
+ merge_probability = config.merge_probability
110
+ )
111
+
112
+ datasets.append(dataset)
113
+
114
+ if not datasets:
115
+ raise ValueError(f"⚠️ Для split='{split}' не найдено ни одного подходящего датасета.")
116
+
117
+ # Объединяем только если их несколько
118
+ full_dataset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
119
+
120
+ loader = DataLoader(
121
+ full_dataset,
122
+ batch_size=config.batch_size,
123
+ shuffle=(split == "train"),
124
+ num_workers=config.num_workers,
125
+ collate_fn=custom_collate_fn
126
+ )
127
+
128
+ return full_dataset, loader
129
+
130
+ def run_eval(model, loader, audio_extractor, text_extractor, criterion, device="cuda"):
131
+ """
132
+ Оценка модели на loader'е. Возвращае�� (loss, uar, war, mf1, wf1).
133
+ """
134
+ model.eval()
135
+ total_loss = 0.0
136
+ total_preds = []
137
+ total_targets = []
138
+ total = 0
139
+
140
+ with torch.no_grad():
141
+ for batch in tqdm(loader):
142
+ if batch is None:
143
+ continue
144
+
145
+ audio = batch["audio"].to(device)
146
+ labels = batch["label"].to(device)
147
+ texts = batch["text"]
148
+
149
+ audio_emb = audio_extractor.extract(audio)
150
+ text_emb = text_extractor.extract(texts)
151
+
152
+ logits = model(audio_emb, text_emb)
153
+ target = labels.argmax(dim=1)
154
+
155
+ loss = criterion(logits, target)
156
+ bs = audio.shape[0]
157
+ total_loss += loss.item() * bs
158
+ total += bs
159
+
160
+ preds = logits.argmax(dim=1)
161
+ total_preds.extend(preds.cpu().numpy().tolist())
162
+ total_targets.extend(target.cpu().numpy().tolist())
163
+
164
+ avg_loss = total_loss / total
165
+
166
+ uar_m = uar(total_targets, total_preds)
167
+ war_m = war(total_targets, total_preds)
168
+ mf1_m = mf1(total_targets, total_preds)
169
+ wf1_m = wf1(total_targets, total_preds)
170
+
171
+ return avg_loss, uar_m, war_m, mf1_m, wf1_m
172
+
173
+ def train_once(config, train_loader, dev_loaders, test_loaders, metrics_csv_path=None):
174
+ """
175
+ Логика обучения (train/dev/test).
176
+ Возвращает лучшую метрику на dev и словарь метрик.
177
+ """
178
+
179
+ logging.info("== Запуск тренировки (train/dev/test) ==")
180
+
181
+ csv_writer = None
182
+ csv_file = None
183
+
184
+ if metrics_csv_path:
185
+ csv_file = open(metrics_csv_path, mode="w", newline="", encoding="utf-8")
186
+ csv_writer = csv.writer(csv_file)
187
+ csv_writer.writerow(["split", "epoch", "dataset", "loss", "uar", "war", "mf1", "wf1", "mean"])
188
+
189
+
190
+ # Seed
191
+ if config.random_seed > 0:
192
+ random.seed(config.random_seed)
193
+ torch.manual_seed(config.random_seed)
194
+ logging.info(f"== Фиксируем random seed: {config.random_seed}")
195
+ else:
196
+ logging.info("== Random seed не фиксирован (0).")
197
+
198
+ device = "cuda" if torch.cuda.is_available() else "cpu"
199
+
200
+ # Экстракторы
201
+ audio_extractor = AudioEmbeddingExtractor(config)
202
+ text_extractor = TextEmbeddingExtractor(config)
203
+
204
+ # Параметры
205
+ hidden_dim = config.hidden_dim
206
+ num_classes = len(config.emotion_columns)
207
+ num_transformer_heads = config.num_transformer_heads
208
+ num_graph_heads = config.num_graph_heads
209
+ hidden_dim_gated = config.hidden_dim_gated
210
+ mode = config.mode
211
+ positional_encoding = config.positional_encoding
212
+ dropout = config.dropout
213
+ out_features = config.out_features
214
+ lr = config.lr
215
+ num_epochs = config.num_epochs
216
+ tr_layer_number = config.tr_layer_number
217
+ max_patience = config.max_patience
218
+
219
+ dict_models = {
220
+ 'BiFormer': BiFormer,
221
+ 'BiGraphFormer': BiGraphFormer,
222
+ 'BiGatedGraphFormer': BiGatedGraphFormer,
223
+ # 'MultiModalTransformer_v5': MultiModalTransformer_v5,
224
+ # 'MultiModalTransformer_v4': MultiModalTransformer_v4,
225
+ # 'MultiModalTransformer_v3': MultiModalTransformer_v3
226
+ }
227
+
228
+ model_cls = dict_models[config.model_name]
229
+ model = model_cls(
230
+ audio_dim = config.audio_embedding_dim,
231
+ text_dim = config.text_embedding_dim,
232
+ hidden_dim = hidden_dim,
233
+ hidden_dim_gated = hidden_dim_gated,
234
+ num_transformer_heads = num_transformer_heads,
235
+ num_graph_heads = num_graph_heads,
236
+ seg_len = config.max_tokens,
237
+ mode = mode,
238
+ dropout = dropout,
239
+ positional_encoding = positional_encoding,
240
+ out_features = out_features,
241
+ tr_layer_number = tr_layer_number,
242
+ device = device,
243
+ num_classes = num_classes
244
+ ).to(device)
245
+
246
+ # Оптимизатор и лосс
247
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
248
+
249
+ class_weights = get_class_weights_from_loader(train_loader, num_classes)
250
+ criterion = WeightedCrossEntropyLoss(class_weights)
251
+
252
+ logging.info("Class weights: " + ", ".join(f"{name}={weight:.4f}" for name, weight in zip(config.emotion_columns, class_weights)))
253
+
254
+ # LR Scheduler
255
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
256
+ optimizer,
257
+ mode="max",
258
+ factor=0.5,
259
+ patience=2,
260
+ min_lr=1e-7
261
+ )
262
+
263
+ # Early stopping по dev
264
+ best_dev_mean = float("-inf")
265
+ best_dev_metrics = {}
266
+ patience_counter = 0
267
+
268
+ for epoch in range(num_epochs):
269
+ logging.info(f"\n=== Эпоха {epoch} ===")
270
+ model.train()
271
+
272
+ total_loss = 0.0
273
+ total_samples = 0
274
+ total_preds = []
275
+ total_targets = []
276
+
277
+ for batch in tqdm(train_loader):
278
+ if batch is None:
279
+ continue
280
+
281
+ audio = batch["audio"].to(device)
282
+ labels = batch["label"].to(device)
283
+ texts = batch["text"]
284
+
285
+ audio_emb = audio_extractor.extract(audio)
286
+ text_emb = text_extractor.extract(texts)
287
+
288
+ logits = model(audio_emb, text_emb)
289
+ target = labels.argmax(dim=1)
290
+ loss = criterion(logits, target)
291
+
292
+ optimizer.zero_grad()
293
+ loss.backward()
294
+ optimizer.step()
295
+
296
+ bs = audio.shape[0]
297
+ total_loss += loss.item() * bs
298
+
299
+ preds = logits.argmax(dim=1)
300
+ total_preds.extend(preds.cpu().numpy().tolist())
301
+ total_targets.extend(target.cpu().numpy().tolist())
302
+ total_samples += bs
303
+
304
+ train_loss = total_loss / total_samples
305
+ uar_m = uar(total_targets, total_preds)
306
+ war_m = war(total_targets, total_preds)
307
+ mf1_m = mf1(total_targets, total_preds)
308
+ wf1_m = wf1(total_targets, total_preds)
309
+ mean_train = np.mean([uar_m, war_m, mf1_m, wf1_m])
310
+
311
+ logging.info(
312
+ f"[TRAIN] Loss={train_loss:.4f}, UAR={uar_m:.4f}, WAR={war_m:.4f}, "
313
+ f"MF1={mf1_m:.4f}, WF1={wf1_m:.4f}, MEAN={mean_train:.4f}"
314
+ )
315
+
316
+ # --- DEV ---
317
+ dev_means = []
318
+ dev_metrics_by_dataset = []
319
+
320
+ for name, loader in dev_loaders:
321
+ d_loss, d_uar, d_war, d_mf1, d_wf1 = run_eval(
322
+ model, loader, audio_extractor, text_extractor, criterion, device
323
+ )
324
+ d_mean = np.mean([d_uar, d_war, d_mf1, d_wf1])
325
+ dev_means.append(d_mean)
326
+
327
+ if csv_writer:
328
+ csv_writer.writerow(["dev", epoch, name, d_loss, d_uar, d_war, d_mf1, d_wf1, d_mean])
329
+
330
+ logging.info(
331
+ f"[DEV:{name}] Loss={d_loss:.4f}, UAR={d_uar:.4f}, WAR={d_war:.4f}, "
332
+ f"MF1={d_mf1:.4f}, WF1={d_wf1:.4f}, MEAN={d_mean:.4f}"
333
+ )
334
+
335
+ dev_metrics_by_dataset.append({
336
+ "name": name,
337
+ "loss": d_loss,
338
+ "uar": d_uar,
339
+ "war": d_war,
340
+ "mf1": d_mf1,
341
+ "wf1": d_wf1,
342
+ "mean": d_mean,
343
+ })
344
+
345
+ mean_dev = np.mean(dev_means)
346
+ scheduler.step(mean_dev)
347
+
348
+ if mean_dev > best_dev_mean:
349
+ best_dev_mean = mean_dev
350
+ patience_counter = 0
351
+ best_dev_metrics = {
352
+ "mean": mean_dev
353
+ }
354
+ best_dev_metrics["by_dataset"] = dev_metrics_by_dataset
355
+ else:
356
+ patience_counter += 1
357
+ if patience_counter >= max_patience:
358
+ logging.info(f"Early stopping: {max_patience} эпох без улучшения.")
359
+ break
360
+
361
+ # --- TEST ---
362
+ for name, loader in test_loaders:
363
+ t_loss, t_uar, t_war, t_mf1, t_wf1 = run_eval(
364
+ model, loader, audio_extractor, text_extractor, criterion, device
365
+ )
366
+ t_mean = np.mean([t_uar, t_war, t_mf1, t_wf1])
367
+ logging.info(
368
+ f"[TEST:{name}] Loss={t_loss:.4f}, UAR={t_uar:.4f}, WAR={t_war:.4f}, "
369
+ f"MF1={t_mf1:.4f}, WF1={t_wf1:.4f}, MEAN={t_mean:.4f}"
370
+ )
371
+
372
+ if csv_writer:
373
+ csv_writer.writerow(["test", epoch, name, t_loss, t_uar, t_war, t_mf1, t_wf1, t_mean])
374
+
375
+ if csv_file:
376
+ csv_file.close()
377
+
378
+ logging.info("Тренировка завершена. Все split'ы обработаны!")
379
+ return best_dev_mean, best_dev_metrics
utils/__pycache__/config_loader.cpython-310.pyc ADDED
Binary file (6.53 kB). View file
 
utils/config_loader.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/config_loader.py
2
+
3
+ import os
4
+ import toml
5
+ import logging
6
+
7
+ class ConfigLoader:
8
+ """
9
+ Класс для загрузки и обработки конфигурации из `config.toml`.
10
+ """
11
+
12
+ def __init__(self, config_path="config.toml"):
13
+ if not os.path.exists(config_path):
14
+ raise FileNotFoundError(f"Файл конфигурации `{config_path}` не найден!")
15
+
16
+ self.config = toml.load(config_path)
17
+
18
+ # ---------------------------
19
+ # Общие параметры
20
+ # ---------------------------
21
+ self.split = self.config.get("split", "train")
22
+
23
+ # ---------------------------
24
+ # Пути к данным
25
+ # ---------------------------
26
+ self.datasets = self.config.get("datasets", {})
27
+
28
+ # ---------------------------
29
+ # Пути к синтетическим данным
30
+ # ---------------------------
31
+ synthetic_data_cfg = self.config.get("synthetic_data", {})
32
+ self.use_synthetic_data = synthetic_data_cfg.get("use_synthetic_data", False)
33
+ self.synthetic_path = synthetic_data_cfg.get("synthetic_path", "E:/MELD_S")
34
+ self.synthetic_ratio = synthetic_data_cfg.get("synthetic_ratio", 0.0)
35
+
36
+ # ---------------------------
37
+ # Модальности и эмоции
38
+ # ---------------------------
39
+ self.modalities = self.config.get("modalities", ["audio"])
40
+ self.emotion_columns = self.config.get("emotion_columns", ["anger", "disgust", "fear", "happy", "neutral", "sad", "surprise"])
41
+
42
+ # ---------------------------
43
+ # DataLoader
44
+ # ---------------------------
45
+ dataloader_cfg = self.config.get("dataloader", {})
46
+ self.num_workers = dataloader_cfg.get("num_workers", 0)
47
+ self.shuffle = dataloader_cfg.get("shuffle", True)
48
+ self.prepare_only = dataloader_cfg.get("prepare_only", False)
49
+
50
+ # ---------------------------
51
+ # Аудио
52
+ # ---------------------------
53
+ audio_cfg = self.config.get("audio", {})
54
+ self.sample_rate = audio_cfg.get("sample_rate", 16000)
55
+ self.wav_length = audio_cfg.get("wav_length", 2)
56
+ self.save_merged_audio = audio_cfg.get("save_merged_audio", True)
57
+ self.merged_audio_base_path = audio_cfg.get("merged_audio_base_path", "saved_merges")
58
+ self.merged_audio_suffix = audio_cfg.get("merged_audio_suffix", "_merged")
59
+ self.force_remerge = audio_cfg.get("force_remerge", False)
60
+
61
+ # ---------------------------
62
+ # Whisper / Текст
63
+ # ---------------------------
64
+ text_cfg = self.config.get("text", {})
65
+ self.text_source = text_cfg.get("source", "csv")
66
+ self.text_column = text_cfg.get("text_column", "text")
67
+ self.whisper_model = text_cfg.get("whisper_model", "tiny")
68
+ self.max_text_tokens = text_cfg.get("max_tokens", 15)
69
+ self.whisper_device = text_cfg.get("whisper_device", "cuda")
70
+ self.use_whisper_for_nontrain_if_no_text = text_cfg.get("use_whisper_for_nontrain_if_no_text", True)
71
+
72
+ # ---------------------------
73
+ # Тренировка: общие
74
+ # ---------------------------
75
+ train_general = self.config.get("train", {}).get("general", {})
76
+ self.random_seed = train_general.get("random_seed", 42)
77
+ self.subset_size = train_general.get("subset_size", 0)
78
+ self.merge_probability = train_general.get("merge_probability", 0)
79
+ self.batch_size = train_general.get("batch_size", 8)
80
+ self.num_epochs = train_general.get("num_epochs", 100)
81
+ self.max_patience = train_general.get("max_patience", 10)
82
+ self.save_best_model = train_general.get("save_best_model", False)
83
+ self.save_prepared_data = train_general.get("save_prepared_data", True)
84
+ self.save_feature_path = train_general.get("save_feature_path", "./features/")
85
+ self.search_type = train_general.get("search_type", "none")
86
+ self.smoothing_probability = train_general.get("smoothing_probability", 0)
87
+ self.path_to_df_ls = train_general.get("path_to_df_ls", None)
88
+
89
+ # ---------------------------
90
+ # Тренировка: параметры модели
91
+ # ---------------------------
92
+ train_model = self.config.get("train", {}).get("model", {})
93
+ self.model_name = train_model.get("model_name", "BiFormer")
94
+ self.hidden_dim = train_model.get("hidden_dim", 256)
95
+ self.hidden_dim_gated = train_model.get("hidden_dim_gated", 256)
96
+ self.num_transformer_heads = train_model.get("num_transformer_heads", 8)
97
+ self.num_graph_heads = train_model.get("num_graph_heads", 8)
98
+ self.tr_layer_number = train_model.get("tr_layer_number", 1)
99
+ self.mamba_d_state = train_model.get("mamba_d_state", 16)
100
+ self.mamba_ker_size = train_model.get("mamba_ker_size", 4)
101
+ self.mamba_layer_number = train_model.get("mamba_layer_number", 3)
102
+ self.positional_encoding = train_model.get("positional_encoding", True)
103
+ self.dropout = train_model.get("dropout", 0.0)
104
+ self.out_features = train_model.get("out_features", 128)
105
+ self.mode = train_model.get("mode", "mean")
106
+
107
+ # ---------------------------
108
+ # Тренировка: оптимизатор
109
+ # ---------------------------
110
+ train_optimizer = self.config.get("train", {}).get("optimizer", {})
111
+ self.optimizer = train_optimizer.get("optimizer", "adam")
112
+ self.lr = train_optimizer.get("lr", 1e-4)
113
+ self.weight_decay = train_optimizer.get("weight_decay", 0.0)
114
+ self.momentum = train_optimizer.get("momentum", 0.9)
115
+
116
+ # ---------------------------
117
+ # Тренировка: шедулер
118
+ # ---------------------------
119
+ train_scheduler = self.config.get("train", {}).get("scheduler", {})
120
+ self.scheduler_type = train_scheduler.get("scheduler_type", "plateau")
121
+ self.warmup_ratio = train_scheduler.get("warmup_ratio", 0.1)
122
+
123
+ # ---------------------------
124
+ # Эмбеддинги
125
+ # ---------------------------
126
+ emb_cfg = self.config.get("embeddings", {})
127
+ self.audio_model_name = emb_cfg.get("audio_model", "amiriparian/ExHuBERT")
128
+ self.text_model_name = emb_cfg.get("text_model", "jinaai/jina-embeddings-v3")
129
+ self.audio_classifier_checkpoint = emb_cfg.get("audio_classifier_checkpoint", "best_audio_model.pt")
130
+ self.text_classifier_checkpoint = emb_cfg.get("text_classifier_checkpoint", "best_text_model.pth")
131
+ self.audio_embedding_dim = emb_cfg.get("audio_embedding_dim", 1024)
132
+ self.text_embedding_dim = emb_cfg.get("text_embedding_dim", 1024)
133
+ self.emb_normalize = emb_cfg.get("emb_normalize", True)
134
+ self.audio_pooling = emb_cfg.get("audio_pooling", None)
135
+ self.text_pooling = emb_cfg.get("text_pooling", None)
136
+ self.max_tokens = emb_cfg.get("max_tokens", 256)
137
+ self.emb_device = emb_cfg.get("device", "cuda")
138
+
139
+ # ---------------------------
140
+ # Синтетика
141
+ # ---------------------------
142
+ # textgen_cfg = self.config.get("textgen", {})
143
+ # self.model_name = textgen_cfg.get("model_name", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
144
+ # self.max_new_tokens = textgen_cfg.get("max_new_tokens", 50)
145
+ # self.temperature = textgen_cfg.get("temperature", 1.0)
146
+ # self.top_p = textgen_cfg.get("top_p", 0.95)
147
+
148
+ if __name__ == "__main__":
149
+ self.log_config()
150
+
151
+ def log_config(self):
152
+ logging.info("=== CONFIGURATION ===")
153
+ logging.info(f"Split: {self.split}")
154
+ logging.info(f"Datasets loaded: {list(self.datasets.keys())}")
155
+ for name, ds in self.datasets.items():
156
+ logging.info(f"[Dataset: {name}]")
157
+ logging.info(f" Base Dir: {ds.get('base_dir', 'N/A')}")
158
+ logging.info(f" CSV Path: {ds.get('csv_path', '')}")
159
+ logging.info(f" WAV Dir: {ds.get('wav_dir', '')}")
160
+ logging.info(f"Emotion columns: {self.emotion_columns}")
161
+
162
+ # Логируем обучающие параметры
163
+ logging.info("--- Training Config ---")
164
+ logging.info(f"Sample Rate={self.sample_rate}, Wav Length={self.wav_length}s")
165
+ logging.info(f"Whisper Model={self.whisper_model}, Device={self.whisper_device}, MaxTokens={self.max_text_tokens}")
166
+ logging.info(f"use_whisper_for_nontrain_if_no_text={self.use_whisper_for_nontrain_if_no_text}")
167
+ logging.info(f"DataLoader: batch_size={self.batch_size}, num_workers={self.num_workers}, shuffle={self.shuffle}")
168
+ logging.info(f"Model Name: {self.model_name}")
169
+ logging.info(f"Random Seed: {self.random_seed}")
170
+ logging.info(f"Hidden Dim: {self.hidden_dim}")
171
+ logging.info(f"Hidden Dim in Gated: {self.hidden_dim_gated}")
172
+ logging.info(f"Num Heads in Transformer: {self.num_transformer_heads}")
173
+ logging.info(f"Num Heads in Graph: {self.num_graph_heads}")
174
+ logging.info(f"Mode stat pooling: {self.mode}")
175
+ logging.info(f"Optimizer: {self.optimizer}")
176
+ logging.info(f"Scheduler Type: {self.scheduler_type}")
177
+ logging.info(f"Warmup Ratio: {self.warmup_ratio}")
178
+ logging.info(f"Weight Decay for Adam: {self.weight_decay}")
179
+ logging.info(f"Momentum (SGD): {self.momentum}")
180
+ logging.info(f"Positional Encoding: {self.positional_encoding}")
181
+ logging.info(f"Number of Transformer Layers: {self.tr_layer_number}")
182
+ logging.info(f"Mamba D State: {self.mamba_d_state}")
183
+ logging.info(f"Mamba Kernel Size: {self.mamba_ker_size}")
184
+ logging.info(f"Mamba Layer Number: {self.mamba_layer_number}")
185
+ logging.info(f"Dropout: {self.dropout}")
186
+ logging.info(f"Out Features: {self.out_features}")
187
+ logging.info(f"LR: {self.lr}")
188
+ logging.info(f"Num Epochs: {self.num_epochs}")
189
+ logging.info(f"Merge Probability={self.merge_probability}")
190
+ logging.info(f"Smoothing Probability={self.smoothing_probability}")
191
+ logging.info(f"Max Patience={self.max_patience}")
192
+ logging.info(f"Save Prepared Data={self.save_prepared_data}")
193
+ logging.info(f"Path to Save Features={self.save_feature_path}")
194
+ logging.info(f"Search Type={self.search_type}")
195
+
196
+ # Логируем embeddings
197
+ logging.info("--- Embeddings Config ---")
198
+ logging.info(f"Audio Model: {self.audio_model_name}, Text Model: {self.text_model_name}")
199
+ logging.info(f"Audio dim={self.audio_embedding_dim}, Text dim={self.text_embedding_dim}")
200
+ logging.info(f"Audio pooling={self.audio_pooling}, Text pooling={self.text_pooling}")
201
+ logging.info(f"Emb device={self.emb_device}, Normalize={self.emb_normalize}")
202
+
203
+ def show_config(self):
204
+ self.log_config()
utils/logger_setup.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/logger_setup.py
2
+
3
+ import logging
4
+ from colorlog import ColoredFormatter
5
+
6
+ def setup_logger(level=logging.INFO, log_file=None):
7
+ """
8
+ Настраивает корневой логгер для вывода цветных логов в консоль и
9
+ (опционально) записи в файл.
10
+
11
+ :param level: Уровень логирования (например, logging.DEBUG)
12
+ :param log_file: Путь к файлу лога (если не None, логи будут писаться в этот файл)
13
+ """
14
+ logger = logging.getLogger()
15
+ if logger.hasHandlers():
16
+ logger.handlers.clear()
17
+
18
+ # Консольный хендлер с colorlog
19
+ console_handler = logging.StreamHandler()
20
+ log_format = (
21
+ "%(log_color)s%(asctime)s [%(levelname)s]%(reset)s %(blue)s%(message)s"
22
+ )
23
+ console_formatter = ColoredFormatter(
24
+ log_format,
25
+ datefmt="%Y-%m-%d %H:%M:%S",
26
+ reset=True,
27
+ log_colors={
28
+ "DEBUG": "cyan",
29
+ "INFO": "green",
30
+ "WARNING": "yellow",
31
+ "ERROR": "red",
32
+ "CRITICAL": "bold_red"
33
+ }
34
+ )
35
+ console_handler.setFormatter(console_formatter)
36
+ logger.addHandler(console_handler)
37
+
38
+ # Если указан log_file, добавляем файловый хендлер
39
+ if log_file is not None:
40
+ file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
41
+ file_format = "%(asctime)s [%(levelname)s] %(message)s"
42
+ file_formatter = logging.Formatter(file_format, datefmt="%Y-%m-%d %H:%M:%S")
43
+ file_handler.setFormatter(file_formatter)
44
+ logger.addHandler(file_handler)
45
+
46
+ logger.setLevel(level)
47
+ return logger
utils/losses.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class WeightedCrossEntropyLoss(nn.Module):
6
+ def __init__(self, class_weights=None):
7
+ """
8
+ Инициализация класса для кросс-энтропийной потери с возможностью взвешивания классов.
9
+
10
+ :param class_weights: Вектор весов для классов (опционально)
11
+ """
12
+ super(WeightedCrossEntropyLoss, self).__init__()
13
+ self.class_weights = class_weights
14
+
15
+ def forward(self, y_pred, y_true):
16
+ """
17
+ Вычисление кросс-энтропийной потери с (или без) взвешиванием классов.
18
+
19
+ :param y_true: Точные метки классов (вектор или одна метка)
20
+ :param y_pred: Вероятностный вектор предсказаний
21
+ :return: Значение потери
22
+ """
23
+
24
+ y_true = y_true.to(torch.long) # Приводим метки к типу Long
25
+ y_pred = y_pred.to(torch.float32) # Приводим предсказания к типу Float32
26
+
27
+ if self.class_weights is not None:
28
+ class_weights = torch.tensor(self.class_weights).float().to(y_true.device)
29
+ loss = F.cross_entropy(y_pred, y_true, weight=class_weights)
30
+ else:
31
+ loss = F.cross_entropy(y_pred, y_true)
32
+
33
+ return loss
utils/measures.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import recall_score, f1_score
2
+
3
+ def uar(y_true, y_pred):
4
+ """
5
+ Вычисление метрики UAR (Unweighted Average Recall).
6
+
7
+ :param y_true: Истинные метки
8
+ :param y_pred: Предсказанные метки
9
+ :return: UAR (Recall по всем классам без учета веса)
10
+ """
11
+ return recall_score(y_true, y_pred, average='macro', zero_division=0)
12
+
13
+ def war(y_true, y_pred):
14
+ """
15
+ Вычисление метрики WAR (Weighted Average Recall).
16
+
17
+ :param y_true: Истинные метки
18
+ :param y_pred: Предсказанные метки
19
+ :return: WAR (Recall с учетом веса классов)
20
+ """
21
+ return recall_score(y_true, y_pred, average='weighted', zero_division=0)
22
+
23
+ def mf1(y_true, y_pred):
24
+ """
25
+ Вычисление метрики MF1 (Macro F1 Score).
26
+
27
+ :param y_true: Истинные метки
28
+ :param y_pred: Предсказанные метки
29
+ :return: MF1 (F1 с усреднением по всем классам)
30
+ """
31
+ return f1_score(y_true, y_pred, average='macro', zero_division=0)
32
+
33
+ def wf1(y_true, y_pred):
34
+ """
35
+ Вычисление метрики WFI (Weighted F1 Score).
36
+
37
+ :param y_true: Истинные метки
38
+ :param y_pred: Предсказанные метки
39
+ :return: WFI (F1 с учетом веса классов)
40
+ """
41
+ return f1_score(y_true, y_pred, average='weighted', zero_division=0)