Upload 3 files
Browse files- LaZSL-main/OP.py +3 -3
- LaZSL-main/load_OP.py +106 -38
- LaZSL-main/main_OP.py +2 -2
LaZSL-main/OP.py
CHANGED
|
@@ -47,7 +47,7 @@ class OP():
|
|
| 47 |
return sim_op
|
| 48 |
|
| 49 |
class OP_d():
|
| 50 |
-
def __init__(self, max_iter,gama,
|
| 51 |
super(OP_d, self).__init__()
|
| 52 |
self.max_iter= max_iter
|
| 53 |
#self.M=M
|
|
@@ -56,7 +56,7 @@ class OP_d():
|
|
| 56 |
self.gama= torch.tensor(gama,dtype=torch.half)
|
| 57 |
self.zero= torch.tensor(-10,dtype=torch.half)
|
| 58 |
self.constrain_type=constrain_type #['patch','att','const']
|
| 59 |
-
self.
|
| 60 |
# self.b=b
|
| 61 |
|
| 62 |
def Sinkhorn(self, K, u, v):
|
|
@@ -86,7 +86,7 @@ class OP_d():
|
|
| 86 |
if is_cost_global:
|
| 87 |
global_sim=sim[:,0,:].unsqueeze(1)
|
| 88 |
region_sim=sim[:,1:,:]
|
| 89 |
-
sim_global=(1-self.
|
| 90 |
sim=region_sim
|
| 91 |
self.M = sim_global.shape[1]
|
| 92 |
|
|
|
|
| 47 |
return sim_op
|
| 48 |
|
| 49 |
class OP_d():
|
| 50 |
+
def __init__(self, max_iter,gama,theta,constrain_type='const'):
|
| 51 |
super(OP_d, self).__init__()
|
| 52 |
self.max_iter= max_iter
|
| 53 |
#self.M=M
|
|
|
|
| 56 |
self.gama= torch.tensor(gama,dtype=torch.half)
|
| 57 |
self.zero= torch.tensor(-10,dtype=torch.half)
|
| 58 |
self.constrain_type=constrain_type #['patch','att','const']
|
| 59 |
+
self.theta=theta
|
| 60 |
# self.b=b
|
| 61 |
|
| 62 |
def Sinkhorn(self, K, u, v):
|
|
|
|
| 86 |
if is_cost_global:
|
| 87 |
global_sim=sim[:,0,:].unsqueeze(1)
|
| 88 |
region_sim=sim[:,1:,:]
|
| 89 |
+
sim_global=(1-self.theta)*global_sim + (self.theta * region_sim)
|
| 90 |
sim=region_sim
|
| 91 |
self.M = sim_global.shape[1]
|
| 92 |
|
LaZSL-main/load_OP.py
CHANGED
|
@@ -8,9 +8,9 @@ import pathlib
|
|
| 8 |
|
| 9 |
from torch.utils.data import DataLoader, Subset
|
| 10 |
from torchvision import transforms
|
| 11 |
-
|
| 12 |
from torchvision.datasets import ImageFolder
|
| 13 |
-
|
| 14 |
from datasets import _transform, CUBDataset, random_crop
|
| 15 |
from collections import OrderedDict
|
| 16 |
from myclip import clip
|
|
@@ -26,32 +26,113 @@ from utils import (
|
|
| 26 |
|
| 27 |
|
| 28 |
hparams = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# hyperparameters
|
| 30 |
|
| 31 |
hparams['model_size'] = "ViT-B/16"
|
| 32 |
# Options:
|
| 33 |
-
# ['
|
| 34 |
-
# 'RN101',
|
| 35 |
-
# 'RN50x4',
|
| 36 |
-
# 'RN50x16',
|
| 37 |
-
# 'RN50x64',
|
| 38 |
-
# 'ViT-B/32',
|
| 39 |
# 'ViT-B/16',
|
| 40 |
-
# 'ViT-L/14'
|
| 41 |
-
# 'ViT-L/14@336px']
|
| 42 |
hparams['dataset'] = 'imagenet'
|
|
|
|
| 43 |
hparams['max_iter'] = 100
|
| 44 |
-
hparams['n_samples'] =
|
| 45 |
#for mix
|
| 46 |
-
hparams['
|
| 47 |
#for crop
|
| 48 |
-
hparams['
|
| 49 |
#for constrain
|
| 50 |
hparams['gama'] = 0.0
|
| 51 |
hparams['constrain_type'] = 'att' #['patch','att','const']
|
| 52 |
|
| 53 |
|
| 54 |
-
hparams['batch_size'] =
|
| 55 |
hparams['device'] = "cuda:2"
|
| 56 |
hparams['category_name_inclusion'] = 'prepend' #'append' 'prepend'
|
| 57 |
|
|
@@ -124,12 +205,11 @@ def custom_loader(path: str) -> torch.Tensor:
|
|
| 124 |
img = datasets.folder.default_loader(path)
|
| 125 |
# Process the image and generate additional augmented samples
|
| 126 |
augmented_imgs = [processor(img)]
|
| 127 |
-
augmented_imgs.extend(processor(random_crop(img,alpha=hparams['
|
| 128 |
# Return a stacked tensor of all processed images
|
| 129 |
return torch.stack(augmented_imgs)
|
| 130 |
|
| 131 |
|
| 132 |
-
|
| 133 |
if hparams['dataset'] == 'imagenet':
|
| 134 |
if hparams['dataset'] == 'imagenet':
|
| 135 |
dsclass = ImageNet
|
|
@@ -146,6 +226,7 @@ if hparams['dataset'] == 'imagenet':
|
|
| 146 |
elif hparams['dataset'] == 'imagenetv2':
|
| 147 |
hparams['data_dir'] = pathlib.Path(IMAGENETV2_DIR)
|
| 148 |
hparams['class_num'] = 1000
|
|
|
|
| 149 |
mydataset = ImageNetV2Dataset(
|
| 150 |
location=hparams['data_dir'],
|
| 151 |
transform=None,
|
|
@@ -160,6 +241,7 @@ elif hparams['dataset'] == 'imagenet-r':
|
|
| 160 |
hparams['data_dir'] = pathlib.Path(IMAGENETR_DIR)
|
| 161 |
dsclass = ImageFolder
|
| 162 |
hparams['class_num'] = 200
|
|
|
|
| 163 |
mydataset = dsclass(
|
| 164 |
hparams['data_dir'],
|
| 165 |
transform=None,
|
|
@@ -172,6 +254,7 @@ elif hparams['dataset'] == 'imagenet-a':
|
|
| 172 |
hparams['data_dir'] = pathlib.Path(IMAGENETA_DIR)
|
| 173 |
dsclass = ImageFolder
|
| 174 |
hparams['class_num'] = 200
|
|
|
|
| 175 |
mydataset = dsclass(
|
| 176 |
hparams['data_dir'],
|
| 177 |
transform=None,
|
|
@@ -184,6 +267,7 @@ elif hparams['dataset'] == 'imagenet-s':
|
|
| 184 |
hparams['data_dir'] = pathlib.Path(IMAGENETS_DIR)
|
| 185 |
dsclass = ImageFolder
|
| 186 |
hparams['class_num'] = 1000
|
|
|
|
| 187 |
mydataset = dsclass(
|
| 188 |
hparams['data_dir'],
|
| 189 |
transform=None,
|
|
@@ -198,6 +282,8 @@ elif hparams['dataset'] == 'imagenet-s':
|
|
| 198 |
elif hparams['dataset'] == 'cub':
|
| 199 |
# load CUB dataset
|
| 200 |
hparams['data_dir'] = pathlib.Path(CUB_DIR)
|
|
|
|
|
|
|
| 201 |
mydataset = CUBDataset(hparams['data_dir'], train=False, transform=None, loader=custom_loader)
|
| 202 |
classes_to_load = None #dataset.classes
|
| 203 |
hparams['descriptor_fname'] = 'descriptors_cub'
|
|
@@ -205,15 +291,9 @@ elif hparams['dataset'] == 'cub':
|
|
| 205 |
|
| 206 |
# I recommend using VISSL https://github.com/facebookresearch/vissl/blob/main/extra_scripts/README.md to download these
|
| 207 |
|
| 208 |
-
|
| 209 |
-
from extra_datasets.patching.eurosat import EuroSATVal
|
| 210 |
-
hparams['data_dir'] = pathlib.Path(EUROSAT_DIR)
|
| 211 |
-
dataset = EuroSATVal(location=hparams['data_dir'], preprocess=tfms)
|
| 212 |
-
dataset = dataset.test_dataset
|
| 213 |
-
hparams['descriptor_fname'] = 'descriptors_eurosat'
|
| 214 |
-
classes_to_load = None
|
| 215 |
|
| 216 |
-
elif hparams['dataset'] == '
|
| 217 |
hparams['class_num'] = 365
|
| 218 |
hparams['data_dir'] = pathlib.Path(PLACES_DIR)
|
| 219 |
mydataset = Places365(hparams['data_dir'], split='val', download=False, transform=None, loader=custom_loader)
|
|
@@ -222,7 +302,7 @@ elif hparams['dataset'] == 'places365':
|
|
| 222 |
hparams['descriptor_fname'] = 'descriptors_places365'
|
| 223 |
classes_to_load = None
|
| 224 |
|
| 225 |
-
elif hparams['dataset'] == '
|
| 226 |
hparams['data_dir'] = pathlib.Path(FOOD101_DIR)
|
| 227 |
dsclass = ImageFolder
|
| 228 |
hparams['class_num'] = 101
|
|
@@ -239,6 +319,7 @@ elif hparams['dataset'] == 'pets':
|
|
| 239 |
hparams['data_dir'] = pathlib.Path(PETS_DIR)
|
| 240 |
dsclass = ImageFolder
|
| 241 |
hparams['class_num'] = 37
|
|
|
|
| 242 |
mydataset = OxfordIIITPet(
|
| 243 |
hparams['data_dir'],
|
| 244 |
transform=None,
|
|
@@ -248,20 +329,7 @@ elif hparams['dataset'] == 'pets':
|
|
| 248 |
hparams['descriptor_fname'] = 'descriptors_pets'
|
| 249 |
classes_to_load = None
|
| 250 |
|
| 251 |
-
elif hparams['dataset'] == 'dtd':
|
| 252 |
-
hparams['class_num'] = 47
|
| 253 |
-
hparams['data_dir'] = pathlib.Path(DTD_DIR)
|
| 254 |
-
mydataset = DTD(
|
| 255 |
-
hparams['data_dir'],
|
| 256 |
-
transform=None,
|
| 257 |
-
split="test",
|
| 258 |
-
loader=custom_loader,
|
| 259 |
-
)
|
| 260 |
|
| 261 |
-
hparams['descriptor_fname'] = 'descriptors_dtd'
|
| 262 |
-
classes_to_load = None
|
| 263 |
-
|
| 264 |
-
|
| 265 |
|
| 266 |
|
| 267 |
|
|
|
|
| 8 |
|
| 9 |
from torch.utils.data import DataLoader, Subset
|
| 10 |
from torchvision import transforms
|
| 11 |
+
|
| 12 |
from torchvision.datasets import ImageFolder
|
| 13 |
+
|
| 14 |
from datasets import _transform, CUBDataset, random_crop
|
| 15 |
from collections import OrderedDict
|
| 16 |
from myclip import clip
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
hparams = {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_params(model, dataset):
|
| 32 |
+
params = {}
|
| 33 |
+
|
| 34 |
+
if model == "ViT-B/16":
|
| 35 |
+
if dataset == "imagenet":
|
| 36 |
+
params['alpha'] = 0.6
|
| 37 |
+
params['theta'] = 0.8
|
| 38 |
+
params['N'] = 90
|
| 39 |
+
elif dataset == "cub":
|
| 40 |
+
params['alpha'] = 0.6
|
| 41 |
+
params['theta'] = 0.9
|
| 42 |
+
params['N'] = 90
|
| 43 |
+
elif dataset == "pets":
|
| 44 |
+
params['alpha'] = 0.6
|
| 45 |
+
params['theta'] = 0.2
|
| 46 |
+
params['N'] = 80
|
| 47 |
+
elif dataset == "food":
|
| 48 |
+
params['alpha'] = 0.6
|
| 49 |
+
params['theta'] = 0.8
|
| 50 |
+
params['N'] = 90
|
| 51 |
+
elif dataset == "place":
|
| 52 |
+
params['alpha'] = 0.4
|
| 53 |
+
params['theta'] = 0.8
|
| 54 |
+
params['N'] = 60
|
| 55 |
+
elif dataset == "imagenetv2":
|
| 56 |
+
params['alpha'] = 0.5
|
| 57 |
+
params['theta'] = 0.8
|
| 58 |
+
params['N'] = 70
|
| 59 |
+
elif dataset == "imagenet-r":
|
| 60 |
+
params['alpha'] = 0.6
|
| 61 |
+
params['theta'] = 0.8
|
| 62 |
+
params['N'] = 90
|
| 63 |
+
elif dataset == "imagenet-a":
|
| 64 |
+
params['alpha'] = 0.5
|
| 65 |
+
params['theta'] = 0.95
|
| 66 |
+
params['N'] = 90
|
| 67 |
+
elif dataset == "imagenet-s":
|
| 68 |
+
params['alpha'] = 0.6
|
| 69 |
+
params['theta'] = 0.8
|
| 70 |
+
params['N'] = 80
|
| 71 |
+
elif model == "ViT-B/32":
|
| 72 |
+
if dataset == "imagenet":
|
| 73 |
+
params['alpha'] = 0.6
|
| 74 |
+
params['theta'] = 0.8
|
| 75 |
+
params['N'] = 90
|
| 76 |
+
elif dataset == "cub":
|
| 77 |
+
params['alpha'] = 0.5
|
| 78 |
+
params['theta'] = 0.95
|
| 79 |
+
params['N'] = 80
|
| 80 |
+
elif dataset == "pets":
|
| 81 |
+
params['alpha'] = 0.6
|
| 82 |
+
params['theta'] = 0.9
|
| 83 |
+
params['N'] = 80
|
| 84 |
+
elif dataset == "food":
|
| 85 |
+
params['alpha'] = 0.6
|
| 86 |
+
params['theta'] = 0.9
|
| 87 |
+
params['N'] = 80
|
| 88 |
+
elif dataset == "place":
|
| 89 |
+
params['alpha'] = 0.6
|
| 90 |
+
params['theta'] = 0.9
|
| 91 |
+
params['N'] = 80
|
| 92 |
+
elif model == "ViT-L/14":
|
| 93 |
+
if dataset == "imagenet":
|
| 94 |
+
params['alpha'] = 0.6
|
| 95 |
+
params['theta'] = 0.8
|
| 96 |
+
params['N'] = 70
|
| 97 |
+
elif dataset == "cub":
|
| 98 |
+
params['alpha'] = 0.5
|
| 99 |
+
params['theta'] = 0.9
|
| 100 |
+
params['N'] = 80
|
| 101 |
+
elif dataset == "pets":
|
| 102 |
+
params['alpha'] = 0.6
|
| 103 |
+
params['theta'] = 0.8
|
| 104 |
+
params['N'] = 60
|
| 105 |
+
elif dataset == "food":
|
| 106 |
+
params['alpha'] = 0.6
|
| 107 |
+
params['theta'] = 0.9
|
| 108 |
+
params['N'] = 70
|
| 109 |
+
elif dataset == "place":
|
| 110 |
+
params['alpha'] = 0.4
|
| 111 |
+
params['theta'] = 0.9
|
| 112 |
+
params['N'] = 70
|
| 113 |
+
|
| 114 |
+
return params
|
| 115 |
# hyperparameters
|
| 116 |
|
| 117 |
hparams['model_size'] = "ViT-B/16"
|
| 118 |
# Options:
|
| 119 |
+
# ['ViT-B/32',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
# 'ViT-B/16',
|
| 121 |
+
# 'ViT-L/14']
|
|
|
|
| 122 |
hparams['dataset'] = 'imagenet'
|
| 123 |
+
params = get_params(hparams['model_size'], hparams['dataset'])
|
| 124 |
hparams['max_iter'] = 100
|
| 125 |
+
hparams['n_samples'] = params['N']
|
| 126 |
#for mix
|
| 127 |
+
hparams['theta'] = params['theta']
|
| 128 |
#for crop
|
| 129 |
+
hparams['alpha'] = params['alpha']
|
| 130 |
#for constrain
|
| 131 |
hparams['gama'] = 0.0
|
| 132 |
hparams['constrain_type'] = 'att' #['patch','att','const']
|
| 133 |
|
| 134 |
|
| 135 |
+
hparams['batch_size'] = 50
|
| 136 |
hparams['device'] = "cuda:2"
|
| 137 |
hparams['category_name_inclusion'] = 'prepend' #'append' 'prepend'
|
| 138 |
|
|
|
|
| 205 |
img = datasets.folder.default_loader(path)
|
| 206 |
# Process the image and generate additional augmented samples
|
| 207 |
augmented_imgs = [processor(img)]
|
| 208 |
+
augmented_imgs.extend(processor(random_crop(img,alpha=hparams['alpha'])) for _ in range(n_samples))
|
| 209 |
# Return a stacked tensor of all processed images
|
| 210 |
return torch.stack(augmented_imgs)
|
| 211 |
|
| 212 |
|
|
|
|
| 213 |
if hparams['dataset'] == 'imagenet':
|
| 214 |
if hparams['dataset'] == 'imagenet':
|
| 215 |
dsclass = ImageNet
|
|
|
|
| 226 |
elif hparams['dataset'] == 'imagenetv2':
|
| 227 |
hparams['data_dir'] = pathlib.Path(IMAGENETV2_DIR)
|
| 228 |
hparams['class_num'] = 1000
|
| 229 |
+
|
| 230 |
mydataset = ImageNetV2Dataset(
|
| 231 |
location=hparams['data_dir'],
|
| 232 |
transform=None,
|
|
|
|
| 241 |
hparams['data_dir'] = pathlib.Path(IMAGENETR_DIR)
|
| 242 |
dsclass = ImageFolder
|
| 243 |
hparams['class_num'] = 200
|
| 244 |
+
|
| 245 |
mydataset = dsclass(
|
| 246 |
hparams['data_dir'],
|
| 247 |
transform=None,
|
|
|
|
| 254 |
hparams['data_dir'] = pathlib.Path(IMAGENETA_DIR)
|
| 255 |
dsclass = ImageFolder
|
| 256 |
hparams['class_num'] = 200
|
| 257 |
+
|
| 258 |
mydataset = dsclass(
|
| 259 |
hparams['data_dir'],
|
| 260 |
transform=None,
|
|
|
|
| 267 |
hparams['data_dir'] = pathlib.Path(IMAGENETS_DIR)
|
| 268 |
dsclass = ImageFolder
|
| 269 |
hparams['class_num'] = 1000
|
| 270 |
+
|
| 271 |
mydataset = dsclass(
|
| 272 |
hparams['data_dir'],
|
| 273 |
transform=None,
|
|
|
|
| 282 |
elif hparams['dataset'] == 'cub':
|
| 283 |
# load CUB dataset
|
| 284 |
hparams['data_dir'] = pathlib.Path(CUB_DIR)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
mydataset = CUBDataset(hparams['data_dir'], train=False, transform=None, loader=custom_loader)
|
| 288 |
classes_to_load = None #dataset.classes
|
| 289 |
hparams['descriptor_fname'] = 'descriptors_cub'
|
|
|
|
| 291 |
|
| 292 |
# I recommend using VISSL https://github.com/facebookresearch/vissl/blob/main/extra_scripts/README.md to download these
|
| 293 |
|
| 294 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
+
elif hparams['dataset'] == 'place':
|
| 297 |
hparams['class_num'] = 365
|
| 298 |
hparams['data_dir'] = pathlib.Path(PLACES_DIR)
|
| 299 |
mydataset = Places365(hparams['data_dir'], split='val', download=False, transform=None, loader=custom_loader)
|
|
|
|
| 302 |
hparams['descriptor_fname'] = 'descriptors_places365'
|
| 303 |
classes_to_load = None
|
| 304 |
|
| 305 |
+
elif hparams['dataset'] == 'food':
|
| 306 |
hparams['data_dir'] = pathlib.Path(FOOD101_DIR)
|
| 307 |
dsclass = ImageFolder
|
| 308 |
hparams['class_num'] = 101
|
|
|
|
| 319 |
hparams['data_dir'] = pathlib.Path(PETS_DIR)
|
| 320 |
dsclass = ImageFolder
|
| 321 |
hparams['class_num'] = 37
|
| 322 |
+
|
| 323 |
mydataset = OxfordIIITPet(
|
| 324 |
hparams['data_dir'],
|
| 325 |
transform=None,
|
|
|
|
| 329 |
hparams['descriptor_fname'] = 'descriptors_pets'
|
| 330 |
classes_to_load = None
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
|
LaZSL-main/main_OP.py
CHANGED
|
@@ -22,7 +22,7 @@ model.to(device)
|
|
| 22 |
model.eval()
|
| 23 |
model.requires_grad_(False)
|
| 24 |
#op=OP(max_iter=hparams['max_iter'],M=49,N=5,n_cls=hparams['class_num'],b=bs)
|
| 25 |
-
op_d=OP_d(max_iter=hparams['max_iter'], gama=hparams['gama'],constrain_type=hparams['constrain_type'],
|
| 26 |
|
| 27 |
print("Encoding descriptions...")
|
| 28 |
|
|
@@ -30,7 +30,7 @@ description_encodings = compute_description_encodings(model)
|
|
| 30 |
|
| 31 |
label_encodings = compute_label_encodings(model)
|
| 32 |
|
| 33 |
-
print("n_samples: %d \nalpha: %f \
|
| 34 |
print("constrain_type: %s " %(hparams['constrain_type']))
|
| 35 |
|
| 36 |
print("Evaluating...")
|
|
|
|
| 22 |
model.eval()
|
| 23 |
model.requires_grad_(False)
|
| 24 |
#op=OP(max_iter=hparams['max_iter'],M=49,N=5,n_cls=hparams['class_num'],b=bs)
|
| 25 |
+
op_d=OP_d(max_iter=hparams['max_iter'], gama=hparams['gama'],constrain_type=hparams['constrain_type'],theta=hparams['theta'])
|
| 26 |
|
| 27 |
print("Encoding descriptions...")
|
| 28 |
|
|
|
|
| 30 |
|
| 31 |
label_encodings = compute_label_encodings(model)
|
| 32 |
|
| 33 |
+
print("n_samples: %d \nalpha: %f \ntheta: %f" %(hparams['n_samples'],hparams['alpha'],hparams['theta']))
|
| 34 |
print("constrain_type: %s " %(hparams['constrain_type']))
|
| 35 |
|
| 36 |
print("Evaluating...")
|