TimurHromek commited on
Commit
9561098
·
verified ·
1 Parent(s): d55534a

Updated to HROM-V1.5 trainer.

Browse files
Files changed (2) hide show
  1. HROM-V1.5_Trainer.py +1311 -0
  2. HROM_Trainer.py +0 -384
HROM-V1.5_Trainer.py ADDED
@@ -0,0 +1,1311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Set parallelism env var *before* importing tokenizers
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset, DataLoader
8
+ # Import necessary dataset functions, including concatenate_datasets if needed later
9
+ from datasets import load_dataset, disable_caching, concatenate_datasets
10
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
11
+ import math
12
+ import re
13
+ from datetime import datetime
14
+ from contextlib import nullcontext
15
+ from collections import defaultdict
16
+ import logging
17
+ import random # For shuffling combined data
18
+
19
+ # Disable caching for datasets if needed, helps ensure reprocessing
20
+ # disable_caching()
21
+
22
+ # Setup logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(levelname)s - %(message)s',
26
+ force=True # Add this
27
+ )
28
+
29
+ # Configuration
30
+ CONFIG = {
31
+ "dim": 768,
32
+ "n_layers": 8,
33
+ "n_heads": 8,
34
+ "ff_dim": 2048,
35
+ "dropout": 0.1,
36
+ "max_seq_len": 512,
37
+ "batch_size": 16, # Keep batch size reasonable
38
+ "checkpoint_interval": 2000,
39
+ "debug_interval": 400,
40
+ # Reverted to training on all four datasets, using correct persona_chat identifier
41
+ "datasets": ["daily_dialog", "empathetic_dialogues", "blended_skill_talk", "AlekseyKorshuk/persona-chat"],
42
+ # Reverted to combined tokenizer name
43
+ "tokenizer_name": "hrom_tokenizer.json",
44
+ # Reverted to combined checkpoint dir
45
+ "checkpoint_dir": "checkpoints",
46
+ "vocab_size": 32000,
47
+ # Adjusted samples per dataset: with 4 datasets, 50k each gives 200k total samples
48
+ "tokenizer_train_samples_per_dataset": 50000,
49
+ "learning_rate": 2e-5,
50
+ "warmup_steps": 1000,
51
+ "max_turns": 8, # Max turns applied per dialogue
52
+ "max_checkpoints": 5,
53
+ "num_epochs": 30,
54
+ "grad_accum_steps": 8 # Keep grad accum reasonable
55
+ }
56
+
57
+ # --- Model Definition (HROM, HROMBlock, HROMAttention, SwiGLU, RoPE) ---
58
+ # (These classes remain unchanged from the previous version)
59
+
60
+ class RotaryEmbedding(nn.Module):
61
+ def __init__(self, dim):
62
+ super().__init__()
63
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
64
+ self.register_buffer("inv_freq", inv_freq)
65
+
66
+ def forward(self, seq_len):
67
+ t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
68
+ freqs = torch.einsum("i, j -> i j", t, self.inv_freq)
69
+ if seq_len == 0:
70
+ return torch.empty((0, self.inv_freq.shape[0] * 2), device=self.inv_freq.device)
71
+ # Defensive reshape only if necessary
72
+ if freqs.shape[0] != seq_len and seq_len > 0:
73
+ freqs = freqs.reshape(seq_len, -1)
74
+ elif seq_len == 0: # Handle edge case for empty sequences
75
+ return torch.empty((0, self.inv_freq.shape[0]*2), device=self.inv_freq.device, dtype=self.inv_freq.dtype)
76
+
77
+ return torch.cat((freqs, freqs), dim=-1)
78
+
79
+ def rotate_half(x):
80
+ x1, x2 = x.chunk(2, dim=-1)
81
+ return torch.cat((-x2, x1), dim=-1)
82
+
83
+ def apply_rotary_pos_emb(pos, t):
84
+ # pos: (T, dim_rotary), t: (B, H, T, Head_Dim)
85
+ pos = pos.to(t.device, dtype=t.dtype)
86
+ pos = pos.unsqueeze(0).unsqueeze(1) # Shape: (1, 1, T, dim_rotary)
87
+ tensor_seq_len = t.shape[2]
88
+ pos_seq_len = pos.shape[2]
89
+
90
+ if pos_seq_len < tensor_seq_len:
91
+ logging.warning(f"RoPE Warning: pos sequence length ({pos_seq_len}) is shorter than tensor sequence length ({tensor_seq_len}). Using truncated tensor length for RoPE.")
92
+ # This case is tricky, maybe only apply to the length of pos?
93
+ # Or indicates an issue upstream. Let's slice t for now, though it's unusual.
94
+ t_rotated = t[:, :, :pos_seq_len, :]
95
+ pos = pos[:, :, :pos_seq_len, :] # Ensure pos matches the sliced tensor length
96
+
97
+ # Apply rotation only to the slice
98
+ cos_pos = pos.cos()
99
+ sin_pos = pos.sin()
100
+ t_rotated = (t_rotated * cos_pos) + (rotate_half(t_rotated) * sin_pos)
101
+
102
+ # Concatenate the rotated part with the un-rotated part
103
+ t_unrotated = t[:, :, pos_seq_len:, :]
104
+ return torch.cat([t_rotated, t_unrotated], dim=2)
105
+
106
+ elif pos_seq_len > tensor_seq_len:
107
+ pos = pos[:, :, :tensor_seq_len, :] # Slice pos to match tensor
108
+
109
+ # Check dimension match after potential slicing
110
+ if pos.shape[-1] != t.shape[-1]:
111
+ logging.error(f"Mismatched dimensions for RoPE: pos ({pos.shape[-1]}) vs t ({t.shape[-1]})")
112
+ raise ValueError("Rotary embedding dimension must match head dimension.")
113
+
114
+ cos_pos = pos.cos()
115
+ sin_pos = pos.sin()
116
+ rotated_t = (t * cos_pos) + (rotate_half(t) * sin_pos)
117
+ return rotated_t
118
+
119
+
120
+ class SwiGLU(nn.Module):
121
+ def forward(self, x):
122
+ x, gate = x.chunk(2, dim=-1)
123
+ return x * nn.functional.gelu(gate)
124
+
125
+ class HROMAttention(nn.Module):
126
+ def __init__(self):
127
+ super().__init__()
128
+ self.dim = CONFIG["dim"]
129
+ self.n_heads = CONFIG["n_heads"]
130
+ self.head_dim = self.dim // self.n_heads
131
+ if self.dim % self.n_heads != 0:
132
+ raise ValueError("dim must be divisible by n_heads")
133
+ self.qkv = nn.Linear(self.dim, 3 * self.dim)
134
+ self.proj = nn.Linear(self.dim, self.dim)
135
+ self.rotary = RotaryEmbedding(self.head_dim)
136
+ self.dropout = nn.Dropout(CONFIG["dropout"])
137
+
138
+ def forward(self, x, mask=None):
139
+ B, T, C = x.shape
140
+ qkv = self.qkv(x)
141
+ qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim)
142
+ q, k, v = qkv.unbind(2)
143
+ q = q.transpose(1, 2)
144
+ k = k.transpose(1, 2)
145
+ v = v.transpose(1, 2)
146
+ # Generate RoPE embeddings for the current sequence length T
147
+ pos = self.rotary(T) # Shape (T, Head_Dim)
148
+ # Apply RoPE
149
+ q = apply_rotary_pos_emb(pos, q)
150
+ k = apply_rotary_pos_emb(pos, k)
151
+ # Attention calculation
152
+ attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
153
+ if mask is not None:
154
+ # Ensure mask is broadcastable (B, 1, T, T)
155
+ if mask.dim() == 2: # (B, T) -> (B, 1, 1, T) -> add with causal = (B, 1, T, T)
156
+ mask = mask.unsqueeze(1).unsqueeze(2)
157
+ elif mask.dim() == 3: # (B, T, T)
158
+ mask = mask.unsqueeze(1)
159
+ # Add mask AFTER scaling scores
160
+ attn_scores = attn_scores + mask # Add large negative values for masked positions
161
+ # Softmax and dropout
162
+ attn_probs = torch.softmax(attn_scores.float(), dim=-1).to(dtype=x.dtype) # Use float for stability
163
+ attn_probs = self.dropout(attn_probs)
164
+ # Output projection
165
+ output = attn_probs @ v
166
+ output = output.transpose(1, 2).reshape(B, T, self.dim)
167
+ return self.proj(output)
168
+
169
+
170
+ class HROMBlock(nn.Module):
171
+ def __init__(self):
172
+ super().__init__()
173
+ self.attn = HROMAttention()
174
+ self.ff = nn.Sequential(
175
+ nn.Linear(CONFIG["dim"], 2 * CONFIG["ff_dim"]),
176
+ SwiGLU(),
177
+ nn.Linear(CONFIG["ff_dim"], CONFIG["dim"])
178
+ )
179
+ self.norm1 = nn.LayerNorm(CONFIG["dim"])
180
+ self.norm2 = nn.LayerNorm(CONFIG["dim"])
181
+ self.dropout = nn.Dropout(CONFIG["dropout"])
182
+
183
+ def forward(self, x, mask=None):
184
+ # Pre-Normalization
185
+ normed_x = self.norm1(x)
186
+ attn_output = self.attn(normed_x, mask)
187
+ x = x + self.dropout(attn_output)
188
+
189
+ normed_x = self.norm2(x)
190
+ ff_output = self.ff(normed_x)
191
+ x = x + self.dropout(ff_output)
192
+ return x
193
+
194
+ class HROM(nn.Module):
195
+ def __init__(self):
196
+ super().__init__()
197
+ self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"])
198
+ self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])])
199
+ self.norm = nn.LayerNorm(CONFIG["dim"])
200
+ self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"])
201
+ self.dropout = nn.Dropout(CONFIG["dropout"]) # Add dropout after embedding
202
+ self.apply(self._init_weights)
203
+
204
+ def _init_weights(self, module):
205
+ if isinstance(module, nn.Linear):
206
+ torch.nn.init.xavier_uniform_(module.weight)
207
+ if module.bias is not None:
208
+ torch.nn.init.zeros_(module.bias)
209
+ elif isinstance(module, nn.Embedding):
210
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
211
+ elif isinstance(module, nn.LayerNorm):
212
+ torch.nn.init.zeros_(module.bias)
213
+ torch.nn.init.ones_(module.weight)
214
+
215
+ def forward(self, input_ids, attention_mask=None):
216
+ B, T = input_ids.shape
217
+ x = self.embed(input_ids)
218
+ x = self.dropout(x) # Apply dropout after embedding
219
+
220
+ # Create the combined mask for attention
221
+ combined_mask = None
222
+ # Start with causal mask valid for all sequences in batch
223
+ causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device) * float('-inf'), diagonal=1)
224
+ combined_mask = causal_mask.unsqueeze(0).unsqueeze(1) # (1, 1, T, T)
225
+
226
+ if attention_mask is not None:
227
+ # Process padding mask from attention_mask (0 = pad, 1 = real)
228
+ # Convert 0s to -inf, 1s to 0
229
+ pad_mask = (1.0 - attention_mask.to(torch.float32)) * torch.finfo(torch.float32).min
230
+ pad_mask = pad_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
231
+ # Add padding mask to causal mask. Broadcasting ensures (B, 1, T, T)
232
+ # Where pad_mask is -inf, the result is -inf. Otherwise, it's the causal value.
233
+ combined_mask = combined_mask + pad_mask
234
+
235
+ # Ensure mask dtype matches data dtype (esp. for AMP)
236
+ combined_mask = combined_mask.to(dtype=x.dtype)
237
+
238
+ for block in self.blocks:
239
+ x = block(x, combined_mask) # Pass the combined mask to each block
240
+
241
+ x = self.norm(x)
242
+ logits = self.head(x)
243
+ return logits
244
+
245
+ # --- Tokenizer Training ---
246
+
247
+ class TokenizerTrainer:
248
+ def __init__(self):
249
+ self.tokenizer = Tokenizer(models.BPE())
250
+ self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
251
+ self.tokenizer.decoder = decoders.ByteLevel()
252
+ self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
253
+ # Use the updated tokenizer name from CONFIG
254
+ self.tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"])
255
+ self.tokenizer_dir = os.path.dirname(self.tokenizer_path)
256
+
257
+ def _clean_text(self, text):
258
+ text = str(text) # Ensure text is string
259
+ text = re.sub(r'_comma_', ',', text)
260
+ # Allow alphanumeric, whitespace, and basic punctuation including quotes
261
+ text = re.sub(r'[^\w\s.,!?\'\-:;<>"]', '', text)
262
+ text = re.sub(r'\s+', ' ', text).strip()
263
+ return text
264
+
265
+ def train(self, dataset_names):
266
+ logging.info("Starting tokenizer training...")
267
+ text_samples = []
268
+ samples_per_dataset = CONFIG['tokenizer_train_samples_per_dataset']
269
+
270
+ # --- Process DailyDialog ---
271
+ if "daily_dialog" in dataset_names:
272
+ logging.info(f"Loading daily_dialog for tokenizer training (max {samples_per_dataset} dialogues)...")
273
+ try:
274
+ # Limit dialogues loaded directly using slicing
275
+ dd_dataset = load_dataset("daily_dialog", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) # Add trust_remote_code=True
276
+ logging.info("Processing daily_dialog...")
277
+ for entry in dd_dataset:
278
+ formatted_dialogue = []
279
+ dialogue = entry['dialog'][:CONFIG["max_turns"]]
280
+ for i, utterance in enumerate(dialogue):
281
+ role = "<user>" if i % 2 == 0 else "<assistant>"
282
+ cleaned_utterance = self._clean_text(utterance)
283
+ if cleaned_utterance: # Only add non-empty turns
284
+ formatted_dialogue.append(f"{role} {cleaned_utterance}")
285
+ if formatted_dialogue: # Only add if dialogue is not empty after cleaning
286
+ text_samples.append(" </s> ".join(formatted_dialogue))
287
+ except Exception as e:
288
+ logging.error(f"Failed to load or process daily_dialog for tokenizer: {e}")
289
+
290
+ # --- Process EmpatheticDialogues ---
291
+ if "empathetic_dialogues" in dataset_names:
292
+ logging.info(f"Loading empathetic_dialogues for tokenizer training (max {samples_per_dataset} dialogues)...")
293
+ try:
294
+ # Load more initially to ensure we get enough unique conversations (adjust multiplier if needed)
295
+ ed_dataset = load_dataset("empathetic_dialogues", split=f"train[:{samples_per_dataset * 3}]", trust_remote_code=True) # Add trust_remote_code=True
296
+ logging.info("Processing empathetic_dialogues...")
297
+ conversations = defaultdict(list)
298
+ processed_conv_count = 0
299
+ # Group utterances by conv_id first
300
+ grouped_by_conv = defaultdict(list)
301
+ for entry in ed_dataset:
302
+ grouped_by_conv[entry['conv_id']].append(entry)
303
+
304
+ # Process conversations ensuring max samples limit
305
+ for conv_id, entries in grouped_by_conv.items():
306
+ if processed_conv_count >= samples_per_dataset:
307
+ break
308
+ # Sort by utterance_idx to maintain order
309
+ sorted_entries = sorted(entries, key=lambda x: x['utterance_idx'])
310
+ formatted_dialogue = []
311
+ # Handle context and first utterance
312
+ if sorted_entries[0]['context']:
313
+ cleaned_context = self._clean_text(sorted_entries[0]['context'])
314
+ if cleaned_context:
315
+ formatted_dialogue.append(f"<user> {cleaned_context}") # Assume context is user start
316
+ # Process subsequent utterances
317
+ last_role = '<user>' if formatted_dialogue else None # Set initial last role based on context
318
+ for entry in sorted_entries:
319
+ cleaned_utterance = self._clean_text(entry['utterance'])
320
+ if cleaned_utterance:
321
+ # Determine role based on alternation
322
+ current_role = '<assistant>' if last_role == '<user>' else '<user>'
323
+ formatted_dialogue.append(f"{current_role} {cleaned_utterance}")
324
+ last_role = current_role # Update last role
325
+ # Apply max turns limit to the formatted turns
326
+ formatted_dialogue = formatted_dialogue[:CONFIG["max_turns"]]
327
+ if formatted_dialogue:
328
+ text_samples.append(" </s> ".join(formatted_dialogue))
329
+ processed_conv_count += 1 # Count processed unique conversations
330
+
331
+ except Exception as e:
332
+ logging.error(f"Failed to load or process empathetic_dialogues for tokenizer: {e}")
333
+
334
+
335
+ # --- Process BlendedSkillTalk ---
336
+ if "blended_skill_talk" in dataset_names:
337
+ logging.info(f"Loading blended_skill_talk for tokenizer training (max {samples_per_dataset} dialogues)...")
338
+ try:
339
+ # Load dialogues - BST is structured differently, slice directly
340
+ bst_dataset = load_dataset("blended_skill_talk", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) # Add trust_remote_code=True
341
+ logging.info("Processing blended_skill_talk...")
342
+ for entry in bst_dataset:
343
+ formatted_dialogue = []
344
+ # Combine the dialogue history and the final two turns
345
+ dialogue_turns_raw = entry['previous_utterance']
346
+ # Add final utterances if they exist and are not empty strings
347
+ if entry.get('free_turker_utterance'):
348
+ dialogue_turns_raw.append(entry['free_turker_utterance'])
349
+ if entry.get('guided_turker_utterance'):
350
+ dialogue_turns_raw.append(entry['guided_turker_utterance'])
351
+
352
+ turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] # Apply max turns limit
353
+ for i, utterance in enumerate(turns_to_process):
354
+ role = "<user>" if i % 2 == 0 else "<assistant>" # Assume simple alternation
355
+ cleaned_utterance = self._clean_text(utterance)
356
+ if cleaned_utterance:
357
+ formatted_dialogue.append(f"{role} {cleaned_utterance}")
358
+ if formatted_dialogue:
359
+ text_samples.append(" </s> ".join(formatted_dialogue))
360
+ except Exception as e:
361
+ logging.error(f"Failed to load or process blended_skill_talk for tokenizer: {e}")
362
+
363
+ # --- Process PersonaChat ---
364
+ if "AlekseyKorshuk/persona-chat" in dataset_names: # Correct dataset identifier
365
+ pc_dataset_name = "AlekseyKorshuk/persona-chat"
366
+ logging.info(f"Loading {pc_dataset_name} for tokenizer training (max {samples_per_dataset} dialogues)...")
367
+ try:
368
+ pc_dataset = load_dataset(pc_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True) # Add trust_remote_code=True, Correct dataset identifier
369
+ logging.info(f"Processing {pc_dataset_name}...")
370
+ for entry in pc_dataset:
371
+ # PersonaChat often has 'utterances' containing 'history'
372
+ if 'utterances' in entry and entry['utterances']:
373
+ # Get the history from the last item in utterances for the full dialogue
374
+ history = entry['utterances'][-1]['history']
375
+ history = history[:CONFIG["max_turns"]] # Apply max turns
376
+ formatted_dialogue = []
377
+ for i, utterance in enumerate(history):
378
+ role = "<user>" if i % 2 == 0 else "<assistant>" # Assume simple alternation
379
+ cleaned_utterance = self._clean_text(utterance)
380
+ if cleaned_utterance:
381
+ formatted_dialogue.append(f"{role} {cleaned_utterance}")
382
+ if formatted_dialogue:
383
+ text_samples.append(" </s> ".join(formatted_dialogue))
384
+ else:
385
+ logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry}")
386
+
387
+ except Exception as e:
388
+ logging.error(f"Failed to load or process {pc_dataset_name} for tokenizer: {e}")
389
+
390
+
391
+ logging.info(f"Total text samples for tokenizer training: {len(text_samples)}")
392
+ if not text_samples:
393
+ raise ValueError("No text samples collected for tokenizer training. Check dataset loading and paths.")
394
+
395
+ # Ensure tokenizer directory exists before training
396
+ os.makedirs(self.tokenizer_dir, exist_ok=True)
397
+
398
+ logging.info(f"Training BPE tokenizer with vocab size {CONFIG['vocab_size']}...")
399
+ trainer = trainers.BpeTrainer(
400
+ vocab_size=CONFIG["vocab_size"],
401
+ special_tokens=self.special_tokens,
402
+ min_frequency=2, # Keep min_frequency low with more data
403
+ show_progress=True
404
+ )
405
+ # Make sure text_samples is an iterator or list of strings
406
+ def text_iterator():
407
+ for sample in text_samples:
408
+ yield sample
409
+
410
+ self.tokenizer.train_from_iterator(text_iterator(), trainer=trainer, length=len(text_samples))
411
+
412
+ eos_token_id = self.tokenizer.token_to_id("</s>")
413
+ if eos_token_id is None:
414
+ logging.warning("</s> token not found in trained tokenizer vocab! Using <pad> as fallback for post-processor.")
415
+ eos_token_id = self.tokenizer.token_to_id("<pad>") or 0 # Fallback needed
416
+
417
+ # Configure post-processor (adjust if needed based on how you structure input/output)
418
+ self.tokenizer.post_processor = processors.TemplateProcessing(
419
+ single="$A </s>",
420
+ pair="$A </s> $B </s>", # How to handle pairs - maybe just use single always?
421
+ special_tokens=[("</s>", eos_token_id)],
422
+ )
423
+
424
+ logging.info(f"Saving tokenizer to {self.tokenizer_path}")
425
+ self.tokenizer.save(self.tokenizer_path)
426
+ logging.info("Tokenizer training complete.")
427
+
428
+ def get_tokenizer(self):
429
+ if not os.path.exists(self.tokenizer_path):
430
+ raise FileNotFoundError(f"Tokenizer file not found at {self.tokenizer_path}. Train tokenizer first.")
431
+ tokenizer = Tokenizer.from_file(self.tokenizer_path)
432
+ # Verify special tokens crucial for processing exist
433
+ required_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
434
+ for token in required_tokens:
435
+ if tokenizer.token_to_id(token) is None:
436
+ raise ValueError(f"Crucial special token '{token}' not found in loaded tokenizer '{self.tokenizer_path}'!")
437
+ return tokenizer
438
+
439
+ # --- Dataset Loading and Processing ---
440
+
441
+ class CombinedChatDataset(Dataset):
442
+ def __init__(self, tokenizer):
443
+ self.tokenizer = tokenizer
444
+ self.pad_id = self.tokenizer.token_to_id("<pad>")
445
+ self.eos_id = self.tokenizer.token_to_id("</s>")
446
+ self.bos_id = self.tokenizer.token_to_id("<s>")
447
+ self.user_id = self.tokenizer.token_to_id("<user>")
448
+ self.assistant_id = self.tokenizer.token_to_id("<assistant>")
449
+ self.max_length = CONFIG["max_seq_len"]
450
+ # Reuse cleaning function from TokenizerTrainer instance
451
+ self._clean_text = TokenizerTrainer()._clean_text
452
+
453
+ self.all_processed_conversations = []
454
+
455
+ # --- Process DailyDialog ---
456
+ if "daily_dialog" in CONFIG["datasets"]:
457
+ logging.info("Loading and processing daily_dialog dataset...")
458
+ try:
459
+ dd_dataset = load_dataset("daily_dialog", split="train", trust_remote_code=True) # Add trust_remote_code=True
460
+ logging.info(f"Processing {len(dd_dataset)} daily_dialog conversations...")
461
+ for entry in dd_dataset:
462
+ conversation = []
463
+ dialogue = entry['dialog'][:CONFIG["max_turns"]]
464
+ if not dialogue: continue
465
+ for i, utterance in enumerate(dialogue):
466
+ role = "<user>" if i % 2 == 0 else "<assistant>"
467
+ cleaned_text = self._clean_text(utterance)
468
+ if cleaned_text:
469
+ conversation.append({'role': role, 'text': cleaned_text})
470
+ if conversation:
471
+ self.all_processed_conversations.append(conversation)
472
+ except Exception as e:
473
+ logging.error(f"Failed to load or process daily_dialog for training: {e}")
474
+
475
+ # --- Process EmpatheticDialogues ---
476
+ if "empathetic_dialogues" in CONFIG["datasets"]:
477
+ logging.info("Loading and processing empathetic_dialogues dataset...")
478
+ try:
479
+ ed_dataset = load_dataset("empathetic_dialogues", split="train", trust_remote_code=True) # Add trust_remote_code=True
480
+ logging.info("Grouping empathetic_dialogues by conversation ID...")
481
+ conversations_grouped = defaultdict(list)
482
+ for entry in ed_dataset:
483
+ conversations_grouped[entry['conv_id']].append(entry)
484
+
485
+ logging.info(f"Processing {len(conversations_grouped)} empathetic_dialogues conversations...")
486
+ for conv_id, entries in conversations_grouped.items():
487
+ conversation = []
488
+ sorted_entries = sorted(entries, key=lambda x: x['utterance_idx'])
489
+ # Handle context as first user turn if present
490
+ if sorted_entries[0]['context']:
491
+ context_text = self._clean_text(sorted_entries[0]['context'])
492
+ if context_text:
493
+ conversation.append({'role': '<user>', 'text': context_text})
494
+ # Process utterances, assuming alternation
495
+ last_role = conversation[-1]['role'] if conversation else None # Role of the last added turn
496
+ for entry in sorted_entries:
497
+ text = self._clean_text(entry['utterance'])
498
+ if not text: continue
499
+ # Determine role based on the *last added* role
500
+ current_role = '<assistant>' if last_role == '<user>' else '<user>'
501
+ conversation.append({'role': current_role, 'text': text})
502
+ last_role = current_role # Update for next iteration
503
+
504
+ # Apply max turns limit *after* forming the full sequence
505
+ conversation = conversation[:CONFIG["max_turns"]]
506
+ if conversation:
507
+ self.all_processed_conversations.append(conversation)
508
+
509
+ except Exception as e:
510
+ logging.error(f"Failed to load or process empathetic_dialogues for training: {e}")
511
+
512
+ # --- Process BlendedSkillTalk ---
513
+ if "blended_skill_talk" in CONFIG["datasets"]:
514
+ logging.info("Loading and processing blended_skill_talk dataset...")
515
+ try:
516
+ bst_dataset = load_dataset("blended_skill_talk", split="train", trust_remote_code=True) # Add trust_remote_code=True
517
+ logging.info(f"Processing {len(bst_dataset)} blended_skill_talk conversations...")
518
+ for entry in bst_dataset:
519
+ conversation = []
520
+ # Reconstruct dialogue: history + final two turns (if they exist)
521
+ dialogue_turns_raw = entry['previous_utterance']
522
+ if entry.get('free_turker_utterance'):
523
+ dialogue_turns_raw.append(entry['free_turker_utterance'])
524
+ if entry.get('guided_turker_utterance'):
525
+ dialogue_turns_raw.append(entry['guided_turker_utterance'])
526
+
527
+ if not dialogue_turns_raw: continue # Skip if no turns found
528
+
529
+ turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] # Apply max turns limit
530
+
531
+ for i, utterance in enumerate(turns_to_process):
532
+ role = "<user>" if i % 2 == 0 else "<assistant>" # Assume simple alternation
533
+ cleaned_text = self._clean_text(utterance)
534
+ if cleaned_text:
535
+ conversation.append({'role': role, 'text': cleaned_text})
536
+ if conversation: # Only add if not empty after cleaning/truncation
537
+ self.all_processed_conversations.append(conversation)
538
+ except Exception as e:
539
+ logging.error(f"Failed to load or process blended_skill_talk for training: {e}")
540
+
541
+ # --- Process PersonaChat ---
542
+ if "AlekseyKorshuk/persona-chat" in CONFIG["datasets"]: # Correct dataset identifier
543
+ pc_dataset_name = "AlekseyKorshuk/persona-chat"
544
+ logging.info(f"Loading and processing {pc_dataset_name} dataset...")
545
+ try:
546
+ pc_dataset = load_dataset(pc_dataset_name, split="train", trust_remote_code=True) # Add trust_remote_code=True, Correct dataset identifier
547
+ logging.info(f"Processing {len(pc_dataset)} {pc_dataset_name} conversations...")
548
+ for entry in pc_dataset:
549
+ conversation = []
550
+ if 'utterances' in entry and entry['utterances']:
551
+ # Extract the dialogue history
552
+ history = entry['utterances'][-1]['history']
553
+ history = history[:CONFIG["max_turns"]] # Apply max turns limit
554
+
555
+ for i, utterance in enumerate(history):
556
+ role = "<user>" if i % 2 == 0 else "<assistant>" # Simple alternation
557
+ cleaned_text = self._clean_text(utterance)
558
+ if cleaned_text:
559
+ conversation.append({'role': role, 'text': cleaned_text})
560
+
561
+ if conversation: # Only add if not empty
562
+ self.all_processed_conversations.append(conversation)
563
+ else:
564
+ logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry.keys()}")
565
+
566
+ except Exception as e:
567
+ logging.error(f"Failed to load or process {pc_dataset_name} for training: {e}")
568
+
569
+
570
+ logging.info(f"Total processed conversations from all datasets: {len(self.all_processed_conversations)}")
571
+ if not self.all_processed_conversations:
572
+ raise ValueError("No processed conversations were created from any dataset. Check loading logic and dataset availability.")
573
+
574
+ logging.info("Shuffling combined dataset...")
575
+ random.shuffle(self.all_processed_conversations)
576
+
577
+
578
+ def __len__(self):
579
+ return len(self.all_processed_conversations)
580
+
581
+ def __getitem__(self, idx):
582
+ conversation = self.all_processed_conversations[idx]
583
+ formatted_ids = [self.bos_id]
584
+ for turn in conversation:
585
+ role_id = self.user_id if turn['role'] == '<user>' else self.assistant_id
586
+ # Encode without adding special tokens automatically by tokenizer
587
+ try:
588
+ utterance_ids = self.tokenizer.encode(turn['text'], add_special_tokens=False).ids
589
+ except Exception as e:
590
+ logging.error(f"Error encoding text at index {idx}, turn '{turn}': {e}")
591
+ utterance_ids = [] # Skip this utterance on error
592
+
593
+ # Check length: Current + Role + Utterance + EOS <= MaxLength
594
+ # Need +1 for role, +len(utterance), +1 for potential EOS
595
+ if len(formatted_ids) + 1 + len(utterance_ids) + 1 > self.max_length:
596
+ # Attempt to add just the role and EOS if utterance is too long
597
+ if len(formatted_ids) + 1 + 1 <= self.max_length:
598
+ formatted_ids.append(role_id)
599
+ formatted_ids.append(self.eos_id)
600
+ break # Stop adding turns
601
+
602
+ formatted_ids.append(role_id)
603
+ formatted_ids.extend(utterance_ids)
604
+ formatted_ids.append(self.eos_id)
605
+
606
+ # Final safety truncate (should be rare if logic above is correct)
607
+ if len(formatted_ids) > self.max_length:
608
+ formatted_ids = formatted_ids[:self.max_length]
609
+ # Ensure last token isn't partial (though unlikely with BPE)
610
+ # If the truncated sequence ends with a role ID, it's probably bad, remove it.
611
+ if formatted_ids and (formatted_ids[-1] == self.user_id or formatted_ids[-1] == self.assistant_id):
612
+ formatted_ids.pop()
613
+ # If after popping the role ID, it's still too long (unlikely), truncate again
614
+ if len(formatted_ids) > self.max_length:
615
+ formatted_ids = formatted_ids[:self.max_length]
616
+
617
+
618
+ # Handle case of extremely short sequences after processing
619
+ if len(formatted_ids) < 2: # Need at least BOS and one other token for input/label pair
620
+ logging.warning(f"Sequence at index {idx} is too short after processing (<2 tokens). Skipping. Original length: {len(conversation)}")
621
+ # Return None to be filtered by collate_fn
622
+ return None
623
+
624
+ input_ids = formatted_ids[:-1]
625
+ labels = formatted_ids[1:]
626
+
627
+ # Final check before returning
628
+ if len(input_ids) == 0:
629
+ logging.warning(f"Sequence at index {idx} resulted in empty input_ids after slicing. Skipping.")
630
+ return None
631
+
632
+
633
+ return {"input_ids": input_ids, "labels": labels}
634
+
635
+ @staticmethod
636
+ def collate_fn(batch):
637
+ # Filter out None items from __getitem__
638
+ batch = [item for item in batch if item is not None]
639
+ if not batch:
640
+ return None # Return None if the whole batch was invalid
641
+
642
+ max_len = max(len(item["input_ids"]) for item in batch)
643
+
644
+ # Load tokenizer once to get pad_id - ensure path matches CONFIG
645
+ try:
646
+ # Correctly reference the tokenizer path from CONFIG within the static method
647
+ tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"])
648
+ # TODO: Consider passing tokenizer/pad_id if this becomes a bottleneck
649
+ tokenizer = Tokenizer.from_file(tokenizer_path)
650
+ pad_id = tokenizer.token_to_id("<pad>")
651
+ if pad_id is None: raise ValueError("<pad> token not found")
652
+ except Exception as e:
653
+ logging.error(f"Collate Error: Failed to load tokenizer or get pad_id ('{CONFIG['tokenizer_name']}'): {e}")
654
+ pad_id = 0 # Risky fallback
655
+
656
+ inputs, labels, masks = [], [], []
657
+ for item in batch:
658
+ input_len = len(item["input_ids"])
659
+ pad_len = max_len - input_len
660
+ inputs.append(item["input_ids"] + [pad_id] * pad_len)
661
+ # Pad labels with pad_id (or any ID to be ignored by CrossEntropyLoss)
662
+ labels.append(item["labels"] + [pad_id] * pad_len)
663
+ masks.append([1] * input_len + [0] * pad_len)
664
+
665
+ return {
666
+ "input_ids": torch.tensor(inputs, dtype=torch.long),
667
+ "labels": torch.tensor(labels, dtype=torch.long),
668
+ "attention_mask": torch.tensor(masks, dtype=torch.long) # Or bool
669
+ }
670
+
671
+ # --- Trainer, Safety Manager, Checkpoint Manager ---
672
+
673
+ class HROMTrainer:
674
+ def __init__(self, model, tokenizer):
675
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
676
+ logging.info(f"Using device: {self.device}")
677
+ self.model = model.to(self.device)
678
+
679
+ self.use_amp = (self.device.type == "cuda" and hasattr(torch.cuda.amp, "GradScaler"))
680
+ self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
681
+ logging.info(f"Automatic Mixed Precision (AMP): {'Enabled' if self.use_amp else 'Disabled'}")
682
+
683
+ self.optimizer = torch.optim.AdamW(
684
+ self.model.parameters(),
685
+ lr=CONFIG["learning_rate"], # Base LR
686
+ betas=(0.9, 0.95),
687
+ weight_decay=0.1,
688
+ fused= (self.device.type == "cuda")
689
+ )
690
+ self.tokenizer = tokenizer
691
+ self.pad_id = self.tokenizer.token_to_id("<pad>")
692
+ if self.pad_id is None:
693
+ # Attempt to get from config if available or fallback
694
+ self.pad_id = CONFIG.get("pad_token_id", 0)
695
+ logging.warning(f"<pad> token ID not found in tokenizer, using fallback ID: {self.pad_id}")
696
+
697
+
698
+ # Make sure ignore_index uses the determined pad_id
699
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id)
700
+ self.base_lr = CONFIG["learning_rate"]
701
+ self.warmup_steps = CONFIG["warmup_steps"]
702
+
703
+ def _adjust_learning_rate(self, step):
704
+ if self.warmup_steps > 0 and step < self.warmup_steps:
705
+ lr = self.base_lr * (step + 1) / self.warmup_steps
706
+ else:
707
+ # Optional: Add LR decay (e.g., cosine) after warmup
708
+ # Example: lr = self.base_lr * 0.5 * (1 + math.cos(math.pi * (step - self.warmup_steps) / (total_steps - self.warmup_steps)))
709
+ lr = self.base_lr # Keep base LR after warmup for now
710
+ for param_group in self.optimizer.param_groups:
711
+ param_group['lr'] = lr
712
+ return lr
713
+
714
+ def train_step(self, batch):
715
+ # Determine precision for autocast
716
+ if self.use_amp:
717
+ amp_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
718
+ autocast_context = torch.cuda.amp.autocast(dtype=amp_dtype, enabled=self.use_amp) if self.use_amp else nullcontext()
719
+
720
+ with autocast_context:
721
+ input_ids = batch["input_ids"].to(self.device)
722
+ attention_mask = batch["attention_mask"].to(self.device)
723
+ labels = batch["labels"].to(self.device)
724
+
725
+ outputs = self.model(input_ids, attention_mask=attention_mask)
726
+
727
+ # Reshape for loss calculation
728
+ logits_flat = outputs.view(-1, outputs.size(-1)) # Shape: (B * T, vocab_size)
729
+ labels_flat = labels.view(-1) # Shape: (B * T)
730
+
731
+ # Calculate loss - ensure logits are float32 for stability esp. with AMP
732
+ loss = self.criterion(logits_flat.float(), labels_flat)
733
+
734
+ # Scale loss for gradient accumulation
735
+ scaled_loss = loss / CONFIG["grad_accum_steps"]
736
+
737
+ # Backward pass
738
+ if self.use_amp and self.scaler:
739
+ self.scaler.scale(scaled_loss).backward()
740
+ else:
741
+ scaled_loss.backward()
742
+
743
+ return loss.item() # Return the unscaled loss for logging
744
+
745
+ def clip_and_step(self, current_optimizer_step):
746
+ current_lr = self._adjust_learning_rate(current_optimizer_step)
747
+ # Gradient Clipping *before* optimizer step
748
+ if self.use_amp and self.scaler:
749
+ # Unscale first - important before clipping
750
+ self.scaler.unscale_(self.optimizer)
751
+ # Clip grad norm
752
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
753
+ # Optimizer step (with scaler)
754
+ self.scaler.step(self.optimizer)
755
+ # Update scaler for next iteration
756
+ self.scaler.update()
757
+ else:
758
+ # Clip grad norm
759
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
760
+ # Optimizer step
761
+ self.optimizer.step()
762
+
763
+ # Zero gradients *after* stepping
764
+ self.optimizer.zero_grad(set_to_none=True)
765
+ return current_lr
766
+
767
+
768
+ class SafetyManager:
769
+ # (No changes needed in SafetyManager implementation itself)
770
+ def __init__(self, model, tokenizer):
771
+ self.model = model
772
+ self.tokenizer = tokenizer
773
+ # More conservative list
774
+ self.bad_words = ["kill", "murder", "suicide", "hate", "abuse", "violence", "illegal", "harm", "die", "attack", "rape", "molest", "exploit", "terror"]
775
+ self.bad_word_ids = []
776
+ logging.info("Initializing safety manager...")
777
+ # Pre-encode bad word sequences
778
+ for word in self.bad_words:
779
+ # Encode potentially multi-token words carefully
780
+ ids = tokenizer.encode(f" {word}", add_special_tokens=False).ids # Add prefix space for BPE
781
+ if ids:
782
+ self.bad_word_ids.append(ids)
783
+ logging.debug(f"Encoded bad word '{word}' (with space) to IDs: {ids}")
784
+ # Try without space too
785
+ ids_no_space = tokenizer.encode(word, add_special_tokens=False).ids
786
+ if ids_no_space and ids_no_space != ids:
787
+ self.bad_word_ids.append(ids_no_space)
788
+ logging.debug(f"Encoded bad word '{word}' (no space) to IDs: {ids_no_space}")
789
+
790
+ if not ids and not ids_no_space:
791
+ logging.warning(f"Could not encode bad word '{word}' - skipping.")
792
+
793
+ # Pre-get special IDs
794
+ self.eos_id = self.tokenizer.token_to_id("</s>")
795
+ self.bos_id = self.tokenizer.token_to_id("<s>")
796
+ self.user_id = self.tokenizer.token_to_id("<user>")
797
+ self.assistant_id = self.tokenizer.token_to_id("<assistant>")
798
+ self.pad_id = self.tokenizer.token_to_id("<pad>")
799
+
800
+ if self.eos_id is None: logging.error("</s> token ID not found for SafetyManager!"); self.eos_id = 0
801
+ if self.bos_id is None: logging.error("<s> token ID not found for SafetyManager!"); self.bos_id = 0
802
+ if self.user_id is None: logging.error("<user> token ID not found for SafetyManager!")
803
+ if self.assistant_id is None: logging.error("<assistant> token ID not found for SafetyManager!")
804
+ if self.pad_id is None: logging.error("<pad> token ID not found for SafetyManager!"); self.pad_id = 0
805
+
806
+
807
+ def contains_sequence(self, tokens, seq):
808
+ """Checks if the list `tokens` contains the sublist `seq`."""
809
+ if not seq or not tokens or len(tokens) < len(seq):
810
+ return False
811
+ seq_len = len(seq)
812
+ for i in range(len(tokens) - seq_len + 1):
813
+ if tokens[i : i + seq_len] == seq:
814
+ return True
815
+ return False
816
+
817
+ def content_filter(self, text_ids):
818
+ """Checks if a list of token IDs contains any bad word sequences."""
819
+ if not isinstance(text_ids, list):
820
+ logging.warning("Content filter received non-list input.")
821
+ return True # Default to safe if input is weird
822
+ for bad_ids in self.bad_word_ids:
823
+ if self.contains_sequence(text_ids, bad_ids):
824
+ # Log the detected sequence for debugging
825
+ detected_word = self.tokenizer.decode(bad_ids)
826
+ logging.warning(f"Unsafe content detected: Found sequence corresponding to '{detected_word}' (IDs: {bad_ids}).")
827
+ return False # Unsafe
828
+ return True # Safe
829
+
830
+ def generate_safely(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50):
831
+ self.model.eval()
832
+ device = next(self.model.parameters()).device
833
+
834
+ # Encode prompt, ensure it ends appropriately (e.g., with role token + EOS?)
835
+ # Let's assume the prompt ends like "<user> blah blah </s>" and we need to add "<assistant>"
836
+ prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False).ids
837
+
838
+ # Start generation sequence with BOS, prompt, and assistant token
839
+ # Ensure prompt doesn't already include BOS
840
+ if prompt_ids and prompt_ids[0] == self.bos_id:
841
+ input_ids = list(prompt_ids)
842
+ else:
843
+ input_ids = [self.bos_id] + list(prompt_ids)
844
+
845
+ # Add the assistant token to signal the model to generate the response
846
+ if self.assistant_id is not None:
847
+ input_ids.append(self.assistant_id)
848
+ else:
849
+ logging.error("Assistant token ID is None, cannot properly start generation.")
850
+ return "Error: Assistant token not found."
851
+
852
+
853
+ generated_ids = list(input_ids) # Start with the prepared input sequence
854
+ logging.debug(f"Starting safe generation with initial IDs: {generated_ids}")
855
+
856
+ with torch.no_grad():
857
+ for step in range(max_new_tokens):
858
+ # Prepare input tensor for this step - only use up to max_seq_len
859
+ current_input_ids = generated_ids[-CONFIG["max_seq_len"]:]
860
+ current_input_tensor = torch.tensor([current_input_ids]).to(device)
861
+ # Create attention mask for the current length
862
+ attention_mask = torch.ones_like(current_input_tensor)
863
+
864
+ # Model forward pass
865
+ try:
866
+ outputs = self.model(current_input_tensor, attention_mask=attention_mask)
867
+ next_token_logits = outputs[:, -1, :] # Logits for the next token
868
+ except Exception as e:
869
+ logging.error(f"Model forward pass failed during generation: {e}")
870
+ break # Stop generation on error
871
+
872
+ # --- Safety Check BEFORE sampling ---
873
+ # Apply penalties to bad word starting tokens if possible
874
+ # For now, we filter *after* sampling the token
875
+
876
+ # Sampling (Temperature, Top-K)
877
+ if temperature > 0 and temperature != 1.0:
878
+ next_token_logits = next_token_logits / temperature
879
+ if top_k > 0 and top_k < next_token_logits.size(-1): # Ensure top_k is valid
880
+ v, _ = torch.topk(next_token_logits, top_k)
881
+ # Handle potential NaN/Inf in logits before comparison
882
+ safe_logits = torch.nan_to_num(next_token_logits, nan=-float('inf'), posinf=float('inf'), neginf=-float('inf'))
883
+ threshold = v[:, [-1]]
884
+ safe_logits[safe_logits < threshold] = -float('Inf')
885
+ next_token_logits = safe_logits # Use the filtered logits
886
+
887
+ probs = torch.softmax(next_token_logits, dim=-1)
888
+ # Handle potential NaNs in probabilities before sampling
889
+ if torch.isnan(probs).any():
890
+ logging.warning("NaN detected in probabilities before sampling. Replacing with uniform distribution.")
891
+ probs = torch.ones_like(probs) / probs.size(-1) # Fallback to uniform
892
+
893
+ next_token_id = torch.multinomial(probs, num_samples=1).item()
894
+
895
+ # --- Safety Check AFTER sampling token ---
896
+ # Check if adding this token creates a bad sequence
897
+ potential_sequence_ids = generated_ids + [next_token_id]
898
+ # Check only the newly formed part for bad words for efficiency?
899
+ # Let's check the whole sequence for simplicity/robustness for now.
900
+ if not self.content_filter(potential_sequence_ids):
901
+ logging.warning(f"Potential unsafe token ({next_token_id}, '{self.tokenizer.decode([next_token_id])}') blocked POST-sampling. Stopping generation.")
902
+ # Optionally try sampling a different token? For now, just stop.
903
+ break
904
+
905
+ # Add the safe token
906
+ generated_ids.append(next_token_id)
907
+
908
+ # Check for EOS token
909
+ if next_token_id == self.eos_id:
910
+ logging.debug(f"EOS token generated at step {step+1}. Stopping generation.")
911
+ break
912
+
913
+ # Prevent infinite loops if max tokens reached
914
+ if step == max_new_tokens - 1:
915
+ logging.debug("Max new tokens reached. Stopping generation.")
916
+ # Ensure the sequence ends with EOS if it didn't naturally
917
+ if generated_ids[-1] != self.eos_id and self.eos_id is not None:
918
+ generated_ids.append(self.eos_id)
919
+
920
+ self.model.train() # Set model back to training mode
921
+
922
+ # Decode the generated part (excluding the initial prompt + assistant token)
923
+ start_index = len(input_ids)
924
+ response_ids = generated_ids[start_index:]
925
+
926
+ # Decode, skipping special tokens like EOS, BOS, PAD but potentially keeping USER/ASSISTANT
927
+ # Let's skip all special tokens for the final output text for clarity.
928
+ decoded_text = self.tokenizer.decode(response_ids, skip_special_tokens=True).strip()
929
+
930
+ return decoded_text
931
+
932
+
933
+ def debug_generation(self, prompt="<user> Tell me about your hobbies."): # Example prompt
934
+ logging.info(f"\n--- Debug Generation & Safety Check ---")
935
+ # Ensure prompt ends logically for the model (e.g., with user token and EOS)
936
+ if not prompt.strip().endswith("</s>"):
937
+ if not prompt.strip().endswith("<user>") and not prompt.strip().endswith("<assistant>"):
938
+ prompt = prompt.strip() + " </s>" # Add EOS if ends mid-sentence
939
+ else:
940
+ prompt = prompt.strip() + " </s>" # Add EOS after role token
941
+
942
+ # Ensure the prompt starts appropriately (e.g., no BOS needed here as generate_safely adds it)
943
+ if prompt.startswith("<s>"):
944
+ prompt = prompt[len("<s>"):].strip()
945
+
946
+
947
+ generated_response = self.generate_safely(prompt, max_new_tokens=60, temperature=0.7, top_k=50)
948
+
949
+ logging.info(f"Prompt Sent: '{prompt}'")
950
+ logging.info(f"Generated Response: '{generated_response}'")
951
+ logging.info("\n--- End Debug Generation ---\n")
952
+
953
+
954
+ class CheckpointManager:
955
+ def __init__(self):
956
+ # Use checkpoint directory from CONFIG
957
+ self.checkpoint_dir = CONFIG["checkpoint_dir"]
958
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
959
+ logging.info(f"Checkpoint directory set to: {self.checkpoint_dir}")
960
+
961
+ def save(self, model, optimizer, step):
962
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
963
+ # Use a consistent naming scheme based on the directory name if desired
964
+ prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
965
+ # Ensure step is converted to string if it's passed as something else (e.g., 'final')
966
+ step_str = str(step)
967
+ filename = f"hrom_{prefix}_step{step_str}_{timestamp}.pt"
968
+ path = os.path.join(self.checkpoint_dir, filename)
969
+ state = {
970
+ "model": model.state_dict(),
971
+ "optimizer": optimizer.state_dict(),
972
+ "step": step if isinstance(step, int) else -1, # Store step number or -1 for non-numeric steps
973
+ "config": CONFIG # Save config with checkpoint
974
+ }
975
+ logging.info(f"Saving checkpoint to {path}...")
976
+ try:
977
+ torch.save(state, path)
978
+ logging.info(f"Checkpoint saved successfully at step {step_str}.")
979
+ self._cleanup_old_checkpoints()
980
+ except Exception as e:
981
+ logging.error(f"Failed to save checkpoint '{path}': {e}")
982
+
983
+ def _cleanup_old_checkpoints(self):
984
+ max_checkpoints = CONFIG.get("max_checkpoints", 5) # Get from config, default 5
985
+ if max_checkpoints <= 0:
986
+ return # Keep all checkpoints if max_checkpoints is non-positive
987
+
988
+ try:
989
+ # Filter only files matching the expected pattern (avoid deleting other files)
990
+ prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
991
+ pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt")
992
+
993
+ checkpoints = []
994
+ for f in os.listdir(self.checkpoint_dir):
995
+ match = pattern.match(f)
996
+ if match:
997
+ filepath = os.path.join(self.checkpoint_dir, f)
998
+ checkpoints.append((filepath, os.path.getmtime(filepath)))
999
+
1000
+ # Sort by modification time (oldest first)
1001
+ checkpoints.sort(key=lambda x: x[1])
1002
+
1003
+ num_to_delete = len(checkpoints) - max_checkpoints
1004
+ if num_to_delete > 0:
1005
+ #logging.info(f"Max checkpoints ({max_checkpoints}) reached. Removing {num_to_delete} oldest checkpoints.")
1006
+ for i in range(num_to_delete):
1007
+ file_to_remove, _ = checkpoints[i]
1008
+ try:
1009
+ os.remove(file_to_remove)
1010
+ #logging.info(f"Removed old checkpoint: {os.path.basename(file_to_remove)}")
1011
+ except OSError as e:
1012
+ logging.error(f"Error removing checkpoint {file_to_remove}: {e}")
1013
+ except Exception as e:
1014
+ logging.error(f"Error during checkpoint cleanup: {e}")
1015
+
1016
+
1017
+ def load_latest(self, model, optimizer):
1018
+ try:
1019
+ # Filter files based on pattern and sort by time
1020
+ prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
1021
+ pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt")
1022
+ checkpoints = []
1023
+ for f in os.listdir(self.checkpoint_dir):
1024
+ match = pattern.match(f)
1025
+ if match:
1026
+ filepath = os.path.join(self.checkpoint_dir, f)
1027
+ checkpoints.append((filepath, os.path.getmtime(filepath)))
1028
+
1029
+ if not checkpoints:
1030
+ logging.info("No valid checkpoints found to load.")
1031
+ return 0 # Start from step 0
1032
+
1033
+ # Sort by modification time (newest first)
1034
+ checkpoints.sort(key=lambda x: x[1], reverse=True)
1035
+
1036
+ latest_checkpoint_path, _ = checkpoints[0]
1037
+ logging.info(f"Loading latest checkpoint from: {latest_checkpoint_path}")
1038
+ map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1039
+ checkpoint = torch.load(latest_checkpoint_path, map_location=map_location)
1040
+
1041
+ # --- Config Compatibility Check (Optional but Recommended) ---
1042
+ loaded_config = checkpoint.get("config", {})
1043
+ # Compare key parameters that affect model architecture or data processing
1044
+ critical_keys = ["dim", "n_layers", "n_heads", "ff_dim", "vocab_size", "max_seq_len", "tokenizer_name"]
1045
+ mismatched_keys = []
1046
+ if loaded_config:
1047
+ for key in critical_keys:
1048
+ # Check if key exists in both and if they differ
1049
+ if key in loaded_config and key in CONFIG and loaded_config[key] != CONFIG[key]:
1050
+ mismatched_keys.append((key, loaded_config[key], CONFIG[key]))
1051
+ # Check if key missing in current config but present in checkpoint
1052
+ elif key in loaded_config and key not in CONFIG:
1053
+ mismatched_keys.append((key, loaded_config[key], "Not in current CONFIG"))
1054
+ # Check if key missing in checkpoint config but present in current
1055
+ elif key not in loaded_config and key in CONFIG:
1056
+ mismatched_keys.append((key, "Not in loaded CONFIG", CONFIG[key]))
1057
+
1058
+
1059
+ if mismatched_keys:
1060
+ logging.warning("--- CONFIG MISMATCH DETECTED ---")
1061
+ logging.warning(f"Checkpoint '{os.path.basename(latest_checkpoint_path)}' was saved with different critical parameters:")
1062
+ for key, loaded_val, current_val in mismatched_keys:
1063
+ logging.warning(f" - {key}: Checkpoint='{loaded_val}', Current='{current_val}'")
1064
+ # Decide whether to proceed: raise error, warn, or try anyway
1065
+ # For now, just warn strongly. Loading might fail or lead to issues.
1066
+ logging.warning("Proceeding with loading, but results may be unexpected or errors may occur.")
1067
+ else:
1068
+ logging.warning("Checkpoint does not contain configuration info. Cannot check compatibility.")
1069
+ # --- End Config Check ---
1070
+
1071
+
1072
+ try:
1073
+ # Strict=False can sometimes help load partially, but hides potential issues
1074
+ model.load_state_dict(checkpoint['model'], strict=True)
1075
+ except RuntimeError as e:
1076
+ logging.error(f"Failed to load model state_dict: {e}")
1077
+ logging.error("This often happens due to architecture mismatch (check CONFIG) or corrupted checkpoint.")
1078
+ logging.error("Starting training from scratch.")
1079
+ return 0 # Cannot resume if model loading fails
1080
+
1081
+ try:
1082
+ optimizer.load_state_dict(checkpoint['optimizer'])
1083
+ except ValueError as e:
1084
+ logging.warning(f"Could not load optimizer state_dict: {e}. Optimizer state will be reset.")
1085
+ # Reinitialize optimizer if state doesn't match? Or just proceed with current state.
1086
+ # Resetting optimizer state is safer if parameters changed.
1087
+ optimizer.state = defaultdict(dict) # Reset state
1088
+ logging.warning("Optimizer state reset.")
1089
+ except Exception as e:
1090
+ logging.error(f"Unexpected error loading optimizer state: {e}. Starting training from scratch.")
1091
+ return 0
1092
+
1093
+ start_step = checkpoint.get('step', 0)
1094
+ # Ensure step is non-negative, resume from next step
1095
+ start_step = max(0, start_step) + 1 if isinstance(start_step, int) else 0
1096
+
1097
+
1098
+ logging.info(f"Checkpoint loaded successfully. Resuming from optimizer step {start_step}.")
1099
+ # Move optimizer state tensors to the correct device
1100
+ for state in optimizer.state.values():
1101
+ for k, v in state.items():
1102
+ if isinstance(v, torch.Tensor):
1103
+ try:
1104
+ state[k] = v.to(map_location)
1105
+ except Exception as e:
1106
+ logging.error(f"Failed to move optimizer tensor '{k}' to device '{map_location}': {e}")
1107
+ return start_step
1108
+
1109
+ except FileNotFoundError:
1110
+ logging.info(f"No checkpoint directory '{self.checkpoint_dir}' or files found. Starting training from scratch.")
1111
+ return 0
1112
+ except Exception as e:
1113
+ logging.error(f"Error loading checkpoint from '{self.checkpoint_dir}': {e}. Starting training from scratch.")
1114
+ # Clean up potentially partially loaded model/optimizer?
1115
+ # Re-initializing might be safer depending on where the error occurred.
1116
+ # For simplicity, we just return 0 here.
1117
+ return 0
1118
+
1119
+
1120
+ # --- Training Function ---
1121
+
1122
+ def train():
1123
+ logging.info("Starting HROM training process on combined datasets (daily_dialog, empathetic_dialogues, blended_skill_talk, AlekseyKorshuk/persona-chat)...") # Corrected log message
1124
+ logging.info(f"Configuration: {CONFIG}")
1125
+
1126
+ # --- Tokenizer Setup ---
1127
+ tokenizer_trainer = TokenizerTrainer()
1128
+ tokenizer_path = tokenizer_trainer.tokenizer_path
1129
+ if not os.path.exists(tokenizer_path):
1130
+ logging.info(f"Combined tokenizer '{CONFIG['tokenizer_name']}' not found. Training tokenizer...")
1131
+ try:
1132
+ # Pass trust_remote_code=True to load_dataset calls inside tokenizer training
1133
+ tokenizer_trainer.train(CONFIG["datasets"])
1134
+ except Exception as e:
1135
+ logging.error(f"Failed during tokenizer training: {e}", exc_info=True)
1136
+ return # Cannot proceed without a tokenizer
1137
+ else:
1138
+ logging.info(f"Loading existing combined tokenizer from {tokenizer_path}")
1139
+ # Load the tokenizer instance *once* here for shared use
1140
+ try:
1141
+ tokenizer = tokenizer_trainer.get_tokenizer()
1142
+ # Update CONFIG with actual token IDs (useful for downstream)
1143
+ CONFIG['pad_token_id'] = tokenizer.token_to_id("<pad>")
1144
+ CONFIG['bos_token_id'] = tokenizer.token_to_id("<s>")
1145
+ CONFIG['eos_token_id'] = tokenizer.token_to_id("</s>")
1146
+ logging.info(f"Loaded tokenizer. Vocab size: {tokenizer.get_vocab_size()}. Special IDs: PAD={CONFIG['pad_token_id']}, BOS={CONFIG['bos_token_id']}, EOS={CONFIG['eos_token_id']}")
1147
+ except (FileNotFoundError, ValueError) as e:
1148
+ logging.error(f"Failed to load tokenizer: {e}. Cannot continue.")
1149
+ return
1150
+
1151
+ # --- Model Initialization ---
1152
+ logging.info("Initializing HROM model...")
1153
+ # Ensure vocab_size in config matches tokenizer
1154
+ if CONFIG['vocab_size'] != tokenizer.get_vocab_size():
1155
+ logging.warning(f"Config vocab_size ({CONFIG['vocab_size']}) differs from tokenizer vocab size ({tokenizer.get_vocab_size()}). Using tokenizer's size.")
1156
+ CONFIG['vocab_size'] = tokenizer.get_vocab_size()
1157
+ model = HROM()
1158
+
1159
+ # --- Calculate and Log Model Parameters ---
1160
+ total_params = sum(p.numel() for p in model.parameters())
1161
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1162
+ logging.info(f"Model initialized. Total parameters: {total_params:,}")
1163
+ logging.info(f"Trainable parameters: {trainable_params:,}")
1164
+ logging.info(f"Parameters (Millions): Total={total_params/1e6:.2f}M, Trainable={trainable_params/1e6:.2f}M")
1165
+
1166
+
1167
+ # --- Dataset and DataLoader ---
1168
+ logging.info("Setting up combined dataset and dataloader...")
1169
+ try:
1170
+ logging.info("Pre-loading/caching datasets...")
1171
+ for ds_name in CONFIG["datasets"]:
1172
+ logging.info(f"Checking cache for '{ds_name}'...")
1173
+ try:
1174
+ # Load just the first example to trigger download/cache check
1175
+ _ = load_dataset(ds_name, split="train[:1]", download_mode="reuse_cache_if_exists", trust_remote_code=True) # Add trust_remote_code
1176
+ except Exception as e:
1177
+ # Log error but try to continue, main dataset loading will handle final error
1178
+ logging.error(f"Could not pre-check dataset '{ds_name}': {e}")
1179
+ logging.info("Dataset download/cache check presumed complete.")
1180
+
1181
+ # Pass the already loaded tokenizer instance
1182
+ dataset = CombinedChatDataset(tokenizer)
1183
+
1184
+ # Check if dataset is empty after processing
1185
+ if len(dataset) == 0:
1186
+ logging.error("Dataset is empty after processing all sources. Cannot train.")
1187
+ return
1188
+
1189
+ dataloader = DataLoader(
1190
+ dataset,
1191
+ batch_size=CONFIG["batch_size"],
1192
+ collate_fn=CombinedChatDataset.collate_fn, # Use static method
1193
+ shuffle=True,
1194
+ # Adjust num_workers based on available cores, be conservative
1195
+ num_workers=min(4, os.cpu_count() // 2 if (os.cpu_count() and os.cpu_count() > 1) else 1),
1196
+ pin_memory=torch.cuda.is_available(),
1197
+ prefetch_factor=2 if torch.cuda.is_available() and os.cpu_count() and os.cpu_count() > 1 else None,
1198
+ drop_last=False # Keep last batch even if smaller
1199
+ )
1200
+ except Exception as e:
1201
+ logging.error(f"Failed to initialize dataset/dataloader: {e}", exc_info=True)
1202
+ return
1203
+
1204
+ # --- Trainer, Checkpoint, Safety ---
1205
+ logging.info("Initializing Trainer, Checkpoint Manager, and Safety Manager...")
1206
+ # Pass the loaded tokenizer instance
1207
+ trainer_obj = HROMTrainer(model, tokenizer)
1208
+ checkpoint_manager = CheckpointManager() # Uses CONFIG["checkpoint_dir"]
1209
+ safety = SafetyManager(model, tokenizer) # Pass the loaded tokenizer instance
1210
+
1211
+ # --- Load Checkpoint ---
1212
+ start_optimizer_step = checkpoint_manager.load_latest(model, trainer_obj.optimizer)
1213
+ # Ensure model is on correct device after loading
1214
+ model.to(trainer_obj.device)
1215
+
1216
+ # --- Training Loop ---
1217
+ logging.info(f"Starting training from optimizer step {start_optimizer_step}")
1218
+ optimizer_step = start_optimizer_step
1219
+ total_loss_accum = 0.0
1220
+ # Calculate starting batch step based on loaded optimizer step and grad accum
1221
+ batch_step = optimizer_step * CONFIG["grad_accum_steps"]
1222
+ epochs_completed = batch_step // len(dataloader) if len(dataloader) > 0 else 0
1223
+ start_epoch = epochs_completed # Start from the epoch corresponding to the loaded step
1224
+
1225
+ # Estimate total steps (can be useful for LR scheduling if implementing decay)
1226
+ try:
1227
+ if len(dataloader) == 0:
1228
+ raise ValueError("DataLoader has zero length. Cannot estimate total steps.")
1229
+ total_optimizer_steps = (len(dataloader) * CONFIG["num_epochs"]) // CONFIG["grad_accum_steps"]
1230
+ logging.info(f"Estimated dataset size: {len(dataset)}")
1231
+ logging.info(f"Estimated batches per epoch: {len(dataloader)}")
1232
+ logging.info(f"Gradient Accumulation Steps: {CONFIG['grad_accum_steps']}")
1233
+ logging.info(f"Effective Batch Size: {CONFIG['batch_size'] * CONFIG['grad_accum_steps']}")
1234
+ logging.info(f"Target Epochs: {CONFIG['num_epochs']}")
1235
+ logging.info(f"Estimated total optimizer steps for {CONFIG['num_epochs']} epochs: {total_optimizer_steps}")
1236
+ except Exception as e:
1237
+ logging.warning(f"Could not accurately estimate dataloader length or total steps: {e}")
1238
+ total_optimizer_steps = -1 # Indicate unknown total steps
1239
+
1240
+
1241
+ model.train() # Ensure model is in training mode
1242
+
1243
+ for epoch in range(start_epoch, CONFIG["num_epochs"]):
1244
+ logging.info(f"--- Starting Epoch {epoch+1}/{CONFIG['num_epochs']} ---")
1245
+ epoch_loss = 0.0
1246
+ num_batches_in_epoch = 0
1247
+
1248
+ # Use enumerate starting from 1 for batch count if preferred
1249
+ for i, batch in enumerate(dataloader):
1250
+ # Check if batch is valid (collate_fn might return None)
1251
+ if batch is None:
1252
+ logging.warning(f"Skipping empty batch at step {i} in epoch {epoch+1}")
1253
+ continue
1254
+
1255
+ # Forward and backward pass (scaled loss)
1256
+ loss = trainer_obj.train_step(batch)
1257
+ if loss is None or torch.isnan(torch.tensor(loss)) or torch.isinf(torch.tensor(loss)):
1258
+ logging.error(f"NaN, Inf, or None loss detected: {loss}. Epoch {epoch+1}, Batch {i}, Opt Step {optimizer_step}. Stopping.")
1259
+ # Try saving a 'nan_inf' checkpoint before exiting
1260
+ checkpoint_manager.save(model, trainer_obj.optimizer, f"{optimizer_step}_error")
1261
+ return
1262
+
1263
+ total_loss_accum += loss
1264
+ epoch_loss += loss
1265
+ num_batches_in_epoch += 1
1266
+ batch_step += 1 # Increment global batch counter (tracks batches processed)
1267
+
1268
+ # Gradient Accumulation Check & Optimizer Step
1269
+ # Check if it's time to perform an optimizer step
1270
+ if batch_step % CONFIG["grad_accum_steps"] == 0:
1271
+ current_lr = trainer_obj.clip_and_step(optimizer_step) # Pass current opt step for LR schedule
1272
+
1273
+ # Calculate average loss over accumulation steps for logging
1274
+ avg_loss = total_loss_accum / CONFIG["grad_accum_steps"]
1275
+ total_loss_accum = 0.0 # Reset loss accumulator
1276
+
1277
+ # Logging
1278
+ if optimizer_step % CONFIG["debug_interval"] == 0:
1279
+ logging.info(f"Epoch {epoch+1} | Opt Step {optimizer_step} | Batch Step {batch_step} | Avg Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
1280
+ # Trigger debug generation less frequently or based on condition
1281
+ if optimizer_step % (CONFIG["debug_interval"] * 5) == 0: # e.g., every 5 debug intervals
1282
+ safety.debug_generation("<user> Hi there! How are you doing today?") # Use a generic debug prompt
1283
+
1284
+ # Checkpointing
1285
+ if optimizer_step > 0 and optimizer_step % CONFIG["checkpoint_interval"] == 0:
1286
+ logging.info(f"Checkpoint interval reached at optimizer step {optimizer_step}.")
1287
+ checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step)
1288
+ # Optional: Run a generation check after saving checkpoint
1289
+ safety.debug_generation("<user> Hi! How are you?")
1290
+
1291
+ optimizer_step += 1 # Increment optimizer step count *after* performing the step
1292
+
1293
+ # --- End of Epoch ---
1294
+ avg_epoch_loss = epoch_loss / num_batches_in_epoch if num_batches_in_epoch > 0 else 0
1295
+ logging.info(f"--- Finished Epoch {epoch+1}/{CONFIG['num_epochs']} | Average Epoch Loss: {avg_epoch_loss:.4f} ---")
1296
+
1297
+ # Save checkpoint at the end of each epoch
1298
+ checkpoint_manager.save(model, trainer_obj.optimizer, f"epoch{epoch+1}_step{optimizer_step}")
1299
+ # Optionally run debug generation at end of epoch
1300
+ safety.debug_generation("<user> Hi! Whats up?")
1301
+
1302
+
1303
+ logging.info(f"Training finished after {CONFIG['num_epochs']} target epochs.")
1304
+ # Final save
1305
+ logging.info("Saving final model state...")
1306
+ checkpoint_manager.save(model, trainer_obj.optimizer, f"final_step{optimizer_step}")
1307
+
1308
+
1309
+ if __name__ == "__main__":
1310
+ # Ensures imports happen after setting the env var if script is run directly
1311
+ train()
HROM_Trainer.py DELETED
@@ -1,384 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import Dataset, DataLoader
4
- from datasets import load_dataset
5
- from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
6
- import math
7
- import os
8
- import re
9
- from datetime import datetime
10
- from contextlib import nullcontext
11
-
12
- # Configuration
13
- CONFIG = {
14
- "dim": 512,
15
- "n_layers": 6,
16
- "n_heads": 8,
17
- "ff_dim": 2048,
18
- "dropout": 0.1,
19
- "max_seq_len": 1024,
20
- "batch_size": 32,
21
- "checkpoint_interval": 1000,
22
- "debug_interval": 500,
23
- "dataset": "daily_dialog",
24
- "vocab_size": 32000,
25
- "tokenizer_train_samples": 100000,
26
- "learning_rate": 1e-4, # Lowered learning rate
27
- "max_turns": 6,
28
- "max_checkpoints": 5,
29
- "num_epochs": 100, # Increased number of epochs
30
- "grad_accum_steps": 4 # Gradient accumulation steps
31
- }
32
-
33
- class RotaryEmbedding(nn.Module):
34
- def __init__(self, dim):
35
- super().__init__()
36
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
37
- self.register_buffer("inv_freq", inv_freq)
38
-
39
- def forward(self, seq_len):
40
- t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
41
- freqs = torch.einsum("i, j -> i j", t, self.inv_freq)
42
- return torch.cat((freqs, freqs), dim=-1)
43
-
44
- def rotate_half(x):
45
- x1, x2 = x.chunk(2, dim=-1)
46
- return torch.cat((-x2, x1), dim=-1)
47
-
48
- def apply_rotary_pos_emb(pos, t):
49
- pos = pos.unsqueeze(0).unsqueeze(1)
50
- return (t * pos.cos()) + (rotate_half(t) * pos.sin())
51
-
52
- class SwiGLU(nn.Module):
53
- def forward(self, x):
54
- x, gate = x.chunk(2, dim=-1)
55
- return x * torch.sigmoid(gate)
56
-
57
- class HROMAttention(nn.Module):
58
- def __init__(self):
59
- super().__init__()
60
- self.dim = CONFIG["dim"]
61
- self.n_heads = CONFIG["n_heads"]
62
- self.head_dim = self.dim // self.n_heads
63
- self.qkv = nn.Linear(self.dim, 3 * self.dim)
64
- self.proj = nn.Linear(self.dim, self.dim)
65
- self.rotary = RotaryEmbedding(self.head_dim)
66
- self.dropout = nn.Dropout(CONFIG["dropout"])
67
-
68
- def forward(self, x, mask=None):
69
- B, T, _ = x.shape
70
- qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
71
- q, k, v = qkv.unbind(2)
72
- q = q.transpose(1, 2)
73
- k = k.transpose(1, 2)
74
- v = v.transpose(1, 2)
75
- pos = self.rotary(T)
76
- q = apply_rotary_pos_emb(pos, q)
77
- k = apply_rotary_pos_emb(pos, k)
78
- attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
79
- if mask is not None:
80
- mask = mask.unsqueeze(1)
81
- attn = attn + mask
82
- attn = torch.softmax(attn, dim=-1)
83
- attn = self.dropout(attn)
84
- out = attn @ v
85
- out = out.transpose(1, 2).reshape(B, T, self.dim)
86
- return self.proj(out)
87
-
88
- class HROMBlock(nn.Module):
89
- def __init__(self):
90
- super().__init__()
91
- self.attn = HROMAttention()
92
- self.ff = nn.Sequential(
93
- nn.Linear(CONFIG["dim"], 2 * CONFIG["ff_dim"]),
94
- SwiGLU(),
95
- nn.Linear(CONFIG["ff_dim"], CONFIG["dim"])
96
- )
97
- self.norm1 = nn.LayerNorm(CONFIG["dim"])
98
- self.norm2 = nn.LayerNorm(CONFIG["dim"])
99
- self.dropout = nn.Dropout(CONFIG["dropout"])
100
-
101
- def forward(self, x, mask=None):
102
- x = x + self.dropout(self.attn(self.norm1(x), mask))
103
- x = x + self.dropout(self.ff(self.norm2(x)))
104
- return x
105
-
106
- class HROM(nn.Module):
107
- def __init__(self):
108
- super().__init__()
109
- self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"])
110
- self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])])
111
- self.norm = nn.LayerNorm(CONFIG["dim"])
112
- self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"])
113
- self.apply(self._init_weights)
114
-
115
- def _init_weights(self, module):
116
- if isinstance(module, nn.Linear):
117
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
- if module.bias is not None:
119
- torch.nn.init.zeros_(module.bias)
120
-
121
- def forward(self, x, attention_mask=None):
122
- x = self.embed(x)
123
- if attention_mask is not None:
124
- B, T = attention_mask.shape
125
- causal_mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1)
126
- causal_mask = causal_mask.to(x.device)
127
- pad_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(dtype=torch.float32)
128
- pad_mask = (1.0 - pad_mask) * torch.finfo(torch.float32).min
129
- mask = causal_mask + pad_mask.squeeze(1)
130
- else:
131
- B, T = x.shape[:2]
132
- mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1)
133
- mask = mask.to(x.device)
134
- mask = mask.unsqueeze(0).expand(B, -1, -1)
135
- for block in self.blocks:
136
- x = block(x, mask)
137
- return self.head(self.norm(x))
138
-
139
- class TokenizerTrainer:
140
- def __init__(self):
141
- self.tokenizer = Tokenizer(models.BPE())
142
- self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
143
- self.tokenizer.decoder = decoders.ByteLevel()
144
- self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
145
-
146
- def train(self, dataset_name):
147
- dataset = load_dataset(dataset_name, split=f"train[:{CONFIG['tokenizer_train_samples']}]")
148
- text_samples = []
149
- for entry in dataset:
150
- if "dialog" in entry:
151
- for i, utterance in enumerate(entry["dialog"][:CONFIG["max_turns"]]):
152
- role = "<user>" if i % 2 == 0 else "<assistant>"
153
- text_samples.append(f"{role} {utterance}")
154
- else:
155
- text_samples.append(self._clean_text(entry.get("text", "")))
156
- trainer = trainers.BpeTrainer(
157
- vocab_size=CONFIG["vocab_size"],
158
- special_tokens=self.special_tokens,
159
- min_frequency=2,
160
- show_progress=True
161
- )
162
- self.tokenizer.train_from_iterator(text_samples, trainer=trainer, length=len(text_samples))
163
- self.tokenizer.post_processor = processors.TemplateProcessing(
164
- single="$A </s>",
165
- pair="$A $B </s>",
166
- special_tokens=[("</s>", self.tokenizer.token_to_id("</s>"))],
167
- )
168
- os.makedirs("tokenizer", exist_ok=True)
169
- self.tokenizer.save("tokenizer/hrom_tokenizer.json")
170
-
171
- def _clean_text(self, text):
172
- text = re.sub(r'[^\w\s.,!?\'\-:;<>]', '', text)
173
- text = re.sub(r'\s+', ' ', text).strip()
174
- return text
175
-
176
- class ChatDataset(Dataset):
177
- def __init__(self, tokenizer):
178
- full_dataset = load_dataset(CONFIG["dataset"], split="train")
179
- num_samples = min(len(full_dataset), CONFIG["tokenizer_train_samples"])
180
- self.dataset = full_dataset.shuffle(seed=42).select(range(num_samples))
181
- self.tokenizer = tokenizer
182
- self.max_length = CONFIG["max_seq_len"]
183
- self.turn_sep = self.tokenizer.token_to_id("</s>")
184
-
185
- def __len__(self):
186
- return len(self.dataset)
187
-
188
- def __getitem__(self, idx):
189
- entry = self.dataset[idx]
190
- formatted = []
191
- if "dialog" in entry:
192
- dialog = entry["dialog"][:CONFIG["max_turns"]]
193
- for i, utterance in enumerate(dialog):
194
- role_token = "<user>" if i % 2 == 0 else "<assistant>"
195
- formatted.extend([
196
- self.tokenizer.token_to_id(role_token),
197
- *self.tokenizer.encode(utterance).ids,
198
- self.turn_sep
199
- ])
200
- else:
201
- text = entry.get("text", "")
202
- formatted.extend([
203
- self.tokenizer.token_to_id("<user>"),
204
- *self.tokenizer.encode(text).ids,
205
- self.turn_sep
206
- ])
207
- formatted = formatted[:self.max_length-2]
208
- formatted = [self.tokenizer.token_to_id("<s>"), *formatted, self.tokenizer.token_to_id("</s>")]
209
- return {
210
- "input_ids": formatted[:-1],
211
- "labels": formatted[1:]
212
- }
213
-
214
- @staticmethod
215
- def collate_fn(batch):
216
- max_len = max(len(item["input_ids"]) for item in batch)
217
- pad_id = Tokenizer.from_file("tokenizer/hrom_tokenizer.json").token_to_id("<pad>")
218
- inputs, labels, masks = [], [], []
219
- for item in batch:
220
- pad_len = max_len - len(item["input_ids"])
221
- inputs.append(item["input_ids"] + [pad_id] * pad_len)
222
- labels.append(item["labels"] + [pad_id] * pad_len)
223
- masks.append([1] * len(item["input_ids"]) + [0] * pad_len)
224
- return {
225
- "input_ids": torch.tensor(inputs),
226
- "labels": torch.tensor(labels),
227
- "attention_mask": torch.tensor(masks)
228
- }
229
-
230
- class HROMTrainer:
231
- def __init__(self, model, tokenizer):
232
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233
- self.model = model.to(self.device)
234
- if self.device.type == "cuda":
235
- self.scaler = torch.cuda.amp.GradScaler()
236
- else:
237
- self.scaler = None
238
- self.optimizer = torch.optim.AdamW(
239
- self.model.parameters(),
240
- lr=CONFIG["learning_rate"],
241
- fused=True if self.device.type == "cuda" else False
242
- )
243
- self.tokenizer = tokenizer
244
-
245
- def train_step(self, batch):
246
- autocast = torch.cuda.amp.autocast if self.device.type == "cuda" else nullcontext
247
- with autocast():
248
- outputs = self.model(
249
- batch["input_ids"].to(self.device),
250
- attention_mask=batch["attention_mask"].to(self.device)
251
- )
252
- original_loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.token_to_id("<pad>"))(
253
- outputs.view(-1, CONFIG["vocab_size"]),
254
- batch["labels"].view(-1).to(self.device)
255
- )
256
- scaled_loss = original_loss / CONFIG["grad_accum_steps"]
257
-
258
- if self.scaler is not None:
259
- self.scaler.scale(scaled_loss).backward()
260
- else:
261
- scaled_loss.backward()
262
-
263
- return original_loss.item()
264
-
265
- def clip_and_step(self):
266
- if self.scaler is not None:
267
- self.scaler.unscale_(self.optimizer)
268
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
269
-
270
- if self.scaler is not None:
271
- self.scaler.step(self.optimizer)
272
- self.scaler.update()
273
- else:
274
- self.optimizer.step()
275
-
276
- self.optimizer.zero_grad()
277
-
278
- class SafetyManager:
279
- def __init__(self, model, tokenizer):
280
- self.model = model
281
- self.tokenizer = tokenizer
282
- self.bad_words = ["hate", "kill", "harm"]
283
- self.bad_word_ids = [tokenizer.encode(w).ids for w in self.bad_words]
284
-
285
- def content_filter(self, text):
286
- tokens = self.tokenizer.encode(text).ids
287
- for bad_ids in self.bad_word_ids:
288
- if any(tokens[i:i+len(bad_ids)] == bad_ids for i in range(len(tokens))):
289
- return False
290
- return True
291
-
292
- def generate_safely(self, prompt, max_length=50):
293
- input_ids = self.tokenizer.encode(prompt).ids
294
- device = next(self.model.parameters()).device
295
- for _ in range(max_length):
296
- with torch.no_grad():
297
- logits = self.model(torch.tensor([input_ids]).to(device))
298
- next_token = logits.argmax(-1)[:, -1].item()
299
- if next_token == self.tokenizer.token_to_id("</s>"):
300
- break
301
- generated = self.tokenizer.decode(input_ids + [next_token])
302
- if not self.content_filter(generated):
303
- break
304
- input_ids.append(next_token)
305
- return self.tokenizer.decode(input_ids)
306
-
307
- def debug_generation(self, prompt="Hello!"):
308
- print(f"\nSafety Check Generation:")
309
- response = self.generate_safely(prompt)
310
- print(f"Prompt: {prompt}\nResponse: {response}")
311
-
312
- class CheckpointManager:
313
- def __init__(self):
314
- self.checkpoint_dir = "checkpoints"
315
- os.makedirs(self.checkpoint_dir, exist_ok=True)
316
-
317
- def save(self, model, optimizer, step):
318
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
319
- path = f"{self.checkpoint_dir}/hrom_{timestamp}_step{step}.pt"
320
- torch.save({
321
- "model": model.state_dict(),
322
- "optimizer": optimizer.state_dict(),
323
- "step": step,
324
- "config": CONFIG
325
- }, path)
326
- self._cleanup_old_checkpoints()
327
-
328
- def _cleanup_old_checkpoints(self):
329
- checkpoints = sorted(os.listdir(self.checkpoint_dir),
330
- key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
331
- while len(checkpoints) > CONFIG["max_checkpoints"]:
332
- os.remove(os.path.join(self.checkpoint_dir, checkpoints[0]))
333
- checkpoints = checkpoints[1:]
334
-
335
- def train():
336
- checkpoint_manager = CheckpointManager()
337
- if not os.path.exists("tokenizer/hrom_tokenizer.json"):
338
- print("Training tokenizer...")
339
- tokenizer_trainer = TokenizerTrainer()
340
- tokenizer_trainer.train(CONFIG["dataset"])
341
-
342
- tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json")
343
- model = HROM()
344
- print("Downloading and caching the dataset...")
345
- _ = load_dataset(CONFIG["dataset"], split="train", download_mode="reuse_cache_if_exists")
346
-
347
- dataset = ChatDataset(tokenizer)
348
- dataloader = DataLoader(
349
- dataset,
350
- batch_size=CONFIG["batch_size"],
351
- collate_fn=ChatDataset.collate_fn
352
- )
353
-
354
- trainer_obj = HROMTrainer(model, tokenizer)
355
- safety = SafetyManager(model, tokenizer)
356
-
357
- step = 0
358
- optimizer_step = 0
359
- total_loss = 0.0
360
- model.train()
361
-
362
- for epoch in range(CONFIG["num_epochs"]):
363
- for batch in dataloader:
364
- loss = trainer_obj.train_step(batch)
365
- total_loss += loss
366
- step += 1
367
-
368
- if step % CONFIG["grad_accum_steps"] == 0:
369
- trainer_obj.clip_and_step()
370
- avg_loss = total_loss / CONFIG["grad_accum_steps"]
371
- total_loss = 0.0
372
-
373
- if optimizer_step % CONFIG["checkpoint_interval"] == 0:
374
- checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step)
375
- safety.debug_generation()
376
-
377
- if optimizer_step % CONFIG["debug_interval"] == 0:
378
- print(f"Optimizer Step {optimizer_step} | Loss: {avg_loss:.4f}")
379
- safety.debug_generation("What's the meaning of life?")
380
-
381
- optimizer_step += 1
382
-
383
- if __name__ == "__main__":
384
- train()