mix-bt / ssl-sota /datasets /tiny_in.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
883 Bytes
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from .transforms import MultiSample, aug_transform
from .base import BaseDataset
def base_transform():
return T.Compose(
[T.ToTensor(), T.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))]
)
class TinyImageNet(BaseDataset):
def ds_train(self):
t = MultiSample(
aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
)
return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/train", transform=t)
def ds_clf(self):
t = base_transform()
return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/train", transform=t)
def ds_test(self):
t = base_transform()
return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/val", transform=t)