jupyterjazz commited on
Commit
debae2a
·
verified ·
1 Parent(s): da7134d

fix-attention-implementation-argument (#15)

Browse files

- fix: attn arg (c16d5ecc6f527a6cfe77371ce05fecc89ff8b32a)

Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_jina_embeddings_v4.py +1 -4
config.json CHANGED
@@ -56,5 +56,6 @@
56
  "vocab_size": 151936,
57
  "truncate_dim": null,
58
  "task_names": ["retrieval", "text-matching", "code"],
59
- "matryoshka_dims": [128, 256, 512, 1024]
 
60
  }
 
56
  "vocab_size": 151936,
57
  "truncate_dim": null,
58
  "task_names": ["retrieval", "text-matching", "code"],
59
+ "matryoshka_dims": [128, 256, 512, 1024],
60
+ "_attn_implementation": "flash_attention_2"
61
  }
modeling_jina_embeddings_v4.py CHANGED
@@ -519,10 +519,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
519
  """
520
  if "torch_dtype" not in kwargs:
521
  kwargs["torch_dtype"] = "auto"
522
-
523
- if torch.cuda.is_available() and "attn_implementation" not in kwargs:
524
- kwargs["attn_implementation"] = "flash_attention_2"
525
-
526
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
527
 
528
  base_model = super().from_pretrained(
 
519
  """
520
  if "torch_dtype" not in kwargs:
521
  kwargs["torch_dtype"] = "auto"
522
+
 
 
 
523
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
524
 
525
  base_model = super().from_pretrained(