Update modeling_ovis.py
Browse files- modeling_ovis.py +32 -13
modeling_ovis.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
|
|
|
| 3 |
from importlib import import_module
|
| 4 |
from typing import List, Callable, Union, Optional, Dict
|
| 5 |
|
| 6 |
import PIL.Image
|
| 7 |
import torch
|
|
|
|
| 8 |
from torch import Tensor
|
| 9 |
from torch.nn import init
|
| 10 |
from torch.nn.functional import softmax, gumbel_softmax, pad
|
|
@@ -556,25 +558,42 @@ class Ovis(OvisPreTrainedModel):
|
|
| 556 |
cache_cls = HybridCache
|
| 557 |
llm = self.get_llm()
|
| 558 |
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
if need_new_cache:
|
| 567 |
if hasattr(llm.config, "_pre_quantization_dtype"):
|
| 568 |
cache_dtype = llm.config._pre_quantization_dtype
|
| 569 |
else:
|
| 570 |
cache_dtype = llm.dtype
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
else:
|
| 579 |
llm._cache.reset()
|
| 580 |
return llm._cache
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
from packaging import version
|
| 4 |
from importlib import import_module
|
| 5 |
from typing import List, Callable, Union, Optional, Dict
|
| 6 |
|
| 7 |
import PIL.Image
|
| 8 |
import torch
|
| 9 |
+
import transformers
|
| 10 |
from torch import Tensor
|
| 11 |
from torch.nn import init
|
| 12 |
from torch.nn.functional import softmax, gumbel_softmax, pad
|
|
|
|
| 558 |
cache_cls = HybridCache
|
| 559 |
llm = self.get_llm()
|
| 560 |
|
| 561 |
+
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
| 562 |
+
need_new_cache = (
|
| 563 |
+
not hasattr(llm, "_cache")
|
| 564 |
+
or (not isinstance(llm._cache, cache_cls))
|
| 565 |
+
or llm._cache.batch_size != batch_size
|
| 566 |
+
or llm._cache.max_cache_len < max_cache_len
|
| 567 |
+
)
|
| 568 |
+
else:
|
| 569 |
+
need_new_cache = (
|
| 570 |
+
not hasattr(llm, "_cache")
|
| 571 |
+
or (not isinstance(llm._cache, cache_cls))
|
| 572 |
+
or llm._cache.max_batch_size != batch_size
|
| 573 |
+
or llm._cache.max_cache_len < max_cache_len
|
| 574 |
+
)
|
| 575 |
|
| 576 |
if need_new_cache:
|
| 577 |
if hasattr(llm.config, "_pre_quantization_dtype"):
|
| 578 |
cache_dtype = llm.config._pre_quantization_dtype
|
| 579 |
else:
|
| 580 |
cache_dtype = llm.dtype
|
| 581 |
+
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
| 582 |
+
llm._cache = cache_cls(
|
| 583 |
+
config=llm.config,
|
| 584 |
+
batch_size=batch_size,
|
| 585 |
+
max_cache_len=max_cache_len,
|
| 586 |
+
device=llm.device,
|
| 587 |
+
dtype=cache_dtype,
|
| 588 |
+
)
|
| 589 |
+
else:
|
| 590 |
+
llm._cache = cache_cls(
|
| 591 |
+
config=llm.config,
|
| 592 |
+
max_batch_size=batch_size,
|
| 593 |
+
max_cache_len=max_cache_len,
|
| 594 |
+
device=llm.device,
|
| 595 |
+
dtype=cache_dtype,
|
| 596 |
+
)
|
| 597 |
else:
|
| 598 |
llm._cache.reset()
|
| 599 |
return llm._cache
|