tuandunghcmut commited on
Commit
e496ed3
·
verified ·
1 Parent(s): c8efe89

Upload HAPTransReID

Browse files
Files changed (4) hide show
  1. README.md +199 -0
  2. config.json +38 -0
  3. hap_transreid.py +926 -0
  4. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HAPTransReID"
4
+ ],
5
+ "attn_drop_rate": 0.0,
6
+ "auto_map": {
7
+ "AutoBackbone": "hap_transreid.HAPTransReID",
8
+ "AutoConfig": "hap_transreid.HAPTransReIDConfig"
9
+ },
10
+ "camera": 0,
11
+ "depth": 12,
12
+ "drop_path_rate": 0.1,
13
+ "drop_rate": 0.0,
14
+ "embed_dim": 768,
15
+ "hybrid_backbone": null,
16
+ "img_size": [
17
+ 256,
18
+ 128
19
+ ],
20
+ "in_chans": 3,
21
+ "local_feature": true,
22
+ "mlp_ratio": 4.0,
23
+ "model_type": "my-vit-b16",
24
+ "norm_layer_eps": 1e-06,
25
+ "num_classes": -1,
26
+ "num_heads": 12,
27
+ "patch_size": 16,
28
+ "qk_scale": null,
29
+ "qkv_bias": false,
30
+ "sie_xishu": 3.0,
31
+ "stride_size": [
32
+ 16,
33
+ 16
34
+ ],
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.42.4",
37
+ "view": 0
38
+ }
hap_transreid.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in
4
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5
+
6
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
7
+
8
+ Status/TODO:
9
+ * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
10
+ * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
11
+ * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
12
+ * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
13
+
14
+ Acknowledgments:
15
+ * The paper authors for releasing code and weights, thanks!
16
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
17
+ for some einops/einsum fun
18
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
19
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
20
+
21
+ Hacked together by / Copyright 2020 Ross Wightman
22
+ """
23
+
24
+ from transformers import (
25
+ PreTrainedModel,
26
+ PretrainedConfig,
27
+ AutoConfig,
28
+ AutoModel,
29
+ AutoModelForImageClassification,
30
+ )
31
+
32
+ import math
33
+ from functools import partial
34
+ from itertools import repeat
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+ import torch.nn.functional as F
39
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
40
+
41
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
42
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
43
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
44
+ from torch._six import container_abcs, int_classes
45
+ else:
46
+ import collections.abc as container_abcs
47
+
48
+ int_classes = int
49
+
50
+
51
+ # From PyTorch internals
52
+ def _ntuple(n):
53
+ def parse(x):
54
+ if isinstance(x, container_abcs.Iterable):
55
+ return x
56
+ return tuple(repeat(x, n))
57
+
58
+ return parse
59
+
60
+
61
+ to_2tuple = _ntuple(2)
62
+
63
+
64
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
65
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
66
+
67
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
68
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
69
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
70
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
71
+ 'survival rate' as the argument.
72
+
73
+ """
74
+ if drop_prob == 0.0 or not training:
75
+ return x
76
+ keep_prob = 1 - drop_prob
77
+ shape = (x.shape[0],) + (1,) * (
78
+ x.ndim - 1
79
+ ) # work with diff dim tensors, not just 2D ConvNets
80
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
81
+ random_tensor.floor_() # binarize
82
+ output = x.div(keep_prob) * random_tensor
83
+ return output
84
+
85
+
86
+ class DropPath(nn.Module):
87
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
88
+
89
+ def __init__(self, drop_prob=None):
90
+ super(DropPath, self).__init__()
91
+ self.drop_prob = drop_prob
92
+
93
+ def forward(self, x):
94
+ return drop_path(x, self.drop_prob, self.training)
95
+
96
+
97
+ def _cfg(url="", **kwargs):
98
+ return {
99
+ "url": url,
100
+ "num_classes": 1000,
101
+ "input_size": (3, 224, 224),
102
+ "pool_size": None,
103
+ "crop_pct": 0.9,
104
+ "interpolation": "bicubic",
105
+ "mean": IMAGENET_DEFAULT_MEAN,
106
+ "std": IMAGENET_DEFAULT_STD,
107
+ "first_conv": "patch_embed.proj",
108
+ "classifier": "head",
109
+ **kwargs,
110
+ }
111
+
112
+
113
+ default_cfgs = {
114
+ # patch models
115
+ "vit_small_patch16_224": _cfg(
116
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth",
117
+ ),
118
+ "vit_base_patch16_224": _cfg(
119
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
120
+ mean=(0.5, 0.5, 0.5),
121
+ std=(0.5, 0.5, 0.5),
122
+ ),
123
+ "vit_base_patch16_384": _cfg(
124
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth",
125
+ input_size=(3, 384, 384),
126
+ mean=(0.5, 0.5, 0.5),
127
+ std=(0.5, 0.5, 0.5),
128
+ crop_pct=1.0,
129
+ ),
130
+ "vit_base_patch32_384": _cfg(
131
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth",
132
+ input_size=(3, 384, 384),
133
+ mean=(0.5, 0.5, 0.5),
134
+ std=(0.5, 0.5, 0.5),
135
+ crop_pct=1.0,
136
+ ),
137
+ "vit_large_patch16_224": _cfg(
138
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
139
+ mean=(0.5, 0.5, 0.5),
140
+ std=(0.5, 0.5, 0.5),
141
+ ),
142
+ "vit_large_patch16_384": _cfg(
143
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth",
144
+ input_size=(3, 384, 384),
145
+ mean=(0.5, 0.5, 0.5),
146
+ std=(0.5, 0.5, 0.5),
147
+ crop_pct=1.0,
148
+ ),
149
+ "vit_large_patch32_384": _cfg(
150
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth",
151
+ input_size=(3, 384, 384),
152
+ mean=(0.5, 0.5, 0.5),
153
+ std=(0.5, 0.5, 0.5),
154
+ crop_pct=1.0,
155
+ ),
156
+ "vit_huge_patch16_224": _cfg(),
157
+ "vit_huge_patch32_384": _cfg(input_size=(3, 384, 384)),
158
+ # hybrid models
159
+ "vit_small_resnet26d_224": _cfg(),
160
+ "vit_small_resnet50d_s3_224": _cfg(),
161
+ "vit_base_resnet26d_224": _cfg(),
162
+ "vit_base_resnet50d_224": _cfg(),
163
+ }
164
+
165
+
166
+ class Mlp(nn.Module):
167
+ def __init__(
168
+ self,
169
+ in_features,
170
+ hidden_features=None,
171
+ out_features=None,
172
+ act_layer=nn.GELU,
173
+ drop=0.0,
174
+ ):
175
+ super().__init__()
176
+ out_features = out_features or in_features
177
+ hidden_features = hidden_features or in_features
178
+ self.fc1 = nn.Linear(in_features, hidden_features)
179
+ self.act = act_layer()
180
+ self.fc2 = nn.Linear(hidden_features, out_features)
181
+ self.drop = nn.Dropout(drop)
182
+
183
+ def forward(self, x):
184
+ x = self.fc1(x)
185
+ x = self.act(x)
186
+ x = self.drop(x)
187
+ x = self.fc2(x)
188
+ x = self.drop(x)
189
+ return x
190
+
191
+
192
+ class Attention(nn.Module):
193
+ def __init__(
194
+ self,
195
+ dim,
196
+ num_heads=8,
197
+ qkv_bias=False,
198
+ qk_scale=None,
199
+ attn_drop=0.0,
200
+ proj_drop=0.0,
201
+ ):
202
+ super().__init__()
203
+ self.num_heads = num_heads
204
+ head_dim = dim // num_heads
205
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
206
+ self.scale = qk_scale or head_dim**-0.5
207
+
208
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
209
+ self.attn_drop = nn.Dropout(attn_drop)
210
+ self.proj = nn.Linear(dim, dim)
211
+ self.proj_drop = nn.Dropout(proj_drop)
212
+
213
+ def forward(self, x):
214
+ B, N, C = x.shape
215
+ qkv = (
216
+ self.qkv(x)
217
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
218
+ .permute(2, 0, 3, 1, 4)
219
+ )
220
+ q, k, v = (
221
+ qkv[0],
222
+ qkv[1],
223
+ qkv[2],
224
+ ) # make torchscript happy (cannot use tensor as tuple)
225
+
226
+ attn = (q @ k.transpose(-2, -1)) * self.scale
227
+ attn = attn.softmax(dim=-1)
228
+ attn = self.attn_drop(attn)
229
+
230
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
231
+ x = self.proj(x)
232
+ x = self.proj_drop(x)
233
+ return x
234
+
235
+
236
+ class Block(nn.Module):
237
+ def __init__(
238
+ self,
239
+ dim,
240
+ num_heads,
241
+ mlp_ratio=4.0,
242
+ qkv_bias=False,
243
+ qk_scale=None,
244
+ drop=0.0,
245
+ attn_drop=0.0,
246
+ drop_path=0.0,
247
+ act_layer=nn.GELU,
248
+ norm_layer=nn.LayerNorm,
249
+ ):
250
+ super().__init__()
251
+ self.norm1 = norm_layer(dim)
252
+ self.attn = Attention(
253
+ dim,
254
+ num_heads=num_heads,
255
+ qkv_bias=qkv_bias,
256
+ qk_scale=qk_scale,
257
+ attn_drop=attn_drop,
258
+ proj_drop=drop,
259
+ )
260
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
261
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
262
+ self.norm2 = norm_layer(dim)
263
+ mlp_hidden_dim = int(dim * mlp_ratio)
264
+ self.mlp = Mlp(
265
+ in_features=dim,
266
+ hidden_features=mlp_hidden_dim,
267
+ act_layer=act_layer,
268
+ drop=drop,
269
+ )
270
+
271
+ def forward(self, x):
272
+ x = x + self.drop_path(self.attn(self.norm1(x)))
273
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
274
+ return x
275
+
276
+
277
+ class PatchEmbed(nn.Module):
278
+ """Image to Patch Embedding"""
279
+
280
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
281
+ super().__init__()
282
+ img_size = to_2tuple(img_size)
283
+ patch_size = to_2tuple(patch_size)
284
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
285
+ self.img_size = img_size
286
+ self.patch_size = patch_size
287
+ self.num_patches = num_patches
288
+
289
+ self.proj = nn.Conv2d(
290
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
291
+ )
292
+
293
+ def forward(self, x):
294
+ B, C, H, W = x.shape
295
+ # FIXME look at relaxing size constraints
296
+ assert (
297
+ H == self.img_size[0] and W == self.img_size[1]
298
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
299
+ x = self.proj(x).flatten(2).transpose(1, 2)
300
+ return x
301
+
302
+
303
+ class HybridEmbed(nn.Module):
304
+ """CNN Feature Map Embedding
305
+ Extract feature map from CNN, flatten, project to embedding dim.
306
+ """
307
+
308
+ def __init__(
309
+ self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768
310
+ ):
311
+ super().__init__()
312
+ assert isinstance(backbone, nn.Module)
313
+ img_size = to_2tuple(img_size)
314
+ self.img_size = img_size
315
+ self.backbone = backbone
316
+ if feature_size is None:
317
+ with torch.no_grad():
318
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
319
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
320
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
321
+ training = backbone.training
322
+ if training:
323
+ backbone.eval()
324
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
325
+ if isinstance(o, (list, tuple)):
326
+ o = o[-1] # last feature if backbone outputs list/tuple of features
327
+ feature_size = o.shape[-2:]
328
+ feature_dim = o.shape[1]
329
+ backbone.train(training)
330
+ else:
331
+ feature_size = to_2tuple(feature_size)
332
+ if hasattr(self.backbone, "feature_info"):
333
+ feature_dim = self.backbone.feature_info.channels()[-1]
334
+ else:
335
+ feature_dim = self.backbone.num_features
336
+ self.num_patches = feature_size[0] * feature_size[1]
337
+ self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
338
+
339
+ def forward(self, x):
340
+ x = self.backbone(x)
341
+ if isinstance(x, (list, tuple)):
342
+ x = x[-1] # last feature if backbone outputs list/tuple of features
343
+ x = self.proj(x).flatten(2).transpose(1, 2)
344
+ return x
345
+
346
+
347
+ class PatchEmbed_overlap(nn.Module):
348
+ """Image to Patch Embedding with overlapping patches"""
349
+
350
+ def __init__(
351
+ self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768
352
+ ):
353
+ super().__init__()
354
+ img_size = to_2tuple(img_size)
355
+ patch_size = to_2tuple(patch_size)
356
+ stride_size_tuple = to_2tuple(stride_size)
357
+ self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1
358
+ self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1
359
+ print(
360
+ "using stride: {}, and patch number is num_y{} * num_x{}".format(
361
+ stride_size, self.num_y, self.num_x
362
+ )
363
+ )
364
+ num_patches = self.num_x * self.num_y
365
+ self.img_size = img_size
366
+ self.patch_size = patch_size
367
+ self.num_patches = num_patches
368
+
369
+ self.proj = nn.Conv2d(
370
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride_size
371
+ )
372
+ for m in self.modules():
373
+ if isinstance(m, nn.Conv2d):
374
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
375
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
376
+ elif isinstance(m, nn.BatchNorm2d):
377
+ m.weight.data.fill_(1)
378
+ m.bias.data.zero_()
379
+ elif isinstance(m, nn.InstanceNorm2d):
380
+ m.weight.data.fill_(1)
381
+ m.bias.data.zero_()
382
+
383
+ def forward(self, x):
384
+ B, C, H, W = x.shape
385
+
386
+ # FIXME look at relaxing size constraints
387
+ assert (
388
+ H == self.img_size[0] and W == self.img_size[1]
389
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
390
+ x = self.proj(x)
391
+
392
+ x = x.flatten(2).transpose(1, 2) # [64, 8, 768]
393
+ return x
394
+
395
+
396
+ class TransReID(nn.Module):
397
+ """Transformer-based Object Re-Identification"""
398
+
399
+ @classmethod
400
+ def from_config(cls, config):
401
+ return cls(
402
+ img_size=config.get("img_size", [384, 128]),
403
+ patch_size=config.get("patch_size", 16),
404
+ stride_size=config.get("stride_size", 16),
405
+ in_chans=config.get("in_chans", 3),
406
+ num_classes=config.get("num_classes", 1000),
407
+ embed_dim=config.get("embed_dim", 768),
408
+ depth=config.get("depth", 12),
409
+ num_heads=config.get("num_heads", 12),
410
+ mlp_ratio=config.get("mlp_ratio", 4.0),
411
+ qkv_bias=config.get("qkv_bias", False),
412
+ qk_scale=config.get("qk_scale", None),
413
+ drop_rate=config.get("drop_rate", 0.0),
414
+ attn_drop_rate=config.get("attn_drop_rate", 0.0),
415
+ drop_path_rate=config.get("drop_path_rate", 0.0),
416
+ camera=config.get("camera", 0),
417
+ view=config.get("view", 0),
418
+ local_feature=config.get("local_feature", False),
419
+ sie_xishu=config.get("sie_xishu", 1.0),
420
+ )
421
+
422
+ def __init__(
423
+ self,
424
+ img_size=224,
425
+ patch_size=16,
426
+ stride_size=16,
427
+ in_chans=3,
428
+ num_classes=1000,
429
+ embed_dim=768,
430
+ depth=12,
431
+ num_heads=12,
432
+ mlp_ratio=4.0,
433
+ qkv_bias=False,
434
+ qk_scale=None,
435
+ drop_rate=0.0,
436
+ attn_drop_rate=0.0,
437
+ camera=0,
438
+ view=0,
439
+ drop_path_rate=0.0,
440
+ hybrid_backbone=None,
441
+ norm_layer=nn.LayerNorm,
442
+ local_feature=False,
443
+ sie_xishu=1.0,
444
+ ):
445
+ nn.Module.__init__(self)
446
+
447
+ self.num_classes = num_classes
448
+ self.num_features = self.embed_dim = (
449
+ embed_dim # num_features for consistency with other models
450
+ )
451
+ self.local_feature = local_feature
452
+ if hybrid_backbone is not None:
453
+ self.patch_embed = HybridEmbed(
454
+ hybrid_backbone,
455
+ img_size=img_size,
456
+ in_chans=in_chans,
457
+ embed_dim=embed_dim,
458
+ )
459
+ else:
460
+ self.patch_embed = PatchEmbed_overlap(
461
+ img_size=img_size,
462
+ patch_size=patch_size,
463
+ stride_size=stride_size,
464
+ in_chans=in_chans,
465
+ embed_dim=embed_dim,
466
+ )
467
+
468
+ num_patches = self.patch_embed.num_patches
469
+
470
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
471
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
472
+ self.cam_num = camera
473
+ self.view_num = view
474
+ self.sie_xishu = sie_xishu
475
+ # Initialize SIE Embedding
476
+ if camera > 1 and view > 1:
477
+ self.sie_embed = nn.Parameter(torch.zeros(camera * view, 1, embed_dim))
478
+ trunc_normal_(self.sie_embed, std=0.02)
479
+ print(
480
+ "camera number is : {} and viewpoint number is : {}".format(
481
+ camera, view
482
+ )
483
+ )
484
+ print("using SIE_Lambda is : {}".format(sie_xishu))
485
+ elif camera > 1:
486
+ self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim))
487
+ trunc_normal_(self.sie_embed, std=0.02)
488
+ print("camera number is : {}".format(camera))
489
+ print("using SIE_Lambda is : {}".format(sie_xishu))
490
+ elif view > 1:
491
+ self.sie_embed = nn.Parameter(torch.zeros(view, 1, embed_dim))
492
+ trunc_normal_(self.sie_embed, std=0.02)
493
+ print("viewpoint number is : {}".format(view))
494
+ print("using SIE_Lambda is : {}".format(sie_xishu))
495
+
496
+ print("using drop_out rate is : {}".format(drop_rate))
497
+ print("using attn_drop_out rate is : {}".format(attn_drop_rate))
498
+ print("using drop_path rate is : {}".format(drop_path_rate))
499
+
500
+ self.pos_drop = nn.Dropout(p=drop_rate)
501
+ dpr = [
502
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
503
+ ] # stochastic depth decay rule
504
+
505
+ self.blocks = nn.ModuleList(
506
+ [
507
+ Block(
508
+ dim=embed_dim,
509
+ num_heads=num_heads,
510
+ mlp_ratio=mlp_ratio,
511
+ qkv_bias=qkv_bias,
512
+ qk_scale=qk_scale,
513
+ drop=drop_rate,
514
+ attn_drop=attn_drop_rate,
515
+ drop_path=dpr[i],
516
+ norm_layer=norm_layer,
517
+ )
518
+ for i in range(depth)
519
+ ]
520
+ )
521
+
522
+ self.norm = norm_layer(embed_dim)
523
+
524
+ # # Classifier head
525
+ self.fc = (
526
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
527
+ )
528
+ trunc_normal_(self.cls_token, std=0.02)
529
+ trunc_normal_(self.pos_embed, std=0.02)
530
+
531
+ self.apply(self._init_weights)
532
+
533
+ def _init_weights(self, m):
534
+ if isinstance(m, nn.Linear):
535
+ trunc_normal_(m.weight, std=0.02)
536
+ if isinstance(m, nn.Linear) and m.bias is not None:
537
+ nn.init.constant_(m.bias, 0)
538
+ elif isinstance(m, nn.LayerNorm):
539
+ nn.init.constant_(m.bias, 0)
540
+ nn.init.constant_(m.weight, 1.0)
541
+
542
+ @torch.jit.ignore
543
+ def no_weight_decay(self):
544
+ return {"pos_embed", "cls_token"}
545
+
546
+ def get_classifier(self):
547
+ return self.head
548
+
549
+ def reset_classifier(self, num_classes, global_pool=""):
550
+ self.num_classes = num_classes
551
+ self.fc = (
552
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
553
+ )
554
+
555
+ def forward_features(self, x, camera_id, view_id):
556
+ B = x.shape[0]
557
+ x = self.patch_embed(x)
558
+
559
+ cls_tokens = self.cls_token.expand(
560
+ B, -1, -1
561
+ ) # stole cls_tokens impl from Phil Wang, thanks
562
+ x = torch.cat((cls_tokens, x), dim=1)
563
+
564
+ if self.cam_num > 0 and self.view_num > 0:
565
+ x = (
566
+ x
567
+ + self.pos_embed
568
+ + self.sie_xishu * self.sie_embed[camera_id * self.view_num + view_id]
569
+ )
570
+ elif self.cam_num > 0:
571
+ x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id]
572
+ elif self.view_num > 0:
573
+ x = x + self.pos_embed + self.sie_xishu * self.sie_embed[view_id]
574
+ else:
575
+ x = x + self.pos_embed
576
+
577
+ x = self.pos_drop(x)
578
+
579
+ if self.local_feature:
580
+ for blk in self.blocks:
581
+ x = blk(x)
582
+
583
+ x = self.norm(x)
584
+
585
+ return x
586
+
587
+ else:
588
+ for blk in self.blocks:
589
+ x = blk(x)
590
+
591
+ x = self.norm(x)
592
+
593
+ return x[:, 0]
594
+
595
+ def forward(self, x, cam_label=None, view_label=None):
596
+ x = self.forward_features(x, cam_label, view_label)
597
+ return x
598
+
599
+ def load_param(self, model_path):
600
+ param_dict = torch.load(model_path, map_location="cpu")
601
+ if "model" in param_dict:
602
+ param_dict = param_dict["model"]
603
+ if "state_dict" in param_dict:
604
+ param_dict = param_dict["state_dict"]
605
+ for k, v in param_dict.items():
606
+ # print(k)
607
+ if "head" in k or "dist" in k:
608
+ continue
609
+ if "patch_embed.proj.weight" in k and len(v.shape) < 4:
610
+ # For old models that I trained prior to conv based patchification
611
+ O, I, H, W = self.patch_embed.proj.weight.shape
612
+ v = v.reshape(O, -1, H, W)
613
+ elif k == "pos_embed" and v.shape != self.pos_embed.shape:
614
+ # To resize pos embedding when using model at different size from pretrained weights
615
+ if "distilled" in model_path:
616
+ print("distill need to choose right cls token in the pth")
617
+ v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1)
618
+ v = resize_pos_embed(
619
+ v, self.pos_embed, self.patch_embed.num_y, self.patch_embed.num_x
620
+ )
621
+ try:
622
+ self.state_dict()[k].copy_(v)
623
+
624
+ except:
625
+ # print("===========================ERROR=========================")
626
+ # print(k)
627
+ # print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape))
628
+ pass
629
+
630
+
631
+ def resize_pos_embed(posemb, posemb_new, hight, width):
632
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
633
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
634
+ ntok_new = posemb_new.shape[1]
635
+
636
+ posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
637
+ ntok_new -= 1
638
+
639
+ int(math.sqrt(len(posemb_grid)))
640
+ print(
641
+ "Resized position embedding from size:{} to size: {} with height:{} width: {}".format(
642
+ posemb.shape, posemb_new.shape, hight, width
643
+ )
644
+ )
645
+ posemb_grid = posemb_grid.reshape(1, 16, 8, -1).permute(0, 3, 1, 2)
646
+ posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode="bilinear")
647
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
648
+ posemb = torch.cat([posemb_token, posemb_grid], dim=1)
649
+ return posemb
650
+
651
+
652
+ def vit_base_patch16_224_TransReID(
653
+ img_size=(256, 128),
654
+ stride_size=16,
655
+ drop_rate=0.0,
656
+ attn_drop_rate=0.0,
657
+ drop_path_rate=0.1,
658
+ camera=0,
659
+ view=0,
660
+ local_feature=False,
661
+ sie_xishu=1.5,
662
+ **kwargs,
663
+ ):
664
+ model = TransReID(
665
+ img_size=img_size,
666
+ patch_size=16,
667
+ stride_size=stride_size,
668
+ embed_dim=768,
669
+ depth=12,
670
+ num_heads=12,
671
+ mlp_ratio=4,
672
+ qkv_bias=True,
673
+ camera=camera,
674
+ view=view,
675
+ drop_path_rate=drop_path_rate,
676
+ drop_rate=drop_rate,
677
+ attn_drop_rate=attn_drop_rate,
678
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
679
+ sie_xishu=sie_xishu,
680
+ local_feature=local_feature,
681
+ **kwargs,
682
+ )
683
+
684
+ return model
685
+
686
+
687
+ def vit_small_patch16_224_TransReID(
688
+ img_size=(256, 128),
689
+ stride_size=16,
690
+ drop_rate=0.0,
691
+ attn_drop_rate=0.0,
692
+ drop_path_rate=0.1,
693
+ camera=0,
694
+ view=0,
695
+ local_feature=False,
696
+ sie_xishu=1.5,
697
+ **kwargs,
698
+ ):
699
+ kwargs.setdefault("qk_scale", 768**-0.5)
700
+ model = TransReID(
701
+ img_size=img_size,
702
+ patch_size=16,
703
+ stride_size=stride_size,
704
+ embed_dim=768,
705
+ depth=8,
706
+ num_heads=8,
707
+ mlp_ratio=3.0,
708
+ qkv_bias=False,
709
+ drop_path_rate=drop_path_rate,
710
+ camera=camera,
711
+ view=view,
712
+ drop_rate=drop_rate,
713
+ attn_drop_rate=attn_drop_rate,
714
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
715
+ sie_xishu=sie_xishu,
716
+ local_feature=local_feature,
717
+ **kwargs,
718
+ )
719
+
720
+ return model
721
+
722
+
723
+ def deit_small_patch16_224_TransReID(
724
+ img_size=(256, 128),
725
+ stride_size=16,
726
+ drop_path_rate=0.1,
727
+ drop_rate=0.0,
728
+ attn_drop_rate=0.0,
729
+ camera=0,
730
+ view=0,
731
+ local_feature=False,
732
+ sie_xishu=1.5,
733
+ **kwargs,
734
+ ):
735
+ model = TransReID(
736
+ img_size=img_size,
737
+ patch_size=16,
738
+ stride_size=stride_size,
739
+ embed_dim=384,
740
+ depth=12,
741
+ num_heads=6,
742
+ mlp_ratio=4,
743
+ qkv_bias=True,
744
+ drop_path_rate=drop_path_rate,
745
+ drop_rate=drop_rate,
746
+ attn_drop_rate=attn_drop_rate,
747
+ camera=camera,
748
+ view=view,
749
+ sie_xishu=sie_xishu,
750
+ local_feature=local_feature,
751
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
752
+ **kwargs,
753
+ )
754
+
755
+ return model
756
+
757
+
758
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
759
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
760
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
761
+ def norm_cdf(x):
762
+ # Computes standard normal cumulative distribution function
763
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
764
+
765
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
766
+ print(
767
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
768
+ "The distribution of values may be incorrect.",
769
+ )
770
+
771
+ with torch.no_grad():
772
+ # Values are generated by using a truncated uniform distribution and
773
+ # then using the inverse CDF for the normal distribution.
774
+ # Get upper and lower cdf values
775
+ l = norm_cdf((a - mean) / std)
776
+ u = norm_cdf((b - mean) / std)
777
+
778
+ # Uniformly fill tensor with values from [l, u], then translate to
779
+ # [2l-1, 2u-1].
780
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
781
+
782
+ # Use inverse cdf transform for normal distribution to get truncated
783
+ # standard normal
784
+ tensor.erfinv_()
785
+
786
+ # Transform to proper mean, std
787
+ tensor.mul_(std * math.sqrt(2.0))
788
+ tensor.add_(mean)
789
+
790
+ # Clamp to ensure it's in the proper range
791
+ tensor.clamp_(min=a, max=b)
792
+ return tensor
793
+
794
+
795
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
796
+ # type: (Tensor, float, float, float, float) -> Tensor
797
+ r"""Fills the input Tensor with values drawn from a truncated
798
+ normal distribution. The values are effectively drawn from the
799
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
800
+ with values outside :math:`[a, b]` redrawn until they are within
801
+ the bounds. The method used for generating the random values works
802
+ best when :math:`a \leq \text{mean} \leq b`.
803
+ Args:
804
+ tensor: an n-dimensional `torch.Tensor`
805
+ mean: the mean of the normal distribution
806
+ std: the standard deviation of the normal distribution
807
+ a: the minimum cutoff value
808
+ b: the maximum cutoff value
809
+ Examples:
810
+ >>> w = torch.empty(3, 5)
811
+ >>> nn.init.trunc_normal_(w)
812
+ """
813
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
814
+
815
+
816
+ class HAPTransReIDConfig(PretrainedConfig):
817
+ model_type = "my-vit-b16"
818
+
819
+ def __init__(
820
+ self,
821
+ img_size=[384, 128],
822
+ stride_size=[16, 16],
823
+ drop_rate=0.0,
824
+ attn_drop_rate=0.0,
825
+ drop_path_rate=0.1,
826
+ camera=0, # not used
827
+ view=0, # not used
828
+ local_feature=True,
829
+ sie_xishu=3.0, # not used
830
+ num_classes=-1, # not used
831
+ patch_size=16,
832
+ in_chans=3,
833
+ embed_dim=768,
834
+ depth=12,
835
+ num_heads=12,
836
+ mlp_ratio=4.0,
837
+ qkv_bias=False,
838
+ qk_scale=None,
839
+ hybrid_backbone=None, # not used
840
+ norm_layer_eps=1e-6,
841
+ **kwargs,
842
+ ):
843
+
844
+ super().__init__(**kwargs)
845
+
846
+ self.img_size = img_size
847
+ self.stride_size = stride_size
848
+ self.drop_rate = drop_rate
849
+ self.attn_drop_rate = attn_drop_rate
850
+ self.drop_path_rate = drop_path_rate
851
+ self.camera = camera
852
+ self.view = view
853
+ self.local_feature = local_feature
854
+ self.sie_xishu = sie_xishu
855
+ self.num_classes = num_classes
856
+ self.patch_size = patch_size
857
+ self.in_chans = in_chans
858
+ self.embed_dim = embed_dim
859
+ self.depth = depth
860
+ self.num_heads = num_heads
861
+ self.mlp_ratio = mlp_ratio
862
+ self.qkv_bias = qkv_bias
863
+ self.qk_scale = qk_scale
864
+ self.hybrid_backbone = hybrid_backbone
865
+ self.norm_layer_eps = norm_layer_eps
866
+
867
+
868
+
869
+
870
+ class HAPTransReID(TransReID, PreTrainedModel):
871
+ config_class = HAPTransReIDConfig
872
+
873
+ def __init__(self, config):
874
+ PreTrainedModel.__init__(self, config)
875
+ self.config = config
876
+ self.model = TransReID(
877
+ img_size=config.img_size,
878
+ stride_size=config.stride_size,
879
+ drop_rate=config.drop_rate,
880
+ attn_drop_rate=config.attn_drop_rate,
881
+ drop_path_rate=config.drop_path_rate,
882
+ camera=config.camera,
883
+ view=config.view,
884
+ local_feature=config.local_feature,
885
+ sie_xishu=config.sie_xishu,
886
+ num_classes=config.num_classes,
887
+ patch_size=config.patch_size,
888
+ in_chans=config.in_chans,
889
+ embed_dim=config.embed_dim,
890
+ depth=config.depth,
891
+ num_heads=config.num_heads,
892
+ mlp_ratio=config.mlp_ratio,
893
+ qkv_bias=config.qkv_bias,
894
+ qk_scale=config.qk_scale,
895
+ norm_layer=partial(nn.LayerNorm, eps=config.norm_layer_eps),
896
+ )
897
+ self.model.hidden_size = self.model.vision_width = config.embed_dim
898
+ def forward(self, x):
899
+ return self.model(x, cam_label=None, view_label=None)
900
+
901
+ @classmethod
902
+ def from_config(cls, config={}, from_path=None, from_pretrained=None):
903
+ '''
904
+ vision_width = hidden_size = 768, just for get information
905
+ not used in the model
906
+ '''
907
+ model = vit_base_patch16_224_TransReID(
908
+ img_size=config.get("img_size", [384, 128]),
909
+ stride_size=config.get("stride_size", [16, 16]),
910
+ drop_rate=config.get("drop_rate", 0.0),
911
+ attn_drop_rate=config.get("attn_drop_rate", 0.0),
912
+ drop_path_rate=config.get("drop_path_rate", 0.1),
913
+ camera=config.get("camera", 0),
914
+ view=config.get("view", 0),
915
+ local_feature=config.get("local_feature", True),
916
+ sie_xishu=config.get("sie_xishu", 3.0),
917
+ num_classes=config.get("num_classes", -1),
918
+ # vision_width=config.get("vision_width", 768),
919
+ # hidden_size=config.get("hidden_size", 768),
920
+ )
921
+ model.vision_width = model.hidden_size = 768
922
+
923
+ if from_path is not None:
924
+ model.load_param(from_path)
925
+
926
+ return model
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4607757202c8364791cace0fc954414335f7f0401ec4722d342ca6ff521a44a2
3
+ size 342888816