|  | import os | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from torch.utils.data import Dataset, DataLoader | 
					
						
						|  | from torchvision import transforms | 
					
						
						|  |  | 
					
						
						|  | class TinyImageNetDataset(Dataset): | 
					
						
						|  | def __init__(self, root_dir, transform=None, train=True): | 
					
						
						|  | self.root_dir = root_dir | 
					
						
						|  | self.transform = transform | 
					
						
						|  | self.image_paths = [] | 
					
						
						|  |  | 
					
						
						|  | if train: | 
					
						
						|  |  | 
					
						
						|  | train_dir = os.path.join(root_dir, 'train') | 
					
						
						|  | for cls in os.listdir(train_dir): | 
					
						
						|  | cls_dir = os.path.join(train_dir, cls, 'images') | 
					
						
						|  | for img_name in os.listdir(cls_dir): | 
					
						
						|  | if img_name.endswith('.JPEG'): | 
					
						
						|  | self.image_paths.append(os.path.join(cls_dir, img_name)) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | val_dir = os.path.join(root_dir, 'val') | 
					
						
						|  | images_dir = os.path.join(val_dir, 'images') | 
					
						
						|  | for img_name in os.listdir(images_dir): | 
					
						
						|  | if img_name.endswith('.JPEG'): | 
					
						
						|  | self.image_paths.append(os.path.join(images_dir, img_name)) | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.image_paths) | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, idx): | 
					
						
						|  | img = Image.open(self.image_paths[idx]).convert('RGB') | 
					
						
						|  | if self.transform: | 
					
						
						|  | img = self.transform(img) | 
					
						
						|  | return img, 0 | 
					
						
						|  |  | 
					
						
						|  | def get_dataloaders(config): | 
					
						
						|  | transform = transforms.Compose([ | 
					
						
						|  | transforms.Resize(config.image_size), | 
					
						
						|  | transforms.RandomHorizontalFlip(), | 
					
						
						|  | transforms.ToTensor(), | 
					
						
						|  | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | 
					
						
						|  | ]) | 
					
						
						|  |  | 
					
						
						|  | train_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=True) | 
					
						
						|  | val_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=False) | 
					
						
						|  |  | 
					
						
						|  | train_loader = DataLoader( | 
					
						
						|  | train_dataset, | 
					
						
						|  | batch_size=config.batch_size, | 
					
						
						|  | shuffle=True, | 
					
						
						|  | num_workers=config.num_workers, | 
					
						
						|  | pin_memory=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | val_loader = DataLoader( | 
					
						
						|  | val_dataset, | 
					
						
						|  | batch_size=config.batch_size, | 
					
						
						|  | shuffle=False, | 
					
						
						|  | num_workers=config.num_workers | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return train_loader, val_loader |