jupyterjazz commited on
Commit
c16d5ec
·
1 Parent(s): da7134d

fix: attn arg

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_jina_embeddings_v4.py +1 -4
config.json CHANGED
@@ -56,5 +56,6 @@
56
  "vocab_size": 151936,
57
  "truncate_dim": null,
58
  "task_names": ["retrieval", "text-matching", "code"],
59
- "matryoshka_dims": [128, 256, 512, 1024]
 
60
  }
 
56
  "vocab_size": 151936,
57
  "truncate_dim": null,
58
  "task_names": ["retrieval", "text-matching", "code"],
59
+ "matryoshka_dims": [128, 256, 512, 1024],
60
+ "_attn_implementation": "flash_attention_2"
61
  }
modeling_jina_embeddings_v4.py CHANGED
@@ -519,10 +519,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
519
  """
520
  if "torch_dtype" not in kwargs:
521
  kwargs["torch_dtype"] = "auto"
522
-
523
- if torch.cuda.is_available() and "attn_implementation" not in kwargs:
524
- kwargs["attn_implementation"] = "flash_attention_2"
525
-
526
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
527
 
528
  base_model = super().from_pretrained(
 
519
  """
520
  if "torch_dtype" not in kwargs:
521
  kwargs["torch_dtype"] = "auto"
522
+
 
 
 
523
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
524
 
525
  base_model = super().from_pretrained(