import os import logging from burmese_gpt.models import BurmeseGPT from burmese_gpt.training import BurmeseGPTTrainer from burmese_gpt.data import BurmeseDataset from burmese_gpt.config import ModelConfig, TrainingConfig from torch.utils.data import DataLoader logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) if __name__ == "__main__": model_config = ModelConfig() training_config = TrainingConfig() os.makedirs(training_config.checkpoint_dir, exist_ok=True) logger.info(f"Loading dataset from {training_config.dataset_url}") train_dataset = BurmeseDataset(split="train[:90%]", config=training_config) val_dataset = BurmeseDataset(split="train[90%:]", config=training_config) model_config.vocab_size = train_dataset.tokenizer.vocab_size logger.info(f"Using vocab size: {model_config.vocab_size}") logger.info("Initializing model...") model = BurmeseGPT(model_config) train_loader = DataLoader( train_dataset, batch_size=training_config.batch_size, shuffle=True ) val_loader = DataLoader(val_dataset, batch_size=training_config.batch_size) logger.info("Starting training...") trainer = BurmeseGPTTrainer( model=model, train_loader=train_loader, val_loader=val_loader, config=training_config, ) metrics = trainer.train() logger.info("Training completed!")