Update modeling_ovis.py
Browse files- modeling_ovis.py +4 -4
modeling_ovis.py
CHANGED
|
@@ -288,10 +288,10 @@ class Ovis(OvisPreTrainedModel):
|
|
| 288 |
super().__init__(config, *inputs, **kwargs)
|
| 289 |
attn_kwargs = dict()
|
| 290 |
if self.config.llm_attn_implementation:
|
| 291 |
-
if self.config.llm_attn_implementation == "flash_attention_2":
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
|
| 296 |
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
|
| 297 |
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
|
|
|
|
| 288 |
super().__init__(config, *inputs, **kwargs)
|
| 289 |
attn_kwargs = dict()
|
| 290 |
if self.config.llm_attn_implementation:
|
| 291 |
+
# if self.config.llm_attn_implementation == "flash_attention_2":
|
| 292 |
+
# assert (is_flash_attn_2_available() and
|
| 293 |
+
# version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.6.3")), \
|
| 294 |
+
# "Using `flash_attention_2` requires having `flash_attn>=2.6.3` installed."
|
| 295 |
attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
|
| 296 |
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
|
| 297 |
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
|