Add new SentenceTransformer model.
Browse files- README.md +197 -0
- config.json +176 -0
- config_sentence_transformers.json +10 -0
- configuration_clip.py +304 -0
- custom_st.py +174 -0
- custom_st_2.py +3 -0
- eva_model.py +764 -0
- hf_model.py +297 -0
- modeling_clip.py +570 -0
- modules.json +8 -0
- preprocessor_config.json +22 -0
- processing_clip.py +88 -0
- pytorch_model.bin +3 -0
- rope_embeddings.py +165 -0
- special_tokens_map.json +37 -0
- tokenizer.json +0 -0
- tokenizer_config.json +64 -0
- transform.py +458 -0
- vocab.txt +0 -0
README.md
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- feature-extraction
|
4 |
+
- sentence-similarity
|
5 |
+
- mteb
|
6 |
+
- clip
|
7 |
+
- vision
|
8 |
+
- transformers.js
|
9 |
+
language: en
|
10 |
+
inference: false
|
11 |
+
license: apache-2.0
|
12 |
+
library_name: transformers
|
13 |
+
---
|
14 |
+
|
15 |
+
<br><br>
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<img src="https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/603763514de52ff951d89793/AFoybzd5lpBQXEBrQHuTt.png?w=200&h=200&f=face" alt="Finetuner logo: Finetuner helps you to create experiments in order to improve embeddings on search tasks. It accompanies you to deliver the last mile of performance-tuning for neural search applications." width="150px">
|
19 |
+
</p>
|
20 |
+
|
21 |
+
|
22 |
+
<p align="center">
|
23 |
+
<b>The embedding set trained by <a href="https://jina.ai/"><b>Jina AI</b></a>.</b>
|
24 |
+
</p>
|
25 |
+
|
26 |
+
<p align="center">
|
27 |
+
<b>Jina CLIP: your CLIP model is also your text retriever!</b>
|
28 |
+
</p>
|
29 |
+
|
30 |
+
|
31 |
+
## Intended Usage & Model Info
|
32 |
+
|
33 |
+
`jina-clip-v1` is a state-of-the-art English **multimodal (text-image) embedding model**.
|
34 |
+
|
35 |
+
Traditional text embedding models, such as [jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en), excel in text-to-text retrieval but incapable of cross-modal tasks. Models like [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) effectively align image and text embeddings but are not optimized for text-to-text retrieval due to their training methodologies and context limitations.
|
36 |
+
|
37 |
+
`jina-clip-v1` bridges this gap by offering robust performance in both domains.
|
38 |
+
Its text component matches the retrieval efficiency of `jina-embeddings-v2-base-en`, while its overall architecture sets a new benchmark for cross-modal retrieval.
|
39 |
+
This dual capability makes it an excellent tool for multimodal retrieval-augmented generation (MuRAG) applications, enabling seamless text-to-text and text-to-image searches within a single model.
|
40 |
+
|
41 |
+
|
42 |
+
## Data & Parameters
|
43 |
+
|
44 |
+
[Check out our paper](https://arxiv.org/abs/2405.20204)
|
45 |
+
|
46 |
+
## Usage
|
47 |
+
|
48 |
+
1. The easiest way to starting using jina-clip-v1-en is to use Jina AI's [Embeddings API](https://jina.ai/embeddings/).
|
49 |
+
2. Alternatively, you can use Jina CLIP directly via transformers package.
|
50 |
+
|
51 |
+
```python
|
52 |
+
!pip install transformers einops timm pillow
|
53 |
+
from transformers import AutoModel
|
54 |
+
|
55 |
+
# Initialize the model
|
56 |
+
model = AutoModel.from_pretrained('jinaai/jina-clip-v1', trust_remote_code=True)
|
57 |
+
|
58 |
+
# New meaningful sentences
|
59 |
+
sentences = ['A blue cat', 'A red cat']
|
60 |
+
|
61 |
+
# Public image URLs
|
62 |
+
image_urls = [
|
63 |
+
'https://i.pinimg.com/600x315/21/48/7e/21487e8e0970dd366dafaed6ab25d8d8.jpg',
|
64 |
+
'https://i.pinimg.com/736x/c9/f2/3e/c9f23e212529f13f19bad5602d84b78b.jpg'
|
65 |
+
]
|
66 |
+
|
67 |
+
# Encode text and images
|
68 |
+
text_embeddings = model.encode_text(sentences)
|
69 |
+
image_embeddings = model.encode_image(image_urls) # also accepts PIL.image, local filenames, dataURI
|
70 |
+
|
71 |
+
# Compute similarities
|
72 |
+
print(text_embeddings[0] @ text_embeddings[1].T) # text embedding similarity
|
73 |
+
print(text_embeddings[0] @ image_embeddings[0].T) # text-image cross-modal similarity
|
74 |
+
print(text_embeddings[0] @ image_embeddings[1].T) # text-image cross-modal similarity
|
75 |
+
print(text_embeddings[1] @ image_embeddings[0].T) # text-image cross-modal similarity
|
76 |
+
print(text_embeddings[1] @ image_embeddings[1].T)# text-image cross-modal similarity
|
77 |
+
```
|
78 |
+
|
79 |
+
3. JavaScript developers can use Jina CLIP via the [Transformers.js](https://huggingface.co/docs/transformers.js) library. Note that to use this model, you need to install Transformers.js [v3](https://github.com/xenova/transformers.js/tree/v3) from source using `npm install xenova/transformers.js#v3`.
|
80 |
+
|
81 |
+
```js
|
82 |
+
import { AutoTokenizer, CLIPTextModelWithProjection, AutoProcessor, CLIPVisionModelWithProjection, RawImage, cos_sim } from '@xenova/transformers';
|
83 |
+
|
84 |
+
// Load tokenizer and text model
|
85 |
+
const tokenizer = await AutoTokenizer.from_pretrained('jinaai/jina-clip-v1');
|
86 |
+
const text_model = await CLIPTextModelWithProjection.from_pretrained('jinaai/jina-clip-v1');
|
87 |
+
|
88 |
+
// Load processor and vision model
|
89 |
+
const processor = await AutoProcessor.from_pretrained('Xenova/clip-vit-base-patch32');
|
90 |
+
const vision_model = await CLIPVisionModelWithProjection.from_pretrained('jinaai/jina-clip-v1');
|
91 |
+
|
92 |
+
// Run tokenization
|
93 |
+
const texts = ['A blue cat', 'A red cat'];
|
94 |
+
const text_inputs = tokenizer(texts, { padding: true, truncation: true });
|
95 |
+
|
96 |
+
// Compute text embeddings
|
97 |
+
const { text_embeds } = await text_model(text_inputs);
|
98 |
+
|
99 |
+
// Read images and run processor
|
100 |
+
const urls = [
|
101 |
+
'https://i.pinimg.com/600x315/21/48/7e/21487e8e0970dd366dafaed6ab25d8d8.jpg',
|
102 |
+
'https://i.pinimg.com/736x/c9/f2/3e/c9f23e212529f13f19bad5602d84b78b.jpg'
|
103 |
+
];
|
104 |
+
const image = await Promise.all(urls.map(url => RawImage.read(url)));
|
105 |
+
const image_inputs = await processor(image);
|
106 |
+
|
107 |
+
// Compute vision embeddings
|
108 |
+
const { image_embeds } = await vision_model(image_inputs);
|
109 |
+
|
110 |
+
// Compute similarities
|
111 |
+
console.log(cos_sim(text_embeds[0].data, text_embeds[1].data)) // text embedding similarity
|
112 |
+
console.log(cos_sim(text_embeds[0].data, image_embeds[0].data)) // text-image cross-modal similarity
|
113 |
+
console.log(cos_sim(text_embeds[0].data, image_embeds[1].data)) // text-image cross-modal similarity
|
114 |
+
console.log(cos_sim(text_embeds[1].data, image_embeds[0].data)) // text-image cross-modal similarity
|
115 |
+
console.log(cos_sim(text_embeds[1].data, image_embeds[1].data)) // text-image cross-modal similarity
|
116 |
+
```
|
117 |
+
|
118 |
+
## Performance
|
119 |
+
|
120 |
+
### Text-Image Retrieval
|
121 |
+
|
122 |
+
| Name | Flickr Image Retr. R@1 | Flickr Image Retr. R@5 | Flickr Text Retr. R@1 | Flickr Text Retr. R@5 |
|
123 |
+
|------------------|-------------------------|-------------------------|-----------------------|-----------------------|
|
124 |
+
| ViT-B-32 | 0.597 | 0.8398 | 0.781 | 0.938 |
|
125 |
+
| ViT-B-16 | 0.6216 | 0.8572 | 0.822 | 0.966 |
|
126 |
+
| jina-clip | 0.6748 | 0.8902 | 0.811 | 0.965 |
|
127 |
+
|
128 |
+
|
129 |
+
| Name | MSCOCO Image Retr. R@1 | MSCOCO Image Retr. R@5 | MSCOCO Text Retr. R@1 | MSCOCO Text Retr. R@5 |
|
130 |
+
|------------------|-------------------------|-------------------------|-----------------------|-----------------------|
|
131 |
+
| ViT-B-32 | 0.342 | 0.6001 | 0.5234 | 0.7634 |
|
132 |
+
| ViT-B-16 | 0.3309 | 0.5842 | 0.5242 | 0.767 |
|
133 |
+
| jina-clip | 0.4111 | 0.6644 | 0.5544 | 0.7904 |
|
134 |
+
|
135 |
+
### Text-Text Retrieval
|
136 |
+
|
137 |
+
| Name | STS12 | STS15 | STS17 | STS13 | STS14 | STS16 | STS22 | STSBenchmark | SummEval |
|
138 |
+
|-----------------------|--------|--------|--------|--------|--------|--------|--------|--------------|----------|
|
139 |
+
| jina-embeddings-v2 | 0.7427 | 0.8755 | 0.8888 | 0.833 | 0.7917 | 0.836 | 0.6346 | 0.8404 | 0.3056 |
|
140 |
+
| jina-clip | 0.7352 | 0.8746 | 0.8976 | 0.8323 | 0.7868 | 0.8377 | 0.6583 | 0.8493 | 0.3048 |
|
141 |
+
|
142 |
+
|
143 |
+
| Name | ArguAna | FiQA2018 | NFCorpus | Quora | SCIDOCS | SciFact | TRECCOVID |
|
144 |
+
|--------------------|---------|----------|----------|-------|---------|---------|-----------|
|
145 |
+
| jina-embeddings-v2 | 0.4418 | 0.4158 | 0.3245 | 0.882 | 0.1986 | 0.6668 | 0.6591 |
|
146 |
+
| jina-clip | 0.4933 | 0.3827 | 0.3352 | 0.8789| 0.2024 | 0.6734 | 0.7161 |
|
147 |
+
|
148 |
+
## Contact
|
149 |
+
|
150 |
+
Join our [Discord community](https://discord.jina.ai) and chat with other community members about ideas.
|
151 |
+
|
152 |
+
## Citation
|
153 |
+
|
154 |
+
If you find `jina-clip-v1` useful in your research, please cite the following paper:
|
155 |
+
|
156 |
+
```bibtex
|
157 |
+
@misc{2405.20204,
|
158 |
+
Author = {Andreas Koukounas and Georgios Mastrapas and Michael Günther and Bo Wang and Scott Martens and Isabelle Mohr and Saba Sturua and Mohammad Kalim Akram and Joan Fontanals Martínez and Saahil Ognawala and Susana Guzman and Maximilian Werk and Nan Wang and Han Xiao},
|
159 |
+
Title = {Jina CLIP: Your CLIP Model Is Also Your Text Retriever},
|
160 |
+
Year = {2024},
|
161 |
+
Eprint = {arXiv:2405.20204},
|
162 |
+
}
|
163 |
+
```
|
164 |
+
|
165 |
+
## FAQ
|
166 |
+
|
167 |
+
### I encounter this problem, what should I do?
|
168 |
+
|
169 |
+
```
|
170 |
+
ValueError: The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has <class 'transformers_modules.jinaai.jina-clip-implementation.7f069e2d54d609ef1ad2eb578c7bf07b5a51de41.configuration_clip.JinaCLIPConfig'> and you passed <class 'transformers_modules.jinaai.jina-clip-implementation.7f069e2d54d609ef1ad2eb578c7bf07b5a51de41.configuration_cli.JinaCLIPConfig'>. Fix one of those so they match!
|
171 |
+
```
|
172 |
+
|
173 |
+
There was a bug in Transformers library between 4.40.x to 4.41.1. You can update transformers to >4.41.2 or <=4.40.0
|
174 |
+
|
175 |
+
### Given one query, how can I merge its text-text and text-image cosine similarity?
|
176 |
+
|
177 |
+
Our emperical study shows that text-text cosine similarity is normally larger than text-image cosine similarity!
|
178 |
+
If you want to merge two scores, we recommended 2 ways:
|
179 |
+
|
180 |
+
1. weighted average of text-text sim and text-image sim:
|
181 |
+
|
182 |
+
```python
|
183 |
+
combined_scores = sim(text, text) + lambda * sim(text, image) # optimal lambda depends on your dataset, but in general lambda=2 can be a good choice.
|
184 |
+
```
|
185 |
+
|
186 |
+
2. apply z-score normalization before merging scores:
|
187 |
+
|
188 |
+
```python
|
189 |
+
# pseudo code
|
190 |
+
query_document_mean = np.mean(cos_sim_text_texts)
|
191 |
+
query_document_std = np.std(cos_sim_text_texts)
|
192 |
+
text_image_mean = np.mean(cos_sim_text_images)
|
193 |
+
text_image_std = np.std(cos_sim_text_images)
|
194 |
+
|
195 |
+
query_document_sim_normalized = (cos_sim_query_documents - query_document_mean) / query_document_std
|
196 |
+
text_image_sim_normalized = (cos_sim_text_images - text_image_mean) / text_image_std
|
197 |
+
```
|
config.json
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": null,
|
3 |
+
"_name_or_path": "jina-clip-v1",
|
4 |
+
"add_projections": false,
|
5 |
+
"architectures": [
|
6 |
+
"JinaCLIPModel"
|
7 |
+
],
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_clip.JinaCLIPConfig",
|
10 |
+
"AutoModel": "modeling_clip.JinaCLIPModel"
|
11 |
+
},
|
12 |
+
"initializer_factor": 1.0,
|
13 |
+
"logit_scale_init_value": 2.6592,
|
14 |
+
"model_type": "jina_clip",
|
15 |
+
"projection_dim": 768,
|
16 |
+
"text_config": {
|
17 |
+
"_name_or_path": "",
|
18 |
+
"add_cross_attention": false,
|
19 |
+
"architectures": null,
|
20 |
+
"bad_words_ids": null,
|
21 |
+
"begin_suppress_tokens": null,
|
22 |
+
"bos_token_id": null,
|
23 |
+
"chunk_size_feed_forward": 0,
|
24 |
+
"cross_attention_hidden_size": null,
|
25 |
+
"decoder_start_token_id": null,
|
26 |
+
"diversity_penalty": 0.0,
|
27 |
+
"do_sample": false,
|
28 |
+
"early_stopping": false,
|
29 |
+
"embed_dim": 768,
|
30 |
+
"encoder_no_repeat_ngram_size": 0,
|
31 |
+
"eos_token_id": null,
|
32 |
+
"exponential_decay_length_penalty": null,
|
33 |
+
"finetuning_task": null,
|
34 |
+
"forced_bos_token_id": null,
|
35 |
+
"forced_eos_token_id": null,
|
36 |
+
"hf_model_config_kwargs": {
|
37 |
+
"use_flash_attn": false
|
38 |
+
},
|
39 |
+
"hf_model_name_or_path": "jinaai/jina-bert-flash-implementation",
|
40 |
+
"id2label": {
|
41 |
+
"0": "LABEL_0",
|
42 |
+
"1": "LABEL_1"
|
43 |
+
},
|
44 |
+
"is_decoder": false,
|
45 |
+
"is_encoder_decoder": false,
|
46 |
+
"label2id": {
|
47 |
+
"LABEL_0": 0,
|
48 |
+
"LABEL_1": 1
|
49 |
+
},
|
50 |
+
"length_penalty": 1.0,
|
51 |
+
"max_length": 20,
|
52 |
+
"min_length": 0,
|
53 |
+
"model_type": "jina_clip_text",
|
54 |
+
"no_repeat_ngram_size": 0,
|
55 |
+
"num_beam_groups": 1,
|
56 |
+
"num_beams": 1,
|
57 |
+
"num_return_sequences": 1,
|
58 |
+
"output_attentions": false,
|
59 |
+
"output_hidden_states": false,
|
60 |
+
"output_scores": false,
|
61 |
+
"pad_token_id": null,
|
62 |
+
"pooler_type": "mean_pooler",
|
63 |
+
"prefix": null,
|
64 |
+
"problem_type": null,
|
65 |
+
"proj_bias": false,
|
66 |
+
"proj_type": null,
|
67 |
+
"pruned_heads": {},
|
68 |
+
"remove_invalid_values": false,
|
69 |
+
"repetition_penalty": 1.0,
|
70 |
+
"return_dict": true,
|
71 |
+
"return_dict_in_generate": false,
|
72 |
+
"sep_token_id": null,
|
73 |
+
"suppress_tokens": null,
|
74 |
+
"task_specific_params": null,
|
75 |
+
"temperature": 1.0,
|
76 |
+
"tf_legacy_loss": false,
|
77 |
+
"tie_encoder_decoder": false,
|
78 |
+
"tie_word_embeddings": true,
|
79 |
+
"tokenizer_class": null,
|
80 |
+
"top_k": 50,
|
81 |
+
"top_p": 1.0,
|
82 |
+
"torch_dtype": null,
|
83 |
+
"torchscript": false,
|
84 |
+
"transformers_version": "4.41.2",
|
85 |
+
"typical_p": 1.0,
|
86 |
+
"use_bfloat16": false
|
87 |
+
},
|
88 |
+
"torch_dtype": "float32",
|
89 |
+
"transformers_version": null,
|
90 |
+
"use_text_flash_attn": null,
|
91 |
+
"use_vision_xformers": null,
|
92 |
+
"vision_config": {
|
93 |
+
"_name_or_path": "",
|
94 |
+
"add_cross_attention": false,
|
95 |
+
"architectures": null,
|
96 |
+
"bad_words_ids": null,
|
97 |
+
"begin_suppress_tokens": null,
|
98 |
+
"bos_token_id": null,
|
99 |
+
"chunk_size_feed_forward": 0,
|
100 |
+
"cross_attention_hidden_size": null,
|
101 |
+
"decoder_start_token_id": null,
|
102 |
+
"diversity_penalty": 0.0,
|
103 |
+
"do_sample": false,
|
104 |
+
"drop_path_rate": 0.0,
|
105 |
+
"early_stopping": false,
|
106 |
+
"embed_dim": 768,
|
107 |
+
"encoder_no_repeat_ngram_size": 0,
|
108 |
+
"eos_token_id": null,
|
109 |
+
"exponential_decay_length_penalty": null,
|
110 |
+
"finetuning_task": null,
|
111 |
+
"forced_bos_token_id": null,
|
112 |
+
"forced_eos_token_id": null,
|
113 |
+
"fused_layer_norm": false,
|
114 |
+
"head_width": 64,
|
115 |
+
"id2label": {
|
116 |
+
"0": "LABEL_0",
|
117 |
+
"1": "LABEL_1"
|
118 |
+
},
|
119 |
+
"image_size": 224,
|
120 |
+
"intp_freq": false,
|
121 |
+
"is_decoder": false,
|
122 |
+
"is_encoder_decoder": false,
|
123 |
+
"label2id": {
|
124 |
+
"LABEL_0": 0,
|
125 |
+
"LABEL_1": 1
|
126 |
+
},
|
127 |
+
"layers": 12,
|
128 |
+
"length_penalty": 1.0,
|
129 |
+
"ls_init_value": null,
|
130 |
+
"max_length": 20,
|
131 |
+
"min_length": 0,
|
132 |
+
"mlp_ratio": 2.6667,
|
133 |
+
"model_type": "jina_clip_vision",
|
134 |
+
"naive_swiglu": true,
|
135 |
+
"no_repeat_ngram_size": 0,
|
136 |
+
"num_beam_groups": 1,
|
137 |
+
"num_beams": 1,
|
138 |
+
"num_return_sequences": 1,
|
139 |
+
"output_attentions": false,
|
140 |
+
"output_hidden_states": false,
|
141 |
+
"output_scores": false,
|
142 |
+
"pad_token_id": null,
|
143 |
+
"patch_dropout": 0.1,
|
144 |
+
"patch_size": 16,
|
145 |
+
"post_norm": false,
|
146 |
+
"prefix": null,
|
147 |
+
"problem_type": null,
|
148 |
+
"proj_type": null,
|
149 |
+
"pruned_heads": {},
|
150 |
+
"pt_hw_seq_len": 14,
|
151 |
+
"qkv_bias": true,
|
152 |
+
"remove_invalid_values": false,
|
153 |
+
"repetition_penalty": 1.0,
|
154 |
+
"return_dict": true,
|
155 |
+
"return_dict_in_generate": false,
|
156 |
+
"rope_embeddings": true,
|
157 |
+
"sep_token_id": null,
|
158 |
+
"subln": true,
|
159 |
+
"suppress_tokens": null,
|
160 |
+
"task_specific_params": null,
|
161 |
+
"temperature": 1.0,
|
162 |
+
"tf_legacy_loss": false,
|
163 |
+
"tie_encoder_decoder": false,
|
164 |
+
"tie_word_embeddings": true,
|
165 |
+
"tokenizer_class": null,
|
166 |
+
"top_k": 50,
|
167 |
+
"top_p": 1.0,
|
168 |
+
"torch_dtype": null,
|
169 |
+
"torchscript": false,
|
170 |
+
"transformers_version": "4.41.2",
|
171 |
+
"typical_p": 1.0,
|
172 |
+
"use_bfloat16": false,
|
173 |
+
"width": 768,
|
174 |
+
"x_attention": false
|
175 |
+
}
|
176 |
+
}
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "3.1.0.dev0",
|
4 |
+
"transformers": "4.41.2",
|
5 |
+
"pytorch": "2.3.1+cu121"
|
6 |
+
},
|
7 |
+
"prompts": {},
|
8 |
+
"default_prompt_name": null,
|
9 |
+
"similarity_fn_name": "cosine"
|
10 |
+
}
|
configuration_clip.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
#
|
3 |
+
# Code mainly copied from:
|
4 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/configuration_clip.py
|
5 |
+
# and adjusted for Jina CLIP
|
6 |
+
|
7 |
+
import os
|
8 |
+
from copy import deepcopy
|
9 |
+
from typing import Any, Dict, Optional, Union
|
10 |
+
|
11 |
+
from transformers import PretrainedConfig, logging
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
""" Jina CLIP model configuration """
|
17 |
+
|
18 |
+
|
19 |
+
class JinaCLIPTextConfig(PretrainedConfig):
|
20 |
+
model_type = 'jina_clip_text'
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
embed_dim: int = 768,
|
25 |
+
hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
|
26 |
+
hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
|
27 |
+
pooler_type: Optional[str] = None,
|
28 |
+
proj_type: Optional[str] = None,
|
29 |
+
proj_bias: bool = False,
|
30 |
+
**kwargs,
|
31 |
+
):
|
32 |
+
super().__init__(**kwargs)
|
33 |
+
|
34 |
+
self.embed_dim = embed_dim
|
35 |
+
self.hf_model_name_or_path = hf_model_name_or_path
|
36 |
+
self.hf_model_config_kwargs = hf_model_config_kwargs or {}
|
37 |
+
self.pooler_type = pooler_type
|
38 |
+
self.proj_type = proj_type
|
39 |
+
self.proj_bias = proj_bias
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_pretrained(
|
43 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
44 |
+
) -> 'PretrainedConfig':
|
45 |
+
cls._set_token_in_kwargs(kwargs)
|
46 |
+
|
47 |
+
configdict, kwargs = cls.get_config_dict(
|
48 |
+
pretrained_model_name_or_path, **kwargs
|
49 |
+
)
|
50 |
+
|
51 |
+
# get the text config dict if we are loading from JinaCLIPConfig
|
52 |
+
if configdict.get('model_type') == 'jina_clip':
|
53 |
+
configdict = configdict['text_config']
|
54 |
+
|
55 |
+
if (
|
56 |
+
'model_type' in configdict
|
57 |
+
and hasattr(cls, 'model_type')
|
58 |
+
and configdict['model_type'] != cls.model_type
|
59 |
+
):
|
60 |
+
logger.warning(
|
61 |
+
f'You are using a model of type {configdict["model_type"]} to '
|
62 |
+
f'instantiate a model of type {cls.model_type}. This is not supported '
|
63 |
+
'for all configurations of models and can yield errors.'
|
64 |
+
)
|
65 |
+
|
66 |
+
return cls.from_dict(configdict, **kwargs)
|
67 |
+
|
68 |
+
|
69 |
+
class JinaCLIPVisionConfig(PretrainedConfig):
|
70 |
+
model_type = 'jina_clip_vision'
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
embed_dim: int = 768,
|
75 |
+
width: int = 768,
|
76 |
+
image_size: int = 224,
|
77 |
+
patch_size: int = 16,
|
78 |
+
layers: int = 12,
|
79 |
+
head_width: int = 64,
|
80 |
+
mlp_ratio: float = 4.0,
|
81 |
+
ls_init_value: Optional[float] = None,
|
82 |
+
patch_dropout: float = 0.0,
|
83 |
+
qkv_bias: bool = True,
|
84 |
+
fused_layer_norm: bool = False,
|
85 |
+
x_attention: bool = False,
|
86 |
+
post_norm: bool = False,
|
87 |
+
rope_embeddings: bool = False,
|
88 |
+
pt_hw_seq_len: int = 16,
|
89 |
+
intp_freq: bool = False,
|
90 |
+
naive_swiglu: bool = False,
|
91 |
+
subln: bool = False,
|
92 |
+
drop_path_rate: float = 0.0,
|
93 |
+
proj_type: Optional[str] = None,
|
94 |
+
**kwargs,
|
95 |
+
):
|
96 |
+
super().__init__(**kwargs)
|
97 |
+
|
98 |
+
self.layers = layers
|
99 |
+
self.embed_dim = embed_dim
|
100 |
+
self.width = width
|
101 |
+
self.head_width = head_width
|
102 |
+
self.mlp_ratio = mlp_ratio
|
103 |
+
self.image_size = image_size
|
104 |
+
self.patch_size = patch_size
|
105 |
+
self.ls_init_value = ls_init_value
|
106 |
+
self.patch_dropout = patch_dropout
|
107 |
+
self.qkv_bias = qkv_bias
|
108 |
+
self.fused_layer_norm = fused_layer_norm
|
109 |
+
self.x_attention = x_attention
|
110 |
+
self.post_norm = post_norm
|
111 |
+
self.rope_embeddings = rope_embeddings
|
112 |
+
self.pt_hw_seq_len = pt_hw_seq_len
|
113 |
+
self.intp_freq = intp_freq
|
114 |
+
self.naive_swiglu = naive_swiglu
|
115 |
+
self.subln = subln
|
116 |
+
self.drop_path_rate = drop_path_rate
|
117 |
+
self.proj_type = proj_type
|
118 |
+
|
119 |
+
@classmethod
|
120 |
+
def from_pretrained(
|
121 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
122 |
+
) -> 'PretrainedConfig':
|
123 |
+
cls._set_token_in_kwargs(kwargs)
|
124 |
+
|
125 |
+
configdict, kwargs = cls.get_config_dict(
|
126 |
+
pretrained_model_name_or_path, **kwargs
|
127 |
+
)
|
128 |
+
|
129 |
+
# get the vision config dict if we are loading from JinaCLIPConfig
|
130 |
+
if configdict.get('model_type') == 'jina_clip':
|
131 |
+
configdict = configdict['vision_config']
|
132 |
+
|
133 |
+
if (
|
134 |
+
'model_type' in configdict
|
135 |
+
and hasattr(cls, 'model_type')
|
136 |
+
and configdict['model_type'] != cls.model_type
|
137 |
+
):
|
138 |
+
logger.warning(
|
139 |
+
f'You are using a model of type {configdict["model_type"]} to '
|
140 |
+
f'instantiate a model of type {cls.model_type}. This is not supported '
|
141 |
+
'for all configurations of models and can yield errors.'
|
142 |
+
)
|
143 |
+
|
144 |
+
return cls.from_dict(configdict, **kwargs)
|
145 |
+
|
146 |
+
|
147 |
+
class JinaCLIPConfig(PretrainedConfig):
|
148 |
+
model_type = 'jina_clip'
|
149 |
+
is_composition = True
|
150 |
+
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
text_config: Optional[Dict] = None,
|
154 |
+
vision_config: Optional[Dict] = None,
|
155 |
+
add_projections: bool = False,
|
156 |
+
projection_dim: int = 768,
|
157 |
+
logit_scale_init_value: float = 2.6592,
|
158 |
+
use_text_flash_attn: Optional[bool] = None,
|
159 |
+
use_vision_xformers: Optional[bool] = None,
|
160 |
+
**kwargs,
|
161 |
+
):
|
162 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
163 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid
|
164 |
+
# them being saved (which causes a lot of confusion!).
|
165 |
+
|
166 |
+
text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
|
167 |
+
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
168 |
+
self.use_text_flash_attn = use_text_flash_attn
|
169 |
+
self.use_vision_xformers = use_vision_xformers
|
170 |
+
|
171 |
+
super().__init__(**kwargs)
|
172 |
+
|
173 |
+
if text_config_dict is not None:
|
174 |
+
if text_config is None:
|
175 |
+
text_config = {}
|
176 |
+
|
177 |
+
# This is the complete result when using `text_config_dict`.
|
178 |
+
_text_config_dict = JinaCLIPTextConfig(**text_config_dict).to_dict()
|
179 |
+
|
180 |
+
# Give a warning if the values exist in both `_text_config_dict` and
|
181 |
+
# `text_config` but being different.
|
182 |
+
for key, value in _text_config_dict.items():
|
183 |
+
if (
|
184 |
+
key in text_config
|
185 |
+
and value != text_config[key]
|
186 |
+
and key not in ['transformers_version']
|
187 |
+
):
|
188 |
+
# If specified in `text_config_dict`
|
189 |
+
if key in text_config_dict:
|
190 |
+
message = (
|
191 |
+
f'`{key}` is found in both `text_config_dict` and '
|
192 |
+
f'`text_config` but with different values. '
|
193 |
+
f'The value `text_config_dict["{key}"]` will be used '
|
194 |
+
f'instead.'
|
195 |
+
)
|
196 |
+
# If inferred from default argument values (
|
197 |
+
# just to be super careful)
|
198 |
+
else:
|
199 |
+
message = (
|
200 |
+
f'`text_config_dict` is provided which will be used to '
|
201 |
+
f'initialize `JinaCLIPTextConfig`. The '
|
202 |
+
f'value `text_config["{key}"]` will be overriden.'
|
203 |
+
)
|
204 |
+
logger.info(message)
|
205 |
+
|
206 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
207 |
+
text_config.update(_text_config_dict)
|
208 |
+
|
209 |
+
if vision_config_dict is not None:
|
210 |
+
if vision_config is None:
|
211 |
+
vision_config = {}
|
212 |
+
|
213 |
+
# This is the complete result when using `vision_config_dict`.
|
214 |
+
_vision_config_dict = JinaCLIPVisionConfig(**vision_config_dict).to_dict()
|
215 |
+
# convert keys to string instead of integer
|
216 |
+
if 'id2label' in _vision_config_dict:
|
217 |
+
_vision_config_dict['id2label'] = {
|
218 |
+
str(key): value
|
219 |
+
for key, value in _vision_config_dict['id2label'].items()
|
220 |
+
}
|
221 |
+
|
222 |
+
# Give a warning if the values exist in both `_vision_config_dict`
|
223 |
+
# and `vision_config` but being different.
|
224 |
+
for key, value in _vision_config_dict.items():
|
225 |
+
if (
|
226 |
+
key in vision_config
|
227 |
+
and value != vision_config[key]
|
228 |
+
and key not in ['transformers_version']
|
229 |
+
):
|
230 |
+
# If specified in `vision_config_dict`
|
231 |
+
if key in vision_config_dict:
|
232 |
+
message = (
|
233 |
+
f'`{key}` is found in both `vision_config_dict` and '
|
234 |
+
f'`vision_config` but with different '
|
235 |
+
f'values. The value `vision_config_dict["{key}"]` will '
|
236 |
+
f'be used instead.'
|
237 |
+
)
|
238 |
+
# If inferred from default argument values
|
239 |
+
# (just to be super careful)
|
240 |
+
else:
|
241 |
+
message = (
|
242 |
+
f'`vision_config_dict` is provided which will be used to '
|
243 |
+
f'initialize `JinaCLIPVisionConfig`. '
|
244 |
+
f'The value `vision_config["{key}"]` will be overriden.'
|
245 |
+
)
|
246 |
+
logger.info(message)
|
247 |
+
|
248 |
+
# Update all values in `vision_config` with the ones in
|
249 |
+
# `_vision_config_dict`.
|
250 |
+
vision_config.update(_vision_config_dict)
|
251 |
+
|
252 |
+
if text_config is None:
|
253 |
+
text_config = {}
|
254 |
+
logger.info(
|
255 |
+
'`text_config` is `None`. Initializing the `JinaCLIPTextConfig` with '
|
256 |
+
'default values.'
|
257 |
+
)
|
258 |
+
|
259 |
+
if vision_config is None:
|
260 |
+
vision_config = {}
|
261 |
+
logger.info(
|
262 |
+
'`vision_config` is `None`. initializing the `JinaCLIPVisionConfig` '
|
263 |
+
'with default values.'
|
264 |
+
)
|
265 |
+
|
266 |
+
self.text_config = JinaCLIPTextConfig(**text_config)
|
267 |
+
self.vision_config = JinaCLIPVisionConfig(**vision_config)
|
268 |
+
|
269 |
+
self.add_projections = add_projections
|
270 |
+
self.projection_dim = projection_dim
|
271 |
+
self.logit_scale_init_value = logit_scale_init_value
|
272 |
+
self.initializer_factor = 1.0
|
273 |
+
|
274 |
+
if not self.add_projections:
|
275 |
+
if self.text_config.embed_dim != self.vision_config.embed_dim:
|
276 |
+
raise ValueError(
|
277 |
+
'When projections are disabled (`add_projections=False`), text '
|
278 |
+
'and vision towers need to have the same embedding dimensionality. '
|
279 |
+
f'Currently text embedding dim is {self.text_config.embed_dim} != '
|
280 |
+
f'{self.vision_config.embed_dim} of the vision tower. '
|
281 |
+
'Either set the same output dim for both towers, or enable '
|
282 |
+
'projections with `add_projections=True`.'
|
283 |
+
)
|
284 |
+
|
285 |
+
@classmethod
|
286 |
+
def from_text_vision_configs(
|
287 |
+
cls,
|
288 |
+
text_config: JinaCLIPTextConfig,
|
289 |
+
vision_config: JinaCLIPVisionConfig,
|
290 |
+
**kwargs,
|
291 |
+
):
|
292 |
+
return cls(
|
293 |
+
text_config=text_config.to_dict(),
|
294 |
+
vision_config=vision_config.to_dict(),
|
295 |
+
projection_dim=text_config.projection_dim,
|
296 |
+
**kwargs,
|
297 |
+
)
|
298 |
+
|
299 |
+
def to_dict(self):
|
300 |
+
output = deepcopy(self.__dict__)
|
301 |
+
output['text_config'] = self.text_config.to_dict()
|
302 |
+
output['vision_config'] = self.vision_config.to_dict()
|
303 |
+
output['model_type'] = self.__class__.model_type
|
304 |
+
return output
|
custom_st.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from io import BytesIO
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
from .custom_st_2 import OtherClass
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoImageProcessor
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
OtherClass()
|
15 |
+
|
16 |
+
class Transformer(nn.Module):
|
17 |
+
"""Huggingface AutoModel to generate token embeddings.
|
18 |
+
Loads the correct class, e.g. BERT / RoBERTa etc.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model_name_or_path: Huggingface models name
|
22 |
+
(https://huggingface.co/models)
|
23 |
+
max_seq_length: Truncate any inputs longer than max_seq_length
|
24 |
+
model_args: Keyword arguments passed to the Huggingface
|
25 |
+
Transformers model
|
26 |
+
tokenizer_args: Keyword arguments passed to the Huggingface
|
27 |
+
Transformers tokenizer
|
28 |
+
config_args: Keyword arguments passed to the Huggingface
|
29 |
+
Transformers config
|
30 |
+
cache_dir: Cache dir for Huggingface Transformers to store/load
|
31 |
+
models
|
32 |
+
do_lower_case: If true, lowercases the input (independent if the
|
33 |
+
model is cased or not)
|
34 |
+
tokenizer_name_or_path: Name or path of the tokenizer. When
|
35 |
+
None, then model_name_or_path is used
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
model_name_or_path: str,
|
41 |
+
max_seq_length: Optional[int] = None,
|
42 |
+
model_args: Optional[Dict[str, Any]] = None,
|
43 |
+
tokenizer_args: Optional[Dict[str, Any]] = None,
|
44 |
+
config_args: Optional[Dict[str, Any]] = None,
|
45 |
+
cache_dir: Optional[str] = None,
|
46 |
+
do_lower_case: bool = False,
|
47 |
+
tokenizer_name_or_path: str = None,
|
48 |
+
) -> None:
|
49 |
+
super(Transformer, self).__init__()
|
50 |
+
self.config_keys = ["max_seq_length", "do_lower_case"]
|
51 |
+
self.do_lower_case = do_lower_case
|
52 |
+
if model_args is None:
|
53 |
+
model_args = {}
|
54 |
+
if tokenizer_args is None:
|
55 |
+
tokenizer_args = {}
|
56 |
+
if config_args is None:
|
57 |
+
config_args = {}
|
58 |
+
|
59 |
+
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
60 |
+
self.jina_clip = AutoModel.from_pretrained(
|
61 |
+
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
|
62 |
+
)
|
63 |
+
|
64 |
+
if max_seq_length is not None and "model_max_length" not in tokenizer_args:
|
65 |
+
tokenizer_args["model_max_length"] = max_seq_length
|
66 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
67 |
+
tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
|
68 |
+
cache_dir=cache_dir,
|
69 |
+
**tokenizer_args,
|
70 |
+
)
|
71 |
+
self.preprocessor = AutoImageProcessor.from_pretrained(
|
72 |
+
tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
|
73 |
+
cache_dir=cache_dir,
|
74 |
+
**tokenizer_args,
|
75 |
+
)
|
76 |
+
|
77 |
+
# No max_seq_length set. Try to infer from model
|
78 |
+
if max_seq_length is None:
|
79 |
+
if (
|
80 |
+
hasattr(self.jina_clip, "config")
|
81 |
+
and hasattr(self.jina_clip.config, "max_position_embeddings")
|
82 |
+
and hasattr(self.tokenizer, "model_max_length")
|
83 |
+
):
|
84 |
+
max_seq_length = min(self.jina_clip.config.max_position_embeddings, self.tokenizer.model_max_length)
|
85 |
+
|
86 |
+
self.max_seq_length = max_seq_length
|
87 |
+
|
88 |
+
if tokenizer_name_or_path is not None:
|
89 |
+
self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__
|
90 |
+
|
91 |
+
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
92 |
+
"""Returns token_embeddings, cls_token"""
|
93 |
+
if "input_ids" in features:
|
94 |
+
embedding = self.jina_clip.get_text_features(input_ids=features["input_ids"])
|
95 |
+
else:
|
96 |
+
embedding = self.jina_clip.get_image_features(pixel_values=features["pixel_values"])
|
97 |
+
return {"sentence_embedding": embedding}
|
98 |
+
|
99 |
+
def get_word_embedding_dimension(self) -> int:
|
100 |
+
return self.config.text_config.embed_dim
|
101 |
+
|
102 |
+
def decode_data_image(data_image_str):
|
103 |
+
header, data = data_image_str.split(',', 1)
|
104 |
+
image_data = base64.b64decode(data)
|
105 |
+
return Image.open(BytesIO(image_data))
|
106 |
+
|
107 |
+
def tokenize(
|
108 |
+
self, batch: Union[List[str]], padding: Union[str, bool] = True
|
109 |
+
) -> Dict[str, torch.Tensor]:
|
110 |
+
"""Tokenizes a text and maps tokens to token-ids"""
|
111 |
+
images = []
|
112 |
+
texts = []
|
113 |
+
for sample in batch:
|
114 |
+
if isinstance(sample, str):
|
115 |
+
if sample.startswith('http'):
|
116 |
+
response = requests.get(sample)
|
117 |
+
images.append(Image.open(BytesIO(response.content)).convert('RGB'))
|
118 |
+
elif sample.startswith('data:image/'):
|
119 |
+
images.append(self.decode_data_image(sample).convert('RGB'))
|
120 |
+
else:
|
121 |
+
# TODO: Make sure that Image.open fails for non-image files
|
122 |
+
try:
|
123 |
+
images.append(Image.open(sample).convert('RGB'))
|
124 |
+
except:
|
125 |
+
texts.append(sample)
|
126 |
+
elif isinstance(sample, Image.Image):
|
127 |
+
images.append(sample.convert('RGB'))
|
128 |
+
|
129 |
+
if images and texts:
|
130 |
+
raise ValueError('Batch must contain either images or texts, not both')
|
131 |
+
|
132 |
+
if texts:
|
133 |
+
return self.tokenizer(
|
134 |
+
texts,
|
135 |
+
padding=padding,
|
136 |
+
truncation="longest_first",
|
137 |
+
return_tensors="pt",
|
138 |
+
max_length=self.max_seq_length,
|
139 |
+
)
|
140 |
+
elif images:
|
141 |
+
return self.preprocessor(images)
|
142 |
+
return {}
|
143 |
+
|
144 |
+
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
145 |
+
self.jina_clip.save_pretrained(output_path, safe_serialization=safe_serialization)
|
146 |
+
self.tokenizer.save_pretrained(output_path)
|
147 |
+
self.preprocessor.save_pretrained(output_path)
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def load(input_path: str) -> "Transformer":
|
151 |
+
# Old classes used other config names than 'sentence_bert_config.json'
|
152 |
+
for config_name in [
|
153 |
+
"sentence_bert_config.json",
|
154 |
+
"sentence_roberta_config.json",
|
155 |
+
"sentence_distilbert_config.json",
|
156 |
+
"sentence_camembert_config.json",
|
157 |
+
"sentence_albert_config.json",
|
158 |
+
"sentence_xlm-roberta_config.json",
|
159 |
+
"sentence_xlnet_config.json",
|
160 |
+
]:
|
161 |
+
sbert_config_path = os.path.join(input_path, config_name)
|
162 |
+
if os.path.exists(sbert_config_path):
|
163 |
+
break
|
164 |
+
|
165 |
+
with open(sbert_config_path) as fIn:
|
166 |
+
config = json.load(fIn)
|
167 |
+
# Don't allow configs to set trust_remote_code
|
168 |
+
if "model_args" in config and "trust_remote_code" in config["model_args"]:
|
169 |
+
config["model_args"].pop("trust_remote_code")
|
170 |
+
if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
|
171 |
+
config["tokenizer_args"].pop("trust_remote_code")
|
172 |
+
if "config_args" in config and "trust_remote_code" in config["config_args"]:
|
173 |
+
config["config_args"].pop("trust_remote_code")
|
174 |
+
return Transformer(model_name_or_path=input_path, **config)
|
custom_st_2.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class OtherClass:
|
3 |
+
pass
|
eva_model.py
ADDED
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from EVA CLIP
|
3 |
+
# https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
try:
|
15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
16 |
+
except ImportError or ModuleNotFoundError:
|
17 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
18 |
+
|
19 |
+
from .rope_embeddings import VisionRotaryEmbeddingFast
|
20 |
+
|
21 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
22 |
+
try:
|
23 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
24 |
+
except ImportError or ModuleNotFoundError:
|
25 |
+
from torch.utils.checkpoint import checkpoint
|
26 |
+
else:
|
27 |
+
from torch.utils.checkpoint import checkpoint
|
28 |
+
|
29 |
+
try:
|
30 |
+
import xformers.ops as xops
|
31 |
+
except ImportError:
|
32 |
+
xops = None
|
33 |
+
|
34 |
+
|
35 |
+
class PatchDropout(nn.Module):
|
36 |
+
"""
|
37 |
+
https://arxiv.org/abs/2212.00794
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, prob, exclude_first_token=True):
|
41 |
+
super().__init__()
|
42 |
+
assert 0 <= prob < 1.0
|
43 |
+
self.prob = prob
|
44 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
if not self.training or self.prob == 0.0:
|
48 |
+
return x
|
49 |
+
|
50 |
+
if self.exclude_first_token:
|
51 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
52 |
+
else:
|
53 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
54 |
+
|
55 |
+
batch = x.size()[0]
|
56 |
+
num_tokens = x.size()[1]
|
57 |
+
|
58 |
+
batch_indices = torch.arange(batch)
|
59 |
+
batch_indices = batch_indices[..., None]
|
60 |
+
|
61 |
+
keep_prob = 1 - self.prob
|
62 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
63 |
+
|
64 |
+
rand = torch.randn(batch, num_tokens)
|
65 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
66 |
+
|
67 |
+
x = x[batch_indices, patch_indices_keep]
|
68 |
+
|
69 |
+
if self.exclude_first_token:
|
70 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
71 |
+
|
72 |
+
return x, patch_indices_keep
|
73 |
+
|
74 |
+
|
75 |
+
class DropPath(nn.Module):
|
76 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
77 |
+
residual blocks)."""
|
78 |
+
|
79 |
+
def __init__(self, drop_prob=None):
|
80 |
+
super(DropPath, self).__init__()
|
81 |
+
self.drop_prob = drop_prob
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
return drop_path(x, self.drop_prob, self.training)
|
85 |
+
|
86 |
+
def extra_repr(self) -> str:
|
87 |
+
return 'p={}'.format(self.drop_prob)
|
88 |
+
|
89 |
+
|
90 |
+
class Mlp(nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
in_features,
|
94 |
+
hidden_features=None,
|
95 |
+
out_features=None,
|
96 |
+
act_layer=nn.GELU,
|
97 |
+
norm_layer=nn.LayerNorm,
|
98 |
+
drop=0.0,
|
99 |
+
subln=False,
|
100 |
+
):
|
101 |
+
super().__init__()
|
102 |
+
out_features = out_features or in_features
|
103 |
+
hidden_features = hidden_features or in_features
|
104 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
105 |
+
self.act = act_layer()
|
106 |
+
|
107 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
108 |
+
|
109 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
110 |
+
self.drop = nn.Dropout(drop)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
x = self.fc1(x)
|
114 |
+
x = self.act(x)
|
115 |
+
# x = self.drop(x)
|
116 |
+
# commit this for the orignal BERT implement
|
117 |
+
x = self.ffn_ln(x)
|
118 |
+
|
119 |
+
x = self.fc2(x)
|
120 |
+
x = self.drop(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class SwiGLU(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
in_features,
|
128 |
+
hidden_features=None,
|
129 |
+
out_features=None,
|
130 |
+
act_layer=nn.SiLU,
|
131 |
+
drop=0.0,
|
132 |
+
norm_layer=nn.LayerNorm,
|
133 |
+
subln=False,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
out_features = out_features or in_features
|
137 |
+
hidden_features = hidden_features or in_features
|
138 |
+
|
139 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
140 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
141 |
+
|
142 |
+
self.act = act_layer()
|
143 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
144 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
145 |
+
|
146 |
+
self.drop = nn.Dropout(drop)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
x1 = self.w1(x)
|
150 |
+
x2 = self.w2(x)
|
151 |
+
hidden = self.act(x1) * x2
|
152 |
+
x = self.ffn_ln(hidden)
|
153 |
+
x = self.w3(x)
|
154 |
+
x = self.drop(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
|
158 |
+
class Attention(nn.Module):
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
dim,
|
162 |
+
num_heads=8,
|
163 |
+
qkv_bias=False,
|
164 |
+
qk_scale=None,
|
165 |
+
attn_drop=0.0,
|
166 |
+
proj_drop=0.0,
|
167 |
+
window_size=None,
|
168 |
+
attn_head_dim=None,
|
169 |
+
xattn=False,
|
170 |
+
rope=None,
|
171 |
+
subln=False,
|
172 |
+
norm_layer=nn.LayerNorm,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
self.num_heads = num_heads
|
176 |
+
head_dim = dim // num_heads
|
177 |
+
if attn_head_dim is not None:
|
178 |
+
head_dim = attn_head_dim
|
179 |
+
all_head_dim = head_dim * self.num_heads
|
180 |
+
self.scale = qk_scale or head_dim**-0.5
|
181 |
+
|
182 |
+
self.subln = subln
|
183 |
+
if self.subln:
|
184 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
185 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
186 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
187 |
+
else:
|
188 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
189 |
+
|
190 |
+
if qkv_bias:
|
191 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
192 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
193 |
+
else:
|
194 |
+
self.q_bias = None
|
195 |
+
self.v_bias = None
|
196 |
+
|
197 |
+
if window_size:
|
198 |
+
self.window_size = window_size
|
199 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
200 |
+
2 * window_size[1] - 1
|
201 |
+
) + 3
|
202 |
+
self.relative_position_bias_table = nn.Parameter(
|
203 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
204 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
205 |
+
# cls to token & token 2 cls & cls to cls
|
206 |
+
|
207 |
+
# get pair-wise relative position index for each token inside the window
|
208 |
+
coords_h = torch.arange(window_size[0])
|
209 |
+
coords_w = torch.arange(window_size[1])
|
210 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
211 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
212 |
+
relative_coords = (
|
213 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
214 |
+
) # 2, Wh*Ww, Wh*Ww
|
215 |
+
relative_coords = relative_coords.permute(
|
216 |
+
1, 2, 0
|
217 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
218 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
219 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
220 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
221 |
+
relative_position_index = torch.zeros(
|
222 |
+
size=(window_size[0] * window_size[1] + 1,) * 2,
|
223 |
+
dtype=relative_coords.dtype,
|
224 |
+
)
|
225 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
226 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
227 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
228 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
229 |
+
|
230 |
+
self.register_buffer('relative_position_index', relative_position_index)
|
231 |
+
else:
|
232 |
+
self.window_size = None
|
233 |
+
self.relative_position_bias_table = None
|
234 |
+
self.relative_position_index = None
|
235 |
+
|
236 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
237 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
238 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
239 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
240 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
241 |
+
self.xattn = xattn
|
242 |
+
self.xattn_drop = attn_drop
|
243 |
+
|
244 |
+
self.rope = rope
|
245 |
+
|
246 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
247 |
+
B, N, C = x.shape
|
248 |
+
if self.subln:
|
249 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
250 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
251 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
252 |
+
|
253 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(
|
254 |
+
0, 2, 1, 3
|
255 |
+
) # B, num_heads, N, C
|
256 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
257 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
258 |
+
else:
|
259 |
+
qkv_bias = None
|
260 |
+
if self.q_bias is not None:
|
261 |
+
qkv_bias = torch.cat(
|
262 |
+
(
|
263 |
+
self.q_bias,
|
264 |
+
torch.zeros_like(self.v_bias, requires_grad=False),
|
265 |
+
self.v_bias,
|
266 |
+
)
|
267 |
+
)
|
268 |
+
|
269 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
270 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
|
271 |
+
2, 0, 3, 1, 4
|
272 |
+
) # 3, B, num_heads, N, C
|
273 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
274 |
+
|
275 |
+
if self.rope:
|
276 |
+
# slightly fast impl
|
277 |
+
q_t = q[:, :, 1:, :]
|
278 |
+
ro_q_t = self.rope(q_t)
|
279 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
280 |
+
|
281 |
+
k_t = k[:, :, 1:, :]
|
282 |
+
ro_k_t = self.rope(k_t)
|
283 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
284 |
+
|
285 |
+
if self.xattn:
|
286 |
+
if xops is None:
|
287 |
+
raise ValueError(
|
288 |
+
"Can't use xattn without xformers. Please 'pip install xformers'"
|
289 |
+
)
|
290 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
291 |
+
k = k.permute(0, 2, 1, 3)
|
292 |
+
v = v.permute(0, 2, 1, 3)
|
293 |
+
|
294 |
+
x = xops.memory_efficient_attention(
|
295 |
+
q,
|
296 |
+
k,
|
297 |
+
v,
|
298 |
+
p=self.xattn_drop,
|
299 |
+
scale=self.scale,
|
300 |
+
)
|
301 |
+
x = x.reshape(B, N, -1)
|
302 |
+
x = self.inner_attn_ln(x)
|
303 |
+
x = self.proj(x)
|
304 |
+
x = self.proj_drop(x)
|
305 |
+
else:
|
306 |
+
q = q * self.scale
|
307 |
+
attn = q @ k.transpose(-2, -1)
|
308 |
+
|
309 |
+
if self.relative_position_bias_table is not None:
|
310 |
+
relative_position_bias = self.relative_position_bias_table[
|
311 |
+
self.relative_position_index.view(-1)
|
312 |
+
].view(
|
313 |
+
self.window_size[0] * self.window_size[1] + 1,
|
314 |
+
self.window_size[0] * self.window_size[1] + 1,
|
315 |
+
-1,
|
316 |
+
) # Wh*Ww,Wh*Ww,nH
|
317 |
+
relative_position_bias = relative_position_bias.permute(
|
318 |
+
2, 0, 1
|
319 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
320 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
321 |
+
|
322 |
+
if rel_pos_bias is not None:
|
323 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
324 |
+
|
325 |
+
if attn_mask is not None:
|
326 |
+
attn_mask = attn_mask.bool()
|
327 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float('-inf'))
|
328 |
+
|
329 |
+
attn = attn.softmax(dim=-1)
|
330 |
+
attn = self.attn_drop(attn)
|
331 |
+
|
332 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
333 |
+
x = self.inner_attn_ln(x)
|
334 |
+
x = self.proj(x)
|
335 |
+
x = self.proj_drop(x)
|
336 |
+
return x
|
337 |
+
|
338 |
+
|
339 |
+
class Block(nn.Module):
|
340 |
+
def __init__(
|
341 |
+
self,
|
342 |
+
dim,
|
343 |
+
num_heads,
|
344 |
+
mlp_ratio=4.0,
|
345 |
+
qkv_bias=False,
|
346 |
+
qk_scale=None,
|
347 |
+
drop=0.0,
|
348 |
+
attn_drop=0.0,
|
349 |
+
drop_path=0.0,
|
350 |
+
init_values=None,
|
351 |
+
act_layer=nn.GELU,
|
352 |
+
norm_layer=nn.LayerNorm,
|
353 |
+
window_size=None,
|
354 |
+
attn_head_dim=None,
|
355 |
+
xattn=False,
|
356 |
+
rope=None,
|
357 |
+
postnorm=False,
|
358 |
+
subln=False,
|
359 |
+
naiveswiglu=False,
|
360 |
+
):
|
361 |
+
super().__init__()
|
362 |
+
self.norm1 = norm_layer(dim)
|
363 |
+
self.attn = Attention(
|
364 |
+
dim,
|
365 |
+
num_heads=num_heads,
|
366 |
+
qkv_bias=qkv_bias,
|
367 |
+
qk_scale=qk_scale,
|
368 |
+
attn_drop=attn_drop,
|
369 |
+
proj_drop=drop,
|
370 |
+
window_size=window_size,
|
371 |
+
attn_head_dim=attn_head_dim,
|
372 |
+
xattn=xattn,
|
373 |
+
rope=rope,
|
374 |
+
subln=subln,
|
375 |
+
norm_layer=norm_layer,
|
376 |
+
)
|
377 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better
|
378 |
+
# than dropout here
|
379 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
380 |
+
self.norm2 = norm_layer(dim)
|
381 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
382 |
+
|
383 |
+
if naiveswiglu:
|
384 |
+
self.mlp = SwiGLU(
|
385 |
+
in_features=dim,
|
386 |
+
hidden_features=mlp_hidden_dim,
|
387 |
+
subln=subln,
|
388 |
+
norm_layer=norm_layer,
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
self.mlp = Mlp(
|
392 |
+
in_features=dim,
|
393 |
+
hidden_features=mlp_hidden_dim,
|
394 |
+
act_layer=act_layer,
|
395 |
+
subln=subln,
|
396 |
+
drop=drop,
|
397 |
+
)
|
398 |
+
|
399 |
+
if init_values is not None and init_values > 0:
|
400 |
+
self.gamma_1 = nn.Parameter(
|
401 |
+
init_values * torch.ones((dim,)), requires_grad=True
|
402 |
+
)
|
403 |
+
self.gamma_2 = nn.Parameter(
|
404 |
+
init_values * torch.ones((dim,)), requires_grad=True
|
405 |
+
)
|
406 |
+
else:
|
407 |
+
self.gamma_1, self.gamma_2 = None, None
|
408 |
+
|
409 |
+
self.postnorm = postnorm
|
410 |
+
|
411 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
412 |
+
if self.gamma_1 is None:
|
413 |
+
if self.postnorm:
|
414 |
+
x = x + self.drop_path(
|
415 |
+
self.norm1(
|
416 |
+
self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
|
417 |
+
)
|
418 |
+
)
|
419 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
420 |
+
else:
|
421 |
+
x = x + self.drop_path(
|
422 |
+
self.attn(
|
423 |
+
self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
|
424 |
+
)
|
425 |
+
)
|
426 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
427 |
+
else:
|
428 |
+
if self.postnorm:
|
429 |
+
x = x + self.drop_path(
|
430 |
+
self.gamma_1
|
431 |
+
* self.norm1(
|
432 |
+
self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
|
433 |
+
)
|
434 |
+
)
|
435 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
436 |
+
else:
|
437 |
+
x = x + self.drop_path(
|
438 |
+
self.gamma_1
|
439 |
+
* self.attn(
|
440 |
+
self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
|
441 |
+
)
|
442 |
+
)
|
443 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
444 |
+
return x
|
445 |
+
|
446 |
+
|
447 |
+
class PatchEmbed(nn.Module):
|
448 |
+
"""Image to Patch Embedding"""
|
449 |
+
|
450 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
451 |
+
super().__init__()
|
452 |
+
img_size = to_2tuple(img_size)
|
453 |
+
patch_size = to_2tuple(patch_size)
|
454 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
455 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
456 |
+
self.img_size = img_size
|
457 |
+
self.patch_size = patch_size
|
458 |
+
self.num_patches = num_patches
|
459 |
+
|
460 |
+
self.proj = nn.Conv2d(
|
461 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
462 |
+
)
|
463 |
+
|
464 |
+
def forward(self, x, **kwargs):
|
465 |
+
target_dtype = self.proj.weight.dtype
|
466 |
+
B, C, H, W = x.shape
|
467 |
+
# FIXME look at relaxing size constraints
|
468 |
+
assert H == self.img_size[0] and W == self.img_size[1], (
|
469 |
+
f"Input image size ({H}*{W}) doesn't match model "
|
470 |
+
f'({self.img_size[0]}*{self.img_size[1]}).'
|
471 |
+
)
|
472 |
+
x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
class RelativePositionBias(nn.Module):
|
477 |
+
def __init__(self, window_size, num_heads):
|
478 |
+
super().__init__()
|
479 |
+
self.window_size = window_size
|
480 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
481 |
+
2 * window_size[1] - 1
|
482 |
+
) + 3
|
483 |
+
self.relative_position_bias_table = nn.Parameter(
|
484 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
485 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
486 |
+
# cls to token & token 2 cls & cls to cls
|
487 |
+
|
488 |
+
# get pair-wise relative position index for each token inside the window
|
489 |
+
coords_h = torch.arange(window_size[0])
|
490 |
+
coords_w = torch.arange(window_size[1])
|
491 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
492 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
493 |
+
relative_coords = (
|
494 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
495 |
+
) # 2, Wh*Ww, Wh*Ww
|
496 |
+
relative_coords = relative_coords.permute(
|
497 |
+
1, 2, 0
|
498 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
499 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
500 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
501 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
502 |
+
relative_position_index = torch.zeros(
|
503 |
+
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
504 |
+
)
|
505 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
506 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
507 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
508 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
509 |
+
|
510 |
+
self.register_buffer('relative_position_index', relative_position_index)
|
511 |
+
|
512 |
+
def forward(self):
|
513 |
+
relative_position_bias = self.relative_position_bias_table[
|
514 |
+
self.relative_position_index.view(-1)
|
515 |
+
].view(
|
516 |
+
self.window_size[0] * self.window_size[1] + 1,
|
517 |
+
self.window_size[0] * self.window_size[1] + 1,
|
518 |
+
-1,
|
519 |
+
) # Wh*Ww,Wh*Ww,nH
|
520 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
521 |
+
|
522 |
+
|
523 |
+
class EVAVisionTransformer(nn.Module):
|
524 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
525 |
+
|
526 |
+
def __init__(
|
527 |
+
self,
|
528 |
+
img_size=224,
|
529 |
+
patch_size=16,
|
530 |
+
in_chans=3,
|
531 |
+
num_classes=0,
|
532 |
+
embed_dim=768,
|
533 |
+
depth=12,
|
534 |
+
num_heads=12,
|
535 |
+
mlp_ratio=4.0,
|
536 |
+
qkv_bias=False,
|
537 |
+
qk_scale=None,
|
538 |
+
drop_rate=0.0,
|
539 |
+
attn_drop_rate=0.0,
|
540 |
+
drop_path_rate=0.0,
|
541 |
+
norm_layer=nn.LayerNorm,
|
542 |
+
init_values=None,
|
543 |
+
patch_dropout=0.0,
|
544 |
+
use_abs_pos_emb=True,
|
545 |
+
use_rel_pos_bias=False,
|
546 |
+
use_shared_rel_pos_bias=False,
|
547 |
+
rope=False,
|
548 |
+
use_mean_pooling=True,
|
549 |
+
init_scale=0.001,
|
550 |
+
grad_checkpointing=False,
|
551 |
+
xattn=False,
|
552 |
+
postnorm=False,
|
553 |
+
pt_hw_seq_len=16,
|
554 |
+
intp_freq=False,
|
555 |
+
naiveswiglu=False,
|
556 |
+
subln=False,
|
557 |
+
proj_type=None,
|
558 |
+
):
|
559 |
+
super().__init__()
|
560 |
+
self.image_size = img_size
|
561 |
+
self.num_classes = num_classes
|
562 |
+
self.num_features = (
|
563 |
+
self.embed_dim
|
564 |
+
) = embed_dim # num_features for consistency with other models
|
565 |
+
|
566 |
+
self.patch_embed = PatchEmbed(
|
567 |
+
img_size=img_size,
|
568 |
+
patch_size=patch_size,
|
569 |
+
in_chans=in_chans,
|
570 |
+
embed_dim=embed_dim,
|
571 |
+
)
|
572 |
+
num_patches = self.patch_embed.num_patches
|
573 |
+
|
574 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
575 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
576 |
+
if use_abs_pos_emb:
|
577 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
578 |
+
else:
|
579 |
+
self.pos_embed = None
|
580 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
581 |
+
|
582 |
+
if use_shared_rel_pos_bias:
|
583 |
+
self.rel_pos_bias = RelativePositionBias(
|
584 |
+
window_size=self.patch_embed.patch_shape, num_heads=num_heads
|
585 |
+
)
|
586 |
+
else:
|
587 |
+
self.rel_pos_bias = None
|
588 |
+
|
589 |
+
if rope:
|
590 |
+
half_head_dim = embed_dim // num_heads // 2
|
591 |
+
hw_seq_len = img_size // patch_size
|
592 |
+
self.rope = VisionRotaryEmbeddingFast(
|
593 |
+
dim=half_head_dim,
|
594 |
+
pt_seq_len=pt_hw_seq_len,
|
595 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
596 |
+
patch_dropout=patch_dropout,
|
597 |
+
)
|
598 |
+
else:
|
599 |
+
self.rope = None
|
600 |
+
|
601 |
+
self.naiveswiglu = naiveswiglu
|
602 |
+
|
603 |
+
dpr = [
|
604 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
605 |
+
] # stochastic depth decay rule
|
606 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
607 |
+
self.blocks = nn.ModuleList(
|
608 |
+
[
|
609 |
+
Block(
|
610 |
+
dim=embed_dim,
|
611 |
+
num_heads=num_heads,
|
612 |
+
mlp_ratio=mlp_ratio,
|
613 |
+
qkv_bias=qkv_bias,
|
614 |
+
qk_scale=qk_scale,
|
615 |
+
drop=drop_rate,
|
616 |
+
attn_drop=attn_drop_rate,
|
617 |
+
drop_path=dpr[i],
|
618 |
+
norm_layer=norm_layer,
|
619 |
+
init_values=init_values,
|
620 |
+
window_size=self.patch_embed.patch_shape
|
621 |
+
if use_rel_pos_bias
|
622 |
+
else None,
|
623 |
+
xattn=xattn,
|
624 |
+
rope=self.rope,
|
625 |
+
postnorm=postnorm,
|
626 |
+
subln=subln,
|
627 |
+
naiveswiglu=naiveswiglu,
|
628 |
+
)
|
629 |
+
for i in range(depth)
|
630 |
+
]
|
631 |
+
)
|
632 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
633 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
634 |
+
if (num_classes == embed_dim) and (proj_type is None):
|
635 |
+
self.head = nn.Identity()
|
636 |
+
elif proj_type == 'linear':
|
637 |
+
self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias)
|
638 |
+
elif proj_type == 'mlp':
|
639 |
+
hidden_size = (embed_dim + num_classes) // 2
|
640 |
+
self.proj = nn.Sequential(
|
641 |
+
nn.Linear(embed_dim, hidden_size, bias=qkv_bias),
|
642 |
+
nn.GELU(),
|
643 |
+
nn.Linear(hidden_size, num_classes, bias=qkv_bias),
|
644 |
+
)
|
645 |
+
|
646 |
+
if self.pos_embed is not None:
|
647 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
648 |
+
|
649 |
+
trunc_normal_(self.cls_token, std=0.02)
|
650 |
+
|
651 |
+
self.apply(self._init_weights)
|
652 |
+
self.fix_init_weight()
|
653 |
+
|
654 |
+
if isinstance(self.head, nn.Linear):
|
655 |
+
trunc_normal_(self.head.weight, std=0.02)
|
656 |
+
self.head.weight.data.mul_(init_scale)
|
657 |
+
if qkv_bias:
|
658 |
+
self.head.bias.data.mul_(init_scale)
|
659 |
+
|
660 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function
|
661 |
+
# would be the identity fn
|
662 |
+
self.patch_dropout = (
|
663 |
+
PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
|
664 |
+
)
|
665 |
+
|
666 |
+
self.grad_checkpointing = grad_checkpointing
|
667 |
+
|
668 |
+
def fix_init_weight(self):
|
669 |
+
def rescale(param, layer_id):
|
670 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
671 |
+
|
672 |
+
for layer_id, layer in enumerate(self.blocks):
|
673 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
674 |
+
if self.naiveswiglu:
|
675 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
676 |
+
else:
|
677 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
678 |
+
|
679 |
+
def get_cast_dtype(self) -> torch.dtype:
|
680 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
681 |
+
|
682 |
+
def _init_weights(self, m):
|
683 |
+
if isinstance(m, nn.Linear):
|
684 |
+
trunc_normal_(m.weight, std=0.02)
|
685 |
+
if m.bias is not None:
|
686 |
+
nn.init.constant_(m.bias, 0)
|
687 |
+
elif isinstance(m, nn.LayerNorm):
|
688 |
+
nn.init.constant_(m.bias, 0)
|
689 |
+
nn.init.constant_(m.weight, 1.0)
|
690 |
+
|
691 |
+
def get_num_layers(self):
|
692 |
+
return len(self.blocks)
|
693 |
+
|
694 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
695 |
+
assert (
|
696 |
+
unlocked_groups == 0
|
697 |
+
), 'partial locking not currently supported for this model'
|
698 |
+
for param in self.parameters():
|
699 |
+
param.requires_grad = False
|
700 |
+
|
701 |
+
@torch.jit.ignore
|
702 |
+
def set_grad_checkpointing(self, enable=True):
|
703 |
+
self.grad_checkpointing = enable
|
704 |
+
|
705 |
+
@torch.jit.ignore
|
706 |
+
def no_weight_decay(self):
|
707 |
+
return {'pos_embed', 'cls_token'}
|
708 |
+
|
709 |
+
def get_classifier(self):
|
710 |
+
return self.head
|
711 |
+
|
712 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
713 |
+
self.num_classes = num_classes
|
714 |
+
self.head = (
|
715 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
716 |
+
)
|
717 |
+
|
718 |
+
def forward_features(self, x, return_all_features=False):
|
719 |
+
x = self.patch_embed(x)
|
720 |
+
batch_size, seq_len, _ = x.size()
|
721 |
+
|
722 |
+
cls_tokens = self.cls_token.expand(
|
723 |
+
batch_size, -1, -1
|
724 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
725 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
726 |
+
if self.pos_embed is not None:
|
727 |
+
x = x + self.pos_embed
|
728 |
+
x = self.pos_drop(x)
|
729 |
+
|
730 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do
|
731 |
+
# nothing but return what was passed in
|
732 |
+
if self.rope is not None:
|
733 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
734 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
735 |
+
self.rope.forward = partial(
|
736 |
+
self.rope.forward, patch_indices_keep=patch_indices_keep
|
737 |
+
)
|
738 |
+
else:
|
739 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
740 |
+
x = self.patch_dropout(x)
|
741 |
+
else:
|
742 |
+
x = self.patch_dropout(x)
|
743 |
+
|
744 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
745 |
+
for blk in self.blocks:
|
746 |
+
if self.grad_checkpointing:
|
747 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
748 |
+
else:
|
749 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
750 |
+
|
751 |
+
if not return_all_features:
|
752 |
+
x = self.norm(x)
|
753 |
+
if self.fc_norm is not None:
|
754 |
+
return self.fc_norm(x.mean(1))
|
755 |
+
else:
|
756 |
+
return x[:, 0]
|
757 |
+
return x
|
758 |
+
|
759 |
+
def forward(self, x, return_all_features=False):
|
760 |
+
if return_all_features:
|
761 |
+
return self.forward_features(x, return_all_features)
|
762 |
+
x = self.forward_features(x)
|
763 |
+
x = self.head(x)
|
764 |
+
return x
|
hf_model.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Dict, Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
7 |
+
from transformers.modeling_outputs import (
|
8 |
+
BaseModelOutput,
|
9 |
+
BaseModelOutputWithPooling,
|
10 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
11 |
+
)
|
12 |
+
|
13 |
+
"""
|
14 |
+
HF architecture mapping
|
15 |
+
"""
|
16 |
+
|
17 |
+
_HF_ARCH_DICT = {
|
18 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
19 |
+
'roberta': {
|
20 |
+
'config_names': {
|
21 |
+
'context_length': 'max_position_embeddings',
|
22 |
+
'vocab_size': 'vocab_size',
|
23 |
+
'width': 'hidden_size',
|
24 |
+
'heads': 'num_attention_heads',
|
25 |
+
'layers': 'num_hidden_layers',
|
26 |
+
'layer_attr': 'layer',
|
27 |
+
'token_embeddings_attr': 'embeddings',
|
28 |
+
},
|
29 |
+
'pooler': 'mean_pooler',
|
30 |
+
},
|
31 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
32 |
+
'xlm-roberta': {
|
33 |
+
'config_names': {
|
34 |
+
'context_length': 'max_position_embeddings',
|
35 |
+
'vocab_size': 'vocab_size',
|
36 |
+
'width': 'hidden_size',
|
37 |
+
'heads': 'num_attention_heads',
|
38 |
+
'layers': 'num_hidden_layers',
|
39 |
+
'layer_attr': 'layer',
|
40 |
+
'token_embeddings_attr': 'embeddings',
|
41 |
+
},
|
42 |
+
'pooler': 'mean_pooler',
|
43 |
+
},
|
44 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
45 |
+
'mt5': {
|
46 |
+
'config_names': {
|
47 |
+
# unlimited seqlen
|
48 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
49 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
50 |
+
'context_length': '',
|
51 |
+
'vocab_size': 'vocab_size',
|
52 |
+
'width': 'd_model',
|
53 |
+
'heads': 'num_heads',
|
54 |
+
'layers': 'num_layers',
|
55 |
+
'layer_attr': 'block',
|
56 |
+
'token_embeddings_attr': 'embed_tokens',
|
57 |
+
},
|
58 |
+
'pooler': 'mean_pooler',
|
59 |
+
},
|
60 |
+
# https://huggingface.co/docs/transformers/model_doc/bert
|
61 |
+
'bert': {
|
62 |
+
'config_names': {
|
63 |
+
'context_length': 'max_position_embeddings',
|
64 |
+
'vocab_size': 'vocab_size',
|
65 |
+
'width': 'hidden_size',
|
66 |
+
'heads': 'num_attention_heads',
|
67 |
+
'layers': 'num_hidden_layers',
|
68 |
+
},
|
69 |
+
'pooler': 'cls_pooler',
|
70 |
+
},
|
71 |
+
# https://huggingface.co/docs/transformers/model_doc/m2m_100
|
72 |
+
'm2m_100': {
|
73 |
+
'config_names': {
|
74 |
+
'context_length': 'max_position_embeddings',
|
75 |
+
'vocab_size': 'vocab_size',
|
76 |
+
'width': 'd_model',
|
77 |
+
'heads': 'encoder_attention_heads',
|
78 |
+
'layers': 'encoder_layers',
|
79 |
+
},
|
80 |
+
'pooler': 'cls_pooler',
|
81 |
+
},
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
"""
|
86 |
+
Pooling functions
|
87 |
+
"""
|
88 |
+
|
89 |
+
_POOLERS = {}
|
90 |
+
|
91 |
+
|
92 |
+
def _camel2snake(s):
|
93 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
94 |
+
|
95 |
+
|
96 |
+
def register_pooler(cls):
|
97 |
+
"""Decorator registering pooler class"""
|
98 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
99 |
+
return cls
|
100 |
+
|
101 |
+
|
102 |
+
@register_pooler
|
103 |
+
class MeanPooler(nn.Module):
|
104 |
+
"""Mean pooling"""
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
|
108 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
109 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
110 |
+
|
111 |
+
|
112 |
+
@register_pooler
|
113 |
+
class MaxPooler(nn.Module):
|
114 |
+
"""
|
115 |
+
Max pooling
|
116 |
+
"""
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
|
120 |
+
masked_output = x.last_hidden_state.masked_fill(
|
121 |
+
attention_mask.unsqueeze(-1), -torch.inf
|
122 |
+
)
|
123 |
+
return masked_output.max(1).values
|
124 |
+
|
125 |
+
|
126 |
+
@register_pooler
|
127 |
+
class ClsPooler(nn.Module):
|
128 |
+
"""
|
129 |
+
CLS token pooling
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, use_pooler_output=True):
|
133 |
+
super().__init__()
|
134 |
+
self.cls_token_position = 0
|
135 |
+
self.use_pooler_output = use_pooler_output
|
136 |
+
|
137 |
+
def forward(self, x: BaseModelOutput, _: torch.Tensor):
|
138 |
+
if (
|
139 |
+
self.use_pooler_output
|
140 |
+
and isinstance(
|
141 |
+
x,
|
142 |
+
(
|
143 |
+
BaseModelOutputWithPooling,
|
144 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
145 |
+
),
|
146 |
+
)
|
147 |
+
and (x.pooler_output is not None)
|
148 |
+
):
|
149 |
+
return x.pooler_output
|
150 |
+
|
151 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
152 |
+
|
153 |
+
|
154 |
+
"""
|
155 |
+
HF text model
|
156 |
+
"""
|
157 |
+
|
158 |
+
|
159 |
+
class HFTextEncoder(nn.Module):
|
160 |
+
output_tokens: torch.jit.Final[bool]
|
161 |
+
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
model_name_or_path: str,
|
165 |
+
output_dim: int,
|
166 |
+
config: PretrainedConfig = None,
|
167 |
+
pooler_type: str = None,
|
168 |
+
proj_type: str = None,
|
169 |
+
proj_bias: bool = False,
|
170 |
+
pretrained: bool = True,
|
171 |
+
output_tokens: bool = False,
|
172 |
+
trust_remote_code: bool = False,
|
173 |
+
revision: Optional[str] = None,
|
174 |
+
model_config_kwargs: Optional[Dict] = None,
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
self.output_tokens = output_tokens
|
178 |
+
self.output_dim = output_dim
|
179 |
+
|
180 |
+
# TODO: find better way to get this information
|
181 |
+
uses_transformer_pooler = pooler_type == 'cls_pooler'
|
182 |
+
model_config_kwargs = model_config_kwargs or {}
|
183 |
+
|
184 |
+
if config is None:
|
185 |
+
self.config = AutoConfig.from_pretrained(
|
186 |
+
model_name_or_path,
|
187 |
+
trust_remote_code=trust_remote_code,
|
188 |
+
code_revision=revision,
|
189 |
+
)
|
190 |
+
self.config.update(model_config_kwargs)
|
191 |
+
create_func, model_args = (
|
192 |
+
(AutoModel.from_pretrained, model_name_or_path)
|
193 |
+
if pretrained
|
194 |
+
else (AutoModel.from_config, self.config)
|
195 |
+
)
|
196 |
+
# TODO: do all model configs have this attribute?
|
197 |
+
# PretrainedConfig does so yes??
|
198 |
+
if (
|
199 |
+
hasattr(self.config, 'is_encoder_decoder')
|
200 |
+
and self.config.is_encoder_decoder
|
201 |
+
):
|
202 |
+
self.transformer = create_func(model_args)
|
203 |
+
self.transformer = self.transformer.encoder
|
204 |
+
else:
|
205 |
+
self.transformer = create_func(
|
206 |
+
model_args,
|
207 |
+
trust_remote_code=trust_remote_code,
|
208 |
+
add_pooling_layer=uses_transformer_pooler,
|
209 |
+
code_revision=revision,
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
self.config = config
|
213 |
+
self.config.update(model_config_kwargs)
|
214 |
+
self.transformer = AutoModel.from_config(self.config)
|
215 |
+
|
216 |
+
if pooler_type is None: # get default arch pooler
|
217 |
+
pooler_type = _HF_ARCH_DICT[self.config.model_type]['pooler']
|
218 |
+
|
219 |
+
# FIXME downstream users of OpenCLIP models use these attr,
|
220 |
+
# need to verify valid across all models
|
221 |
+
self.vocab_size = getattr(self.config, 'vocab_size', 0)
|
222 |
+
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
|
223 |
+
|
224 |
+
self.pooler = _POOLERS[pooler_type]()
|
225 |
+
|
226 |
+
d_model = getattr(
|
227 |
+
self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
|
228 |
+
)
|
229 |
+
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
|
230 |
+
self.proj = nn.Identity()
|
231 |
+
elif proj_type == 'linear':
|
232 |
+
self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
|
233 |
+
elif proj_type == 'mlp':
|
234 |
+
hidden_size = (d_model + output_dim) // 2
|
235 |
+
self.proj = nn.Sequential(
|
236 |
+
nn.Linear(d_model, hidden_size, bias=proj_bias),
|
237 |
+
nn.GELU(),
|
238 |
+
nn.Linear(hidden_size, output_dim, bias=proj_bias),
|
239 |
+
)
|
240 |
+
|
241 |
+
def forward(self, x: torch.Tensor):
|
242 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
243 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
244 |
+
pooled_out = self.pooler(out, attn_mask)
|
245 |
+
projected = self.proj(pooled_out)
|
246 |
+
|
247 |
+
seq_len = out.last_hidden_state.shape[1]
|
248 |
+
tokens = (
|
249 |
+
out.last_hidden_state[
|
250 |
+
:, torch.arange(seq_len) != self.pooler.cls_token_position, :
|
251 |
+
]
|
252 |
+
if isinstance(self.pooler, ClsPooler)
|
253 |
+
else out.last_hidden_state
|
254 |
+
)
|
255 |
+
|
256 |
+
if self.output_tokens:
|
257 |
+
return projected, tokens
|
258 |
+
return projected
|
259 |
+
|
260 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
261 |
+
if not unlocked_layers: # full freezing
|
262 |
+
for n, p in self.transformer.named_parameters():
|
263 |
+
p.requires_grad = (
|
264 |
+
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
|
265 |
+
)
|
266 |
+
return
|
267 |
+
|
268 |
+
encoder = (
|
269 |
+
self.transformer.encoder
|
270 |
+
if hasattr(self.transformer, 'encoder')
|
271 |
+
else self.transformer
|
272 |
+
)
|
273 |
+
layer_list = getattr(
|
274 |
+
encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
|
275 |
+
)
|
276 |
+
print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
|
277 |
+
embeddings = getattr(
|
278 |
+
self.transformer,
|
279 |
+
_HF_ARCH_DICT[self.config.model_type]['config_names'][
|
280 |
+
'token_embeddings_attr'
|
281 |
+
],
|
282 |
+
)
|
283 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
284 |
+
# freeze layers
|
285 |
+
for module in modules:
|
286 |
+
for n, p in module.named_parameters():
|
287 |
+
p.requires_grad = (
|
288 |
+
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
|
289 |
+
)
|
290 |
+
|
291 |
+
@torch.jit.ignore
|
292 |
+
def set_grad_checkpointing(self, _=True):
|
293 |
+
self.transformer.gradient_checkpointing_enable()
|
294 |
+
|
295 |
+
def init_parameters(self):
|
296 |
+
pass
|
297 |
+
|
modeling_clip.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
#
|
3 |
+
# Code mainly copied from:
|
4 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
|
5 |
+
# and adjusted for Jina CLIP
|
6 |
+
|
7 |
+
from functools import partial
|
8 |
+
from typing import List, Optional, Tuple, Union
|
9 |
+
from io import BytesIO
|
10 |
+
import requests
|
11 |
+
import base64
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as f
|
15 |
+
import torch.utils.checkpoint
|
16 |
+
from torch import nn
|
17 |
+
from transformers import (
|
18 |
+
AutoImageProcessor,
|
19 |
+
AutoTokenizer,
|
20 |
+
BatchEncoding,
|
21 |
+
BatchFeature,
|
22 |
+
PreTrainedModel,
|
23 |
+
logging,
|
24 |
+
)
|
25 |
+
from transformers.models.clip.modeling_clip import (
|
26 |
+
CLIPOutput,
|
27 |
+
CLIPTextModelOutput,
|
28 |
+
CLIPVisionModelOutput,
|
29 |
+
clip_loss,
|
30 |
+
)
|
31 |
+
|
32 |
+
try:
|
33 |
+
from tqdm.autonotebook import trange
|
34 |
+
|
35 |
+
has_tqdm = True
|
36 |
+
except ImportError:
|
37 |
+
has_tqdm = False
|
38 |
+
|
39 |
+
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
40 |
+
from .eva_model import EVAVisionTransformer
|
41 |
+
from .hf_model import HFTextEncoder
|
42 |
+
# needed for HF to correctly import in cache
|
43 |
+
from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
|
44 |
+
from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform # noqa: F401
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__)
|
47 |
+
|
48 |
+
|
49 |
+
""" Jina CLIP model implementation """
|
50 |
+
|
51 |
+
|
52 |
+
class LayerNorm(nn.LayerNorm):
|
53 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor):
|
56 |
+
origtype = x.dtype
|
57 |
+
x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
58 |
+
return x.to(origtype)
|
59 |
+
|
60 |
+
|
61 |
+
def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
|
62 |
+
return HFTextEncoder(
|
63 |
+
model_name_or_path=config.hf_model_name_or_path,
|
64 |
+
output_dim=config.embed_dim,
|
65 |
+
pooler_type=config.pooler_type,
|
66 |
+
proj_type=config.proj_type,
|
67 |
+
proj_bias=config.proj_bias,
|
68 |
+
pretrained=False,
|
69 |
+
output_tokens=False,
|
70 |
+
trust_remote_code=True,
|
71 |
+
revision=None,
|
72 |
+
model_config_kwargs=config.hf_model_config_kwargs,
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer:
|
77 |
+
norm_layer = partial(LayerNorm, eps=1e-6)
|
78 |
+
|
79 |
+
if config.fused_layer_norm:
|
80 |
+
try:
|
81 |
+
from apex.normalization import FusedLayerNorm
|
82 |
+
|
83 |
+
norm_layer = partial(FusedLayerNorm, eps=1e-6)
|
84 |
+
except (ModuleNotFoundError, ImportError):
|
85 |
+
logger.warning('Please install apex to use fused layer norm, ignoring')
|
86 |
+
|
87 |
+
return EVAVisionTransformer(
|
88 |
+
img_size=config.image_size,
|
89 |
+
patch_size=config.patch_size,
|
90 |
+
num_classes=config.embed_dim,
|
91 |
+
use_mean_pooling=False,
|
92 |
+
init_values=config.ls_init_value,
|
93 |
+
patch_dropout=config.patch_dropout,
|
94 |
+
embed_dim=config.width,
|
95 |
+
depth=config.layers,
|
96 |
+
num_heads=config.width // config.head_width,
|
97 |
+
mlp_ratio=config.mlp_ratio,
|
98 |
+
qkv_bias=config.qkv_bias,
|
99 |
+
drop_path_rate=config.drop_path_rate,
|
100 |
+
norm_layer=norm_layer,
|
101 |
+
xattn=config.x_attention,
|
102 |
+
rope=config.rope_embeddings,
|
103 |
+
postnorm=config.post_norm,
|
104 |
+
pt_hw_seq_len=config.pt_hw_seq_len,
|
105 |
+
intp_freq=config.intp_freq,
|
106 |
+
naiveswiglu=config.naive_swiglu,
|
107 |
+
subln=config.subln,
|
108 |
+
proj_type=config.proj_type,
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
class JinaCLIPPreTrainedModel(PreTrainedModel):
|
113 |
+
"""
|
114 |
+
An abstract class to handle weights initialization and a simple interface for
|
115 |
+
downloading and loading pretrained models.
|
116 |
+
"""
|
117 |
+
|
118 |
+
config_class = JinaCLIPConfig
|
119 |
+
base_model_prefix = 'clip'
|
120 |
+
supports_gradient_checkpointing = True
|
121 |
+
|
122 |
+
def _init_weights(self, module):
|
123 |
+
"""Initialize the weights"""
|
124 |
+
if isinstance(module, JinaCLIPModel):
|
125 |
+
if isinstance(module.text_projection, nn.Linear):
|
126 |
+
nn.init.normal_(
|
127 |
+
module.text_projection.weight,
|
128 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
129 |
+
)
|
130 |
+
if isinstance(module.text_projection, nn.Linear):
|
131 |
+
nn.init.normal_(
|
132 |
+
module.visual_projection.weight,
|
133 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
134 |
+
)
|
135 |
+
if isinstance(module, nn.LayerNorm):
|
136 |
+
module.bias.data.zero_()
|
137 |
+
module.weight.data.fill_(1.0)
|
138 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
139 |
+
module.bias.data.zero_()
|
140 |
+
|
141 |
+
|
142 |
+
class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
|
143 |
+
config_class = JinaCLIPTextConfig
|
144 |
+
|
145 |
+
def __init__(self, config: JinaCLIPTextConfig):
|
146 |
+
super().__init__(config)
|
147 |
+
self.text_model = _build_text_tower(config)
|
148 |
+
self.post_init()
|
149 |
+
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
153 |
+
return_dict: Optional[bool] = None,
|
154 |
+
*_,
|
155 |
+
**__,
|
156 |
+
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
|
157 |
+
return_dict = (
|
158 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
159 |
+
)
|
160 |
+
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
|
161 |
+
feats = self.text_model(x=x)
|
162 |
+
out = CLIPTextModelOutput(text_embeds=feats)
|
163 |
+
return out if return_dict else out.to_tuple()
|
164 |
+
|
165 |
+
|
166 |
+
class JinaCLIPVisionModel(JinaCLIPPreTrainedModel):
|
167 |
+
config_class = JinaCLIPVisionConfig
|
168 |
+
main_input_name = 'pixel_values'
|
169 |
+
|
170 |
+
def __init__(self, config: JinaCLIPVisionConfig):
|
171 |
+
super().__init__(config)
|
172 |
+
self.vision_model = _build_vision_tower(config)
|
173 |
+
self.post_init()
|
174 |
+
|
175 |
+
def forward(
|
176 |
+
self,
|
177 |
+
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
|
178 |
+
return_dict: Optional[bool] = None,
|
179 |
+
*_,
|
180 |
+
**__,
|
181 |
+
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
|
182 |
+
return_dict = (
|
183 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
184 |
+
)
|
185 |
+
x = (
|
186 |
+
pixel_values.pixel_values
|
187 |
+
if isinstance(pixel_values, BatchFeature)
|
188 |
+
else pixel_values
|
189 |
+
)
|
190 |
+
feats = self.vision_model(x=x)
|
191 |
+
out = CLIPVisionModelOutput(image_embeds=feats)
|
192 |
+
return out if return_dict else out.to_tuple()
|
193 |
+
|
194 |
+
|
195 |
+
class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
196 |
+
config_class = JinaCLIPConfig
|
197 |
+
|
198 |
+
def __init__(self, config: JinaCLIPConfig):
|
199 |
+
super().__init__(config)
|
200 |
+
|
201 |
+
if not isinstance(config.text_config, JinaCLIPTextConfig):
|
202 |
+
raise ValueError(
|
203 |
+
'Attribute config.text_config is expected to be of type '
|
204 |
+
f'JinaCLIPTextConfig but is of type {type(config.text_config)}.'
|
205 |
+
)
|
206 |
+
|
207 |
+
if not isinstance(config.vision_config, JinaCLIPVisionConfig):
|
208 |
+
raise ValueError(
|
209 |
+
'Attribute config.vision_config is expected to be of type '
|
210 |
+
f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
|
211 |
+
)
|
212 |
+
|
213 |
+
text_config = config.text_config
|
214 |
+
vision_config = config.vision_config
|
215 |
+
|
216 |
+
if config.use_text_flash_attn is not None:
|
217 |
+
text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
|
218 |
+
if config.use_vision_xformers is not None:
|
219 |
+
vision_config.x_attention = config.use_vision_xformers
|
220 |
+
|
221 |
+
self.add_projections = config.add_projections
|
222 |
+
self.projection_dim = config.projection_dim
|
223 |
+
self.text_embed_dim = text_config.embed_dim
|
224 |
+
self.vision_embed_dim = vision_config.embed_dim
|
225 |
+
|
226 |
+
self.text_model = _build_text_tower(text_config)
|
227 |
+
self.vision_model = _build_vision_tower(vision_config)
|
228 |
+
self.logit_scale = nn.Parameter(
|
229 |
+
torch.tensor(self.config.logit_scale_init_value)
|
230 |
+
)
|
231 |
+
|
232 |
+
if self.add_projections:
|
233 |
+
self.visual_projection = nn.Linear(
|
234 |
+
self.vision_embed_dim, self.projection_dim, bias=False
|
235 |
+
)
|
236 |
+
self.text_projection = nn.Linear(
|
237 |
+
self.text_embed_dim, self.projection_dim, bias=False
|
238 |
+
)
|
239 |
+
else:
|
240 |
+
self.visual_projection = nn.Identity()
|
241 |
+
self.text_projection = nn.Identity()
|
242 |
+
|
243 |
+
self.tokenizer = None
|
244 |
+
self.preprocess = None
|
245 |
+
self.post_init()
|
246 |
+
|
247 |
+
def get_tokenizer(self):
|
248 |
+
if not self.tokenizer:
|
249 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
250 |
+
self.config._name_or_path, trust_remote_code=True
|
251 |
+
)
|
252 |
+
return self.tokenizer
|
253 |
+
|
254 |
+
def get_preprocess(self):
|
255 |
+
if not self.preprocess:
|
256 |
+
self.preprocess = AutoImageProcessor.from_pretrained(
|
257 |
+
self.config._name_or_path, trust_remote_code=True
|
258 |
+
)
|
259 |
+
return self.preprocess
|
260 |
+
|
261 |
+
def get_text_features(
|
262 |
+
self,
|
263 |
+
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
264 |
+
*_,
|
265 |
+
**__,
|
266 |
+
) -> torch.FloatTensor:
|
267 |
+
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
|
268 |
+
return self.text_projection(self.text_model(x=x))
|
269 |
+
|
270 |
+
def get_image_features(
|
271 |
+
self,
|
272 |
+
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
|
273 |
+
*_,
|
274 |
+
**__,
|
275 |
+
) -> torch.FloatTensor:
|
276 |
+
x = (
|
277 |
+
pixel_values.pixel_values
|
278 |
+
if isinstance(pixel_values, BatchFeature)
|
279 |
+
else pixel_values
|
280 |
+
)
|
281 |
+
return self.visual_projection(self.vision_model(x=x))
|
282 |
+
|
283 |
+
@torch.inference_mode()
|
284 |
+
def encode_text(
|
285 |
+
self,
|
286 |
+
sentences: Union[str, List[str]],
|
287 |
+
batch_size: int = 32,
|
288 |
+
show_progress_bar: Optional[bool] = None,
|
289 |
+
convert_to_numpy: bool = True,
|
290 |
+
convert_to_tensor: bool = False,
|
291 |
+
device: Optional[torch.device] = None,
|
292 |
+
normalize_embeddings: bool = True,
|
293 |
+
**tokenizer_kwargs,
|
294 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
295 |
+
"""
|
296 |
+
Computes sentence embeddings
|
297 |
+
Args:
|
298 |
+
sentences(`str` or `List[str]`):
|
299 |
+
Sentence or sentences to be encoded
|
300 |
+
batch_size(`int`, *optional*, defaults to 32):
|
301 |
+
Batch size for the computation
|
302 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
303 |
+
Show a progress bar when encoding sentences.
|
304 |
+
If set to None, progress bar is only shown when
|
305 |
+
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
306 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
307 |
+
If true, the output is a list of numpy vectors.
|
308 |
+
Else, it is a list of pytorch tensors.
|
309 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
310 |
+
If true, you get one large tensor as return.
|
311 |
+
Overwrites any setting from convert_to_numpy
|
312 |
+
device(`torch.device`, *optional*, defaults to None):
|
313 |
+
Which torch.device to use for the computation
|
314 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
315 |
+
If set to true, returned vectors will have length 1. In that case,
|
316 |
+
the faster dot-product (util.dot_score) instead of cosine similarity
|
317 |
+
can be used.
|
318 |
+
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
319 |
+
Keyword arguments for the tokenizer
|
320 |
+
Returns:
|
321 |
+
By default, a list of tensors is returned.
|
322 |
+
If convert_to_tensor, a stacked tensor is returned.
|
323 |
+
If convert_to_numpy, a numpy matrix is returned.
|
324 |
+
"""
|
325 |
+
is_training = self.training
|
326 |
+
self.eval()
|
327 |
+
all_embeddings = []
|
328 |
+
|
329 |
+
self.tokenizer = self.get_tokenizer()
|
330 |
+
|
331 |
+
if show_progress_bar is None:
|
332 |
+
show_progress_bar = (
|
333 |
+
logger.getEffectiveLevel() == logging.INFO
|
334 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
335 |
+
)
|
336 |
+
|
337 |
+
if convert_to_tensor:
|
338 |
+
convert_to_numpy = False
|
339 |
+
|
340 |
+
input_was_string = False
|
341 |
+
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
|
342 |
+
sentences = [sentences]
|
343 |
+
input_was_string = True
|
344 |
+
|
345 |
+
if device is not None:
|
346 |
+
self.to(device)
|
347 |
+
|
348 |
+
permutation = np.argsort([-len(i) for i in sentences])
|
349 |
+
inverse_permutation = np.argsort(permutation)
|
350 |
+
sentences = [sentences[idx] for idx in permutation]
|
351 |
+
|
352 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
353 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
|
354 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
355 |
+
|
356 |
+
if has_tqdm:
|
357 |
+
range_iter = trange(
|
358 |
+
0,
|
359 |
+
len(sentences),
|
360 |
+
batch_size,
|
361 |
+
desc='Encoding',
|
362 |
+
disable=not show_progress_bar,
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
range_iter = range(0, len(sentences), batch_size)
|
366 |
+
|
367 |
+
for i in range_iter:
|
368 |
+
encoded_input = self.tokenizer(
|
369 |
+
sentences[i : i + batch_size],
|
370 |
+
return_tensors='pt',
|
371 |
+
**tokenizer_kwargs,
|
372 |
+
).to(self.device)
|
373 |
+
|
374 |
+
embeddings = self.get_text_features(input_ids=encoded_input)
|
375 |
+
if normalize_embeddings:
|
376 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
377 |
+
if convert_to_numpy:
|
378 |
+
embeddings = embeddings.cpu()
|
379 |
+
all_embeddings.extend(embeddings)
|
380 |
+
|
381 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
382 |
+
|
383 |
+
if convert_to_tensor:
|
384 |
+
all_embeddings = torch.stack(all_embeddings)
|
385 |
+
elif convert_to_numpy:
|
386 |
+
all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
|
387 |
+
|
388 |
+
if input_was_string:
|
389 |
+
all_embeddings = all_embeddings[0]
|
390 |
+
|
391 |
+
self.train(is_training)
|
392 |
+
return all_embeddings
|
393 |
+
|
394 |
+
def decode_data_image(data_image_str):
|
395 |
+
header, data = data_image_str.split(',', 1)
|
396 |
+
image_data = base64.b64decode(data)
|
397 |
+
return Image.open(BytesIO(image_data))
|
398 |
+
|
399 |
+
@torch.inference_mode()
|
400 |
+
def encode_image(
|
401 |
+
self,
|
402 |
+
images: Union[str, List[Union[str, "Image.Image"]]],
|
403 |
+
batch_size: int = 32,
|
404 |
+
show_progress_bar: Optional[bool] = None,
|
405 |
+
convert_to_numpy: bool = True,
|
406 |
+
convert_to_tensor: bool = False,
|
407 |
+
device: Optional[torch.device] = None,
|
408 |
+
normalize_embeddings: bool = True,
|
409 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
410 |
+
"""
|
411 |
+
Computes image embeddings.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
images(`str` or `List[Union[str, Image.Image]]`):
|
415 |
+
image paths, URLs, PIL images, or data:image/ strings to be encoded
|
416 |
+
batch_size(`int`, *optional*, defaults to 32):
|
417 |
+
Batch size for the computation
|
418 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
419 |
+
Show a progress bar when encoding images.
|
420 |
+
If set to None, progress bar is only shown when
|
421 |
+
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
422 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
423 |
+
If true, the output is a list of numpy vectors.
|
424 |
+
Else, it is a list of pytorch tensors.
|
425 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
426 |
+
If true, you get one large tensor as return.
|
427 |
+
Overwrites any setting from convert_to_numpy
|
428 |
+
device(`torch.device`, *optional*, defaults to None):
|
429 |
+
Which torch.device to use for the computation
|
430 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
431 |
+
If set to true, returned vectors will have length 1. In that case,
|
432 |
+
the faster dot-product (util.dot_score) instead of cosine similarity
|
433 |
+
can be used.
|
434 |
+
Returns:
|
435 |
+
By default, a list of tensors is returned.
|
436 |
+
If convert_to_tensor, a stacked tensor is returned.
|
437 |
+
If convert_to_numpy, a numpy matrix is returned.
|
438 |
+
"""
|
439 |
+
|
440 |
+
is_training = self.training
|
441 |
+
self.eval()
|
442 |
+
|
443 |
+
self.preprocess = self.get_preprocess()
|
444 |
+
all_embeddings = []
|
445 |
+
|
446 |
+
if show_progress_bar is None:
|
447 |
+
show_progress_bar = (
|
448 |
+
logger.getEffectiveLevel() == logging.INFO
|
449 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
450 |
+
)
|
451 |
+
|
452 |
+
if convert_to_tensor:
|
453 |
+
convert_to_numpy = False
|
454 |
+
|
455 |
+
input_was_single_img = False
|
456 |
+
if isinstance(images, str) or not hasattr(images, '__len__'):
|
457 |
+
images = [images]
|
458 |
+
input_was_single_img = True
|
459 |
+
|
460 |
+
if device is not None:
|
461 |
+
self.to(device)
|
462 |
+
|
463 |
+
permutation = np.argsort([-len(str(i)) for i in images])
|
464 |
+
inverse_permutation = np.argsort(permutation)
|
465 |
+
images = [images[idx] for idx in permutation]
|
466 |
+
|
467 |
+
if has_tqdm:
|
468 |
+
range_iter = trange(
|
469 |
+
0,
|
470 |
+
len(images),
|
471 |
+
batch_size,
|
472 |
+
desc='Encoding',
|
473 |
+
disable=not show_progress_bar,
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
range_iter = range(0, len(images), batch_size)
|
477 |
+
|
478 |
+
from PIL import Image
|
479 |
+
|
480 |
+
for i in range_iter:
|
481 |
+
batch_images = images[i:i+batch_size]
|
482 |
+
processed_inputs = []
|
483 |
+
|
484 |
+
for img in batch_images:
|
485 |
+
if isinstance(img, str):
|
486 |
+
if img.startswith('http'):
|
487 |
+
response = requests.get(img)
|
488 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
489 |
+
elif img.startswith('data:image/'):
|
490 |
+
image = decode_data_image(img).convert('RGB')
|
491 |
+
else:
|
492 |
+
image = Image.open(img).convert('RGB')
|
493 |
+
elif isinstance(img, Image.Image):
|
494 |
+
image = img.convert('RGB')
|
495 |
+
else:
|
496 |
+
raise ValueError("Unsupported image format")
|
497 |
+
|
498 |
+
processed_inputs.append(image)
|
499 |
+
|
500 |
+
processed_inputs = self.preprocess(processed_inputs)
|
501 |
+
processed_inputs = processed_inputs.to(self.device)
|
502 |
+
embeddings = self.get_image_features(processed_inputs)
|
503 |
+
|
504 |
+
if normalize_embeddings:
|
505 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
506 |
+
if convert_to_numpy:
|
507 |
+
embeddings = embeddings.cpu()
|
508 |
+
all_embeddings.extend(embeddings)
|
509 |
+
|
510 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
511 |
+
|
512 |
+
if convert_to_tensor:
|
513 |
+
all_embeddings = torch.stack(all_embeddings)
|
514 |
+
elif convert_to_numpy:
|
515 |
+
all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
|
516 |
+
|
517 |
+
if input_was_single_img:
|
518 |
+
all_embeddings = all_embeddings[0]
|
519 |
+
|
520 |
+
self.train(is_training)
|
521 |
+
return all_embeddings
|
522 |
+
|
523 |
+
def forward(
|
524 |
+
self,
|
525 |
+
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
526 |
+
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
|
527 |
+
return_dict: Optional[bool] = None,
|
528 |
+
return_loss: Optional[bool] = None,
|
529 |
+
*_,
|
530 |
+
**__,
|
531 |
+
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]:
|
532 |
+
return_dict = (
|
533 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
534 |
+
)
|
535 |
+
image_embeds = self.get_image_features(pixel_values=pixel_values)
|
536 |
+
text_embeds = self.get_text_features(input_ids=input_ids)
|
537 |
+
|
538 |
+
# normalized features
|
539 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
540 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
541 |
+
|
542 |
+
# cosine similarity as logits
|
543 |
+
logit_scale = self.logit_scale.exp()
|
544 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
545 |
+
logits_per_image = logits_per_text.t()
|
546 |
+
|
547 |
+
loss = None
|
548 |
+
if return_loss:
|
549 |
+
loss = clip_loss(logits_per_text)
|
550 |
+
|
551 |
+
if not return_dict:
|
552 |
+
output = (
|
553 |
+
logits_per_image,
|
554 |
+
logits_per_text,
|
555 |
+
text_embeds,
|
556 |
+
image_embeds,
|
557 |
+
None,
|
558 |
+
None,
|
559 |
+
)
|
560 |
+
return ((loss,) + output) if loss is not None else output
|
561 |
+
|
562 |
+
return CLIPOutput(
|
563 |
+
loss=loss,
|
564 |
+
logits_per_image=logits_per_image,
|
565 |
+
logits_per_text=logits_per_text,
|
566 |
+
text_embeds=text_embeds,
|
567 |
+
image_embeds=image_embeds,
|
568 |
+
text_model_output=None,
|
569 |
+
vision_model_output=None,
|
570 |
+
)
|
modules.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "custom_st.Transformer"
|
7 |
+
}
|
8 |
+
]
|
preprocessor_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoImageProcessor": "processing_clip.JinaCLIPImageProcessor",
|
4 |
+
"AutoProcessor": "jinaai/jina-clip-implementation--processing_clip.JinaCLIPProcessor"
|
5 |
+
},
|
6 |
+
"fill_color": 0,
|
7 |
+
"image_processor_type": "JinaCLIPImageProcessor",
|
8 |
+
"interpolation": "bicubic",
|
9 |
+
"mean": [
|
10 |
+
0.48145466,
|
11 |
+
0.4578275,
|
12 |
+
0.40821073
|
13 |
+
],
|
14 |
+
"processor_class": "JinaCLIPProcessor",
|
15 |
+
"resize_mode": "shortest",
|
16 |
+
"size": 224,
|
17 |
+
"std": [
|
18 |
+
0.26862954,
|
19 |
+
0.26130258,
|
20 |
+
0.27577711
|
21 |
+
]
|
22 |
+
}
|
processing_clip.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
#
|
3 |
+
# Code mainly copied from:
|
4 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py
|
5 |
+
# and adjusted for Jina CLIP
|
6 |
+
|
7 |
+
from typing import Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
11 |
+
from transformers.image_utils import ImageInput, make_list_of_images
|
12 |
+
from transformers.models.clip import CLIPProcessor
|
13 |
+
|
14 |
+
from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform
|
15 |
+
|
16 |
+
""" Jina CLIP processor implementation """
|
17 |
+
|
18 |
+
|
19 |
+
class JinaCLIPProcessor(CLIPProcessor):
|
20 |
+
image_processor_class = 'AutoImageProcessor'
|
21 |
+
tokenizer_class = 'AutoTokenizer'
|
22 |
+
|
23 |
+
|
24 |
+
""" Jina CLIP image processor implementation """
|
25 |
+
|
26 |
+
|
27 |
+
class JinaCLIPImageProcessor(BaseImageProcessor):
|
28 |
+
model_input_names = ['pixel_values']
|
29 |
+
_valid_processor_keys = [
|
30 |
+
'size',
|
31 |
+
'mean',
|
32 |
+
'std',
|
33 |
+
'resize_mode',
|
34 |
+
'interpolation',
|
35 |
+
'fill_color',
|
36 |
+
]
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN,
|
42 |
+
std: Union[float, Tuple[float]] = OPENAI_DATASET_STD,
|
43 |
+
resize_mode: str = 'shortest',
|
44 |
+
interpolation: str = 'bicubic',
|
45 |
+
fill_color: int = 0,
|
46 |
+
**kwargs,
|
47 |
+
) -> None:
|
48 |
+
super().__init__(**kwargs)
|
49 |
+
self.size = size
|
50 |
+
self.mean = mean
|
51 |
+
self.std = std
|
52 |
+
self.resize_mode = resize_mode
|
53 |
+
self.interpolation = interpolation
|
54 |
+
self.fill_color = fill_color
|
55 |
+
self.transform = self._build_transform()
|
56 |
+
|
57 |
+
def _build_transform(self):
|
58 |
+
return image_transform(
|
59 |
+
image_size=self.size,
|
60 |
+
is_train=False,
|
61 |
+
mean=self.mean,
|
62 |
+
std=self.std,
|
63 |
+
resize_mode=self.resize_mode,
|
64 |
+
interpolation=self.interpolation,
|
65 |
+
fill_color=self.fill_color,
|
66 |
+
aug_cfg=None,
|
67 |
+
)
|
68 |
+
|
69 |
+
def to_dict(self):
|
70 |
+
output = super().to_dict()
|
71 |
+
output.pop('transform')
|
72 |
+
return output
|
73 |
+
|
74 |
+
def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
|
75 |
+
|
76 |
+
_transform_needs_rebuild = False
|
77 |
+
for k, v in kwargs.items():
|
78 |
+
if k in self._valid_processor_keys:
|
79 |
+
if v != getattr(self, k):
|
80 |
+
setattr(self, k, v)
|
81 |
+
_transform_needs_rebuild = True
|
82 |
+
|
83 |
+
if _transform_needs_rebuild:
|
84 |
+
self.transform = self._build_transform()
|
85 |
+
|
86 |
+
images = make_list_of_images(images)
|
87 |
+
out = torch.stack([self.transform(image) for image in images], dim=0)
|
88 |
+
return BatchFeature(data={'pixel_values': out})
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5af329d790c12cf109dabb4e31bf20e24dc07f8aab26509fb39004998cd9674e
|
3 |
+
size 890826430
|
rope_embeddings.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from EVA CLIP
|
3 |
+
# https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from math import pi
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def broadcast(tensors, dim=-1):
|
15 |
+
num_tensors = len(tensors)
|
16 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
17 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
18 |
+
shape_len = list(shape_lens)[0]
|
19 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
20 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
21 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
22 |
+
assert all(
|
23 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
24 |
+
), 'invalid dimensions for broadcastable concatentation'
|
25 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
26 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
27 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
28 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
29 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
30 |
+
return torch.cat(tensors, dim=dim)
|
31 |
+
|
32 |
+
|
33 |
+
def rotate_half(x):
|
34 |
+
x = rearrange(x, '... (d r) -> ... d r', r=2)
|
35 |
+
x1, x2 = x.unbind(dim=-1)
|
36 |
+
x = torch.stack((-x2, x1), dim=-1)
|
37 |
+
return rearrange(x, '... d r -> ... (d r)')
|
38 |
+
|
39 |
+
|
40 |
+
class VisionRotaryEmbedding(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
dim,
|
44 |
+
pt_seq_len,
|
45 |
+
ft_seq_len=None,
|
46 |
+
custom_freqs=None,
|
47 |
+
freqs_for='lang',
|
48 |
+
theta=10000,
|
49 |
+
max_freq=10,
|
50 |
+
num_freqs=1,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
if custom_freqs:
|
54 |
+
freqs = custom_freqs
|
55 |
+
elif freqs_for == 'lang':
|
56 |
+
freqs = 1.0 / (
|
57 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
58 |
+
)
|
59 |
+
elif freqs_for == 'pixel':
|
60 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
61 |
+
elif freqs_for == 'constant':
|
62 |
+
freqs = torch.ones(num_freqs).float()
|
63 |
+
else:
|
64 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
65 |
+
|
66 |
+
if ft_seq_len is None:
|
67 |
+
ft_seq_len = pt_seq_len
|
68 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
69 |
+
|
70 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
71 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
|
72 |
+
|
73 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
74 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
|
75 |
+
|
76 |
+
freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
|
77 |
+
|
78 |
+
self.register_buffer('freqs_cos', freqs.cos())
|
79 |
+
self.register_buffer('freqs_sin', freqs.sin())
|
80 |
+
|
81 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
82 |
+
|
83 |
+
def forward(self, t, start_index=0):
|
84 |
+
rot_dim = self.freqs_cos.shape[-1]
|
85 |
+
end_index = start_index + rot_dim
|
86 |
+
assert rot_dim <= t.shape[-1], (
|
87 |
+
f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in '
|
88 |
+
f'all the positions {rot_dim}'
|
89 |
+
)
|
90 |
+
t_left, t, t_right = (
|
91 |
+
t[..., :start_index],
|
92 |
+
t[..., start_index:end_index],
|
93 |
+
t[..., end_index:],
|
94 |
+
)
|
95 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
96 |
+
|
97 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
98 |
+
|
99 |
+
|
100 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
dim,
|
104 |
+
pt_seq_len,
|
105 |
+
ft_seq_len=None,
|
106 |
+
custom_freqs=None,
|
107 |
+
freqs_for='lang',
|
108 |
+
theta=10000,
|
109 |
+
max_freq=10,
|
110 |
+
num_freqs=1,
|
111 |
+
patch_dropout=0.0,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
if custom_freqs:
|
115 |
+
freqs = custom_freqs
|
116 |
+
elif freqs_for == 'lang':
|
117 |
+
freqs = 1.0 / (
|
118 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
119 |
+
)
|
120 |
+
elif freqs_for == 'pixel':
|
121 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
122 |
+
elif freqs_for == 'constant':
|
123 |
+
freqs = torch.ones(num_freqs).float()
|
124 |
+
else:
|
125 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
126 |
+
|
127 |
+
if ft_seq_len is None:
|
128 |
+
ft_seq_len = pt_seq_len
|
129 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
130 |
+
|
131 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
132 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r=2)
|
133 |
+
freqs = broadcast((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
134 |
+
|
135 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
136 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
137 |
+
|
138 |
+
self.patch_dropout = patch_dropout
|
139 |
+
|
140 |
+
self.register_buffer('freqs_cos', freqs_cos)
|
141 |
+
self.register_buffer('freqs_sin', freqs_sin)
|
142 |
+
|
143 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
144 |
+
|
145 |
+
def forward(self, t, patch_indices_keep=None):
|
146 |
+
if patch_indices_keep is not None:
|
147 |
+
batch = t.size()[0]
|
148 |
+
batch_indices = torch.arange(batch)
|
149 |
+
batch_indices = batch_indices[..., None]
|
150 |
+
|
151 |
+
freqs_cos = repeat(
|
152 |
+
self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
|
153 |
+
)
|
154 |
+
freqs_sin = repeat(
|
155 |
+
self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
|
156 |
+
)
|
157 |
+
|
158 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
159 |
+
freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
|
160 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
161 |
+
freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
|
162 |
+
|
163 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
164 |
+
|
165 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
special_tokens_map.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": {
|
3 |
+
"content": "[CLS]",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"mask_token": {
|
10 |
+
"content": "[MASK]",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "[PAD]",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"sep_token": {
|
24 |
+
"content": "[SEP]",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"unk_token": {
|
31 |
+
"content": "[UNK]",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
}
|
37 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": true,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_basic_tokenize": true,
|
47 |
+
"do_lower_case": true,
|
48 |
+
"mask_token": "[MASK]",
|
49 |
+
"max_length": 8192,
|
50 |
+
"model_max_length": 8192,
|
51 |
+
"never_split": null,
|
52 |
+
"pad_to_multiple_of": null,
|
53 |
+
"pad_token": "[PAD]",
|
54 |
+
"pad_token_type_id": 0,
|
55 |
+
"padding_side": "right",
|
56 |
+
"sep_token": "[SEP]",
|
57 |
+
"stride": 0,
|
58 |
+
"strip_accents": null,
|
59 |
+
"tokenize_chinese_chars": true,
|
60 |
+
"tokenizer_class": "BertTokenizer",
|
61 |
+
"truncation_side": "right",
|
62 |
+
"truncation_strategy": "longest_first",
|
63 |
+
"unk_token": "[UNK]"
|
64 |
+
}
|
transform.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
import random
|
3 |
+
import warnings
|
4 |
+
from dataclasses import asdict, dataclass
|
5 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
from torchvision.transforms import (
|
10 |
+
CenterCrop,
|
11 |
+
ColorJitter,
|
12 |
+
Compose,
|
13 |
+
Grayscale,
|
14 |
+
InterpolationMode,
|
15 |
+
Normalize,
|
16 |
+
RandomResizedCrop,
|
17 |
+
Resize,
|
18 |
+
ToTensor,
|
19 |
+
)
|
20 |
+
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
21 |
+
|
22 |
+
OPENAI_DATASET_MEAN = tuple(OPENAI_CLIP_MEAN)
|
23 |
+
OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class PreprocessCfg:
|
28 |
+
size: Union[int, Tuple[int, int]] = 224
|
29 |
+
mode: str = 'RGB'
|
30 |
+
mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
|
31 |
+
std: Tuple[float, ...] = OPENAI_DATASET_STD
|
32 |
+
interpolation: str = 'bicubic'
|
33 |
+
resize_mode: str = 'shortest'
|
34 |
+
fill_color: int = 0
|
35 |
+
|
36 |
+
def __post_init__(self):
|
37 |
+
assert self.mode in ('RGB',)
|
38 |
+
|
39 |
+
@property
|
40 |
+
def num_channels(self):
|
41 |
+
return 3
|
42 |
+
|
43 |
+
@property
|
44 |
+
def input_size(self):
|
45 |
+
return (self.num_channels,) + (self.size, self.size)
|
46 |
+
|
47 |
+
|
48 |
+
_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
|
49 |
+
|
50 |
+
|
51 |
+
def merge_preprocess_dict(
|
52 |
+
base: Union[PreprocessCfg, Dict],
|
53 |
+
overlay: Dict,
|
54 |
+
):
|
55 |
+
"""Merge overlay key-value pairs on top of base preprocess cfg or dict.
|
56 |
+
Input dicts are filtered based on PreprocessCfg fields.
|
57 |
+
"""
|
58 |
+
if isinstance(base, PreprocessCfg):
|
59 |
+
base_clean = asdict(base)
|
60 |
+
else:
|
61 |
+
base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
|
62 |
+
if overlay:
|
63 |
+
overlay_clean = {
|
64 |
+
k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None
|
65 |
+
}
|
66 |
+
base_clean.update(overlay_clean)
|
67 |
+
return base_clean
|
68 |
+
|
69 |
+
|
70 |
+
def merge_preprocess_kwargs(base: Union[PreprocessCfg, Dict], **kwargs):
|
71 |
+
return merge_preprocess_dict(base, kwargs)
|
72 |
+
|
73 |
+
|
74 |
+
@dataclass
|
75 |
+
class AugmentationCfg:
|
76 |
+
scale: Tuple[float, float] = (0.9, 1.0)
|
77 |
+
ratio: Optional[Tuple[float, float]] = None
|
78 |
+
color_jitter: Optional[
|
79 |
+
Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
|
80 |
+
] = None
|
81 |
+
re_prob: Optional[float] = None
|
82 |
+
re_count: Optional[int] = None
|
83 |
+
use_timm: bool = False
|
84 |
+
|
85 |
+
# params for simclr_jitter_gray
|
86 |
+
color_jitter_prob: float = None
|
87 |
+
gray_scale_prob: float = None
|
88 |
+
|
89 |
+
|
90 |
+
def _setup_size(size, error_msg):
|
91 |
+
if isinstance(size, numbers.Number):
|
92 |
+
return int(size), int(size)
|
93 |
+
|
94 |
+
if isinstance(size, Sequence) and len(size) == 1:
|
95 |
+
return size[0], size[0]
|
96 |
+
|
97 |
+
if len(size) != 2:
|
98 |
+
raise ValueError(error_msg)
|
99 |
+
|
100 |
+
return size
|
101 |
+
|
102 |
+
|
103 |
+
class ResizeKeepRatio:
|
104 |
+
"""Resize and Keep Ratio
|
105 |
+
|
106 |
+
Copy & paste from `timm`
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
size,
|
112 |
+
longest=0.0,
|
113 |
+
interpolation=InterpolationMode.BICUBIC,
|
114 |
+
random_scale_prob=0.0,
|
115 |
+
random_scale_range=(0.85, 1.05),
|
116 |
+
random_aspect_prob=0.0,
|
117 |
+
random_aspect_range=(0.9, 1.11),
|
118 |
+
):
|
119 |
+
if isinstance(size, (list, tuple)):
|
120 |
+
self.size = tuple(size)
|
121 |
+
else:
|
122 |
+
self.size = (size, size)
|
123 |
+
self.interpolation = interpolation
|
124 |
+
self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
|
125 |
+
self.random_scale_prob = random_scale_prob
|
126 |
+
self.random_scale_range = random_scale_range
|
127 |
+
self.random_aspect_prob = random_aspect_prob
|
128 |
+
self.random_aspect_range = random_aspect_range
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def get_params(
|
132 |
+
img,
|
133 |
+
target_size,
|
134 |
+
longest,
|
135 |
+
random_scale_prob=0.0,
|
136 |
+
random_scale_range=(0.85, 1.05),
|
137 |
+
random_aspect_prob=0.0,
|
138 |
+
random_aspect_range=(0.9, 1.11),
|
139 |
+
):
|
140 |
+
"""Get parameters"""
|
141 |
+
source_size = img.size[::-1] # h, w
|
142 |
+
h, w = source_size
|
143 |
+
target_h, target_w = target_size
|
144 |
+
ratio_h = h / target_h
|
145 |
+
ratio_w = w / target_w
|
146 |
+
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
|
147 |
+
1.0 - longest
|
148 |
+
)
|
149 |
+
if random_scale_prob > 0 and random.random() < random_scale_prob:
|
150 |
+
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
|
151 |
+
ratio_factor = (ratio_factor, ratio_factor)
|
152 |
+
else:
|
153 |
+
ratio_factor = (1.0, 1.0)
|
154 |
+
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
|
155 |
+
aspect_factor = random.uniform(
|
156 |
+
random_aspect_range[0], random_aspect_range[1]
|
157 |
+
)
|
158 |
+
ratio_factor = (
|
159 |
+
ratio_factor[0] / aspect_factor,
|
160 |
+
ratio_factor[1] * aspect_factor,
|
161 |
+
)
|
162 |
+
size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
|
163 |
+
return size
|
164 |
+
|
165 |
+
def __call__(self, img):
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
img (PIL Image): Image to be cropped and resized.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
PIL Image: Resized, padded to at least target size, possibly
|
172 |
+
cropped to exactly target size
|
173 |
+
"""
|
174 |
+
size = self.get_params(
|
175 |
+
img,
|
176 |
+
self.size,
|
177 |
+
self.longest,
|
178 |
+
self.random_scale_prob,
|
179 |
+
self.random_scale_range,
|
180 |
+
self.random_aspect_prob,
|
181 |
+
self.random_aspect_range,
|
182 |
+
)
|
183 |
+
img = F.resize(img, size, self.interpolation)
|
184 |
+
return img
|
185 |
+
|
186 |
+
def __repr__(self):
|
187 |
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
188 |
+
format_string += f', interpolation={self.interpolation})'
|
189 |
+
format_string += f', longest={self.longest:.3f})'
|
190 |
+
return format_string
|
191 |
+
|
192 |
+
|
193 |
+
def center_crop_or_pad(
|
194 |
+
img: torch.Tensor, output_size: List[int], fill=0
|
195 |
+
) -> torch.Tensor:
|
196 |
+
"""Center crops and/or pads the given image.
|
197 |
+
If the image is torch Tensor, it is expected
|
198 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
199 |
+
dimensions. If image size is smaller than output size along any edge, image is
|
200 |
+
padded with 0 and then center cropped.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
img (PIL Image or Tensor): Image to be cropped.
|
204 |
+
output_size (sequence or int): (height, width) of the crop box. If int or
|
205 |
+
sequence with single int, it is used for both directions.
|
206 |
+
fill (int, Tuple[int]): Padding color
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
PIL Image or Tensor: Cropped image.
|
210 |
+
"""
|
211 |
+
if isinstance(output_size, numbers.Number):
|
212 |
+
output_size = (int(output_size), int(output_size))
|
213 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
214 |
+
output_size = (output_size[0], output_size[0])
|
215 |
+
|
216 |
+
_, image_height, image_width = F.get_dimensions(img)
|
217 |
+
crop_height, crop_width = output_size
|
218 |
+
|
219 |
+
if crop_width > image_width or crop_height > image_height:
|
220 |
+
padding_ltrb = [
|
221 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
222 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
223 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
224 |
+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
225 |
+
]
|
226 |
+
img = F.pad(img, padding_ltrb, fill=fill)
|
227 |
+
_, image_height, image_width = F.get_dimensions(img)
|
228 |
+
if crop_width == image_width and crop_height == image_height:
|
229 |
+
return img
|
230 |
+
|
231 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
232 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
233 |
+
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
|
234 |
+
|
235 |
+
|
236 |
+
class CenterCropOrPad(torch.nn.Module):
|
237 |
+
"""Crops the given image at the center.
|
238 |
+
If the image is torch Tensor, it is expected
|
239 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
240 |
+
dimensions. If image size is smaller than output size along any edge, image is
|
241 |
+
padded with 0 and then center cropped.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
245 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
246 |
+
made. If provided a sequence of length 1, it will be interpreted as
|
247 |
+
(size[0], size[0]).
|
248 |
+
"""
|
249 |
+
|
250 |
+
def __init__(self, size, fill=0):
|
251 |
+
super().__init__()
|
252 |
+
self.size = _setup_size(
|
253 |
+
size, error_msg='Please provide only two dimensions (h, w) for size.'
|
254 |
+
)
|
255 |
+
self.fill = fill
|
256 |
+
|
257 |
+
def forward(self, img):
|
258 |
+
"""
|
259 |
+
Args:
|
260 |
+
img (PIL Image or Tensor): Image to be cropped.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
PIL Image or Tensor: Cropped image.
|
264 |
+
"""
|
265 |
+
return center_crop_or_pad(img, self.size, fill=self.fill)
|
266 |
+
|
267 |
+
def __repr__(self) -> str:
|
268 |
+
return f'{self.__class__.__name__}(size={self.size})'
|
269 |
+
|
270 |
+
|
271 |
+
def _convert_to_rgb(image):
|
272 |
+
return image.convert('RGB')
|
273 |
+
|
274 |
+
|
275 |
+
class _ColorJitter(object):
|
276 |
+
"""
|
277 |
+
Apply Color Jitter to the PIL image with a specified probability.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
|
281 |
+
assert 0.0 <= p <= 1.0
|
282 |
+
self.p = p
|
283 |
+
self.transf = ColorJitter(
|
284 |
+
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
|
285 |
+
)
|
286 |
+
|
287 |
+
def __call__(self, img):
|
288 |
+
if random.random() < self.p:
|
289 |
+
return self.transf(img)
|
290 |
+
else:
|
291 |
+
return img
|
292 |
+
|
293 |
+
|
294 |
+
class _GrayScale(object):
|
295 |
+
"""
|
296 |
+
Apply Gray Scale to the PIL image with a specified probability.
|
297 |
+
"""
|
298 |
+
|
299 |
+
def __init__(self, p=0.2):
|
300 |
+
assert 0.0 <= p <= 1.0
|
301 |
+
self.p = p
|
302 |
+
self.transf = Grayscale(num_output_channels=3)
|
303 |
+
|
304 |
+
def __call__(self, img):
|
305 |
+
if random.random() < self.p:
|
306 |
+
return self.transf(img)
|
307 |
+
else:
|
308 |
+
return img
|
309 |
+
|
310 |
+
|
311 |
+
def image_transform(
|
312 |
+
image_size: Union[int, Tuple[int, int]],
|
313 |
+
is_train: bool,
|
314 |
+
mean: Optional[Tuple[float, ...]] = None,
|
315 |
+
std: Optional[Tuple[float, ...]] = None,
|
316 |
+
resize_mode: Optional[str] = None,
|
317 |
+
interpolation: Optional[str] = None,
|
318 |
+
fill_color: int = 0,
|
319 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
320 |
+
):
|
321 |
+
mean = mean or OPENAI_DATASET_MEAN
|
322 |
+
if not isinstance(mean, (list, tuple)):
|
323 |
+
mean = (mean,) * 3
|
324 |
+
|
325 |
+
std = std or OPENAI_DATASET_STD
|
326 |
+
if not isinstance(std, (list, tuple)):
|
327 |
+
std = (std,) * 3
|
328 |
+
|
329 |
+
interpolation = interpolation or 'bicubic'
|
330 |
+
assert interpolation in ['bicubic', 'bilinear', 'random']
|
331 |
+
# NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for
|
332 |
+
# inference if set
|
333 |
+
interpolation_mode = (
|
334 |
+
InterpolationMode.BILINEAR
|
335 |
+
if interpolation == 'bilinear'
|
336 |
+
else InterpolationMode.BICUBIC
|
337 |
+
)
|
338 |
+
|
339 |
+
resize_mode = resize_mode or 'shortest'
|
340 |
+
assert resize_mode in ('shortest', 'longest', 'squash')
|
341 |
+
|
342 |
+
if isinstance(aug_cfg, dict):
|
343 |
+
aug_cfg = AugmentationCfg(**aug_cfg)
|
344 |
+
else:
|
345 |
+
aug_cfg = aug_cfg or AugmentationCfg()
|
346 |
+
|
347 |
+
normalize = Normalize(mean=mean, std=std)
|
348 |
+
|
349 |
+
if is_train:
|
350 |
+
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
351 |
+
use_timm = aug_cfg_dict.pop('use_timm', False)
|
352 |
+
if use_timm:
|
353 |
+
from timm.data import create_transform # timm can still be optional
|
354 |
+
|
355 |
+
if isinstance(image_size, (tuple, list)):
|
356 |
+
assert len(image_size) >= 2
|
357 |
+
input_size = (3,) + image_size[-2:]
|
358 |
+
else:
|
359 |
+
input_size = (3, image_size, image_size)
|
360 |
+
|
361 |
+
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
362 |
+
# drop extra non-timm items
|
363 |
+
aug_cfg_dict.pop('color_jitter_prob', None)
|
364 |
+
aug_cfg_dict.pop('gray_scale_prob', None)
|
365 |
+
|
366 |
+
train_transform = create_transform(
|
367 |
+
input_size=input_size,
|
368 |
+
is_training=True,
|
369 |
+
hflip=0.0,
|
370 |
+
mean=mean,
|
371 |
+
std=std,
|
372 |
+
re_mode='pixel',
|
373 |
+
interpolation=interpolation,
|
374 |
+
**aug_cfg_dict,
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
train_transform = [
|
378 |
+
RandomResizedCrop(
|
379 |
+
image_size,
|
380 |
+
scale=aug_cfg_dict.pop('scale'),
|
381 |
+
interpolation=InterpolationMode.BICUBIC,
|
382 |
+
),
|
383 |
+
_convert_to_rgb,
|
384 |
+
]
|
385 |
+
if aug_cfg.color_jitter_prob:
|
386 |
+
assert (
|
387 |
+
aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
|
388 |
+
)
|
389 |
+
train_transform.extend(
|
390 |
+
[_ColorJitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)]
|
391 |
+
)
|
392 |
+
if aug_cfg.gray_scale_prob:
|
393 |
+
train_transform.extend([_GrayScale(aug_cfg.gray_scale_prob)])
|
394 |
+
train_transform.extend(
|
395 |
+
[
|
396 |
+
ToTensor(),
|
397 |
+
normalize,
|
398 |
+
]
|
399 |
+
)
|
400 |
+
train_transform = Compose(train_transform)
|
401 |
+
if aug_cfg_dict:
|
402 |
+
warnings.warn(
|
403 |
+
f'Unused augmentation cfg items, specify `use_timm` to use '
|
404 |
+
f'({list(aug_cfg_dict.keys())}).'
|
405 |
+
)
|
406 |
+
return train_transform
|
407 |
+
else:
|
408 |
+
if resize_mode == 'longest':
|
409 |
+
transforms = [
|
410 |
+
ResizeKeepRatio(
|
411 |
+
image_size, interpolation=interpolation_mode, longest=1
|
412 |
+
),
|
413 |
+
CenterCropOrPad(image_size, fill=fill_color),
|
414 |
+
]
|
415 |
+
elif resize_mode == 'squash':
|
416 |
+
if isinstance(image_size, int):
|
417 |
+
image_size = (image_size, image_size)
|
418 |
+
transforms = [
|
419 |
+
Resize(image_size, interpolation=interpolation_mode),
|
420 |
+
]
|
421 |
+
else:
|
422 |
+
assert resize_mode == 'shortest'
|
423 |
+
if not isinstance(image_size, (tuple, list)):
|
424 |
+
image_size = (image_size, image_size)
|
425 |
+
if image_size[0] == image_size[1]:
|
426 |
+
# simple case, use torchvision built-in Resize w/ shortest edge mode
|
427 |
+
# (scalar size arg)
|
428 |
+
transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
|
429 |
+
else:
|
430 |
+
# resize shortest edge to matching target dim for non-square target
|
431 |
+
transforms = [ResizeKeepRatio(image_size)]
|
432 |
+
transforms += [CenterCrop(image_size)]
|
433 |
+
|
434 |
+
transforms.extend(
|
435 |
+
[
|
436 |
+
_convert_to_rgb,
|
437 |
+
ToTensor(),
|
438 |
+
normalize,
|
439 |
+
]
|
440 |
+
)
|
441 |
+
return Compose(transforms)
|
442 |
+
|
443 |
+
|
444 |
+
def image_transform_v2(
|
445 |
+
cfg: PreprocessCfg,
|
446 |
+
is_train: bool,
|
447 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
448 |
+
):
|
449 |
+
return image_transform(
|
450 |
+
image_size=cfg.size,
|
451 |
+
is_train=is_train,
|
452 |
+
mean=cfg.mean,
|
453 |
+
std=cfg.std,
|
454 |
+
interpolation=cfg.interpolation,
|
455 |
+
resize_mode=cfg.resize_mode,
|
456 |
+
fill_color=cfg.fill_color,
|
457 |
+
aug_cfg=aug_cfg,
|
458 |
+
)
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|