KimingChen commited on
Commit
1418963
·
verified ·
1 Parent(s): 92f5feb

Upload 3 files

Browse files
Files changed (3) hide show
  1. LaZSL-main/OP.py +3 -3
  2. LaZSL-main/load_OP.py +106 -38
  3. LaZSL-main/main_OP.py +2 -2
LaZSL-main/OP.py CHANGED
@@ -47,7 +47,7 @@ class OP():
47
  return sim_op
48
 
49
  class OP_d():
50
- def __init__(self, max_iter,gama,alpha,constrain_type='const'):
51
  super(OP_d, self).__init__()
52
  self.max_iter= max_iter
53
  #self.M=M
@@ -56,7 +56,7 @@ class OP_d():
56
  self.gama= torch.tensor(gama,dtype=torch.half)
57
  self.zero= torch.tensor(-10,dtype=torch.half)
58
  self.constrain_type=constrain_type #['patch','att','const']
59
- self.alpha=alpha
60
  # self.b=b
61
 
62
  def Sinkhorn(self, K, u, v):
@@ -86,7 +86,7 @@ class OP_d():
86
  if is_cost_global:
87
  global_sim=sim[:,0,:].unsqueeze(1)
88
  region_sim=sim[:,1:,:]
89
- sim_global=(1-self.alpha)*global_sim + (self.alpha * region_sim)
90
  sim=region_sim
91
  self.M = sim_global.shape[1]
92
 
 
47
  return sim_op
48
 
49
  class OP_d():
50
+ def __init__(self, max_iter,gama,theta,constrain_type='const'):
51
  super(OP_d, self).__init__()
52
  self.max_iter= max_iter
53
  #self.M=M
 
56
  self.gama= torch.tensor(gama,dtype=torch.half)
57
  self.zero= torch.tensor(-10,dtype=torch.half)
58
  self.constrain_type=constrain_type #['patch','att','const']
59
+ self.theta=theta
60
  # self.b=b
61
 
62
  def Sinkhorn(self, K, u, v):
 
86
  if is_cost_global:
87
  global_sim=sim[:,0,:].unsqueeze(1)
88
  region_sim=sim[:,1:,:]
89
+ sim_global=(1-self.theta)*global_sim + (self.theta * region_sim)
90
  sim=region_sim
91
  self.M = sim_global.shape[1]
92
 
LaZSL-main/load_OP.py CHANGED
@@ -8,9 +8,9 @@ import pathlib
8
 
9
  from torch.utils.data import DataLoader, Subset
10
  from torchvision import transforms
11
- #from torchvision.datasets import ImageNet, ImageFolder, Places365
12
  from torchvision.datasets import ImageFolder
13
- #from imagenetv2_pytorch import ImageNetV2Dataset as ImageNetV2
14
  from datasets import _transform, CUBDataset, random_crop
15
  from collections import OrderedDict
16
  from myclip import clip
@@ -26,32 +26,113 @@ from utils import (
26
 
27
 
28
  hparams = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # hyperparameters
30
 
31
  hparams['model_size'] = "ViT-B/16"
32
  # Options:
33
- # ['RN50',
34
- # 'RN101',
35
- # 'RN50x4',
36
- # 'RN50x16',
37
- # 'RN50x64',
38
- # 'ViT-B/32',
39
  # 'ViT-B/16',
40
- # 'ViT-L/14',
41
- # 'ViT-L/14@336px']
42
  hparams['dataset'] = 'imagenet'
 
43
  hparams['max_iter'] = 100
44
- hparams['n_samples'] = 90
45
  #for mix
46
- hparams['alpha'] = 0.8
47
  #for crop
48
- hparams['alpha_crop'] = 0.6
49
  #for constrain
50
  hparams['gama'] = 0.0
51
  hparams['constrain_type'] = 'att' #['patch','att','const']
52
 
53
 
54
- hparams['batch_size'] = 1
55
  hparams['device'] = "cuda:2"
56
  hparams['category_name_inclusion'] = 'prepend' #'append' 'prepend'
57
 
@@ -124,12 +205,11 @@ def custom_loader(path: str) -> torch.Tensor:
124
  img = datasets.folder.default_loader(path)
125
  # Process the image and generate additional augmented samples
126
  augmented_imgs = [processor(img)]
127
- augmented_imgs.extend(processor(random_crop(img,alpha=hparams['alpha_crop'])) for _ in range(n_samples))
128
  # Return a stacked tensor of all processed images
129
  return torch.stack(augmented_imgs)
130
 
131
 
132
-
133
  if hparams['dataset'] == 'imagenet':
134
  if hparams['dataset'] == 'imagenet':
135
  dsclass = ImageNet
@@ -146,6 +226,7 @@ if hparams['dataset'] == 'imagenet':
146
  elif hparams['dataset'] == 'imagenetv2':
147
  hparams['data_dir'] = pathlib.Path(IMAGENETV2_DIR)
148
  hparams['class_num'] = 1000
 
149
  mydataset = ImageNetV2Dataset(
150
  location=hparams['data_dir'],
151
  transform=None,
@@ -160,6 +241,7 @@ elif hparams['dataset'] == 'imagenet-r':
160
  hparams['data_dir'] = pathlib.Path(IMAGENETR_DIR)
161
  dsclass = ImageFolder
162
  hparams['class_num'] = 200
 
163
  mydataset = dsclass(
164
  hparams['data_dir'],
165
  transform=None,
@@ -172,6 +254,7 @@ elif hparams['dataset'] == 'imagenet-a':
172
  hparams['data_dir'] = pathlib.Path(IMAGENETA_DIR)
173
  dsclass = ImageFolder
174
  hparams['class_num'] = 200
 
175
  mydataset = dsclass(
176
  hparams['data_dir'],
177
  transform=None,
@@ -184,6 +267,7 @@ elif hparams['dataset'] == 'imagenet-s':
184
  hparams['data_dir'] = pathlib.Path(IMAGENETS_DIR)
185
  dsclass = ImageFolder
186
  hparams['class_num'] = 1000
 
187
  mydataset = dsclass(
188
  hparams['data_dir'],
189
  transform=None,
@@ -198,6 +282,8 @@ elif hparams['dataset'] == 'imagenet-s':
198
  elif hparams['dataset'] == 'cub':
199
  # load CUB dataset
200
  hparams['data_dir'] = pathlib.Path(CUB_DIR)
 
 
201
  mydataset = CUBDataset(hparams['data_dir'], train=False, transform=None, loader=custom_loader)
202
  classes_to_load = None #dataset.classes
203
  hparams['descriptor_fname'] = 'descriptors_cub'
@@ -205,15 +291,9 @@ elif hparams['dataset'] == 'cub':
205
 
206
  # I recommend using VISSL https://github.com/facebookresearch/vissl/blob/main/extra_scripts/README.md to download these
207
 
208
- elif hparams['dataset'] == 'eurosat':
209
- from extra_datasets.patching.eurosat import EuroSATVal
210
- hparams['data_dir'] = pathlib.Path(EUROSAT_DIR)
211
- dataset = EuroSATVal(location=hparams['data_dir'], preprocess=tfms)
212
- dataset = dataset.test_dataset
213
- hparams['descriptor_fname'] = 'descriptors_eurosat'
214
- classes_to_load = None
215
 
216
- elif hparams['dataset'] == 'places365':
217
  hparams['class_num'] = 365
218
  hparams['data_dir'] = pathlib.Path(PLACES_DIR)
219
  mydataset = Places365(hparams['data_dir'], split='val', download=False, transform=None, loader=custom_loader)
@@ -222,7 +302,7 @@ elif hparams['dataset'] == 'places365':
222
  hparams['descriptor_fname'] = 'descriptors_places365'
223
  classes_to_load = None
224
 
225
- elif hparams['dataset'] == 'food101':
226
  hparams['data_dir'] = pathlib.Path(FOOD101_DIR)
227
  dsclass = ImageFolder
228
  hparams['class_num'] = 101
@@ -239,6 +319,7 @@ elif hparams['dataset'] == 'pets':
239
  hparams['data_dir'] = pathlib.Path(PETS_DIR)
240
  dsclass = ImageFolder
241
  hparams['class_num'] = 37
 
242
  mydataset = OxfordIIITPet(
243
  hparams['data_dir'],
244
  transform=None,
@@ -248,20 +329,7 @@ elif hparams['dataset'] == 'pets':
248
  hparams['descriptor_fname'] = 'descriptors_pets'
249
  classes_to_load = None
250
 
251
- elif hparams['dataset'] == 'dtd':
252
- hparams['class_num'] = 47
253
- hparams['data_dir'] = pathlib.Path(DTD_DIR)
254
- mydataset = DTD(
255
- hparams['data_dir'],
256
- transform=None,
257
- split="test",
258
- loader=custom_loader,
259
- )
260
 
261
- hparams['descriptor_fname'] = 'descriptors_dtd'
262
- classes_to_load = None
263
-
264
-
265
 
266
 
267
 
 
8
 
9
  from torch.utils.data import DataLoader, Subset
10
  from torchvision import transforms
11
+
12
  from torchvision.datasets import ImageFolder
13
+
14
  from datasets import _transform, CUBDataset, random_crop
15
  from collections import OrderedDict
16
  from myclip import clip
 
26
 
27
 
28
  hparams = {}
29
+
30
+
31
+ def get_params(model, dataset):
32
+ params = {}
33
+
34
+ if model == "ViT-B/16":
35
+ if dataset == "imagenet":
36
+ params['alpha'] = 0.6
37
+ params['theta'] = 0.8
38
+ params['N'] = 90
39
+ elif dataset == "cub":
40
+ params['alpha'] = 0.6
41
+ params['theta'] = 0.9
42
+ params['N'] = 90
43
+ elif dataset == "pets":
44
+ params['alpha'] = 0.6
45
+ params['theta'] = 0.2
46
+ params['N'] = 80
47
+ elif dataset == "food":
48
+ params['alpha'] = 0.6
49
+ params['theta'] = 0.8
50
+ params['N'] = 90
51
+ elif dataset == "place":
52
+ params['alpha'] = 0.4
53
+ params['theta'] = 0.8
54
+ params['N'] = 60
55
+ elif dataset == "imagenetv2":
56
+ params['alpha'] = 0.5
57
+ params['theta'] = 0.8
58
+ params['N'] = 70
59
+ elif dataset == "imagenet-r":
60
+ params['alpha'] = 0.6
61
+ params['theta'] = 0.8
62
+ params['N'] = 90
63
+ elif dataset == "imagenet-a":
64
+ params['alpha'] = 0.5
65
+ params['theta'] = 0.95
66
+ params['N'] = 90
67
+ elif dataset == "imagenet-s":
68
+ params['alpha'] = 0.6
69
+ params['theta'] = 0.8
70
+ params['N'] = 80
71
+ elif model == "ViT-B/32":
72
+ if dataset == "imagenet":
73
+ params['alpha'] = 0.6
74
+ params['theta'] = 0.8
75
+ params['N'] = 90
76
+ elif dataset == "cub":
77
+ params['alpha'] = 0.5
78
+ params['theta'] = 0.95
79
+ params['N'] = 80
80
+ elif dataset == "pets":
81
+ params['alpha'] = 0.6
82
+ params['theta'] = 0.9
83
+ params['N'] = 80
84
+ elif dataset == "food":
85
+ params['alpha'] = 0.6
86
+ params['theta'] = 0.9
87
+ params['N'] = 80
88
+ elif dataset == "place":
89
+ params['alpha'] = 0.6
90
+ params['theta'] = 0.9
91
+ params['N'] = 80
92
+ elif model == "ViT-L/14":
93
+ if dataset == "imagenet":
94
+ params['alpha'] = 0.6
95
+ params['theta'] = 0.8
96
+ params['N'] = 70
97
+ elif dataset == "cub":
98
+ params['alpha'] = 0.5
99
+ params['theta'] = 0.9
100
+ params['N'] = 80
101
+ elif dataset == "pets":
102
+ params['alpha'] = 0.6
103
+ params['theta'] = 0.8
104
+ params['N'] = 60
105
+ elif dataset == "food":
106
+ params['alpha'] = 0.6
107
+ params['theta'] = 0.9
108
+ params['N'] = 70
109
+ elif dataset == "place":
110
+ params['alpha'] = 0.4
111
+ params['theta'] = 0.9
112
+ params['N'] = 70
113
+
114
+ return params
115
  # hyperparameters
116
 
117
  hparams['model_size'] = "ViT-B/16"
118
  # Options:
119
+ # ['ViT-B/32',
 
 
 
 
 
120
  # 'ViT-B/16',
121
+ # 'ViT-L/14']
 
122
  hparams['dataset'] = 'imagenet'
123
+ params = get_params(hparams['model_size'], hparams['dataset'])
124
  hparams['max_iter'] = 100
125
+ hparams['n_samples'] = params['N']
126
  #for mix
127
+ hparams['theta'] = params['theta']
128
  #for crop
129
+ hparams['alpha'] = params['alpha']
130
  #for constrain
131
  hparams['gama'] = 0.0
132
  hparams['constrain_type'] = 'att' #['patch','att','const']
133
 
134
 
135
+ hparams['batch_size'] = 50
136
  hparams['device'] = "cuda:2"
137
  hparams['category_name_inclusion'] = 'prepend' #'append' 'prepend'
138
 
 
205
  img = datasets.folder.default_loader(path)
206
  # Process the image and generate additional augmented samples
207
  augmented_imgs = [processor(img)]
208
+ augmented_imgs.extend(processor(random_crop(img,alpha=hparams['alpha'])) for _ in range(n_samples))
209
  # Return a stacked tensor of all processed images
210
  return torch.stack(augmented_imgs)
211
 
212
 
 
213
  if hparams['dataset'] == 'imagenet':
214
  if hparams['dataset'] == 'imagenet':
215
  dsclass = ImageNet
 
226
  elif hparams['dataset'] == 'imagenetv2':
227
  hparams['data_dir'] = pathlib.Path(IMAGENETV2_DIR)
228
  hparams['class_num'] = 1000
229
+
230
  mydataset = ImageNetV2Dataset(
231
  location=hparams['data_dir'],
232
  transform=None,
 
241
  hparams['data_dir'] = pathlib.Path(IMAGENETR_DIR)
242
  dsclass = ImageFolder
243
  hparams['class_num'] = 200
244
+
245
  mydataset = dsclass(
246
  hparams['data_dir'],
247
  transform=None,
 
254
  hparams['data_dir'] = pathlib.Path(IMAGENETA_DIR)
255
  dsclass = ImageFolder
256
  hparams['class_num'] = 200
257
+
258
  mydataset = dsclass(
259
  hparams['data_dir'],
260
  transform=None,
 
267
  hparams['data_dir'] = pathlib.Path(IMAGENETS_DIR)
268
  dsclass = ImageFolder
269
  hparams['class_num'] = 1000
270
+
271
  mydataset = dsclass(
272
  hparams['data_dir'],
273
  transform=None,
 
282
  elif hparams['dataset'] == 'cub':
283
  # load CUB dataset
284
  hparams['data_dir'] = pathlib.Path(CUB_DIR)
285
+
286
+
287
  mydataset = CUBDataset(hparams['data_dir'], train=False, transform=None, loader=custom_loader)
288
  classes_to_load = None #dataset.classes
289
  hparams['descriptor_fname'] = 'descriptors_cub'
 
291
 
292
  # I recommend using VISSL https://github.com/facebookresearch/vissl/blob/main/extra_scripts/README.md to download these
293
 
294
+
 
 
 
 
 
 
295
 
296
+ elif hparams['dataset'] == 'place':
297
  hparams['class_num'] = 365
298
  hparams['data_dir'] = pathlib.Path(PLACES_DIR)
299
  mydataset = Places365(hparams['data_dir'], split='val', download=False, transform=None, loader=custom_loader)
 
302
  hparams['descriptor_fname'] = 'descriptors_places365'
303
  classes_to_load = None
304
 
305
+ elif hparams['dataset'] == 'food':
306
  hparams['data_dir'] = pathlib.Path(FOOD101_DIR)
307
  dsclass = ImageFolder
308
  hparams['class_num'] = 101
 
319
  hparams['data_dir'] = pathlib.Path(PETS_DIR)
320
  dsclass = ImageFolder
321
  hparams['class_num'] = 37
322
+
323
  mydataset = OxfordIIITPet(
324
  hparams['data_dir'],
325
  transform=None,
 
329
  hparams['descriptor_fname'] = 'descriptors_pets'
330
  classes_to_load = None
331
 
 
 
 
 
 
 
 
 
 
332
 
 
 
 
 
333
 
334
 
335
 
LaZSL-main/main_OP.py CHANGED
@@ -22,7 +22,7 @@ model.to(device)
22
  model.eval()
23
  model.requires_grad_(False)
24
  #op=OP(max_iter=hparams['max_iter'],M=49,N=5,n_cls=hparams['class_num'],b=bs)
25
- op_d=OP_d(max_iter=hparams['max_iter'], gama=hparams['gama'],constrain_type=hparams['constrain_type'],alpha=hparams['alpha'])
26
 
27
  print("Encoding descriptions...")
28
 
@@ -30,7 +30,7 @@ description_encodings = compute_description_encodings(model)
30
 
31
  label_encodings = compute_label_encodings(model)
32
 
33
- print("n_samples: %d \nalpha: %f \nalpha_crop: %f" %(hparams['n_samples'],hparams['alpha'],hparams['alpha_crop']))
34
  print("constrain_type: %s " %(hparams['constrain_type']))
35
 
36
  print("Evaluating...")
 
22
  model.eval()
23
  model.requires_grad_(False)
24
  #op=OP(max_iter=hparams['max_iter'],M=49,N=5,n_cls=hparams['class_num'],b=bs)
25
+ op_d=OP_d(max_iter=hparams['max_iter'], gama=hparams['gama'],constrain_type=hparams['constrain_type'],theta=hparams['theta'])
26
 
27
  print("Encoding descriptions...")
28
 
 
30
 
31
  label_encodings = compute_label_encodings(model)
32
 
33
+ print("n_samples: %d \nalpha: %f \ntheta: %f" %(hparams['n_samples'],hparams['alpha'],hparams['theta']))
34
  print("constrain_type: %s " %(hparams['constrain_type']))
35
 
36
  print("Evaluating...")