Zai commited on
Commit
9cfe63d
·
1 Parent(s): 4cc4af5

Reformat code with black

Browse files
burmese_gpt/config.py CHANGED
@@ -1,5 +1,6 @@
1
  from dataclasses import dataclass
2
 
 
3
  @dataclass
4
  class ModelConfig:
5
  vocab_size: int = 30000
@@ -8,6 +9,7 @@ class ModelConfig:
8
  num_layers: int = 4
9
  dropout: float = 0.1
10
 
 
11
  @dataclass
12
  class TrainingConfig:
13
  batch_size: int = 32
@@ -17,4 +19,4 @@ class TrainingConfig:
17
  log_dir: str = "logs"
18
  save_every: int = 1
19
  eval_every: int = 1
20
- dataset_url: str = "zaibutcooler/wiki-burmese"
 
1
  from dataclasses import dataclass
2
 
3
+
4
  @dataclass
5
  class ModelConfig:
6
  vocab_size: int = 30000
 
9
  num_layers: int = 4
10
  dropout: float = 0.1
11
 
12
+
13
  @dataclass
14
  class TrainingConfig:
15
  batch_size: int = 32
 
19
  log_dir: str = "logs"
20
  save_every: int = 1
21
  eval_every: int = 1
22
+ dataset_url: str = "zaibutcooler/wiki-burmese"
burmese_gpt/data/__init__.py CHANGED
@@ -1 +1 @@
1
- from .dataset import BurmeseDataset
 
1
+ from .dataset import BurmeseDataset
burmese_gpt/data/dataset.py CHANGED
@@ -6,7 +6,7 @@ from burmese_gpt.config import TrainingConfig
6
 
7
 
8
  class BurmeseDataset(Dataset):
9
- def __init__(self, split="train", max_length=128,config:TrainingConfig=None):
10
  self.dataset = load_dataset(config.dataset_url, split=split)
11
  self.tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
12
  if self.tokenizer.pad_token is None:
@@ -23,9 +23,9 @@ class BurmeseDataset(Dataset):
23
  truncation=True,
24
  max_length=self.max_length,
25
  padding="max_length",
26
- return_tensors="pt"
27
  )
28
  return {
29
  "input_ids": encodings["input_ids"].squeeze(),
30
- "attention_mask": encodings["attention_mask"].squeeze()
31
- }
 
6
 
7
 
8
  class BurmeseDataset(Dataset):
9
+ def __init__(self, split="train", max_length=128, config: TrainingConfig = None):
10
  self.dataset = load_dataset(config.dataset_url, split=split)
11
  self.tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
12
  if self.tokenizer.pad_token is None:
 
23
  truncation=True,
24
  max_length=self.max_length,
25
  padding="max_length",
26
+ return_tensors="pt",
27
  )
28
  return {
29
  "input_ids": encodings["input_ids"].squeeze(),
30
+ "attention_mask": encodings["attention_mask"].squeeze(),
31
+ }
burmese_gpt/models/__init__.py CHANGED
@@ -1 +1 @@
1
- from .model import BurmeseGPT
 
1
+ from .model import BurmeseGPT
burmese_gpt/models/model.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from torch import nn
3
  from burmese_gpt.config import ModelConfig
4
 
 
5
  class BurmeseGPT(nn.Module):
6
  def __init__(self, config: ModelConfig):
7
  super(BurmeseGPT, self).__init__()
@@ -18,9 +19,11 @@ class BurmeseGPT(nn.Module):
18
  d_model=config.embed_dim,
19
  nhead=config.num_heads,
20
  dropout=config.dropout,
21
- batch_first=True
 
 
 
22
  )
23
- self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
24
 
25
  # Final projection layer
26
  self.fc = nn.Linear(config.embed_dim, config.vocab_size)
@@ -55,4 +58,4 @@ class BurmeseGPT(nn.Module):
55
  x = self.transformer(x, mask)
56
 
57
  # Final projection
58
- return self.fc(x)
 
2
  from torch import nn
3
  from burmese_gpt.config import ModelConfig
4
 
5
+
6
  class BurmeseGPT(nn.Module):
7
  def __init__(self, config: ModelConfig):
8
  super(BurmeseGPT, self).__init__()
 
19
  d_model=config.embed_dim,
20
  nhead=config.num_heads,
21
  dropout=config.dropout,
22
+ batch_first=True,
23
+ )
24
+ self.transformer = nn.TransformerEncoder(
25
+ encoder_layer, num_layers=config.num_layers
26
  )
 
27
 
28
  # Final projection layer
29
  self.fc = nn.Linear(config.embed_dim, config.vocab_size)
 
58
  x = self.transformer(x, mask)
59
 
60
  # Final projection
61
+ return self.fc(x)
burmese_gpt/training/__init__.py CHANGED
@@ -1 +1 @@
1
- from .trainer import BurmeseGPTTrainer
 
1
+ from .trainer import BurmeseGPTTrainer
burmese_gpt/training/trainer.py CHANGED
@@ -8,8 +8,9 @@ from burmese_gpt.config import TrainingConfig
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
11
  class BurmeseGPTTrainer:
12
- def __init__(self, model, train_loader, val_loader, config:TrainingConfig):
13
  """
14
  Trainer for BurmeseGPT model
15
 
@@ -32,7 +33,9 @@ class BurmeseGPTTrainer:
32
  self.optimizer = AdamW(
33
  model.parameters(),
34
  lr=config.learning_rate,
35
- weight_decay=config.weight_decay if hasattr(config, 'weight_decay') else 0.01
 
 
36
  )
37
 
38
  # Loss function (ignoring padding tokens)
@@ -59,8 +62,7 @@ class BurmeseGPTTrainer:
59
 
60
  # Calculate loss (same as original)
61
  loss = self.criterion(
62
- outputs.reshape(-1, outputs.size(-1)),
63
- targets.reshape(-1)
64
  )
65
 
66
  # Backward pass
@@ -85,8 +87,7 @@ class BurmeseGPTTrainer:
85
 
86
  outputs = self.model(inputs)
87
  loss = self.criterion(
88
- outputs.reshape(-1, outputs.size(-1)),
89
- targets.reshape(-1)
90
  )
91
  total_loss += loss.item()
92
 
@@ -99,19 +100,19 @@ class BurmeseGPTTrainer:
99
  Returns:
100
  Dictionary with training metrics
101
  """
102
- metrics = {'train_loss': [], 'val_loss': []}
103
- best_loss = float('inf')
104
 
105
  for epoch in range(1, self.config.num_epochs + 1):
106
  logger.info(f"Epoch {epoch}/{self.config.num_epochs}")
107
 
108
  # Training
109
  train_loss = self.train_epoch()
110
- metrics['train_loss'].append(train_loss)
111
 
112
  # Validation
113
  val_loss = self.validate()
114
- metrics['val_loss'].append(val_loss)
115
 
116
  logger.info(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
117
 
@@ -129,8 +130,11 @@ class BurmeseGPTTrainer:
129
 
130
  def save_checkpoint(self, filename: str):
131
  """Save model checkpoint"""
132
- torch.save({
133
- 'model_state_dict': self.model.state_dict(),
134
- 'optimizer_state_dict': self.optimizer.state_dict(),
135
- 'config': self.config
136
- }, f"{self.config.checkpoint_dir}/{filename}")
 
 
 
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+
12
  class BurmeseGPTTrainer:
13
+ def __init__(self, model, train_loader, val_loader, config: TrainingConfig):
14
  """
15
  Trainer for BurmeseGPT model
16
 
 
33
  self.optimizer = AdamW(
34
  model.parameters(),
35
  lr=config.learning_rate,
36
+ weight_decay=(
37
+ config.weight_decay if hasattr(config, "weight_decay") else 0.01
38
+ ),
39
  )
40
 
41
  # Loss function (ignoring padding tokens)
 
62
 
63
  # Calculate loss (same as original)
64
  loss = self.criterion(
65
+ outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)
 
66
  )
67
 
68
  # Backward pass
 
87
 
88
  outputs = self.model(inputs)
89
  loss = self.criterion(
90
+ outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)
 
91
  )
92
  total_loss += loss.item()
93
 
 
100
  Returns:
101
  Dictionary with training metrics
102
  """
103
+ metrics = {"train_loss": [], "val_loss": []}
104
+ best_loss = float("inf")
105
 
106
  for epoch in range(1, self.config.num_epochs + 1):
107
  logger.info(f"Epoch {epoch}/{self.config.num_epochs}")
108
 
109
  # Training
110
  train_loss = self.train_epoch()
111
+ metrics["train_loss"].append(train_loss)
112
 
113
  # Validation
114
  val_loss = self.validate()
115
+ metrics["val_loss"].append(val_loss)
116
 
117
  logger.info(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
118
 
 
130
 
131
  def save_checkpoint(self, filename: str):
132
  """Save model checkpoint"""
133
+ torch.save(
134
+ {
135
+ "model_state_dict": self.model.state_dict(),
136
+ "optimizer_state_dict": self.optimizer.state_dict(),
137
+ "config": self.config,
138
+ },
139
+ f"{self.config.checkpoint_dir}/{filename}",
140
+ )
scripts/sample.py CHANGED
@@ -1,4 +1,4 @@
1
  # TODO: Need to sample
2
 
3
  if __name__ == "__main__":
4
- print("Sampling the Burmese GPT model...")
 
1
  # TODO: Need to sample
2
 
3
  if __name__ == "__main__":
4
+ print("Sampling the Burmese GPT model...")
scripts/space.py CHANGED
@@ -2,9 +2,7 @@ import streamlit as st
2
 
3
  # Set up the page layout
4
  st.set_page_config(
5
- page_title="Burmese GPT",
6
- page_icon=":speech_balloon:",
7
- layout="wide"
8
  )
9
 
10
  # Create a sidebar with a title and a brief description
@@ -49,4 +47,4 @@ elif selected_view == "Chat Interface":
49
  response_area = st.text_area("Model:", height=200, disabled=True)
50
 
51
  # Add some space between the input and output areas
52
- st.write("")
 
2
 
3
  # Set up the page layout
4
  st.set_page_config(
5
+ page_title="Burmese GPT", page_icon=":speech_balloon:", layout="wide"
 
 
6
  )
7
 
8
  # Create a sidebar with a title and a brief description
 
47
  response_area = st.text_area("Model:", height=200, disabled=True)
48
 
49
  # Add some space between the input and output areas
50
+ st.write("")
scripts/train.py CHANGED
@@ -9,13 +9,12 @@ from burmese_gpt.config import ModelConfig, TrainingConfig
9
  from torch.utils.data import DataLoader
10
 
11
  logging.basicConfig(
12
- level=logging.INFO,
13
- format='%(asctime)s - %(levelname)s - %(message)s'
14
  )
15
  logger = logging.getLogger(__name__)
16
 
17
 
18
- if __name__ == '__main__':
19
  model_config = ModelConfig()
20
  training_config = TrainingConfig()
21
 
@@ -23,8 +22,8 @@ if __name__ == '__main__':
23
 
24
  logger.info(f"Loading dataset from {training_config.dataset_url}")
25
 
26
- train_dataset = BurmeseDataset(split="train[:90%]") # First 90% for training
27
- val_dataset = BurmeseDataset(split="train[90%:]") # Last 10% for validation
28
 
29
  model_config.vocab_size = train_dataset.tokenizer.vocab_size
30
  logger.info(f"Using vocab size: {model_config.vocab_size}")
@@ -33,25 +32,18 @@ if __name__ == '__main__':
33
  model = BurmeseGPT(model_config)
34
 
35
  train_loader = DataLoader(
36
- train_dataset,
37
- batch_size=training_config.batch_size,
38
- shuffle=True
39
- )
40
- val_loader = DataLoader(
41
- val_dataset,
42
- batch_size=training_config.batch_size
43
  )
 
44
 
45
  logger.info("Starting training...")
46
  trainer = BurmeseGPTTrainer(
47
  model=model,
48
  train_loader=train_loader,
49
  val_loader=val_loader,
50
- config=training_config
51
  )
52
 
53
  metrics = trainer.train()
54
 
55
  logger.info("Training completed!")
56
-
57
-
 
9
  from torch.utils.data import DataLoader
10
 
11
  logging.basicConfig(
12
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
 
13
  )
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
+ if __name__ == "__main__":
18
  model_config = ModelConfig()
19
  training_config = TrainingConfig()
20
 
 
22
 
23
  logger.info(f"Loading dataset from {training_config.dataset_url}")
24
 
25
+ train_dataset = BurmeseDataset(split="train[:90%]", config=training_config)
26
+ val_dataset = BurmeseDataset(split="train[90%:]", config=training_config)
27
 
28
  model_config.vocab_size = train_dataset.tokenizer.vocab_size
29
  logger.info(f"Using vocab size: {model_config.vocab_size}")
 
32
  model = BurmeseGPT(model_config)
33
 
34
  train_loader = DataLoader(
35
+ train_dataset, batch_size=training_config.batch_size, shuffle=True
 
 
 
 
 
 
36
  )
37
+ val_loader = DataLoader(val_dataset, batch_size=training_config.batch_size)
38
 
39
  logger.info("Starting training...")
40
  trainer = BurmeseGPTTrainer(
41
  model=model,
42
  train_loader=train_loader,
43
  val_loader=val_loader,
44
+ config=training_config,
45
  )
46
 
47
  metrics = trainer.train()
48
 
49
  logger.info("Training completed!")
 
 
setup.py CHANGED
@@ -1,7 +1,7 @@
1
- from setuptools import setup
2
 
3
  setup(
4
  name="burmese_gpt",
5
  version="0.1",
6
  author="Sai Ye Yint Aung",
7
- )
 
1
+ from setuptools import setup
2
 
3
  setup(
4
  name="burmese_gpt",
5
  version="0.1",
6
  author="Sai Ye Yint Aung",
7
+ )
tests/test_data.py CHANGED
@@ -6,8 +6,8 @@ from burmese_gpt.config import TrainingConfig
6
  class TestData(unittest.TestCase):
7
  def test_data(self):
8
  training_config = TrainingConfig()
9
- train_dataset = BurmeseDataset(split="train[:90%]",config=training_config)
10
- val_dataset = BurmeseDataset(split="train[90%:]",config=training_config)
11
 
12
  self.assertIsNotNone(train_dataset)
13
  self.assertIsNotNone(val_dataset)
 
6
  class TestData(unittest.TestCase):
7
  def test_data(self):
8
  training_config = TrainingConfig()
9
+ train_dataset = BurmeseDataset(split="train[:90%]", config=training_config)
10
+ val_dataset = BurmeseDataset(split="train[90%:]", config=training_config)
11
 
12
  self.assertIsNotNone(train_dataset)
13
  self.assertIsNotNone(val_dataset)