Upload 3 files
Browse files- LaZSL-main/OP.py +3 -3
- LaZSL-main/load_OP.py +106 -38
- LaZSL-main/main_OP.py +2 -2
LaZSL-main/OP.py
CHANGED
@@ -47,7 +47,7 @@ class OP():
|
|
47 |
return sim_op
|
48 |
|
49 |
class OP_d():
|
50 |
-
def __init__(self, max_iter,gama,
|
51 |
super(OP_d, self).__init__()
|
52 |
self.max_iter= max_iter
|
53 |
#self.M=M
|
@@ -56,7 +56,7 @@ class OP_d():
|
|
56 |
self.gama= torch.tensor(gama,dtype=torch.half)
|
57 |
self.zero= torch.tensor(-10,dtype=torch.half)
|
58 |
self.constrain_type=constrain_type #['patch','att','const']
|
59 |
-
self.
|
60 |
# self.b=b
|
61 |
|
62 |
def Sinkhorn(self, K, u, v):
|
@@ -86,7 +86,7 @@ class OP_d():
|
|
86 |
if is_cost_global:
|
87 |
global_sim=sim[:,0,:].unsqueeze(1)
|
88 |
region_sim=sim[:,1:,:]
|
89 |
-
sim_global=(1-self.
|
90 |
sim=region_sim
|
91 |
self.M = sim_global.shape[1]
|
92 |
|
|
|
47 |
return sim_op
|
48 |
|
49 |
class OP_d():
|
50 |
+
def __init__(self, max_iter,gama,theta,constrain_type='const'):
|
51 |
super(OP_d, self).__init__()
|
52 |
self.max_iter= max_iter
|
53 |
#self.M=M
|
|
|
56 |
self.gama= torch.tensor(gama,dtype=torch.half)
|
57 |
self.zero= torch.tensor(-10,dtype=torch.half)
|
58 |
self.constrain_type=constrain_type #['patch','att','const']
|
59 |
+
self.theta=theta
|
60 |
# self.b=b
|
61 |
|
62 |
def Sinkhorn(self, K, u, v):
|
|
|
86 |
if is_cost_global:
|
87 |
global_sim=sim[:,0,:].unsqueeze(1)
|
88 |
region_sim=sim[:,1:,:]
|
89 |
+
sim_global=(1-self.theta)*global_sim + (self.theta * region_sim)
|
90 |
sim=region_sim
|
91 |
self.M = sim_global.shape[1]
|
92 |
|
LaZSL-main/load_OP.py
CHANGED
@@ -8,9 +8,9 @@ import pathlib
|
|
8 |
|
9 |
from torch.utils.data import DataLoader, Subset
|
10 |
from torchvision import transforms
|
11 |
-
|
12 |
from torchvision.datasets import ImageFolder
|
13 |
-
|
14 |
from datasets import _transform, CUBDataset, random_crop
|
15 |
from collections import OrderedDict
|
16 |
from myclip import clip
|
@@ -26,32 +26,113 @@ from utils import (
|
|
26 |
|
27 |
|
28 |
hparams = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# hyperparameters
|
30 |
|
31 |
hparams['model_size'] = "ViT-B/16"
|
32 |
# Options:
|
33 |
-
# ['
|
34 |
-
# 'RN101',
|
35 |
-
# 'RN50x4',
|
36 |
-
# 'RN50x16',
|
37 |
-
# 'RN50x64',
|
38 |
-
# 'ViT-B/32',
|
39 |
# 'ViT-B/16',
|
40 |
-
# 'ViT-L/14'
|
41 |
-
# 'ViT-L/14@336px']
|
42 |
hparams['dataset'] = 'imagenet'
|
|
|
43 |
hparams['max_iter'] = 100
|
44 |
-
hparams['n_samples'] =
|
45 |
#for mix
|
46 |
-
hparams['
|
47 |
#for crop
|
48 |
-
hparams['
|
49 |
#for constrain
|
50 |
hparams['gama'] = 0.0
|
51 |
hparams['constrain_type'] = 'att' #['patch','att','const']
|
52 |
|
53 |
|
54 |
-
hparams['batch_size'] =
|
55 |
hparams['device'] = "cuda:2"
|
56 |
hparams['category_name_inclusion'] = 'prepend' #'append' 'prepend'
|
57 |
|
@@ -124,12 +205,11 @@ def custom_loader(path: str) -> torch.Tensor:
|
|
124 |
img = datasets.folder.default_loader(path)
|
125 |
# Process the image and generate additional augmented samples
|
126 |
augmented_imgs = [processor(img)]
|
127 |
-
augmented_imgs.extend(processor(random_crop(img,alpha=hparams['
|
128 |
# Return a stacked tensor of all processed images
|
129 |
return torch.stack(augmented_imgs)
|
130 |
|
131 |
|
132 |
-
|
133 |
if hparams['dataset'] == 'imagenet':
|
134 |
if hparams['dataset'] == 'imagenet':
|
135 |
dsclass = ImageNet
|
@@ -146,6 +226,7 @@ if hparams['dataset'] == 'imagenet':
|
|
146 |
elif hparams['dataset'] == 'imagenetv2':
|
147 |
hparams['data_dir'] = pathlib.Path(IMAGENETV2_DIR)
|
148 |
hparams['class_num'] = 1000
|
|
|
149 |
mydataset = ImageNetV2Dataset(
|
150 |
location=hparams['data_dir'],
|
151 |
transform=None,
|
@@ -160,6 +241,7 @@ elif hparams['dataset'] == 'imagenet-r':
|
|
160 |
hparams['data_dir'] = pathlib.Path(IMAGENETR_DIR)
|
161 |
dsclass = ImageFolder
|
162 |
hparams['class_num'] = 200
|
|
|
163 |
mydataset = dsclass(
|
164 |
hparams['data_dir'],
|
165 |
transform=None,
|
@@ -172,6 +254,7 @@ elif hparams['dataset'] == 'imagenet-a':
|
|
172 |
hparams['data_dir'] = pathlib.Path(IMAGENETA_DIR)
|
173 |
dsclass = ImageFolder
|
174 |
hparams['class_num'] = 200
|
|
|
175 |
mydataset = dsclass(
|
176 |
hparams['data_dir'],
|
177 |
transform=None,
|
@@ -184,6 +267,7 @@ elif hparams['dataset'] == 'imagenet-s':
|
|
184 |
hparams['data_dir'] = pathlib.Path(IMAGENETS_DIR)
|
185 |
dsclass = ImageFolder
|
186 |
hparams['class_num'] = 1000
|
|
|
187 |
mydataset = dsclass(
|
188 |
hparams['data_dir'],
|
189 |
transform=None,
|
@@ -198,6 +282,8 @@ elif hparams['dataset'] == 'imagenet-s':
|
|
198 |
elif hparams['dataset'] == 'cub':
|
199 |
# load CUB dataset
|
200 |
hparams['data_dir'] = pathlib.Path(CUB_DIR)
|
|
|
|
|
201 |
mydataset = CUBDataset(hparams['data_dir'], train=False, transform=None, loader=custom_loader)
|
202 |
classes_to_load = None #dataset.classes
|
203 |
hparams['descriptor_fname'] = 'descriptors_cub'
|
@@ -205,15 +291,9 @@ elif hparams['dataset'] == 'cub':
|
|
205 |
|
206 |
# I recommend using VISSL https://github.com/facebookresearch/vissl/blob/main/extra_scripts/README.md to download these
|
207 |
|
208 |
-
|
209 |
-
from extra_datasets.patching.eurosat import EuroSATVal
|
210 |
-
hparams['data_dir'] = pathlib.Path(EUROSAT_DIR)
|
211 |
-
dataset = EuroSATVal(location=hparams['data_dir'], preprocess=tfms)
|
212 |
-
dataset = dataset.test_dataset
|
213 |
-
hparams['descriptor_fname'] = 'descriptors_eurosat'
|
214 |
-
classes_to_load = None
|
215 |
|
216 |
-
elif hparams['dataset'] == '
|
217 |
hparams['class_num'] = 365
|
218 |
hparams['data_dir'] = pathlib.Path(PLACES_DIR)
|
219 |
mydataset = Places365(hparams['data_dir'], split='val', download=False, transform=None, loader=custom_loader)
|
@@ -222,7 +302,7 @@ elif hparams['dataset'] == 'places365':
|
|
222 |
hparams['descriptor_fname'] = 'descriptors_places365'
|
223 |
classes_to_load = None
|
224 |
|
225 |
-
elif hparams['dataset'] == '
|
226 |
hparams['data_dir'] = pathlib.Path(FOOD101_DIR)
|
227 |
dsclass = ImageFolder
|
228 |
hparams['class_num'] = 101
|
@@ -239,6 +319,7 @@ elif hparams['dataset'] == 'pets':
|
|
239 |
hparams['data_dir'] = pathlib.Path(PETS_DIR)
|
240 |
dsclass = ImageFolder
|
241 |
hparams['class_num'] = 37
|
|
|
242 |
mydataset = OxfordIIITPet(
|
243 |
hparams['data_dir'],
|
244 |
transform=None,
|
@@ -248,20 +329,7 @@ elif hparams['dataset'] == 'pets':
|
|
248 |
hparams['descriptor_fname'] = 'descriptors_pets'
|
249 |
classes_to_load = None
|
250 |
|
251 |
-
elif hparams['dataset'] == 'dtd':
|
252 |
-
hparams['class_num'] = 47
|
253 |
-
hparams['data_dir'] = pathlib.Path(DTD_DIR)
|
254 |
-
mydataset = DTD(
|
255 |
-
hparams['data_dir'],
|
256 |
-
transform=None,
|
257 |
-
split="test",
|
258 |
-
loader=custom_loader,
|
259 |
-
)
|
260 |
|
261 |
-
hparams['descriptor_fname'] = 'descriptors_dtd'
|
262 |
-
classes_to_load = None
|
263 |
-
|
264 |
-
|
265 |
|
266 |
|
267 |
|
|
|
8 |
|
9 |
from torch.utils.data import DataLoader, Subset
|
10 |
from torchvision import transforms
|
11 |
+
|
12 |
from torchvision.datasets import ImageFolder
|
13 |
+
|
14 |
from datasets import _transform, CUBDataset, random_crop
|
15 |
from collections import OrderedDict
|
16 |
from myclip import clip
|
|
|
26 |
|
27 |
|
28 |
hparams = {}
|
29 |
+
|
30 |
+
|
31 |
+
def get_params(model, dataset):
|
32 |
+
params = {}
|
33 |
+
|
34 |
+
if model == "ViT-B/16":
|
35 |
+
if dataset == "imagenet":
|
36 |
+
params['alpha'] = 0.6
|
37 |
+
params['theta'] = 0.8
|
38 |
+
params['N'] = 90
|
39 |
+
elif dataset == "cub":
|
40 |
+
params['alpha'] = 0.6
|
41 |
+
params['theta'] = 0.9
|
42 |
+
params['N'] = 90
|
43 |
+
elif dataset == "pets":
|
44 |
+
params['alpha'] = 0.6
|
45 |
+
params['theta'] = 0.2
|
46 |
+
params['N'] = 80
|
47 |
+
elif dataset == "food":
|
48 |
+
params['alpha'] = 0.6
|
49 |
+
params['theta'] = 0.8
|
50 |
+
params['N'] = 90
|
51 |
+
elif dataset == "place":
|
52 |
+
params['alpha'] = 0.4
|
53 |
+
params['theta'] = 0.8
|
54 |
+
params['N'] = 60
|
55 |
+
elif dataset == "imagenetv2":
|
56 |
+
params['alpha'] = 0.5
|
57 |
+
params['theta'] = 0.8
|
58 |
+
params['N'] = 70
|
59 |
+
elif dataset == "imagenet-r":
|
60 |
+
params['alpha'] = 0.6
|
61 |
+
params['theta'] = 0.8
|
62 |
+
params['N'] = 90
|
63 |
+
elif dataset == "imagenet-a":
|
64 |
+
params['alpha'] = 0.5
|
65 |
+
params['theta'] = 0.95
|
66 |
+
params['N'] = 90
|
67 |
+
elif dataset == "imagenet-s":
|
68 |
+
params['alpha'] = 0.6
|
69 |
+
params['theta'] = 0.8
|
70 |
+
params['N'] = 80
|
71 |
+
elif model == "ViT-B/32":
|
72 |
+
if dataset == "imagenet":
|
73 |
+
params['alpha'] = 0.6
|
74 |
+
params['theta'] = 0.8
|
75 |
+
params['N'] = 90
|
76 |
+
elif dataset == "cub":
|
77 |
+
params['alpha'] = 0.5
|
78 |
+
params['theta'] = 0.95
|
79 |
+
params['N'] = 80
|
80 |
+
elif dataset == "pets":
|
81 |
+
params['alpha'] = 0.6
|
82 |
+
params['theta'] = 0.9
|
83 |
+
params['N'] = 80
|
84 |
+
elif dataset == "food":
|
85 |
+
params['alpha'] = 0.6
|
86 |
+
params['theta'] = 0.9
|
87 |
+
params['N'] = 80
|
88 |
+
elif dataset == "place":
|
89 |
+
params['alpha'] = 0.6
|
90 |
+
params['theta'] = 0.9
|
91 |
+
params['N'] = 80
|
92 |
+
elif model == "ViT-L/14":
|
93 |
+
if dataset == "imagenet":
|
94 |
+
params['alpha'] = 0.6
|
95 |
+
params['theta'] = 0.8
|
96 |
+
params['N'] = 70
|
97 |
+
elif dataset == "cub":
|
98 |
+
params['alpha'] = 0.5
|
99 |
+
params['theta'] = 0.9
|
100 |
+
params['N'] = 80
|
101 |
+
elif dataset == "pets":
|
102 |
+
params['alpha'] = 0.6
|
103 |
+
params['theta'] = 0.8
|
104 |
+
params['N'] = 60
|
105 |
+
elif dataset == "food":
|
106 |
+
params['alpha'] = 0.6
|
107 |
+
params['theta'] = 0.9
|
108 |
+
params['N'] = 70
|
109 |
+
elif dataset == "place":
|
110 |
+
params['alpha'] = 0.4
|
111 |
+
params['theta'] = 0.9
|
112 |
+
params['N'] = 70
|
113 |
+
|
114 |
+
return params
|
115 |
# hyperparameters
|
116 |
|
117 |
hparams['model_size'] = "ViT-B/16"
|
118 |
# Options:
|
119 |
+
# ['ViT-B/32',
|
|
|
|
|
|
|
|
|
|
|
120 |
# 'ViT-B/16',
|
121 |
+
# 'ViT-L/14']
|
|
|
122 |
hparams['dataset'] = 'imagenet'
|
123 |
+
params = get_params(hparams['model_size'], hparams['dataset'])
|
124 |
hparams['max_iter'] = 100
|
125 |
+
hparams['n_samples'] = params['N']
|
126 |
#for mix
|
127 |
+
hparams['theta'] = params['theta']
|
128 |
#for crop
|
129 |
+
hparams['alpha'] = params['alpha']
|
130 |
#for constrain
|
131 |
hparams['gama'] = 0.0
|
132 |
hparams['constrain_type'] = 'att' #['patch','att','const']
|
133 |
|
134 |
|
135 |
+
hparams['batch_size'] = 50
|
136 |
hparams['device'] = "cuda:2"
|
137 |
hparams['category_name_inclusion'] = 'prepend' #'append' 'prepend'
|
138 |
|
|
|
205 |
img = datasets.folder.default_loader(path)
|
206 |
# Process the image and generate additional augmented samples
|
207 |
augmented_imgs = [processor(img)]
|
208 |
+
augmented_imgs.extend(processor(random_crop(img,alpha=hparams['alpha'])) for _ in range(n_samples))
|
209 |
# Return a stacked tensor of all processed images
|
210 |
return torch.stack(augmented_imgs)
|
211 |
|
212 |
|
|
|
213 |
if hparams['dataset'] == 'imagenet':
|
214 |
if hparams['dataset'] == 'imagenet':
|
215 |
dsclass = ImageNet
|
|
|
226 |
elif hparams['dataset'] == 'imagenetv2':
|
227 |
hparams['data_dir'] = pathlib.Path(IMAGENETV2_DIR)
|
228 |
hparams['class_num'] = 1000
|
229 |
+
|
230 |
mydataset = ImageNetV2Dataset(
|
231 |
location=hparams['data_dir'],
|
232 |
transform=None,
|
|
|
241 |
hparams['data_dir'] = pathlib.Path(IMAGENETR_DIR)
|
242 |
dsclass = ImageFolder
|
243 |
hparams['class_num'] = 200
|
244 |
+
|
245 |
mydataset = dsclass(
|
246 |
hparams['data_dir'],
|
247 |
transform=None,
|
|
|
254 |
hparams['data_dir'] = pathlib.Path(IMAGENETA_DIR)
|
255 |
dsclass = ImageFolder
|
256 |
hparams['class_num'] = 200
|
257 |
+
|
258 |
mydataset = dsclass(
|
259 |
hparams['data_dir'],
|
260 |
transform=None,
|
|
|
267 |
hparams['data_dir'] = pathlib.Path(IMAGENETS_DIR)
|
268 |
dsclass = ImageFolder
|
269 |
hparams['class_num'] = 1000
|
270 |
+
|
271 |
mydataset = dsclass(
|
272 |
hparams['data_dir'],
|
273 |
transform=None,
|
|
|
282 |
elif hparams['dataset'] == 'cub':
|
283 |
# load CUB dataset
|
284 |
hparams['data_dir'] = pathlib.Path(CUB_DIR)
|
285 |
+
|
286 |
+
|
287 |
mydataset = CUBDataset(hparams['data_dir'], train=False, transform=None, loader=custom_loader)
|
288 |
classes_to_load = None #dataset.classes
|
289 |
hparams['descriptor_fname'] = 'descriptors_cub'
|
|
|
291 |
|
292 |
# I recommend using VISSL https://github.com/facebookresearch/vissl/blob/main/extra_scripts/README.md to download these
|
293 |
|
294 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
+
elif hparams['dataset'] == 'place':
|
297 |
hparams['class_num'] = 365
|
298 |
hparams['data_dir'] = pathlib.Path(PLACES_DIR)
|
299 |
mydataset = Places365(hparams['data_dir'], split='val', download=False, transform=None, loader=custom_loader)
|
|
|
302 |
hparams['descriptor_fname'] = 'descriptors_places365'
|
303 |
classes_to_load = None
|
304 |
|
305 |
+
elif hparams['dataset'] == 'food':
|
306 |
hparams['data_dir'] = pathlib.Path(FOOD101_DIR)
|
307 |
dsclass = ImageFolder
|
308 |
hparams['class_num'] = 101
|
|
|
319 |
hparams['data_dir'] = pathlib.Path(PETS_DIR)
|
320 |
dsclass = ImageFolder
|
321 |
hparams['class_num'] = 37
|
322 |
+
|
323 |
mydataset = OxfordIIITPet(
|
324 |
hparams['data_dir'],
|
325 |
transform=None,
|
|
|
329 |
hparams['descriptor_fname'] = 'descriptors_pets'
|
330 |
classes_to_load = None
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
|
|
|
|
|
|
|
|
333 |
|
334 |
|
335 |
|
LaZSL-main/main_OP.py
CHANGED
@@ -22,7 +22,7 @@ model.to(device)
|
|
22 |
model.eval()
|
23 |
model.requires_grad_(False)
|
24 |
#op=OP(max_iter=hparams['max_iter'],M=49,N=5,n_cls=hparams['class_num'],b=bs)
|
25 |
-
op_d=OP_d(max_iter=hparams['max_iter'], gama=hparams['gama'],constrain_type=hparams['constrain_type'],
|
26 |
|
27 |
print("Encoding descriptions...")
|
28 |
|
@@ -30,7 +30,7 @@ description_encodings = compute_description_encodings(model)
|
|
30 |
|
31 |
label_encodings = compute_label_encodings(model)
|
32 |
|
33 |
-
print("n_samples: %d \nalpha: %f \
|
34 |
print("constrain_type: %s " %(hparams['constrain_type']))
|
35 |
|
36 |
print("Evaluating...")
|
|
|
22 |
model.eval()
|
23 |
model.requires_grad_(False)
|
24 |
#op=OP(max_iter=hparams['max_iter'],M=49,N=5,n_cls=hparams['class_num'],b=bs)
|
25 |
+
op_d=OP_d(max_iter=hparams['max_iter'], gama=hparams['gama'],constrain_type=hparams['constrain_type'],theta=hparams['theta'])
|
26 |
|
27 |
print("Encoding descriptions...")
|
28 |
|
|
|
30 |
|
31 |
label_encodings = compute_label_encodings(model)
|
32 |
|
33 |
+
print("n_samples: %d \nalpha: %f \ntheta: %f" %(hparams['n_samples'],hparams['alpha'],hparams['theta']))
|
34 |
print("constrain_type: %s " %(hparams['constrain_type']))
|
35 |
|
36 |
print("Evaluating...")
|