Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CION_ReIDZoo/CION_Finetune/FastReID/configs/Market1501/AGW_R50.yml +7 -0
- CION_ReIDZoo/CION_Finetune/FastReID/configs/VeRi/sbs_R50-ibn.yml +32 -0
- CION_ReIDZoo/CION_Finetune/FastReID/configs/VehicleID/bagtricks_R50-ibn.yml +35 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/Caffe/ReadMe.md +21 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/Caffe/caffe_net.py +139 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/Caffe/layer_param.py +158 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/README.md +160 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/caffe_export.py +78 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/caffe_inference.py +95 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/onnx_export.py +146 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/onnx_inference.py +85 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/pytorch_to_caffe.py +747 -0
- CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/trt_export.py +82 -0
- open_clip/.github/workflows/ci.yml +121 -0
- open_clip/.github/workflows/clear-cache.yml +29 -0
- open_clip/.github/workflows/python-publish.yml +37 -0
- open_clip/docs/Interacting_with_open_clip.ipynb +0 -0
- open_clip/docs/clip_recall.png +0 -0
- open_clip/docs/openclip_retrieval_results.csv +122 -0
- open_clip/docs/script_examples/clipa/vit_b16/i50_t16_finetune.sh +27 -0
- open_clip/docs/script_examples/clipa/vit_b16/i50_t16_pretrain.sh +26 -0
- open_clip/docs/script_examples/clipa/vit_l16/i17_t16_finetune.sh +27 -0
- open_clip/docs/script_examples/clipa/vit_l16/i17_t16_pretrain.sh +26 -0
- open_clip/docs/script_examples/clipa/vit_l16/i37_t8_finetune.sh +27 -0
- open_clip/docs/script_examples/clipa/vit_l16/i37_t8_pretrain.sh +26 -0
- open_clip/docs/script_examples/clipav2/vit_h14/i257_t32_finetunex4.sh +32 -0
- open_clip/docs/script_examples/clipav2/vit_h14/i50_t8_pretrain.sh +30 -0
- open_clip/docs/script_examples/clipav2/vit_h14/i577_t32_finetunex1.sh +32 -0
- open_clip/docs/script_examples/stability_example.sh +60 -0
- open_clip/scripts/clipav1_vit_l16_i37_t8.sh +6 -0
- open_clip/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh +10 -0
- open_clip/scripts/h14_84_8_pretrain.sh +31 -0
- open_clip/src/open_clip/constants.py +11 -0
- open_clip/src/open_clip/convert.py +200 -0
- open_clip/src/open_clip/factory.py +578 -0
- open_clip/src/open_clip/hf_model.py +193 -0
- open_clip/src/open_clip/model.py +650 -0
- open_clip/src/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
- open_clip/src/open_clip/model_configs/EVA02-B-16.json +18 -0
- open_clip/src/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
- open_clip/src/open_clip/model_configs/EVA02-L-14-336.json +18 -0
- open_clip/src/open_clip/model_configs/RN50x16-quickgelu.json +22 -0
- open_clip/src/open_clip/model_configs/RN50x4-quickgelu.json +22 -0
- open_clip/src/open_clip/model_configs/RN50x4.json +21 -0
- open_clip/src/open_clip/model_configs/RN50x64.json +21 -0
- open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-256.json +29 -0
- open_clip/src/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
- open_clip/src/open_clip/model_configs/ViT-B-16-plus.json +16 -0
- open_clip/src/open_clip/model_configs/ViT-L-14-CLIPA-336.json +25 -0
- open_clip/src/open_clip/model_configs/ViT-L-14-CLIPA.json +25 -0
CION_ReIDZoo/CION_Finetune/FastReID/configs/Market1501/AGW_R50.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../Base-AGW.yml"
|
2 |
+
|
3 |
+
DATASETS:
|
4 |
+
NAMES: ("Market1501",)
|
5 |
+
TESTS: ("Market1501",)
|
6 |
+
|
7 |
+
OUTPUT_DIR: "logs/market1501/agw_R50"
|
CION_ReIDZoo/CION_Finetune/FastReID/configs/VeRi/sbs_R50-ibn.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../Base-Strongerbaseline.yml"
|
2 |
+
|
3 |
+
INPUT:
|
4 |
+
SIZE_TRAIN: [256, 256]
|
5 |
+
SIZE_TEST: [256, 256]
|
6 |
+
|
7 |
+
MODEL:
|
8 |
+
BACKBONE:
|
9 |
+
WITH_IBN: True
|
10 |
+
|
11 |
+
SOLVER:
|
12 |
+
OPT: "SGD"
|
13 |
+
BASE_LR: 0.01
|
14 |
+
ETA_MIN_LR: 7.7e-5
|
15 |
+
|
16 |
+
IMS_PER_BATCH: 64
|
17 |
+
MAX_ITER: 60
|
18 |
+
DELAY_ITERS: 30
|
19 |
+
WARMUP_ITERS: 10
|
20 |
+
FREEZE_ITERS: 10
|
21 |
+
|
22 |
+
CHECKPOINT_PERIOD: 20
|
23 |
+
|
24 |
+
DATASETS:
|
25 |
+
NAMES: ("VeRi",)
|
26 |
+
TESTS: ("VeRi",)
|
27 |
+
|
28 |
+
TEST:
|
29 |
+
EVAL_PERIOD: 20
|
30 |
+
IMS_PER_BATCH: 128
|
31 |
+
|
32 |
+
OUTPUT_DIR: "logs/veri/sbs_R50-ibn"
|
CION_ReIDZoo/CION_Finetune/FastReID/configs/VehicleID/bagtricks_R50-ibn.yml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../Base-bagtricks.yml"
|
2 |
+
|
3 |
+
INPUT:
|
4 |
+
SIZE_TRAIN: [256, 256]
|
5 |
+
SIZE_TEST: [256, 256]
|
6 |
+
|
7 |
+
MODEL:
|
8 |
+
BACKBONE:
|
9 |
+
WITH_IBN: True
|
10 |
+
HEADS:
|
11 |
+
POOL_LAYER: gempool
|
12 |
+
LOSSES:
|
13 |
+
TRI:
|
14 |
+
HARD_MINING: False
|
15 |
+
MARGIN: 0.0
|
16 |
+
|
17 |
+
DATASETS:
|
18 |
+
NAMES: ("VehicleID",)
|
19 |
+
TESTS: ("SmallVehicleID", "MediumVehicleID", "LargeVehicleID",)
|
20 |
+
|
21 |
+
SOLVER:
|
22 |
+
BIAS_LR_FACTOR: 1.
|
23 |
+
|
24 |
+
IMS_PER_BATCH: 512
|
25 |
+
MAX_ITER: 60
|
26 |
+
STEPS: [30, 50]
|
27 |
+
WARMUP_ITERS: 10
|
28 |
+
|
29 |
+
CHECKPOINT_PERIOD: 20
|
30 |
+
|
31 |
+
TEST:
|
32 |
+
EVAL_PERIOD: 20
|
33 |
+
IMS_PER_BATCH: 128
|
34 |
+
|
35 |
+
OUTPUT_DIR: "logs/vehicleid/bagtricks_R50-ibn_4gpu"
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/Caffe/ReadMe.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The Caffe in nn_tools Provides some convenient API
|
2 |
+
If there are some problem in parse your prototxt or caffemodel, Please replace
|
3 |
+
the caffe.proto with your own version and compile it with command
|
4 |
+
`protoc --python_out ./ caffe.proto`
|
5 |
+
|
6 |
+
## caffe_net.py
|
7 |
+
Using `from nn_tools.Caffe import caffe_net` to import this model
|
8 |
+
### Prototxt
|
9 |
+
+ `net=caffe_net.Prototxt(file_name)` to open a prototxt file
|
10 |
+
+ `net.init_caffemodel(caffe_cmd_path='caffe')` to generate a caffemodel file in the current work directory \
|
11 |
+
if your `caffe` cmd not in the $PATH, specify your caffe cmd path by the `caffe_cmd_path` kwargs.
|
12 |
+
### Caffemodel
|
13 |
+
+ `net=caffe_net.Caffemodel(file_name)` to open a caffemodel
|
14 |
+
+ `net.save_prototxt(path)` to save the caffemodel to a prototxt file (not containing the weight data)
|
15 |
+
+ `net.get_layer_data(layer_name)` return the numpy ndarray data of the layer
|
16 |
+
+ `net.set_layer_date(layer_name, datas)` specify the data of one layer in the caffemodel .`datas` is normally a list of numpy ndarray `[weights,bias]`
|
17 |
+
+ `net.save(path)` save the changed caffemodel
|
18 |
+
### Functions for both Prototxt and Caffemodel
|
19 |
+
+ `net.add_layer(layer_params,before='',after='')` add a new layer with `Layer_Param` object
|
20 |
+
+ `net.remove_layer_by_name(layer_name)`
|
21 |
+
+ `net.get_layer_by_name(layer_name)` or `net.layer(layer_name)` get the raw Layer object defined in caffe_pb2
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/Caffe/caffe_net.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from . import caffe_pb2 as pb
|
3 |
+
import google.protobuf.text_format as text_format
|
4 |
+
import numpy as np
|
5 |
+
from .layer_param import Layer_param
|
6 |
+
|
7 |
+
class _Net(object):
|
8 |
+
def __init__(self):
|
9 |
+
self.net=pb.NetParameter()
|
10 |
+
|
11 |
+
def layer_index(self,layer_name):
|
12 |
+
# find a layer's index by name. if the layer was found, return the layer position in the net, else return -1.
|
13 |
+
for i, layer in enumerate(self.net.layer):
|
14 |
+
if layer.name == layer_name:
|
15 |
+
return i
|
16 |
+
|
17 |
+
def add_layer(self,layer_params,before='',after=''):
|
18 |
+
# find the before of after layer's position
|
19 |
+
index = -1
|
20 |
+
if after != '':
|
21 |
+
index = self.layer_index(after) + 1
|
22 |
+
if before != '':
|
23 |
+
index = self.layer_index(before)
|
24 |
+
new_layer = pb.LayerParameter()
|
25 |
+
new_layer.CopyFrom(layer_params.param)
|
26 |
+
#insert the layer into the layer protolist
|
27 |
+
if index != -1:
|
28 |
+
self.net.layer.add()
|
29 |
+
for i in range(len(self.net.layer) - 1, index, -1):
|
30 |
+
self.net.layer[i].CopyFrom(self.net.layer[i - 1])
|
31 |
+
self.net.layer[index].CopyFrom(new_layer)
|
32 |
+
else:
|
33 |
+
self.net.layer.extend([new_layer])
|
34 |
+
|
35 |
+
def remove_layer_by_name(self,layer_name):
|
36 |
+
for i,layer in enumerate(self.net.layer):
|
37 |
+
if layer.name == layer_name:
|
38 |
+
del self.net.layer[i]
|
39 |
+
return
|
40 |
+
raise(AttributeError, "cannot found layer %s" % str(layer_name))
|
41 |
+
|
42 |
+
def get_layer_by_name(self, layer_name):
|
43 |
+
# get the layer by layer_name
|
44 |
+
for layer in self.net.layer:
|
45 |
+
if layer.name == layer_name:
|
46 |
+
return layer
|
47 |
+
raise(AttributeError, "cannot found layer %s" % str(layer_name))
|
48 |
+
|
49 |
+
def save_prototxt(self,path):
|
50 |
+
prototxt=pb.NetParameter()
|
51 |
+
prototxt.CopyFrom(self.net)
|
52 |
+
for layer in prototxt.layer:
|
53 |
+
del layer.blobs[:]
|
54 |
+
with open(path,'w') as f:
|
55 |
+
f.write(text_format.MessageToString(prototxt))
|
56 |
+
|
57 |
+
def layer(self,layer_name):
|
58 |
+
return self.get_layer_by_name(layer_name)
|
59 |
+
|
60 |
+
def layers(self):
|
61 |
+
return list(self.net.layer)
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
class Prototxt(_Net):
|
66 |
+
def __init__(self,file_name=''):
|
67 |
+
super(Prototxt,self).__init__()
|
68 |
+
self.file_name=file_name
|
69 |
+
if file_name!='':
|
70 |
+
f = open(file_name,'r')
|
71 |
+
text_format.Parse(f.read(), self.net)
|
72 |
+
pass
|
73 |
+
|
74 |
+
def init_caffemodel(self,caffe_cmd_path='caffe'):
|
75 |
+
"""
|
76 |
+
:param caffe_cmd_path: The shell command of caffe, normally at <path-to-caffe>/build/tools/caffe
|
77 |
+
"""
|
78 |
+
s=pb.SolverParameter()
|
79 |
+
s.train_net=self.file_name
|
80 |
+
s.max_iter=0
|
81 |
+
s.base_lr=1
|
82 |
+
s.solver_mode = pb.SolverParameter.CPU
|
83 |
+
s.snapshot_prefix='./nn'
|
84 |
+
with open('/tmp/nn_tools_solver.prototxt','w') as f:
|
85 |
+
f.write(str(s))
|
86 |
+
import os
|
87 |
+
os.system('%s train --solver /tmp/nn_tools_solver.prototxt'%caffe_cmd_path)
|
88 |
+
|
89 |
+
class Caffemodel(_Net):
|
90 |
+
def __init__(self, file_name=''):
|
91 |
+
super(Caffemodel,self).__init__()
|
92 |
+
# caffe_model dir
|
93 |
+
if file_name!='':
|
94 |
+
f = open(file_name,'rb')
|
95 |
+
self.net.ParseFromString(f.read())
|
96 |
+
f.close()
|
97 |
+
|
98 |
+
def save(self, path):
|
99 |
+
with open(path,'wb') as f:
|
100 |
+
f.write(self.net.SerializeToString())
|
101 |
+
|
102 |
+
def add_layer_with_data(self,layer_params,datas, before='', after=''):
|
103 |
+
"""
|
104 |
+
Args:
|
105 |
+
layer_params:A Layer_Param object
|
106 |
+
datas:a fixed dimension numpy object list
|
107 |
+
after: put the layer after a specified layer
|
108 |
+
before: put the layer before a specified layer
|
109 |
+
"""
|
110 |
+
self.add_layer(layer_params,before,after)
|
111 |
+
new_layer =self.layer(layer_params.name)
|
112 |
+
|
113 |
+
#process blobs
|
114 |
+
del new_layer.blobs[:]
|
115 |
+
for data in datas:
|
116 |
+
new_blob=new_layer.blobs.add()
|
117 |
+
for dim in data.shape:
|
118 |
+
new_blob.shape.dim.append(dim)
|
119 |
+
new_blob.data.extend(data.flatten().astype(float))
|
120 |
+
|
121 |
+
def get_layer_data(self,layer_name):
|
122 |
+
layer=self.layer(layer_name)
|
123 |
+
datas=[]
|
124 |
+
for blob in layer.blobs:
|
125 |
+
shape=list(blob.shape.dim)
|
126 |
+
data=np.array(blob.data).reshape(shape)
|
127 |
+
datas.append(data)
|
128 |
+
return datas
|
129 |
+
|
130 |
+
def set_layer_data(self,layer_name,datas):
|
131 |
+
# datas is normally a list of [weights,bias]
|
132 |
+
layer=self.layer(layer_name)
|
133 |
+
for blob,data in zip(layer.blobs,datas):
|
134 |
+
blob.data[:]=data.flatten()
|
135 |
+
pass
|
136 |
+
|
137 |
+
class Net():
|
138 |
+
def __init__(self,*args,**kwargs):
|
139 |
+
raise(TypeError,'the class Net is no longer used, please use Caffemodel or Prototxt instead')
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/Caffe/layer_param.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from . import caffe_pb2 as pb
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def pair_process(item,strict_one=True):
|
6 |
+
if hasattr(item,'__iter__'):
|
7 |
+
for i in item:
|
8 |
+
if i!=item[0]:
|
9 |
+
if strict_one:
|
10 |
+
raise ValueError("number in item {} must be the same".format(item))
|
11 |
+
else:
|
12 |
+
print("IMPORTANT WARNING: number in item {} must be the same".format(item))
|
13 |
+
return item[0]
|
14 |
+
return item
|
15 |
+
|
16 |
+
def pair_reduce(item):
|
17 |
+
if hasattr(item,'__iter__'):
|
18 |
+
for i in item:
|
19 |
+
if i!=item[0]:
|
20 |
+
return item
|
21 |
+
return [item[0]]
|
22 |
+
return [item]
|
23 |
+
|
24 |
+
class Layer_param():
|
25 |
+
def __init__(self,name='',type='',top=(),bottom=()):
|
26 |
+
self.param=pb.LayerParameter()
|
27 |
+
self.name=self.param.name=name
|
28 |
+
self.type=self.param.type=type
|
29 |
+
|
30 |
+
self.top=self.param.top
|
31 |
+
self.top.extend(top)
|
32 |
+
self.bottom=self.param.bottom
|
33 |
+
self.bottom.extend(bottom)
|
34 |
+
|
35 |
+
def fc_param(self, num_output, weight_filler='xavier', bias_filler='constant',has_bias=True):
|
36 |
+
if self.type != 'InnerProduct':
|
37 |
+
raise TypeError('the layer type must be InnerProduct if you want set fc param')
|
38 |
+
fc_param = pb.InnerProductParameter()
|
39 |
+
fc_param.num_output = num_output
|
40 |
+
fc_param.weight_filler.type = weight_filler
|
41 |
+
fc_param.bias_term = has_bias
|
42 |
+
if has_bias:
|
43 |
+
fc_param.bias_filler.type = bias_filler
|
44 |
+
self.param.inner_product_param.CopyFrom(fc_param)
|
45 |
+
|
46 |
+
def conv_param(self, num_output, kernel_size, stride=(1), pad=(0,),
|
47 |
+
weight_filler_type='xavier', bias_filler_type='constant',
|
48 |
+
bias_term=True, dilation=None,groups=None):
|
49 |
+
"""
|
50 |
+
add a conv_param layer if you spec the layer type "Convolution"
|
51 |
+
Args:
|
52 |
+
num_output: a int
|
53 |
+
kernel_size: int list
|
54 |
+
stride: a int list
|
55 |
+
weight_filler_type: the weight filer type
|
56 |
+
bias_filler_type: the bias filler type
|
57 |
+
Returns:
|
58 |
+
"""
|
59 |
+
if self.type not in ['Convolution','Deconvolution']:
|
60 |
+
raise TypeError('the layer type must be Convolution or Deconvolution if you want set conv param')
|
61 |
+
conv_param=pb.ConvolutionParameter()
|
62 |
+
conv_param.num_output=num_output
|
63 |
+
conv_param.kernel_size.extend(pair_reduce(kernel_size))
|
64 |
+
conv_param.stride.extend(pair_reduce(stride))
|
65 |
+
conv_param.pad.extend(pair_reduce(pad))
|
66 |
+
conv_param.bias_term=bias_term
|
67 |
+
conv_param.weight_filler.type=weight_filler_type
|
68 |
+
if bias_term:
|
69 |
+
conv_param.bias_filler.type = bias_filler_type
|
70 |
+
if dilation:
|
71 |
+
conv_param.dilation.extend(pair_reduce(dilation))
|
72 |
+
if groups:
|
73 |
+
conv_param.group=groups
|
74 |
+
self.param.convolution_param.CopyFrom(conv_param)
|
75 |
+
|
76 |
+
def pool_param(self,type='MAX',kernel_size=2,stride=2,pad=None, ceil_mode = False):
|
77 |
+
pool_param=pb.PoolingParameter()
|
78 |
+
pool_param.pool=pool_param.PoolMethod.Value(type)
|
79 |
+
pool_param.kernel_size=pair_process(kernel_size)
|
80 |
+
pool_param.stride=pair_process(stride)
|
81 |
+
pool_param.ceil_mode=ceil_mode
|
82 |
+
if pad:
|
83 |
+
if isinstance(pad,tuple):
|
84 |
+
pool_param.pad_h = pad[0]
|
85 |
+
pool_param.pad_w = pad[1]
|
86 |
+
else:
|
87 |
+
pool_param.pad=pad
|
88 |
+
self.param.pooling_param.CopyFrom(pool_param)
|
89 |
+
|
90 |
+
def batch_norm_param(self,use_global_stats=0,moving_average_fraction=None,eps=None):
|
91 |
+
bn_param=pb.BatchNormParameter()
|
92 |
+
bn_param.use_global_stats=use_global_stats
|
93 |
+
if moving_average_fraction:
|
94 |
+
bn_param.moving_average_fraction=moving_average_fraction
|
95 |
+
if eps:
|
96 |
+
bn_param.eps = eps
|
97 |
+
self.param.batch_norm_param.CopyFrom(bn_param)
|
98 |
+
|
99 |
+
# layer
|
100 |
+
# {
|
101 |
+
# name: "upsample_layer"
|
102 |
+
# type: "Upsample"
|
103 |
+
# bottom: "some_input_feature_map"
|
104 |
+
# bottom: "some_input_pool_index"
|
105 |
+
# top: "some_output"
|
106 |
+
# upsample_param {
|
107 |
+
# upsample_h: 224
|
108 |
+
# upsample_w: 224
|
109 |
+
# }
|
110 |
+
# }
|
111 |
+
def upsample_param(self,size=None, scale_factor=None):
|
112 |
+
upsample_param=pb.UpsampleParameter()
|
113 |
+
if scale_factor:
|
114 |
+
if isinstance(scale_factor,int):
|
115 |
+
upsample_param.scale = scale_factor
|
116 |
+
else:
|
117 |
+
upsample_param.scale_h = scale_factor[0]
|
118 |
+
upsample_param.scale_w = scale_factor[1]
|
119 |
+
|
120 |
+
if size:
|
121 |
+
if isinstance(size,int):
|
122 |
+
upsample_param.upsample_h = size
|
123 |
+
else:
|
124 |
+
upsample_param.upsample_h = size[0]
|
125 |
+
upsample_param.upsample_w = size[1]
|
126 |
+
#upsample_param.upsample_h = size[0] * scale_factor
|
127 |
+
#upsample_param.upsample_w = size[1] * scale_factor
|
128 |
+
self.param.upsample_param.CopyFrom(upsample_param)
|
129 |
+
def interp_param(self,size=None, scale_factor=None):
|
130 |
+
interp_param=pb.InterpParameter()
|
131 |
+
if scale_factor:
|
132 |
+
if isinstance(scale_factor,int):
|
133 |
+
interp_param.zoom_factor = scale_factor
|
134 |
+
|
135 |
+
if size:
|
136 |
+
print('size:', size)
|
137 |
+
interp_param.height = size[0]
|
138 |
+
interp_param.width = size[1]
|
139 |
+
self.param.interp_param.CopyFrom(interp_param)
|
140 |
+
|
141 |
+
def add_data(self,*args):
|
142 |
+
"""Args are data numpy array
|
143 |
+
"""
|
144 |
+
del self.param.blobs[:]
|
145 |
+
for data in args:
|
146 |
+
new_blob = self.param.blobs.add()
|
147 |
+
for dim in data.shape:
|
148 |
+
new_blob.shape.dim.append(dim)
|
149 |
+
new_blob.data.extend(data.flatten().astype(float))
|
150 |
+
|
151 |
+
def set_params_by_dict(self,dic):
|
152 |
+
pass
|
153 |
+
|
154 |
+
def copy_from(self,layer_param):
|
155 |
+
pass
|
156 |
+
|
157 |
+
def set_enum(param,key,value):
|
158 |
+
setattr(param,key,param.Value(value))
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/README.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Deployment
|
2 |
+
|
3 |
+
This directory contains:
|
4 |
+
|
5 |
+
1. The scripts that convert a fastreid model to Caffe/ONNX/TRT format.
|
6 |
+
|
7 |
+
2. The exmpales that load a R50 baseline model in Caffe/ONNX/TRT and run inference.
|
8 |
+
|
9 |
+
## Tutorial
|
10 |
+
|
11 |
+
### Caffe Convert
|
12 |
+
|
13 |
+
<details>
|
14 |
+
<summary>step-to-step pipeline for caffe convert</summary>
|
15 |
+
|
16 |
+
This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe model, if you want to convert more complex architecture, you need to customize more things.
|
17 |
+
|
18 |
+
1. Run `caffe_export.py` to get the converted Caffe model,
|
19 |
+
|
20 |
+
```bash
|
21 |
+
python caffe_export.py --config-file root-path/market1501/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/caffe_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
22 |
+
```
|
23 |
+
|
24 |
+
then you can check the Caffe model and prototxt in `outputs/caffe_model`.
|
25 |
+
|
26 |
+
2. Change `prototxt` following next three steps:
|
27 |
+
|
28 |
+
1) Edit `max_pooling` in `baseline_R50.prototxt` like this
|
29 |
+
|
30 |
+
```prototxt
|
31 |
+
layer {
|
32 |
+
name: "max_pool1"
|
33 |
+
type: "Pooling"
|
34 |
+
bottom: "relu_blob1"
|
35 |
+
top: "max_pool_blob1"
|
36 |
+
pooling_param {
|
37 |
+
pool: MAX
|
38 |
+
kernel_size: 3
|
39 |
+
stride: 2
|
40 |
+
pad: 0 # 1
|
41 |
+
# ceil_mode: false
|
42 |
+
}
|
43 |
+
}
|
44 |
+
```
|
45 |
+
|
46 |
+
2) Add `avg_pooling` right place in `baseline_R50.prototxt`
|
47 |
+
|
48 |
+
```prototxt
|
49 |
+
layer {
|
50 |
+
name: "avgpool1"
|
51 |
+
type: "Pooling"
|
52 |
+
bottom: "relu_blob49"
|
53 |
+
top: "avgpool_blob1"
|
54 |
+
pooling_param {
|
55 |
+
pool: AVE
|
56 |
+
global_pooling: true
|
57 |
+
}
|
58 |
+
}
|
59 |
+
```
|
60 |
+
|
61 |
+
3) Change the last layer `top` name to `output`
|
62 |
+
|
63 |
+
```prototxt
|
64 |
+
layer {
|
65 |
+
name: "bn_scale54"
|
66 |
+
type: "Scale"
|
67 |
+
bottom: "batch_norm_blob54"
|
68 |
+
top: "output" # bn_norm_blob54
|
69 |
+
scale_param {
|
70 |
+
bias_term: true
|
71 |
+
}
|
72 |
+
}
|
73 |
+
```
|
74 |
+
|
75 |
+
3. (optional) You can open [Netscope](https://ethereon.github.io/netscope/quickstart.html), then enter you network `prototxt` to visualize the network.
|
76 |
+
|
77 |
+
4. Run `caffe_inference.py` to save Caffe model features with input images
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python caffe_inference.py --model-def outputs/caffe_model/baseline_R50.prototxt \
|
81 |
+
--model-weights outputs/caffe_model/baseline_R50.caffemodel \
|
82 |
+
--input test_data/*.jpg --output caffe_output
|
83 |
+
```
|
84 |
+
|
85 |
+
5. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that Caffe and PyTorch are computing the same value for the network.
|
86 |
+
|
87 |
+
```python
|
88 |
+
np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
|
89 |
+
```
|
90 |
+
|
91 |
+
</details>
|
92 |
+
|
93 |
+
### ONNX Convert
|
94 |
+
|
95 |
+
<details>
|
96 |
+
<summary>step-to-step pipeline for onnx convert</summary>
|
97 |
+
|
98 |
+
This is a tiny example for converting fastreid-baseline in `meta_arch` to ONNX model. ONNX supports most operators in pytorch as far as I know and if some operators are not supported by ONNX, you need to customize these.
|
99 |
+
|
100 |
+
1. Run `onnx_export.py` to get the converted ONNX model,
|
101 |
+
|
102 |
+
```bash
|
103 |
+
python onnx_export.py --config-file root-path/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/onnx_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
104 |
+
```
|
105 |
+
|
106 |
+
then you can check the ONNX model in `outputs/onnx_model`.
|
107 |
+
|
108 |
+
2. (optional) You can use [Netron](https://github.com/lutzroeder/netron) to visualize the network.
|
109 |
+
|
110 |
+
3. Run `onnx_inference.py` to save ONNX model features with input images
|
111 |
+
|
112 |
+
```bash
|
113 |
+
python onnx_inference.py --model-path outputs/onnx_model/baseline_R50.onnx \
|
114 |
+
--input test_data/*.jpg --output onnx_output
|
115 |
+
```
|
116 |
+
|
117 |
+
4. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that ONNX Runtime and PyTorch are computing the same value for the network.
|
118 |
+
|
119 |
+
```python
|
120 |
+
np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
|
121 |
+
```
|
122 |
+
|
123 |
+
</details>
|
124 |
+
|
125 |
+
### TensorRT Convert
|
126 |
+
|
127 |
+
<details>
|
128 |
+
<summary>step-to-step pipeline for trt convert</summary>
|
129 |
+
|
130 |
+
This is a tiny example for converting fastreid-baseline in `meta_arch` to TRT model. We use [tiny-tensorrt](https://github.com/zerollzeng/tiny-tensorrt), which is a simple and easy-to-use nvidia TensorRT warpper, to get the model converted to tensorRT.
|
131 |
+
|
132 |
+
First you need to convert the pytorch model to ONNX format following [ONNX Convert](https://github.com/JDAI-CV/fast-reid/tree/master/tools/deploy#onnx-convert), and you need to remember your `output` name. Then you can convert ONNX model to TensorRT following instructions below.
|
133 |
+
|
134 |
+
1. Run command line below to get the converted TRT model from ONNX model,
|
135 |
+
|
136 |
+
```bash
|
137 |
+
|
138 |
+
python trt_export.py --name "baseline_R50" --output outputs/trt_model --onnx-model outputs/onnx_model/baseline.onnx --heighi 256 --width 128
|
139 |
+
```
|
140 |
+
|
141 |
+
then you can check the TRT model in `outputs/trt_model`.
|
142 |
+
|
143 |
+
2. Run `trt_inference.py` to save TRT model features with input images
|
144 |
+
|
145 |
+
```bash
|
146 |
+
python onnx_inference.py --model-path outputs/trt_model/baseline.engine \
|
147 |
+
--input test_data/*.jpg --output trt_output --output-name trt_model_outputname
|
148 |
+
```
|
149 |
+
|
150 |
+
3. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that TensorRT and PyTorch are computing the same value for the network.
|
151 |
+
|
152 |
+
```python
|
153 |
+
np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
|
154 |
+
```
|
155 |
+
|
156 |
+
</details>
|
157 |
+
|
158 |
+
## Acknowledgements
|
159 |
+
|
160 |
+
Thank to [CPFLAME](https://github.com/CPFLAME), [gcong18](https://github.com/gcong18), [YuxiangJohn](https://github.com/YuxiangJohn) and [wiggin66](https://github.com/wiggin66) at JDAI Model Acceleration Group for help in PyTorch model converting.
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/caffe_export.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
"""
|
3 |
+
@author: xingyu liao
|
4 |
+
@contact: [email protected]
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import sys
|
11 |
+
sys.path.append('../../')
|
12 |
+
|
13 |
+
import pytorch_to_caffe
|
14 |
+
from fastreid.config import get_cfg
|
15 |
+
from fastreid.modeling.meta_arch import build_model
|
16 |
+
from fastreid.utils.file_io import PathManager
|
17 |
+
from fastreid.utils.checkpoint import Checkpointer
|
18 |
+
from fastreid.utils.logger import setup_logger
|
19 |
+
|
20 |
+
logger = setup_logger(name='caffe_export')
|
21 |
+
|
22 |
+
|
23 |
+
def setup_cfg(args):
|
24 |
+
cfg = get_cfg()
|
25 |
+
cfg.merge_from_file(args.config_file)
|
26 |
+
cfg.merge_from_list(args.opts)
|
27 |
+
cfg.freeze()
|
28 |
+
return cfg
|
29 |
+
|
30 |
+
|
31 |
+
def get_parser():
|
32 |
+
parser = argparse.ArgumentParser(description="Convert Pytorch to Caffe model")
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
"--config-file",
|
36 |
+
metavar="FILE",
|
37 |
+
help="path to config file",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--name",
|
41 |
+
default="baseline",
|
42 |
+
help="name for converted model"
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--output",
|
46 |
+
default='caffe_model',
|
47 |
+
help='path to save converted caffe model'
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--opts",
|
51 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
52 |
+
default=[],
|
53 |
+
nargs=argparse.REMAINDER,
|
54 |
+
)
|
55 |
+
return parser
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == '__main__':
|
59 |
+
args = get_parser().parse_args()
|
60 |
+
cfg = setup_cfg(args)
|
61 |
+
|
62 |
+
cfg.defrost()
|
63 |
+
cfg.MODEL.BACKBONE.PRETRAIN = False
|
64 |
+
cfg.MODEL.HEADS.POOL_LAYER = "identity"
|
65 |
+
cfg.MODEL.BACKBONE.WITH_NL = False
|
66 |
+
|
67 |
+
model = build_model(cfg)
|
68 |
+
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
|
69 |
+
model.eval()
|
70 |
+
logger.info(model)
|
71 |
+
|
72 |
+
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(torch.device(cfg.MODEL.DEVICE))
|
73 |
+
PathManager.mkdirs(args.output)
|
74 |
+
pytorch_to_caffe.trans_net(model, inputs, args.name)
|
75 |
+
pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")
|
76 |
+
pytorch_to_caffe.save_caffemodel(f"{args.output}/{args.name}.caffemodel")
|
77 |
+
|
78 |
+
logger.info(f"Export caffe model in {args.output} sucessfully!")
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/caffe_inference.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
"""
|
3 |
+
@author: xingyu liao
|
4 |
+
@contact: [email protected]
|
5 |
+
"""
|
6 |
+
|
7 |
+
import caffe
|
8 |
+
import tqdm
|
9 |
+
import glob
|
10 |
+
import os
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
caffe.set_mode_gpu()
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
|
19 |
+
def get_parser():
|
20 |
+
parser = argparse.ArgumentParser(description="Caffe model inference")
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"--model-def",
|
24 |
+
default="logs/test_caffe/baseline_R50.prototxt",
|
25 |
+
help="caffe model prototxt"
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--model-weights",
|
29 |
+
default="logs/test_caffe/baseline_R50.caffemodel",
|
30 |
+
help="caffe model weights"
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--input",
|
34 |
+
nargs="+",
|
35 |
+
help="A list of space separated input images; "
|
36 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--output",
|
40 |
+
default='caffe_output',
|
41 |
+
help='path to save converted caffe model'
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--height",
|
45 |
+
type=int,
|
46 |
+
default=256,
|
47 |
+
help="height of image"
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--width",
|
51 |
+
type=int,
|
52 |
+
default=128,
|
53 |
+
help="width of image"
|
54 |
+
)
|
55 |
+
return parser
|
56 |
+
|
57 |
+
|
58 |
+
def preprocess(image_path, image_height, image_width):
|
59 |
+
original_image = cv2.imread(image_path)
|
60 |
+
# the model expects RGB inputs
|
61 |
+
original_image = original_image[:, :, ::-1]
|
62 |
+
|
63 |
+
# Apply pre-processing to image.
|
64 |
+
image = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
65 |
+
image = image.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
|
66 |
+
image = (image - np.array([0.485 * 255, 0.456 * 255, 0.406 * 255]).reshape((1, -1, 1, 1))) / np.array(
|
67 |
+
[0.229 * 255, 0.224 * 255, 0.225 * 255]).reshape((1, -1, 1, 1))
|
68 |
+
return image
|
69 |
+
|
70 |
+
|
71 |
+
def normalize(nparray, order=2, axis=-1):
|
72 |
+
"""Normalize a N-D numpy array along the specified axis."""
|
73 |
+
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
74 |
+
return nparray / (norm + np.finfo(np.float32).eps)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
args = get_parser().parse_args()
|
79 |
+
|
80 |
+
net = caffe.Net(args.model_def, args.model_weights, caffe.TEST)
|
81 |
+
net.blobs['blob1'].reshape(1, 3, args.height, args.width)
|
82 |
+
|
83 |
+
if not os.path.exists(args.output): os.makedirs(args.output)
|
84 |
+
|
85 |
+
if args.input:
|
86 |
+
if os.path.isdir(args.input[0]):
|
87 |
+
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
88 |
+
assert args.input, "The input path(s) was not found"
|
89 |
+
for path in tqdm.tqdm(args.input):
|
90 |
+
image = preprocess(path, args.height, args.width)
|
91 |
+
net.blobs['blob1'].data[...] = image
|
92 |
+
feat = net.forward()['output']
|
93 |
+
feat = normalize(feat[..., 0, 0], axis=1)
|
94 |
+
np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
|
95 |
+
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/onnx_export.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
"""
|
3 |
+
@author: xingyu liao
|
4 |
+
@contact: [email protected]
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import io
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import onnx
|
12 |
+
import torch
|
13 |
+
from onnxsim import simplify
|
14 |
+
from torch.onnx import OperatorExportTypes
|
15 |
+
|
16 |
+
sys.path.append('../../')
|
17 |
+
|
18 |
+
from fastreid.config import get_cfg
|
19 |
+
from fastreid.modeling.meta_arch import build_model
|
20 |
+
from fastreid.utils.file_io import PathManager
|
21 |
+
from fastreid.utils.checkpoint import Checkpointer
|
22 |
+
from fastreid.utils.logger import setup_logger
|
23 |
+
|
24 |
+
logger = setup_logger(name='onnx_export')
|
25 |
+
|
26 |
+
|
27 |
+
def setup_cfg(args):
|
28 |
+
cfg = get_cfg()
|
29 |
+
cfg.merge_from_file(args.config_file)
|
30 |
+
cfg.merge_from_list(args.opts)
|
31 |
+
cfg.freeze()
|
32 |
+
return cfg
|
33 |
+
|
34 |
+
|
35 |
+
def get_parser():
|
36 |
+
parser = argparse.ArgumentParser(description="Convert Pytorch to ONNX model")
|
37 |
+
|
38 |
+
parser.add_argument(
|
39 |
+
"--config-file",
|
40 |
+
metavar="FILE",
|
41 |
+
help="path to config file",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--name",
|
45 |
+
default="baseline",
|
46 |
+
help="name for converted model"
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--output",
|
50 |
+
default='onnx_model',
|
51 |
+
help='path to save converted onnx model'
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--opts",
|
55 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
56 |
+
default=[],
|
57 |
+
nargs=argparse.REMAINDER,
|
58 |
+
)
|
59 |
+
return parser
|
60 |
+
|
61 |
+
|
62 |
+
def remove_initializer_from_input(model):
|
63 |
+
if model.ir_version < 4:
|
64 |
+
print(
|
65 |
+
'Model with ir_version below 4 requires to include initilizer in graph input'
|
66 |
+
)
|
67 |
+
return
|
68 |
+
|
69 |
+
inputs = model.graph.input
|
70 |
+
name_to_input = {}
|
71 |
+
for input in inputs:
|
72 |
+
name_to_input[input.name] = input
|
73 |
+
|
74 |
+
for initializer in model.graph.initializer:
|
75 |
+
if initializer.name in name_to_input:
|
76 |
+
inputs.remove(name_to_input[initializer.name])
|
77 |
+
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def export_onnx_model(model, inputs):
|
82 |
+
"""
|
83 |
+
Trace and export a model to onnx format.
|
84 |
+
Args:
|
85 |
+
model (nn.Module):
|
86 |
+
inputs (torch.Tensor): the model will be called by `model(*inputs)`
|
87 |
+
Returns:
|
88 |
+
an onnx model
|
89 |
+
"""
|
90 |
+
assert isinstance(model, torch.nn.Module)
|
91 |
+
|
92 |
+
# make sure all modules are in eval mode, onnx may change the training state
|
93 |
+
# of the module if the states are not consistent
|
94 |
+
def _check_eval(module):
|
95 |
+
assert not module.training
|
96 |
+
|
97 |
+
model.apply(_check_eval)
|
98 |
+
|
99 |
+
# Export the model to ONNX
|
100 |
+
with torch.no_grad():
|
101 |
+
with io.BytesIO() as f:
|
102 |
+
torch.onnx.export(
|
103 |
+
model,
|
104 |
+
inputs,
|
105 |
+
f,
|
106 |
+
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
107 |
+
# verbose=True, # NOTE: uncomment this for debugging
|
108 |
+
# export_params=True,
|
109 |
+
)
|
110 |
+
onnx_model = onnx.load_from_string(f.getvalue())
|
111 |
+
|
112 |
+
# Apply ONNX's Optimization
|
113 |
+
all_passes = onnx.optimizer.get_available_passes()
|
114 |
+
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
|
115 |
+
assert all(p in all_passes for p in passes)
|
116 |
+
onnx_model = onnx.optimizer.optimize(onnx_model, passes)
|
117 |
+
return onnx_model
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
args = get_parser().parse_args()
|
122 |
+
cfg = setup_cfg(args)
|
123 |
+
|
124 |
+
cfg.defrost()
|
125 |
+
cfg.MODEL.BACKBONE.PRETRAIN = False
|
126 |
+
if cfg.MODEL.HEADS.POOL_LAYER == 'fastavgpool':
|
127 |
+
cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
128 |
+
model = build_model(cfg)
|
129 |
+
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
|
130 |
+
model.eval()
|
131 |
+
logger.info(model)
|
132 |
+
|
133 |
+
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
|
134 |
+
onnx_model = export_onnx_model(model, inputs)
|
135 |
+
|
136 |
+
model_simp, check = simplify(onnx_model)
|
137 |
+
|
138 |
+
model_simp = remove_initializer_from_input(model_simp)
|
139 |
+
|
140 |
+
assert check, "Simplified ONNX model could not be validated"
|
141 |
+
|
142 |
+
PathManager.mkdirs(args.output)
|
143 |
+
|
144 |
+
onnx.save_model(model_simp, f"{args.output}/{args.name}.onnx")
|
145 |
+
|
146 |
+
logger.info(f"Export onnx model in {args.output} successfully!")
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/onnx_inference.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
"""
|
3 |
+
@author: xingyu liao
|
4 |
+
@contact: [email protected]
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import glob
|
9 |
+
import os
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import onnxruntime
|
14 |
+
import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
def get_parser():
|
18 |
+
parser = argparse.ArgumentParser(description="onnx model inference")
|
19 |
+
|
20 |
+
parser.add_argument(
|
21 |
+
"--model-path",
|
22 |
+
default="onnx_model/baseline.onnx",
|
23 |
+
help="onnx model path"
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--input",
|
27 |
+
nargs="+",
|
28 |
+
help="A list of space separated input images; "
|
29 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--output",
|
33 |
+
default='onnx_output',
|
34 |
+
help='path to save converted caffe model'
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--height",
|
38 |
+
type=int,
|
39 |
+
default=256,
|
40 |
+
help="height of image"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--width",
|
44 |
+
type=int,
|
45 |
+
default=128,
|
46 |
+
help="width of image"
|
47 |
+
)
|
48 |
+
return parser
|
49 |
+
|
50 |
+
|
51 |
+
def preprocess(image_path, image_height, image_width):
|
52 |
+
original_image = cv2.imread(image_path)
|
53 |
+
# the model expects RGB inputs
|
54 |
+
original_image = original_image[:, :, ::-1]
|
55 |
+
|
56 |
+
# Apply pre-processing to image.
|
57 |
+
img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
58 |
+
img = img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
|
59 |
+
return img
|
60 |
+
|
61 |
+
|
62 |
+
def normalize(nparray, order=2, axis=-1):
|
63 |
+
"""Normalize a N-D numpy array along the specified axis."""
|
64 |
+
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
65 |
+
return nparray / (norm + np.finfo(np.float32).eps)
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
args = get_parser().parse_args()
|
70 |
+
|
71 |
+
ort_sess = onnxruntime.InferenceSession(args.model_path)
|
72 |
+
|
73 |
+
input_name = ort_sess.get_inputs()[0].name
|
74 |
+
|
75 |
+
if not os.path.exists(args.output): os.makedirs(args.output)
|
76 |
+
|
77 |
+
if args.input:
|
78 |
+
if os.path.isdir(args.input[0]):
|
79 |
+
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
80 |
+
assert args.input, "The input path(s) was not found"
|
81 |
+
for path in tqdm.tqdm(args.input):
|
82 |
+
image = preprocess(path, args.height, args.width)
|
83 |
+
feat = ort_sess.run(None, {input_name: image})[0]
|
84 |
+
feat = normalize(feat, axis=1)
|
85 |
+
np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/pytorch_to_caffe.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import traceback
|
7 |
+
from Caffe import caffe_net
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.autograd import Variable
|
10 |
+
from Caffe import layer_param
|
11 |
+
from torch.nn.modules.utils import _pair
|
12 |
+
import numpy as np
|
13 |
+
import math
|
14 |
+
from torch.nn.modules.utils import _list_with_default
|
15 |
+
|
16 |
+
"""
|
17 |
+
How to support a new layer type:
|
18 |
+
layer_name=log.add_layer(layer_type_name)
|
19 |
+
top_blobs=log.add_blobs(<output of that layer>)
|
20 |
+
layer=caffe_net.Layer_param(xxx)
|
21 |
+
<set layer parameters>
|
22 |
+
[<layer.add_data(*datas)>]
|
23 |
+
log.cnet.add_layer(layer)
|
24 |
+
|
25 |
+
Please MUTE the inplace operations to avoid not find in graph
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
# TODO: support the inplace output of the layers
|
30 |
+
|
31 |
+
class Blob_LOG():
|
32 |
+
def __init__(self):
|
33 |
+
self.data = {}
|
34 |
+
|
35 |
+
def __setitem__(self, key, value):
|
36 |
+
self.data[key] = value
|
37 |
+
|
38 |
+
def __getitem__(self, key):
|
39 |
+
return self.data[key]
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.data)
|
43 |
+
|
44 |
+
|
45 |
+
NET_INITTED = False
|
46 |
+
|
47 |
+
|
48 |
+
# 转换原理解析:通过记录
|
49 |
+
class TransLog(object):
|
50 |
+
def __init__(self):
|
51 |
+
"""
|
52 |
+
doing init() with inputs Variable before using it
|
53 |
+
"""
|
54 |
+
self.layers = {}
|
55 |
+
self.detail_layers = {}
|
56 |
+
self.detail_blobs = {}
|
57 |
+
self._blobs = Blob_LOG()
|
58 |
+
self._blobs_data = []
|
59 |
+
self.cnet = caffe_net.Caffemodel('')
|
60 |
+
self.debug = True
|
61 |
+
|
62 |
+
def init(self, inputs):
|
63 |
+
"""
|
64 |
+
:param inputs: is a list of input variables
|
65 |
+
"""
|
66 |
+
self.add_blobs(inputs)
|
67 |
+
|
68 |
+
def add_layer(self, name='layer'):
|
69 |
+
if name in self.layers:
|
70 |
+
return self.layers[name]
|
71 |
+
if name not in self.detail_layers.keys():
|
72 |
+
self.detail_layers[name] = 0
|
73 |
+
self.detail_layers[name] += 1
|
74 |
+
name = '{}{}'.format(name, self.detail_layers[name])
|
75 |
+
self.layers[name] = name
|
76 |
+
if self.debug:
|
77 |
+
print("{} was added to layers".format(self.layers[name]))
|
78 |
+
return self.layers[name]
|
79 |
+
|
80 |
+
def add_blobs(self, blobs, name='blob', with_num=True):
|
81 |
+
rst = []
|
82 |
+
for blob in blobs:
|
83 |
+
self._blobs_data.append(blob) # to block the memory address be rewrited
|
84 |
+
blob_id = int(id(blob))
|
85 |
+
if name not in self.detail_blobs.keys():
|
86 |
+
self.detail_blobs[name] = 0
|
87 |
+
self.detail_blobs[name] += 1
|
88 |
+
if with_num:
|
89 |
+
rst.append('{}{}'.format(name, self.detail_blobs[name]))
|
90 |
+
else:
|
91 |
+
rst.append('{}'.format(name))
|
92 |
+
if self.debug:
|
93 |
+
print("{}:{} was added to blobs".format(blob_id, rst[-1]))
|
94 |
+
print('Add blob {} : {}'.format(rst[-1].center(21), blob.size()))
|
95 |
+
self._blobs[blob_id] = rst[-1]
|
96 |
+
return rst
|
97 |
+
|
98 |
+
def blobs(self, var):
|
99 |
+
var = id(var)
|
100 |
+
if self.debug:
|
101 |
+
print("{}:{} getting".format(var, self._blobs[var]))
|
102 |
+
try:
|
103 |
+
return self._blobs[var]
|
104 |
+
except:
|
105 |
+
print("WARNING: CANNOT FOUND blob {}".format(var))
|
106 |
+
return None
|
107 |
+
|
108 |
+
|
109 |
+
log = TransLog()
|
110 |
+
|
111 |
+
layer_names = {}
|
112 |
+
|
113 |
+
|
114 |
+
def _conv2d(raw, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
115 |
+
x = raw(input, weight, bias, stride, padding, dilation, groups)
|
116 |
+
name = log.add_layer(name='conv')
|
117 |
+
log.add_blobs([x], name='conv_blob')
|
118 |
+
layer = caffe_net.Layer_param(name=name, type='Convolution',
|
119 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
120 |
+
layer.conv_param(x.size()[1], weight.size()[2:], stride=_pair(stride),
|
121 |
+
pad=_pair(padding), dilation=_pair(dilation), bias_term=bias is not None, groups=groups)
|
122 |
+
if bias is not None:
|
123 |
+
layer.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
124 |
+
#print('conv2d weight, bias: ',weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
125 |
+
|
126 |
+
else:
|
127 |
+
layer.param.convolution_param.bias_term = False
|
128 |
+
layer.add_data(weight.cpu().data.numpy())
|
129 |
+
log.cnet.add_layer(layer)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
def _conv_transpose2d(raw, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
134 |
+
x = raw(input, weight, bias, stride, padding, output_padding, groups, dilation)
|
135 |
+
name = log.add_layer(name='conv_transpose')
|
136 |
+
log.add_blobs([x], name='conv_transpose_blob')
|
137 |
+
layer = caffe_net.Layer_param(name=name, type='Deconvolution',
|
138 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
139 |
+
layer.conv_param(x.size()[1], weight.size()[2:], stride=_pair(stride),
|
140 |
+
pad=_pair(padding), dilation=_pair(dilation), bias_term=bias is not None)
|
141 |
+
if bias is not None:
|
142 |
+
layer.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
143 |
+
else:
|
144 |
+
layer.param.convolution_param.bias_term = False
|
145 |
+
layer.add_data(weight.cpu().data.numpy())
|
146 |
+
log.cnet.add_layer(layer)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
def _linear(raw, input, weight, bias=None):
|
151 |
+
x = raw(input, weight, bias)
|
152 |
+
layer_name = log.add_layer(name='fc')
|
153 |
+
top_blobs = log.add_blobs([x], name='fc_blob')
|
154 |
+
layer = caffe_net.Layer_param(name=layer_name, type='InnerProduct',
|
155 |
+
bottom=[log.blobs(input)], top=top_blobs)
|
156 |
+
layer.fc_param(x.size()[1], has_bias=bias is not None)
|
157 |
+
if bias is not None:
|
158 |
+
layer.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
159 |
+
else:
|
160 |
+
layer.add_data(weight.cpu().data.numpy())
|
161 |
+
log.cnet.add_layer(layer)
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
def _split(raw, tensor, split_size, dim=0):
|
166 |
+
# split in pytorch is slice in caffe
|
167 |
+
x = raw(tensor, split_size, dim)
|
168 |
+
layer_name = log.add_layer('split')
|
169 |
+
top_blobs = log.add_blobs(x, name='split_blob')
|
170 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Slice',
|
171 |
+
bottom=[log.blobs(tensor)], top=top_blobs)
|
172 |
+
slice_num = int(np.floor(tensor.size()[dim] / split_size))
|
173 |
+
slice_param = caffe_net.pb.SliceParameter(axis=dim, slice_point=[split_size * i for i in range(1, slice_num)])
|
174 |
+
layer.param.slice_param.CopyFrom(slice_param)
|
175 |
+
log.cnet.add_layer(layer)
|
176 |
+
return x
|
177 |
+
|
178 |
+
|
179 |
+
def _pool(type, raw, input, x, kernel_size, stride, padding, ceil_mode):
|
180 |
+
# TODO dilation,ceil_mode,return indices
|
181 |
+
layer_name = log.add_layer(name='{}_pool'.format(type))
|
182 |
+
top_blobs = log.add_blobs([x], name='{}_pool_blob'.format(type))
|
183 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Pooling', bottom=[log.blobs(input)], top=top_blobs)
|
184 |
+
|
185 |
+
# TODO w,h different kernel, stride and padding
|
186 |
+
# processing ceil mode
|
187 |
+
layer.pool_param(kernel_size=kernel_size, stride=kernel_size if stride is None else stride,
|
188 |
+
pad=padding, type=type.upper())
|
189 |
+
log.cnet.add_layer(layer)
|
190 |
+
if ceil_mode == False and stride is not None:
|
191 |
+
oheight = (input.size()[2] - _pair(kernel_size)[0] + 2 * _pair(padding)[0]) % (_pair(stride)[0])
|
192 |
+
owidth = (input.size()[3] - _pair(kernel_size)[1] + 2 * _pair(padding)[1]) % (_pair(stride)[1])
|
193 |
+
if oheight != 0 or owidth != 0:
|
194 |
+
caffe_out = raw(input, kernel_size, stride, padding, ceil_mode=False)
|
195 |
+
print("WARNING: the output shape miss match at {}: "
|
196 |
+
|
197 |
+
"input {} output---Pytorch:{}---Caffe:{}\n"
|
198 |
+
"This is caused by the different implementation that ceil mode in caffe and the floor mode in pytorch.\n"
|
199 |
+
"You can add the clip layer in caffe prototxt manually if shape mismatch error is caused in caffe. ".format(
|
200 |
+
layer_name, input.size(), x.size(), caffe_out.size()))
|
201 |
+
|
202 |
+
|
203 |
+
def _max_pool2d(raw, input, kernel_size, stride=None, padding=0, dilation=1,
|
204 |
+
ceil_mode=False, return_indices=False):
|
205 |
+
x = raw(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices)
|
206 |
+
_pool('max', raw, input, x, kernel_size, stride, padding, ceil_mode)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
def _avg_pool2d(raw, input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
211 |
+
x = raw(input, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
212 |
+
_pool('ave', raw, input, x, kernel_size, stride, padding, ceil_mode)
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
def _max(raw, *args):
|
217 |
+
x = raw(*args)
|
218 |
+
if len(args) == 1:
|
219 |
+
# TODO max in one tensor
|
220 |
+
assert NotImplementedError
|
221 |
+
else:
|
222 |
+
bottom_blobs = []
|
223 |
+
for arg in args:
|
224 |
+
bottom_blobs.append(log.blobs(arg))
|
225 |
+
layer_name = log.add_layer(name='max')
|
226 |
+
top_blobs = log.add_blobs([x], name='max_blob')
|
227 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
228 |
+
bottom=bottom_blobs, top=top_blobs)
|
229 |
+
layer.param.eltwise_param.operation = 2
|
230 |
+
log.cnet.add_layer(layer)
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
def _cat(raw, inputs, dimension=0):
|
235 |
+
x = raw(inputs, dimension)
|
236 |
+
bottom_blobs = []
|
237 |
+
for input in inputs:
|
238 |
+
bottom_blobs.append(log.blobs(input))
|
239 |
+
layer_name = log.add_layer(name='cat')
|
240 |
+
top_blobs = log.add_blobs([x], name='cat_blob')
|
241 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Concat',
|
242 |
+
bottom=bottom_blobs, top=top_blobs)
|
243 |
+
layer.param.concat_param.axis = dimension
|
244 |
+
log.cnet.add_layer(layer)
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
def _dropout(raw, input, p=0.5, training=False, inplace=False):
|
249 |
+
x = raw(input, p, training, inplace)
|
250 |
+
bottom_blobs = [log.blobs(input)]
|
251 |
+
layer_name = log.add_layer(name='dropout')
|
252 |
+
top_blobs = log.add_blobs([x], name=bottom_blobs[0], with_num=False)
|
253 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Dropout',
|
254 |
+
bottom=bottom_blobs, top=top_blobs)
|
255 |
+
layer.param.dropout_param.dropout_ratio = p
|
256 |
+
layer.param.include.extend([caffe_net.pb.NetStateRule(phase=0)]) # 1 for test, 0 for train
|
257 |
+
log.cnet.add_layer(layer)
|
258 |
+
return x
|
259 |
+
|
260 |
+
|
261 |
+
def _threshold(raw, input, threshold, value, inplace=False):
|
262 |
+
# for threshold or relu
|
263 |
+
if threshold == 0 and value == 0:
|
264 |
+
x = raw(input, threshold, value, inplace)
|
265 |
+
bottom_blobs = [log.blobs(input)]
|
266 |
+
name = log.add_layer(name='relu')
|
267 |
+
log.add_blobs([x], name='relu_blob')
|
268 |
+
layer = caffe_net.Layer_param(name=name, type='ReLU',
|
269 |
+
bottom=bottom_blobs, top=[log.blobs(x)])
|
270 |
+
log.cnet.add_layer(layer)
|
271 |
+
return x
|
272 |
+
if value != 0:
|
273 |
+
raise NotImplemented("value !=0 not implemented in caffe")
|
274 |
+
x = raw(input, input, threshold, value, inplace)
|
275 |
+
bottom_blobs = [log.blobs(input)]
|
276 |
+
layer_name = log.add_layer(name='threshold')
|
277 |
+
top_blobs = log.add_blobs([x], name='threshold_blob')
|
278 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Threshold',
|
279 |
+
bottom=bottom_blobs, top=top_blobs)
|
280 |
+
layer.param.threshold_param.threshold = threshold
|
281 |
+
log.cnet.add_layer(layer)
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
def _relu(raw, input, inplace=False):
|
286 |
+
# for threshold or prelu
|
287 |
+
x = raw(input, False)
|
288 |
+
name = log.add_layer(name='relu')
|
289 |
+
log.add_blobs([x], name='relu_blob')
|
290 |
+
layer = caffe_net.Layer_param(name=name, type='ReLU',
|
291 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
292 |
+
log.cnet.add_layer(layer)
|
293 |
+
return x
|
294 |
+
|
295 |
+
|
296 |
+
def _prelu(raw, input, weight):
|
297 |
+
# for threshold or prelu
|
298 |
+
x = raw(input, weight)
|
299 |
+
bottom_blobs = [log.blobs(input)]
|
300 |
+
name = log.add_layer(name='prelu')
|
301 |
+
log.add_blobs([x], name='prelu_blob')
|
302 |
+
layer = caffe_net.Layer_param(name=name, type='PReLU',
|
303 |
+
bottom=bottom_blobs, top=[log.blobs(x)])
|
304 |
+
if weight.size()[0] == 1:
|
305 |
+
layer.param.prelu_param.channel_shared = True
|
306 |
+
layer.add_data(weight.cpu().data.numpy()[0])
|
307 |
+
else:
|
308 |
+
layer.add_data(weight.cpu().data.numpy())
|
309 |
+
log.cnet.add_layer(layer)
|
310 |
+
return x
|
311 |
+
|
312 |
+
|
313 |
+
def _leaky_relu(raw, input, negative_slope=0.01, inplace=False):
|
314 |
+
x = raw(input, negative_slope)
|
315 |
+
name = log.add_layer(name='leaky_relu')
|
316 |
+
log.add_blobs([x], name='leaky_relu_blob')
|
317 |
+
layer = caffe_net.Layer_param(name=name, type='ReLU',
|
318 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
319 |
+
layer.param.relu_param.negative_slope = negative_slope
|
320 |
+
log.cnet.add_layer(layer)
|
321 |
+
return x
|
322 |
+
|
323 |
+
|
324 |
+
def _tanh(raw, input):
|
325 |
+
# for tanh activation
|
326 |
+
x = raw(input)
|
327 |
+
name = log.add_layer(name='tanh')
|
328 |
+
log.add_blobs([x], name='tanh_blob')
|
329 |
+
layer = caffe_net.Layer_param(name=name, type='TanH',
|
330 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
331 |
+
log.cnet.add_layer(layer)
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
def _softmax(raw, input, dim=None, _stacklevel=3):
|
336 |
+
# for F.softmax
|
337 |
+
x = raw(input, dim=dim)
|
338 |
+
if dim is None:
|
339 |
+
dim = F._get_softmax_dim('softmax', input.dim(), _stacklevel)
|
340 |
+
bottom_blobs = [log.blobs(input)]
|
341 |
+
name = log.add_layer(name='softmax')
|
342 |
+
log.add_blobs([x], name='softmax_blob')
|
343 |
+
layer = caffe_net.Layer_param(name=name, type='Softmax',
|
344 |
+
bottom=bottom_blobs, top=[log.blobs(x)])
|
345 |
+
layer.param.softmax_param.axis = dim
|
346 |
+
log.cnet.add_layer(layer)
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
def _sigmoid(raw, input):
|
351 |
+
# for tanh activation
|
352 |
+
x = raw(input)
|
353 |
+
name = log.add_layer(name='Sigmoid')
|
354 |
+
log.add_blobs([x], name='Sigmoid_blob')
|
355 |
+
layer = caffe_net.Layer_param(name=name, type='Sigmoid',
|
356 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
357 |
+
log.cnet.add_layer(layer)
|
358 |
+
return x
|
359 |
+
|
360 |
+
|
361 |
+
def _batch_norm(raw, input, running_mean, running_var, weight=None, bias=None,
|
362 |
+
training=False, momentum=0.1, eps=1e-5):
|
363 |
+
# because the runing_mean and runing_var will be changed after the _batch_norm operation, we first save the parameters
|
364 |
+
|
365 |
+
x = raw(input, running_mean, running_var, weight, bias,
|
366 |
+
training, momentum, eps)
|
367 |
+
bottom_blobs = [log.blobs(input)]
|
368 |
+
layer_name1 = log.add_layer(name='batch_norm')
|
369 |
+
top_blobs = log.add_blobs([x], name='batch_norm_blob')
|
370 |
+
layer1 = caffe_net.Layer_param(name=layer_name1, type='BatchNorm',
|
371 |
+
bottom=bottom_blobs, top=top_blobs)
|
372 |
+
if running_mean is None or running_var is None:
|
373 |
+
# not use global_stats, normalization is performed over the current mini-batch
|
374 |
+
layer1.batch_norm_param(use_global_stats=0, eps=eps)
|
375 |
+
else:
|
376 |
+
layer1.batch_norm_param(use_global_stats=1, eps=eps)
|
377 |
+
running_mean_clone = running_mean.clone()
|
378 |
+
running_var_clone = running_var.clone()
|
379 |
+
layer1.add_data(running_mean_clone.cpu().numpy(), running_var_clone.cpu().numpy(), np.array([1.0]))
|
380 |
+
#print('running_mean: ',running_mean_clone.cpu().numpy())
|
381 |
+
#print('running_var: ',running_var_clone.cpu().numpy())
|
382 |
+
log.cnet.add_layer(layer1)
|
383 |
+
if weight is not None and bias is not None:
|
384 |
+
layer_name2 = log.add_layer(name='bn_scale')
|
385 |
+
layer2 = caffe_net.Layer_param(name=layer_name2, type='Scale',
|
386 |
+
bottom=top_blobs, top=top_blobs)
|
387 |
+
layer2.param.scale_param.bias_term = True
|
388 |
+
layer2.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
389 |
+
log.cnet.add_layer(layer2)
|
390 |
+
#print('scale weight: ', weight.cpu().data.numpy())
|
391 |
+
#print('scale bias: ', bias.cpu().data.numpy())
|
392 |
+
return x
|
393 |
+
|
394 |
+
|
395 |
+
def _instance_norm(raw, input, running_mean=None, running_var=None, weight=None,
|
396 |
+
bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
|
397 |
+
# TODO: the batch size!=1 view operations
|
398 |
+
print("WARNING: The Instance Normalization transfers to Caffe using BatchNorm, so the batch size should be 1")
|
399 |
+
if running_var is not None or weight is not None:
|
400 |
+
# TODO: the affine=True or track_running_stats=True case
|
401 |
+
raise NotImplementedError("not implement the affine=True or track_running_stats=True case InstanceNorm")
|
402 |
+
x = torch.batch_norm(
|
403 |
+
input, weight, bias, running_mean, running_var,
|
404 |
+
use_input_stats, momentum, eps, torch.backends.cudnn.enabled)
|
405 |
+
bottom_blobs = [log.blobs(input)]
|
406 |
+
layer_name1 = log.add_layer(name='instance_norm')
|
407 |
+
top_blobs = log.add_blobs([x], name='instance_norm_blob')
|
408 |
+
layer1 = caffe_net.Layer_param(name=layer_name1, type='BatchNorm',
|
409 |
+
bottom=bottom_blobs, top=top_blobs)
|
410 |
+
if running_mean is None or running_var is None:
|
411 |
+
# not use global_stats, normalization is performed over the current mini-batch
|
412 |
+
layer1.batch_norm_param(use_global_stats=0, eps=eps)
|
413 |
+
running_mean = torch.zeros(input.size()[1])
|
414 |
+
running_var = torch.ones(input.size()[1])
|
415 |
+
else:
|
416 |
+
layer1.batch_norm_param(use_global_stats=1, eps=eps)
|
417 |
+
running_mean_clone = running_mean.clone()
|
418 |
+
running_var_clone = running_var.clone()
|
419 |
+
layer1.add_data(running_mean_clone.cpu().numpy(), running_var_clone.cpu().numpy(), np.array([1.0]))
|
420 |
+
log.cnet.add_layer(layer1)
|
421 |
+
if weight is not None and bias is not None:
|
422 |
+
layer_name2 = log.add_layer(name='bn_scale')
|
423 |
+
layer2 = caffe_net.Layer_param(name=layer_name2, type='Scale',
|
424 |
+
bottom=top_blobs, top=top_blobs)
|
425 |
+
layer2.param.scale_param.bias_term = True
|
426 |
+
layer2.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
427 |
+
log.cnet.add_layer(layer2)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
# upsample layer
|
432 |
+
def _interpolate(raw, input, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
433 |
+
# 定义的参数包括 scale,即输出与输入的尺寸比例,如 2;scale_h、scale_w,
|
434 |
+
# 同 scale,分别为 h、w 方向上的尺寸比例;pad_out_h、pad_out_w,仅在 scale 为 2 时
|
435 |
+
# 有用,对输出进行额外 padding 在 h、w 方向上的数值;upsample_h、upsample_w,输
|
436 |
+
# 出图像尺寸的数值。在 Upsample 的相关代码中,推荐仅仅使用 upsample_h、
|
437 |
+
# upsample_w 准确定义 Upsample 层的输出尺寸,其他所有的参数都不推荐继续使用。
|
438 |
+
'''
|
439 |
+
if mode == 'bilinear':
|
440 |
+
x = raw(input, size, scale_factor, mode)
|
441 |
+
name = log.add_layer(name='conv_transpose')
|
442 |
+
log.add_blobs([x], name='conv_transpose_blob')
|
443 |
+
layer = caffe_net.Layer_param(name=name, type='Deconvolution',
|
444 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
445 |
+
print('Deconv: ', name)
|
446 |
+
print(input.shape)
|
447 |
+
print(x.size())
|
448 |
+
print(size)
|
449 |
+
factor = float(size[0]) / input.shape[2]
|
450 |
+
C = x.size()[1]
|
451 |
+
print(factor,C)
|
452 |
+
kernel_size = int(2 * factor - factor % 2)
|
453 |
+
stride = int(factor)
|
454 |
+
num_output = C
|
455 |
+
group = C
|
456 |
+
pad = math.ceil((factor-1) / 2.)
|
457 |
+
print('kernel_size, stride, num_output, group, pad')
|
458 |
+
print(kernel_size, stride, num_output, group, pad)
|
459 |
+
layer.conv_param(num_output, kernel_size, stride=stride,
|
460 |
+
pad=pad, weight_filler_type='bilinear', bias_term=False, groups=group)
|
461 |
+
|
462 |
+
layer.param.convolution_param.bias_term = False
|
463 |
+
log.cnet.add_layer(layer)
|
464 |
+
return x
|
465 |
+
'''
|
466 |
+
# transfer bilinear align_corners=True to caffe-interp
|
467 |
+
if mode == "bilinear" and align_corners == True:
|
468 |
+
x = raw(input, size, scale_factor, mode)
|
469 |
+
name = log.add_layer(name='interp')
|
470 |
+
log.add_blobs([x], name='interp_blob')
|
471 |
+
layer = caffe_net.Layer_param(name=name, type='Interp',
|
472 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
473 |
+
layer.interp_param(size=size, scale_factor=scale_factor)
|
474 |
+
log.cnet.add_layer(layer)
|
475 |
+
return x
|
476 |
+
|
477 |
+
# for nearest _interpolate
|
478 |
+
if mode != "nearest" or align_corners != None:
|
479 |
+
raise NotImplementedError("not implement F.interpolate totoaly")
|
480 |
+
x = raw(input, size, scale_factor, mode)
|
481 |
+
layer_name = log.add_layer(name='upsample')
|
482 |
+
top_blobs = log.add_blobs([x], name='upsample_blob'.format(type))
|
483 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Upsample',
|
484 |
+
bottom=[log.blobs(input)], top=top_blobs)
|
485 |
+
#layer.upsample_param(size=(input.size(2), input.size(3)), scale_factor=scale_factor)
|
486 |
+
#layer.upsample_param(size=size, scale_factor=scale_factor)
|
487 |
+
layer.upsample_param(size=None, scale_factor=size[0])
|
488 |
+
|
489 |
+
log.cnet.add_layer(layer)
|
490 |
+
return x
|
491 |
+
|
492 |
+
|
493 |
+
# ----- for Variable operations --------
|
494 |
+
|
495 |
+
def _view(input, *args):
|
496 |
+
x = raw_view(input, *args)
|
497 |
+
if not NET_INITTED:
|
498 |
+
return x
|
499 |
+
layer_name = log.add_layer(name='view')
|
500 |
+
top_blobs = log.add_blobs([x], name='view_blob')
|
501 |
+
|
502 |
+
# print('*'*60)
|
503 |
+
# print('input={}'.format(input))
|
504 |
+
# print('layer_name={}'.format(layer_name))
|
505 |
+
# print('top_blobs={}'.format(top_blobs))
|
506 |
+
|
507 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Reshape', bottom=[log.blobs(input)], top=top_blobs)
|
508 |
+
# TODO: reshpae added to nn_tools layer
|
509 |
+
dims = list(args)
|
510 |
+
dims[0] = 0 # the first dim should be batch_size
|
511 |
+
layer.param.reshape_param.shape.CopyFrom(caffe_net.pb.BlobShape(dim=dims))
|
512 |
+
log.cnet.add_layer(layer)
|
513 |
+
return x
|
514 |
+
|
515 |
+
|
516 |
+
def _mean(input, *args, **kwargs):
|
517 |
+
x = raw_mean(input, *args, **kwargs)
|
518 |
+
if not NET_INITTED:
|
519 |
+
return x
|
520 |
+
layer_name = log.add_layer(name='mean')
|
521 |
+
top_blobs = log.add_blobs([x], name='mean_blob')
|
522 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Reduction',
|
523 |
+
bottom=[log.blobs(input)], top=top_blobs)
|
524 |
+
if len(args) == 1:
|
525 |
+
dim = args[0]
|
526 |
+
elif 'dim' in kwargs:
|
527 |
+
dim = kwargs['dim']
|
528 |
+
else:
|
529 |
+
raise NotImplementedError('mean operation must specify a dim')
|
530 |
+
layer.param.reduction_param.operation = 4
|
531 |
+
layer.param.reduction_param.axis = dim
|
532 |
+
log.cnet.add_layer(layer)
|
533 |
+
return x
|
534 |
+
|
535 |
+
|
536 |
+
def _add(input, *args):
|
537 |
+
# check if add a const value
|
538 |
+
if isinstance(args[0], int):
|
539 |
+
print('value: ',args[0])
|
540 |
+
x = raw__add__(input, *args)
|
541 |
+
#x = raw(input)
|
542 |
+
layer_name = log.add_layer(name='scale')
|
543 |
+
log.add_blobs([x], name='Scale_blob')
|
544 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Scale',
|
545 |
+
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
546 |
+
dim = x.shape[1]
|
547 |
+
layer.param.scale_param.bias_term = True
|
548 |
+
weight = np.ones(dim, dtype=np.float32)
|
549 |
+
bias = args[0] * np.ones(dim, dtype=np.float32)
|
550 |
+
layer.add_data(weight, bias)
|
551 |
+
log.cnet.add_layer(layer)
|
552 |
+
return x
|
553 |
+
# otherwise add a tensor
|
554 |
+
x = raw__add__(input, *args)
|
555 |
+
if not NET_INITTED:
|
556 |
+
return x
|
557 |
+
layer_name = log.add_layer(name='add')
|
558 |
+
top_blobs = log.add_blobs([x], name='add_blob')
|
559 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
560 |
+
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
561 |
+
layer.param.eltwise_param.operation = 1 # sum is 1
|
562 |
+
log.cnet.add_layer(layer)
|
563 |
+
return x
|
564 |
+
|
565 |
+
|
566 |
+
def _iadd(input, *args):
|
567 |
+
x = raw__iadd__(input, *args)
|
568 |
+
if not NET_INITTED:
|
569 |
+
return x
|
570 |
+
x = x.clone()
|
571 |
+
layer_name = log.add_layer(name='add')
|
572 |
+
top_blobs = log.add_blobs([x], name='add_blob')
|
573 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
574 |
+
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
575 |
+
layer.param.eltwise_param.operation = 1 # sum is 1
|
576 |
+
log.cnet.add_layer(layer)
|
577 |
+
return x
|
578 |
+
|
579 |
+
|
580 |
+
def _sub(input, *args):
|
581 |
+
x = raw__sub__(input, *args)
|
582 |
+
if not NET_INITTED:
|
583 |
+
return x
|
584 |
+
layer_name = log.add_layer(name='sub')
|
585 |
+
top_blobs = log.add_blobs([x], name='sub_blob')
|
586 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
587 |
+
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
588 |
+
layer.param.eltwise_param.operation = 1 # sum is 1
|
589 |
+
layer.param.eltwise_param.coeff.extend([1., -1.])
|
590 |
+
log.cnet.add_layer(layer)
|
591 |
+
return x
|
592 |
+
|
593 |
+
|
594 |
+
def _isub(input, *args):
|
595 |
+
x = raw__isub__(input, *args)
|
596 |
+
if not NET_INITTED:
|
597 |
+
return x
|
598 |
+
x = x.clone()
|
599 |
+
layer_name = log.add_layer(name='sub')
|
600 |
+
top_blobs = log.add_blobs([x], name='sub_blob')
|
601 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
602 |
+
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
603 |
+
layer.param.eltwise_param.operation = 1 # sum is 1
|
604 |
+
log.cnet.add_layer(layer)
|
605 |
+
return x
|
606 |
+
|
607 |
+
|
608 |
+
def _mul(input, *args):
|
609 |
+
x = raw__sub__(input, *args)
|
610 |
+
if not NET_INITTED:
|
611 |
+
return x
|
612 |
+
layer_name = log.add_layer(name='mul')
|
613 |
+
top_blobs = log.add_blobs([x], name='mul_blob')
|
614 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
615 |
+
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
616 |
+
layer.param.eltwise_param.operation = 0 # product is 1
|
617 |
+
log.cnet.add_layer(layer)
|
618 |
+
return x
|
619 |
+
|
620 |
+
|
621 |
+
def _imul(input, *args):
|
622 |
+
x = raw__isub__(input, *args)
|
623 |
+
if not NET_INITTED:
|
624 |
+
return x
|
625 |
+
x = x.clone()
|
626 |
+
layer_name = log.add_layer(name='mul')
|
627 |
+
top_blobs = log.add_blobs([x], name='mul_blob')
|
628 |
+
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
629 |
+
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
630 |
+
layer.param.eltwise_param.operation = 0 # product is 1
|
631 |
+
layer.param.eltwise_param.coeff.extend([1., -1.])
|
632 |
+
log.cnet.add_layer(layer)
|
633 |
+
return x
|
634 |
+
|
635 |
+
|
636 |
+
def _adaptive_avg_pool2d(raw, input, output_size):
|
637 |
+
_output_size = _list_with_default(output_size, input.size())
|
638 |
+
x = raw(input, _output_size)
|
639 |
+
_pool('ave', raw, input, x, input.shape[2], input.shape[2], 0, False)
|
640 |
+
return x
|
641 |
+
|
642 |
+
|
643 |
+
# 核心组件,通过该类,实现对torch的function中的operators的输入,输出以及参数的读取
|
644 |
+
class Rp(object):
|
645 |
+
def __init__(self, raw, replace, **kwargs):
|
646 |
+
# replace the raw function to replace function
|
647 |
+
self.obj = replace
|
648 |
+
self.raw = raw
|
649 |
+
|
650 |
+
def __call__(self, *args, **kwargs):
|
651 |
+
if not NET_INITTED:
|
652 |
+
return self.raw(*args, **kwargs)
|
653 |
+
for stack in traceback.walk_stack(None):
|
654 |
+
if 'self' in stack[0].f_locals:
|
655 |
+
layer = stack[0].f_locals['self']
|
656 |
+
if layer in layer_names:
|
657 |
+
log.pytorch_layer_name = layer_names[layer]
|
658 |
+
print(layer_names[layer])
|
659 |
+
break
|
660 |
+
out = self.obj(self.raw, *args, **kwargs)
|
661 |
+
# if isinstance(out,Variable):
|
662 |
+
# out=[out]
|
663 |
+
return out
|
664 |
+
|
665 |
+
|
666 |
+
F.conv2d = Rp(F.conv2d, _conv2d)
|
667 |
+
F.linear = Rp(F.linear, _linear)
|
668 |
+
F.relu = Rp(F.relu, _relu)
|
669 |
+
|
670 |
+
F.leaky_relu = Rp(F.leaky_relu, _leaky_relu)
|
671 |
+
F.max_pool2d = Rp(F.max_pool2d, _max_pool2d)
|
672 |
+
F.avg_pool2d = Rp(F.avg_pool2d, _avg_pool2d)
|
673 |
+
F.dropout = Rp(F.dropout, _dropout)
|
674 |
+
F.threshold = Rp(F.threshold, _threshold)
|
675 |
+
F.prelu = Rp(F.prelu, _prelu)
|
676 |
+
F.batch_norm = Rp(F.batch_norm, _batch_norm)
|
677 |
+
F.instance_norm = Rp(F.instance_norm, _instance_norm)
|
678 |
+
F.softmax = Rp(F.softmax, _softmax)
|
679 |
+
F.conv_transpose2d = Rp(F.conv_transpose2d, _conv_transpose2d)
|
680 |
+
F.interpolate = Rp(F.interpolate, _interpolate)
|
681 |
+
F.adaptive_avg_pool2d = Rp(F.adaptive_avg_pool2d, _adaptive_avg_pool2d)
|
682 |
+
|
683 |
+
torch.split = Rp(torch.split, _split)
|
684 |
+
torch.max = Rp(torch.max, _max)
|
685 |
+
torch.cat = Rp(torch.cat, _cat)
|
686 |
+
torch.sigmoid = Rp(torch.sigmoid, _sigmoid)
|
687 |
+
|
688 |
+
# TODO: other types of the view function
|
689 |
+
try:
|
690 |
+
raw_view = Variable.view
|
691 |
+
Variable.view = _view
|
692 |
+
raw_mean = Variable.mean
|
693 |
+
Variable.mean = _mean
|
694 |
+
raw__add__ = Variable.__add__
|
695 |
+
Variable.__add__ = _add
|
696 |
+
raw__iadd__ = Variable.__iadd__
|
697 |
+
Variable.__iadd__ = _iadd
|
698 |
+
raw__sub__ = Variable.__sub__
|
699 |
+
Variable.__sub__ = _sub
|
700 |
+
raw__isub__ = Variable.__isub__
|
701 |
+
Variable.__isub__ = _isub
|
702 |
+
raw__mul__ = Variable.__mul__
|
703 |
+
Variable.__mul__ = _mul
|
704 |
+
raw__imul__ = Variable.__imul__
|
705 |
+
Variable.__imul__ = _imul
|
706 |
+
except:
|
707 |
+
# for new version 0.4.0 and later version
|
708 |
+
for t in [torch.Tensor]:
|
709 |
+
raw_view = t.view
|
710 |
+
t.view = _view
|
711 |
+
raw_mean = t.mean
|
712 |
+
t.mean = _mean
|
713 |
+
raw__add__ = t.__add__
|
714 |
+
t.__add__ = _add
|
715 |
+
raw__iadd__ = t.__iadd__
|
716 |
+
t.__iadd__ = _iadd
|
717 |
+
raw__sub__ = t.__sub__
|
718 |
+
t.__sub__ = _sub
|
719 |
+
raw__isub__ = t.__isub__
|
720 |
+
t.__isub__ = _isub
|
721 |
+
raw__mul__ = t.__mul__
|
722 |
+
t.__mul__ = _mul
|
723 |
+
raw__imul__ = t.__imul__
|
724 |
+
t.__imul__ = _imul
|
725 |
+
|
726 |
+
|
727 |
+
def trans_net(net, input_var, name='TransferedPytorchModel'):
|
728 |
+
print('Starting Transform, This will take a while')
|
729 |
+
log.init([input_var])
|
730 |
+
log.cnet.net.name = name
|
731 |
+
log.cnet.net.input.extend([log.blobs(input_var)])
|
732 |
+
log.cnet.net.input_dim.extend(input_var.size())
|
733 |
+
global NET_INITTED
|
734 |
+
NET_INITTED = True
|
735 |
+
for name, layer in net.named_modules():
|
736 |
+
layer_names[layer] = name
|
737 |
+
print("torch ops name:", layer_names)
|
738 |
+
out = net.forward(input_var)
|
739 |
+
print('Transform Completed')
|
740 |
+
|
741 |
+
|
742 |
+
def save_prototxt(save_name):
|
743 |
+
log.cnet.save_prototxt(save_name)
|
744 |
+
|
745 |
+
|
746 |
+
def save_caffemodel(save_name):
|
747 |
+
log.cnet.save(save_name)
|
CION_ReIDZoo/CION_Finetune/FastReID/tools/deploy/trt_export.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
"""
|
3 |
+
@author: xingyu liao
|
4 |
+
@contact: [email protected]
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
import sys
|
11 |
+
|
12 |
+
sys.path.append('../../')
|
13 |
+
sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
|
14 |
+
|
15 |
+
import pytrt
|
16 |
+
from fastreid.utils.logger import setup_logger
|
17 |
+
from fastreid.utils.file_io import PathManager
|
18 |
+
|
19 |
+
|
20 |
+
logger = setup_logger(name='trt_export')
|
21 |
+
|
22 |
+
|
23 |
+
def get_parser():
|
24 |
+
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
"--name",
|
28 |
+
default="baseline",
|
29 |
+
help="name for converted model"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--output",
|
33 |
+
default='outputs/trt_model',
|
34 |
+
help='path to save converted trt model'
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--onnx-model",
|
38 |
+
default='outputs/onnx_model/baseline.onnx',
|
39 |
+
help='path to onnx model'
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--height",
|
43 |
+
type=int,
|
44 |
+
default=256,
|
45 |
+
help="height of image"
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--width",
|
49 |
+
type=int,
|
50 |
+
default=128,
|
51 |
+
help="width of image"
|
52 |
+
)
|
53 |
+
return parser
|
54 |
+
|
55 |
+
|
56 |
+
def export_trt_model(onnxModel, engineFile, input_numpy_array):
|
57 |
+
r"""
|
58 |
+
Export a model to trt format.
|
59 |
+
"""
|
60 |
+
|
61 |
+
trt = pytrt.Trt()
|
62 |
+
|
63 |
+
customOutput = []
|
64 |
+
maxBatchSize = 1
|
65 |
+
calibratorData = []
|
66 |
+
mode = 2
|
67 |
+
trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
|
68 |
+
trt.DoInference(input_numpy_array) # slightly different from c++
|
69 |
+
return 0
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
+
args = get_parser().parse_args()
|
74 |
+
|
75 |
+
inputs = np.zeros(shape=(32, args.height, args.width, 3))
|
76 |
+
onnxModel = args.onnx_model
|
77 |
+
engineFile = os.path.join(args.output, args.name+'.engine')
|
78 |
+
|
79 |
+
PathManager.mkdirs(args.output)
|
80 |
+
export_trt_model(onnxModel, engineFile, inputs)
|
81 |
+
|
82 |
+
logger.info(f"Export trt model in {args.output} successfully!")
|
open_clip/.github/workflows/ci.yml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Continuous integration
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
paths-ignore:
|
8 |
+
- '**.md'
|
9 |
+
- 'CITATION.cff'
|
10 |
+
- 'LICENSE'
|
11 |
+
- '.gitignore'
|
12 |
+
- 'docs/**'
|
13 |
+
pull_request:
|
14 |
+
branches:
|
15 |
+
- main
|
16 |
+
paths-ignore:
|
17 |
+
- '**.md'
|
18 |
+
- 'CITATION.cff'
|
19 |
+
- 'LICENSE'
|
20 |
+
- '.gitignore'
|
21 |
+
- 'docs/**'
|
22 |
+
workflow_dispatch:
|
23 |
+
inputs:
|
24 |
+
manual_revision_reference:
|
25 |
+
required: false
|
26 |
+
type: string
|
27 |
+
manual_revision_test:
|
28 |
+
required: false
|
29 |
+
type: string
|
30 |
+
|
31 |
+
env:
|
32 |
+
REVISION_REFERENCE: v2.8.2
|
33 |
+
#9d31b2ec4df6d8228f370ff20c8267ec6ba39383 earliest compatible v2.7.0 + pretrained_hf param
|
34 |
+
|
35 |
+
jobs:
|
36 |
+
Tests:
|
37 |
+
strategy:
|
38 |
+
matrix:
|
39 |
+
os: [ ubuntu-latest ] #, macos-latest ]
|
40 |
+
python: [ 3.8 ]
|
41 |
+
job_num: [ 4 ]
|
42 |
+
job: [ 1, 2, 3, 4 ]
|
43 |
+
runs-on: ${{ matrix.os }}
|
44 |
+
steps:
|
45 |
+
- uses: actions/checkout@v3
|
46 |
+
with:
|
47 |
+
fetch-depth: 0
|
48 |
+
ref: ${{ inputs.manual_revision_test }}
|
49 |
+
- name: Set up Python ${{ matrix.python }}
|
50 |
+
id: pythonsetup
|
51 |
+
uses: actions/setup-python@v4
|
52 |
+
with:
|
53 |
+
python-version: ${{ matrix.python }}
|
54 |
+
- name: Venv cache
|
55 |
+
id: venv-cache
|
56 |
+
uses: actions/cache@v3
|
57 |
+
with:
|
58 |
+
path: .env
|
59 |
+
key: venv-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ hashFiles('requirements*') }}
|
60 |
+
- name: Pytest durations cache
|
61 |
+
uses: actions/cache@v3
|
62 |
+
with:
|
63 |
+
path: .test_durations
|
64 |
+
key: test_durations-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ matrix.job }}-${{ github.run_id }}
|
65 |
+
restore-keys: test_durations-0-
|
66 |
+
- name: Setup
|
67 |
+
if: steps.venv-cache.outputs.cache-hit != 'true'
|
68 |
+
run: |
|
69 |
+
python3 -m venv .env
|
70 |
+
source .env/bin/activate
|
71 |
+
pip install -e .[test]
|
72 |
+
- name: Prepare test data
|
73 |
+
run: |
|
74 |
+
source .env/bin/activate
|
75 |
+
python -m pytest \
|
76 |
+
--quiet --co \
|
77 |
+
--splitting-algorithm least_duration \
|
78 |
+
--splits ${{ matrix.job_num }} \
|
79 |
+
--group ${{ matrix.job }} \
|
80 |
+
-m regression_test \
|
81 |
+
tests \
|
82 |
+
| head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \
|
83 |
+
> models_gh_runner.txt
|
84 |
+
if [ -n "${{ inputs.manual_revision_reference }}" ]; then
|
85 |
+
REVISION_REFERENCE=${{ inputs.manual_revision_reference }}
|
86 |
+
fi
|
87 |
+
python tests/util_test.py \
|
88 |
+
--save_model_list models_gh_runner.txt \
|
89 |
+
--model_list models_gh_runner.txt \
|
90 |
+
--git_revision $REVISION_REFERENCE
|
91 |
+
- name: Unit tests
|
92 |
+
run: |
|
93 |
+
source .env/bin/activate
|
94 |
+
if [[ -f .test_durations ]]
|
95 |
+
then
|
96 |
+
cp .test_durations durations_1
|
97 |
+
mv .test_durations durations_2
|
98 |
+
fi
|
99 |
+
python -m pytest \
|
100 |
+
-x -s -v \
|
101 |
+
--splitting-algorithm least_duration \
|
102 |
+
--splits ${{ matrix.job_num }} \
|
103 |
+
--group ${{ matrix.job }} \
|
104 |
+
--store-durations \
|
105 |
+
--durations-path durations_1 \
|
106 |
+
--clean-durations \
|
107 |
+
-m "not regression_test" \
|
108 |
+
tests
|
109 |
+
OPEN_CLIP_TEST_REG_MODELS=models_gh_runner.txt python -m pytest \
|
110 |
+
-x -s -v \
|
111 |
+
--store-durations \
|
112 |
+
--durations-path durations_2 \
|
113 |
+
--clean-durations \
|
114 |
+
-m "regression_test" \
|
115 |
+
tests
|
116 |
+
jq -s -S 'add' durations_* > .test_durations
|
117 |
+
- name: Collect pytest durations
|
118 |
+
uses: actions/upload-artifact@v3
|
119 |
+
with:
|
120 |
+
name: pytest_durations_${{ matrix.os }}-${{ matrix.python }}-${{ matrix.job }}
|
121 |
+
path: .test_durations
|
open_clip/.github/workflows/clear-cache.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Clear cache
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
|
6 |
+
permissions:
|
7 |
+
actions: write
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
clear-cache:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- name: Clear cache
|
14 |
+
uses: actions/github-script@v6
|
15 |
+
with:
|
16 |
+
script: |
|
17 |
+
const caches = await github.rest.actions.getActionsCacheList({
|
18 |
+
owner: context.repo.owner,
|
19 |
+
repo: context.repo.repo,
|
20 |
+
})
|
21 |
+
for (const cache of caches.data.actions_caches) {
|
22 |
+
console.log(cache)
|
23 |
+
await github.rest.actions.deleteActionsCacheById({
|
24 |
+
owner: context.repo.owner,
|
25 |
+
repo: context.repo.repo,
|
26 |
+
cache_id: cache.id,
|
27 |
+
})
|
28 |
+
}
|
29 |
+
|
open_clip/.github/workflows/python-publish.yml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
jobs:
|
8 |
+
deploy:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- uses: actions/checkout@v2
|
12 |
+
- uses: actions-ecosystem/action-regex-match@v2
|
13 |
+
id: regex-match
|
14 |
+
with:
|
15 |
+
text: ${{ github.event.head_commit.message }}
|
16 |
+
regex: '^Release ([^ ]+)'
|
17 |
+
- name: Set up Python
|
18 |
+
uses: actions/setup-python@v2
|
19 |
+
with:
|
20 |
+
python-version: '3.8'
|
21 |
+
- name: Install dependencies
|
22 |
+
run: |
|
23 |
+
python -m pip install --upgrade pip
|
24 |
+
pip install setuptools wheel twine build
|
25 |
+
- name: Release
|
26 |
+
if: ${{ steps.regex-match.outputs.match != '' }}
|
27 |
+
uses: softprops/action-gh-release@v1
|
28 |
+
with:
|
29 |
+
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
30 |
+
- name: Build and publish
|
31 |
+
if: ${{ steps.regex-match.outputs.match != '' }}
|
32 |
+
env:
|
33 |
+
TWINE_USERNAME: __token__
|
34 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
35 |
+
run: |
|
36 |
+
python -m build
|
37 |
+
twine upload dist/*
|
open_clip/docs/Interacting_with_open_clip.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
open_clip/docs/clip_recall.png
ADDED
![]() |
open_clip/docs/openclip_retrieval_results.csv
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name,pretrained,params (M),FLOPs (B),Average score,Flickr image retr. R@1,Flickr image retr. R@5,Flickr image retr. R@10,Flickr text retr. R@1,Flickr text retr. R@5,Flickr text retr. R@10,MSCOCO image retr. R@1,MSCOCO image retr. R@5,MSCOCO image retr. R@10,MSCOCO text retr. R@1,MSCOCO text retr. R@5,MSCOCO text retr. R@10,WinoGAViL avg jaccard score,WinoGAViL jaccard score 10,WinoGAViL jaccard score 10-12,WinoGAViL jaccard score 12,WinoGAViL jaccard score 5,WinoGAViL jaccard score 5-6,WinoGAViL jaccard score 6
|
2 |
+
ViT-SO400M-14-SigLIP-384,webli,877.96,723.48,0.7721,0.8296,0.9610,0.9804,0.9430,0.9970,0.9980,0.5421,0.7678,0.8424,0.7242,0.8998,0.9448,0.6181,0.5807,0.5754,0.5701,0.6427,0.6316,0.6210
|
3 |
+
ViT-H-14-378-quickgelu,dfn5b,986.71,1054.05,0.7719,0.8202,0.9598,0.9798,0.9400,0.9920,0.9960,0.5564,0.7920,0.8626,0.7188,0.9048,0.9496,0.6123,0.5668,0.5674,0.5679,0.6409,0.6265,0.6127
|
4 |
+
ViT-L-16-SigLIP-384,webli,652.48,422.91,0.7642,0.8142,0.9536,0.9748,0.9370,0.9920,0.9990,0.5391,0.7657,0.8399,0.7190,0.9006,0.9410,0.6070,0.5562,0.5550,0.5539,0.6422,0.6234,0.6056
|
5 |
+
ViT-H-14-quickgelu,dfn5b,986.11,381.68,0.7640,0.8012,0.9526,0.9738,0.9280,0.9940,0.9990,0.5391,0.7805,0.8553,0.7230,0.9024,0.9452,0.6022,0.5568,0.5561,0.5554,0.6315,0.6167,0.6028
|
6 |
+
ViT-L-16-SigLIP-256,webli,652.15,201.62,0.7619,0.7904,0.9446,0.9684,0.9180,0.9900,0.9980,0.5228,0.7580,0.8334,0.7080,0.8870,0.9374,0.6142,0.5877,0.5748,0.5619,0.6450,0.6267,0.6092
|
7 |
+
ViT-SO400M-14-SigLIP,webli,877.36,233.54,0.7567,0.7526,0.9226,0.9554,0.9100,0.9910,0.9980,0.5176,0.7527,0.8300,0.6966,0.8908,0.9348,0.6189,0.5736,0.5665,0.5594,0.6430,0.6354,0.6282
|
8 |
+
ViT-B-16-SigLIP-512,webli,203.79,227.26,0.7554,0.7906,0.9458,0.9690,0.9250,0.9920,0.9960,0.5055,0.7421,0.8217,0.6872,0.8786,0.9266,0.6070,0.5761,0.5696,0.5633,0.6291,0.6187,0.6088
|
9 |
+
ViT-B-16-SigLIP-384,webli,203.45,123.15,0.7542,0.7848,0.9416,0.9708,0.9270,0.9910,0.9980,0.4990,0.7366,0.8204,0.6774,0.8758,0.9258,0.6077,0.5730,0.5719,0.5709,0.6314,0.6190,0.6072
|
10 |
+
ViT-B-16-SigLIP-256,webli,203.2,57.84,0.7462,0.7504,0.9242,0.9626,0.9040,0.9830,0.9930,0.4834,0.7254,0.8099,0.6614,0.8634,0.9182,0.6113,0.5754,0.5715,0.5677,0.6373,0.6238,0.6110
|
11 |
+
ViT-B-16-SigLIP,webli,203.16,46.44,0.7442,0.7468,0.9230,0.9562,0.8910,0.9800,0.9930,0.4778,0.7244,0.8100,0.6574,0.8542,0.9126,0.6134,0.5726,0.5738,0.5750,0.6335,0.6259,0.6186
|
12 |
+
coca_ViT-L-14,mscoco_finetuned_laion2b_s13b_b90k,638.45,214.52,0.7432,0.7846,0.9452,0.9712,0.8900,0.9870,0.9950,0.5374,0.7779,0.8567,0.6682,0.8762,0.9242,0.5762,0.5265,0.5146,0.5028,0.5978,0.5956,0.5934
|
13 |
+
ViT-H-14-quickgelu,metaclip_fullcc,986.11,381.68,0.7412,0.7834,0.9464,0.9692,0.9180,0.9870,0.9970,0.4882,0.7323,0.8136,0.6622,0.8618,0.9188,0.5891,0.5316,0.5312,0.5308,0.6220,0.6073,0.5934
|
14 |
+
ViT-L-14-quickgelu,dfn2b,427.62,175.33,0.7406,0.7546,0.9280,0.9608,0.8960,0.9850,0.9930,0.4856,0.7381,0.8237,0.6560,0.8584,0.9120,0.5955,0.5688,0.5519,0.5351,0.6219,0.6093,0.5973
|
15 |
+
xlm-roberta-large-ViT-H-14,frozen_laion5b_s13b_b90k,1193.01,671.01,0.7363,0.7742,0.9392,0.9670,0.9180,0.9930,0.9980,0.4921,0.7305,0.8159,0.6596,0.8618,0.9184,0.5767,0.5277,0.5206,0.5135,0.6076,0.5943,0.5818
|
16 |
+
ViT-B-16-SigLIP-i18n-256,webli,370.63,57.84,0.7347,0.7216,0.9036,0.9470,0.8960,0.9820,0.9910,0.4492,0.6948,0.7837,0.6448,0.8434,0.9060,0.6113,0.5887,0.5710,0.5534,0.6303,0.6240,0.6180
|
17 |
+
ViT-L-14-quickgelu,metaclip_fullcc,427.62,175.33,0.7325,0.7642,0.9366,0.9646,0.9010,0.9850,0.9930,0.4709,0.7141,0.8026,0.6442,0.8504,0.9130,0.5817,0.5539,0.5357,0.5176,0.5990,0.5963,0.5937
|
18 |
+
ViT-H-14-CLIPA-336,laion2b,968.64,800.88,0.7289,0.7814,0.9392,0.9684,0.9230,0.9920,0.9970,0.5013,0.7387,0.8224,0.6752,0.8802,0.9308,0.5452,0.4970,0.4878,0.4786,0.5764,0.5633,0.5508
|
19 |
+
convnext_xxlarge,laion2b_s34b_b82k_augreg_rewind,1200.58,443.03,0.7238,0.7886,0.9438,0.9714,0.9110,0.9910,0.9970,0.5003,0.7407,0.8240,0.6740,0.8738,0.9264,0.5327,0.4826,0.4741,0.4656,0.5674,0.5512,0.5357
|
20 |
+
EVA02-L-14-336,merged2b_s6b_b61k,428.08,395.16,0.7235,0.7802,0.9432,0.9684,0.8960,0.9890,0.9960,0.4794,0.7172,0.7999,0.6416,0.8522,0.9084,0.5532,0.5081,0.5053,0.5025,0.5719,0.5683,0.5649
|
21 |
+
convnext_xxlarge,laion2b_s34b_b82k_augreg_soup,1200.58,443.03,0.7232,0.7902,0.9460,0.9706,0.9140,0.9920,0.9970,0.4997,0.7403,0.8227,0.6736,0.8722,0.9260,0.5316,0.4790,0.4702,0.4615,0.5651,0.5509,0.5374
|
22 |
+
EVA01-g-14-plus,merged2b_s11b_b114k,1366.62,581.15,0.7231,0.7900,0.9448,0.9692,0.9170,0.9930,0.9970,0.5033,0.7398,0.8206,0.6816,0.8750,0.9250,0.5298,0.4787,0.4684,0.4582,0.5655,0.5491,0.5336
|
23 |
+
ViT-bigG-14-CLIPA-336,datacomp1b,2517.76,2271.58,0.7220,0.7828,0.9408,0.9674,0.9220,0.9940,0.9980,0.5042,0.7432,0.8238,0.6776,0.8792,0.9306,0.5323,0.4624,0.4487,0.4350,0.5619,0.5587,0.5556
|
24 |
+
ViT-bigG-14,laion2b_s39b_b160k,2539.57,1065.36,0.7213,0.7956,0.9498,0.9714,0.9290,0.9930,0.9980,0.5137,0.7490,0.8300,0.6738,0.8686,0.9266,0.5224,0.4714,0.4488,0.4263,0.5588,0.5456,0.5330
|
25 |
+
EVA02-L-14,merged2b_s4b_b131k,427.76,175.3,0.7207,0.7732,0.9374,0.9678,0.8970,0.9850,0.9910,0.4747,0.7123,0.7971,0.6366,0.8434,0.9038,0.5518,0.5128,0.5081,0.5034,0.5767,0.5656,0.5551
|
26 |
+
ViT-L-14,commonpool_xl_laion_s13b_b90k,427.62,175.33,0.7206,0.7490,0.9308,0.9630,0.8770,0.9900,0.9950,0.4677,0.7142,0.8042,0.6350,0.8458,0.9122,0.5635,0.5193,0.4966,0.4740,0.6023,0.5846,0.5677
|
27 |
+
convnext_xxlarge,laion2b_s34b_b82k_augreg,1200.58,443.03,0.7201,0.7902,0.9414,0.9726,0.9160,0.9920,0.9970,0.4970,0.7400,0.8199,0.6698,0.8744,0.9236,0.5290,0.4730,0.4536,0.4343,0.5622,0.5527,0.5437
|
28 |
+
ViT-H-14-CLIPA-336,datacomp1b,968.64,800.88,0.7192,0.7652,0.9350,0.9630,0.9060,0.9920,0.9950,0.4928,0.7309,0.8144,0.6716,0.8746,0.9276,0.5365,0.4801,0.4587,0.4375,0.5658,0.5610,0.5566
|
29 |
+
EVA02-E-14-plus,laion2b_s9b_b144k,5044.89,2362.19,0.7188,0.7886,0.9434,0.9698,0.9410,0.9930,0.9980,0.5110,0.7492,0.8276,0.6872,0.8760,0.9274,0.5134,0.4576,0.4403,0.4232,0.5470,0.5364,0.5263
|
30 |
+
ViT-H-14-CLIPA,datacomp1b,968.24,354.02,0.7178,0.7588,0.9330,0.9628,0.9100,0.9900,0.9960,0.4910,0.7291,0.8140,0.6698,0.8730,0.9272,0.5343,0.4793,0.4578,0.4365,0.5641,0.5584,0.5529
|
31 |
+
ViT-bigG-14-CLIPA,datacomp1b,2517.22,1007.93,0.7175,0.7786,0.9374,0.9650,0.9190,0.9930,0.9980,0.4996,0.7414,0.8214,0.6780,0.8742,0.9312,0.5247,0.4461,0.4383,0.4306,0.5552,0.5519,0.5488
|
32 |
+
ViT-B-16,dfn2b,149.62,41.09,0.7158,0.6912,0.8982,0.9406,0.8540,0.9780,0.9860,0.4339,0.6924,0.7882,0.6038,0.8310,0.8946,0.5852,0.5509,0.5417,0.5327,0.6058,0.5989,0.5923
|
33 |
+
EVA02-E-14,laion2b_s4b_b115k,4704.59,2311.42,0.7152,0.7810,0.9438,0.9700,0.9220,0.9950,0.9970,0.5037,0.7429,0.8224,0.6746,0.8738,0.9256,0.5110,0.4316,0.4429,0.4542,0.5485,0.5325,0.5173
|
34 |
+
convnext_large_d_320,laion2b_s29b_b131k_ft_soup,351.77,157.98,0.7148,0.7704,0.9394,0.9648,0.9230,0.9890,0.9980,0.4829,0.7251,0.8103,0.6502,0.8608,0.9180,0.5285,0.4574,0.4549,0.4525,0.5650,0.5517,0.5391
|
35 |
+
ViT-L-14-quickgelu,metaclip_400m,427.62,175.33,0.7137,0.7342,0.9228,0.9574,0.8620,0.9800,0.9960,0.4381,0.6862,0.7780,0.6000,0.8290,0.8934,0.5701,0.5322,0.5175,0.5029,0.5941,0.5867,0.5796
|
36 |
+
convnext_large_d_320,laion2b_s29b_b131k_ft,351.77,157.98,0.7122,0.7690,0.9364,0.9668,0.9140,0.9890,0.9980,0.4814,0.7232,0.8103,0.6384,0.8558,0.9168,0.5250,0.4727,0.4558,0.4389,0.5623,0.5468,0.5321
|
37 |
+
convnext_large_d,laion2b_s26b_b102k_augreg,351.77,107.5,0.7102,0.7588,0.9310,0.9652,0.9180,0.9850,0.9940,0.4701,0.7139,0.8013,0.6400,0.8506,0.9140,0.5252,0.4835,0.4646,0.4457,0.5571,0.5443,0.5322
|
38 |
+
ViT-L-14-CLIPA-336,datacomp1b,414.54,387.39,0.7092,0.7462,0.9240,0.9590,0.9040,0.9920,0.9980,0.4715,0.7159,0.8024,0.6564,0.8636,0.9180,0.5276,0.4601,0.4449,0.4297,0.5611,0.5537,0.5467
|
39 |
+
ViT-L-14,datacomp_xl_s13b_b90k,427.62,175.33,0.7091,0.7338,0.9174,0.9554,0.8900,0.9860,0.9970,0.4573,0.7003,0.7916,0.6330,0.8414,0.9040,0.5471,0.4836,0.4666,0.4497,0.5912,0.5724,0.5546
|
40 |
+
ViT-g-14,laion2b_s12b_b42k,1366.68,581.15,0.7085,0.7642,0.9364,0.9624,0.9090,0.9910,0.9980,0.4802,0.7238,0.8079,0.6492,0.8530,0.9152,0.5178,0.4498,0.4427,0.4357,0.5589,0.5414,0.5249
|
41 |
+
coca_ViT-L-14,laion2b_s13b_b90k,638.45,214.52,0.7085,0.7428,0.9202,0.9542,0.8840,0.9930,0.9990,0.4565,0.7042,0.7921,0.6292,0.8370,0.9038,0.5403,0.4800,0.4739,0.4677,0.5704,0.5612,0.5525
|
42 |
+
ViT-B-16-quickgelu,metaclip_fullcc,149.62,41.09,0.7077,0.7072,0.9082,0.9454,0.8550,0.9740,0.9890,0.4134,0.6718,0.7696,0.5936,0.8058,0.8784,0.5787,0.5391,0.5221,0.5052,0.6064,0.5965,0.5871
|
43 |
+
ViT-H-14,laion2b_s32b_b79k,986.11,381.68,0.7057,0.7764,0.9418,0.9660,0.9070,0.9920,0.9970,0.4948,0.7338,0.8151,0.6592,0.8606,0.9188,0.4998,0.4289,0.4247,0.4206,0.5462,0.5234,0.5018
|
44 |
+
convnext_base_w,laion_aesthetic_s13b_b82k,179.39,49.38,0.7047,0.7306,0.9160,0.9532,0.8880,0.9820,0.9930,0.4355,0.6858,0.7808,0.6120,0.8338,0.8978,0.5461,0.4937,0.4764,0.4592,0.5829,0.5681,0.5540
|
45 |
+
ViT-L-14-CLIPA,datacomp1b,414.21,167.5,0.7044,0.7356,0.9206,0.9550,0.9020,0.9860,0.9930,0.4695,0.7084,0.7981,0.6512,0.8544,0.9150,0.5244,0.4495,0.4388,0.4282,0.5583,0.5514,0.5449
|
46 |
+
ViT-g-14,laion2b_s34b_b88k,1366.68,581.15,0.7038,0.7772,0.9416,0.9690,0.9140,0.9930,0.9960,0.4878,0.7334,0.8151,0.6638,0.8598,0.9184,0.4973,0.4207,0.4104,0.4001,0.5458,0.5247,0.5047
|
47 |
+
convnext_base_w_320,laion_aesthetic_s13b_b82k,179.39,71.94,0.7030,0.7402,0.9200,0.9524,0.9070,0.9830,0.9920,0.4387,0.6920,0.7824,0.6106,0.8312,0.8964,0.5385,0.4742,0.4610,0.4478,0.5695,0.5630,0.5567
|
48 |
+
ViT-L-14,laion2b_s32b_b82k,427.62,175.33,0.7010,0.7552,0.9286,0.9594,0.8950,0.9870,0.9940,0.4651,0.7108,0.7980,0.6336,0.8398,0.9074,0.5142,0.4448,0.4385,0.4322,0.5507,0.5381,0.5261
|
49 |
+
EVA01-g-14,laion400m_s11b_b41k,1136.44,547.36,0.7003,0.7264,0.9162,0.9510,0.8810,0.9830,0.9930,0.4406,0.6849,0.7732,0.6180,0.8328,0.8998,0.5363,0.4646,0.4640,0.4634,0.5665,0.5591,0.5521
|
50 |
+
convnext_base_w,laion2b_s13b_b82k,179.39,49.38,0.6976,0.7214,0.9152,0.9538,0.8770,0.9800,0.9930,0.4285,0.6805,0.7770,0.6058,0.8324,0.8954,0.5311,0.4926,0.4706,0.4488,0.5648,0.5501,0.5362
|
51 |
+
ViT-B-16-quickgelu,metaclip_400m,149.62,41.09,0.6962,0.6766,0.8960,0.9414,0.8570,0.9720,0.9870,0.3997,0.6536,0.7530,0.5648,0.7988,0.8714,0.5675,0.4984,0.5113,0.5242,0.5916,0.5853,0.5792
|
52 |
+
ViT-L-14,commonpool_xl_clip_s13b_b90k,427.62,175.33,0.6956,0.6886,0.8996,0.9424,0.8650,0.9740,0.9920,0.4273,0.6761,0.7700,0.6040,0.8190,0.8858,0.5460,0.4833,0.4728,0.4622,0.5821,0.5691,0.5566
|
53 |
+
convnext_base_w,laion2b_s13b_b82k_augreg,179.39,49.38,0.6922,0.7240,0.9144,0.9514,0.8730,0.9830,0.9920,0.4303,0.6786,0.7750,0.6140,0.8308,0.8964,0.5229,0.4504,0.4390,0.4276,0.5567,0.5494,0.5424
|
54 |
+
EVA02-B-16,merged2b_s8b_b131k,149.69,41.09,0.6904,0.7146,0.9114,0.9468,0.8600,0.9660,0.9880,0.4215,0.6686,0.7629,0.5874,0.8056,0.8812,0.5325,0.4686,0.4722,0.4758,0.5522,0.5515,0.5509
|
55 |
+
ViT-B-16,laion2b_s34b_b88k,149.62,41.09,0.6884,0.6984,0.9038,0.9456,0.8630,0.9790,0.9940,0.4231,0.6770,0.7706,0.5944,0.8178,0.8862,0.5220,0.4628,0.4601,0.4575,0.5549,0.5415,0.5288
|
56 |
+
ViT-B-32-quickgelu,metaclip_fullcc,151.28,14.78,0.6879,0.6510,0.8766,0.9274,0.8080,0.9530,0.9730,0.3806,0.6411,0.7430,0.5518,0.7890,0.8650,0.5729,0.5351,0.5238,0.5126,0.5983,0.5883,0.5788
|
57 |
+
convnext_base_w_320,laion_aesthetic_s13b_b82k_augreg,179.39,71.94,0.6870,0.7108,0.9032,0.9452,0.8910,0.9740,0.9900,0.4243,0.6722,0.7680,0.6040,0.8276,0.8912,0.5192,0.4599,0.4293,0.3989,0.5686,0.5475,0.5274
|
58 |
+
ViT-B-16,datacomp_xl_s13b_b90k,149.62,41.09,0.6867,0.6756,0.8834,0.9300,0.8510,0.9720,0.9840,0.4016,0.6595,0.7562,0.5744,0.8086,0.8830,0.5408,0.4966,0.4832,0.4699,0.5730,0.5590,0.5457
|
59 |
+
ViT-L-14,commonpool_xl_s13b_b90k,427.62,175.33,0.6817,0.6478,0.8656,0.9210,0.8200,0.9500,0.9740,0.3859,0.6310,0.7328,0.5446,0.7712,0.8484,0.5697,0.5115,0.5072,0.5030,0.5932,0.5894,0.5858
|
60 |
+
ViT-B-32,laion2b_e16,151.28,14.78,0.6777,0.6638,0.8830,0.9322,0.8440,0.9650,0.9840,0.3913,0.6467,0.7481,0.5624,0.7956,0.8708,0.5343,0.4800,0.4602,0.4404,0.5669,0.5577,0.5490
|
61 |
+
xlm-roberta-base-ViT-B-32,laion5b_s13b_b90k,366.12,105.87,0.6759,0.6448,0.8628,0.9168,0.8270,0.9640,0.9780,0.3778,0.6344,0.7348,0.5354,0.7798,0.8640,0.5519,0.4912,0.4827,0.4742,0.5834,0.5737,0.5645
|
62 |
+
ViT-B-32,laion2b_s34b_b79k,151.28,14.78,0.6750,0.6678,0.8838,0.9310,0.8410,0.9620,0.9830,0.3934,0.6543,0.7561,0.5632,0.7984,0.8712,0.5254,0.4603,0.4479,0.4356,0.5603,0.5498,0.5398
|
63 |
+
nllb-clip-large-siglip,v1,1195.5,1804.22,0.6738,0.7276,0.9224,0.9560,0.8330,0.9710,0.9910,0.4513,0.7084,0.8016,0.5386,0.7920,0.8712,0.4871,0.4138,0.4035,0.3933,0.5230,0.5135,0.5045
|
64 |
+
ViT-L-14,laion400m_e32,427.62,175.33,0.6734,0.7022,0.9094,0.9458,0.8760,0.9780,0.9950,0.4300,0.6803,0.7740,0.5974,0.8218,0.8938,0.4815,0.4006,0.3932,0.3858,0.5245,0.5094,0.4950
|
65 |
+
ViT-L-14,laion400m_e31,427.62,175.33,0.6728,0.7050,0.9068,0.9464,0.8720,0.9760,0.9950,0.4284,0.6797,0.7731,0.5974,0.8216,0.8934,0.4805,0.4022,0.3949,0.3877,0.5262,0.5075,0.4897
|
66 |
+
ViT-B-32-256,datacomp_s34b_b86k,151.29,17.46,0.6712,0.6492,0.8722,0.9216,0.8480,0.9670,0.9850,0.3993,0.6543,0.7524,0.5792,0.8056,0.8810,0.5165,0.4554,0.4300,0.4048,0.5428,0.5437,0.5446
|
67 |
+
ViT-B-32-quickgelu,metaclip_400m,151.28,14.78,0.6710,0.6234,0.8558,0.9152,0.7780,0.9350,0.9700,0.3591,0.6179,0.7209,0.5182,0.7640,0.8466,0.5657,0.5279,0.5097,0.4916,0.5907,0.5834,0.5764
|
68 |
+
coca_ViT-B-32,laion2b_s13b_b90k,253.56,33.34,0.6689,0.6336,0.8566,0.9144,0.8160,0.9570,0.9740,0.3618,0.6136,0.7176,0.5456,0.7788,0.8556,0.5499,0.4924,0.4698,0.4472,0.5778,0.5751,0.5725
|
69 |
+
ViT-B-16-plus-240,laion400m_e32,208.38,64.03,0.6676,0.6810,0.8888,0.9406,0.8650,0.9720,0.9880,0.4100,0.6619,0.7611,0.5858,0.8108,0.8806,0.4856,0.4146,0.4069,0.3993,0.5250,0.5104,0.4964
|
70 |
+
ViT-B-16-plus-240,laion400m_e31,208.38,64.03,0.6671,0.6790,0.8878,0.9392,0.8560,0.9730,0.9860,0.4075,0.6604,0.7594,0.5834,0.8082,0.8796,0.4882,0.4206,0.4093,0.3980,0.5290,0.5131,0.4981
|
71 |
+
ViT-L-14-336,openai,427.94,395.22,0.6666,0.6690,0.8900,0.9334,0.8770,0.9850,0.9940,0.3709,0.6162,0.7147,0.5794,0.8120,0.8792,0.5041,0.4201,0.4145,0.4089,0.5403,0.5324,0.5249
|
72 |
+
roberta-ViT-B-32,laion2b_s12b_b32k,212.72,105.87,0.6641,0.6434,0.8674,0.9250,0.8170,0.9490,0.9720,0.3744,0.6309,0.7355,0.5436,0.7774,0.8568,0.5224,0.4602,0.4583,0.4564,0.5531,0.5426,0.5325
|
73 |
+
ViT-B-16,laion400m_e32,149.62,41.09,0.6638,0.6566,0.8828,0.9300,0.8350,0.9680,0.9850,0.3835,0.6362,0.7386,0.5542,0.7964,0.8688,0.5033,0.4380,0.4326,0.4272,0.5294,0.5256,0.5219
|
74 |
+
ViT-B-16,laion400m_e31,149.62,41.09,0.6621,0.6572,0.8808,0.9314,0.8330,0.9660,0.9840,0.3834,0.6366,0.7383,0.5514,0.7960,0.8688,0.5024,0.4255,0.4225,0.4195,0.5323,0.5276,0.5232
|
75 |
+
ViT-B-32,datacomp_xl_s13b_b90k,151.28,14.78,0.6594,0.6108,0.8492,0.9092,0.7900,0.9390,0.9620,0.3714,0.6235,0.7268,0.5354,0.7778,0.8604,0.5317,0.4772,0.4594,0.4418,0.5589,0.5544,0.5502
|
76 |
+
ViT-L-14,openai,427.62,175.33,0.6588,0.6496,0.8724,0.9204,0.8520,0.9740,0.9920,0.3651,0.6106,0.7113,0.5634,0.7938,0.8660,0.5047,0.4307,0.4136,0.3966,0.5457,0.5334,0.5218
|
77 |
+
RN50x64,openai,623.26,552.65,0.6563,0.6898,0.8990,0.9432,0.8690,0.9820,0.9920,0.3524,0.5992,0.7033,0.5842,0.8046,0.8786,0.4779,0.4149,0.3936,0.3725,0.5280,0.5044,0.4820
|
78 |
+
nllb-clip-base-siglip,v1,507.47,472.91,0.6545,0.6922,0.9002,0.9424,0.7920,0.9500,0.9820,0.4315,0.6913,0.7852,0.5116,0.7776,0.8628,0.4715,0.3776,0.3815,0.3854,0.5113,0.4999,0.4890
|
79 |
+
convnext_base,laion400m_s13b_b51k,151.52,36.67,0.6541,0.6496,0.8814,0.9304,0.8380,0.9710,0.9910,0.3760,0.6315,0.7337,0.5470,0.7990,0.8676,0.4811,0.4146,0.4045,0.3944,0.5145,0.5052,0.4964
|
80 |
+
RN50x16,openai,290.98,162.69,0.6518,0.6534,0.8710,0.9178,0.8570,0.9700,0.9880,0.3541,0.6002,0.7014,0.5536,0.7876,0.8670,0.4957,0.4311,0.3946,0.3584,0.5419,0.5275,0.5138
|
81 |
+
ViT-B-16,openai,149.62,41.09,0.6507,0.6216,0.8572,0.9192,0.8220,0.9660,0.9900,0.3309,0.5842,0.6899,0.5242,0.7670,0.8462,0.5171,0.4487,0.4316,0.4146,0.5550,0.5441,0.5337
|
82 |
+
ViT-B-32,laion400m_e31,151.28,14.78,0.6412,0.5970,0.8398,0.9036,0.7810,0.9380,0.9660,0.3420,0.6001,0.7059,0.5234,0.7634,0.8432,0.5059,0.4283,0.4262,0.4242,0.5477,0.5310,0.5152
|
83 |
+
ViT-B-32,laion400m_e32,151.28,14.78,0.6412,0.5962,0.8396,0.9020,0.7770,0.9410,0.9680,0.3431,0.6000,0.7054,0.5244,0.7642,0.8454,0.5055,0.4272,0.4265,0.4258,0.5476,0.5304,0.5139
|
84 |
+
ViT-B-32-quickgelu,laion400m_e32,151.28,14.78,0.6394,0.6170,0.8546,0.9086,0.7880,0.9400,0.9700,0.3533,0.6089,0.7165,0.5258,0.7672,0.8464,0.4884,0.4097,0.4072,0.4047,0.5280,0.5140,0.5006
|
85 |
+
ViT-B-32-quickgelu,laion400m_e31,151.28,14.78,0.6389,0.6174,0.8548,0.9078,0.7870,0.9400,0.9730,0.3535,0.6100,0.7177,0.5254,0.7702,0.8490,0.4860,0.4054,0.4034,0.4015,0.5273,0.5120,0.4974
|
86 |
+
RN50x4,openai,178.3,51.82,0.6373,0.6258,0.8476,0.9018,0.8210,0.9630,0.9830,0.3339,0.5812,0.6830,0.5296,0.7662,0.8490,0.4893,0.4104,0.3912,0.3720,0.5354,0.5202,0.5058
|
87 |
+
nllb-clip-large,v1,1399.22,1468.46,0.6346,0.6090,0.8576,0.9202,0.7160,0.9220,0.9600,0.3617,0.6250,0.7304,0.4392,0.7086,0.8036,0.5097,0.4334,0.4299,0.4265,0.5395,0.5348,0.5304
|
88 |
+
ViT-B-32,openai,151.28,14.78,0.6321,0.5878,0.8356,0.9002,0.7890,0.9490,0.9820,0.3044,0.5594,0.6687,0.5012,0.7500,0.8352,0.5054,0.4454,0.4125,0.3798,0.5492,0.5347,0.5210
|
89 |
+
ViT-B-32-quickgelu,openai,151.28,14.78,0.6321,0.5878,0.8356,0.9002,0.7890,0.9490,0.9820,0.3044,0.5594,0.6687,0.5012,0.7500,0.8352,0.5054,0.4454,0.4125,0.3798,0.5492,0.5347,0.5210
|
90 |
+
ViT-B-16,commonpool_l_laion_s1b_b8k,149.62,41.09,0.6274,0.5664,0.8114,0.8836,0.7230,0.9100,0.9510,0.3195,0.5756,0.6848,0.4652,0.7170,0.8088,0.5227,0.4558,0.4476,0.4395,0.5563,0.5463,0.5369
|
91 |
+
ViT-B-16,datacomp_l_s1b_b8k,149.62,41.09,0.6267,0.5536,0.8090,0.8768,0.7320,0.9170,0.9480,0.3218,0.5747,0.6858,0.4872,0.7292,0.8246,0.5113,0.4613,0.4465,0.4318,0.5404,0.5317,0.5235
|
92 |
+
RN101,openai,119.69,25.5,0.6249,0.5804,0.8228,0.8852,0.7900,0.9490,0.9740,0.3069,0.5546,0.6603,0.4982,0.7448,0.8250,0.4920,0.4347,0.4130,0.3913,0.5272,0.5170,0.5072
|
93 |
+
RN101-quickgelu,openai,119.69,25.5,0.6249,0.5804,0.8228,0.8852,0.7900,0.9490,0.9740,0.3069,0.5546,0.6603,0.4982,0.7448,0.8250,0.4920,0.4347,0.4130,0.3913,0.5272,0.5170,0.5072
|
94 |
+
ViT-B-16,commonpool_l_image_s1b_b8k,149.62,41.09,0.6164,0.5162,0.7908,0.8700,0.6890,0.8830,0.9270,0.2907,0.5449,0.6628,0.4338,0.6882,0.7902,0.5339,0.4932,0.4787,0.4643,0.5609,0.5513,0.5421
|
95 |
+
ViT-B-16,commonpool_l_basic_s1b_b8k,149.62,41.09,0.6132,0.5250,0.7880,0.8670,0.6780,0.8780,0.9330,0.2862,0.5411,0.6535,0.4304,0.6844,0.7826,0.5285,0.4872,0.4812,0.4753,0.5517,0.5434,0.5355
|
96 |
+
ViT-B-16,commonpool_l_text_s1b_b8k,149.62,41.09,0.6126,0.5336,0.7954,0.8650,0.6820,0.8880,0.9320,0.3020,0.5538,0.6674,0.4440,0.7012,0.7906,0.5147,0.4815,0.4570,0.4326,0.5394,0.5329,0.5266
|
97 |
+
RN50,openai,102.01,18.18,0.6124,0.5736,0.8318,0.9004,0.8000,0.9500,0.9790,0.2854,0.5291,0.6459,0.4884,0.7278,0.8212,0.4766,0.3682,0.3622,0.3562,0.5318,0.5127,0.4945
|
98 |
+
RN50-quickgelu,openai,102.01,18.18,0.6124,0.5736,0.8318,0.9004,0.8000,0.9500,0.9790,0.2854,0.5291,0.6459,0.4884,0.7278,0.8212,0.4766,0.3682,0.3622,0.3562,0.5318,0.5127,0.4945
|
99 |
+
ViT-B-16,commonpool_l_clip_s1b_b8k,149.62,41.09,0.6054,0.5102,0.7764,0.8564,0.6810,0.8940,0.9430,0.2880,0.5368,0.6492,0.4436,0.6938,0.7954,0.5133,0.4493,0.4359,0.4225,0.5472,0.5377,0.5287
|
100 |
+
ViT-B-16,commonpool_l_s1b_b8k,149.62,41.09,0.5635,0.4186,0.7040,0.7952,0.5710,0.8260,0.8940,0.2224,0.4565,0.5694,0.3486,0.6026,0.7096,0.5272,0.5029,0.4777,0.4526,0.5532,0.5428,0.5329
|
101 |
+
nllb-clip-base,v1,501.89,369.6,0.5562,0.4740,0.7554,0.8350,0.5660,0.8210,0.8920,0.2649,0.5150,0.6321,0.3514,0.6100,0.7204,0.4708,0.3890,0.3904,0.3919,0.5085,0.4962,0.4844
|
102 |
+
RN50-quickgelu,cc12m,102.01,18.18,0.5531,0.4696,0.7432,0.8284,0.6050,0.8600,0.9100,0.2412,0.4868,0.6024,0.3370,0.6136,0.7180,0.4644,0.4136,0.3876,0.3617,0.4970,0.4887,0.4807
|
103 |
+
RN50,cc12m,102.01,18.18,0.5475,0.4594,0.7336,0.8234,0.6120,0.8540,0.9070,0.2367,0.4787,0.5947,0.3412,0.6072,0.7154,0.4553,0.4081,0.3828,0.3577,0.4918,0.4782,0.4653
|
104 |
+
RN101-quickgelu,yfcc15m,119.69,25.5,0.4846,0.3474,0.6126,0.7192,0.5370,0.8030,0.8900,0.1779,0.3908,0.5043,0.2872,0.5456,0.6544,0.4226,0.3363,0.3147,0.2933,0.4723,0.4566,0.4416
|
105 |
+
RN101,yfcc15m,119.69,25.5,0.4796,0.3402,0.6034,0.7124,0.5450,0.8010,0.8810,0.1772,0.3895,0.5040,0.3036,0.5536,0.6652,0.4088,0.3252,0.2987,0.2724,0.4643,0.4435,0.4239
|
106 |
+
RN50-quickgelu,yfcc15m,102.01,18.18,0.4716,0.3180,0.5876,0.6968,0.4910,0.7840,0.8870,0.1645,0.3758,0.4891,0.2720,0.5260,0.6442,0.4209,0.3393,0.3130,0.2867,0.4634,0.4549,0.4469
|
107 |
+
RN50,yfcc15m,102.01,18.18,0.4709,0.3092,0.5884,0.6952,0.5120,0.8040,0.8850,0.1645,0.3747,0.4894,0.2826,0.5350,0.6546,0.4106,0.3249,0.3023,0.2799,0.4536,0.4448,0.4364
|
108 |
+
ViT-B-32,commonpool_m_image_s128m_b4k,151.28,14.78,0.3976,0.1896,0.4156,0.5312,0.2710,0.5300,0.6360,0.1054,0.2710,0.3703,0.1654,0.3758,0.4898,0.4804,0.4206,0.4026,0.3847,0.5261,0.5050,0.4849
|
109 |
+
ViT-B-32,commonpool_m_clip_s128m_b4k,151.28,14.78,0.3923,0.1890,0.4106,0.5238,0.3040,0.5540,0.6590,0.1127,0.2718,0.3719,0.1826,0.3960,0.5070,0.4519,0.3685,0.3581,0.3478,0.4947,0.4815,0.4689
|
110 |
+
ViT-B-32,commonpool_m_text_s128m_b4k,151.28,14.78,0.3917,0.1872,0.4200,0.5344,0.2900,0.5350,0.6590,0.1086,0.2712,0.3731,0.1818,0.3920,0.5008,0.4540,0.3863,0.3618,0.3373,0.5038,0.4831,0.4635
|
111 |
+
ViT-B-32,commonpool_m_basic_s128m_b4k,151.28,14.78,0.3907,0.1874,0.4172,0.5350,0.2500,0.5190,0.6530,0.1007,0.2559,0.3578,0.1658,0.3528,0.4654,0.4727,0.4245,0.4015,0.3786,0.5182,0.4951,0.4732
|
112 |
+
ViT-B-32,datacomp_m_s128m_b4k,151.28,14.78,0.3756,0.1812,0.4036,0.5188,0.2670,0.5200,0.6320,0.1100,0.2699,0.3750,0.1714,0.3770,0.4886,0.4337,0.3570,0.3287,0.3005,0.4914,0.4667,0.4433
|
113 |
+
ViT-B-32,commonpool_m_laion_s128m_b4k,151.28,14.78,0.3741,0.1792,0.3964,0.5120,0.2760,0.5330,0.6220,0.1052,0.2606,0.3637,0.1654,0.3592,0.4660,0.4409,0.3585,0.3348,0.3113,0.4894,0.4743,0.4600
|
114 |
+
ViT-B-32,commonpool_m_s128m_b4k,151.28,14.78,0.3449,0.1268,0.3164,0.4252,0.1930,0.4140,0.5360,0.0701,0.1941,0.2803,0.1246,0.3002,0.3996,0.4759,0.4354,0.3983,0.3614,0.5273,0.5003,0.4747
|
115 |
+
ViT-B-32,commonpool_s_text_s13m_b4k,151.28,14.78,0.1922,0.0278,0.0950,0.1462,0.0570,0.1300,0.2000,0.0170,0.0551,0.0906,0.0258,0.0814,0.1234,0.3964,0.3312,0.3120,0.2929,0.4487,0.4231,0.3987
|
116 |
+
ViT-B-32,commonpool_s_basic_s13m_b4k,151.28,14.78,0.1874,0.0256,0.0924,0.1466,0.0400,0.1220,0.2000,0.0144,0.0487,0.0782,0.0192,0.0646,0.1024,0.4008,0.3164,0.3033,0.2903,0.4526,0.4315,0.4115
|
117 |
+
ViT-B-32,commonpool_s_clip_s13m_b4k,151.28,14.78,0.1857,0.0286,0.0938,0.1454,0.0400,0.1370,0.2070,0.0157,0.0515,0.0847,0.0292,0.0794,0.1206,0.3801,0.3282,0.2994,0.2708,0.4299,0.4055,0.3823
|
118 |
+
ViT-B-32,commonpool_s_s13m_b4k,151.28,14.78,0.1789,0.0214,0.0690,0.1138,0.0300,0.1080,0.1760,0.0088,0.0351,0.0601,0.0176,0.0568,0.0902,0.3979,0.3338,0.3126,0.2914,0.4484,0.4249,0.4025
|
119 |
+
ViT-B-32,commonpool_s_image_s13m_b4k,151.28,14.78,0.1536,0.0142,0.0594,0.0960,0.0200,0.0740,0.1160,0.0109,0.0385,0.0641,0.0150,0.0500,0.0842,0.3552,0.2669,0.2525,0.2382,0.4123,0.3875,0.3640
|
120 |
+
ViT-B-32,datacomp_s_s13m_b4k,151.28,14.78,0.1536,0.0142,0.0594,0.0960,0.0200,0.0740,0.1160,0.0109,0.0385,0.0641,0.0150,0.0500,0.0842,0.3552,0.2669,0.2525,0.2382,0.4123,0.3875,0.3640
|
121 |
+
ViT-B-32,commonpool_s_laion_s13m_b4k,151.28,14.78,0.1527,0.0176,0.0678,0.1056,0.0270,0.0870,0.1440,0.0102,0.0363,0.0607,0.0140,0.0442,0.0770,0.3463,0.2607,0.2410,0.2215,0.4049,0.3795,0.3554
|
122 |
+
coca_ViT-B-32,mscoco_finetuned_laion2b_s13b_b90k,253.56,33.34,0.1306,0.0074,0.0214,0.0436,0.0110,0.0490,0.0990,0.0033,0.0137,0.0249,0.0088,0.0338,0.0552,0.3299,0.2484,0.2329,0.2175,0.3873,0.3604,0.3348
|
open_clip/docs/script_examples/clipa/vit_b16/i50_t16_finetune.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
2 |
+
--save-frequency 1 \
|
3 |
+
--save-most-recent \
|
4 |
+
--zeroshot-frequency 1 \
|
5 |
+
--train-data '/path/to/laion-400m' \
|
6 |
+
--dataset-type webdataset \
|
7 |
+
--lr "2.56e-5" \
|
8 |
+
--beta1 0.9 \
|
9 |
+
--beta2 0.95 \
|
10 |
+
--warmup 3072 \
|
11 |
+
--wd 0.2 \
|
12 |
+
--batch-size 1024 \
|
13 |
+
--aug-cfg scale='(0.4, 1.0)' \
|
14 |
+
--epochs 1 \
|
15 |
+
--train-num-samples 131072000 \
|
16 |
+
--workers 6 \
|
17 |
+
--model ViT-B-16-CL16 \
|
18 |
+
--pretrained '/path/to/ckpt' \
|
19 |
+
--precision 'amp_bf16' \
|
20 |
+
--ddp-static-graph \
|
21 |
+
--local-loss \
|
22 |
+
--gather-with-grad \
|
23 |
+
--grad-checkpointing \
|
24 |
+
--log-every-n-steps 256 \
|
25 |
+
--seed 0 \
|
26 |
+
--logs ./logs/ \
|
27 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipa/vit_b16/i50_t16_pretrain.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
2 |
+
--save-frequency 1 \
|
3 |
+
--save-most-recent \
|
4 |
+
--zeroshot-frequency 1 \
|
5 |
+
--train-data '/path/to/laion-400m' \
|
6 |
+
--dataset-type webdataset \
|
7 |
+
--lr "2.048e-3" \
|
8 |
+
--beta1 0.9 \
|
9 |
+
--beta2 0.95 \
|
10 |
+
--warmup 782 \
|
11 |
+
--wd 0.2 \
|
12 |
+
--batch-size 8192 \
|
13 |
+
--aug-cfg scale='(0.4, 1.0)' \
|
14 |
+
--epochs 6 \
|
15 |
+
--workers 6 \
|
16 |
+
--model ViT-B-16-CL16 \
|
17 |
+
--precision 'amp_bf16' \
|
18 |
+
--ddp-static-graph \
|
19 |
+
--local-loss \
|
20 |
+
--gather-with-grad \
|
21 |
+
--force-image-size 112 \
|
22 |
+
--grad-checkpointing \
|
23 |
+
--log-every-n-steps 32 \
|
24 |
+
--seed 0 \
|
25 |
+
--logs ./logs/ \
|
26 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipa/vit_l16/i17_t16_finetune.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
2 |
+
--save-frequency 1 \
|
3 |
+
--save-most-recent \
|
4 |
+
--zeroshot-frequency 1 \
|
5 |
+
--train-data '/path/to/laion-400m' \
|
6 |
+
--dataset-type webdataset \
|
7 |
+
--lr "2.24e-5" \
|
8 |
+
--beta1 0.9 \
|
9 |
+
--beta2 0.95 \
|
10 |
+
--warmup 3571 \
|
11 |
+
--wd 0.2 \
|
12 |
+
--batch-size 896 \
|
13 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
14 |
+
--epochs 1 \
|
15 |
+
--train-num-samples 131072000 \
|
16 |
+
--workers 6 \
|
17 |
+
--model ViT-L-16-CL16-GAP \
|
18 |
+
--pretrained '/path/to/ckpt' \
|
19 |
+
--precision 'amp_bf16' \
|
20 |
+
--ddp-static-graph \
|
21 |
+
--local-loss \
|
22 |
+
--gather-with-grad \
|
23 |
+
--grad-checkpointing \
|
24 |
+
--log-every-n-steps 293 \
|
25 |
+
--seed 0 \
|
26 |
+
--logs ./logs/ \
|
27 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipa/vit_l16/i17_t16_pretrain.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
2 |
+
--save-frequency 1 \
|
3 |
+
--save-most-recent \
|
4 |
+
--zeroshot-frequency 1 \
|
5 |
+
--train-data '/path/to/laion-400m' \
|
6 |
+
--dataset-type webdataset \
|
7 |
+
--lr "1.024e-3" \
|
8 |
+
--beta1 0.9 \
|
9 |
+
--beta2 0.95 \
|
10 |
+
--warmup 1563 \
|
11 |
+
--wd 0.2 \
|
12 |
+
--batch-size 4096 \
|
13 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
14 |
+
--epochs 6 \
|
15 |
+
--workers 6 \
|
16 |
+
--model ViT-L-16-CL16-GAP \
|
17 |
+
--precision 'amp_bf16' \
|
18 |
+
--ddp-static-graph \
|
19 |
+
--local-loss \
|
20 |
+
--gather-with-grad \
|
21 |
+
--force-image-size 64 \
|
22 |
+
--grad-checkpointing \
|
23 |
+
--log-every-n-steps 64 \
|
24 |
+
--seed 0 \
|
25 |
+
--logs ./logs/ \
|
26 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipa/vit_l16/i37_t8_finetune.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
2 |
+
--save-frequency 1 \
|
3 |
+
--save-most-recent \
|
4 |
+
--zeroshot-frequency 1 \
|
5 |
+
--train-data '/path/to/laion-400m' \
|
6 |
+
--dataset-type webdataset \
|
7 |
+
--lr "2.24e-5" \
|
8 |
+
--beta1 0.9 \
|
9 |
+
--beta2 0.95 \
|
10 |
+
--warmup 3571 \
|
11 |
+
--wd 0.2 \
|
12 |
+
--batch-size 896 \
|
13 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
14 |
+
--epochs 1 \
|
15 |
+
--train-num-samples 131072000 \
|
16 |
+
--workers 6 \
|
17 |
+
--model ViT-L-16-CL32-GAP \
|
18 |
+
--pretrained '/path/to/ckpt' \
|
19 |
+
--precision 'amp_bf16' \
|
20 |
+
--ddp-static-graph \
|
21 |
+
--local-loss \
|
22 |
+
--gather-with-grad \
|
23 |
+
--grad-checkpointing \
|
24 |
+
--log-every-n-steps 293 \
|
25 |
+
--seed 0 \
|
26 |
+
--logs ./logs/ \
|
27 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipa/vit_l16/i37_t8_pretrain.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
2 |
+
--save-frequency 1 \
|
3 |
+
--save-most-recent \
|
4 |
+
--zeroshot-frequency 1 \
|
5 |
+
--train-data '/path/to/laion-400m' \
|
6 |
+
--dataset-type webdataset \
|
7 |
+
--lr "1.024e-3" \
|
8 |
+
--beta1 0.9 \
|
9 |
+
--beta2 0.95 \
|
10 |
+
--warmup 1563 \
|
11 |
+
--wd 0.2 \
|
12 |
+
--batch-size 4096 \
|
13 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
14 |
+
--epochs 6 \
|
15 |
+
--workers 6 \
|
16 |
+
--model ViT-L-16-CL8-Syntax-GAP \
|
17 |
+
--precision 'amp_bf16' \
|
18 |
+
--ddp-static-graph \
|
19 |
+
--local-loss \
|
20 |
+
--gather-with-grad \
|
21 |
+
--force-image-size 96 \
|
22 |
+
--grad-checkpointing \
|
23 |
+
--log-every-n-steps 64 \
|
24 |
+
--seed 0 \
|
25 |
+
--logs ./logs/ \
|
26 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipav2/vit_h14/i257_t32_finetunex4.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# have not been tested. use it at your own discretion
|
2 |
+
# the original experiment was run on tpu v3-256.
|
3 |
+
# this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups.
|
4 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
5 |
+
--save-frequency 1 \
|
6 |
+
--save-most-recent \
|
7 |
+
--zeroshot-frequency 1 \
|
8 |
+
--train-data '/path/to/laion2b_or_datacomp1b' \
|
9 |
+
--train-num-samples 131072000 \
|
10 |
+
--dataset-type webdataset \
|
11 |
+
--lr "5.12e-5" \
|
12 |
+
--beta1 0.9 \
|
13 |
+
--beta2 0.95 \
|
14 |
+
--warmup 800 \
|
15 |
+
--wd 0.2 \
|
16 |
+
--batch-size 4096 \
|
17 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
18 |
+
--epochs 4 \
|
19 |
+
--workers 6 \
|
20 |
+
--model ViT-H-14-CL32-GAP \
|
21 |
+
--pretrained '/path/to/pretrain84_ckpt' \
|
22 |
+
--precision 'amp_bf16' \
|
23 |
+
--ddp-static-graph \
|
24 |
+
--local-loss \
|
25 |
+
--gather-with-grad \
|
26 |
+
--force-image-size 224 \
|
27 |
+
--force-patch-dropout 0.3 \
|
28 |
+
--grad-checkpointing \
|
29 |
+
--log-every-n-steps 64 \
|
30 |
+
--seed 0 \
|
31 |
+
--logs ./logs/ \
|
32 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipav2/vit_h14/i50_t8_pretrain.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# have not been tested. use it at your own discretion
|
2 |
+
# the original experiment was run on tpu v3-256.
|
3 |
+
# this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups.
|
4 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
5 |
+
--save-frequency 1 \
|
6 |
+
--save-most-recent \
|
7 |
+
--zeroshot-frequency 1 \
|
8 |
+
--train-data '/path/to/laion2b_or_datacomp1b' \
|
9 |
+
--train-num-samples 4e8 \
|
10 |
+
--dataset-type webdataset \
|
11 |
+
--lr "2.048e-3" \
|
12 |
+
--beta1 0.9 \
|
13 |
+
--beta2 0.95 \
|
14 |
+
--warmup 3200 \
|
15 |
+
--wd 0.2 \
|
16 |
+
--batch-size 8192 \
|
17 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
18 |
+
--epochs 32 \
|
19 |
+
--workers 6 \
|
20 |
+
--model ViT-H-14-CL8-Syntax-GAP \
|
21 |
+
--precision 'amp_bf16' \
|
22 |
+
--ddp-static-graph \
|
23 |
+
--local-loss \
|
24 |
+
--gather-with-grad \
|
25 |
+
--force-image-size 84 \
|
26 |
+
--grad-checkpointing \
|
27 |
+
--log-every-n-steps 32 \
|
28 |
+
--seed 0 \
|
29 |
+
--logs ./logs/ \
|
30 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/clipav2/vit_h14/i577_t32_finetunex1.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# have not been tested. use it at your own discretion
|
2 |
+
# the original experiment was run on tpu v3-256.
|
3 |
+
# this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups.
|
4 |
+
torchrun --nproc_per_node 8 -m open_clip_train.main \
|
5 |
+
--save-frequency 1 \
|
6 |
+
--save-most-recent \
|
7 |
+
--zeroshot-frequency 1 \
|
8 |
+
--train-data '/path/to/laion2b_or_datacomp1b' \
|
9 |
+
--train-num-samples 131072000 \
|
10 |
+
--dataset-type webdataset \
|
11 |
+
--lr "6.4e-6" \
|
12 |
+
--beta1 0.9 \
|
13 |
+
--beta2 0.95 \
|
14 |
+
--warmup 1600 \
|
15 |
+
--wd 0.2 \
|
16 |
+
--batch-size 2048 \
|
17 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
18 |
+
--epochs 1 \
|
19 |
+
--workers 6 \
|
20 |
+
--model ViT-H-14-CL32-GAP \
|
21 |
+
--pretrained '/path/to/finetune224_ckpt' \
|
22 |
+
--precision 'amp_bf16' \
|
23 |
+
--ddp-static-graph \
|
24 |
+
--local-loss \
|
25 |
+
--gather-with-grad \
|
26 |
+
--force-image-size 336 \
|
27 |
+
--force-patch-dropout 0.4 \
|
28 |
+
--grad-checkpointing \
|
29 |
+
--log-every-n-steps 64 \
|
30 |
+
--seed 0 \
|
31 |
+
--logs ./logs/ \
|
32 |
+
--imagenet-val '/path/to/imagenet/val'
|
open_clip/docs/script_examples/stability_example.sh
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=g40423
|
3 |
+
#SBATCH --job-name=testopenclip
|
4 |
+
#SBATCH --nodes 30
|
5 |
+
#SBATCH --ntasks-per-node=8
|
6 |
+
#SBATCH --cpus-per-task=12
|
7 |
+
#SBATCH --output=%x_%j.out
|
8 |
+
#SBATCH --comment=laion
|
9 |
+
#SBATCH --open-mode=append
|
10 |
+
#SBATCH --exclusive
|
11 |
+
|
12 |
+
module load openmpi
|
13 |
+
module load cuda/11.7
|
14 |
+
|
15 |
+
export MASTER_ADDR=`hostname`
|
16 |
+
export MASTER_PORT=12802
|
17 |
+
export NCCL_PROTO=simple
|
18 |
+
export FI_EFA_FORK_SAFE=1
|
19 |
+
export FI_LOG_LEVEL=1
|
20 |
+
export FI_EFA_USE_DEVICE_RDMA=1
|
21 |
+
export NCCL_DEBUG=info
|
22 |
+
|
23 |
+
export PYTHONFAULTHANDLER=1
|
24 |
+
|
25 |
+
export CUDA_LAUNCH_BLOCKING=0
|
26 |
+
export OMPI_MCA_mtl_base_verbose=1
|
27 |
+
export FI_EFA_ENABLE_SHM_TRANSFER=0
|
28 |
+
export FI_PROVIDER=efa
|
29 |
+
export FI_EFA_TX_MIN_CREDITS=64
|
30 |
+
export NCCL_TREE_THRESHOLD=0
|
31 |
+
|
32 |
+
cd /admin/home-mitchellw/open_clip/src
|
33 |
+
export PYTHONPATH="$PYTHONPATH:/admin/home-mitchellw/open_clip/src"
|
34 |
+
|
35 |
+
EXP_NAME="test-B-32-laion5b-lr1e-3-bs90k"
|
36 |
+
|
37 |
+
srun --comment laion --cpu_bind=v --accel-bind=gn python -m open_clip_train.main \
|
38 |
+
--save-frequency 1 \
|
39 |
+
--train-data="pipe:aws s3 cp s3://s-datasets/laion5b/{laion2B-data/{000000..231349}.tar,laion2B-multi-data/{000000..226687}.tar,laion1B-nolang-data/{000000..127231}.tar} -" \
|
40 |
+
--train-num-samples 135646078 \
|
41 |
+
--dataset-type webdataset \
|
42 |
+
--dataset-resampled \
|
43 |
+
--warmup 2000 \
|
44 |
+
--batch-size=375 \
|
45 |
+
--epochs=97 \
|
46 |
+
--lr 1e-3 \
|
47 |
+
--workers=8 \
|
48 |
+
--report-to wandb \
|
49 |
+
--name ${EXP_NAME} \
|
50 |
+
--logs /scratch/logs/ \
|
51 |
+
--model ViT-B-32 \
|
52 |
+
--seed 0 \
|
53 |
+
--ddp-static-graph \
|
54 |
+
--local-loss \
|
55 |
+
--gather-with-grad \
|
56 |
+
--grad-checkpointing \
|
57 |
+
--precision amp_bfloat16 \
|
58 |
+
--wandb-project-name open_clip6 \
|
59 |
+
--resume "latest" \
|
60 |
+
--remote-sync s3://s-laion/mitchellw/logs
|
open_clip/scripts/clipav1_vit_l16_i37_t8.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# eval on a single gpu
|
2 |
+
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m open_clip_train.main \
|
3 |
+
--model ViT-L-16-CL32-GAP \
|
4 |
+
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \
|
5 |
+
--seed 0 \
|
6 |
+
--imagenet-val '/path/to/ImageNet/val'
|
open_clip/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=1 python3 -m open_clip_train.main \
|
2 |
+
--model ViT-H-14-CL32-GAP-BigVision \
|
3 |
+
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \
|
4 |
+
--force-image-size 336 \
|
5 |
+
--square-resize-only \
|
6 |
+
--interpolation 'bilinear' \
|
7 |
+
--image-mean 0.485 0.456 0.406 \
|
8 |
+
--image-std 0.229 0.224 0.225 \
|
9 |
+
--seed 0 \
|
10 |
+
--imagenet-val '/path/to/ImageNet/val'
|
open_clip/scripts/h14_84_8_pretrain.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 64k batchsize for 2.048e-3 lr
|
2 |
+
TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m open_clip_train.main \
|
3 |
+
--save-frequency 1 \
|
4 |
+
--save-most-recent \
|
5 |
+
--zeroshot-frequency 1 \
|
6 |
+
--train-data '/path/to/laion' \
|
7 |
+
--dataset-type webdataset \
|
8 |
+
--lr "2.048e-3" \
|
9 |
+
--beta1 0.9 \
|
10 |
+
--beta2 0.95 \
|
11 |
+
--warmup 782 \
|
12 |
+
--wd 0.2 \
|
13 |
+
--batch-size 4096 \
|
14 |
+
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
|
15 |
+
--epochs=7 \
|
16 |
+
--workers=6 \
|
17 |
+
--model ViT-H-14-CL8-SyntaxMask-GAP \
|
18 |
+
--precision 'amp_bf16' \
|
19 |
+
--local-loss \
|
20 |
+
--gather-with-grad \
|
21 |
+
--force-image-size 84 \
|
22 |
+
--grad-checkpointing \
|
23 |
+
--log-every-n-steps 32 \
|
24 |
+
--seed 0 \
|
25 |
+
--logs ./logs/ \
|
26 |
+
--imagenet-val '/path/to/ImageNet/val' \
|
27 |
+
--name 'name' \
|
28 |
+
--report-to "wandb" \
|
29 |
+
--wandb-project-name "project_name"
|
30 |
+
|
31 |
+
|
open_clip/src/open_clip/constants.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
3 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
4 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
5 |
+
INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
6 |
+
INCEPTION_STD = (0.5, 0.5, 0.5)
|
7 |
+
|
8 |
+
# Default name for a weights file hosted on the Huggingface Hub.
|
9 |
+
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
|
10 |
+
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
|
11 |
+
HF_CONFIG_NAME = 'open_clip_config.json'
|
open_clip/src/open_clip/convert.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
|
2 |
+
"""
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .model import CLIP, CustomTextCLIP
|
9 |
+
from .transformer import TextTransformer, Transformer
|
10 |
+
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
|
14 |
+
""" Load weights from .npz checkpoints for official Google big_vision image-text models
|
15 |
+
|
16 |
+
Currently the SigLIP source models are supported and a CustomTextCLIP destination model
|
17 |
+
w/ timm image encoder.
|
18 |
+
"""
|
19 |
+
from timm.layers import resample_patch_embed, resample_abs_pos_embed
|
20 |
+
|
21 |
+
def _n2p(w, t=True, idx=None):
|
22 |
+
if idx is not None:
|
23 |
+
w = w[idx]
|
24 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
25 |
+
w = w.flatten()
|
26 |
+
if t:
|
27 |
+
if w.ndim == 4:
|
28 |
+
w = w.transpose([3, 2, 0, 1])
|
29 |
+
elif w.ndim == 3:
|
30 |
+
w = w.transpose([2, 0, 1])
|
31 |
+
elif w.ndim == 2:
|
32 |
+
w = w.transpose([1, 0])
|
33 |
+
return torch.from_numpy(w)
|
34 |
+
|
35 |
+
w = np.load(checkpoint_path)
|
36 |
+
interpolation = 'bilinear'
|
37 |
+
antialias = False
|
38 |
+
|
39 |
+
def _convert_timm_img(module, prefix):
|
40 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
41 |
+
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
|
42 |
+
embed_conv_w = resample_patch_embed(
|
43 |
+
embed_conv_w,
|
44 |
+
module.patch_embed.proj.weight.shape[-2:],
|
45 |
+
interpolation=interpolation,
|
46 |
+
antialias=antialias,
|
47 |
+
verbose=True,
|
48 |
+
)
|
49 |
+
module.patch_embed.proj.weight.copy_(embed_conv_w)
|
50 |
+
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
51 |
+
|
52 |
+
if module.cls_token is not None:
|
53 |
+
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
54 |
+
|
55 |
+
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
|
56 |
+
if pos_embed_w.shape != module.pos_embed.shape:
|
57 |
+
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
|
58 |
+
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
|
59 |
+
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
|
60 |
+
pos_embed_w,
|
61 |
+
new_size=module.patch_embed.grid_size,
|
62 |
+
num_prefix_tokens=num_prefix_tokens,
|
63 |
+
interpolation=interpolation,
|
64 |
+
antialias=antialias,
|
65 |
+
verbose=True,
|
66 |
+
)
|
67 |
+
module.pos_embed.copy_(pos_embed_w)
|
68 |
+
|
69 |
+
mha_sub, b_sub, ln1_sub = (0, 0, 1)
|
70 |
+
for i, block in enumerate(module.blocks.children()):
|
71 |
+
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
|
72 |
+
block_prefix = f'{prefix}Transformer/encoderblock/'
|
73 |
+
idx = i
|
74 |
+
else:
|
75 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
76 |
+
idx = None
|
77 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
|
78 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
|
79 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
|
80 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
81 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
|
82 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
83 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
|
84 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
|
85 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
|
86 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
|
87 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
|
88 |
+
for r in range(2):
|
89 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
|
90 |
+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
|
91 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
|
92 |
+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
|
93 |
+
|
94 |
+
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
95 |
+
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
96 |
+
|
97 |
+
if module.attn_pool is not None:
|
98 |
+
block_prefix = f'{prefix}MAPHead_0/'
|
99 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
100 |
+
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
|
101 |
+
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
|
102 |
+
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
|
103 |
+
module.attn_pool.kv.weight.copy_(torch.cat([
|
104 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
|
105 |
+
module.attn_pool.kv.bias.copy_(torch.cat([
|
106 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
|
107 |
+
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
108 |
+
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
109 |
+
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
110 |
+
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
111 |
+
for r in range(2):
|
112 |
+
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
|
113 |
+
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
|
114 |
+
|
115 |
+
def _convert_openclip_transformer(module: Transformer, prefix):
|
116 |
+
for i, block in enumerate(module.resblocks.children()):
|
117 |
+
block_prefix = f'{prefix}encoderblock_{i}/'
|
118 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
119 |
+
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
120 |
+
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
121 |
+
block.attn.in_proj_weight.copy_(torch.cat([
|
122 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
123 |
+
block.attn.in_proj_bias.copy_(torch.cat([
|
124 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
125 |
+
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
126 |
+
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
127 |
+
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
|
128 |
+
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
|
129 |
+
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
|
130 |
+
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
|
131 |
+
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
|
132 |
+
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
|
133 |
+
|
134 |
+
def _convert_openclip_txt(module: TextTransformer, prefix):
|
135 |
+
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
|
136 |
+
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
|
137 |
+
module.positional_embedding.copy_(pos_embed_w)
|
138 |
+
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
|
139 |
+
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
|
140 |
+
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
|
141 |
+
if module.text_projection is not None:
|
142 |
+
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
143 |
+
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
144 |
+
|
145 |
+
_convert_timm_img(model.visual.trunk, 'img/')
|
146 |
+
_convert_openclip_txt(model.text, 'txt/')
|
147 |
+
model.logit_bias.copy_(_n2p(w['b'])[0])
|
148 |
+
model.logit_scale.copy_(_n2p(w['t'])[0])
|
149 |
+
|
150 |
+
|
151 |
+
@torch.no_grad()
|
152 |
+
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
|
153 |
+
|
154 |
+
def _convert_timm_img(state_dict):
|
155 |
+
if fastvit:
|
156 |
+
from timm.models.fastvit import checkpoint_filter_fn
|
157 |
+
else:
|
158 |
+
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
|
159 |
+
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
|
160 |
+
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
|
161 |
+
return timm_state_dict
|
162 |
+
|
163 |
+
def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
|
164 |
+
text_dict = {}
|
165 |
+
for k, v in state_dict.items():
|
166 |
+
if not k.startswith(prefix):
|
167 |
+
continue
|
168 |
+
k = k.replace(prefix, '')
|
169 |
+
k = k.replace('projection_layer', 'text_projection')
|
170 |
+
k = k.replace('embedding_layer', 'token_embedding')
|
171 |
+
if k.startswith('positional_embedding.pos_embed.pos_embed'):
|
172 |
+
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
|
173 |
+
v = v.squeeze()
|
174 |
+
k = k.replace('final_layer_norm', 'ln_final')
|
175 |
+
k = k.replace('pre_norm_mha.0', 'ln_1')
|
176 |
+
k = k.replace('pre_norm_mha.1', 'attn')
|
177 |
+
k = k.replace('pre_norm_ffn.0', 'ln_2')
|
178 |
+
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
|
179 |
+
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
|
180 |
+
k = k.replace('qkv_proj.weight', 'in_proj_weight')
|
181 |
+
k = k.replace('qkv_proj.bias', 'in_proj_bias')
|
182 |
+
k = k.replace('transformer.', 'transformer.resblocks.')
|
183 |
+
text_dict['text.' + k] = v
|
184 |
+
return text_dict
|
185 |
+
|
186 |
+
image_dict = _convert_timm_img(state_dict)
|
187 |
+
text_dict = _convert_openclip_txt(state_dict)
|
188 |
+
out_dict = {**image_dict, **text_dict}
|
189 |
+
out_dict['logit_scale'] = state_dict['logit_scale']
|
190 |
+
return out_dict
|
191 |
+
|
192 |
+
|
193 |
+
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
|
194 |
+
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
|
195 |
+
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
|
196 |
+
state_dict = convert_mobile_clip_state_dict(model, state_dict)
|
197 |
+
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
|
198 |
+
# convert b model
|
199 |
+
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
|
200 |
+
return state_dict
|
open_clip/src/open_clip/factory.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import warnings
|
6 |
+
from copy import deepcopy
|
7 |
+
from dataclasses import asdict
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .convert import convert_state_dict
|
14 |
+
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
15 |
+
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
|
16 |
+
from .coca_model import CoCa
|
17 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
|
18 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
19 |
+
list_pretrained_tags_by_model, download_pretrained_from_hf
|
20 |
+
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
|
21 |
+
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
|
22 |
+
|
23 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
24 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
25 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
26 |
+
|
27 |
+
|
28 |
+
def _natural_key(string_):
|
29 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
30 |
+
|
31 |
+
|
32 |
+
def _rescan_model_configs():
|
33 |
+
global _MODEL_CONFIGS
|
34 |
+
|
35 |
+
config_ext = ('.json',)
|
36 |
+
config_files = []
|
37 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
38 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
39 |
+
config_files.append(config_path)
|
40 |
+
elif config_path.is_dir():
|
41 |
+
for ext in config_ext:
|
42 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
43 |
+
|
44 |
+
for cf in config_files:
|
45 |
+
with open(cf, 'r') as f:
|
46 |
+
model_cfg = json.load(f)
|
47 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
48 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
49 |
+
|
50 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
51 |
+
|
52 |
+
|
53 |
+
_rescan_model_configs() # initial populate of model config registry
|
54 |
+
|
55 |
+
|
56 |
+
def list_models():
|
57 |
+
""" enumerate available model architectures based on config files """
|
58 |
+
return list(_MODEL_CONFIGS.keys())
|
59 |
+
|
60 |
+
|
61 |
+
def add_model_config(path):
|
62 |
+
""" add model config path or file and update registry """
|
63 |
+
if not isinstance(path, Path):
|
64 |
+
path = Path(path)
|
65 |
+
_MODEL_CONFIG_PATHS.append(path)
|
66 |
+
_rescan_model_configs()
|
67 |
+
|
68 |
+
|
69 |
+
def get_model_config(model_name):
|
70 |
+
""" Fetch model config from builtin (local library) configs.
|
71 |
+
"""
|
72 |
+
if model_name in _MODEL_CONFIGS:
|
73 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
74 |
+
else:
|
75 |
+
return None
|
76 |
+
|
77 |
+
|
78 |
+
def _get_hf_config(
|
79 |
+
model_id: str,
|
80 |
+
cache_dir: Optional[str] = None,
|
81 |
+
):
|
82 |
+
""" Fetch model config from HuggingFace Hub.
|
83 |
+
"""
|
84 |
+
config_path = download_pretrained_from_hf(
|
85 |
+
model_id,
|
86 |
+
filename='open_clip_config.json',
|
87 |
+
cache_dir=cache_dir,
|
88 |
+
)
|
89 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
90 |
+
config = json.load(f)
|
91 |
+
return config
|
92 |
+
|
93 |
+
|
94 |
+
def get_tokenizer(
|
95 |
+
model_name: str = '',
|
96 |
+
context_length: Optional[int] = None,
|
97 |
+
cache_dir: Optional[str] = None,
|
98 |
+
**kwargs,
|
99 |
+
):
|
100 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
101 |
+
model_name = model_name[len(HF_HUB_PREFIX):]
|
102 |
+
try:
|
103 |
+
config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg']
|
104 |
+
except Exception:
|
105 |
+
tokenizer = HFTokenizer(
|
106 |
+
model_name,
|
107 |
+
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
|
108 |
+
cache_dir=cache_dir,
|
109 |
+
**kwargs,
|
110 |
+
)
|
111 |
+
return tokenizer
|
112 |
+
else:
|
113 |
+
config = get_model_config(model_name)
|
114 |
+
assert config is not None, f"No valid model config found for {model_name}."
|
115 |
+
|
116 |
+
text_config = config.get('text_cfg', {})
|
117 |
+
if 'tokenizer_kwargs' in text_config:
|
118 |
+
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
|
119 |
+
else:
|
120 |
+
tokenizer_kwargs = kwargs
|
121 |
+
|
122 |
+
if context_length is None:
|
123 |
+
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
|
124 |
+
|
125 |
+
if 'hf_tokenizer_name' in text_config:
|
126 |
+
tokenizer = HFTokenizer(
|
127 |
+
text_config['hf_tokenizer_name'],
|
128 |
+
context_length=context_length,
|
129 |
+
cache_dir=cache_dir,
|
130 |
+
**tokenizer_kwargs,
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
tokenizer = SimpleTokenizer(
|
134 |
+
context_length=context_length,
|
135 |
+
**tokenizer_kwargs,
|
136 |
+
)
|
137 |
+
|
138 |
+
return tokenizer
|
139 |
+
|
140 |
+
|
141 |
+
def load_state_dict(
|
142 |
+
checkpoint_path: str,
|
143 |
+
device='cpu',
|
144 |
+
weights_only=True,
|
145 |
+
):
|
146 |
+
# Check if safetensors or not and load weights accordingly
|
147 |
+
if str(checkpoint_path).endswith(".safetensors"):
|
148 |
+
from safetensors.torch import load_file
|
149 |
+
checkpoint = load_file(checkpoint_path, device=device)
|
150 |
+
else:
|
151 |
+
try:
|
152 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
153 |
+
except TypeError:
|
154 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
155 |
+
|
156 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
157 |
+
state_dict = checkpoint['state_dict']
|
158 |
+
elif isinstance(checkpoint, torch.jit.ScriptModule):
|
159 |
+
state_dict = checkpoint.state_dict()
|
160 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
161 |
+
state_dict.pop(key, None)
|
162 |
+
else:
|
163 |
+
state_dict = checkpoint
|
164 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
165 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
166 |
+
return state_dict
|
167 |
+
|
168 |
+
|
169 |
+
def load_checkpoint(
|
170 |
+
model: Union[CLIP, CustomTextCLIP],
|
171 |
+
checkpoint_path: str,
|
172 |
+
strict: bool = True,
|
173 |
+
weights_only: bool = True,
|
174 |
+
device='cpu',
|
175 |
+
):
|
176 |
+
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
|
177 |
+
# Separate path loading numpy big_vision (SigLIP) weights
|
178 |
+
from open_clip.convert import load_big_vision_weights
|
179 |
+
load_big_vision_weights(model, checkpoint_path)
|
180 |
+
return {}
|
181 |
+
|
182 |
+
state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
|
183 |
+
|
184 |
+
# Detect & convert 3rd party state_dicts -> open_clip
|
185 |
+
state_dict = convert_state_dict(model, state_dict)
|
186 |
+
|
187 |
+
# Detect old format and make compatible with new format
|
188 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
189 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
190 |
+
|
191 |
+
# correct if logit_scale differs in being scaler vs 1d param
|
192 |
+
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
|
193 |
+
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
|
194 |
+
|
195 |
+
# correct if logit_bias differs in being scaler vs 1d param
|
196 |
+
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
|
197 |
+
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
|
198 |
+
|
199 |
+
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
|
200 |
+
if 'logit_bias' not in state_dict and model.logit_bias is not None:
|
201 |
+
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
|
202 |
+
|
203 |
+
# Certain text transformers no longer expect position_ids after transformers==4.31
|
204 |
+
position_id_key = 'text.transformer.embeddings.position_ids'
|
205 |
+
if position_id_key in state_dict and not hasattr(model, position_id_key):
|
206 |
+
del state_dict[position_id_key]
|
207 |
+
|
208 |
+
resize_pos_embed(state_dict, model)
|
209 |
+
resize_text_pos_embed(state_dict, model)
|
210 |
+
|
211 |
+
# Finally, load the massaged state_dict into model
|
212 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
213 |
+
return incompatible_keys
|
214 |
+
|
215 |
+
|
216 |
+
def create_model(
|
217 |
+
model_name: str,
|
218 |
+
pretrained: Optional[str] = None,
|
219 |
+
precision: str = 'fp32',
|
220 |
+
device: Union[str, torch.device] = 'cpu',
|
221 |
+
jit: bool = False,
|
222 |
+
force_quick_gelu: bool = False,
|
223 |
+
force_custom_text: bool = False,
|
224 |
+
force_patch_dropout: Optional[float] = None,
|
225 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
226 |
+
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
227 |
+
pretrained_image: bool = False,
|
228 |
+
pretrained_hf: bool = True,
|
229 |
+
cache_dir: Optional[str] = None,
|
230 |
+
output_dict: Optional[bool] = None,
|
231 |
+
require_pretrained: bool = False,
|
232 |
+
load_weights_only: bool = True,
|
233 |
+
**model_kwargs,
|
234 |
+
):
|
235 |
+
"""Creates and configures a contrastive vision-language model.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
model_name: Name of the model architecture to create. Can be a local model name
|
239 |
+
or a Hugging Face model ID prefixed with 'hf-hub:'.
|
240 |
+
pretrained: Tag/path for pretrained model weights. Can be:
|
241 |
+
- A pretrained tag name (e.g., 'openai')
|
242 |
+
- A path to local weights
|
243 |
+
- None to initialize with random weights
|
244 |
+
precision: Model precision/AMP configuration. Options:
|
245 |
+
- 'fp32': 32-bit floating point
|
246 |
+
- 'fp16'/'bf16': Mixed precision with FP32 for certain layers
|
247 |
+
- 'pure_fp16'/'pure_bf16': Pure 16-bit precision
|
248 |
+
device: Device to load the model on ('cpu', 'cuda', or torch.device object)
|
249 |
+
jit: If True, JIT compile the model
|
250 |
+
force_quick_gelu: Force use of QuickGELU activation
|
251 |
+
force_custom_text: Force use of custom text encoder
|
252 |
+
force_patch_dropout: Override default patch dropout value
|
253 |
+
force_image_size: Override default image size for vision encoder
|
254 |
+
force_preprocess_cfg: Override default preprocessing configuration
|
255 |
+
pretrained_image: Load pretrained weights for timm vision models
|
256 |
+
pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights
|
257 |
+
cache_dir: Override default cache directory for downloaded model files
|
258 |
+
output_dict: If True and model supports it, return dictionary of features
|
259 |
+
require_pretrained: Raise error if pretrained weights cannot be loaded
|
260 |
+
load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)
|
261 |
+
**model_kwargs: Additional keyword arguments passed to model constructor
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Created and configured model instance
|
265 |
+
|
266 |
+
Raises:
|
267 |
+
RuntimeError: If model config is not found or required pretrained weights
|
268 |
+
cannot be loaded
|
269 |
+
|
270 |
+
Examples:
|
271 |
+
# Create basic CLIP model
|
272 |
+
model = create_model('ViT-B/32')
|
273 |
+
|
274 |
+
# Create CLIP model with mixed precision on GPU
|
275 |
+
model = create_model('ViT-B/32', precision='fp16', device='cuda')
|
276 |
+
|
277 |
+
# Load pretrained OpenAI weights
|
278 |
+
model = create_model('ViT-B/32', pretrained='openai')
|
279 |
+
|
280 |
+
# Load Hugging Face model
|
281 |
+
model = create_model('hf-hub:organization/model-name')
|
282 |
+
"""
|
283 |
+
|
284 |
+
force_preprocess_cfg = force_preprocess_cfg or {}
|
285 |
+
preprocess_cfg = asdict(PreprocessCfg())
|
286 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
287 |
+
if has_hf_hub_prefix:
|
288 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
289 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
290 |
+
config = _get_hf_config(model_id, cache_dir=cache_dir)
|
291 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
|
292 |
+
model_cfg = config['model_cfg']
|
293 |
+
pretrained_hf = False # override, no need to load original HF text weights
|
294 |
+
else:
|
295 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
296 |
+
checkpoint_path = None
|
297 |
+
model_cfg = None
|
298 |
+
|
299 |
+
if isinstance(device, str):
|
300 |
+
device = torch.device(device)
|
301 |
+
|
302 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
303 |
+
if model_cfg is not None:
|
304 |
+
logging.info(f'Loaded {model_name} model config.')
|
305 |
+
else:
|
306 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
307 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
308 |
+
|
309 |
+
if force_quick_gelu:
|
310 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
311 |
+
model_cfg["quick_gelu"] = True
|
312 |
+
|
313 |
+
if force_patch_dropout is not None:
|
314 |
+
# override the default patch dropout value
|
315 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
316 |
+
|
317 |
+
if force_image_size is not None:
|
318 |
+
# override model config's image size
|
319 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
320 |
+
|
321 |
+
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
|
322 |
+
if pretrained_image:
|
323 |
+
if is_timm_model:
|
324 |
+
# pretrained weight loading for timm models set via vision_cfg
|
325 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
326 |
+
else:
|
327 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
328 |
+
|
329 |
+
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
|
330 |
+
cast_dtype = get_cast_dtype(precision)
|
331 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
332 |
+
if is_hf_model:
|
333 |
+
# load pretrained weights for HF text model IFF no CLIP weights being loaded
|
334 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
|
335 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
336 |
+
|
337 |
+
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
|
338 |
+
if custom_text:
|
339 |
+
if "multimodal_cfg" in model_cfg:
|
340 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
341 |
+
else:
|
342 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
343 |
+
else:
|
344 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
345 |
+
|
346 |
+
if precision in ("fp16", "bf16"):
|
347 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
348 |
+
# manual mixed precision that matches original OpenAI behaviour
|
349 |
+
if is_timm_model:
|
350 |
+
# FIXME this is a bit janky, create timm based model in low-precision and
|
351 |
+
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
352 |
+
# Why? The convert_weights_to_lp fn only works with native models.
|
353 |
+
model.to(device=device, dtype=dtype)
|
354 |
+
from .transformer import LayerNormFp32
|
355 |
+
|
356 |
+
def _convert_ln(m):
|
357 |
+
if isinstance(m, LayerNormFp32):
|
358 |
+
m.weight.data = m.weight.data.to(torch.float32)
|
359 |
+
m.bias.data = m.bias.data.to(torch.float32)
|
360 |
+
model.apply(_convert_ln)
|
361 |
+
else:
|
362 |
+
model.to(device=device)
|
363 |
+
convert_weights_to_lp(model, dtype=dtype)
|
364 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
365 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
366 |
+
model.to(device=device, dtype=dtype)
|
367 |
+
else:
|
368 |
+
model.to(device=device)
|
369 |
+
|
370 |
+
pretrained_loaded = False
|
371 |
+
if pretrained:
|
372 |
+
checkpoint_path = ''
|
373 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
374 |
+
if pretrained_cfg:
|
375 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
376 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
|
377 |
+
pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)
|
378 |
+
model_quick_gelu = model_cfg.get('quick_gelu', False)
|
379 |
+
if pretrained_quick_gelu and not model_quick_gelu:
|
380 |
+
warnings.warn(
|
381 |
+
f'These pretrained weights were trained with QuickGELU activation but the model config does '
|
382 |
+
f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.')
|
383 |
+
elif not pretrained_quick_gelu and model_quick_gelu:
|
384 |
+
warnings.warn(
|
385 |
+
f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '
|
386 |
+
f'model config, consider using a model config without QuickGELU or disable override flags.')
|
387 |
+
elif os.path.exists(pretrained):
|
388 |
+
checkpoint_path = pretrained
|
389 |
+
|
390 |
+
if checkpoint_path:
|
391 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
392 |
+
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
|
393 |
+
else:
|
394 |
+
error_str = (
|
395 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
396 |
+
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
397 |
+
logging.warning(error_str)
|
398 |
+
raise RuntimeError(error_str)
|
399 |
+
pretrained_loaded = True
|
400 |
+
elif has_hf_hub_prefix:
|
401 |
+
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
|
402 |
+
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
|
403 |
+
pretrained_loaded = True
|
404 |
+
|
405 |
+
if require_pretrained and not pretrained_loaded:
|
406 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
407 |
+
raise RuntimeError(
|
408 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
409 |
+
|
410 |
+
if output_dict and hasattr(model, "output_dict"):
|
411 |
+
model.output_dict = True
|
412 |
+
|
413 |
+
if jit:
|
414 |
+
model = torch.jit.script(model)
|
415 |
+
|
416 |
+
# set image preprocessing configuration in model attributes for convenience
|
417 |
+
if getattr(model.visual, 'image_size', None) is not None:
|
418 |
+
# use image_size set on model creation (via config or force_image_size arg)
|
419 |
+
force_preprocess_cfg['size'] = model.visual.image_size
|
420 |
+
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
|
421 |
+
|
422 |
+
return model
|
423 |
+
|
424 |
+
|
425 |
+
def create_loss(args):
|
426 |
+
if args.distill:
|
427 |
+
return DistillClipLoss(
|
428 |
+
local_loss=args.local_loss,
|
429 |
+
gather_with_grad=args.gather_with_grad,
|
430 |
+
cache_labels=True,
|
431 |
+
rank=args.rank,
|
432 |
+
world_size=args.world_size,
|
433 |
+
use_horovod=args.horovod,
|
434 |
+
)
|
435 |
+
elif "coca" in args.model.lower():
|
436 |
+
return CoCaLoss(
|
437 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
438 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
439 |
+
local_loss=args.local_loss,
|
440 |
+
gather_with_grad=args.gather_with_grad,
|
441 |
+
cache_labels=True,
|
442 |
+
rank=args.rank,
|
443 |
+
world_size=args.world_size,
|
444 |
+
use_horovod=args.horovod,
|
445 |
+
)
|
446 |
+
elif args.siglip:
|
447 |
+
assert not args.horovod, "Horovod not currently supported for SigLip"
|
448 |
+
return SigLipLoss(
|
449 |
+
rank=args.rank,
|
450 |
+
world_size=args.world_size,
|
451 |
+
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
|
452 |
+
)
|
453 |
+
|
454 |
+
return ClipLoss(
|
455 |
+
local_loss=args.local_loss,
|
456 |
+
gather_with_grad=args.gather_with_grad,
|
457 |
+
cache_labels=True,
|
458 |
+
rank=args.rank,
|
459 |
+
world_size=args.world_size,
|
460 |
+
use_horovod=args.horovod,
|
461 |
+
)
|
462 |
+
|
463 |
+
|
464 |
+
def create_model_and_transforms(
|
465 |
+
model_name: str,
|
466 |
+
pretrained: Optional[str] = None,
|
467 |
+
precision: str = 'fp32',
|
468 |
+
device: Union[str, torch.device] = 'cpu',
|
469 |
+
jit: bool = False,
|
470 |
+
force_quick_gelu: bool = False,
|
471 |
+
force_custom_text: bool = False,
|
472 |
+
force_patch_dropout: Optional[float] = None,
|
473 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
474 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
475 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
476 |
+
image_interpolation: Optional[str] = None,
|
477 |
+
image_resize_mode: Optional[str] = None, # only effective for inference
|
478 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
479 |
+
pretrained_image: bool = False,
|
480 |
+
pretrained_hf: bool = True,
|
481 |
+
cache_dir: Optional[str] = None,
|
482 |
+
output_dict: Optional[bool] = None,
|
483 |
+
load_weights_only: bool = True,
|
484 |
+
**model_kwargs,
|
485 |
+
):
|
486 |
+
force_preprocess_cfg = merge_preprocess_kwargs(
|
487 |
+
{},
|
488 |
+
mean=image_mean,
|
489 |
+
std=image_std,
|
490 |
+
interpolation=image_interpolation,
|
491 |
+
resize_mode=image_resize_mode,
|
492 |
+
)
|
493 |
+
|
494 |
+
model = create_model(
|
495 |
+
model_name,
|
496 |
+
pretrained,
|
497 |
+
precision=precision,
|
498 |
+
device=device,
|
499 |
+
jit=jit,
|
500 |
+
force_quick_gelu=force_quick_gelu,
|
501 |
+
force_custom_text=force_custom_text,
|
502 |
+
force_patch_dropout=force_patch_dropout,
|
503 |
+
force_image_size=force_image_size,
|
504 |
+
force_preprocess_cfg=force_preprocess_cfg,
|
505 |
+
pretrained_image=pretrained_image,
|
506 |
+
pretrained_hf=pretrained_hf,
|
507 |
+
cache_dir=cache_dir,
|
508 |
+
output_dict=output_dict,
|
509 |
+
load_weights_only=load_weights_only,
|
510 |
+
**model_kwargs,
|
511 |
+
)
|
512 |
+
|
513 |
+
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
|
514 |
+
|
515 |
+
preprocess_train = image_transform_v2(
|
516 |
+
pp_cfg,
|
517 |
+
is_train=True,
|
518 |
+
aug_cfg=aug_cfg,
|
519 |
+
)
|
520 |
+
preprocess_val = image_transform_v2(
|
521 |
+
pp_cfg,
|
522 |
+
is_train=False,
|
523 |
+
)
|
524 |
+
|
525 |
+
return model, preprocess_train, preprocess_val
|
526 |
+
|
527 |
+
|
528 |
+
def create_model_from_pretrained(
|
529 |
+
model_name: str,
|
530 |
+
pretrained: Optional[str] = None,
|
531 |
+
precision: str = 'fp32',
|
532 |
+
device: Union[str, torch.device] = 'cpu',
|
533 |
+
jit: bool = False,
|
534 |
+
force_quick_gelu: bool = False,
|
535 |
+
force_custom_text: bool = False,
|
536 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
537 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
538 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
539 |
+
image_interpolation: Optional[str] = None,
|
540 |
+
image_resize_mode: Optional[str] = None, # only effective for inference
|
541 |
+
return_transform: bool = True,
|
542 |
+
cache_dir: Optional[str] = None,
|
543 |
+
load_weights_only: bool = True,
|
544 |
+
**model_kwargs,
|
545 |
+
):
|
546 |
+
force_preprocess_cfg = merge_preprocess_kwargs(
|
547 |
+
{},
|
548 |
+
mean=image_mean,
|
549 |
+
std=image_std,
|
550 |
+
interpolation=image_interpolation,
|
551 |
+
resize_mode=image_resize_mode,
|
552 |
+
)
|
553 |
+
|
554 |
+
model = create_model(
|
555 |
+
model_name,
|
556 |
+
pretrained,
|
557 |
+
precision=precision,
|
558 |
+
device=device,
|
559 |
+
jit=jit,
|
560 |
+
force_quick_gelu=force_quick_gelu,
|
561 |
+
force_custom_text=force_custom_text,
|
562 |
+
force_image_size=force_image_size,
|
563 |
+
force_preprocess_cfg=force_preprocess_cfg,
|
564 |
+
cache_dir=cache_dir,
|
565 |
+
require_pretrained=True,
|
566 |
+
load_weights_only=load_weights_only,
|
567 |
+
**model_kwargs,
|
568 |
+
)
|
569 |
+
|
570 |
+
if not return_transform:
|
571 |
+
return model
|
572 |
+
|
573 |
+
preprocess = image_transform_v2(
|
574 |
+
PreprocessCfg(**model.visual.preprocess_cfg),
|
575 |
+
is_train=False,
|
576 |
+
)
|
577 |
+
|
578 |
+
return model, preprocess
|
open_clip/src/open_clip/hf_model.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
import re
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import TensorType
|
10 |
+
|
11 |
+
try:
|
12 |
+
import transformers
|
13 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
14 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
15 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
16 |
+
except ImportError as e:
|
17 |
+
transformers = None
|
18 |
+
|
19 |
+
|
20 |
+
class BaseModelOutput:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class PretrainedConfig:
|
25 |
+
pass
|
26 |
+
|
27 |
+
from .hf_configs import arch_dict
|
28 |
+
|
29 |
+
|
30 |
+
# utils
|
31 |
+
def _camel2snake(s):
|
32 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
33 |
+
|
34 |
+
|
35 |
+
# TODO: ?last - for gpt-like models
|
36 |
+
_POOLERS = {}
|
37 |
+
|
38 |
+
|
39 |
+
def register_pooler(cls):
|
40 |
+
"""Decorator registering pooler class"""
|
41 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
42 |
+
return cls
|
43 |
+
|
44 |
+
|
45 |
+
@register_pooler
|
46 |
+
class MeanPooler(nn.Module):
|
47 |
+
"""Mean pooling"""
|
48 |
+
|
49 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
50 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
51 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
52 |
+
|
53 |
+
|
54 |
+
@register_pooler
|
55 |
+
class MaxPooler(nn.Module):
|
56 |
+
"""Max pooling"""
|
57 |
+
|
58 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
59 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
60 |
+
return masked_output.max(1).values
|
61 |
+
|
62 |
+
|
63 |
+
@register_pooler
|
64 |
+
class ClsPooler(nn.Module):
|
65 |
+
"""CLS token pooling"""
|
66 |
+
|
67 |
+
def __init__(self, use_pooler_output=True):
|
68 |
+
super().__init__()
|
69 |
+
self.cls_token_position = 0
|
70 |
+
self.use_pooler_output = use_pooler_output
|
71 |
+
|
72 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
73 |
+
if (self.use_pooler_output and
|
74 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
75 |
+
(x.pooler_output is not None)
|
76 |
+
):
|
77 |
+
return x.pooler_output
|
78 |
+
|
79 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
80 |
+
|
81 |
+
|
82 |
+
@register_pooler
|
83 |
+
class ClsLastHiddenStatePooler(nn.Module):
|
84 |
+
"""CLS token pooling
|
85 |
+
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self):
|
89 |
+
super().__init__()
|
90 |
+
self.cls_token_position = 0
|
91 |
+
|
92 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
93 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
94 |
+
|
95 |
+
|
96 |
+
class HFTextEncoder(nn.Module):
|
97 |
+
"""HuggingFace model adapter"""
|
98 |
+
output_tokens: torch.jit.Final[bool]
|
99 |
+
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
model_name_or_path: str,
|
103 |
+
output_dim: int,
|
104 |
+
config: PretrainedConfig = None,
|
105 |
+
pooler_type: str = None,
|
106 |
+
proj_type: str = None,
|
107 |
+
pretrained: bool = True,
|
108 |
+
output_tokens: bool = False,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.output_tokens = output_tokens
|
112 |
+
self.output_dim = output_dim
|
113 |
+
|
114 |
+
# TODO: find better way to get this information
|
115 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
116 |
+
|
117 |
+
if transformers is None:
|
118 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
119 |
+
if config is None:
|
120 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
121 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
122 |
+
AutoModel.from_config, self.config)
|
123 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
124 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
125 |
+
self.transformer = create_func(model_args)
|
126 |
+
self.transformer = self.transformer.encoder
|
127 |
+
else:
|
128 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
129 |
+
else:
|
130 |
+
self.config = config
|
131 |
+
self.transformer = AutoModel.from_config(config)
|
132 |
+
if pooler_type is None: # get default arch pooler
|
133 |
+
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
134 |
+
|
135 |
+
# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
|
136 |
+
self.vocab_size = getattr(self.config, 'vocab_size', 0)
|
137 |
+
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
|
138 |
+
|
139 |
+
self.pooler = _POOLERS[pooler_type]()
|
140 |
+
|
141 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
142 |
+
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
|
143 |
+
self.proj = nn.Identity()
|
144 |
+
elif proj_type == 'linear':
|
145 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
146 |
+
elif proj_type == 'mlp':
|
147 |
+
hidden_size = (d_model + output_dim) // 2
|
148 |
+
self.proj = nn.Sequential(
|
149 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
150 |
+
nn.GELU(),
|
151 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, x: TensorType):
|
155 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
156 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
157 |
+
pooled_out = self.pooler(out, attn_mask)
|
158 |
+
projected = self.proj(pooled_out)
|
159 |
+
|
160 |
+
seq_len = out.last_hidden_state.shape[1]
|
161 |
+
tokens = (
|
162 |
+
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
163 |
+
if type(self.pooler) == ClsPooler
|
164 |
+
else out.last_hidden_state
|
165 |
+
)
|
166 |
+
|
167 |
+
if self.output_tokens:
|
168 |
+
return projected, tokens
|
169 |
+
return projected
|
170 |
+
|
171 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
172 |
+
if not unlocked_layers: # full freezing
|
173 |
+
for n, p in self.transformer.named_parameters():
|
174 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
175 |
+
return
|
176 |
+
|
177 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
178 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
179 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
180 |
+
embeddings = getattr(
|
181 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
182 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
183 |
+
# freeze layers
|
184 |
+
for module in modules:
|
185 |
+
for n, p in module.named_parameters():
|
186 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
187 |
+
|
188 |
+
@torch.jit.ignore
|
189 |
+
def set_grad_checkpointing(self, enable=True):
|
190 |
+
self.transformer.gradient_checkpointing_enable()
|
191 |
+
|
192 |
+
def init_parameters(self):
|
193 |
+
pass
|
open_clip/src/open_clip/model.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import copy
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
from torch.utils.checkpoint import checkpoint
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
from .hf_model import HFTextEncoder
|
19 |
+
from .modified_resnet import ModifiedResNet
|
20 |
+
from .timm_model import TimmModel
|
21 |
+
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
|
22 |
+
text_global_pool
|
23 |
+
from .utils import to_2tuple
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class CLIPVisionCfg:
|
28 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
29 |
+
width: int = 768
|
30 |
+
head_width: int = 64
|
31 |
+
mlp_ratio: float = 4.0
|
32 |
+
patch_size: int = 16
|
33 |
+
image_size: Union[Tuple[int, int], int] = 224
|
34 |
+
|
35 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
36 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
37 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
|
38 |
+
attn_pooler_queries: int = 256 # n_queries for attentional pooler
|
39 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
40 |
+
no_ln_pre: bool = False # disable pre transformer LayerNorm
|
41 |
+
pos_embed_type: str = 'learnable'
|
42 |
+
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
|
43 |
+
pool_type: str = 'tok'
|
44 |
+
output_tokens: bool = False
|
45 |
+
act_kwargs: Optional[dict] = None
|
46 |
+
norm_kwargs: Optional[dict] = None
|
47 |
+
|
48 |
+
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
|
49 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
50 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
51 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
52 |
+
timm_proj_bias: bool = False # enable bias final projection
|
53 |
+
timm_drop: float = 0. # head dropout
|
54 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class CLIPTextCfg:
|
59 |
+
context_length: int = 77
|
60 |
+
vocab_size: int = 49408
|
61 |
+
hf_tokenizer_name: Optional[str] = None
|
62 |
+
tokenizer_kwargs: Optional[dict] = None
|
63 |
+
|
64 |
+
width: int = 512
|
65 |
+
heads: int = 8
|
66 |
+
layers: int = 12
|
67 |
+
mlp_ratio: float = 4.0
|
68 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
69 |
+
embed_cls: bool = False
|
70 |
+
pad_id: int = 0
|
71 |
+
no_causal_mask: bool = False # disable causal masking
|
72 |
+
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
|
73 |
+
pool_type: str = 'argmax'
|
74 |
+
proj_bias: bool = False
|
75 |
+
proj_type: str = 'linear' # control final text projection, 'none' forces no projection
|
76 |
+
output_tokens: bool = False
|
77 |
+
act_kwargs: dict = None
|
78 |
+
norm_kwargs: dict = None
|
79 |
+
|
80 |
+
# HuggingFace specific text tower config
|
81 |
+
hf_model_name: Optional[str] = None
|
82 |
+
hf_model_pretrained: bool = True
|
83 |
+
hf_proj_type: str = 'mlp'
|
84 |
+
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
|
85 |
+
|
86 |
+
|
87 |
+
def get_cast_dtype(precision: str):
|
88 |
+
cast_dtype = None
|
89 |
+
if precision == 'bf16':
|
90 |
+
cast_dtype = torch.bfloat16
|
91 |
+
elif precision == 'fp16':
|
92 |
+
cast_dtype = torch.float16
|
93 |
+
return cast_dtype
|
94 |
+
|
95 |
+
|
96 |
+
def get_input_dtype(precision: str):
|
97 |
+
input_dtype = None
|
98 |
+
if precision in ('bf16', 'pure_bf16'):
|
99 |
+
input_dtype = torch.bfloat16
|
100 |
+
elif precision in ('fp16', 'pure_fp16'):
|
101 |
+
input_dtype = torch.float16
|
102 |
+
return input_dtype
|
103 |
+
|
104 |
+
|
105 |
+
def _build_vision_tower(
|
106 |
+
embed_dim: int,
|
107 |
+
vision_cfg: CLIPVisionCfg,
|
108 |
+
quick_gelu: bool = False,
|
109 |
+
cast_dtype: Optional[torch.dtype] = None
|
110 |
+
):
|
111 |
+
if isinstance(vision_cfg, dict):
|
112 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
113 |
+
|
114 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
115 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
116 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
117 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
118 |
+
|
119 |
+
if vision_cfg.timm_model_name:
|
120 |
+
visual = TimmModel(
|
121 |
+
vision_cfg.timm_model_name,
|
122 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
123 |
+
pool=vision_cfg.timm_pool,
|
124 |
+
proj=vision_cfg.timm_proj,
|
125 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
126 |
+
drop=vision_cfg.timm_drop,
|
127 |
+
drop_path=vision_cfg.timm_drop_path,
|
128 |
+
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
|
129 |
+
embed_dim=embed_dim,
|
130 |
+
image_size=vision_cfg.image_size,
|
131 |
+
)
|
132 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
133 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
134 |
+
visual = ModifiedResNet(
|
135 |
+
layers=vision_cfg.layers,
|
136 |
+
output_dim=embed_dim,
|
137 |
+
heads=vision_heads,
|
138 |
+
image_size=vision_cfg.image_size,
|
139 |
+
width=vision_cfg.width,
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
143 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
144 |
+
if vision_cfg.norm_kwargs:
|
145 |
+
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
|
146 |
+
if vision_cfg.act_kwargs is not None:
|
147 |
+
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
|
148 |
+
|
149 |
+
visual = VisionTransformer(
|
150 |
+
image_size=vision_cfg.image_size,
|
151 |
+
patch_size=vision_cfg.patch_size,
|
152 |
+
width=vision_cfg.width,
|
153 |
+
layers=vision_cfg.layers,
|
154 |
+
heads=vision_heads,
|
155 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
156 |
+
ls_init_value=vision_cfg.ls_init_value,
|
157 |
+
patch_dropout=vision_cfg.patch_dropout,
|
158 |
+
attentional_pool=vision_cfg.attentional_pool,
|
159 |
+
attn_pooler_queries=vision_cfg.attn_pooler_queries,
|
160 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
161 |
+
pos_embed_type=vision_cfg.pos_embed_type,
|
162 |
+
no_ln_pre=vision_cfg.no_ln_pre,
|
163 |
+
final_ln_after_pool=vision_cfg.final_ln_after_pool,
|
164 |
+
pool_type=vision_cfg.pool_type,
|
165 |
+
output_tokens=vision_cfg.output_tokens,
|
166 |
+
output_dim=embed_dim,
|
167 |
+
act_layer=act_layer,
|
168 |
+
norm_layer=norm_layer,
|
169 |
+
)
|
170 |
+
|
171 |
+
return visual
|
172 |
+
|
173 |
+
|
174 |
+
def _build_text_tower(
|
175 |
+
embed_dim: int,
|
176 |
+
text_cfg: CLIPTextCfg,
|
177 |
+
quick_gelu: bool = False,
|
178 |
+
cast_dtype: Optional[torch.dtype] = None,
|
179 |
+
):
|
180 |
+
if isinstance(text_cfg, dict):
|
181 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
182 |
+
|
183 |
+
if text_cfg.hf_model_name:
|
184 |
+
text = HFTextEncoder(
|
185 |
+
text_cfg.hf_model_name,
|
186 |
+
output_dim=embed_dim,
|
187 |
+
proj_type=text_cfg.hf_proj_type,
|
188 |
+
pooler_type=text_cfg.hf_pooler_type,
|
189 |
+
pretrained=text_cfg.hf_model_pretrained,
|
190 |
+
output_tokens=text_cfg.output_tokens,
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
194 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
195 |
+
if text_cfg.norm_kwargs:
|
196 |
+
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
|
197 |
+
if text_cfg.act_kwargs is not None:
|
198 |
+
act_layer = partial(act_layer, **text_cfg.act_kwargs)
|
199 |
+
|
200 |
+
text = TextTransformer(
|
201 |
+
context_length=text_cfg.context_length,
|
202 |
+
vocab_size=text_cfg.vocab_size,
|
203 |
+
width=text_cfg.width,
|
204 |
+
heads=text_cfg.heads,
|
205 |
+
layers=text_cfg.layers,
|
206 |
+
mlp_ratio=text_cfg.mlp_ratio,
|
207 |
+
ls_init_value=text_cfg.ls_init_value,
|
208 |
+
output_dim=embed_dim,
|
209 |
+
embed_cls=text_cfg.embed_cls,
|
210 |
+
no_causal_mask=text_cfg.no_causal_mask,
|
211 |
+
pad_id=text_cfg.pad_id,
|
212 |
+
pool_type=text_cfg.pool_type,
|
213 |
+
proj_type=text_cfg.proj_type,
|
214 |
+
proj_bias=text_cfg.proj_bias,
|
215 |
+
output_tokens=text_cfg.output_tokens,
|
216 |
+
act_layer=act_layer,
|
217 |
+
norm_layer=norm_layer,
|
218 |
+
)
|
219 |
+
return text
|
220 |
+
|
221 |
+
|
222 |
+
class CLIP(nn.Module):
|
223 |
+
output_dict: torch.jit.Final[bool]
|
224 |
+
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
embed_dim: int,
|
228 |
+
vision_cfg: CLIPVisionCfg,
|
229 |
+
text_cfg: CLIPTextCfg,
|
230 |
+
quick_gelu: bool = False,
|
231 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
232 |
+
init_logit_bias: Optional[float] = None,
|
233 |
+
nonscalar_logit_scale: bool = False,
|
234 |
+
cast_dtype: Optional[torch.dtype] = None,
|
235 |
+
output_dict: bool = False,
|
236 |
+
):
|
237 |
+
super().__init__()
|
238 |
+
self.output_dict = output_dict
|
239 |
+
|
240 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
241 |
+
|
242 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
243 |
+
self.transformer = text.transformer
|
244 |
+
self.context_length = text.context_length
|
245 |
+
self.vocab_size = text.vocab_size
|
246 |
+
self.token_embedding = text.token_embedding
|
247 |
+
self.positional_embedding = text.positional_embedding
|
248 |
+
self.ln_final = text.ln_final
|
249 |
+
self.text_projection = text.text_projection
|
250 |
+
self.text_pool_type = text.pool_type
|
251 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
252 |
+
|
253 |
+
lshape = [1] if nonscalar_logit_scale else []
|
254 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
255 |
+
if init_logit_bias is not None:
|
256 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
257 |
+
else:
|
258 |
+
self.logit_bias = None
|
259 |
+
|
260 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
261 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
262 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
263 |
+
|
264 |
+
@torch.jit.ignore
|
265 |
+
def set_grad_checkpointing(self, enable=True):
|
266 |
+
self.visual.set_grad_checkpointing(enable)
|
267 |
+
self.transformer.grad_checkpointing = enable
|
268 |
+
|
269 |
+
@torch.jit.ignore
|
270 |
+
def no_weight_decay(self):
|
271 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
272 |
+
no_wd = {'positional_embedding'}
|
273 |
+
if hasattr(self.visual, 'no_weight_decay'):
|
274 |
+
for n in self.visual.no_weight_decay():
|
275 |
+
no_wd.add('visual.' + n)
|
276 |
+
return no_wd
|
277 |
+
|
278 |
+
def encode_image(self, image, normalize: bool = False):
|
279 |
+
features = self.visual(image)
|
280 |
+
return F.normalize(features, dim=-1) if normalize else features
|
281 |
+
|
282 |
+
def encode_text(self, text, normalize: bool = False):
|
283 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
284 |
+
|
285 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
286 |
+
|
287 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
288 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
289 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
290 |
+
x, _ = text_global_pool(x, text, self.text_pool_type)
|
291 |
+
if self.text_projection is not None:
|
292 |
+
if isinstance(self.text_projection, nn.Linear):
|
293 |
+
x = self.text_projection(x)
|
294 |
+
else:
|
295 |
+
x = x @ self.text_projection
|
296 |
+
|
297 |
+
return F.normalize(x, dim=-1) if normalize else x
|
298 |
+
|
299 |
+
def get_logits(self, image, text):
|
300 |
+
image_features = self.encode_image(image, normalize=True)
|
301 |
+
text_features = self.encode_text(text, normalize=True)
|
302 |
+
image_logits = self.logit_scale.exp() * image_features @ text_features.T
|
303 |
+
if self.logit_bias is not None:
|
304 |
+
image_logits += self.logit_bias
|
305 |
+
text_logits = image_logits.T
|
306 |
+
return image_logits, text_logits
|
307 |
+
|
308 |
+
def forward(
|
309 |
+
self,
|
310 |
+
image: Optional[torch.Tensor] = None,
|
311 |
+
text: Optional[torch.Tensor] = None,
|
312 |
+
):
|
313 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
314 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
315 |
+
|
316 |
+
if self.output_dict:
|
317 |
+
out_dict = {
|
318 |
+
"image_features": image_features,
|
319 |
+
"text_features": text_features,
|
320 |
+
"logit_scale": self.logit_scale.exp()
|
321 |
+
}
|
322 |
+
if self.logit_bias is not None:
|
323 |
+
out_dict['logit_bias'] = self.logit_bias
|
324 |
+
return out_dict
|
325 |
+
|
326 |
+
if self.logit_bias is not None:
|
327 |
+
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
|
328 |
+
return image_features, text_features, self.logit_scale.exp()
|
329 |
+
|
330 |
+
|
331 |
+
class CustomTextCLIP(nn.Module):
|
332 |
+
output_dict: torch.jit.Final[bool]
|
333 |
+
|
334 |
+
def __init__(
|
335 |
+
self,
|
336 |
+
embed_dim: int,
|
337 |
+
vision_cfg: CLIPVisionCfg,
|
338 |
+
text_cfg: CLIPTextCfg,
|
339 |
+
quick_gelu: bool = False,
|
340 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
341 |
+
init_logit_bias: Optional[float] = None,
|
342 |
+
nonscalar_logit_scale: bool = False,
|
343 |
+
cast_dtype: Optional[torch.dtype] = None,
|
344 |
+
output_dict: bool = False,
|
345 |
+
):
|
346 |
+
super().__init__()
|
347 |
+
self.output_dict = output_dict
|
348 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
349 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
350 |
+
self.context_length = self.text.context_length
|
351 |
+
self.vocab_size = self.text.vocab_size
|
352 |
+
|
353 |
+
lshape = [1] if nonscalar_logit_scale else []
|
354 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
355 |
+
if init_logit_bias is not None:
|
356 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
357 |
+
else:
|
358 |
+
self.logit_bias = None
|
359 |
+
|
360 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
361 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
362 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
363 |
+
|
364 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
365 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
366 |
+
|
367 |
+
@torch.jit.ignore
|
368 |
+
def set_grad_checkpointing(self, enable=True):
|
369 |
+
self.visual.set_grad_checkpointing(enable)
|
370 |
+
self.text.set_grad_checkpointing(enable)
|
371 |
+
|
372 |
+
@torch.jit.ignore
|
373 |
+
def no_weight_decay(self):
|
374 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
375 |
+
no_wd = set()
|
376 |
+
if hasattr(self.visual, 'no_weight_decay'):
|
377 |
+
for n in self.visual.no_weight_decay():
|
378 |
+
no_wd.add('visual.' + n)
|
379 |
+
if hasattr(self.text, 'no_weight_decay'):
|
380 |
+
for n in self.visual.no_weight_decay():
|
381 |
+
no_wd.add('text.' + n)
|
382 |
+
return no_wd
|
383 |
+
|
384 |
+
def encode_image(self, image, normalize: bool = False):
|
385 |
+
features = self.visual(image)
|
386 |
+
return F.normalize(features, dim=-1) if normalize else features
|
387 |
+
|
388 |
+
def encode_text(self, text, normalize: bool = False):
|
389 |
+
features = self.text(text)
|
390 |
+
return F.normalize(features, dim=-1) if normalize else features
|
391 |
+
|
392 |
+
def get_logits(self, image, text):
|
393 |
+
image_features = self.encode_image(image, normalize=True)
|
394 |
+
text_features = self.encode_text(text, normalize=True)
|
395 |
+
image_logits = self.logit_scale.exp() * image_features @ text_features.T
|
396 |
+
if self.logit_bias is not None:
|
397 |
+
image_logits += self.logit_bias
|
398 |
+
text_logits = image_logits.T
|
399 |
+
return image_logits, text_logits
|
400 |
+
|
401 |
+
def forward(
|
402 |
+
self,
|
403 |
+
image: Optional[torch.Tensor] = None,
|
404 |
+
text: Optional[torch.Tensor] = None,
|
405 |
+
):
|
406 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
407 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
408 |
+
|
409 |
+
if self.output_dict:
|
410 |
+
out_dict = {
|
411 |
+
"image_features": image_features,
|
412 |
+
"text_features": text_features,
|
413 |
+
"logit_scale": self.logit_scale.exp()
|
414 |
+
}
|
415 |
+
if self.logit_bias is not None:
|
416 |
+
out_dict['logit_bias'] = self.logit_bias
|
417 |
+
return out_dict
|
418 |
+
|
419 |
+
if self.logit_bias is not None:
|
420 |
+
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
|
421 |
+
return image_features, text_features, self.logit_scale.exp()
|
422 |
+
|
423 |
+
|
424 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
425 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
426 |
+
|
427 |
+
def _convert_weights(l):
|
428 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
429 |
+
l.weight.data = l.weight.data.to(dtype)
|
430 |
+
if l.bias is not None:
|
431 |
+
l.bias.data = l.bias.data.to(dtype)
|
432 |
+
|
433 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
434 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
435 |
+
tensor = getattr(l, attr)
|
436 |
+
if tensor is not None:
|
437 |
+
tensor.data = tensor.data.to(dtype)
|
438 |
+
|
439 |
+
if isinstance(l, (CLIP, TextTransformer)):
|
440 |
+
# convert text nn.Parameter projections
|
441 |
+
attr = getattr(l, "text_projection", None)
|
442 |
+
if attr is not None:
|
443 |
+
attr.data = attr.data.to(dtype)
|
444 |
+
|
445 |
+
if isinstance(l, VisionTransformer):
|
446 |
+
# convert vision nn.Parameter projections
|
447 |
+
attr = getattr(l, "proj", None)
|
448 |
+
if attr is not None:
|
449 |
+
attr.data = attr.data.to(dtype)
|
450 |
+
|
451 |
+
model.apply(_convert_weights)
|
452 |
+
|
453 |
+
|
454 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
455 |
+
|
456 |
+
|
457 |
+
# used to maintain checkpoint compatibility
|
458 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
459 |
+
if 'text_projection' in state_dict:
|
460 |
+
# old format state_dict, move text tower -> .text
|
461 |
+
new_state_dict = {}
|
462 |
+
for k, v in state_dict.items():
|
463 |
+
if any(k.startswith(p) for p in (
|
464 |
+
'text_projection',
|
465 |
+
'positional_embedding',
|
466 |
+
'token_embedding',
|
467 |
+
'transformer',
|
468 |
+
'ln_final',
|
469 |
+
)):
|
470 |
+
k = 'text.' + k
|
471 |
+
new_state_dict[k] = v
|
472 |
+
return new_state_dict
|
473 |
+
return state_dict
|
474 |
+
|
475 |
+
|
476 |
+
def build_model_from_openai_state_dict(
|
477 |
+
state_dict: dict,
|
478 |
+
quick_gelu=True,
|
479 |
+
cast_dtype=torch.float16,
|
480 |
+
):
|
481 |
+
vit = "visual.proj" in state_dict
|
482 |
+
|
483 |
+
if vit:
|
484 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
485 |
+
vision_layers = len(
|
486 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
487 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
488 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
489 |
+
image_size = vision_patch_size * grid_size
|
490 |
+
else:
|
491 |
+
counts: list = [
|
492 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
493 |
+
vision_layers = tuple(counts)
|
494 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
495 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
496 |
+
vision_patch_size = None
|
497 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
498 |
+
image_size = output_width * 32
|
499 |
+
|
500 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
501 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
502 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
503 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
504 |
+
transformer_heads = transformer_width // 64
|
505 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
506 |
+
|
507 |
+
vision_cfg = CLIPVisionCfg(
|
508 |
+
layers=vision_layers,
|
509 |
+
width=vision_width,
|
510 |
+
patch_size=vision_patch_size,
|
511 |
+
image_size=image_size,
|
512 |
+
)
|
513 |
+
text_cfg = CLIPTextCfg(
|
514 |
+
context_length=context_length,
|
515 |
+
vocab_size=vocab_size,
|
516 |
+
width=transformer_width,
|
517 |
+
heads=transformer_heads,
|
518 |
+
layers=transformer_layers,
|
519 |
+
)
|
520 |
+
model = CLIP(
|
521 |
+
embed_dim,
|
522 |
+
vision_cfg=vision_cfg,
|
523 |
+
text_cfg=text_cfg,
|
524 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
525 |
+
cast_dtype=cast_dtype,
|
526 |
+
)
|
527 |
+
|
528 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
529 |
+
state_dict.pop(key, None)
|
530 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
531 |
+
model.load_state_dict(state_dict)
|
532 |
+
return model.eval()
|
533 |
+
|
534 |
+
|
535 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
536 |
+
model.eval()
|
537 |
+
image_size = model.visual.image_size
|
538 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
539 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
540 |
+
model = torch.jit.trace_module(
|
541 |
+
model,
|
542 |
+
inputs=dict(
|
543 |
+
forward=(example_images, example_text),
|
544 |
+
encode_text=(example_text,),
|
545 |
+
encode_image=(example_images,)
|
546 |
+
))
|
547 |
+
model.visual.image_size = image_size
|
548 |
+
return model
|
549 |
+
|
550 |
+
|
551 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
552 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
553 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
554 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
555 |
+
return
|
556 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
557 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
558 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
559 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
560 |
+
return
|
561 |
+
|
562 |
+
if extra_tokens:
|
563 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
564 |
+
else:
|
565 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
566 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
567 |
+
|
568 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
569 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
570 |
+
pos_emb_img = F.interpolate(
|
571 |
+
pos_emb_img,
|
572 |
+
size=grid_size,
|
573 |
+
mode=interpolation,
|
574 |
+
antialias=antialias,
|
575 |
+
align_corners=False,
|
576 |
+
)
|
577 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
578 |
+
if pos_emb_tok is not None:
|
579 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
580 |
+
else:
|
581 |
+
new_pos_embed = pos_emb_img
|
582 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
583 |
+
|
584 |
+
|
585 |
+
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
|
586 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
587 |
+
if old_pos_embed is None:
|
588 |
+
return
|
589 |
+
# FIXME add support for text cls_token
|
590 |
+
model_pos_embed = getattr(model, 'positional_embedding', None)
|
591 |
+
if model_pos_embed is None:
|
592 |
+
model_pos_embed = getattr(model.text, 'positional_embedding', None)
|
593 |
+
|
594 |
+
old_num_pos = old_pos_embed.shape[0]
|
595 |
+
old_width = old_pos_embed.shape[1]
|
596 |
+
num_pos = model_pos_embed.shape[0]
|
597 |
+
width = model_pos_embed.shape[1]
|
598 |
+
assert old_width == width, 'text pos_embed width changed!'
|
599 |
+
if old_num_pos == num_pos:
|
600 |
+
return
|
601 |
+
|
602 |
+
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
|
603 |
+
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
|
604 |
+
old_pos_embed = F.interpolate(
|
605 |
+
old_pos_embed,
|
606 |
+
size=num_pos,
|
607 |
+
mode=interpolation,
|
608 |
+
antialias=antialias,
|
609 |
+
align_corners=False,
|
610 |
+
)
|
611 |
+
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
|
612 |
+
new_pos_embed = old_pos_embed
|
613 |
+
|
614 |
+
state_dict['positional_embedding'] = new_pos_embed
|
615 |
+
|
616 |
+
|
617 |
+
def get_model_preprocess_cfg(model):
|
618 |
+
module = getattr(model, 'visual', model)
|
619 |
+
preprocess_cfg = getattr(module, 'preprocess_cfg', {})
|
620 |
+
if not preprocess_cfg:
|
621 |
+
# use separate legacy attributes if preprocess_cfg dict not found
|
622 |
+
size = getattr(module, 'image_size')
|
623 |
+
if size is not None:
|
624 |
+
preprocess_cfg['size'] = size
|
625 |
+
mean = getattr(module, 'image_mean', None)
|
626 |
+
if mean is not None:
|
627 |
+
preprocess_cfg['mean'] = mean
|
628 |
+
std = getattr(module, 'image_std', None)
|
629 |
+
if std is not None:
|
630 |
+
preprocess_cfg['std'] = std
|
631 |
+
return preprocess_cfg
|
632 |
+
|
633 |
+
|
634 |
+
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
|
635 |
+
module = getattr(model, 'visual', model)
|
636 |
+
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
|
637 |
+
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
|
638 |
+
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
|
639 |
+
|
640 |
+
|
641 |
+
def get_model_tokenize_cfg(model):
|
642 |
+
module = getattr(model, 'text', model)
|
643 |
+
cfg = {}
|
644 |
+
context_length = getattr(module, 'context_length', None)
|
645 |
+
if context_length is not None:
|
646 |
+
cfg['context_length'] = context_length
|
647 |
+
vocab_size = getattr(module, 'vocab_size', None)
|
648 |
+
if vocab_size is not None:
|
649 |
+
cfg['vocab_size'] = vocab_size
|
650 |
+
return cfg
|
open_clip/src/open_clip/model_configs/EVA01-g-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva_giant_patch14_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
open_clip/src/open_clip/model_configs/EVA02-B-16.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_base_patch16_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
open_clip/src/open_clip/model_configs/EVA02-E-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_enormous_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1280,
|
14 |
+
"heads": 20,
|
15 |
+
"layers": 32
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
open_clip/src/open_clip/model_configs/EVA02-L-14-336.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"timm_model_name": "eva02_large_patch14_clip_336",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
open_clip/src/open_clip/model_configs/RN50x16-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 384,
|
6 |
+
"layers": [
|
7 |
+
6,
|
8 |
+
8,
|
9 |
+
18,
|
10 |
+
8
|
11 |
+
],
|
12 |
+
"width": 96,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 768,
|
19 |
+
"heads": 12,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
open_clip/src/open_clip/model_configs/RN50x4-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 288,
|
6 |
+
"layers": [
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
10,
|
10 |
+
6
|
11 |
+
],
|
12 |
+
"width": 80,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 640,
|
19 |
+
"heads": 10,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
open_clip/src/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
open_clip/src/open_clip/model_configs/RN50x64.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 448,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
15,
|
8 |
+
36,
|
9 |
+
10
|
10 |
+
],
|
11 |
+
"width": 128,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 1024,
|
18 |
+
"heads": 16,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-256.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 256,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_256",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 32000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
}
|
28 |
+
}
|
29 |
+
}
|
open_clip/src/open_clip/model_configs/ViT-B-16-plus-240.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 240,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
open_clip/src/open_clip/model_configs/ViT-B-16-plus.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
open_clip/src/open_clip/model_configs/ViT-L-14-CLIPA-336.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14,
|
8 |
+
"no_ln_pre": true,
|
9 |
+
"pool_type": "avg",
|
10 |
+
"final_ln_after_pool": true
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 32,
|
14 |
+
"vocab_size": 32000,
|
15 |
+
"hf_tokenizer_name": "bert-base-uncased",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"strip_sep_token": true
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"pool_type": "last",
|
23 |
+
"no_causal_mask": true
|
24 |
+
}
|
25 |
+
}
|
open_clip/src/open_clip/model_configs/ViT-L-14-CLIPA.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14,
|
8 |
+
"no_ln_pre": true,
|
9 |
+
"pool_type": "avg",
|
10 |
+
"final_ln_after_pool": true
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 32,
|
14 |
+
"vocab_size": 32000,
|
15 |
+
"hf_tokenizer_name": "bert-base-uncased",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"strip_sep_token": true
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"pool_type": "last",
|
23 |
+
"no_causal_mask": true
|
24 |
+
}
|
25 |
+
}
|