|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
import torch |
|
import onnx |
|
import onnxruntime as ort |
|
import numpy as np |
|
import os |
|
from tqdm import tqdm |
|
from transformers import AutoConfig |
|
from typing import List, Tuple |
|
from axengine import InferenceSession |
|
from ml_dtypes import bfloat16 |
|
from utils.infer_func import InferManager |
|
import argparse |
|
from PIL import Image |
|
from torchvision.transforms import Resize, ToTensor, Normalize, Compose |
|
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |
|
|
|
|
|
def run_vision_model( |
|
encoder, |
|
pixel_values, |
|
patch_attention_mask=None, |
|
): |
|
batch_size = pixel_values.size(0) |
|
if patch_attention_mask is None: |
|
patch_size = 16 |
|
patch_attention_mask = torch.ones( |
|
( |
|
batch_size, |
|
pixel_values.size(2) // patch_size, |
|
pixel_values.size(3) // patch_size, |
|
) |
|
) |
|
patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device) |
|
|
|
hidden_states = embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) |
|
|
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1) |
|
|
|
|
|
|
|
if not torch.any(~patch_attention_mask): |
|
patch_attention_mask = None |
|
elif not self._use_flash_attention_2: |
|
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) |
|
|
|
|
|
|
|
encoder_outputs = encoder.run(None, {"input": hidden_states.detach().cpu().to(dtype=torch.float32).numpy()})[0] |
|
encoder_outputs = torch.from_numpy(encoder_outputs).to(device, dtype=hidden_states.dtype) |
|
|
|
return encoder_outputs |
|
|
|
|
|
def get_image_features(encoder, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): |
|
""" |
|
Encodes images into continuous embeddings that can be forwarded to the language model. |
|
|
|
Args: |
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): |
|
The tensors corresponding to the input images. |
|
pixel_attention_mask (`torch.LongTensor`, *optional*): |
|
The attention mask indicating padded regions in the image. |
|
""" |
|
batch_size, num_images, num_channels, height, width = pixel_values.shape |
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) |
|
|
|
|
|
nb_values_per_image = pixel_values.shape[1:].numel() |
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image |
|
|
|
if not any(real_images_inds): |
|
|
|
real_images_inds[0] = True |
|
|
|
pixel_values = pixel_values[real_images_inds].contiguous() |
|
|
|
if pixel_attention_mask is None: |
|
pixel_attention_mask = torch.ones( |
|
size=[pixel_values.shape[i] for i in (0, 2, 3)], |
|
dtype=torch.bool, |
|
device=pixel_values.device, |
|
) |
|
else: |
|
|
|
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) |
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() |
|
patch_size = 16 |
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) |
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) |
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() |
|
|
|
|
|
image_hidden_states = run_vision_model(encoder, pixel_values, patch_attention_mask) |
|
|
|
|
|
|
|
return image_hidden_states |
|
|
|
|
|
def inputs_merger( |
|
input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor |
|
): |
|
""" |
|
This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. |
|
The merging happens as follows: |
|
- The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`. |
|
- We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. |
|
We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. |
|
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. |
|
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. |
|
""" |
|
_, patch_size, _ = image_hidden_states.shape |
|
|
|
image_mask = input_ids == 49190 |
|
num_image_tokens = image_mask.sum(dim=1) |
|
if not torch.all(num_image_tokens % patch_size == 0): |
|
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.") |
|
|
|
blocks_per_sample = num_image_tokens // patch_size |
|
|
|
offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) |
|
block_offset = offsets[:-1] |
|
row_cum = image_mask.cumsum(dim=-1) |
|
chunk_idx = (row_cum - 1) // patch_size |
|
local_idx = (row_cum - 1) % patch_size |
|
block_idx = block_offset.unsqueeze(1) + chunk_idx |
|
|
|
image_embeds = torch.zeros_like(inputs_embeds) |
|
image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] |
|
|
|
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) |
|
return merged_embeds |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
""" |
|
python3 infer_axmodel.py -i ../assets/panda.jpg --vit_model ./vit-models/vision_model.axmodel |
|
""" |
|
|
|
prompt = None |
|
parser = argparse.ArgumentParser(description="Model configuration parameters") |
|
parser.add_argument("--hf_model", type=str, default="./SmolVLM2-500M-Video-Instruct/", |
|
help="Path to HuggingFace model") |
|
parser.add_argument("--axmodel_path", type=str, default="./SmolVLM2-500M-Video-Instruct_axmodel/", |
|
help="Path to save compiled axmodel of llama model") |
|
parser.add_argument("--vit_model", type=str, default='./vit-models/vision_model.axmodel', |
|
help="Path to save compiled axmodel of llama model") |
|
parser.add_argument("-i", "--images", type=str, default="../assets/bee.jpg", |
|
help="Path to the test image.") |
|
parser.add_argument("-q", "--question", type=str, default="Can you describe this image?", |
|
help="Your question that you want to ask the model.") |
|
args = parser.parse_args() |
|
|
|
hf_model_path = args.hf_model |
|
axmodel_path = args.axmodel_path |
|
images = args.images |
|
prompt = args.question |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
embeddings = torch.load("./embeds/SmolVLMVisionEmbeddings.pkl", map_location=device, weights_only=False) |
|
embeds = np.load(os.path.join(axmodel_path, "model.embed_tokens.weight.npy")) |
|
|
|
encoder = InferenceSession(args.vit_model) |
|
|
|
processor = AutoProcessor.from_pretrained(hf_model_path) |
|
config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True) |
|
tokenizer = processor.tokenizer |
|
|
|
TARGET_IMAGE_SIZE = (512, 512) |
|
image = Image.open(images).convert('RGB') |
|
|
|
|
|
preprocess = Compose([ |
|
Resize(TARGET_IMAGE_SIZE), |
|
|
|
|
|
]) |
|
|
|
preprocessed_image = preprocess(image) |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": preprocessed_image}, |
|
|
|
{"type": "text", "text": prompt}, |
|
] |
|
}, |
|
] |
|
|
|
inputs = processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt", |
|
).to(device, dtype=torch.bfloat16) |
|
|
|
pixel_values = inputs["pixel_values"] |
|
pixel_attention_mask = inputs["pixel_attention_mask"] |
|
input_ids = inputs["input_ids"] |
|
input_ids_length = input_ids.shape[1] |
|
|
|
inputs_embeds = np.take(embeds, input_ids[0].cpu().numpy().tolist(), axis=0)[None, ...] |
|
inputs_embeds = torch.from_numpy(inputs_embeds).to(device, dtype=torch.bfloat16) |
|
|
|
""" |
|
miniforge-pypy3/envs/lerobot/lib/python3.10/site-packages/transformers/models/smolvlm/modeling_smolvlm.py(681)get_image_features() |
|
""" |
|
image_hidden_states = get_image_features(encoder, pixel_values, pixel_attention_mask) |
|
|
|
inputs_embeds = inputs_merger( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
image_hidden_states=image_hidden_states, |
|
).to(dtype=torch.float32).cpu().numpy() |
|
|
|
prefill_data = inputs_embeds |
|
prefill_data = prefill_data.astype(bfloat16) |
|
token_ids = input_ids[0].cpu().numpy().tolist() |
|
token_len = len(token_ids) |
|
cfg = config.text_config |
|
|
|
imer = InferManager(cfg, axmodel_path) |
|
|
|
token_ids = imer.prefill(tokenizer, token_ids, prefill_data[0], slice_len=128) |
|
imer.decode(tokenizer, token_ids, embeds, slice_len=128) |
|
print("\n") |
|
|