simson_base / train.py
Defetya's picture
Upload train.py with huggingface_hub
993bee6 verified
# ==============================================================================
# 1. IMPORTS
# ==============================================================================
import os
import warnings
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
from rdkit import Chem, RDLogger
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, BertModel, BertConfig
import pandas as pd
# ==============================================================================
# 2. INITIAL SETUP
# ==============================================================================
# Suppress RDKit console output
RDLogger.DisableLog('rdApp.*')
# Ignore warnings for cleaner output
warnings.filterwarnings("ignore")
# ==============================================================================
# 3. MODEL AND LOSS FUNCTION
# ==============================================================================
def global_average_pooling(x):
"""Global Average Pooling: from [B, max_len, hid_dim] to [B, hid_dim]"""
return torch.mean(x, dim=1)
class SimSonEncoder(nn.Module):
"""The main encoder model based on BERT."""
def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):
super(SimSonEncoder, self).__init__()
self.bert = BertModel(config, add_pooling_layer=False)
self.linear = nn.Linear(config.hidden_size, max_len)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids, attention_mask=None):
if attention_mask is None:
attention_mask = input_ids.ne(self.bert.config.pad_token_id)
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = self.dropout(outputs.last_hidden_state)
pooled_output = global_average_pooling(hidden_states)
return self.linear(pooled_output)
class ContrastiveLoss(nn.Module):
"""Calculates the contrastive loss for the SimSon model."""
def __init__(self, temperature=0.2):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
self.similarity_fn = F.cosine_similarity
def forward(self, proj_1, proj_2):
batch_size = proj_1.shape[0]
device = proj_1.device
# Normalize projections
z_i = F.normalize(proj_1, p=2, dim=1)
z_j = F.normalize(proj_2, p=2, dim=1)
# Concatenate for similarity matrix calculation
representations = torch.cat([z_i, z_j], dim=0)
# Calculate cosine similarity between all pairs
similarity_matrix = self.similarity_fn(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
# Identify positive pairs (original and its augmentation)
sim_ij = torch.diag(similarity_matrix, batch_size)
sim_ji = torch.diag(similarity_matrix, -batch_size)
positives = torch.cat([sim_ij, sim_ji], dim=0)
# Create a mask to exclude self-comparisons
nominator = torch.exp(positives / self.temperature)
mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float()
denominator = mask * torch.exp(similarity_matrix / self.temperature)
# Calculate the final loss
loss = -torch.log(nominator / torch.sum(denominator, dim=1))
return torch.sum(loss) / (2 * batch_size)
# ==============================================================================
# 4. DATA HANDLING
# ==============================================================================
class SmilesEnumerator:
"""Generates randomized SMILES strings for data augmentation."""
def randomize_smiles(self, smiles):
try:
mol = Chem.MolFromSmiles(smiles)
return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles
except:
return smiles
class ContrastiveSmilesDataset(Dataset):
"""Dataset for creating pairs of augmented SMILES for contrastive learning."""
def __init__(self, smiles_list, tokenizer, max_length=512):
self.smiles_list = smiles_list
self.tokenizer = tokenizer
self.max_length = max_length
self.enumerator = SmilesEnumerator()
def __len__(self):
return len(self.smiles_list)
def __getitem__(self, idx):
original_smiles = self.smiles_list[idx]
# Create two different augmentations of the same SMILES
smiles_1 = self.enumerator.randomize_smiles(original_smiles)
smiles_2 = self.enumerator.randomize_smiles(original_smiles)
# Tokenize and do pad. Padding will be handled by the collate_fn.
tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
return {
'input_ids_1': torch.tensor(tokens_1['input_ids']),
'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
'input_ids_2': torch.tensor(tokens_2['input_ids']),
'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
}
class PrecomputedContrastiveSmilesDataset(Dataset):
"""
A Dataset class that reads pre-augmented SMILES pairs from a Parquet file.
This is significantly faster as it offloads the expensive SMILES randomization
to a one-time preprocessing step.
"""
def __init__(self, tokenizer, file_path: str, max_length: int = 512):
self.tokenizer = tokenizer
self.max_length = max_length
# Load the entire dataset from the Parquet file into memory.
# This is fast and efficient for subsequent access.
print(f"Loading pre-computed data from {file_path}...")
self.data = pd.read_parquet(file_path)
print("Data loaded successfully.")
def __len__(self):
"""Returns the total number of pairs in the dataset."""
return len(self.data)
def __getitem__(self, idx):
"""
Retrieves a pre-augmented pair, tokenizes it, and returns it
in the format expected by the DataCollator.
"""
# Retrieve the pre-augmented pair from the DataFrame
row = self.data.iloc[idx]
smiles_1 = row['smiles_1']
smiles_2 = row['smiles_2']
# Tokenize the pair. This operation is fast and remains in the data loader.
tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
return {
'input_ids_1': torch.tensor(tokens_1['input_ids']),
'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
'input_ids_2': torch.tensor(tokens_2['input_ids']),
'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
}
class PreTokenizedSmilesDataset(Dataset):
"""
A Dataset that loads a pre-tokenized and pre-padded dataset created
by the preprocessing script. It uses memory-mapping for instant loads
and high efficiency.
"""
def __init__(self, dataset_path: str):
# Load the dataset from disk. This is very fast due to memory-mapping.
self.dataset = load_from_disk(dataset_path)
# Set the format to PyTorch tensors for direct use in the model
self.dataset.set_format(type='torch', columns=[
'input_ids_1', 'attention_mask_1', 'input_ids_2', 'attention_mask_2'
])
print(f"Successfully loaded pre-tokenized dataset from {dataset_path}.")
def __len__(self):
"""Returns the total number of items in the dataset."""
return len(self.dataset)
def __getitem__(self, idx):
"""Retrieves a single pre-processed item."""
return self.dataset[idx]
class DataCollatorWithPadding:
"""
A collate function that dynamically pads inputs to the longest sequence
across both augmented views in the batch, ensuring consistent tensor shapes.
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, features):
# Create a combined list of features for both views to find the global max length
combined_features = []
for feature in features:
combined_features.append({'input_ids': feature['input_ids_1'], 'attention_mask': feature['attention_mask_1']})
combined_features.append({'input_ids': feature['input_ids_2'], 'attention_mask': feature['attention_mask_2']})
# Pad the combined batch. This ensures all sequences are padded to the same length.
padded_combined = self.tokenizer.pad(combined_features, padding='longest', return_tensors='pt')
# Split the padded tensors back into two views
batch_size = len(features)
input_ids_1, input_ids_2 = torch.split(padded_combined['input_ids'], batch_size, dim=0)
attention_mask_1, attention_mask_2 = torch.split(padded_combined['attention_mask'], batch_size, dim=0)
return {
'input_ids_1': input_ids_1,
'attention_mask_1': attention_mask_1,
'input_ids_2': input_ids_2,
'attention_mask_2': attention_mask_2,
}
# ==============================================================================
# 5. TRAINING AND EVALUATION LOOPS
# ==============================================================================
def evaluation_step(model, batch, criterion, device):
"""Performs a single evaluation step on a batch of data."""
input_ids_1 = batch['input_ids_1'].to(device)
attention_mask_1 = batch['attention_mask_1'].to(device)
input_ids_2 = batch['input_ids_2'].to(device)
attention_mask_2 = batch['attention_mask_2'].to(device)
combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
with torch.no_grad():
combined_proj = model(combined_input_ids, combined_attention_mask)
batch_size = input_ids_1.size(0)
proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
loss = criterion(proj_1, proj_2)
return proj_1, proj_2, loss
def train_epoch(model, train_loader, optimizer, criterion, device, scheduler, save_path, save_steps):
model.train()
total_loss = 0
progress_bar = tqdm(train_loader, desc="Training Batch", leave=False)
for step, batch in enumerate(progress_bar, 1):
input_ids_1 = batch['input_ids_1'].to(device)
attention_mask_1 = batch['attention_mask_1'].to(device)
input_ids_2 = batch['input_ids_2'].to(device)
attention_mask_2 = batch['attention_mask_2'].to(device)
optimizer.zero_grad()
with torch.autocast(dtype=torch.float16, device_type="cuda"):
combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
combined_proj = model(combined_input_ids, combined_attention_mask)
batch_size = input_ids_1.size(0)
proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
loss = criterion(proj_1, proj_2)
loss.backward()
optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scheduler.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=f"{loss.item():.4f}")
wandb.log({
"train_batch_loss": loss.item(),
"learning_rate": scheduler.get_last_lr()[0]
})
if save_path and step % save_steps == 0:
torch.save(model.state_dict(), save_path)
progress_bar.write(f"Checkpoint saved at step {step}")
return total_loss / len(train_loader)
def validate_epoch(model, val_loader, criterion, device):
model.eval()
total_loss = 0
progress_bar = tqdm(val_loader, desc="Validating", leave=False)
for batch in progress_bar:
_, _, loss = evaluation_step(model, batch, criterion, device)
total_loss += loss.item()
print(f'Validation loss: {total_loss / len(val_loader)}')
return total_loss / len(val_loader)
def test_model(model, test_loader, criterion, device):
model.eval()
total_loss = 0
all_similarities = []
progress_bar = tqdm(test_loader, desc="Testing", leave=False)
for batch in progress_bar:
proj_1, proj_2, loss = evaluation_step(model, batch, criterion, device)
total_loss += loss.item()
proj_1_norm = F.normalize(proj_1, p=2, dim=1)
proj_2_norm = F.normalize(proj_2, p=2, dim=1)
batch_similarities = F.cosine_similarity(proj_1_norm, proj_2_norm, dim=1)
all_similarities.extend(batch_similarities.cpu().numpy())
avg_loss = total_loss / len(test_loader)
avg_sim = np.mean(all_similarities)
std_sim = np.std(all_similarities)
return avg_loss, avg_sim, std_sim
# ==============================================================================
# 6. SINGLE-GPU TRAINING
# ==============================================================================
def run_training(model_config, hparams, data_splits):
"""The main function to run the training and evaluation process."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
wandb_key = os.getenv("WANDB_API_KEY")
if wandb_key:
wandb.login(key=wandb_key)
wandb.init(
project="simson-contrastive-learning-single-gpu",
name=f"run-{wandb.util.generate_id()}",
config=hparams
)
train_smiles, val_smiles, test_smiles = data_splits
tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
precomputed_train_path = 'data/splits/train.parquet'
precomputed_test_path = 'data/splits/test.parquet'
precomputed_val_path = 'data/splits/validation.parquet'
train_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_train_path, max_length=hparams['max_length'])
test_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_test_path, max_length=hparams['max_length'])
val_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_val_path, max_length=hparams['max_length'])
train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, num_workers=16, prefetch_factor=128, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
print('Initialized all data. Compiling the model...')
model = SimSonEncoder(config=model_config, max_len=hparams['max_embeddings']).to(device)
model = torch.compile(model)
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params // 1_000_000} M")
wandb.config.update({"total_params_M": total_params // 1_000_000})
criterion = ContrastiveLoss(temperature=hparams['temperature']).to(device)
optimizer = optim.AdamW(model.parameters(), lr=hparams['lr'], weight_decay=1e-5, fused=True)
print(f"Len of dataloader is {len(train_loader)}, with bs: {len(train_loader) // hparams['batch_size']}")
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_mult=1, T_0=int(hparams['epochs'] * len(train_loader)))
print("Starting training...")
wandb.watch(model, log='all', log_freq=5000)
best_val_loss = float('inf')
epoch_iterator = tqdm(range(hparams['epochs']), desc="Epochs")
model.load_state_dict(torch.load(hparams['save_path']))
val_loss = validate_epoch(model, val_loader, criterion, device)
for epoch in epoch_iterator:
train_loss = train_epoch(model, train_loader, optimizer, criterion, device, scheduler, hparams['save_path'], hparams['save_steps'])
val_loss = validate_epoch(model, val_loader, criterion, device)
epoch_iterator.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=f"{val_loss:.4f}")
wandb.log({
"epoch": epoch + 1,
"train_epoch_loss": train_loss,
"val_epoch_loss": val_loss,
})
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), hparams['save_path'])
epoch_iterator.write(f"Epoch {epoch + 1}: New best model saved with val loss {val_loss:.4f}")
epoch_iterator.write("Training complete. Starting final testing...")
# Load the best model for testing
model.load_state_dict(torch.load(hparams['save_path']))
test_loss, avg_sim, std_sim = test_model(model, test_loader, criterion, device)
print("\n--- Test Results ---")
print(f"Test Loss: {test_loss:.4f}")
print(f"Average Cosine Similarity: {avg_sim:.4f} \u00B1 {std_sim:.4f}")
print("--------------------")
wandb.log({
"test_loss": test_loss,
"avg_cosine_similarity": avg_sim,
"std_cosine_similarity": std_sim
})
wandb.finish()
# ==============================================================================
# 7. MAIN EXECUTION
# ==============================================================================
def main():
"""Main function to configure and run the training process."""
hparams = {
'epochs': 1,
'lr': 1e-5,
'temperature': 0.05,
'batch_size': 64,
'max_length': 128,
'save_path': "simson_checkpoints/simson_model_single_gpu.bin",
'save_steps': 100_000,
'max_embeddings': 512,
}
dataset = load_dataset('HoangHa/SMILES-250M')['train']
smiles_column_name = 'SMILES'
total_size = len(dataset)
test_size = int(0.1 * total_size)
val_size = int(0.1 * (total_size - test_size))
test_smiles = dataset.select(range(test_size))[smiles_column_name]
val_smiles = dataset.select(range(test_size, test_size + val_size))[smiles_column_name]
train_smiles = dataset.select(range(test_size + val_size, total_size))[smiles_column_name]
data_splits = (train_smiles, val_smiles, test_smiles)
tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
model_config = BertConfig(
vocab_size=tokenizer.vocab_size, # Keep your optimal SMILES vocabulary
hidden_size=768, # 2x increase (768 → 1536)
num_hidden_layers=12, # ~1.67x increase (12 → 20)
num_attention_heads=12, # 2x increase (12 → 24)
intermediate_size=2048, # Traditional size (2048 → 4096)
max_position_embeddings=512
)
save_dir = os.path.dirname(hparams['save_path'])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Directly call the training function for a single-GPU run
run_training(model_config, hparams, data_splits)
if __name__ == '__main__':
main()