Stigall commited on
Commit
b08a6ee
·
1 Parent(s): 6dc7a9e

Upload tiny_trainer.py

Browse files
Files changed (1) hide show
  1. tiny_trainer.py +234 -0
tiny_trainer.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library imports (if any)
2
+ import os
3
+ # Third-party library imports
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import BertForSequenceClassification, BertTokenizerFast
7
+ import torch.optim as optim
8
+ from torch.nn import CrossEntropyLoss
9
+ from torch.utils.data import DataLoader, TensorDataset
10
+ # Local application/library s
11
+ from FallingPlanet.orbit.utils.Metrics import AdvancedMetrics
12
+ from FallingPlanet.orbit.utils.Metrics import TinyEmoBoard
13
+ import torchmetrics
14
+ from tqdm import tqdm
15
+ from FallingPlanet.orbit.utils.callbacks import EarlyStopping
16
+ from FallingPlanet.orbit.models import BertFineTuneTiny
17
+ from itertools import islice
18
+
19
+ class Classifier:
20
+ def __init__(self,model, device, num_labels, log_dir):
21
+ self.model = model.to(device)
22
+ self.device = device
23
+ self.loss_criterion = CrossEntropyLoss()
24
+ self.writer = TinyEmoBoard(log_dir=log_dir)
25
+
26
+
27
+ self.accuracy = torchmetrics.Accuracy(num_classes=num_labels, task='multiclass').to(device)
28
+ self.precision = torchmetrics.Precision(num_classes=num_labels, task='multiclass').to(device)
29
+ self.recall = torchmetrics.Recall(num_classes=num_labels, task='multiclass').to(device)
30
+ self.f1= torchmetrics.F1Score(num_classes=num_labels, task = 'multiclass').to(device)
31
+ self.mcc = torchmetrics.MatthewsCorrCoef(num_classes=num_labels,task = 'multiclass').to(device)
32
+ self.top2_acc = torchmetrics.Accuracy(top_k=2, num_classes=num_labels,task='multiclass').to(device)
33
+
34
+ def compute_loss(self,logits, labels):
35
+ loss = self.loss_criterion(logits,labels)
36
+ return loss
37
+
38
+ def train_step(self, dataloader, optimizer, epoch):
39
+ self.model.train()
40
+ total_loss = 0.0
41
+ # Initialize metric accumulators
42
+ total_accuracy = 0.0
43
+ total_precision = 0.0
44
+ total_recall = 0.0
45
+ total_f1 = 0.0
46
+ total_mcc = 0.0
47
+
48
+ pbar = tqdm(dataloader, desc=f"Training Epoch {epoch}")
49
+
50
+ for batch in pbar:
51
+ input_ids, attention_masks, labels = [x.to(self.device) for x in batch]
52
+
53
+ optimizer.zero_grad()
54
+ outputs = self.model(input_ids, attention_masks)
55
+ loss = self.compute_loss(outputs, labels)
56
+ loss.backward()
57
+ optimizer.step()
58
+
59
+
60
+ total_loss += loss.item()
61
+
62
+ # Update and accumulate metrics
63
+ total_accuracy += self.accuracy(outputs.argmax(dim=1), labels).item()
64
+ total_precision += self.precision(outputs.argmax(dim=1), labels).item()
65
+ total_recall += self.recall(outputs.argmax(dim=1), labels).item()
66
+ total_f1 += self.f1(outputs, labels).item()
67
+ total_mcc += self.mcc(outputs.argmax(dim=1), labels).item()
68
+
69
+ # Update tqdm description with current loss and metrics
70
+ pbar.set_postfix(loss=total_loss / (pbar.n + 1))
71
+
72
+ # Calculate averages
73
+ num_batches = len(dataloader)
74
+ avg_accuracy = total_accuracy / num_batches
75
+ avg_precision = total_precision / num_batches
76
+ avg_recall = total_recall / num_batches
77
+ avg_f1 = total_f1 / num_batches
78
+ avg_mcc = total_mcc / num_batches
79
+ avg_train_loss = total_loss / num_batches
80
+
81
+ # Log metrics to TensorBoard
82
+ self.writer.log_scalar('Training/Average Loss', avg_train_loss, epoch)
83
+ self.writer.log_scalar('Training/Average Accuracy', avg_accuracy, epoch)
84
+ self.writer.log_scalar('Training/Average Precision', avg_precision, epoch)
85
+ self.writer.log_scalar('Training/Average Recall', avg_recall, epoch)
86
+ self.writer.log_scalar('Training/Average F1', avg_f1, epoch)
87
+ self.writer.log_scalar('Training/Average MCC', avg_mcc, epoch)
88
+
89
+ pbar.close()
90
+
91
+
92
+ def val_step(self, dataloader, epoch):
93
+ self.model.eval()
94
+ total_loss = 0.0
95
+ # Initialize metric accumulators
96
+ total_accuracy = 0.0
97
+ total_precision = 0.0
98
+ total_recall = 0.0
99
+ total_f1 = 0.0
100
+ total_mcc = 0.0
101
+
102
+ with torch.no_grad():
103
+ pbar = tqdm(dataloader, desc=f"Validation Epoch {epoch}")
104
+ for batch in pbar:
105
+ input_ids, attention_masks, labels = [x.to(self.device) for x in batch]
106
+
107
+ outputs = self.model(input_ids, attention_masks)
108
+ loss = self.compute_loss(outputs, labels)
109
+
110
+ total_loss += loss.item()
111
+
112
+ # Update and accumulate metrics
113
+ total_accuracy += self.accuracy(outputs.argmax(dim=1), labels).item()
114
+ total_precision += self.precision(outputs.argmax(dim=1), labels).item()
115
+ total_recall += self.recall(outputs.argmax(dim=1), labels).item()
116
+ total_f1 += self.f1(outputs, labels).item()
117
+ total_mcc += self.mcc(outputs.argmax(dim=1), labels).item()
118
+
119
+ # Update tqdm description with current loss and metrics
120
+ pbar.set_postfix(loss=total_loss / (pbar.n + 1))
121
+
122
+ # Calculate averages
123
+ num_batches = len(dataloader)
124
+ avg_val_loss = total_loss / num_batches
125
+ avg_accuracy = total_accuracy / num_batches
126
+ avg_precision = total_precision / num_batches
127
+ avg_recall = total_recall / num_batches
128
+ avg_f1 = total_f1 / num_batches
129
+ avg_mcc = total_mcc / num_batches
130
+
131
+ # Log metrics to TensorBoard
132
+ self.writer.log_scalar('Validation/Average Loss', avg_val_loss, epoch)
133
+ self.writer.log_scalar('Validation/Average Accuracy', avg_accuracy, epoch)
134
+ self.writer.log_scalar('Validation/Average Precision', avg_precision, epoch)
135
+ self.writer.log_scalar('Validation/Average Recall', avg_recall, epoch)
136
+ self.writer.log_scalar('Validation/Average F1', avg_f1, epoch)
137
+ self.writer.log_scalar('Validation/Average MCC', avg_mcc, epoch)
138
+
139
+ pbar.close()
140
+ return avg_val_loss
141
+
142
+
143
+ def test_step(self, dataloader):
144
+ self.model.eval()
145
+ # Initialize aggregated metrics
146
+ aggregated_metrics = {
147
+ 'total_accuracy': 0.0,
148
+ 'total_precision': 0.0,
149
+ 'total_recall': 0.0,
150
+ 'total_f1': 0.0,
151
+ 'total_mcc': 0.0,
152
+ 'total_top_2_acc': 0.0
153
+ }
154
+
155
+ with torch.no_grad():
156
+ pbar = tqdm(dataloader, desc="Testing")
157
+ for batch in pbar:
158
+ input_ids, attention_masks, labels = [x.to(self.device) for x in batch]
159
+ outputs = self.model(input_ids, attention_masks)
160
+
161
+ # Update and accumulate metrics
162
+ aggregated_metrics['total_accuracy'] += self.accuracy(outputs.argmax(dim=1), labels).item()
163
+ aggregated_metrics['total_precision'] += self.precision(outputs.argmax(dim=1), labels).item()
164
+ aggregated_metrics['total_recall'] += self.recall(outputs.argmax(dim=1), labels).item()
165
+ aggregated_metrics['total_f1'] += self.f1(outputs, labels).item()
166
+ aggregated_metrics['total_mcc'] += self.mcc(outputs.argmax(dim=1), labels).item()
167
+ aggregated_metrics['total_top_2_acc'] += self.top2_acc(outputs, labels).item()
168
+
169
+ # Update tqdm description with current metrics
170
+ pbar.set_postfix({
171
+ 'Accuracy': aggregated_metrics['total_accuracy'] / (pbar.n + 1),
172
+ 'MCC': aggregated_metrics['total_mcc'] / (pbar.n + 1)
173
+ })
174
+
175
+ # Calculate average metrics
176
+ num_batches = len(dataloader)
177
+ for key in aggregated_metrics:
178
+ aggregated_metrics[key] /= num_batches
179
+
180
+ return aggregated_metrics
181
+
182
+
183
+
184
+ def main(mode = "full"):
185
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
186
+
187
+
188
+ emotion_data_train = torch.load(r"E:\text_datasets\saved\train_emotion_no_batch_no_batch.pt")
189
+ emotion_data_val = torch.load(r"E:\text_datasets\saved\val_emotion_no_batch_no_batch.pt")
190
+ emotion_data_test = torch.load(r"E:\text_datasets\saved\test_emotion_no_batch_no_batch.pt")
191
+
192
+
193
+
194
+
195
+
196
+
197
+ dataloader_train = DataLoader(emotion_data_train, batch_size=512, shuffle=True)
198
+ dataloader_val = DataLoader(emotion_data_val, batch_size=512)
199
+ dataloader_test = DataLoader(emotion_data_test, batch_size=512)
200
+
201
+ NUM_EMOTION_LABELS = 9
202
+ LOG_DIR = r"EmoBERTv2-tiny\logging"
203
+
204
+
205
+ model = BertFineTuneTiny(num_tasks=1, num_labels=[9])
206
+ optimizer = torch.optim.AdamW(model.parameters(),lr =1e-5, weight_decay=1e-6)
207
+ classifier = Classifier(model, device, NUM_EMOTION_LABELS, LOG_DIR)
208
+
209
+ if mode in ["train", "full"]:
210
+ # Your training logic here
211
+ early_stopping = EarlyStopping(patience=50, min_delta=1e-8) # Initialize Early Stopping
212
+ num_epochs = 75
213
+ for epoch in range(num_epochs):
214
+ classifier.train_step(dataloader_train, optimizer, epoch)
215
+ val_loss = classifier.val_step(dataloader_val, epoch)
216
+
217
+ if early_stopping.step(val_loss, classifier.model):
218
+ print("Early stopping triggered. Restoring best model weights.")
219
+ classifier.model.load_state_dict(early_stopping.best_state)
220
+ break
221
+
222
+ if early_stopping.best_state is not None:
223
+ torch.save(early_stopping.best_state, 'EmoBERTv2-tiny.pth')
224
+
225
+ if mode in ["test", "full"]:
226
+ if os.path.exists('EmoBERTv2-tiny.pth'):
227
+ classifier.model.load_state_dict(torch.load('EmoBERTv2-tiny.pth'))
228
+ # Assuming you have test_step implemented in classifier
229
+ test_results = classifier.test_step(dataloader_test)
230
+ print("Test Results:", test_results)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ main(mode="full") # or "train" or "test"