dhrumii commited on
Commit
9e816ce
·
verified ·
1 Parent(s): 513cd85

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +373 -0
model.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from datasets import load_dataset
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import classification_report, confusion_matrix
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+ import seaborn as sns
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from torchvision import transforms, models
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+
19
+ # Set random seeds for reproducibility
20
+ torch.manual_seed(42)
21
+ np.random.seed(42)
22
+
23
+ # Configuration
24
+ CONFIG = {
25
+ 'img_size': 224,
26
+ 'batch_size': 16, # Reduced batch size
27
+ 'num_epochs': 30,
28
+ 'learning_rate': 0.0001,
29
+ 'patience': 7,
30
+ 'device': 'cuda' if torch.cuda.is_available() else 'cpu',
31
+ 'num_workers': 0, # Set to 0 to avoid multiprocessing issues
32
+ 'model_save_path': 'best_trash_classifier.pth',
33
+ }
34
+
35
+ print(f"Using device: {CONFIG['device']}")
36
+
37
+ # Memory-Efficient Dataset Class
38
+ class TrashDatasetLazy(Dataset):
39
+ def __init__(self, dataset, indices, transform=None):
40
+ self.dataset = dataset
41
+ self.indices = indices
42
+ self.transform = transform
43
+
44
+ def __len__(self):
45
+ return len(self.indices)
46
+
47
+ def __getitem__(self, idx):
48
+ actual_idx = self.indices[idx]
49
+ item = self.dataset[actual_idx]
50
+
51
+ image = item['image']
52
+ label = item['label']
53
+
54
+ # Convert to PIL Image if needed
55
+ if not isinstance(image, Image.Image):
56
+ image = Image.fromarray(np.array(image))
57
+
58
+ # Convert to RGB if grayscale
59
+ if image.mode != 'RGB':
60
+ image = image.convert('RGB')
61
+
62
+ if self.transform:
63
+ image = self.transform(image)
64
+
65
+ return image, label
66
+
67
+ # Data Augmentation and Normalization
68
+ train_transform = transforms.Compose([
69
+ transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
70
+ transforms.RandomHorizontalFlip(p=0.5),
71
+ transforms.RandomRotation(15),
72
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
73
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ val_transform = transforms.Compose([
79
+ transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
82
+ ])
83
+
84
+ # Load Dataset (streaming mode)
85
+ print("\n" + "="*60)
86
+ print("LOADING DATASET")
87
+ print("="*60)
88
+
89
+ ds = load_dataset("rootstrap-org/waste-classifier", split="train")
90
+ print(f"Dataset loaded successfully!")
91
+ print(f"Total samples: {len(ds)}")
92
+
93
+ # Get class names
94
+ class_names = ds.features['label'].names
95
+ num_classes = len(class_names)
96
+ print(f"\nNumber of classes: {num_classes}")
97
+ print(f"Classes: {class_names}")
98
+
99
+ # Extract only labels for splitting (not images!)
100
+ labels = [item['label'] for item in ds]
101
+
102
+ # Check class distribution
103
+ unique, counts = np.unique(labels, return_counts=True)
104
+ print("\nClass Distribution:")
105
+ for cls_idx, count in zip(unique, counts):
106
+ print(f" {class_names[cls_idx]}: {count} samples ({count/len(labels)*100:.2f}%)")
107
+
108
+ # Split dataset: 70% train, 15% val, 15% test
109
+ print("\n" + "="*60)
110
+ print("SPLITTING DATASET")
111
+ print("="*60)
112
+
113
+ indices = np.arange(len(ds))
114
+ train_idx, temp_idx, y_train, y_temp = train_test_split(
115
+ indices, labels, test_size=0.3, random_state=42, stratify=labels
116
+ )
117
+ val_idx, test_idx, y_val, y_test = train_test_split(
118
+ temp_idx, y_temp, test_size=0.5, random_state=42, stratify=y_temp
119
+ )
120
+
121
+ print(f"Train set: {len(train_idx)} samples")
122
+ print(f"Validation set: {len(val_idx)} samples")
123
+ print(f"Test set: {len(test_idx)} samples")
124
+
125
+ # Calculate class weights for handling imbalance
126
+ class_weights = compute_class_weight(
127
+ class_weight='balanced',
128
+ classes=np.unique(y_train),
129
+ y=y_train
130
+ )
131
+ class_weights = torch.FloatTensor(class_weights).to(CONFIG['device'])
132
+ print(f"\nClass weights (for imbalance): {class_weights.cpu().numpy()}")
133
+
134
+ # Create datasets and dataloaders
135
+ train_dataset = TrashDatasetLazy(ds, train_idx, transform=train_transform)
136
+ val_dataset = TrashDatasetLazy(ds, val_idx, transform=val_transform)
137
+ test_dataset = TrashDatasetLazy(ds, test_idx, transform=val_transform)
138
+
139
+ train_loader = DataLoader(
140
+ train_dataset, batch_size=CONFIG['batch_size'],
141
+ shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True
142
+ )
143
+ val_loader = DataLoader(
144
+ val_dataset, batch_size=CONFIG['batch_size'],
145
+ shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True
146
+ )
147
+ test_loader = DataLoader(
148
+ test_dataset, batch_size=CONFIG['batch_size'],
149
+ shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True
150
+ )
151
+
152
+ # Build Model using EfficientNetV2 (pretrained)
153
+ print("\n" + "="*60)
154
+ print("BUILDING MODEL")
155
+ print("="*60)
156
+
157
+ model = models.efficientnet_v2_s(weights='IMAGENET1K_V1')
158
+
159
+ # Freeze early layers
160
+ for param in list(model.parameters())[:-20]:
161
+ param.requires_grad = False
162
+
163
+ # Modify classifier for our number of classes
164
+ num_features = model.classifier[1].in_features
165
+ model.classifier = nn.Sequential(
166
+ nn.Dropout(p=0.3, inplace=True),
167
+ nn.Linear(num_features, 512),
168
+ nn.ReLU(),
169
+ nn.Dropout(p=0.3),
170
+ nn.Linear(512, num_classes)
171
+ )
172
+
173
+ model = model.to(CONFIG['device'])
174
+ print(f"Model: EfficientNetV2-S (pretrained)")
175
+ print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
176
+ print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
177
+
178
+ # Loss function with class weights and optimizer
179
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
180
+ optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=0.01)
181
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
182
+ optimizer, mode='min', factor=0.5, patience=3, verbose=True
183
+ )
184
+
185
+ # Training and Validation Functions
186
+ def train_epoch(model, loader, criterion, optimizer, device):
187
+ model.train()
188
+ running_loss = 0.0
189
+ correct = 0
190
+ total = 0
191
+
192
+ pbar = tqdm(loader, desc='Training')
193
+ for images, labels in pbar:
194
+ images, labels = images.to(device), labels.to(device)
195
+
196
+ optimizer.zero_grad()
197
+ outputs = model(images)
198
+ loss = criterion(outputs, labels)
199
+ loss.backward()
200
+ optimizer.step()
201
+
202
+ running_loss += loss.item() * images.size(0)
203
+ _, predicted = outputs.max(1)
204
+ total += labels.size(0)
205
+ correct += predicted.eq(labels).sum().item()
206
+
207
+ pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
208
+
209
+ epoch_loss = running_loss / total
210
+ epoch_acc = 100. * correct / total
211
+ return epoch_loss, epoch_acc
212
+
213
+ def validate_epoch(model, loader, criterion, device):
214
+ model.eval()
215
+ running_loss = 0.0
216
+ correct = 0
217
+ total = 0
218
+
219
+ with torch.no_grad():
220
+ pbar = tqdm(loader, desc='Validation')
221
+ for images, labels in pbar:
222
+ images, labels = images.to(device), labels.to(device)
223
+
224
+ outputs = model(images)
225
+ loss = criterion(outputs, labels)
226
+
227
+ running_loss += loss.item() * images.size(0)
228
+ _, predicted = outputs.max(1)
229
+ total += labels.size(0)
230
+ correct += predicted.eq(labels).sum().item()
231
+
232
+ pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
233
+
234
+ epoch_loss = running_loss / total
235
+ epoch_acc = 100. * correct / total
236
+ return epoch_loss, epoch_acc
237
+
238
+ # Training Loop
239
+ print("\n" + "="*60)
240
+ print("TRAINING MODEL")
241
+ print("="*60)
242
+
243
+ history = {
244
+ 'train_loss': [], 'train_acc': [],
245
+ 'val_loss': [], 'val_acc': []
246
+ }
247
+
248
+ best_val_acc = 0.0
249
+ patience_counter = 0
250
+
251
+ for epoch in range(CONFIG['num_epochs']):
252
+ print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
253
+ print("-" * 60)
254
+
255
+ train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, CONFIG['device'])
256
+ val_loss, val_acc = validate_epoch(model, val_loader, criterion, CONFIG['device'])
257
+
258
+ history['train_loss'].append(train_loss)
259
+ history['train_acc'].append(train_acc)
260
+ history['val_loss'].append(val_loss)
261
+ history['val_acc'].append(val_acc)
262
+
263
+ print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
264
+ print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
265
+
266
+ scheduler.step(val_loss)
267
+
268
+ # Save best model
269
+ if val_acc > best_val_acc:
270
+ best_val_acc = val_acc
271
+ torch.save({
272
+ 'epoch': epoch,
273
+ 'model_state_dict': model.state_dict(),
274
+ 'optimizer_state_dict': optimizer.state_dict(),
275
+ 'val_acc': val_acc,
276
+ 'class_names': class_names
277
+ }, CONFIG['model_save_path'])
278
+ print(f"✓ Model saved! (Val Acc: {val_acc:.2f}%)")
279
+ patience_counter = 0
280
+ else:
281
+ patience_counter += 1
282
+ print(f"No improvement ({patience_counter}/{CONFIG['patience']})")
283
+
284
+ if patience_counter >= CONFIG['patience']:
285
+ print("\nEarly stopping triggered!")
286
+ break
287
+
288
+ # Plot Training History
289
+ print("\n" + "="*60)
290
+ print("SAVING TRAINING GRAPHS")
291
+ print("="*60)
292
+
293
+ fig, axes = plt.subplots(1, 2, figsize=(15, 5))
294
+
295
+ # Loss plot
296
+ axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
297
+ axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
298
+ axes[0].set_xlabel('Epoch')
299
+ axes[0].set_ylabel('Loss')
300
+ axes[0].set_title('Training and Validation Loss')
301
+ axes[0].legend()
302
+ axes[0].grid(True, alpha=0.3)
303
+
304
+ # Accuracy plot
305
+ axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
306
+ axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
307
+ axes[1].set_xlabel('Epoch')
308
+ axes[1].set_ylabel('Accuracy (%)')
309
+ axes[1].set_title('Training and Validation Accuracy')
310
+ axes[1].legend()
311
+ axes[1].grid(True, alpha=0.3)
312
+
313
+ plt.tight_layout()
314
+ plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
315
+ print("✓ Training graphs saved as 'training_history.png'")
316
+
317
+ # Load Best Model and Test
318
+ print("\n" + "="*60)
319
+ print("LOADING BEST MODEL AND TESTING")
320
+ print("="*60)
321
+
322
+ checkpoint = torch.load(CONFIG['model_save_path'])
323
+ model.load_state_dict(checkpoint['model_state_dict'])
324
+ print(f"✓ Loaded best model from epoch {checkpoint['epoch']+1}")
325
+ print(f" Best validation accuracy: {checkpoint['val_acc']:.2f}%")
326
+
327
+ # Test the model
328
+ model.eval()
329
+ all_preds = []
330
+ all_labels = []
331
+
332
+ with torch.no_grad():
333
+ for images, labels in tqdm(test_loader, desc='Testing'):
334
+ images = images.to(CONFIG['device'])
335
+ outputs = model(images)
336
+ _, predicted = outputs.max(1)
337
+
338
+ all_preds.extend(predicted.cpu().numpy())
339
+ all_labels.extend(labels.numpy())
340
+
341
+ # Calculate test accuracy
342
+ test_acc = 100. * np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
343
+ print(f"\n{'='*60}")
344
+ print(f"TEST SET ACCURACY: {test_acc:.2f}%")
345
+ print(f"{'='*60}")
346
+
347
+ # Classification Report
348
+ print("\n" + "="*60)
349
+ print("CLASSIFICATION REPORT")
350
+ print("="*60)
351
+ print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))
352
+
353
+ # Confusion Matrix
354
+ cm = confusion_matrix(all_labels, all_preds)
355
+ plt.figure(figsize=(10, 8))
356
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
357
+ xticklabels=class_names, yticklabels=class_names)
358
+ plt.title('Confusion Matrix - Test Set')
359
+ plt.ylabel('True Label')
360
+ plt.xlabel('Predicted Label')
361
+ plt.xticks(rotation=45, ha='right')
362
+ plt.yticks(rotation=0)
363
+ plt.tight_layout()
364
+ plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
365
+ print("\n✓ Confusion matrix saved as 'confusion_matrix.png'")
366
+
367
+ print("\n" + "="*60)
368
+ print("TRAINING COMPLETE!")
369
+ print("="*60)
370
+ print(f"✓ Best model saved: {CONFIG['model_save_path']}")
371
+ print(f"✓ Training history: training_history.png")
372
+ print(f"✓ Confusion matrix: confusion_matrix.png")
373
+ print(f"✓ Final test accuracy: {test_acc:.2f}%")