Create main.py
Browse files- HROM-V1.6/main.py +1428 -0
HROM-V1.6/main.py
ADDED
@@ -0,0 +1,1428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# Set parallelism env var *before* importing tokenizers
|
3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
# Import necessary dataset functions, including concatenate_datasets if needed later
|
9 |
+
from datasets import load_dataset, disable_caching, concatenate_datasets
|
10 |
+
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
|
11 |
+
import math
|
12 |
+
import re
|
13 |
+
from datetime import datetime
|
14 |
+
from contextlib import nullcontext
|
15 |
+
from collections import defaultdict
|
16 |
+
import logging
|
17 |
+
import random # For shuffling combined data
|
18 |
+
|
19 |
+
# Disable caching for datasets if needed, helps ensure reprocessing
|
20 |
+
# disable_caching()
|
21 |
+
|
22 |
+
# Setup logging
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.INFO,
|
25 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
26 |
+
force=True # Add this
|
27 |
+
)
|
28 |
+
|
29 |
+
# Configuration
|
30 |
+
CONFIG = {
|
31 |
+
# --- Scaled Parameters ---
|
32 |
+
"dim": 960,
|
33 |
+
"n_layers": 32,
|
34 |
+
"n_heads": 32,
|
35 |
+
"ff_dim": 3608, # Explicitly set to 4 * dim
|
36 |
+
|
37 |
+
# --- Kept Parameters ---
|
38 |
+
"dropout": 0.1,
|
39 |
+
"max_seq_len": 512,
|
40 |
+
"vocab_size": 32000, # Fixed by tokenizer
|
41 |
+
|
42 |
+
# --- Training/Dataset Parameters ---
|
43 |
+
"batch_size": 2, # Increased batch size for 2 GPUs (adjust if needed)
|
44 |
+
"checkpoint_interval": 2000,
|
45 |
+
"debug_interval": 400,
|
46 |
+
# --- ADDED CoQA and QuAC and Universal-Transformers-RawConvo ---
|
47 |
+
"datasets": ["daily_dialog", "empathetic_dialogues", "blended_skill_talk", "AlekseyKorshuk/persona-chat", "elapt1c/Universal-Transformers-RawConvo", "papahawk/conversational-01"],
|
48 |
+
"tokenizer_name": "hrom_tokenizer.json", # New name for expanded tokenizer
|
49 |
+
"checkpoint_dir": "checkpoints", # Separate directory for expanded data model
|
50 |
+
# --- Increased samples per dataset slightly for tokenizer ---
|
51 |
+
"tokenizer_train_samples_per_dataset": 100000, # Use same limit for all, incl. new ones
|
52 |
+
"learning_rate": 1e-5,
|
53 |
+
"warmup_steps": 1000,
|
54 |
+
"max_turns": 8, # Keep max_turns limit for Q&A datasets too
|
55 |
+
"max_checkpoints": 5,
|
56 |
+
"num_epochs": 30,
|
57 |
+
"grad_accum_steps": 4,
|
58 |
+
"num_gpus": 2 # Specify the number of GPUs to use
|
59 |
+
}
|
60 |
+
|
61 |
+
# --- Model Definition (HROM, HROMBlock, HROMAttention, SwiGLU, RoPE) ---
|
62 |
+
# (These classes remain unchanged from the previous version)
|
63 |
+
class RotaryEmbedding(nn.Module):
|
64 |
+
def __init__(self, dim):
|
65 |
+
super().__init__()
|
66 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
67 |
+
self.register_buffer("inv_freq", inv_freq)
|
68 |
+
|
69 |
+
def forward(self, seq_len):
|
70 |
+
t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
|
71 |
+
freqs = torch.einsum("i, j -> i j", t, self.inv_freq)
|
72 |
+
if seq_len == 0:
|
73 |
+
return torch.empty((0, self.inv_freq.shape[0] * 2), device=self.inv_freq.device)
|
74 |
+
# Defensive reshape only if necessary
|
75 |
+
if freqs.shape[0] != seq_len and seq_len > 0:
|
76 |
+
freqs = freqs.reshape(seq_len, -1)
|
77 |
+
elif seq_len == 0: # Handle edge case for empty sequences
|
78 |
+
return torch.empty((0, self.inv_freq.shape[0]*2), device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
79 |
+
|
80 |
+
return torch.cat((freqs, freqs), dim=-1)
|
81 |
+
|
82 |
+
def rotate_half(x):
|
83 |
+
x1, x2 = x.chunk(2, dim=-1)
|
84 |
+
return torch.cat((-x2, x1), dim=-1)
|
85 |
+
|
86 |
+
def apply_rotary_pos_emb(pos, t):
|
87 |
+
# pos: (T, dim_rotary), t: (B, H, T, Head_Dim)
|
88 |
+
pos = pos.to(t.device, dtype=t.dtype)
|
89 |
+
pos = pos.unsqueeze(0).unsqueeze(1) # Shape: (1, 1, T, dim_rotary)
|
90 |
+
tensor_seq_len = t.shape[2]
|
91 |
+
pos_seq_len = pos.shape[2]
|
92 |
+
|
93 |
+
if pos_seq_len < tensor_seq_len:
|
94 |
+
logging.warning(f"RoPE Warning: pos sequence length ({pos_seq_len}) is shorter than tensor sequence length ({tensor_seq_len}). Using truncated tensor length for RoPE.")
|
95 |
+
# This case is tricky, maybe only apply to the length of pos?
|
96 |
+
# Or indicates an issue upstream. Let's slice t for now, though it's unusual.
|
97 |
+
t_rotated = t[:, :, :pos_seq_len, :]
|
98 |
+
pos = pos[:, :, :pos_seq_len, :] # Ensure pos matches the sliced tensor length
|
99 |
+
|
100 |
+
# Apply rotation only to the slice
|
101 |
+
cos_pos = pos.cos()
|
102 |
+
sin_pos = pos.sin()
|
103 |
+
t_rotated = (t_rotated * cos_pos) + (rotate_half(t_rotated) * sin_pos)
|
104 |
+
|
105 |
+
# Concatenate the rotated part with the un-rotated part
|
106 |
+
t_unrotated = t[:, :, pos_seq_len:, :]
|
107 |
+
return torch.cat([t_rotated, t_unrotated], dim=2)
|
108 |
+
|
109 |
+
elif pos_seq_len > tensor_seq_len:
|
110 |
+
pos = pos[:, :, :tensor_seq_len, :] # Slice pos to match tensor
|
111 |
+
|
112 |
+
# Check dimension match after potential slicing
|
113 |
+
if pos.shape[-1] != t.shape[-1]:
|
114 |
+
logging.error(f"Mismatched dimensions for RoPE: pos ({pos.shape[-1]}) vs t ({t.shape[-1]})")
|
115 |
+
raise ValueError("Rotary embedding dimension must match head dimension.")
|
116 |
+
|
117 |
+
cos_pos = pos.cos()
|
118 |
+
sin_pos = pos.sin()
|
119 |
+
rotated_t = (t * cos_pos) + (rotate_half(t) * sin_pos)
|
120 |
+
return rotated_t
|
121 |
+
|
122 |
+
|
123 |
+
class SwiGLU(nn.Module):
|
124 |
+
def forward(self, x):
|
125 |
+
x, gate = x.chunk(2, dim=-1)
|
126 |
+
return x * nn.functional.gelu(gate)
|
127 |
+
|
128 |
+
class HROMAttention(nn.Module):
|
129 |
+
def __init__(self):
|
130 |
+
super().__init__()
|
131 |
+
self.dim = CONFIG["dim"]
|
132 |
+
self.n_heads = CONFIG["n_heads"]
|
133 |
+
self.head_dim = self.dim // self.n_heads
|
134 |
+
if self.dim % self.n_heads != 0:
|
135 |
+
raise ValueError("dim must be divisible by n_heads")
|
136 |
+
self.qkv = nn.Linear(self.dim, 3 * self.dim)
|
137 |
+
self.proj = nn.Linear(self.dim, self.dim)
|
138 |
+
# Initialize RotaryEmbedding with head_dim to ensure match
|
139 |
+
self.rotary = RotaryEmbedding(self.head_dim)
|
140 |
+
self.dropout = nn.Dropout(CONFIG["dropout"])
|
141 |
+
|
142 |
+
# --- Assertion to verify dimension match ---
|
143 |
+
assert self.rotary.inv_freq.shape[0] * 2 == self.head_dim, f"Rotary Embedding dimension ({self.rotary.inv_freq.shape[0] * 2}) must match head dimension ({self.head_dim}). Check CONFIG parameters 'dim' and 'n_heads'."
|
144 |
+
|
145 |
+
|
146 |
+
def forward(self, x, mask=None):
|
147 |
+
B, T, C = x.shape
|
148 |
+
qkv = self.qkv(x)
|
149 |
+
qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim)
|
150 |
+
q, k, v = qkv.unbind(2)
|
151 |
+
q = q.transpose(1, 2)
|
152 |
+
k = k.transpose(1, 2)
|
153 |
+
v = v.transpose(1, 2)
|
154 |
+
# Generate RoPE embeddings for the current sequence length T
|
155 |
+
pos = self.rotary(T) # Shape (T, Head_Dim)
|
156 |
+
# Apply RoPE
|
157 |
+
q = apply_rotary_pos_emb(pos, q)
|
158 |
+
k = apply_rotary_pos_emb(pos, k)
|
159 |
+
# Attention calculation
|
160 |
+
attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
161 |
+
if mask is not None:
|
162 |
+
# Ensure mask is broadcastable (B, 1, T, T)
|
163 |
+
if mask.dim() == 2: # (B, T) -> (B, 1, 1, T) -> add with causal = (B, 1, T, T)
|
164 |
+
mask = mask.unsqueeze(1).unsqueeze(2)
|
165 |
+
elif mask.dim() == 3: # (B, T, T)
|
166 |
+
mask = mask.unsqueeze(1)
|
167 |
+
# Add mask AFTER scaling scores
|
168 |
+
attn_scores = attn_scores + mask # Add large negative values for masked positions
|
169 |
+
# Softmax and dropout
|
170 |
+
attn_probs = torch.softmax(attn_scores.float(), dim=-1).to(dtype=x.dtype) # Use float for stability
|
171 |
+
attn_probs = self.dropout(attn_probs)
|
172 |
+
# Output projection
|
173 |
+
output = attn_probs @ v
|
174 |
+
output = output.transpose(1, 2).reshape(B, T, self.dim)
|
175 |
+
return self.proj(output)
|
176 |
+
|
177 |
+
|
178 |
+
class HROMBlock(nn.Module):
|
179 |
+
def __init__(self):
|
180 |
+
super().__init__()
|
181 |
+
self.attn = HROMAttention()
|
182 |
+
self.ff = nn.Sequential(
|
183 |
+
nn.Linear(CONFIG["dim"], 2 * CONFIG["ff_dim"]),
|
184 |
+
SwiGLU(),
|
185 |
+
nn.Linear(CONFIG["ff_dim"], CONFIG["dim"])
|
186 |
+
)
|
187 |
+
self.norm1 = nn.LayerNorm(CONFIG["dim"])
|
188 |
+
self.norm2 = nn.LayerNorm(CONFIG["dim"])
|
189 |
+
self.dropout = nn.Dropout(CONFIG["dropout"])
|
190 |
+
|
191 |
+
def forward(self, x, mask=None):
|
192 |
+
# Pre-Normalization
|
193 |
+
normed_x = self.norm1(x)
|
194 |
+
attn_output = self.attn(normed_x, mask)
|
195 |
+
x = x + self.dropout(attn_output)
|
196 |
+
|
197 |
+
normed_x = self.norm2(x)
|
198 |
+
ff_output = self.ff(normed_x)
|
199 |
+
x = x + self.dropout(ff_output)
|
200 |
+
return x
|
201 |
+
|
202 |
+
class HROM(nn.Module):
|
203 |
+
def __init__(self):
|
204 |
+
super().__init__()
|
205 |
+
self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"])
|
206 |
+
self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])])
|
207 |
+
self.norm = nn.LayerNorm(CONFIG["dim"])
|
208 |
+
self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"])
|
209 |
+
self.dropout = nn.Dropout(CONFIG["dropout"]) # Add dropout after embedding
|
210 |
+
self.apply(self._init_weights)
|
211 |
+
|
212 |
+
def _init_weights(self, module):
|
213 |
+
if isinstance(module, nn.Linear):
|
214 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
215 |
+
if module.bias is not None:
|
216 |
+
torch.nn.init.zeros_(module.bias)
|
217 |
+
elif isinstance(module, nn.Embedding):
|
218 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
219 |
+
elif isinstance(module, nn.LayerNorm):
|
220 |
+
torch.nn.init.zeros_(module.bias)
|
221 |
+
torch.nn.init.ones_(module.weight)
|
222 |
+
|
223 |
+
def forward(self, input_ids, attention_mask=None):
|
224 |
+
B, T = input_ids.shape
|
225 |
+
x = self.embed(input_ids)
|
226 |
+
x = self.dropout(x) # Apply dropout after embedding
|
227 |
+
|
228 |
+
# Create the combined mask for attention
|
229 |
+
combined_mask = None
|
230 |
+
# Start with causal mask valid for all sequences in batch
|
231 |
+
causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device) * float('-inf'), diagonal=1)
|
232 |
+
combined_mask = causal_mask.unsqueeze(0).unsqueeze(1) # (1, 1, T, T)
|
233 |
+
|
234 |
+
if attention_mask is not None:
|
235 |
+
# Process padding mask from attention_mask (0 = pad, 1 = real)
|
236 |
+
# Convert 0s to -inf, 1s to 0
|
237 |
+
pad_mask = (1.0 - attention_mask.to(torch.float32)) * torch.finfo(torch.float32).min
|
238 |
+
pad_mask = pad_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
|
239 |
+
# Add padding mask to causal mask. Broadcasting ensures (B, 1, T, T)
|
240 |
+
# Where pad_mask is -inf, the result is -inf. Otherwise, it's the causal value.
|
241 |
+
combined_mask = combined_mask + pad_mask
|
242 |
+
|
243 |
+
# Ensure mask dtype matches data dtype (esp. for AMP)
|
244 |
+
combined_mask = combined_mask.to(dtype=x.dtype)
|
245 |
+
|
246 |
+
for block in self.blocks:
|
247 |
+
x = block(x, combined_mask) # Pass the combined mask to each block
|
248 |
+
|
249 |
+
x = self.norm(x)
|
250 |
+
logits = self.head(x)
|
251 |
+
return logits
|
252 |
+
|
253 |
+
# --- Tokenizer Training ---
|
254 |
+
|
255 |
+
class TokenizerTrainer:
|
256 |
+
def __init__(self):
|
257 |
+
self.tokenizer = Tokenizer(models.BPE())
|
258 |
+
self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
259 |
+
self.tokenizer.decoder = decoders.ByteLevel()
|
260 |
+
self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
|
261 |
+
# Use the updated tokenizer name from CONFIG
|
262 |
+
self.tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"])
|
263 |
+
self.tokenizer_dir = os.path.dirname(self.tokenizer_path)
|
264 |
+
|
265 |
+
def _clean_text(self, text):
|
266 |
+
text = str(text) # Ensure text is string
|
267 |
+
text = re.sub(r'_comma_', ',', text)
|
268 |
+
# Allow alphanumeric, whitespace, and basic punctuation including quotes
|
269 |
+
text = re.sub(r'[^\w\s.,!?\'\-:;<>"]', '', text)
|
270 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
271 |
+
return text
|
272 |
+
|
273 |
+
def train(self, dataset_names):
|
274 |
+
logging.info("Starting tokenizer training...")
|
275 |
+
text_samples = []
|
276 |
+
samples_per_dataset = CONFIG['tokenizer_train_samples_per_dataset']
|
277 |
+
|
278 |
+
# --- Process DailyDialog ---
|
279 |
+
if "daily_dialog" in dataset_names:
|
280 |
+
logging.info(f"Loading daily_dialog for tokenizer training (max {samples_per_dataset} dialogues)...")
|
281 |
+
try:
|
282 |
+
# Limit dialogues loaded directly using slicing
|
283 |
+
dd_dataset = load_dataset("daily_dialog", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) # Add trust_remote_code=True
|
284 |
+
logging.info("Processing daily_dialog...")
|
285 |
+
for entry in dd_dataset:
|
286 |
+
formatted_dialogue = []
|
287 |
+
dialogue = entry['dialog'][:CONFIG["max_turns"]]
|
288 |
+
for i, utterance in enumerate(dialogue):
|
289 |
+
role = "<user>" if i % 2 == 0 else "<assistant>"
|
290 |
+
cleaned_utterance = self._clean_text(utterance)
|
291 |
+
if cleaned_utterance: # Only add non-empty turns
|
292 |
+
formatted_dialogue.append(f"{role} {cleaned_utterance}")
|
293 |
+
if formatted_dialogue: # Only add if dialogue is not empty after cleaning
|
294 |
+
text_samples.append(" </s> ".join(formatted_dialogue))
|
295 |
+
except Exception as e:
|
296 |
+
logging.error(f"Failed to load or process daily_dialog for tokenizer: {e}")
|
297 |
+
|
298 |
+
# --- Process EmpatheticDialogues ---
|
299 |
+
if "empathetic_dialogues" in dataset_names:
|
300 |
+
logging.info(f"Loading empathetic_dialogues for tokenizer training (max {samples_per_dataset} dialogues)...")
|
301 |
+
try:
|
302 |
+
# Load more initially to ensure we get enough unique conversations (adjust multiplier if needed)
|
303 |
+
ed_dataset = load_dataset("empathetic_dialogues", split=f"train[:{samples_per_dataset * 3}]", trust_remote_code=True) # Add trust_remote_code=True
|
304 |
+
logging.info("Processing empathetic_dialogues...")
|
305 |
+
conversations = defaultdict(list)
|
306 |
+
processed_conv_count = 0
|
307 |
+
# Group utterances by conv_id first
|
308 |
+
grouped_by_conv = defaultdict(list)
|
309 |
+
for entry in ed_dataset:
|
310 |
+
grouped_by_conv[entry['conv_id']].append(entry)
|
311 |
+
|
312 |
+
# Process conversations ensuring max samples limit
|
313 |
+
for conv_id, entries in grouped_by_conv.items():
|
314 |
+
if processed_conv_count >= samples_per_dataset:
|
315 |
+
break
|
316 |
+
# Sort by utterance_idx to maintain order
|
317 |
+
sorted_entries = sorted(entries, key=lambda x: x['utterance_idx'])
|
318 |
+
formatted_dialogue = []
|
319 |
+
# Handle context and first utterance
|
320 |
+
if sorted_entries[0]['context']:
|
321 |
+
cleaned_context = self._clean_text(sorted_entries[0]['context'])
|
322 |
+
if cleaned_context:
|
323 |
+
formatted_dialogue.append(f"<user> {cleaned_context}") # Assume context is user start
|
324 |
+
# Process subsequent utterances
|
325 |
+
last_role = '<user>' if formatted_dialogue else None # Set initial last role based on context
|
326 |
+
for entry in sorted_entries:
|
327 |
+
cleaned_utterance = self._clean_text(entry['utterance'])
|
328 |
+
if cleaned_utterance:
|
329 |
+
# Determine role based on alternation
|
330 |
+
current_role = '<assistant>' if last_role == '<user>' else '<user>'
|
331 |
+
formatted_dialogue.append(f"{current_role} {cleaned_utterance}")
|
332 |
+
last_role = current_role # Update last role
|
333 |
+
# Apply max turns limit to the formatted turns
|
334 |
+
formatted_dialogue = formatted_dialogue[:CONFIG["max_turns"]]
|
335 |
+
if formatted_dialogue:
|
336 |
+
text_samples.append(" </s> ".join(formatted_dialogue))
|
337 |
+
processed_conv_count += 1 # Count processed unique conversations
|
338 |
+
|
339 |
+
except Exception as e:
|
340 |
+
logging.error(f"Failed to load or process empathetic_dialogues for tokenizer: {e}")
|
341 |
+
|
342 |
+
|
343 |
+
# --- Process BlendedSkillTalk ---
|
344 |
+
if "blended_skill_talk" in dataset_names:
|
345 |
+
logging.info(f"Loading blended_skill_talk for tokenizer training (max {samples_per_dataset} dialogues)...")
|
346 |
+
try:
|
347 |
+
# Load dialogues - BST is structured differently, slice directly
|
348 |
+
bst_dataset = load_dataset("blended_skill_talk", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) # Add trust_remote_code=True
|
349 |
+
logging.info("Processing blended_skill_talk...")
|
350 |
+
for entry in bst_dataset:
|
351 |
+
formatted_dialogue = []
|
352 |
+
# Combine the dialogue history and the final two turns
|
353 |
+
dialogue_turns_raw = entry['previous_utterance']
|
354 |
+
# Add final utterances if they exist and are not empty strings
|
355 |
+
if entry.get('free_turker_utterance'):
|
356 |
+
dialogue_turns_raw.append(entry['free_turker_utterance'])
|
357 |
+
if entry.get('guided_turker_utterance'):
|
358 |
+
dialogue_turns_raw.append(entry['guided_turker_utterance'])
|
359 |
+
|
360 |
+
turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] # Apply max turns limit
|
361 |
+
for i, utterance in enumerate(turns_to_process):
|
362 |
+
role = "<user>" if i % 2 == 0 else "<assistant>" # Assume simple alternation
|
363 |
+
cleaned_utterance = self._clean_text(utterance)
|
364 |
+
if cleaned_utterance:
|
365 |
+
formatted_dialogue.append(f"{role} {cleaned_utterance}")
|
366 |
+
if formatted_dialogue:
|
367 |
+
text_samples.append(" </s> ".join(formatted_dialogue))
|
368 |
+
except Exception as e:
|
369 |
+
logging.error(f"Failed to load or process blended_skill_talk for tokenizer: {e}")
|
370 |
+
|
371 |
+
# --- Process PersonaChat ---
|
372 |
+
if "AlekseyKorshuk/persona-chat" in dataset_names: # Correct dataset identifier
|
373 |
+
pc_dataset_name = "AlekseyKorshuk/persona-chat"
|
374 |
+
logging.info(f"Loading {pc_dataset_name} for tokenizer training (max {samples_per_dataset} dialogues)...")
|
375 |
+
try:
|
376 |
+
pc_dataset = load_dataset(pc_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True) # Add trust_remote_code=True, Correct dataset identifier
|
377 |
+
logging.info(f"Processing {pc_dataset_name}...")
|
378 |
+
for entry in pc_dataset:
|
379 |
+
# PersonaChat often has 'utterances' containing 'history'
|
380 |
+
if 'utterances' in entry and entry['utterances']:
|
381 |
+
# Get the history from the last item in utterances for the full dialogue
|
382 |
+
history = entry['utterances'][-1]['history']
|
383 |
+
history = history[:CONFIG["max_turns"]] # Apply max turns
|
384 |
+
formatted_dialogue = []
|
385 |
+
for i, utterance in enumerate(history):
|
386 |
+
role = "<user>" if i % 2 == 0 else "<assistant>" # Assume simple alternation
|
387 |
+
cleaned_utterance = self._clean_text(utterance)
|
388 |
+
if cleaned_utterance:
|
389 |
+
formatted_dialogue.append(f"{role} {cleaned_utterance}")
|
390 |
+
if formatted_dialogue:
|
391 |
+
text_samples.append(" </s> ".join(formatted_dialogue))
|
392 |
+
else:
|
393 |
+
logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry}")
|
394 |
+
|
395 |
+
except Exception as e:
|
396 |
+
logging.error(f"Failed to load or process {pc_dataset_name} for tokenizer: {e}")
|
397 |
+
|
398 |
+
# --- Process Universal-Transformers-RawConvo ---
|
399 |
+
if "elapt1c/Universal-Transformers-RawConvo" in dataset_names:
|
400 |
+
utrc_dataset_name = "elapt1c/Universal-Transformers-RawConvo"
|
401 |
+
logging.info(f"Loading {utrc_dataset_name} for tokenizer training (max {samples_per_dataset} dialogues)...")
|
402 |
+
try:
|
403 |
+
utrc_dataset = load_dataset(utrc_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True)
|
404 |
+
logging.info(f"Processing {utrc_dataset_name}...")
|
405 |
+
for entry in utrc_dataset:
|
406 |
+
formatted_dialogue = []
|
407 |
+
dialogue = entry['dialogue'] # Assuming 'dialogue' column contains the conversation turns
|
408 |
+
if isinstance(dialogue, list): # Check if dialogue is a list of strings
|
409 |
+
dialogue = dialogue[:CONFIG["max_turns"]]
|
410 |
+
for i, utterance in enumerate(dialogue):
|
411 |
+
role = "<user>" if i % 2 == 0 else "<assistant>"
|
412 |
+
cleaned_utterance = self._clean_text(utterance)
|
413 |
+
if cleaned_utterance:
|
414 |
+
formatted_dialogue.append(f"{role} {cleaned_utterance}")
|
415 |
+
if formatted_dialogue:
|
416 |
+
text_samples.append(" </s> ".join(formatted_dialogue))
|
417 |
+
else:
|
418 |
+
logging.warning(f"Skipping {utrc_dataset_name} entry due to unexpected 'dialogue' structure: {type(dialogue)}")
|
419 |
+
|
420 |
+
except Exception as e:
|
421 |
+
logging.error(f"Failed to load or process {utrc_dataset_name} for tokenizer: {e}")
|
422 |
+
|
423 |
+
# --- Process papahawk/conversational-01 ---
|
424 |
+
if "papahawk/conversational-01" in dataset_names:
|
425 |
+
pco_dataset_name = "papahawk/conversational-01"
|
426 |
+
logging.info(f"Loading {pco_dataset_name} for tokenizer training (max {samples_per_dataset} samples)...")
|
427 |
+
try:
|
428 |
+
pco_dataset = load_dataset(pco_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True)
|
429 |
+
logging.info(f"Processing {pco_dataset_name}...")
|
430 |
+
for entry in pco_dataset:
|
431 |
+
prompt = entry['prompt']
|
432 |
+
response = entry['response']
|
433 |
+
cleaned_prompt = self._clean_text(prompt)
|
434 |
+
cleaned_response = self._clean_text(response)
|
435 |
+
if cleaned_prompt and cleaned_response:
|
436 |
+
formatted_dialogue = [f"<user> {cleaned_prompt}", f"<assistant> {cleaned_response}"]
|
437 |
+
text_samples.append(" </s> ".join(formatted_dialogue))
|
438 |
+
|
439 |
+
except Exception as e:
|
440 |
+
logging.error(f"Failed to load or process {pco_dataset_name} for tokenizer: {e}")
|
441 |
+
|
442 |
+
|
443 |
+
logging.info(f"Total text samples for tokenizer training: {len(text_samples)}")
|
444 |
+
if not text_samples:
|
445 |
+
raise ValueError("No text samples collected for tokenizer training. Check dataset loading and paths.")
|
446 |
+
|
447 |
+
# Ensure tokenizer directory exists before training
|
448 |
+
os.makedirs(self.tokenizer_dir, exist_ok=True)
|
449 |
+
|
450 |
+
logging.info(f"Training BPE tokenizer with vocab size {CONFIG['vocab_size']}...")
|
451 |
+
trainer = trainers.BpeTrainer(
|
452 |
+
vocab_size=CONFIG["vocab_size"],
|
453 |
+
special_tokens=self.special_tokens,
|
454 |
+
min_frequency=2, # Keep min_frequency low with more data
|
455 |
+
show_progress=True
|
456 |
+
)
|
457 |
+
# Make sure text_samples is an iterator or list of strings
|
458 |
+
def text_iterator():
|
459 |
+
for sample in text_samples:
|
460 |
+
yield sample
|
461 |
+
|
462 |
+
self.tokenizer.train_from_iterator(text_iterator(), trainer=trainer, length=len(text_samples))
|
463 |
+
|
464 |
+
eos_token_id = self.tokenizer.token_to_id("</s>")
|
465 |
+
if eos_token_id is None:
|
466 |
+
logging.warning("</s> token not found in trained tokenizer vocab! Using <pad> as fallback for post-processor.")
|
467 |
+
eos_token_id = self.tokenizer.token_to_id("<pad>") or 0 # Fallback needed
|
468 |
+
|
469 |
+
# Configure post-processor (adjust if needed based on how you structure input/output)
|
470 |
+
self.tokenizer.post_processor = processors.TemplateProcessing(
|
471 |
+
single="$A </s>",
|
472 |
+
pair="$A </s> $B </s>", # How to handle pairs - maybe just use single always?
|
473 |
+
special_tokens=[("</s>", eos_token_id)],
|
474 |
+
)
|
475 |
+
|
476 |
+
logging.info(f"Saving tokenizer to {self.tokenizer_path}")
|
477 |
+
self.tokenizer.save(self.tokenizer_path)
|
478 |
+
logging.info("Tokenizer training complete.")
|
479 |
+
|
480 |
+
def get_tokenizer(self):
|
481 |
+
if not os.path.exists(self.tokenizer_path):
|
482 |
+
raise FileNotFoundError(f"Tokenizer file not found at {self.tokenizer_path}. Train tokenizer first.")
|
483 |
+
tokenizer = Tokenizer.from_file(self.tokenizer_path)
|
484 |
+
# Verify special tokens crucial for processing exist
|
485 |
+
required_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
|
486 |
+
for token in required_tokens:
|
487 |
+
if tokenizer.token_to_id(token) is None:
|
488 |
+
raise ValueError(f"Crucial special token '{token}' not found in loaded tokenizer '{self.tokenizer_path}'!")
|
489 |
+
return tokenizer
|
490 |
+
|
491 |
+
# --- Dataset Loading and Processing ---
|
492 |
+
|
493 |
+
class CombinedChatDataset(Dataset):
|
494 |
+
def __init__(self, tokenizer):
|
495 |
+
self.tokenizer = tokenizer
|
496 |
+
self.pad_id = self.tokenizer.token_to_id("<pad>")
|
497 |
+
self.eos_id = self.tokenizer.token_to_id("</s>")
|
498 |
+
self.bos_id = self.tokenizer.token_to_id("<s>")
|
499 |
+
self.user_id = self.tokenizer.token_to_id("<user>")
|
500 |
+
self.assistant_id = self.tokenizer.token_to_id("<assistant>")
|
501 |
+
self.max_length = CONFIG["max_seq_len"]
|
502 |
+
# Reuse cleaning function from TokenizerTrainer instance
|
503 |
+
self._clean_text = TokenizerTrainer()._clean_text
|
504 |
+
|
505 |
+
self.all_processed_conversations = []
|
506 |
+
|
507 |
+
# --- Process DailyDialog ---
|
508 |
+
if "daily_dialog" in CONFIG["datasets"]:
|
509 |
+
logging.info("Loading and processing daily_dialog dataset...")
|
510 |
+
try:
|
511 |
+
dd_dataset = load_dataset("daily_dialog", split="train", trust_remote_code=True) # Add trust_remote_code=True
|
512 |
+
logging.info(f"Processing {len(dd_dataset)} daily_dialog conversations...")
|
513 |
+
for entry in dd_dataset:
|
514 |
+
conversation = []
|
515 |
+
dialogue = entry['dialog'][:CONFIG["max_turns"]]
|
516 |
+
if not dialogue: continue
|
517 |
+
for i, utterance in enumerate(dialogue):
|
518 |
+
role = "<user>" if i % 2 == 0 else "<assistant>"
|
519 |
+
cleaned_text = self._clean_text(utterance)
|
520 |
+
if cleaned_text:
|
521 |
+
conversation.append({'role': role, 'text': cleaned_text})
|
522 |
+
if conversation:
|
523 |
+
self.all_processed_conversations.append(conversation)
|
524 |
+
except Exception as e:
|
525 |
+
logging.error(f"Failed to load or process daily_dialog for training: {e}")
|
526 |
+
|
527 |
+
# --- Process EmpatheticDialogues ---
|
528 |
+
if "empathetic_dialogues" in CONFIG["datasets"]:
|
529 |
+
logging.info("Loading and processing empathetic_dialogues dataset...")
|
530 |
+
try:
|
531 |
+
ed_dataset = load_dataset("empathetic_dialogues", split="train", trust_remote_code=True) # Add trust_remote_code=True
|
532 |
+
logging.info("Grouping empathetic_dialogues by conversation ID...")
|
533 |
+
conversations_grouped = defaultdict(list)
|
534 |
+
for entry in ed_dataset:
|
535 |
+
conversations_grouped[entry['conv_id']].append(entry)
|
536 |
+
|
537 |
+
logging.info(f"Processing {len(conversations_grouped)} empathetic_dialogues conversations...")
|
538 |
+
for conv_id, entries in conversations_grouped.items():
|
539 |
+
conversation = []
|
540 |
+
sorted_entries = sorted(entries, key=lambda x: x['utterance_idx'])
|
541 |
+
# Handle context as first user turn if present
|
542 |
+
if sorted_entries[0]['context']:
|
543 |
+
context_text = self._clean_text(sorted_entries[0]['context'])
|
544 |
+
if context_text:
|
545 |
+
conversation.append({'role': '<user>', 'text': context_text})
|
546 |
+
# Process utterances, assuming alternation
|
547 |
+
last_role = conversation[-1]['role'] if conversation else None # Role of the last added turn
|
548 |
+
for entry in sorted_entries:
|
549 |
+
text = self._clean_text(entry['utterance'])
|
550 |
+
if not text: continue
|
551 |
+
# Determine role based on the *last added* role
|
552 |
+
current_role = '<assistant>' if last_role == '<user>' else '<user>'
|
553 |
+
conversation.append({'role': current_role, 'text': text})
|
554 |
+
last_role = current_role # Update for next iteration
|
555 |
+
|
556 |
+
# Apply max turns limit *after* forming the full sequence
|
557 |
+
conversation = conversation[:CONFIG["max_turns"]]
|
558 |
+
if conversation:
|
559 |
+
self.all_processed_conversations.append(conversation)
|
560 |
+
|
561 |
+
except Exception as e:
|
562 |
+
logging.error(f"Failed to load or process empathetic_dialogues for training: {e}")
|
563 |
+
|
564 |
+
# --- Process BlendedSkillTalk ---
|
565 |
+
if "blended_skill_talk" in CONFIG["datasets"]:
|
566 |
+
logging.info("Loading and processing blended_skill_talk dataset...")
|
567 |
+
try:
|
568 |
+
bst_dataset = load_dataset("blended_skill_talk", split="train", trust_remote_code=True) # Add trust_remote_code=True
|
569 |
+
logging.info(f"Processing {len(bst_dataset)} blended_skill_talk conversations...")
|
570 |
+
for entry in bst_dataset:
|
571 |
+
conversation = []
|
572 |
+
# Reconstruct dialogue: history + final two turns (if they exist)
|
573 |
+
dialogue_turns_raw = entry['previous_utterance']
|
574 |
+
if entry.get('free_turker_utterance'):
|
575 |
+
dialogue_turns_raw.append(entry['free_turker_utterance'])
|
576 |
+
if entry.get('guided_turker_utterance'):
|
577 |
+
dialogue_turns_raw.append(entry['guided_turker_utterance'])
|
578 |
+
|
579 |
+
if not dialogue_turns_raw: continue # Skip if no turns found
|
580 |
+
|
581 |
+
turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] # Apply max turns limit
|
582 |
+
|
583 |
+
for i, utterance in enumerate(turns_to_process):
|
584 |
+
role = "<user>" if i % 2 == 0 else "<assistant>" # Assume simple alternation
|
585 |
+
cleaned_text = self._clean_text(utterance)
|
586 |
+
if cleaned_text:
|
587 |
+
conversation.append({'role': role, 'text': cleaned_text})
|
588 |
+
if conversation: # Only add if not empty after cleaning/truncation
|
589 |
+
self.all_processed_conversations.append(conversation)
|
590 |
+
except Exception as e:
|
591 |
+
logging.error(f"Failed to load or process blended_skill_talk for training: {e}")
|
592 |
+
|
593 |
+
# --- Process PersonaChat ---
|
594 |
+
if "AlekseyKorshuk/persona-chat" in CONFIG["datasets"]: # Correct dataset identifier
|
595 |
+
pc_dataset_name = "AlekseyKorshuk/persona-chat"
|
596 |
+
logging.info(f"Loading and processing {pc_dataset_name} dataset...")
|
597 |
+
try:
|
598 |
+
pc_dataset = load_dataset(pc_dataset_name, split="train", trust_remote_code=True) # Add trust_remote_code=True, Correct dataset identifier
|
599 |
+
logging.info(f"Processing {len(pc_dataset)} {pc_dataset_name} conversations...")
|
600 |
+
for entry in pc_dataset:
|
601 |
+
conversation = []
|
602 |
+
if 'utterances' in entry and entry['utterances']:
|
603 |
+
# Extract the dialogue history
|
604 |
+
history = entry['utterances'][-1]['history']
|
605 |
+
history = history[:CONFIG["max_turns"]] # Apply max turns limit
|
606 |
+
|
607 |
+
for i, utterance in enumerate(history):
|
608 |
+
role = "<user>" if i % 2 == 0 else "<assistant>" # Simple alternation
|
609 |
+
cleaned_text = self._clean_text(utterance)
|
610 |
+
if cleaned_text:
|
611 |
+
conversation.append({'role': role, 'text': cleaned_text})
|
612 |
+
|
613 |
+
if conversation: # Only add if not empty
|
614 |
+
self.all_processed_conversations.append(conversation)
|
615 |
+
else:
|
616 |
+
logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry.keys()}")
|
617 |
+
|
618 |
+
except Exception as e:
|
619 |
+
logging.error(f"Failed to load or process {pc_dataset_name} for training: {e}")
|
620 |
+
|
621 |
+
# --- Process Universal-Transformers-RawConvo ---
|
622 |
+
if "elapt1c/Universal-Transformers-RawConvo" in CONFIG["datasets"]:
|
623 |
+
utrc_dataset_name = "elapt1c/Universal-Transformers-RawConvo"
|
624 |
+
logging.info(f"Loading and processing {utrc_dataset_name} dataset...")
|
625 |
+
try:
|
626 |
+
utrc_dataset = load_dataset(utrc_dataset_name, split="train", trust_remote_code=True)
|
627 |
+
logging.info(f"Processing {len(utrc_dataset)} {utrc_dataset_name} conversations...")
|
628 |
+
for entry in utrc_dataset:
|
629 |
+
conversation = []
|
630 |
+
dialogue = entry['dialogue'] # Assuming 'dialogue' column contains the conversation turns
|
631 |
+
if isinstance(dialogue, list): # Check if dialogue is a list of strings
|
632 |
+
dialogue = dialogue[:CONFIG["max_turns"]]
|
633 |
+
for i, utterance in enumerate(dialogue):
|
634 |
+
role = "<user>" if i % 2 == 0 else "<assistant>"
|
635 |
+
cleaned_text = self._clean_text(utterance)
|
636 |
+
if cleaned_text:
|
637 |
+
conversation.append({'role': role, 'text': cleaned_text})
|
638 |
+
if conversation:
|
639 |
+
self.all_processed_conversations.append(conversation)
|
640 |
+
else:
|
641 |
+
logging.warning(f"Skipping {utrc_dataset_name} entry due to unexpected 'dialogue' structure: {type(dialogue)}")
|
642 |
+
|
643 |
+
except Exception as e:
|
644 |
+
logging.error(f"Failed to load or process {utrc_dataset_name} for training: {e}")
|
645 |
+
|
646 |
+
# --- Process papahawk/conversational-01 ---
|
647 |
+
if "papahawk/conversational-01" in CONFIG["datasets"]:
|
648 |
+
pco_dataset_name = "papahawk/conversational-01"
|
649 |
+
logging.info(f"Loading and processing {pco_dataset_name} dataset...")
|
650 |
+
try:
|
651 |
+
pco_dataset = load_dataset(pco_dataset_name, split="train", trust_remote_code=True)
|
652 |
+
logging.info(f"Processing {len(pco_dataset)} {pco_dataset_name} conversations...")
|
653 |
+
for entry in pco_dataset:
|
654 |
+
prompt = entry['prompt']
|
655 |
+
response = entry['response']
|
656 |
+
cleaned_prompt = self._clean_text(prompt)
|
657 |
+
cleaned_response = self._clean_text(response)
|
658 |
+
if cleaned_prompt and cleaned_response:
|
659 |
+
conversation = [
|
660 |
+
{'role': '<user>', 'text': cleaned_prompt},
|
661 |
+
{'role': '<assistant>', 'text': cleaned_response}
|
662 |
+
]
|
663 |
+
self.all_processed_conversations.append(conversation)
|
664 |
+
|
665 |
+
except Exception as e:
|
666 |
+
logging.error(f"Failed to load or process {pco_dataset_name} for training: {e}")
|
667 |
+
|
668 |
+
|
669 |
+
logging.info(f"Total processed conversations from all datasets: {len(self.all_processed_conversations)}")
|
670 |
+
if not self.all_processed_conversations:
|
671 |
+
raise ValueError("No processed conversations were created from any dataset. Check loading logic and dataset availability.")
|
672 |
+
|
673 |
+
logging.info("Shuffling combined dataset...")
|
674 |
+
random.shuffle(self.all_processed_conversations)
|
675 |
+
|
676 |
+
|
677 |
+
def __len__(self):
|
678 |
+
return len(self.all_processed_conversations)
|
679 |
+
|
680 |
+
def __getitem__(self, idx):
|
681 |
+
conversation = self.all_processed_conversations[idx]
|
682 |
+
formatted_ids = [self.bos_id]
|
683 |
+
for turn in conversation:
|
684 |
+
role_id = self.user_id if turn['role'] == '<user>' else self.assistant_id
|
685 |
+
# Encode without adding special tokens automatically by tokenizer
|
686 |
+
try:
|
687 |
+
utterance_ids = self.tokenizer.encode(turn['text'], add_special_tokens=False).ids
|
688 |
+
except Exception as e:
|
689 |
+
logging.error(f"Error encoding text at index {idx}, turn '{turn}': {e}")
|
690 |
+
utterance_ids = [] # Skip this utterance on error
|
691 |
+
|
692 |
+
# Check length: Current + Role + Utterance + EOS <= MaxLength
|
693 |
+
# Need +1 for role, +len(utterance), +1 for potential EOS
|
694 |
+
if len(formatted_ids) + 1 + len(utterance_ids) + 1 > self.max_length:
|
695 |
+
# Attempt to add just the role and EOS if utterance is too long
|
696 |
+
if len(formatted_ids) + 1 + 1 <= self.max_length:
|
697 |
+
formatted_ids.append(role_id)
|
698 |
+
formatted_ids.append(self.eos_id)
|
699 |
+
break # Stop adding turns
|
700 |
+
|
701 |
+
formatted_ids.append(role_id)
|
702 |
+
formatted_ids.extend(utterance_ids)
|
703 |
+
formatted_ids.append(self.eos_id)
|
704 |
+
|
705 |
+
# Final safety truncate (should be rare if logic above is correct)
|
706 |
+
if len(formatted_ids) > self.max_length:
|
707 |
+
formatted_ids = formatted_ids[:self.max_length]
|
708 |
+
# Ensure last token isn't partial (though unlikely with BPE)
|
709 |
+
# If the truncated sequence ends with a role ID, it's probably bad, remove it.
|
710 |
+
if formatted_ids and (formatted_ids[-1] == self.user_id or formatted_ids[-1] == self.assistant_id):
|
711 |
+
formatted_ids.pop()
|
712 |
+
# If after popping the role ID, it's still too long (unlikely), truncate again
|
713 |
+
if len(formatted_ids) > self.max_length:
|
714 |
+
formatted_ids = formatted_ids[:self.max_length]
|
715 |
+
|
716 |
+
|
717 |
+
# Handle case of extremely short sequences after processing
|
718 |
+
if len(formatted_ids) < 2: # Need at least BOS and one other token for input/label pair
|
719 |
+
logging.warning(f"Sequence at index {idx} is too short after processing (<2 tokens). Skipping. Original length: {len(conversation)}")
|
720 |
+
# Return None to be filtered by collate_fn
|
721 |
+
return None
|
722 |
+
|
723 |
+
input_ids = formatted_ids[:-1]
|
724 |
+
labels = formatted_ids[1:]
|
725 |
+
|
726 |
+
# Final check before returning
|
727 |
+
if len(input_ids) == 0:
|
728 |
+
logging.warning(f"Sequence at index {idx} resulted in empty input_ids after slicing. Skipping.")
|
729 |
+
return None
|
730 |
+
|
731 |
+
|
732 |
+
return {"input_ids": input_ids, "labels": labels}
|
733 |
+
|
734 |
+
@staticmethod
|
735 |
+
def collate_fn(batch):
|
736 |
+
# Filter out None items from __getitem__
|
737 |
+
batch = [item for item in batch if item is not None]
|
738 |
+
if not batch:
|
739 |
+
return None # Return None if the whole batch was invalid
|
740 |
+
|
741 |
+
max_len = max(len(item["input_ids"]) for item in batch)
|
742 |
+
|
743 |
+
# Load tokenizer once to get pad_id - ensure path matches CONFIG
|
744 |
+
try:
|
745 |
+
# Correctly reference the tokenizer path from CONFIG within the static method
|
746 |
+
tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"])
|
747 |
+
# TODO: Consider passing tokenizer/pad_id if this becomes a bottleneck
|
748 |
+
tokenizer = Tokenizer.from_file(tokenizer_path)
|
749 |
+
pad_id = tokenizer.token_to_id("<pad>")
|
750 |
+
if pad_id is None: raise ValueError("<pad> token not found")
|
751 |
+
except Exception as e:
|
752 |
+
logging.error(f"Collate Error: Failed to load tokenizer or get pad_id ('{CONFIG['tokenizer_name']}'): {e}")
|
753 |
+
pad_id = 0 # Risky fallback
|
754 |
+
|
755 |
+
inputs, labels, masks = [], [], []
|
756 |
+
for item in batch:
|
757 |
+
input_len = len(item["input_ids"])
|
758 |
+
pad_len = max_len - input_len
|
759 |
+
inputs.append(item["input_ids"] + [pad_id] * pad_len)
|
760 |
+
# Pad labels with pad_id (or any ID to be ignored by CrossEntropyLoss)
|
761 |
+
labels.append(item["labels"] + [pad_id] * pad_len)
|
762 |
+
masks.append([1] * input_len + [0] * pad_len)
|
763 |
+
|
764 |
+
return {
|
765 |
+
"input_ids": torch.tensor(inputs, dtype=torch.long),
|
766 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
767 |
+
"attention_mask": torch.tensor(masks, dtype=torch.long) # Or bool
|
768 |
+
}
|
769 |
+
|
770 |
+
# --- Trainer, Safety Manager, Checkpoint Manager ---
|
771 |
+
|
772 |
+
class HROMTrainer:
|
773 |
+
def __init__(self, model, tokenizer):
|
774 |
+
# Determine device based on number of GPUs
|
775 |
+
if CONFIG["num_gpus"] > 1 and torch.cuda.device_count() >= CONFIG["num_gpus"]:
|
776 |
+
self.device = [torch.device(f"cuda:{i}") for i in range(CONFIG["num_gpus"])]
|
777 |
+
logging.info(f"Using DataParallel on devices: {self.device}")
|
778 |
+
model = nn.DataParallel(model, device_ids=self.device) # Wrap model for DP
|
779 |
+
self.main_device = self.device[0] # Main device for logging etc.
|
780 |
+
elif torch.cuda.is_available():
|
781 |
+
self.device = [torch.device("cuda:0")] # Single GPU list for consistency
|
782 |
+
self.main_device = self.device[0]
|
783 |
+
logging.info(f"Using single GPU: {self.main_device}")
|
784 |
+
else:
|
785 |
+
self.device = [torch.device("cpu")]
|
786 |
+
self.main_device = self.device[0]
|
787 |
+
logging.info(f"Using CPU: {self.main_device}")
|
788 |
+
|
789 |
+
self.model = model.to(self.main_device) # Move model to main device
|
790 |
+
|
791 |
+
self.use_amp = (self.main_device.type == "cuda" and hasattr(torch.cuda.amp, "GradScaler"))
|
792 |
+
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
793 |
+
logging.info(f"Automatic Mixed Precision (AMP): {'Enabled' if self.use_amp else 'Disabled'}")
|
794 |
+
|
795 |
+
self.optimizer = torch.optim.AdamW(
|
796 |
+
self.model.parameters(),
|
797 |
+
lr=CONFIG["learning_rate"], # Base LR
|
798 |
+
betas=(0.9, 0.95),
|
799 |
+
weight_decay=0.1,
|
800 |
+
fused= (self.main_device.type == "cuda")
|
801 |
+
)
|
802 |
+
self.tokenizer = tokenizer
|
803 |
+
self.pad_id = self.tokenizer.token_to_id("<pad>")
|
804 |
+
if self.pad_id is None:
|
805 |
+
# Attempt to get from config if available or fallback
|
806 |
+
self.pad_id = CONFIG.get("pad_token_id", 0)
|
807 |
+
logging.warning(f"<pad> token ID not found in tokenizer, using fallback ID: {self.pad_id}")
|
808 |
+
|
809 |
+
|
810 |
+
# Make sure ignore_index uses the determined pad_id
|
811 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id)
|
812 |
+
self.base_lr = CONFIG["learning_rate"]
|
813 |
+
self.warmup_steps = CONFIG["warmup_steps"]
|
814 |
+
|
815 |
+
def _adjust_learning_rate(self, step):
|
816 |
+
if self.warmup_steps > 0 and step < self.warmup_steps:
|
817 |
+
lr = self.base_lr * (step + 1) / self.warmup_steps
|
818 |
+
else:
|
819 |
+
# Optional: Add LR decay (e.g., cosine) after warmup
|
820 |
+
# Example: lr = self.base_lr * 0.5 * (1 + math.cos(math.pi * (step - self.warmup_steps) / (total_steps - self.warmup_steps)))
|
821 |
+
lr = self.base_lr # Keep base LR after warmup for now
|
822 |
+
for param_group in self.optimizer.param_groups:
|
823 |
+
param_group['lr'] = lr
|
824 |
+
return lr
|
825 |
+
|
826 |
+
def train_step(self, batch):
|
827 |
+
# Determine precision for autocast
|
828 |
+
if self.use_amp:
|
829 |
+
# Use torch.float16 (FP16) for mixed precision training
|
830 |
+
amp_dtype = torch.float16 # FP16 training explicitly requested
|
831 |
+
autocast_context = torch.cuda.amp.autocast(dtype=amp_dtype, enabled=self.use_amp) if self.use_amp else nullcontext()
|
832 |
+
|
833 |
+
with autocast_context:
|
834 |
+
input_ids = batch["input_ids"].to(self.main_device) # Move to main device
|
835 |
+
attention_mask = batch["attention_mask"].to(self.main_device) # Move to main device
|
836 |
+
labels = batch["labels"].to(self.main_device) # Move to main device
|
837 |
+
|
838 |
+
outputs = self.model(input_ids, attention_mask=attention_mask)
|
839 |
+
|
840 |
+
# Reshape for loss calculation
|
841 |
+
logits_flat = outputs.view(-1, outputs.size(-1)) # Shape: (B * T, vocab_size)
|
842 |
+
labels_flat = labels.view(-1) # Shape: (B * T)
|
843 |
+
|
844 |
+
# Calculate loss - ensure logits are float32 for stability esp. with AMP
|
845 |
+
loss = self.criterion(logits_flat.float(), labels_flat)
|
846 |
+
|
847 |
+
# Scale loss for gradient accumulation
|
848 |
+
scaled_loss = loss / CONFIG["grad_accum_steps"]
|
849 |
+
|
850 |
+
# Backward pass
|
851 |
+
if self.use_amp and self.scaler:
|
852 |
+
self.scaler.scale(scaled_loss).backward()
|
853 |
+
else:
|
854 |
+
scaled_loss.backward()
|
855 |
+
|
856 |
+
return loss.item() # Return the unscaled loss for logging
|
857 |
+
|
858 |
+
def clip_and_step(self, current_optimizer_step):
|
859 |
+
current_lr = self._adjust_learning_rate(current_optimizer_step)
|
860 |
+
# Gradient Clipping *before* optimizer step
|
861 |
+
if self.use_amp and self.scaler:
|
862 |
+
# Unscale first - important before clipping
|
863 |
+
self.scaler.unscale_(self.optimizer)
|
864 |
+
# Clip grad norm
|
865 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
866 |
+
# Optimizer step (with scaler)
|
867 |
+
self.scaler.step(self.optimizer)
|
868 |
+
# Update scaler for next iteration
|
869 |
+
self.scaler.update()
|
870 |
+
else:
|
871 |
+
# Clip grad norm
|
872 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
873 |
+
# Optimizer step
|
874 |
+
self.optimizer.step()
|
875 |
+
|
876 |
+
# Zero gradients *after* stepping
|
877 |
+
self.optimizer.zero_grad(set_to_none=True)
|
878 |
+
return current_lr
|
879 |
+
|
880 |
+
|
881 |
+
class SafetyManager:
|
882 |
+
def __init__(self, model, tokenizer, main_device): # Added main_device argument
|
883 |
+
self.model = model
|
884 |
+
self.tokenizer = tokenizer
|
885 |
+
self.main_device = main_device # Store main_device
|
886 |
+
# More conservative list
|
887 |
+
self.bad_words = ["kill", "murder", "suicide", "hate", "abuse", "violence", "illegal", "harm", "die", "attack", "rape", "molest", "exploit", "terror"]
|
888 |
+
self.bad_word_ids = []
|
889 |
+
logging.info("Initializing safety manager...")
|
890 |
+
# Pre-encode bad word sequences
|
891 |
+
for word in self.bad_words:
|
892 |
+
# Encode potentially multi-token words carefully
|
893 |
+
ids = tokenizer.encode(f" {word}", add_special_tokens=False).ids # Add prefix space for BPE
|
894 |
+
if ids:
|
895 |
+
self.bad_word_ids.append(ids)
|
896 |
+
logging.debug(f"Encoded bad word '{word}' (with space) to IDs: {ids}")
|
897 |
+
# Try without space too
|
898 |
+
ids_no_space = tokenizer.encode(word, add_special_tokens=False).ids
|
899 |
+
if ids_no_space and ids_no_space != ids:
|
900 |
+
self.bad_word_ids.append(ids_no_space)
|
901 |
+
logging.debug(f"Encoded bad word '{word}' (no space) to IDs: {ids_no_space}")
|
902 |
+
|
903 |
+
if not ids and not ids_no_space:
|
904 |
+
logging.warning(f"Could not encode bad word '{word}' - skipping.")
|
905 |
+
|
906 |
+
# Pre-get special IDs
|
907 |
+
self.eos_id = self.tokenizer.token_to_id("</s>")
|
908 |
+
self.bos_id = self.tokenizer.token_to_id("<s>")
|
909 |
+
self.user_id = self.tokenizer.token_to_id("<user>")
|
910 |
+
self.assistant_id = self.tokenizer.token_to_id("<assistant>")
|
911 |
+
self.pad_id = self.tokenizer.token_to_id("<pad>")
|
912 |
+
|
913 |
+
if self.eos_id is None: logging.error("</s> token ID not found for SafetyManager!"); self.eos_id = 0
|
914 |
+
if self.bos_id is None: logging.error("<s> token ID not found for SafetyManager!"); self.bos_id = 0
|
915 |
+
if self.user_id is None: logging.error("<user> token ID not found for SafetyManager!")
|
916 |
+
if self.assistant_id is None: logging.error("<assistant> token ID not found for SafetyManager!")
|
917 |
+
if self.pad_id is None: logging.error("<pad> token ID not found for SafetyManager!"); self.pad_id = 0
|
918 |
+
|
919 |
+
|
920 |
+
def contains_sequence(self, tokens, seq):
|
921 |
+
"""Checks if the list `tokens` contains the sublist `seq`."""
|
922 |
+
if not seq or not tokens or len(tokens) < len(seq):
|
923 |
+
return False
|
924 |
+
seq_len = len(seq)
|
925 |
+
for i in range(len(tokens) - seq_len + 1):
|
926 |
+
if tokens[i : i + seq_len] == seq:
|
927 |
+
return True
|
928 |
+
return False
|
929 |
+
|
930 |
+
def content_filter(self, text_ids):
|
931 |
+
"""Checks if a list of token IDs contains any bad word sequences."""
|
932 |
+
if not isinstance(text_ids, list):
|
933 |
+
logging.warning("Content filter received non-list input.")
|
934 |
+
return True # Default to safe if input is weird
|
935 |
+
for bad_ids in self.bad_word_ids:
|
936 |
+
if self.contains_sequence(text_ids, bad_ids):
|
937 |
+
# Log the detected sequence for debugging
|
938 |
+
detected_word = self.tokenizer.decode(bad_ids)
|
939 |
+
logging.warning(f"Unsafe content detected: Found sequence corresponding to '{detected_word}' (IDs: {bad_ids}).")
|
940 |
+
return False # Unsafe
|
941 |
+
return True # Safe
|
942 |
+
|
943 |
+
def generate_safely(self, prompt, max_new_tokens=50, temperature=0.5, top_k=50):
|
944 |
+
self.model.eval()
|
945 |
+
device = self.main_device # Use stored main_device for generation
|
946 |
+
if isinstance(self.main_device, list): # Handle DataParallel case for generation device - SHOULD BE self.main_device now
|
947 |
+
device = self.main_device[0] # Ensure generation happens on the main GPU
|
948 |
+
|
949 |
+
|
950 |
+
# Encode prompt, ensure it ends appropriately (e.g., with role token + EOS?)
|
951 |
+
# Let's assume the prompt ends like "<user> blah blah </s>" and we need to add "<assistant>"
|
952 |
+
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False).ids
|
953 |
+
|
954 |
+
# Start generation sequence with BOS, prompt, and assistant token
|
955 |
+
# Ensure prompt doesn't already include BOS
|
956 |
+
if prompt_ids and prompt_ids[0] == self.bos_id:
|
957 |
+
input_ids = list(prompt_ids)
|
958 |
+
else:
|
959 |
+
input_ids = [self.bos_id] + list(prompt_ids)
|
960 |
+
|
961 |
+
# Add the assistant token to signal the model to generate the response
|
962 |
+
if self.assistant_id is not None:
|
963 |
+
input_ids.append(self.assistant_id)
|
964 |
+
else:
|
965 |
+
logging.error("Assistant token ID is None, cannot properly start generation.")
|
966 |
+
return "Error: Assistant token not found."
|
967 |
+
|
968 |
+
|
969 |
+
generated_ids = list(input_ids) # Start with the prepared input sequence
|
970 |
+
logging.debug(f"Starting safe generation with initial IDs: {generated_ids}")
|
971 |
+
|
972 |
+
with torch.no_grad():
|
973 |
+
for step in range(max_new_tokens):
|
974 |
+
# Prepare input tensor for this step - only use up to max_seq_len
|
975 |
+
current_input_ids = generated_ids[-CONFIG["max_seq_len"]:]
|
976 |
+
current_input_tensor = torch.tensor([current_input_ids]).to(device)
|
977 |
+
# Create attention mask for the current length
|
978 |
+
attention_mask = torch.ones_like(current_input_tensor)
|
979 |
+
|
980 |
+
# Model forward pass
|
981 |
+
try:
|
982 |
+
outputs = self.model(current_input_tensor, attention_mask=attention_mask)
|
983 |
+
next_token_logits = outputs[:, -1, :] # Logits for the next token
|
984 |
+
if isinstance(self.model, nn.DataParallel): # Handle DP output
|
985 |
+
next_token_logits = next_token_logits.mean(dim=0, keepdim=True) # Average logits from DP
|
986 |
+
except Exception as e:
|
987 |
+
logging.error(f"Model forward pass failed during generation: {e}")
|
988 |
+
break # Stop generation on error
|
989 |
+
|
990 |
+
# --- Safety Check BEFORE sampling ---
|
991 |
+
# Apply penalties to bad word starting tokens if possible
|
992 |
+
# For now, we filter *after* sampling the token
|
993 |
+
|
994 |
+
# Sampling (Temperature, Top-K)
|
995 |
+
if temperature > 0 and temperature != 1.0:
|
996 |
+
next_token_logits = next_token_logits / temperature
|
997 |
+
if top_k > 0 and top_k < next_token_logits.size(-1): # Ensure top_k is valid
|
998 |
+
v, _ = torch.topk(next_token_logits, top_k)
|
999 |
+
# Handle potential NaN/Inf in logits before comparison
|
1000 |
+
safe_logits = torch.nan_to_num(next_token_logits, nan=-float('inf'), posinf=float('inf'), neginf=-float('inf'))
|
1001 |
+
threshold = v[:, [-1]]
|
1002 |
+
safe_logits[safe_logits < threshold] = -float('Inf')
|
1003 |
+
next_token_logits = safe_logits # Use the filtered logits
|
1004 |
+
|
1005 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
1006 |
+
# Handle potential NaNs in probabilities before sampling
|
1007 |
+
if torch.isnan(probs).any():
|
1008 |
+
logging.warning("NaN detected in probabilities before sampling. Replacing with uniform distribution.")
|
1009 |
+
probs = torch.ones_like(probs) / probs.size(-1) # Fallback to uniform
|
1010 |
+
|
1011 |
+
next_token_id = torch.multinomial(probs, num_samples=1).item()
|
1012 |
+
|
1013 |
+
# --- Safety Check AFTER sampling token ---
|
1014 |
+
# Check if adding this token creates a bad sequence
|
1015 |
+
potential_sequence_ids = generated_ids + [next_token_id]
|
1016 |
+
# Check only the newly formed part for bad words for efficiency?
|
1017 |
+
# Let's check the whole sequence for simplicity/robustness for now.
|
1018 |
+
if not self.content_filter(potential_sequence_ids):
|
1019 |
+
logging.warning(f"Potential unsafe token ({next_token_id}, '{self.tokenizer.decode([next_token_id])}') blocked POST-sampling. Stopping generation.")
|
1020 |
+
# Optionally try sampling a different token? For now, just stop.
|
1021 |
+
break
|
1022 |
+
|
1023 |
+
# Add the safe token
|
1024 |
+
generated_ids.append(next_token_id)
|
1025 |
+
|
1026 |
+
# Check for EOS token
|
1027 |
+
if next_token_id == self.eos_id:
|
1028 |
+
logging.debug(f"EOS token generated at step {step+1}. Stopping generation.")
|
1029 |
+
break
|
1030 |
+
|
1031 |
+
# Prevent infinite loops if max tokens reached
|
1032 |
+
if step == max_new_tokens - 1:
|
1033 |
+
logging.debug("Max new tokens reached. Stopping generation.")
|
1034 |
+
# Ensure the sequence ends with EOS if it didn't naturally
|
1035 |
+
if generated_ids[-1] != self.eos_id and self.eos_id is not None:
|
1036 |
+
generated_ids.append(self.eos_id)
|
1037 |
+
|
1038 |
+
self.model.train() # Set model back to training mode
|
1039 |
+
|
1040 |
+
# Decode the generated part (excluding the initial prompt + assistant token)
|
1041 |
+
start_index = len(input_ids)
|
1042 |
+
response_ids = generated_ids[start_index:]
|
1043 |
+
|
1044 |
+
# Decode, skipping special tokens like EOS, BOS, PAD but potentially keeping USER/ASSISTANT
|
1045 |
+
# Let's skip all special tokens for the final output text for clarity.
|
1046 |
+
decoded_text = self.tokenizer.decode(response_ids, skip_special_tokens=True).strip()
|
1047 |
+
|
1048 |
+
return decoded_text
|
1049 |
+
|
1050 |
+
|
1051 |
+
def debug_generation(self, prompt="<user> Tell me about your hobbies."): # Example prompt
|
1052 |
+
logging.info(f"\n--- Debug Generation & Safety Check ---")
|
1053 |
+
# Ensure prompt ends logically for the model (e.g., with user token and EOS)
|
1054 |
+
if not prompt.strip().endswith("</s>"):
|
1055 |
+
if not prompt.strip().endswith("<user>") and not prompt.strip().endswith("<assistant>"):
|
1056 |
+
prompt = prompt.strip() + " </s>" # Add EOS if ends mid-sentence
|
1057 |
+
else:
|
1058 |
+
prompt = prompt.strip() + " </s>" # Add EOS after role token
|
1059 |
+
|
1060 |
+
# Ensure the prompt starts appropriately (e.g., no BOS needed here as generate_safely adds it)
|
1061 |
+
if prompt.startswith("<s>"):
|
1062 |
+
prompt = prompt[len("<s>"):].strip()
|
1063 |
+
|
1064 |
+
|
1065 |
+
generated_response = self.generate_safely(prompt, max_new_tokens=60, temperature=0.7, top_k=50)
|
1066 |
+
|
1067 |
+
logging.info(f"Prompt Sent: '{prompt}'")
|
1068 |
+
logging.info(f"Generated Response: '{generated_response}'")
|
1069 |
+
logging.info("\n--- End Debug Generation ---\n")
|
1070 |
+
|
1071 |
+
|
1072 |
+
class CheckpointManager:
|
1073 |
+
def __init__(self):
|
1074 |
+
# Use checkpoint directory from CONFIG
|
1075 |
+
self.checkpoint_dir = CONFIG["checkpoint_dir"]
|
1076 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
1077 |
+
logging.info(f"Checkpoint directory set to: {self.checkpoint_dir}")
|
1078 |
+
|
1079 |
+
def save(self, model, optimizer, step):
|
1080 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
1081 |
+
# Use a consistent naming scheme based on the directory name if desired
|
1082 |
+
prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
|
1083 |
+
# Ensure step is converted to string if it's passed as something else (e.g., 'final')
|
1084 |
+
step_str = str(step)
|
1085 |
+
filename = f"hrom_{prefix}_step{step_str}_{timestamp}.pt"
|
1086 |
+
path = os.path.join(self.checkpoint_dir, filename)
|
1087 |
+
state = {
|
1088 |
+
"model": model.state_dict(),
|
1089 |
+
"optimizer": optimizer.state_dict(),
|
1090 |
+
"step": step if isinstance(step, int) else -1, # Store step number or -1 for non-numeric steps
|
1091 |
+
"config": CONFIG # Save config with checkpoint
|
1092 |
+
}
|
1093 |
+
logging.info(f"Saving checkpoint to {path}...")
|
1094 |
+
try:
|
1095 |
+
torch.save(state, path)
|
1096 |
+
logging.info(f"Checkpoint saved successfully at step {step_str}.")
|
1097 |
+
self._cleanup_old_checkpoints()
|
1098 |
+
except Exception as e:
|
1099 |
+
logging.error(f"Failed to save checkpoint '{path}': {e}")
|
1100 |
+
|
1101 |
+
def _cleanup_old_checkpoints(self):
|
1102 |
+
max_checkpoints = CONFIG.get("max_checkpoints", 5) # Get from config, default 5
|
1103 |
+
if max_checkpoints <= 0:
|
1104 |
+
return # Keep all checkpoints if max_checkpoints is non-positive
|
1105 |
+
|
1106 |
+
try:
|
1107 |
+
# Filter only files matching the expected pattern (avoid deleting other files)
|
1108 |
+
prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
|
1109 |
+
pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt")
|
1110 |
+
|
1111 |
+
checkpoints = []
|
1112 |
+
for f in os.listdir(self.checkpoint_dir):
|
1113 |
+
match = pattern.match(f)
|
1114 |
+
if match:
|
1115 |
+
filepath = os.path.join(self.checkpoint_dir, f)
|
1116 |
+
checkpoints.append((filepath, os.path.getmtime(filepath)))
|
1117 |
+
|
1118 |
+
# Sort by modification time (oldest first)
|
1119 |
+
checkpoints.sort(key=lambda x: x[1])
|
1120 |
+
|
1121 |
+
num_to_delete = len(checkpoints) - max_checkpoints
|
1122 |
+
if num_to_delete > 0:
|
1123 |
+
#logging.info(f"Max checkpoints ({max_checkpoints}) reached. Removing {num_to_delete} oldest checkpoints.")
|
1124 |
+
for i in range(num_to_delete):
|
1125 |
+
file_to_remove, _ = checkpoints[i]
|
1126 |
+
try:
|
1127 |
+
os.remove(file_to_remove)
|
1128 |
+
#logging.info(f"Removed old checkpoint: {os.path.basename(file_to_remove)}")
|
1129 |
+
except OSError as e:
|
1130 |
+
logging.error(f"Error removing checkpoint {file_to_remove}: {e}")
|
1131 |
+
except Exception as e:
|
1132 |
+
logging.error(f"Error during checkpoint cleanup: {e}")
|
1133 |
+
|
1134 |
+
|
1135 |
+
def load_latest(self, model, optimizer):
|
1136 |
+
try:
|
1137 |
+
# Filter files based on pattern and sort by time
|
1138 |
+
prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
|
1139 |
+
pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt")
|
1140 |
+
checkpoints = []
|
1141 |
+
for f in os.listdir(self.checkpoint_dir):
|
1142 |
+
match = pattern.match(f)
|
1143 |
+
if match:
|
1144 |
+
filepath = os.path.join(self.checkpoint_dir, f)
|
1145 |
+
checkpoints.append((filepath, os.path.getmtime(filepath)))
|
1146 |
+
|
1147 |
+
if not checkpoints:
|
1148 |
+
logging.info("No valid checkpoints found to load.")
|
1149 |
+
return 0 # Start from step 0
|
1150 |
+
|
1151 |
+
# Sort by modification time (newest first)
|
1152 |
+
checkpoints.sort(key=lambda x: x[1], reverse=True)
|
1153 |
+
|
1154 |
+
latest_checkpoint_path, _ = checkpoints[0]
|
1155 |
+
logging.info(f"Loading latest checkpoint from: {latest_checkpoint_path}")
|
1156 |
+
map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
1157 |
+
checkpoint = torch.load(latest_checkpoint_path, map_location=map_location)
|
1158 |
+
|
1159 |
+
# --- Config Compatibility Check (Optional but Recommended) ---
|
1160 |
+
loaded_config = checkpoint.get("config", {})
|
1161 |
+
# Compare key parameters that affect model architecture or data processing
|
1162 |
+
critical_keys = ["dim", "n_layers", "n_heads", "ff_dim", "vocab_size", "max_seq_len", "tokenizer_name"]
|
1163 |
+
mismatched_keys = []
|
1164 |
+
if loaded_config:
|
1165 |
+
for key in critical_keys:
|
1166 |
+
# Check if key exists in both and if they differ
|
1167 |
+
if key in loaded_config and key in CONFIG and loaded_config[key] != CONFIG[key]:
|
1168 |
+
mismatched_keys.append((key, loaded_config[key], CONFIG[key]))
|
1169 |
+
# Check if key missing in current config but present in checkpoint
|
1170 |
+
elif key in loaded_config and key not in CONFIG:
|
1171 |
+
mismatched_keys.append((key, loaded_config[key], "Not in current CONFIG"))
|
1172 |
+
# Check if key missing in checkpoint config but present in current
|
1173 |
+
elif key not in loaded_config and key in CONFIG:
|
1174 |
+
mismatched_keys.append((key, "Not in loaded CONFIG", CONFIG[key]))
|
1175 |
+
|
1176 |
+
|
1177 |
+
if mismatched_keys:
|
1178 |
+
logging.warning("--- CONFIG MISMATCH DETECTED ---")
|
1179 |
+
logging.warning(f"Checkpoint '{os.path.basename(latest_checkpoint_path)}' was saved with different critical parameters:")
|
1180 |
+
for key, loaded_val, current_val in mismatched_keys:
|
1181 |
+
logging.warning(f" - {key}: Checkpoint='{loaded_val}', Current='{current_val}'")
|
1182 |
+
# Decide whether to proceed: raise error, warn, or try anyway
|
1183 |
+
# For now, just warn strongly. Loading might fail or lead to issues.
|
1184 |
+
logging.warning("Proceeding with loading, but results may be unexpected or errors may occur.")
|
1185 |
+
else:
|
1186 |
+
logging.warning("Checkpoint does not contain configuration info. Cannot check compatibility.")
|
1187 |
+
# --- End Config Check ---
|
1188 |
+
|
1189 |
+
|
1190 |
+
try:
|
1191 |
+
# Strict=False can sometimes help load partially, but hides potential issues
|
1192 |
+
model.load_state_dict(checkpoint['model'], strict=True)
|
1193 |
+
except RuntimeError as e:
|
1194 |
+
logging.error(f"Failed to load model state_dict: {e}")
|
1195 |
+
logging.error("This often happens due to architecture mismatch (check CONFIG) or corrupted checkpoint.")
|
1196 |
+
logging.error("Starting training from scratch.")
|
1197 |
+
return 0 # Cannot resume if model loading fails
|
1198 |
+
|
1199 |
+
try:
|
1200 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
1201 |
+
except ValueError as e:
|
1202 |
+
logging.warning(f"Could not load optimizer state_dict: {e}. Optimizer state will be reset.")
|
1203 |
+
# Reinitialize optimizer if state doesn't match? Or just proceed with current state.
|
1204 |
+
# Resetting optimizer state is safer if parameters changed.
|
1205 |
+
optimizer.state = defaultdict(dict) # Reset state
|
1206 |
+
logging.warning("Optimizer state reset.")
|
1207 |
+
except Exception as e:
|
1208 |
+
logging.error(f"Unexpected error loading optimizer state: {e}. Starting training from scratch.")
|
1209 |
+
return 0
|
1210 |
+
|
1211 |
+
start_step = checkpoint.get('step', 0)
|
1212 |
+
# Ensure step is non-negative, resume from next step
|
1213 |
+
start_step = max(0, start_step) + 1 if isinstance(start_step, int) else 0
|
1214 |
+
|
1215 |
+
|
1216 |
+
logging.info(f"Checkpoint loaded successfully. Resuming from optimizer step {start_step}.")
|
1217 |
+
# Move optimizer state tensors to the correct device
|
1218 |
+
for state in optimizer.state.values():
|
1219 |
+
for k, v in state.items():
|
1220 |
+
if isinstance(v, torch.Tensor):
|
1221 |
+
try:
|
1222 |
+
state[k] = v.to(map_location)
|
1223 |
+
except Exception as e:
|
1224 |
+
logging.error(f"Failed to move optimizer tensor '{k}' to device '{map_location}': {e}")
|
1225 |
+
return start_step
|
1226 |
+
|
1227 |
+
except FileNotFoundError:
|
1228 |
+
logging.info(f"No checkpoint directory '{self.checkpoint_dir}' or files found. Starting training from scratch.")
|
1229 |
+
return 0
|
1230 |
+
except Exception as e:
|
1231 |
+
logging.error(f"Error loading checkpoint from '{self.checkpoint_dir}': {e}. Starting training from scratch.")
|
1232 |
+
# Clean up potentially partially loaded model/optimizer?
|
1233 |
+
# Re-initializing might be safer depending on where the error occurred.
|
1234 |
+
# For simplicity, we just return 0 here.
|
1235 |
+
return 0
|
1236 |
+
|
1237 |
+
|
1238 |
+
# --- Training Function ---
|
1239 |
+
|
1240 |
+
def train():
|
1241 |
+
logging.info(f"Starting HROM training process on combined datasets (daily_dialog, empathetic_dialogues, blended_skill_talk, AlekseyKorshuk/persona-chat, elapt1c/Universal-Transformers-RawConvo, papahawk/conversational-01) using {CONFIG['num_gpus']} GPUs...") # Updated log message
|
1242 |
+
logging.info(f"Configuration: {CONFIG}")
|
1243 |
+
|
1244 |
+
# --- Tokenizer Setup ---
|
1245 |
+
tokenizer_trainer = TokenizerTrainer()
|
1246 |
+
tokenizer_path = tokenizer_trainer.tokenizer_path
|
1247 |
+
if not os.path.exists(tokenizer_path):
|
1248 |
+
logging.info(f"Combined tokenizer '{CONFIG['tokenizer_name']}' not found. Training tokenizer...")
|
1249 |
+
try:
|
1250 |
+
# Pass trust_remote_code=True to load_dataset calls inside tokenizer training
|
1251 |
+
tokenizer_trainer.train(CONFIG["datasets"])
|
1252 |
+
except Exception as e:
|
1253 |
+
logging.error(f"Failed during tokenizer training: {e}", exc_info=True)
|
1254 |
+
return # Cannot proceed without a tokenizer
|
1255 |
+
else:
|
1256 |
+
logging.info(f"Loading existing combined tokenizer from {tokenizer_path}")
|
1257 |
+
# Load the tokenizer instance *once* here for shared use
|
1258 |
+
try:
|
1259 |
+
tokenizer = tokenizer_trainer.get_tokenizer()
|
1260 |
+
# Update CONFIG with actual token IDs (useful for downstream)
|
1261 |
+
CONFIG['pad_token_id'] = tokenizer.token_to_id("<pad>")
|
1262 |
+
CONFIG['bos_token_id'] = tokenizer.token_to_id("<s>")
|
1263 |
+
CONFIG['eos_token_id'] = tokenizer.token_to_id("</s>")
|
1264 |
+
logging.info(f"Loaded tokenizer. Vocab size: {tokenizer.get_vocab_size()}. Special IDs: PAD={CONFIG['pad_token_id']}, BOS={CONFIG['bos_token_id']}, EOS={CONFIG['eos_token_id']}")
|
1265 |
+
except (FileNotFoundError, ValueError) as e:
|
1266 |
+
logging.error(f"Failed to load tokenizer: {e}. Cannot continue.")
|
1267 |
+
return
|
1268 |
+
|
1269 |
+
# --- Model Initialization ---
|
1270 |
+
logging.info("Initializing HROM model...")
|
1271 |
+
# Ensure vocab_size in config matches tokenizer
|
1272 |
+
if CONFIG['vocab_size'] != tokenizer.get_vocab_size():
|
1273 |
+
logging.warning(f"Config vocab_size ({CONFIG['vocab_size']}) differs from tokenizer vocab size ({tokenizer.get_vocab_size()}). Using tokenizer's size.")
|
1274 |
+
CONFIG['vocab_size'] = tokenizer.get_vocab_size()
|
1275 |
+
model = HROM()
|
1276 |
+
|
1277 |
+
# --- Calculate and Log Model Parameters ---
|
1278 |
+
total_params = sum(p.numel() for p in model.parameters())
|
1279 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
1280 |
+
logging.info(f"Model initialized. Total parameters: {total_params:,}")
|
1281 |
+
logging.info(f"Trainable parameters: {trainable_params:,}")
|
1282 |
+
logging.info(f"Parameters (Millions): Total={total_params/1e6:.2f}M, Trainable={trainable_params/1e6:.2f}M")
|
1283 |
+
|
1284 |
+
|
1285 |
+
# --- Dataset and DataLoader ---
|
1286 |
+
logging.info("Setting up combined dataset and dataloader...")
|
1287 |
+
try:
|
1288 |
+
logging.info("Pre-loading/caching datasets...")
|
1289 |
+
for ds_name in CONFIG["datasets"]:
|
1290 |
+
logging.info(f"Checking cache for '{ds_name}'...")
|
1291 |
+
try:
|
1292 |
+
# Load just the first example to trigger download/cache check
|
1293 |
+
_ = load_dataset(ds_name, split="train[:1]", download_mode="reuse_cache_if_exists", trust_remote_code=True) # Add trust_remote_code
|
1294 |
+
except Exception as e:
|
1295 |
+
# Log error but try to continue, main dataset loading will handle final error
|
1296 |
+
logging.error(f"Could not pre-check dataset '{ds_name}': {e}")
|
1297 |
+
logging.info("Dataset download/cache check presumed complete.")
|
1298 |
+
|
1299 |
+
# Pass the already loaded tokenizer instance
|
1300 |
+
dataset = CombinedChatDataset(tokenizer)
|
1301 |
+
|
1302 |
+
# Check if dataset is empty after processing
|
1303 |
+
if len(dataset) == 0:
|
1304 |
+
logging.error("Dataset is empty after processing all sources. Cannot train.")
|
1305 |
+
return
|
1306 |
+
|
1307 |
+
dataloader = DataLoader(
|
1308 |
+
dataset,
|
1309 |
+
batch_size=CONFIG["batch_size"],
|
1310 |
+
collate_fn=CombinedChatDataset.collate_fn, # Use static method
|
1311 |
+
shuffle=True,
|
1312 |
+
# Adjust num_workers based on available cores, be conservative
|
1313 |
+
num_workers=min(4 * CONFIG["num_gpus"], os.cpu_count() // 2 if (os.cpu_count() and os.cpu_count() > 1) else 1), # Increased num_workers
|
1314 |
+
pin_memory=torch.cuda.is_available(),
|
1315 |
+
prefetch_factor=2 if torch.cuda.is_available() and os.cpu_count() and os.cpu_count() > 1 else None,
|
1316 |
+
drop_last=False # Keep last batch even if smaller
|
1317 |
+
)
|
1318 |
+
except Exception as e:
|
1319 |
+
logging.error(f"Failed to initialize dataset/dataloader: {e}", exc_info=True)
|
1320 |
+
return
|
1321 |
+
|
1322 |
+
# --- Trainer, Checkpoint, Safety ---
|
1323 |
+
logging.info("Initializing Trainer, Checkpoint Manager, and Safety Manager...")
|
1324 |
+
# Pass the loaded tokenizer instance and main_device to SafetyManager
|
1325 |
+
trainer_obj = HROMTrainer(model, tokenizer)
|
1326 |
+
checkpoint_manager = CheckpointManager() # Uses CONFIG["checkpoint_dir"]
|
1327 |
+
safety = SafetyManager(trainer_obj.model, tokenizer, trainer_obj.main_device) # Pass main_device to SafetyManager
|
1328 |
+
|
1329 |
+
# --- Load Checkpoint ---
|
1330 |
+
start_optimizer_step = checkpoint_manager.load_latest(trainer_obj.model, trainer_obj.optimizer) # Pass potentially DP wrapped model
|
1331 |
+
# Ensure model is on correct device after loading (already handled in trainer init)
|
1332 |
+
|
1333 |
+
# --- Training Loop ---
|
1334 |
+
logging.info(f"Starting training from optimizer step {start_optimizer_step}")
|
1335 |
+
optimizer_step = start_optimizer_step
|
1336 |
+
total_loss_accum = 0.0
|
1337 |
+
# Calculate starting batch step based on loaded optimizer step and grad accum
|
1338 |
+
batch_step = optimizer_step * CONFIG["grad_accum_steps"]
|
1339 |
+
epochs_completed = batch_step // len(dataloader) if len(dataloader) > 0 else 0
|
1340 |
+
start_epoch = epochs_completed # Start from the epoch corresponding to the loaded step
|
1341 |
+
|
1342 |
+
# Estimate total steps (can be useful for LR scheduling if implementing decay)
|
1343 |
+
try:
|
1344 |
+
if len(dataloader) == 0:
|
1345 |
+
raise ValueError("DataLoader has zero length. Cannot estimate total steps.")
|
1346 |
+
total_optimizer_steps = (len(dataloader) * CONFIG["num_epochs"]) // CONFIG["grad_accum_steps"]
|
1347 |
+
logging.info(f"Estimated dataset size: {len(dataset)}")
|
1348 |
+
logging.info(f"Estimated batches per epoch: {len(dataloader)}")
|
1349 |
+
logging.info(f"Gradient Accumulation Steps: {CONFIG['grad_accum_steps']}")
|
1350 |
+
logging.info(f"Effective Batch Size: {CONFIG['batch_size'] * CONFIG['grad_accum_steps']}")
|
1351 |
+
logging.info(f"Target Epochs: {CONFIG['num_epochs']}")
|
1352 |
+
logging.info(f"Estimated total optimizer steps for {CONFIG['num_epochs']} epochs: {total_optimizer_steps}")
|
1353 |
+
except Exception as e:
|
1354 |
+
logging.warning(f"Could not accurately estimate dataloader length or total steps: {e}")
|
1355 |
+
total_optimizer_steps = -1 # Indicate unknown total steps
|
1356 |
+
|
1357 |
+
|
1358 |
+
trainer_obj.model.train() # Ensure model is in training mode
|
1359 |
+
|
1360 |
+
for epoch in range(start_epoch, CONFIG["num_epochs"]):
|
1361 |
+
logging.info(f"--- Starting Epoch {epoch+1}/{CONFIG['num_epochs']} ---")
|
1362 |
+
epoch_loss = 0.0
|
1363 |
+
num_batches_in_epoch = 0
|
1364 |
+
|
1365 |
+
# Use enumerate starting from 1 for batch count if preferred
|
1366 |
+
for i, batch in enumerate(dataloader):
|
1367 |
+
# Check if batch is valid (collate_fn might return None)
|
1368 |
+
if batch is None:
|
1369 |
+
logging.warning(f"Skipping empty batch at step {i} in epoch {epoch+1}")
|
1370 |
+
continue
|
1371 |
+
|
1372 |
+
# Forward and backward pass (scaled loss)
|
1373 |
+
loss = trainer_obj.train_step(batch)
|
1374 |
+
if loss is None or torch.isnan(torch.tensor(loss)) or torch.isinf(torch.tensor(loss)):
|
1375 |
+
logging.error(f"NaN, Inf, or None loss detected: {loss}. Epoch {epoch+1}, Batch {i}, Opt Step {optimizer_step}. Stopping.")
|
1376 |
+
# Try saving a 'nan_inf' checkpoint before exiting
|
1377 |
+
checkpoint_manager.save(trainer_obj.model, trainer_obj.optimizer, f"{optimizer_step}_error") # Pass potentially DP wrapped model
|
1378 |
+
return
|
1379 |
+
|
1380 |
+
total_loss_accum += loss
|
1381 |
+
epoch_loss += loss
|
1382 |
+
num_batches_in_epoch += 1
|
1383 |
+
batch_step += 1 # Increment global batch counter (tracks batches processed)
|
1384 |
+
|
1385 |
+
# Gradient Accumulation Check & Optimizer Step
|
1386 |
+
# Check if it's time to perform an optimizer step
|
1387 |
+
if batch_step % CONFIG["grad_accum_steps"] == 0:
|
1388 |
+
current_lr = trainer_obj.clip_and_step(optimizer_step) # Pass current opt step for LR schedule
|
1389 |
+
|
1390 |
+
# Calculate average loss over accumulation steps for logging
|
1391 |
+
avg_loss = total_loss_accum / CONFIG["grad_accum_steps"]
|
1392 |
+
total_loss_accum = 0.0 # Reset loss accumulator
|
1393 |
+
|
1394 |
+
# Logging (use main_device for logging info)
|
1395 |
+
if optimizer_step % CONFIG["debug_interval"] == 0:
|
1396 |
+
logging.info(f"Epoch {epoch+1} | Opt Step {optimizer_step} | Batch Step {batch_step} | Avg Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | Device: {trainer_obj.main_device}") # Added device info to log
|
1397 |
+
# Trigger debug generation less frequently or based on condition
|
1398 |
+
if optimizer_step % (CONFIG["debug_interval"] * 5) == 0: # e.g., every 5 debug intervals
|
1399 |
+
safety.debug_generation("<user> Hi there! How are you doing today?") # Use a generic debug prompt
|
1400 |
+
|
1401 |
+
# Checkpointing
|
1402 |
+
if optimizer_step > 0 and optimizer_step % CONFIG["checkpoint_interval"] == 0:
|
1403 |
+
logging.info(f"Checkpoint interval reached at optimizer step {optimizer_step}.")
|
1404 |
+
checkpoint_manager.save(trainer_obj.model, trainer_obj.optimizer, optimizer_step) # Pass potentially DP wrapped model
|
1405 |
+
# Optional: Run a generation check after saving checkpoint
|
1406 |
+
safety.debug_generation("<user> Hi! How are you?")
|
1407 |
+
|
1408 |
+
optimizer_step += 1 # Increment optimizer step count *after* performing the step
|
1409 |
+
|
1410 |
+
# --- End of Epoch ---
|
1411 |
+
avg_epoch_loss = epoch_loss / num_batches_in_epoch if num_batches_in_epoch > 0 else 0
|
1412 |
+
logging.info(f"--- Finished Epoch {epoch+1}/{CONFIG['num_epochs']} | Average Epoch Loss: {avg_epoch_loss:.4f} ---")
|
1413 |
+
|
1414 |
+
# Save checkpoint at the end of each epoch
|
1415 |
+
checkpoint_manager.save(trainer_obj.model, trainer_obj.optimizer, f"epoch{epoch+1}_step{optimizer_step}") # Pass potentially DP wrapped model
|
1416 |
+
# Optionally run debug generation at end of epoch
|
1417 |
+
safety.debug_generation("<user> Hi! Whats up?")
|
1418 |
+
|
1419 |
+
|
1420 |
+
logging.info(f"Training finished after {CONFIG['num_epochs']} target epochs.")
|
1421 |
+
# Final save
|
1422 |
+
logging.info("Saving final model state...")
|
1423 |
+
checkpoint_manager.save(trainer_obj.model, trainer_obj.optimizer, f"final_step{optimizer_step}") # Pass potentially DP wrapped model
|
1424 |
+
|
1425 |
+
|
1426 |
+
if __name__ == "__main__":
|
1427 |
+
# Ensures imports happen after setting the env var if script is run directly
|
1428 |
+
train()
|