Spaces:
Sleeping
Sleeping
import torch | |
from torch.optim import AdamW | |
from tqdm import tqdm | |
import logging | |
from typing import Dict | |
from burmese_gpt.config import TrainingConfig | |
logger = logging.getLogger(__name__) | |
class BurmeseGPTTrainer: | |
def __init__(self, model, train_loader, val_loader, config: TrainingConfig): | |
""" | |
Trainer for BurmeseGPT model | |
Args: | |
model: Initialized BurmeseGPT model | |
train_loader: Training DataLoader | |
val_loader: Validation DataLoader | |
config: Training configuration | |
""" | |
self.model = model | |
self.train_loader = train_loader | |
self.val_loader = val_loader | |
self.config = config | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Move model to device | |
self.model.to(self.device) | |
# Initialize optimizer (using same settings as your original) | |
self.optimizer = AdamW( | |
model.parameters(), | |
lr=config.learning_rate, | |
weight_decay=( | |
config.weight_decay if hasattr(config, "weight_decay") else 0.01 | |
), | |
) | |
# Loss function (ignoring padding tokens) | |
self.criterion = torch.nn.CrossEntropyLoss( | |
ignore_index=train_loader.dataset.tokenizer.pad_token_id | |
) | |
def train_epoch(self) -> float: | |
"""Run one training epoch, return average loss""" | |
self.model.train() | |
total_loss = 0 | |
for batch in tqdm(self.train_loader, desc="Training"): | |
# Move batch to device | |
input_ids = batch["input_ids"].to(self.device) | |
# Shift inputs and targets (as in your original code) | |
inputs = input_ids[:, :-1] | |
targets = input_ids[:, 1:] | |
# Forward pass | |
self.optimizer.zero_grad() | |
outputs = self.model(inputs) | |
# Calculate loss (same as original) | |
loss = self.criterion( | |
outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1) | |
) | |
# Backward pass | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
self.optimizer.step() | |
total_loss += loss.item() | |
return total_loss / len(self.train_loader) | |
def validate(self) -> float: | |
"""Run validation, return average loss""" | |
self.model.eval() | |
total_loss = 0 | |
with torch.no_grad(): | |
for batch in tqdm(self.val_loader, desc="Validating"): | |
input_ids = batch["input_ids"].to(self.device) | |
inputs = input_ids[:, :-1] | |
targets = input_ids[:, 1:] | |
outputs = self.model(inputs) | |
loss = self.criterion( | |
outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1) | |
) | |
total_loss += loss.item() | |
return total_loss / len(self.val_loader) | |
def train(self) -> Dict[str, list]: | |
""" | |
Full training loop | |
Returns: | |
Dictionary with training metrics | |
""" | |
metrics = {"train_loss": [], "val_loss": []} | |
best_loss = float("inf") | |
for epoch in range(1, self.config.num_epochs + 1): | |
logger.info(f"Epoch {epoch}/{self.config.num_epochs}") | |
# Training | |
train_loss = self.train_epoch() | |
metrics["train_loss"].append(train_loss) | |
# Validation | |
val_loss = self.validate() | |
metrics["val_loss"].append(val_loss) | |
logger.info(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") | |
# Save best model | |
if val_loss < best_loss: | |
best_loss = val_loss | |
self.save_checkpoint("best_model.pth") | |
logger.info("Saved best model") | |
# Save periodic checkpoint | |
if epoch % self.config.save_every == 0: | |
self.save_checkpoint(f"epoch_{epoch}.pth") | |
return metrics | |
def save_checkpoint(self, filename: str): | |
"""Save model checkpoint""" | |
torch.save( | |
{ | |
"model_state_dict": self.model.state_dict(), | |
"optimizer_state_dict": self.optimizer.state_dict(), | |
"config": self.config, | |
}, | |
f"{self.config.checkpoint_dir}/{filename}", | |
) | |