fix-attention-implementation-argument (#15)
Browse files- fix: attn arg (c16d5ecc6f527a6cfe77371ce05fecc89ff8b32a)
- config.json +2 -1
- modeling_jina_embeddings_v4.py +1 -4
config.json
CHANGED
@@ -56,5 +56,6 @@
|
|
56 |
"vocab_size": 151936,
|
57 |
"truncate_dim": null,
|
58 |
"task_names": ["retrieval", "text-matching", "code"],
|
59 |
-
"matryoshka_dims": [128, 256, 512, 1024]
|
|
|
60 |
}
|
|
|
56 |
"vocab_size": 151936,
|
57 |
"truncate_dim": null,
|
58 |
"task_names": ["retrieval", "text-matching", "code"],
|
59 |
+
"matryoshka_dims": [128, 256, 512, 1024],
|
60 |
+
"_attn_implementation": "flash_attention_2"
|
61 |
}
|
modeling_jina_embeddings_v4.py
CHANGED
@@ -519,10 +519,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
519 |
"""
|
520 |
if "torch_dtype" not in kwargs:
|
521 |
kwargs["torch_dtype"] = "auto"
|
522 |
-
|
523 |
-
if torch.cuda.is_available() and "attn_implementation" not in kwargs:
|
524 |
-
kwargs["attn_implementation"] = "flash_attention_2"
|
525 |
-
|
526 |
kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
|
527 |
|
528 |
base_model = super().from_pretrained(
|
|
|
519 |
"""
|
520 |
if "torch_dtype" not in kwargs:
|
521 |
kwargs["torch_dtype"] = "auto"
|
522 |
+
|
|
|
|
|
|
|
523 |
kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
|
524 |
|
525 |
base_model = super().from_pretrained(
|