Defetya commited on
Commit
993bee6
·
verified ·
1 Parent(s): a584f1f

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +450 -0
train.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # 1. IMPORTS
3
+ # ==============================================================================
4
+ import os
5
+ import warnings
6
+ import wandb
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader, Dataset
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from rdkit import Chem, RDLogger
16
+ from datasets import load_dataset, load_from_disk
17
+ from transformers import AutoTokenizer, BertModel, BertConfig
18
+ import pandas as pd
19
+
20
+ # ==============================================================================
21
+ # 2. INITIAL SETUP
22
+ # ==============================================================================
23
+ # Suppress RDKit console output
24
+ RDLogger.DisableLog('rdApp.*')
25
+ # Ignore warnings for cleaner output
26
+ warnings.filterwarnings("ignore")
27
+
28
+ # ==============================================================================
29
+ # 3. MODEL AND LOSS FUNCTION
30
+ # ==============================================================================
31
+ def global_average_pooling(x):
32
+ """Global Average Pooling: from [B, max_len, hid_dim] to [B, hid_dim]"""
33
+ return torch.mean(x, dim=1)
34
+
35
+ class SimSonEncoder(nn.Module):
36
+ """The main encoder model based on BERT."""
37
+ def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):
38
+ super(SimSonEncoder, self).__init__()
39
+ self.bert = BertModel(config, add_pooling_layer=False)
40
+ self.linear = nn.Linear(config.hidden_size, max_len)
41
+ self.dropout = nn.Dropout(dropout)
42
+
43
+ def forward(self, input_ids, attention_mask=None):
44
+ if attention_mask is None:
45
+ attention_mask = input_ids.ne(self.bert.config.pad_token_id)
46
+
47
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
48
+ hidden_states = self.dropout(outputs.last_hidden_state)
49
+ pooled_output = global_average_pooling(hidden_states)
50
+ return self.linear(pooled_output)
51
+
52
+ class ContrastiveLoss(nn.Module):
53
+ """Calculates the contrastive loss for the SimSon model."""
54
+ def __init__(self, temperature=0.2):
55
+ super(ContrastiveLoss, self).__init__()
56
+ self.temperature = temperature
57
+ self.similarity_fn = F.cosine_similarity
58
+
59
+ def forward(self, proj_1, proj_2):
60
+ batch_size = proj_1.shape[0]
61
+ device = proj_1.device
62
+
63
+ # Normalize projections
64
+ z_i = F.normalize(proj_1, p=2, dim=1)
65
+ z_j = F.normalize(proj_2, p=2, dim=1)
66
+
67
+ # Concatenate for similarity matrix calculation
68
+ representations = torch.cat([z_i, z_j], dim=0)
69
+
70
+ # Calculate cosine similarity between all pairs
71
+ similarity_matrix = self.similarity_fn(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
72
+
73
+ # Identify positive pairs (original and its augmentation)
74
+ sim_ij = torch.diag(similarity_matrix, batch_size)
75
+ sim_ji = torch.diag(similarity_matrix, -batch_size)
76
+ positives = torch.cat([sim_ij, sim_ji], dim=0)
77
+
78
+ # Create a mask to exclude self-comparisons
79
+ nominator = torch.exp(positives / self.temperature)
80
+ mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float()
81
+ denominator = mask * torch.exp(similarity_matrix / self.temperature)
82
+
83
+ # Calculate the final loss
84
+ loss = -torch.log(nominator / torch.sum(denominator, dim=1))
85
+ return torch.sum(loss) / (2 * batch_size)
86
+
87
+ # ==============================================================================
88
+ # 4. DATA HANDLING
89
+ # ==============================================================================
90
+ class SmilesEnumerator:
91
+ """Generates randomized SMILES strings for data augmentation."""
92
+ def randomize_smiles(self, smiles):
93
+ try:
94
+ mol = Chem.MolFromSmiles(smiles)
95
+ return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles
96
+ except:
97
+ return smiles
98
+
99
+ class ContrastiveSmilesDataset(Dataset):
100
+ """Dataset for creating pairs of augmented SMILES for contrastive learning."""
101
+ def __init__(self, smiles_list, tokenizer, max_length=512):
102
+ self.smiles_list = smiles_list
103
+ self.tokenizer = tokenizer
104
+ self.max_length = max_length
105
+ self.enumerator = SmilesEnumerator()
106
+
107
+ def __len__(self):
108
+ return len(self.smiles_list)
109
+
110
+ def __getitem__(self, idx):
111
+ original_smiles = self.smiles_list[idx]
112
+
113
+ # Create two different augmentations of the same SMILES
114
+ smiles_1 = self.enumerator.randomize_smiles(original_smiles)
115
+ smiles_2 = self.enumerator.randomize_smiles(original_smiles)
116
+
117
+ # Tokenize and do pad. Padding will be handled by the collate_fn.
118
+ tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
119
+ tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
120
+
121
+ return {
122
+ 'input_ids_1': torch.tensor(tokens_1['input_ids']),
123
+ 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
124
+ 'input_ids_2': torch.tensor(tokens_2['input_ids']),
125
+ 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
126
+ }
127
+
128
+ class PrecomputedContrastiveSmilesDataset(Dataset):
129
+ """
130
+ A Dataset class that reads pre-augmented SMILES pairs from a Parquet file.
131
+ This is significantly faster as it offloads the expensive SMILES randomization
132
+ to a one-time preprocessing step.
133
+ """
134
+ def __init__(self, tokenizer, file_path: str, max_length: int = 512):
135
+ self.tokenizer = tokenizer
136
+ self.max_length = max_length
137
+
138
+ # Load the entire dataset from the Parquet file into memory.
139
+ # This is fast and efficient for subsequent access.
140
+ print(f"Loading pre-computed data from {file_path}...")
141
+ self.data = pd.read_parquet(file_path)
142
+ print("Data loaded successfully.")
143
+
144
+ def __len__(self):
145
+ """Returns the total number of pairs in the dataset."""
146
+ return len(self.data)
147
+
148
+ def __getitem__(self, idx):
149
+ """
150
+ Retrieves a pre-augmented pair, tokenizes it, and returns it
151
+ in the format expected by the DataCollator.
152
+ """
153
+ # Retrieve the pre-augmented pair from the DataFrame
154
+ row = self.data.iloc[idx]
155
+ smiles_1 = row['smiles_1']
156
+ smiles_2 = row['smiles_2']
157
+
158
+ # Tokenize the pair. This operation is fast and remains in the data loader.
159
+ tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
160
+ tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
161
+
162
+ return {
163
+ 'input_ids_1': torch.tensor(tokens_1['input_ids']),
164
+ 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
165
+ 'input_ids_2': torch.tensor(tokens_2['input_ids']),
166
+ 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
167
+ }
168
+
169
+ class PreTokenizedSmilesDataset(Dataset):
170
+ """
171
+ A Dataset that loads a pre-tokenized and pre-padded dataset created
172
+ by the preprocessing script. It uses memory-mapping for instant loads
173
+ and high efficiency.
174
+ """
175
+ def __init__(self, dataset_path: str):
176
+ # Load the dataset from disk. This is very fast due to memory-mapping.
177
+ self.dataset = load_from_disk(dataset_path)
178
+ # Set the format to PyTorch tensors for direct use in the model
179
+ self.dataset.set_format(type='torch', columns=[
180
+ 'input_ids_1', 'attention_mask_1', 'input_ids_2', 'attention_mask_2'
181
+ ])
182
+ print(f"Successfully loaded pre-tokenized dataset from {dataset_path}.")
183
+
184
+ def __len__(self):
185
+ """Returns the total number of items in the dataset."""
186
+ return len(self.dataset)
187
+
188
+ def __getitem__(self, idx):
189
+ """Retrieves a single pre-processed item."""
190
+ return self.dataset[idx]
191
+
192
+
193
+ class DataCollatorWithPadding:
194
+ """
195
+ A collate function that dynamically pads inputs to the longest sequence
196
+ across both augmented views in the batch, ensuring consistent tensor shapes.
197
+ """
198
+ def __init__(self, tokenizer):
199
+ self.tokenizer = tokenizer
200
+
201
+ def __call__(self, features):
202
+ # Create a combined list of features for both views to find the global max length
203
+ combined_features = []
204
+ for feature in features:
205
+ combined_features.append({'input_ids': feature['input_ids_1'], 'attention_mask': feature['attention_mask_1']})
206
+ combined_features.append({'input_ids': feature['input_ids_2'], 'attention_mask': feature['attention_mask_2']})
207
+
208
+ # Pad the combined batch. This ensures all sequences are padded to the same length.
209
+ padded_combined = self.tokenizer.pad(combined_features, padding='longest', return_tensors='pt')
210
+
211
+ # Split the padded tensors back into two views
212
+ batch_size = len(features)
213
+ input_ids_1, input_ids_2 = torch.split(padded_combined['input_ids'], batch_size, dim=0)
214
+ attention_mask_1, attention_mask_2 = torch.split(padded_combined['attention_mask'], batch_size, dim=0)
215
+
216
+ return {
217
+ 'input_ids_1': input_ids_1,
218
+ 'attention_mask_1': attention_mask_1,
219
+ 'input_ids_2': input_ids_2,
220
+ 'attention_mask_2': attention_mask_2,
221
+ }
222
+
223
+ # ==============================================================================
224
+ # 5. TRAINING AND EVALUATION LOOPS
225
+ # ==============================================================================
226
+ def evaluation_step(model, batch, criterion, device):
227
+ """Performs a single evaluation step on a batch of data."""
228
+ input_ids_1 = batch['input_ids_1'].to(device)
229
+ attention_mask_1 = batch['attention_mask_1'].to(device)
230
+ input_ids_2 = batch['input_ids_2'].to(device)
231
+ attention_mask_2 = batch['attention_mask_2'].to(device)
232
+
233
+ combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
234
+ combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
235
+
236
+ with torch.no_grad():
237
+ combined_proj = model(combined_input_ids, combined_attention_mask)
238
+
239
+ batch_size = input_ids_1.size(0)
240
+ proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
241
+
242
+ loss = criterion(proj_1, proj_2)
243
+ return proj_1, proj_2, loss
244
+
245
+ def train_epoch(model, train_loader, optimizer, criterion, device, scheduler, save_path, save_steps):
246
+ model.train()
247
+ total_loss = 0
248
+ progress_bar = tqdm(train_loader, desc="Training Batch", leave=False)
249
+
250
+ for step, batch in enumerate(progress_bar, 1):
251
+ input_ids_1 = batch['input_ids_1'].to(device)
252
+ attention_mask_1 = batch['attention_mask_1'].to(device)
253
+ input_ids_2 = batch['input_ids_2'].to(device)
254
+ attention_mask_2 = batch['attention_mask_2'].to(device)
255
+
256
+ optimizer.zero_grad()
257
+ with torch.autocast(dtype=torch.float16, device_type="cuda"):
258
+ combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
259
+ combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
260
+
261
+ combined_proj = model(combined_input_ids, combined_attention_mask)
262
+
263
+ batch_size = input_ids_1.size(0)
264
+ proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
265
+
266
+ loss = criterion(proj_1, proj_2)
267
+
268
+ loss.backward()
269
+
270
+ optimizer.step()
271
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
272
+ scheduler.step()
273
+
274
+ total_loss += loss.item()
275
+
276
+ progress_bar.set_postfix(loss=f"{loss.item():.4f}")
277
+ wandb.log({
278
+ "train_batch_loss": loss.item(),
279
+ "learning_rate": scheduler.get_last_lr()[0]
280
+ })
281
+ if save_path and step % save_steps == 0:
282
+ torch.save(model.state_dict(), save_path)
283
+ progress_bar.write(f"Checkpoint saved at step {step}")
284
+
285
+ return total_loss / len(train_loader)
286
+
287
+ def validate_epoch(model, val_loader, criterion, device):
288
+ model.eval()
289
+ total_loss = 0
290
+ progress_bar = tqdm(val_loader, desc="Validating", leave=False)
291
+
292
+ for batch in progress_bar:
293
+ _, _, loss = evaluation_step(model, batch, criterion, device)
294
+ total_loss += loss.item()
295
+ print(f'Validation loss: {total_loss / len(val_loader)}')
296
+ return total_loss / len(val_loader)
297
+
298
+ def test_model(model, test_loader, criterion, device):
299
+ model.eval()
300
+ total_loss = 0
301
+ all_similarities = []
302
+ progress_bar = tqdm(test_loader, desc="Testing", leave=False)
303
+
304
+ for batch in progress_bar:
305
+ proj_1, proj_2, loss = evaluation_step(model, batch, criterion, device)
306
+ total_loss += loss.item()
307
+
308
+ proj_1_norm = F.normalize(proj_1, p=2, dim=1)
309
+ proj_2_norm = F.normalize(proj_2, p=2, dim=1)
310
+ batch_similarities = F.cosine_similarity(proj_1_norm, proj_2_norm, dim=1)
311
+ all_similarities.extend(batch_similarities.cpu().numpy())
312
+
313
+ avg_loss = total_loss / len(test_loader)
314
+ avg_sim = np.mean(all_similarities)
315
+ std_sim = np.std(all_similarities)
316
+
317
+ return avg_loss, avg_sim, std_sim
318
+
319
+ # ==============================================================================
320
+ # 6. SINGLE-GPU TRAINING
321
+ # ==============================================================================
322
+ def run_training(model_config, hparams, data_splits):
323
+ """The main function to run the training and evaluation process."""
324
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
325
+ print(f"Using device: {device}")
326
+
327
+ wandb_key = os.getenv("WANDB_API_KEY")
328
+ if wandb_key:
329
+ wandb.login(key=wandb_key)
330
+ wandb.init(
331
+ project="simson-contrastive-learning-single-gpu",
332
+ name=f"run-{wandb.util.generate_id()}",
333
+ config=hparams
334
+ )
335
+ train_smiles, val_smiles, test_smiles = data_splits
336
+
337
+
338
+ tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
339
+
340
+ precomputed_train_path = 'data/splits/train.parquet'
341
+ precomputed_test_path = 'data/splits/test.parquet'
342
+ precomputed_val_path = 'data/splits/validation.parquet'
343
+
344
+ train_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_train_path, max_length=hparams['max_length'])
345
+ test_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_test_path, max_length=hparams['max_length'])
346
+ val_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_val_path, max_length=hparams['max_length'])
347
+
348
+ train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, num_workers=16, prefetch_factor=128, pin_memory=True)
349
+ val_loader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
350
+ test_loader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
351
+ print('Initialized all data. Compiling the model...')
352
+ model = SimSonEncoder(config=model_config, max_len=hparams['max_embeddings']).to(device)
353
+ model = torch.compile(model)
354
+ print(model)
355
+ total_params = sum(p.numel() for p in model.parameters())
356
+
357
+ print(f"Total number of parameters: {total_params // 1_000_000} M")
358
+ wandb.config.update({"total_params_M": total_params // 1_000_000})
359
+
360
+ criterion = ContrastiveLoss(temperature=hparams['temperature']).to(device)
361
+ optimizer = optim.AdamW(model.parameters(), lr=hparams['lr'], weight_decay=1e-5, fused=True)
362
+ print(f"Len of dataloader is {len(train_loader)}, with bs: {len(train_loader) // hparams['batch_size']}")
363
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_mult=1, T_0=int(hparams['epochs'] * len(train_loader)))
364
+ print("Starting training...")
365
+ wandb.watch(model, log='all', log_freq=5000)
366
+
367
+ best_val_loss = float('inf')
368
+ epoch_iterator = tqdm(range(hparams['epochs']), desc="Epochs")
369
+ model.load_state_dict(torch.load(hparams['save_path']))
370
+ val_loss = validate_epoch(model, val_loader, criterion, device)
371
+
372
+ for epoch in epoch_iterator:
373
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device, scheduler, hparams['save_path'], hparams['save_steps'])
374
+ val_loss = validate_epoch(model, val_loader, criterion, device)
375
+ epoch_iterator.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=f"{val_loss:.4f}")
376
+ wandb.log({
377
+ "epoch": epoch + 1,
378
+ "train_epoch_loss": train_loss,
379
+ "val_epoch_loss": val_loss,
380
+ })
381
+
382
+ if val_loss < best_val_loss:
383
+ best_val_loss = val_loss
384
+ torch.save(model.state_dict(), hparams['save_path'])
385
+ epoch_iterator.write(f"Epoch {epoch + 1}: New best model saved with val loss {val_loss:.4f}")
386
+
387
+ epoch_iterator.write("Training complete. Starting final testing...")
388
+ # Load the best model for testing
389
+ model.load_state_dict(torch.load(hparams['save_path']))
390
+
391
+ test_loss, avg_sim, std_sim = test_model(model, test_loader, criterion, device)
392
+
393
+ print("\n--- Test Results ---")
394
+ print(f"Test Loss: {test_loss:.4f}")
395
+ print(f"Average Cosine Similarity: {avg_sim:.4f} \u00B1 {std_sim:.4f}")
396
+ print("--------------------")
397
+
398
+ wandb.log({
399
+ "test_loss": test_loss,
400
+ "avg_cosine_similarity": avg_sim,
401
+ "std_cosine_similarity": std_sim
402
+ })
403
+
404
+ wandb.finish()
405
+
406
+ # ==============================================================================
407
+ # 7. MAIN EXECUTION
408
+ # ==============================================================================
409
+ def main():
410
+ """Main function to configure and run the training process."""
411
+ hparams = {
412
+ 'epochs': 1,
413
+ 'lr': 1e-5,
414
+ 'temperature': 0.05,
415
+ 'batch_size': 64,
416
+ 'max_length': 128,
417
+ 'save_path': "simson_checkpoints/simson_model_single_gpu.bin",
418
+ 'save_steps': 100_000,
419
+ 'max_embeddings': 512,
420
+ }
421
+
422
+ dataset = load_dataset('HoangHa/SMILES-250M')['train']
423
+ smiles_column_name = 'SMILES'
424
+
425
+ total_size = len(dataset)
426
+ test_size = int(0.1 * total_size)
427
+ val_size = int(0.1 * (total_size - test_size))
428
+
429
+ test_smiles = dataset.select(range(test_size))[smiles_column_name]
430
+ val_smiles = dataset.select(range(test_size, test_size + val_size))[smiles_column_name]
431
+ train_smiles = dataset.select(range(test_size + val_size, total_size))[smiles_column_name]
432
+ data_splits = (train_smiles, val_smiles, test_smiles)
433
+ tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
434
+ model_config = BertConfig(
435
+ vocab_size=tokenizer.vocab_size, # Keep your optimal SMILES vocabulary
436
+ hidden_size=768, # 2x increase (768 → 1536)
437
+ num_hidden_layers=12, # ~1.67x increase (12 → 20)
438
+ num_attention_heads=12, # 2x increase (12 → 24)
439
+ intermediate_size=2048, # Traditional size (2048 → 4096)
440
+ max_position_embeddings=512
441
+ )
442
+ save_dir = os.path.dirname(hparams['save_path'])
443
+ if not os.path.exists(save_dir):
444
+ os.makedirs(save_dir)
445
+
446
+ # Directly call the training function for a single-GPU run
447
+ run_training(model_config, hparams, data_splits)
448
+
449
+ if __name__ == '__main__':
450
+ main()