import argparse import requests import onnxruntime from transformers import AutoTokenizer import numpy as np import time from PIL import Image IMAGE_TOKEN_INDEX = 151646 MAX_GEN_LEN = 128 USE_SAMPLING = True print("Loading inference sessions...") load_start = time.time() image_emb_session = onnxruntime.InferenceSession("vlm/model/vision_encoder.onnx") text_emb_session = onnxruntime.InferenceSession("vlm/model/token_embedding.onnx") decoding_session = onnxruntime.InferenceSession("vlm/model/decoder.onnx") load_end = time.time() print(f"Inference sessions are loaded. Loading takes {load_end-load_start:0.2f} sec") def main(args): tokenizer = AutoTokenizer.from_pretrained("./vlm/tokenizer") tokenizer.add_tokens([""], special_tokens=True) query = args.input_text prompt = f"<|im_start|>user\n\n{query}<|im_end|>\n<|im_start|>assistant\n" past_kv_values, first_token, input_token_len = prefill(args, tokenizer, prompt) decode(args, tokenizer, past_kv_values, first_token, input_token_len) def process_image(image_path): # Load image if "https" in image_path: image = Image.open(requests.get(image_path, stream=True).raw) else: image = Image.open(image_path) crop_size = (224, 224) do_center_crop = True do_convert_rgb = True do_normalize = True do_rescale = True do_resize = True image_mean = [0.48145466, 0.4578275, 0.40821073] image_std = [0.26862954, 0.26130258, 0.27577711] rescale_factor = 0.00392156862745098 # 1/255 size = {"shortest_edge": 224} resample = Image.BICUBIC # resample = 3 # Convert to rgb if do_convert_rgb: image = image.convert("RGB") # Resize image if do_resize: shortest_edge = min(image.size) scale_factor = size["shortest_edge"] / shortest_edge new_size = (int(image.width * scale_factor), int(image.height * scale_factor)) image = image.resize(new_size, resample=resample) # Center Crop if do_center_crop: left = (image.width - crop_size[0]) / 2 top = (image.height - crop_size[1]) / 2 right = (image.width + crop_size[0]) / 2 bottom = (image.height + crop_size[1]) / 2 image = image.crop((left, top, right, bottom)) # Convert to image array image_array = np.array(image).astype(np.float32) # Rescale (0-255 to 0-1) if do_rescale: image_array = image_array * rescale_factor # Normalize if do_normalize: image_array = (image_array - image_mean) / image_std # (H, W, C) -> (C, H, W) image_array = np.transpose(image_array, (2, 0, 1)) # add batch dim (1, C, H, W) image_array = np.expand_dims(image_array, axis=0) return image_array.astype(np.float32) def top_p_sampling(last_logits, top_p=0.99): sorted_indices = np.argsort(-last_logits) sorted_logits = last_logits[sorted_indices] cumulative_probs = np.cumsum(np.exp(sorted_logits - np.max(sorted_logits))) cumulative_probs /= cumulative_probs[-1] cutoff_index = np.searchsorted(cumulative_probs, top_p, side="right") probs = np.exp(sorted_logits[: cutoff_index + 1] - np.max(sorted_logits[: cutoff_index + 1])) probs /= np.sum(probs) next_token = np.random.choice(sorted_indices[: cutoff_index + 1], p=probs) return next_token # Prefill step # Inputs ## input_ids: [1, seq_len] ## past_key_values: each layer needs key[1, 2, 0, kv_dim], value[1, 2, 0, kv_dim] => total 56 kv # Outputs ## logits: [1, seq_len, 151936] ## present: each layer returns key[1, 2, seq_len, kv_dim], value[1, 2, seq_len, kv_dim] => total 56 kv def prefill(args, tokenizer, input_prompt): print("Running prefill step...") prefill_start = time.time() input_ids = tokenizer(input_prompt)["input_ids"] image_token_pos = input_ids.index(IMAGE_TOKEN_INDEX) pixel_value = process_image(args.image_path) # Get image embedding & Project image embedding to text embedding space image_emb_output = image_emb_session.run(None, {"pixel_values": pixel_value}) image_features_proj = image_emb_output[0] # Get text embedding text_emb_output = text_emb_session.run(None, {"input_ids": [input_ids]}) input_features = text_emb_output[0] # Split text embedding pre_image_text_emb = input_features[:, :image_token_pos, :] post_image_text_emb = input_features[:, image_token_pos + 1 :, :] # Merge text embedding and image embedding hidden_states = np.concatenate((pre_image_text_emb, image_features_proj, post_image_text_emb), axis=1) input_token_len = hidden_states.shape[1] # Prepare inputs used in prefill step with dummy input for initial past kv value prefill_input = { "/model/embed_tokens/Gather_output_0": hidden_states, "attention_mask": np.expand_dims(np.ones(input_token_len).astype(np.int64), axis=0), "position_ids": np.expand_dims(np.arange(input_token_len), axis=0), } for i in range(24): entities = ["key", "value"] for entity in entities: input_name = f"past_key_values.{i}.{entity}" prefill_input[input_name] = np.random.rand(1, 2, 0, 64).astype(np.float32) # Run prefill prefill_outputs = decoding_session.run(None, prefill_input) # Get past kv values for decode step past_kv_values = prefill_outputs[1:] # Get first token with top-p sampling if USE_SAMPLING: last_logits = prefill_outputs[0][0][-1] next_token = top_p_sampling(last_logits) else: next_token = prefill_outputs[0].argmax(-1)[0][-1] prefill_done = time.time() print(f"Prefill step done. Throughtput: {input_token_len/(prefill_done - prefill_start):0.2f} token/sec") return past_kv_values, next_token, input_token_len # Generation step # Inputs ## input_ids: [1, 1] ## past_key_values: each layer needs key[1, 2, past_seq_len, kv_dim], value[1, 2, past_seq_len, kv_dim] => total 56 kv # Outputs ## logits: [1, 1, 151936] ## present: each layer returns key[1, 2, seq_len, kv_dim], value[1, 2, seq_len, kv_dim] => total 56 kv def decode(args, tokenizer, past_kv_values, first_token, input_token_len): print("Runing decode step...", end="\n\n") decode_start = time.time() generated_ids = [first_token] next_token = first_token for last_token_id in range(MAX_GEN_LEN): embedding_output = text_emb_session.run(None, {"input_ids": [[next_token]]}) # Get new token's embedding hidden_states = embedding_output[0] # Prepare inputs for decoding step decoding_input = { "/model/embed_tokens/Gather_output_0": hidden_states.astype(np.float32), "attention_mask": [[1]], "position_ids": [[input_token_len]], } input_token_len += 1 for j in range(24): for k in range(2): if k == 0: input_name = f"past_key_values.{j}.key" else: input_name = f"past_key_values.{j}.value" decoding_input[input_name] = past_kv_values[2 * j + k].astype(np.float32) # Run decoding decoding_outputs = decoding_session.run(None, decoding_input) # Save kv values for next step past_kv_values = decoding_outputs[1:] # Get next token with top_p sampling last_logits = decoding_outputs[0][0][-1] if USE_SAMPLING: next_token = top_p_sampling(last_logits) else: next_token = decoding_outputs[0].argmax(-1)[0][-1] if next_token == tokenizer.eos_token_id: break # Save generated token generated_ids.append(next_token) decode_done = time.time() response = tokenizer.decode(generated_ids) print(f"Response: {response}") with open(args.output_path, 'w') as f: f.write(response) print(f"\nDecode step done. Throughtput: {last_token_id/(decode_done - decode_start):0.2f} token/sec") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_text", type=str, help="Input query for inference", default="Where was this photo taken?") parser.add_argument("--image_path", type=str, help="Local image path or image url", default="assets/test_image.png") parser.add_argument("--output_path", type=str, help="Output path to save the response", default="output.txt") args = parser.parse_args() main(args)