Spaces:
Sleeping
Sleeping
Zai
commited on
Commit
·
9cfe63d
1
Parent(s):
4cc4af5
Reformat code with black
Browse files- burmese_gpt/config.py +3 -1
- burmese_gpt/data/__init__.py +1 -1
- burmese_gpt/data/dataset.py +4 -4
- burmese_gpt/models/__init__.py +1 -1
- burmese_gpt/models/model.py +6 -3
- burmese_gpt/training/__init__.py +1 -1
- burmese_gpt/training/trainer.py +19 -15
- scripts/sample.py +1 -1
- scripts/space.py +2 -4
- scripts/train.py +7 -15
- setup.py +2 -2
- tests/test_data.py +2 -2
burmese_gpt/config.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
|
|
|
3 |
@dataclass
|
4 |
class ModelConfig:
|
5 |
vocab_size: int = 30000
|
@@ -8,6 +9,7 @@ class ModelConfig:
|
|
8 |
num_layers: int = 4
|
9 |
dropout: float = 0.1
|
10 |
|
|
|
11 |
@dataclass
|
12 |
class TrainingConfig:
|
13 |
batch_size: int = 32
|
@@ -17,4 +19,4 @@ class TrainingConfig:
|
|
17 |
log_dir: str = "logs"
|
18 |
save_every: int = 1
|
19 |
eval_every: int = 1
|
20 |
-
dataset_url: str = "zaibutcooler/wiki-burmese"
|
|
|
1 |
from dataclasses import dataclass
|
2 |
|
3 |
+
|
4 |
@dataclass
|
5 |
class ModelConfig:
|
6 |
vocab_size: int = 30000
|
|
|
9 |
num_layers: int = 4
|
10 |
dropout: float = 0.1
|
11 |
|
12 |
+
|
13 |
@dataclass
|
14 |
class TrainingConfig:
|
15 |
batch_size: int = 32
|
|
|
19 |
log_dir: str = "logs"
|
20 |
save_every: int = 1
|
21 |
eval_every: int = 1
|
22 |
+
dataset_url: str = "zaibutcooler/wiki-burmese"
|
burmese_gpt/data/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .dataset import BurmeseDataset
|
|
|
1 |
+
from .dataset import BurmeseDataset
|
burmese_gpt/data/dataset.py
CHANGED
@@ -6,7 +6,7 @@ from burmese_gpt.config import TrainingConfig
|
|
6 |
|
7 |
|
8 |
class BurmeseDataset(Dataset):
|
9 |
-
def __init__(self, split="train", max_length=128,config:TrainingConfig=None):
|
10 |
self.dataset = load_dataset(config.dataset_url, split=split)
|
11 |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
|
12 |
if self.tokenizer.pad_token is None:
|
@@ -23,9 +23,9 @@ class BurmeseDataset(Dataset):
|
|
23 |
truncation=True,
|
24 |
max_length=self.max_length,
|
25 |
padding="max_length",
|
26 |
-
return_tensors="pt"
|
27 |
)
|
28 |
return {
|
29 |
"input_ids": encodings["input_ids"].squeeze(),
|
30 |
-
"attention_mask": encodings["attention_mask"].squeeze()
|
31 |
-
}
|
|
|
6 |
|
7 |
|
8 |
class BurmeseDataset(Dataset):
|
9 |
+
def __init__(self, split="train", max_length=128, config: TrainingConfig = None):
|
10 |
self.dataset = load_dataset(config.dataset_url, split=split)
|
11 |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
|
12 |
if self.tokenizer.pad_token is None:
|
|
|
23 |
truncation=True,
|
24 |
max_length=self.max_length,
|
25 |
padding="max_length",
|
26 |
+
return_tensors="pt",
|
27 |
)
|
28 |
return {
|
29 |
"input_ids": encodings["input_ids"].squeeze(),
|
30 |
+
"attention_mask": encodings["attention_mask"].squeeze(),
|
31 |
+
}
|
burmese_gpt/models/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .model import BurmeseGPT
|
|
|
1 |
+
from .model import BurmeseGPT
|
burmese_gpt/models/model.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
from torch import nn
|
3 |
from burmese_gpt.config import ModelConfig
|
4 |
|
|
|
5 |
class BurmeseGPT(nn.Module):
|
6 |
def __init__(self, config: ModelConfig):
|
7 |
super(BurmeseGPT, self).__init__()
|
@@ -18,9 +19,11 @@ class BurmeseGPT(nn.Module):
|
|
18 |
d_model=config.embed_dim,
|
19 |
nhead=config.num_heads,
|
20 |
dropout=config.dropout,
|
21 |
-
batch_first=True
|
|
|
|
|
|
|
22 |
)
|
23 |
-
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
|
24 |
|
25 |
# Final projection layer
|
26 |
self.fc = nn.Linear(config.embed_dim, config.vocab_size)
|
@@ -55,4 +58,4 @@ class BurmeseGPT(nn.Module):
|
|
55 |
x = self.transformer(x, mask)
|
56 |
|
57 |
# Final projection
|
58 |
-
return self.fc(x)
|
|
|
2 |
from torch import nn
|
3 |
from burmese_gpt.config import ModelConfig
|
4 |
|
5 |
+
|
6 |
class BurmeseGPT(nn.Module):
|
7 |
def __init__(self, config: ModelConfig):
|
8 |
super(BurmeseGPT, self).__init__()
|
|
|
19 |
d_model=config.embed_dim,
|
20 |
nhead=config.num_heads,
|
21 |
dropout=config.dropout,
|
22 |
+
batch_first=True,
|
23 |
+
)
|
24 |
+
self.transformer = nn.TransformerEncoder(
|
25 |
+
encoder_layer, num_layers=config.num_layers
|
26 |
)
|
|
|
27 |
|
28 |
# Final projection layer
|
29 |
self.fc = nn.Linear(config.embed_dim, config.vocab_size)
|
|
|
58 |
x = self.transformer(x, mask)
|
59 |
|
60 |
# Final projection
|
61 |
+
return self.fc(x)
|
burmese_gpt/training/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .trainer import BurmeseGPTTrainer
|
|
|
1 |
+
from .trainer import BurmeseGPTTrainer
|
burmese_gpt/training/trainer.py
CHANGED
@@ -8,8 +8,9 @@ from burmese_gpt.config import TrainingConfig
|
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
|
|
11 |
class BurmeseGPTTrainer:
|
12 |
-
def __init__(self, model, train_loader, val_loader, config:TrainingConfig):
|
13 |
"""
|
14 |
Trainer for BurmeseGPT model
|
15 |
|
@@ -32,7 +33,9 @@ class BurmeseGPTTrainer:
|
|
32 |
self.optimizer = AdamW(
|
33 |
model.parameters(),
|
34 |
lr=config.learning_rate,
|
35 |
-
weight_decay=
|
|
|
|
|
36 |
)
|
37 |
|
38 |
# Loss function (ignoring padding tokens)
|
@@ -59,8 +62,7 @@ class BurmeseGPTTrainer:
|
|
59 |
|
60 |
# Calculate loss (same as original)
|
61 |
loss = self.criterion(
|
62 |
-
outputs.reshape(-1, outputs.size(-1)),
|
63 |
-
targets.reshape(-1)
|
64 |
)
|
65 |
|
66 |
# Backward pass
|
@@ -85,8 +87,7 @@ class BurmeseGPTTrainer:
|
|
85 |
|
86 |
outputs = self.model(inputs)
|
87 |
loss = self.criterion(
|
88 |
-
outputs.reshape(-1, outputs.size(-1)),
|
89 |
-
targets.reshape(-1)
|
90 |
)
|
91 |
total_loss += loss.item()
|
92 |
|
@@ -99,19 +100,19 @@ class BurmeseGPTTrainer:
|
|
99 |
Returns:
|
100 |
Dictionary with training metrics
|
101 |
"""
|
102 |
-
metrics = {
|
103 |
-
best_loss = float(
|
104 |
|
105 |
for epoch in range(1, self.config.num_epochs + 1):
|
106 |
logger.info(f"Epoch {epoch}/{self.config.num_epochs}")
|
107 |
|
108 |
# Training
|
109 |
train_loss = self.train_epoch()
|
110 |
-
metrics[
|
111 |
|
112 |
# Validation
|
113 |
val_loss = self.validate()
|
114 |
-
metrics[
|
115 |
|
116 |
logger.info(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
117 |
|
@@ -129,8 +130,11 @@ class BurmeseGPTTrainer:
|
|
129 |
|
130 |
def save_checkpoint(self, filename: str):
|
131 |
"""Save model checkpoint"""
|
132 |
-
torch.save(
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
11 |
+
|
12 |
class BurmeseGPTTrainer:
|
13 |
+
def __init__(self, model, train_loader, val_loader, config: TrainingConfig):
|
14 |
"""
|
15 |
Trainer for BurmeseGPT model
|
16 |
|
|
|
33 |
self.optimizer = AdamW(
|
34 |
model.parameters(),
|
35 |
lr=config.learning_rate,
|
36 |
+
weight_decay=(
|
37 |
+
config.weight_decay if hasattr(config, "weight_decay") else 0.01
|
38 |
+
),
|
39 |
)
|
40 |
|
41 |
# Loss function (ignoring padding tokens)
|
|
|
62 |
|
63 |
# Calculate loss (same as original)
|
64 |
loss = self.criterion(
|
65 |
+
outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)
|
|
|
66 |
)
|
67 |
|
68 |
# Backward pass
|
|
|
87 |
|
88 |
outputs = self.model(inputs)
|
89 |
loss = self.criterion(
|
90 |
+
outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)
|
|
|
91 |
)
|
92 |
total_loss += loss.item()
|
93 |
|
|
|
100 |
Returns:
|
101 |
Dictionary with training metrics
|
102 |
"""
|
103 |
+
metrics = {"train_loss": [], "val_loss": []}
|
104 |
+
best_loss = float("inf")
|
105 |
|
106 |
for epoch in range(1, self.config.num_epochs + 1):
|
107 |
logger.info(f"Epoch {epoch}/{self.config.num_epochs}")
|
108 |
|
109 |
# Training
|
110 |
train_loss = self.train_epoch()
|
111 |
+
metrics["train_loss"].append(train_loss)
|
112 |
|
113 |
# Validation
|
114 |
val_loss = self.validate()
|
115 |
+
metrics["val_loss"].append(val_loss)
|
116 |
|
117 |
logger.info(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
118 |
|
|
|
130 |
|
131 |
def save_checkpoint(self, filename: str):
|
132 |
"""Save model checkpoint"""
|
133 |
+
torch.save(
|
134 |
+
{
|
135 |
+
"model_state_dict": self.model.state_dict(),
|
136 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
137 |
+
"config": self.config,
|
138 |
+
},
|
139 |
+
f"{self.config.checkpoint_dir}/{filename}",
|
140 |
+
)
|
scripts/sample.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
# TODO: Need to sample
|
2 |
|
3 |
if __name__ == "__main__":
|
4 |
-
print("Sampling the Burmese GPT model...")
|
|
|
1 |
# TODO: Need to sample
|
2 |
|
3 |
if __name__ == "__main__":
|
4 |
+
print("Sampling the Burmese GPT model...")
|
scripts/space.py
CHANGED
@@ -2,9 +2,7 @@ import streamlit as st
|
|
2 |
|
3 |
# Set up the page layout
|
4 |
st.set_page_config(
|
5 |
-
page_title="Burmese GPT",
|
6 |
-
page_icon=":speech_balloon:",
|
7 |
-
layout="wide"
|
8 |
)
|
9 |
|
10 |
# Create a sidebar with a title and a brief description
|
@@ -49,4 +47,4 @@ elif selected_view == "Chat Interface":
|
|
49 |
response_area = st.text_area("Model:", height=200, disabled=True)
|
50 |
|
51 |
# Add some space between the input and output areas
|
52 |
-
st.write("")
|
|
|
2 |
|
3 |
# Set up the page layout
|
4 |
st.set_page_config(
|
5 |
+
page_title="Burmese GPT", page_icon=":speech_balloon:", layout="wide"
|
|
|
|
|
6 |
)
|
7 |
|
8 |
# Create a sidebar with a title and a brief description
|
|
|
47 |
response_area = st.text_area("Model:", height=200, disabled=True)
|
48 |
|
49 |
# Add some space between the input and output areas
|
50 |
+
st.write("")
|
scripts/train.py
CHANGED
@@ -9,13 +9,12 @@ from burmese_gpt.config import ModelConfig, TrainingConfig
|
|
9 |
from torch.utils.data import DataLoader
|
10 |
|
11 |
logging.basicConfig(
|
12 |
-
level=logging.INFO,
|
13 |
-
format='%(asctime)s - %(levelname)s - %(message)s'
|
14 |
)
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
|
18 |
-
if __name__ ==
|
19 |
model_config = ModelConfig()
|
20 |
training_config = TrainingConfig()
|
21 |
|
@@ -23,8 +22,8 @@ if __name__ == '__main__':
|
|
23 |
|
24 |
logger.info(f"Loading dataset from {training_config.dataset_url}")
|
25 |
|
26 |
-
train_dataset = BurmeseDataset(split="train[:90%]")
|
27 |
-
val_dataset = BurmeseDataset(split="train[90%:]")
|
28 |
|
29 |
model_config.vocab_size = train_dataset.tokenizer.vocab_size
|
30 |
logger.info(f"Using vocab size: {model_config.vocab_size}")
|
@@ -33,25 +32,18 @@ if __name__ == '__main__':
|
|
33 |
model = BurmeseGPT(model_config)
|
34 |
|
35 |
train_loader = DataLoader(
|
36 |
-
train_dataset,
|
37 |
-
batch_size=training_config.batch_size,
|
38 |
-
shuffle=True
|
39 |
-
)
|
40 |
-
val_loader = DataLoader(
|
41 |
-
val_dataset,
|
42 |
-
batch_size=training_config.batch_size
|
43 |
)
|
|
|
44 |
|
45 |
logger.info("Starting training...")
|
46 |
trainer = BurmeseGPTTrainer(
|
47 |
model=model,
|
48 |
train_loader=train_loader,
|
49 |
val_loader=val_loader,
|
50 |
-
config=training_config
|
51 |
)
|
52 |
|
53 |
metrics = trainer.train()
|
54 |
|
55 |
logger.info("Training completed!")
|
56 |
-
|
57 |
-
|
|
|
9 |
from torch.utils.data import DataLoader
|
10 |
|
11 |
logging.basicConfig(
|
12 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
13 |
)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
|
17 |
+
if __name__ == "__main__":
|
18 |
model_config = ModelConfig()
|
19 |
training_config = TrainingConfig()
|
20 |
|
|
|
22 |
|
23 |
logger.info(f"Loading dataset from {training_config.dataset_url}")
|
24 |
|
25 |
+
train_dataset = BurmeseDataset(split="train[:90%]", config=training_config)
|
26 |
+
val_dataset = BurmeseDataset(split="train[90%:]", config=training_config)
|
27 |
|
28 |
model_config.vocab_size = train_dataset.tokenizer.vocab_size
|
29 |
logger.info(f"Using vocab size: {model_config.vocab_size}")
|
|
|
32 |
model = BurmeseGPT(model_config)
|
33 |
|
34 |
train_loader = DataLoader(
|
35 |
+
train_dataset, batch_size=training_config.batch_size, shuffle=True
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
+
val_loader = DataLoader(val_dataset, batch_size=training_config.batch_size)
|
38 |
|
39 |
logger.info("Starting training...")
|
40 |
trainer = BurmeseGPTTrainer(
|
41 |
model=model,
|
42 |
train_loader=train_loader,
|
43 |
val_loader=val_loader,
|
44 |
+
config=training_config,
|
45 |
)
|
46 |
|
47 |
metrics = trainer.train()
|
48 |
|
49 |
logger.info("Training completed!")
|
|
|
|
setup.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from setuptools import
|
2 |
|
3 |
setup(
|
4 |
name="burmese_gpt",
|
5 |
version="0.1",
|
6 |
author="Sai Ye Yint Aung",
|
7 |
-
)
|
|
|
1 |
+
from setuptools import setup
|
2 |
|
3 |
setup(
|
4 |
name="burmese_gpt",
|
5 |
version="0.1",
|
6 |
author="Sai Ye Yint Aung",
|
7 |
+
)
|
tests/test_data.py
CHANGED
@@ -6,8 +6,8 @@ from burmese_gpt.config import TrainingConfig
|
|
6 |
class TestData(unittest.TestCase):
|
7 |
def test_data(self):
|
8 |
training_config = TrainingConfig()
|
9 |
-
train_dataset = BurmeseDataset(split="train[:90%]",config=training_config)
|
10 |
-
val_dataset = BurmeseDataset(split="train[90%:]",config=training_config)
|
11 |
|
12 |
self.assertIsNotNone(train_dataset)
|
13 |
self.assertIsNotNone(val_dataset)
|
|
|
6 |
class TestData(unittest.TestCase):
|
7 |
def test_data(self):
|
8 |
training_config = TrainingConfig()
|
9 |
+
train_dataset = BurmeseDataset(split="train[:90%]", config=training_config)
|
10 |
+
val_dataset = BurmeseDataset(split="train[90%:]", config=training_config)
|
11 |
|
12 |
self.assertIsNotNone(train_dataset)
|
13 |
self.assertIsNotNone(val_dataset)
|