LLM text generation python examples

  • run_llm.py
from transformers import AutoConfig, AutoTokenizer
import onnxruntime
import numpy as np

# 1. Load config, processor, and model
path_to_model = "./llm/model"
path_to_tokenizer = "./llm/tokenizer"
config = AutoConfig.from_pretrained(path_to_model)
tokenizer = AutoTokenizer.from_pretrained(path_to_tokenizer)
decoder_session = onnxruntime.InferenceSession(f"{path_to_model}/q4f16.onnx")

## Set config values
num_key_value_heads = config.num_key_value_heads
head_dim = config.head_dim
num_hidden_layers = config.num_hidden_layers
eos_token_id = 106 # 106 is for <end_of_turn>

# 2. Prepare inputs
## Create input messages
messages = [
  { "role": "system", "content": "You are a helpful assistant." },
  { "role": "user", "content": "Write me a short poem about Machine Learning." },
]

## Apply tokenizer
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np")

## Prepare decoder inputs
batch_size = inputs['input_ids'].shape[0]
past_key_values = {
    f'past_key_values.{layer}.{kv}': np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=np.float32)
    for layer in range(num_hidden_layers)
    for kv in ('key', 'value')
}
input_ids = inputs['input_ids']
position_ids = np.tile(np.arange(1, input_ids.shape[-1] + 1), (batch_size, 1))

# 3. Generation loop
max_new_tokens = 128
generated_tokens = np.array([[]], dtype=np.int64)
for i in range(max_new_tokens):
  logits, *present_key_values = decoder_session.run(None, dict(
      input_ids=input_ids,
      position_ids=position_ids,
      **past_key_values,
  ))

  ## Update values for next generation loop
  input_ids = logits[:, -1].argmax(-1, keepdims=True)
  position_ids = position_ids[:, -1:] + 1
  for j, key in enumerate(past_key_values):
    past_key_values[key] = present_key_values[j]

  generated_tokens = np.concatenate([generated_tokens, input_ids], axis=-1)
  if (input_ids == eos_token_id).all():
    break

  ## (Optional) Streaming
  print(tokenizer.decode(input_ids[0]), end='', flush=True)
print()

# 4. Output result
print(tokenizer.batch_decode(generated_tokens))

VLM text generation python examples

  • run_vlm.py
import argparse
import requests
import onnxruntime
from transformers import AutoTokenizer
import numpy as np
import time
from PIL import Image

IMAGE_TOKEN_INDEX = 151646
MAX_GEN_LEN = 128
USE_SAMPLING = True

print("Loading inference sessions...")
load_start = time.time()

image_emb_session = onnxruntime.InferenceSession("vlm/model/vision_encoder.onnx")
text_emb_session = onnxruntime.InferenceSession("vlm/model/token_embedding.onnx")
decoding_session = onnxruntime.InferenceSession("vlm/model/decoder.onnx")


load_end = time.time()
print(f"Inference sessions are loaded. Loading takes {load_end-load_start:0.2f} sec")


def main(args):
    tokenizer = AutoTokenizer.from_pretrained("./vlm/tokenizer")
    tokenizer.add_tokens(["<image>"], special_tokens=True)

    query = args.input_text
    prompt = f"<|im_start|>user\n<image>\n{query}<|im_end|>\n<|im_start|>assistant\n"
    past_kv_values, first_token, input_token_len = prefill(args, tokenizer, prompt)

    decode(args, tokenizer, past_kv_values, first_token, input_token_len)


def process_image(image_path):
    # Load image
    if "https" in image_path:
        image = Image.open(requests.get(image_path, stream=True).raw)
    else:
        image = Image.open(image_path)
    crop_size = (224, 224)
    do_center_crop = True
    do_convert_rgb = True
    do_normalize = True
    do_rescale = True
    do_resize = True
    image_mean = [0.48145466, 0.4578275, 0.40821073]
    image_std = [0.26862954, 0.26130258, 0.27577711]
    rescale_factor = 0.00392156862745098  # 1/255
    size = {"shortest_edge": 224}
    resample = Image.BICUBIC  # resample = 3

    # Convert to rgb
    if do_convert_rgb:
        image = image.convert("RGB")

    # Resize image
    if do_resize:
        shortest_edge = min(image.size)
        scale_factor = size["shortest_edge"] / shortest_edge
        new_size = (int(image.width * scale_factor), int(image.height * scale_factor))
        image = image.resize(new_size, resample=resample)

    # Center Crop
    if do_center_crop:
        left = (image.width - crop_size[0]) / 2
        top = (image.height - crop_size[1]) / 2
        right = (image.width + crop_size[0]) / 2
        bottom = (image.height + crop_size[1]) / 2
        image = image.crop((left, top, right, bottom))

    # Convert to image array
    image_array = np.array(image).astype(np.float32)

    # Rescale (0-255 to 0-1)
    if do_rescale:
        image_array = image_array * rescale_factor

    # Normalize
    if do_normalize:
        image_array = (image_array - image_mean) / image_std

    # (H, W, C) -> (C, H, W)
    image_array = np.transpose(image_array, (2, 0, 1))

    # add batch dim (1, C, H, W)
    image_array = np.expand_dims(image_array, axis=0)

    return image_array.astype(np.float32)


def top_p_sampling(last_logits, top_p=0.99):
    sorted_indices = np.argsort(-last_logits)
    sorted_logits = last_logits[sorted_indices]

    cumulative_probs = np.cumsum(np.exp(sorted_logits - np.max(sorted_logits)))
    cumulative_probs /= cumulative_probs[-1]

    cutoff_index = np.searchsorted(cumulative_probs, top_p, side="right")

    probs = np.exp(sorted_logits[: cutoff_index + 1] - np.max(sorted_logits[: cutoff_index + 1]))
    probs /= np.sum(probs)

    next_token = np.random.choice(sorted_indices[: cutoff_index + 1], p=probs)

    return next_token


# Prefill step
# Inputs
## input_ids: [1, seq_len]
## past_key_values: each layer needs key[1, 2, 0, kv_dim], value[1, 2, 0, kv_dim] => total 56 kv
# Outputs
## logits: [1, seq_len, 151936]
## present: each layer returns key[1, 2, seq_len, kv_dim], value[1, 2, seq_len, kv_dim] => total 56 kv
def prefill(args, tokenizer, input_prompt):
    print("Running prefill step...")
    prefill_start = time.time()

    input_ids = tokenizer(input_prompt)["input_ids"]
    image_token_pos = input_ids.index(IMAGE_TOKEN_INDEX)

    pixel_value = process_image(args.image_path)

    # Get image embedding & Project image embedding to text embedding space
    image_emb_output = image_emb_session.run(None, {"pixel_values": pixel_value})
    image_features_proj = image_emb_output[0]

    # Get text embedding
    text_emb_output = text_emb_session.run(None, {"input_ids": [input_ids]})
    input_features = text_emb_output[0]

    # Split text embedding
    pre_image_text_emb = input_features[:, :image_token_pos, :]
    post_image_text_emb = input_features[:, image_token_pos + 1 :, :]

    # Merge text embedding and image embedding
    hidden_states = np.concatenate((pre_image_text_emb, image_features_proj, post_image_text_emb), axis=1)
    input_token_len = hidden_states.shape[1]

    # Prepare inputs used in prefill step with dummy input for initial past kv value
    prefill_input = {
        "/model/embed_tokens/Gather_output_0": hidden_states,
        "attention_mask": np.expand_dims(np.ones(input_token_len).astype(np.int64), axis=0),
        "position_ids": np.expand_dims(np.arange(input_token_len), axis=0),
    }
    for i in range(24):
        entities = ["key", "value"]
        for entity in entities:
            input_name = f"past_key_values.{i}.{entity}"
            prefill_input[input_name] = np.random.rand(1, 2, 0, 64).astype(np.float32)

    # Run prefill
    prefill_outputs = decoding_session.run(None, prefill_input)

    # Get past kv values for decode step
    past_kv_values = prefill_outputs[1:]

    # Get first token with top-p sampling
    if USE_SAMPLING:
        last_logits = prefill_outputs[0][0][-1]
        next_token = top_p_sampling(last_logits)
    else:
        next_token = prefill_outputs[0].argmax(-1)[0][-1]

    prefill_done = time.time()
    print(f"Prefill step done. Throughtput: {input_token_len/(prefill_done - prefill_start):0.2f} token/sec")

    return past_kv_values, next_token, input_token_len


# Generation step
# Inputs
## input_ids: [1, 1]
## past_key_values: each layer needs key[1, 2, past_seq_len, kv_dim], value[1, 2, past_seq_len, kv_dim] => total 56 kv
# Outputs
## logits: [1, 1, 151936]
## present: each layer returns key[1, 2, seq_len, kv_dim], value[1, 2, seq_len, kv_dim] => total 56 kv
def decode(args, tokenizer, past_kv_values, first_token, input_token_len):
    print("Runing decode step...", end="\n\n")
    decode_start = time.time()

    generated_ids = [first_token]
    next_token = first_token

    for last_token_id in range(MAX_GEN_LEN):
        embedding_output = text_emb_session.run(None, {"input_ids": [[next_token]]})

        # Get new token's embedding
        hidden_states = embedding_output[0]

        # Prepare inputs for decoding step
        decoding_input = {
            "/model/embed_tokens/Gather_output_0": hidden_states.astype(np.float32),
            "attention_mask": [[1]],
            "position_ids": [[input_token_len]],
        }
        input_token_len += 1
        for j in range(24):
            for k in range(2):
                if k == 0:
                    input_name = f"past_key_values.{j}.key"
                else:
                    input_name = f"past_key_values.{j}.value"
                decoding_input[input_name] = past_kv_values[2 * j + k].astype(np.float32)

        # Run decoding
        decoding_outputs = decoding_session.run(None, decoding_input)

        # Save kv values for next step
        past_kv_values = decoding_outputs[1:]

        # Get next token with top_p sampling
        last_logits = decoding_outputs[0][0][-1]

        if USE_SAMPLING:
            next_token = top_p_sampling(last_logits)
        else:
            next_token = decoding_outputs[0].argmax(-1)[0][-1]

        if next_token == tokenizer.eos_token_id:
            break

        # Save generated token
        generated_ids.append(next_token)

    decode_done = time.time()
    response = tokenizer.decode(generated_ids)

    print(f"Response: {response}")
    with open(args.output_path, 'w') as f:
        f.write(response)
    print(f"\nDecode step done. Throughtput: {last_token_id/(decode_done - decode_start):0.2f} token/sec")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_text", type=str, help="Input query for inference", default="Where was this photo taken?")
    parser.add_argument("--image_path", type=str, help="Local image path or image url", default="assets/test_image.png")
    parser.add_argument("--output_path", type=str, help="Output path to save the response", default="output.txt")
    args = parser.parse_args()

    main(args)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support