burmese-gpt / scripts /train.py
Zai
Reformat code with black
9cfe63d
raw
history blame
1.47 kB
import os
import logging
from burmese_gpt.models import BurmeseGPT
from burmese_gpt.training import BurmeseGPTTrainer
from burmese_gpt.data import BurmeseDataset
from burmese_gpt.config import ModelConfig, TrainingConfig
from torch.utils.data import DataLoader
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
model_config = ModelConfig()
training_config = TrainingConfig()
os.makedirs(training_config.checkpoint_dir, exist_ok=True)
logger.info(f"Loading dataset from {training_config.dataset_url}")
train_dataset = BurmeseDataset(split="train[:90%]", config=training_config)
val_dataset = BurmeseDataset(split="train[90%:]", config=training_config)
model_config.vocab_size = train_dataset.tokenizer.vocab_size
logger.info(f"Using vocab size: {model_config.vocab_size}")
logger.info("Initializing model...")
model = BurmeseGPT(model_config)
train_loader = DataLoader(
train_dataset, batch_size=training_config.batch_size, shuffle=True
)
val_loader = DataLoader(val_dataset, batch_size=training_config.batch_size)
logger.info("Starting training...")
trainer = BurmeseGPTTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
config=training_config,
)
metrics = trainer.train()
logger.info("Training completed!")