|
import argparse |
|
import requests |
|
import onnxruntime |
|
from transformers import AutoTokenizer |
|
import numpy as np |
|
import time |
|
from PIL import Image |
|
|
|
IMAGE_TOKEN_INDEX = 151646 |
|
MAX_GEN_LEN = 128 |
|
USE_SAMPLING = True |
|
|
|
print("Loading inference sessions...") |
|
load_start = time.time() |
|
|
|
image_emb_session = onnxruntime.InferenceSession("vlm/model/vision_encoder.onnx") |
|
text_emb_session = onnxruntime.InferenceSession("vlm/model/token_embedding.onnx") |
|
decoding_session = onnxruntime.InferenceSession("vlm/model/decoder.onnx") |
|
|
|
|
|
load_end = time.time() |
|
print(f"Inference sessions are loaded. Loading takes {load_end-load_start:0.2f} sec") |
|
|
|
|
|
def main(args): |
|
tokenizer = AutoTokenizer.from_pretrained("./vlm/tokenizer") |
|
tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
|
|
query = args.input_text |
|
prompt = f"<|im_start|>user\n<image>\n{query}<|im_end|>\n<|im_start|>assistant\n" |
|
past_kv_values, first_token, input_token_len = prefill(args, tokenizer, prompt) |
|
|
|
decode(args, tokenizer, past_kv_values, first_token, input_token_len) |
|
|
|
|
|
def process_image(image_path): |
|
|
|
if "https" in image_path: |
|
image = Image.open(requests.get(image_path, stream=True).raw) |
|
else: |
|
image = Image.open(image_path) |
|
crop_size = (224, 224) |
|
do_center_crop = True |
|
do_convert_rgb = True |
|
do_normalize = True |
|
do_rescale = True |
|
do_resize = True |
|
image_mean = [0.48145466, 0.4578275, 0.40821073] |
|
image_std = [0.26862954, 0.26130258, 0.27577711] |
|
rescale_factor = 0.00392156862745098 |
|
size = {"shortest_edge": 224} |
|
resample = Image.BICUBIC |
|
|
|
|
|
if do_convert_rgb: |
|
image = image.convert("RGB") |
|
|
|
|
|
if do_resize: |
|
shortest_edge = min(image.size) |
|
scale_factor = size["shortest_edge"] / shortest_edge |
|
new_size = (int(image.width * scale_factor), int(image.height * scale_factor)) |
|
image = image.resize(new_size, resample=resample) |
|
|
|
|
|
if do_center_crop: |
|
left = (image.width - crop_size[0]) / 2 |
|
top = (image.height - crop_size[1]) / 2 |
|
right = (image.width + crop_size[0]) / 2 |
|
bottom = (image.height + crop_size[1]) / 2 |
|
image = image.crop((left, top, right, bottom)) |
|
|
|
|
|
image_array = np.array(image).astype(np.float32) |
|
|
|
|
|
if do_rescale: |
|
image_array = image_array * rescale_factor |
|
|
|
|
|
if do_normalize: |
|
image_array = (image_array - image_mean) / image_std |
|
|
|
|
|
image_array = np.transpose(image_array, (2, 0, 1)) |
|
|
|
|
|
image_array = np.expand_dims(image_array, axis=0) |
|
|
|
return image_array.astype(np.float32) |
|
|
|
|
|
def top_p_sampling(last_logits, top_p=0.99): |
|
sorted_indices = np.argsort(-last_logits) |
|
sorted_logits = last_logits[sorted_indices] |
|
|
|
cumulative_probs = np.cumsum(np.exp(sorted_logits - np.max(sorted_logits))) |
|
cumulative_probs /= cumulative_probs[-1] |
|
|
|
cutoff_index = np.searchsorted(cumulative_probs, top_p, side="right") |
|
|
|
probs = np.exp(sorted_logits[: cutoff_index + 1] - np.max(sorted_logits[: cutoff_index + 1])) |
|
probs /= np.sum(probs) |
|
|
|
next_token = np.random.choice(sorted_indices[: cutoff_index + 1], p=probs) |
|
|
|
return next_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prefill(args, tokenizer, input_prompt): |
|
print("Running prefill step...") |
|
prefill_start = time.time() |
|
|
|
input_ids = tokenizer(input_prompt)["input_ids"] |
|
image_token_pos = input_ids.index(IMAGE_TOKEN_INDEX) |
|
|
|
pixel_value = process_image(args.image_path) |
|
|
|
|
|
image_emb_output = image_emb_session.run(None, {"pixel_values": pixel_value}) |
|
image_features_proj = image_emb_output[0] |
|
|
|
|
|
text_emb_output = text_emb_session.run(None, {"input_ids": [input_ids]}) |
|
input_features = text_emb_output[0] |
|
|
|
|
|
pre_image_text_emb = input_features[:, :image_token_pos, :] |
|
post_image_text_emb = input_features[:, image_token_pos + 1 :, :] |
|
|
|
|
|
hidden_states = np.concatenate((pre_image_text_emb, image_features_proj, post_image_text_emb), axis=1) |
|
input_token_len = hidden_states.shape[1] |
|
|
|
|
|
prefill_input = { |
|
"/model/embed_tokens/Gather_output_0": hidden_states, |
|
"attention_mask": np.expand_dims(np.ones(input_token_len).astype(np.int64), axis=0), |
|
"position_ids": np.expand_dims(np.arange(input_token_len), axis=0), |
|
} |
|
for i in range(24): |
|
entities = ["key", "value"] |
|
for entity in entities: |
|
input_name = f"past_key_values.{i}.{entity}" |
|
prefill_input[input_name] = np.random.rand(1, 2, 0, 64).astype(np.float32) |
|
|
|
|
|
prefill_outputs = decoding_session.run(None, prefill_input) |
|
|
|
|
|
past_kv_values = prefill_outputs[1:] |
|
|
|
|
|
if USE_SAMPLING: |
|
last_logits = prefill_outputs[0][0][-1] |
|
next_token = top_p_sampling(last_logits) |
|
else: |
|
next_token = prefill_outputs[0].argmax(-1)[0][-1] |
|
|
|
prefill_done = time.time() |
|
print(f"Prefill step done. Throughtput: {input_token_len/(prefill_done - prefill_start):0.2f} token/sec") |
|
|
|
return past_kv_values, next_token, input_token_len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode(args, tokenizer, past_kv_values, first_token, input_token_len): |
|
print("Runing decode step...", end="\n\n") |
|
decode_start = time.time() |
|
|
|
generated_ids = [first_token] |
|
next_token = first_token |
|
|
|
for last_token_id in range(MAX_GEN_LEN): |
|
embedding_output = text_emb_session.run(None, {"input_ids": [[next_token]]}) |
|
|
|
|
|
hidden_states = embedding_output[0] |
|
|
|
|
|
decoding_input = { |
|
"/model/embed_tokens/Gather_output_0": hidden_states.astype(np.float32), |
|
"attention_mask": [[1]], |
|
"position_ids": [[input_token_len]], |
|
} |
|
input_token_len += 1 |
|
for j in range(24): |
|
for k in range(2): |
|
if k == 0: |
|
input_name = f"past_key_values.{j}.key" |
|
else: |
|
input_name = f"past_key_values.{j}.value" |
|
decoding_input[input_name] = past_kv_values[2 * j + k].astype(np.float32) |
|
|
|
|
|
decoding_outputs = decoding_session.run(None, decoding_input) |
|
|
|
|
|
past_kv_values = decoding_outputs[1:] |
|
|
|
|
|
last_logits = decoding_outputs[0][0][-1] |
|
|
|
if USE_SAMPLING: |
|
next_token = top_p_sampling(last_logits) |
|
else: |
|
next_token = decoding_outputs[0].argmax(-1)[0][-1] |
|
|
|
if next_token == tokenizer.eos_token_id: |
|
break |
|
|
|
|
|
generated_ids.append(next_token) |
|
|
|
decode_done = time.time() |
|
response = tokenizer.decode(generated_ids) |
|
|
|
print(f"Response: {response}") |
|
with open(args.output_path, 'w') as f: |
|
f.write(response) |
|
print(f"\nDecode step done. Throughtput: {last_token_id/(decode_done - decode_start):0.2f} token/sec") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--input_text", type=str, help="Input query for inference", default="Where was this photo taken?") |
|
parser.add_argument("--image_path", type=str, help="Local image path or image url", default="assets/test_image.png") |
|
parser.add_argument("--output_path", type=str, help="Output path to save the response", default="output.txt") |
|
args = parser.parse_args() |
|
|
|
main(args) |