Fix inference code
Browse files- modeling_minicpmv.py +6 -5
modeling_minicpmv.py
CHANGED
|
@@ -5,7 +5,6 @@ import torch
|
|
| 5 |
import torchvision
|
| 6 |
from copy import deepcopy
|
| 7 |
from PIL import Image
|
| 8 |
-
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 9 |
from torchvision import transforms
|
| 10 |
from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast
|
| 11 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
|
@@ -13,6 +12,8 @@ from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransfo
|
|
| 13 |
from .configuration_minicpm import MiniCPMVConfig
|
| 14 |
from .resampler import Resampler
|
| 15 |
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class MiniCPMVPreTrainedModel(LlamaPreTrainedModel):
|
| 18 |
config_class = MiniCPMVConfig
|
|
@@ -352,6 +353,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 352 |
if image is not None and isinstance(copy_msgs[0]['content'], str):
|
| 353 |
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
|
| 354 |
|
|
|
|
|
|
|
| 355 |
for i, msg in enumerate(copy_msgs):
|
| 356 |
role = msg["role"]
|
| 357 |
content = msg["content"]
|
|
@@ -361,8 +364,6 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 361 |
if isinstance(content, str):
|
| 362 |
content = [content]
|
| 363 |
|
| 364 |
-
images = []
|
| 365 |
-
tgt_sizes = []
|
| 366 |
cur_msgs = []
|
| 367 |
for c in content:
|
| 368 |
if isinstance(c, Image.Image):
|
|
@@ -387,10 +388,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 387 |
elif isinstance(c, str):
|
| 388 |
cur_msgs.append(c)
|
| 389 |
|
| 390 |
-
if tgt_sizes:
|
| 391 |
-
tgt_sizes = torch.vstack(tgt_sizes)
|
| 392 |
|
| 393 |
msg['content'] = '\n'.join(cur_msgs)
|
|
|
|
|
|
|
| 394 |
|
| 395 |
input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
|
| 396 |
|
|
|
|
| 5 |
import torchvision
|
| 6 |
from copy import deepcopy
|
| 7 |
from PIL import Image
|
|
|
|
| 8 |
from torchvision import transforms
|
| 9 |
from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast
|
| 10 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
|
|
|
| 12 |
from .configuration_minicpm import MiniCPMVConfig
|
| 13 |
from .resampler import Resampler
|
| 14 |
|
| 15 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
|
| 16 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
|
| 17 |
|
| 18 |
class MiniCPMVPreTrainedModel(LlamaPreTrainedModel):
|
| 19 |
config_class = MiniCPMVConfig
|
|
|
|
| 353 |
if image is not None and isinstance(copy_msgs[0]['content'], str):
|
| 354 |
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
|
| 355 |
|
| 356 |
+
images = []
|
| 357 |
+
tgt_sizes = []
|
| 358 |
for i, msg in enumerate(copy_msgs):
|
| 359 |
role = msg["role"]
|
| 360 |
content = msg["content"]
|
|
|
|
| 364 |
if isinstance(content, str):
|
| 365 |
content = [content]
|
| 366 |
|
|
|
|
|
|
|
| 367 |
cur_msgs = []
|
| 368 |
for c in content:
|
| 369 |
if isinstance(c, Image.Image):
|
|
|
|
| 388 |
elif isinstance(c, str):
|
| 389 |
cur_msgs.append(c)
|
| 390 |
|
|
|
|
|
|
|
| 391 |
|
| 392 |
msg['content'] = '\n'.join(cur_msgs)
|
| 393 |
+
if tgt_sizes:
|
| 394 |
+
tgt_sizes = torch.vstack(tgt_sizes)
|
| 395 |
|
| 396 |
input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
|
| 397 |
|