File size: 4,403 Bytes
d6f1c68
 
 
 
 
 
 
 
 
 
9cfe63d
d6f1c68
9cfe63d
d6f1c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfe63d
 
 
d6f1c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfe63d
d6f1c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfe63d
d6f1c68
 
 
 
 
 
 
 
 
 
 
 
9cfe63d
 
d6f1c68
 
 
 
 
 
9cfe63d
d6f1c68
 
 
9cfe63d
d6f1c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfe63d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from torch.optim import AdamW
from tqdm import tqdm
import logging
from typing import Dict

from burmese_gpt.config import TrainingConfig

logger = logging.getLogger(__name__)


class BurmeseGPTTrainer:
    def __init__(self, model, train_loader, val_loader, config: TrainingConfig):
        """
        Trainer for BurmeseGPT model

        Args:
            model: Initialized BurmeseGPT model
            train_loader: Training DataLoader
            val_loader: Validation DataLoader
            config: Training configuration
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Move model to device
        self.model.to(self.device)

        # Initialize optimizer (using same settings as your original)
        self.optimizer = AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=(
                config.weight_decay if hasattr(config, "weight_decay") else 0.01
            ),
        )

        # Loss function (ignoring padding tokens)
        self.criterion = torch.nn.CrossEntropyLoss(
            ignore_index=train_loader.dataset.tokenizer.pad_token_id
        )

    def train_epoch(self) -> float:
        """Run one training epoch, return average loss"""
        self.model.train()
        total_loss = 0

        for batch in tqdm(self.train_loader, desc="Training"):
            # Move batch to device
            input_ids = batch["input_ids"].to(self.device)

            # Shift inputs and targets (as in your original code)
            inputs = input_ids[:, :-1]
            targets = input_ids[:, 1:]

            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(inputs)

            # Calculate loss (same as original)
            loss = self.criterion(
                outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)
            )

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss / len(self.train_loader)

    def validate(self) -> float:
        """Run validation, return average loss"""
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validating"):
                input_ids = batch["input_ids"].to(self.device)
                inputs = input_ids[:, :-1]
                targets = input_ids[:, 1:]

                outputs = self.model(inputs)
                loss = self.criterion(
                    outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)
                )
                total_loss += loss.item()

        return total_loss / len(self.val_loader)

    def train(self) -> Dict[str, list]:
        """
        Full training loop

        Returns:
            Dictionary with training metrics
        """
        metrics = {"train_loss": [], "val_loss": []}
        best_loss = float("inf")

        for epoch in range(1, self.config.num_epochs + 1):
            logger.info(f"Epoch {epoch}/{self.config.num_epochs}")

            # Training
            train_loss = self.train_epoch()
            metrics["train_loss"].append(train_loss)

            # Validation
            val_loss = self.validate()
            metrics["val_loss"].append(val_loss)

            logger.info(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

            # Save best model
            if val_loss < best_loss:
                best_loss = val_loss
                self.save_checkpoint("best_model.pth")
                logger.info("Saved best model")

            # Save periodic checkpoint
            if epoch % self.config.save_every == 0:
                self.save_checkpoint(f"epoch_{epoch}.pth")

        return metrics

    def save_checkpoint(self, filename: str):
        """Save model checkpoint"""
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "config": self.config,
            },
            f"{self.config.checkpoint_dir}/{filename}",
        )