Upload 3 files
Browse files- smalvlm-256m-instruct_q8_ekv2048.tflite +3 -0
- test_tflite.py +310 -0
- tokenizer.model +3 -0
    	
        smalvlm-256m-instruct_q8_ekv2048.tflite
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:469a85dc3ddbb4458091da5dee62df625d58595132e26eb0ce2eae7248f22e60
         | 
| 3 | 
            +
            size 288312304
         | 
    	
        test_tflite.py
    ADDED
    
    | @@ -0,0 +1,310 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict
         | 
| 2 | 
            +
            from ai_edge_litert import interpreter as interpreter_lib
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
            from collections.abc import Sequence
         | 
| 6 | 
            +
            from transformers import AutoProcessor
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            import requests
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from transformers import AutoModelForVision2Seq
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def _get_mask(shape: Sequence[int], k: int):
         | 
| 16 | 
            +
                """Gets the mask for the input to the model.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Args:
         | 
| 19 | 
            +
                shape: The shape of the mask input to the model.
         | 
| 20 | 
            +
                k: all elements below the k-th diagonal are set to 0.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                Returns:
         | 
| 23 | 
            +
                The mask for the input to the model. All the elements in the mask are set
         | 
| 24 | 
            +
                to -inf except that all the elements below the k-th diagonal are set to 0.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                mask = np.ones(shape, dtype=np.float32) * float("-inf")
         | 
| 27 | 
            +
                mask = np.triu(mask, k=k)
         | 
| 28 | 
            +
                return mask
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class LiteRTLlmPipeline:
         | 
| 32 | 
            +
             | 
| 33 | 
            +
              def __init__(self, interpreter, processor):
         | 
| 34 | 
            +
                """Initializes the pipeline."""
         | 
| 35 | 
            +
                self._interpreter = interpreter
         | 
| 36 | 
            +
                self._processor = processor
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                self._prefill_runner = None
         | 
| 39 | 
            +
                self._decode_runner = self._interpreter.get_signature_runner("decode")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
              def _init_prefill_runner(self, num_input_tokens: int):
         | 
| 42 | 
            +
                """Initializes all the variables related to the prefill runner.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                This method initializes the following variables:
         | 
| 45 | 
            +
                  - self._prefill_runner: The prefill runner based on the input size.
         | 
| 46 | 
            +
                  - self._max_seq_len: The maximum sequence length supported by the model.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                Args:
         | 
| 49 | 
            +
                  num_input_tokens: The number of input tokens.
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                if not self._interpreter:
         | 
| 52 | 
            +
                  raise ValueError("Interpreter is not initialized.")
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # Prefill runner related variables will be initialized in `predict_text` and
         | 
| 55 | 
            +
                # `compute_log_likelihood`.
         | 
| 56 | 
            +
                self._prefill_runner = self._get_prefill_runner(num_input_tokens)
         | 
| 57 | 
            +
                # input_token_shape has shape (batch, max_seq_len)
         | 
| 58 | 
            +
                input_token_shape = self._prefill_runner.get_input_details()["tokens"][
         | 
| 59 | 
            +
                    "shape"
         | 
| 60 | 
            +
                ]
         | 
| 61 | 
            +
                if len(input_token_shape) == 1:
         | 
| 62 | 
            +
                  self._max_seq_len = input_token_shape[0]
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                  self._max_seq_len = input_token_shape[1]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                # SmolLM: kv cache input has shape [batch=1, cache_size, num_kv_heads, head_dim].
         | 
| 67 | 
            +
                kv_cache_shape = self._prefill_runner.get_input_details()["kv_cache_k_0"][
         | 
| 68 | 
            +
                    "shape"
         | 
| 69 | 
            +
                ]
         | 
| 70 | 
            +
                self._max_kv_cache_seq_len = kv_cache_shape[1]
         | 
| 71 | 
            +
             | 
| 72 | 
            +
              def _init_kv_cache(self) -> dict[str, np.ndarray]:
         | 
| 73 | 
            +
                if self._prefill_runner is None:
         | 
| 74 | 
            +
                  raise ValueError("Prefill runner is not initialized.")
         | 
| 75 | 
            +
                kv_cache = {}
         | 
| 76 | 
            +
                for input_key in self._prefill_runner.get_input_details().keys():
         | 
| 77 | 
            +
                  if "kv_cache" in input_key:
         | 
| 78 | 
            +
                    kv_cache[input_key] = np.zeros(
         | 
| 79 | 
            +
                        self._prefill_runner.get_input_details()[input_key]["shape"],
         | 
| 80 | 
            +
                        dtype=np.float32,
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                    kv_cache[input_key] = np.zeros(
         | 
| 83 | 
            +
                        self._prefill_runner.get_input_details()[input_key]["shape"],
         | 
| 84 | 
            +
                        dtype=np.float32,
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                return kv_cache
         | 
| 87 | 
            +
             | 
| 88 | 
            +
              def _get_prefill_runner(self, num_input_tokens: int) :
         | 
| 89 | 
            +
                """Gets the prefill runner with the best suitable input size.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                Args:
         | 
| 92 | 
            +
                  num_input_tokens: The number of input tokens.
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                Returns:
         | 
| 95 | 
            +
                  The prefill runner with the smallest input size.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                best_signature = None
         | 
| 98 | 
            +
                delta = sys.maxsize
         | 
| 99 | 
            +
                max_prefill_len = -1
         | 
| 100 | 
            +
                for key in self._interpreter.get_signature_list().keys():
         | 
| 101 | 
            +
                  if "prefill" not in key or 'pixel' not in key:
         | 
| 102 | 
            +
                    continue
         | 
| 103 | 
            +
                  input_pos = self._interpreter.get_signature_runner(key).get_input_details()[
         | 
| 104 | 
            +
                      "input_pos"
         | 
| 105 | 
            +
                  ]
         | 
| 106 | 
            +
                  # input_pos["shape"] has shape (max_seq_len, )
         | 
| 107 | 
            +
                  seq_size = input_pos["shape"][0]
         | 
| 108 | 
            +
                  max_prefill_len = max(max_prefill_len, seq_size)
         | 
| 109 | 
            +
                  if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:
         | 
| 110 | 
            +
                    delta = seq_size - num_input_tokens
         | 
| 111 | 
            +
                    best_signature = key
         | 
| 112 | 
            +
                if best_signature is None:
         | 
| 113 | 
            +
                  raise ValueError(
         | 
| 114 | 
            +
                      "The largest prefill length supported is %d, but we have %d number of input tokens"
         | 
| 115 | 
            +
                      %(max_prefill_len, num_input_tokens)
         | 
| 116 | 
            +
                  )
         | 
| 117 | 
            +
                return self._interpreter.get_signature_runner(best_signature)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
              def _run_prefill(
         | 
| 120 | 
            +
                  self, prefill_token_ids: Sequence[int], pixel_values: np.ndarray,
         | 
| 121 | 
            +
              ) -> dict[str, np.ndarray]:
         | 
| 122 | 
            +
                """Runs prefill and returns the kv cache.
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                Args:
         | 
| 125 | 
            +
                  prefill_token_ids: The token ids of the prefill input.
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                Returns:
         | 
| 128 | 
            +
                  The updated kv cache.
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                if not self._prefill_runner:
         | 
| 131 | 
            +
                  raise ValueError("Prefill runner is not initialized.")
         | 
| 132 | 
            +
                prefill_token_length = len(prefill_token_ids)
         | 
| 133 | 
            +
                if prefill_token_length == 0:
         | 
| 134 | 
            +
                  return self._init_kv_cache()
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                # Prepare the input to be [1, max_seq_len].
         | 
| 137 | 
            +
                input_token_ids = [0] * self._max_seq_len
         | 
| 138 | 
            +
                input_token_ids[:prefill_token_length] = prefill_token_ids
         | 
| 139 | 
            +
                input_token_ids = np.asarray(input_token_ids, dtype=np.int32)
         | 
| 140 | 
            +
                input_token_ids = np.expand_dims(input_token_ids, axis=0)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # Prepare the input position to be [max_seq_len].
         | 
| 143 | 
            +
                input_pos = [0] * self._max_seq_len
         | 
| 144 | 
            +
                input_pos[:prefill_token_length] = range(prefill_token_length)
         | 
| 145 | 
            +
                input_pos = np.asarray(input_pos, dtype=np.int32)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                # Initialize kv cache.
         | 
| 148 | 
            +
                prefill_inputs = self._init_kv_cache()
         | 
| 149 | 
            +
                # Prepare the tokens and input position inputs.
         | 
| 150 | 
            +
                prefill_inputs.update({
         | 
| 151 | 
            +
                    "tokens": input_token_ids,
         | 
| 152 | 
            +
                    "input_pos": input_pos,
         | 
| 153 | 
            +
                    "pixel_values": pixel_values,
         | 
| 154 | 
            +
                })
         | 
| 155 | 
            +
                if "mask" in self._prefill_runner.get_input_details().keys():
         | 
| 156 | 
            +
                  # For prefill, mask has shape [batch=1, 1, seq_len, kv_cache_size].
         | 
| 157 | 
            +
                  # We want mask[0, 0, i, j] = 0 for j<=i and -inf otherwise.
         | 
| 158 | 
            +
                  prefill_inputs["mask"] = _get_mask(
         | 
| 159 | 
            +
                      shape=self._prefill_runner.get_input_details()["mask"]["shape"],
         | 
| 160 | 
            +
                      k=1,
         | 
| 161 | 
            +
                  )
         | 
| 162 | 
            +
                prefill_outputs = self._prefill_runner(**prefill_inputs)
         | 
| 163 | 
            +
                if "logits" in prefill_outputs:
         | 
| 164 | 
            +
                  # Prefill outputs includes logits and kv cache. We only output kv cache.
         | 
| 165 | 
            +
                  prefill_outputs.pop("logits")
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                return prefill_outputs
         | 
| 168 | 
            +
             | 
| 169 | 
            +
              def _greedy_sampler(self, logits: np.ndarray) -> int:
         | 
| 170 | 
            +
                return int(np.argmax(logits))
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
              def _run_decode(
         | 
| 174 | 
            +
                  self,
         | 
| 175 | 
            +
                  start_pos: int,
         | 
| 176 | 
            +
                  start_token_id: int,
         | 
| 177 | 
            +
                  kv_cache: dict[str, np.ndarray],
         | 
| 178 | 
            +
                  max_decode_steps: int,
         | 
| 179 | 
            +
              ) -> str:
         | 
| 180 | 
            +
                """Runs decode and outputs the token ids from greedy sampler.
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                Args:
         | 
| 183 | 
            +
                  start_pos: The position of the first token of the decode input.
         | 
| 184 | 
            +
                  start_token_id: The token id of the first token of the decode input.
         | 
| 185 | 
            +
                  kv_cache: The kv cache from the prefill.
         | 
| 186 | 
            +
                  max_decode_steps: The max decode steps.
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                Returns:
         | 
| 189 | 
            +
                  The token ids from the greedy sampler.
         | 
| 190 | 
            +
                """
         | 
| 191 | 
            +
                next_pos = start_pos
         | 
| 192 | 
            +
                next_token = start_token_id
         | 
| 193 | 
            +
                decode_text = []
         | 
| 194 | 
            +
                decode_inputs = kv_cache
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                for _ in range(max_decode_steps):
         | 
| 197 | 
            +
                  decode_inputs.update({
         | 
| 198 | 
            +
                      "tokens": np.array([[next_token]], dtype=np.int32),
         | 
| 199 | 
            +
                      "input_pos": np.array([next_pos], dtype=np.int32),
         | 
| 200 | 
            +
                  })
         | 
| 201 | 
            +
                  if "mask" in self._decode_runner.get_input_details().keys():
         | 
| 202 | 
            +
                    # For decode, mask has shape [batch=1, 1, 1, kv_cache_size].
         | 
| 203 | 
            +
                    # We want mask[0, 0, 0, j] = 0 for j<=next_pos and -inf otherwise.
         | 
| 204 | 
            +
                    decode_inputs["mask"] = _get_mask(
         | 
| 205 | 
            +
                        shape=self._decode_runner.get_input_details()["mask"]["shape"],
         | 
| 206 | 
            +
                        k=next_pos + 1,
         | 
| 207 | 
            +
                    )
         | 
| 208 | 
            +
                  decode_outputs = self._decode_runner(**decode_inputs)
         | 
| 209 | 
            +
                  # Output logits has shape (batch=1, 1, vocab_size). We only take the first
         | 
| 210 | 
            +
                  # element.
         | 
| 211 | 
            +
                  logits = decode_outputs.pop("logits")[0][0]
         | 
| 212 | 
            +
                  next_token = self._greedy_sampler(logits)
         | 
| 213 | 
            +
                  if next_token == self._processor.tokenizer.eos_token_id:
         | 
| 214 | 
            +
                    break
         | 
| 215 | 
            +
                  decode_text.append(self._processor.tokenizer.decode(next_token, skip_special_tokens=True))
         | 
| 216 | 
            +
                  if len(decode_text[-1]) == 0:
         | 
| 217 | 
            +
                    # Break out the loop if we hit the special token.
         | 
| 218 | 
            +
                    break
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                  print(decode_text[-1], end='', flush=True)
         | 
| 221 | 
            +
                  # Decode outputs includes logits and kv cache. We already poped out
         | 
| 222 | 
            +
                  # logits, so the rest is kv cache. We pass the updated kv cache as input
         | 
| 223 | 
            +
                  # to the next decode step.
         | 
| 224 | 
            +
                  decode_inputs = decode_outputs
         | 
| 225 | 
            +
                  next_pos += 1
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                print() # print a new line at the end.
         | 
| 228 | 
            +
                return ''.join(decode_text)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
              def generate(self, inputs: Dict, max_decode_steps: int | None = None) -> str:
         | 
| 231 | 
            +
              
         | 
| 232 | 
            +
                token_ids = inputs["input_ids"][0]
         | 
| 233 | 
            +
                pixel_values = inputs["pixel_values"][0]
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                # Initialize the prefill runner with the suitable input size.
         | 
| 236 | 
            +
                self._init_prefill_runner(len(token_ids))
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                # Run prefill.
         | 
| 239 | 
            +
                # Prefill up to the seond to the last token of the prompt, because the last
         | 
| 240 | 
            +
                # token of the prompt will be used to bootstrap decode.
         | 
| 241 | 
            +
                prefill_token_length = len(token_ids) - 1
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                print('Running prefill')
         | 
| 244 | 
            +
                kv_cache = self._run_prefill(token_ids[:prefill_token_length], pixel_values)
         | 
| 245 | 
            +
                # Run decode.
         | 
| 246 | 
            +
                print('Running decode')
         | 
| 247 | 
            +
                actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1
         | 
| 248 | 
            +
                if max_decode_steps is not None:
         | 
| 249 | 
            +
                  actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)
         | 
| 250 | 
            +
                decode_text = self._run_decode(
         | 
| 251 | 
            +
                    prefill_token_length,
         | 
| 252 | 
            +
                    token_ids[prefill_token_length],
         | 
| 253 | 
            +
                    kv_cache,
         | 
| 254 | 
            +
                    actual_max_decode_steps,
         | 
| 255 | 
            +
                )
         | 
| 256 | 
            +
                return decode_text
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            if __name__ == "__main__":
         | 
| 260 | 
            +
                
         | 
| 261 | 
            +
                model_id = './models/SmolVLM-256M-Instruct'
         | 
| 262 | 
            +
                tflite_model_path = './models/SmolVLM-256M-Instruct-tflite/smalvlm-256m-instruct_q8_ekv2048.tflite'
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                interpreter = interpreter_lib.InterpreterWithCustomOps(
         | 
| 265 | 
            +
                    custom_op_registerers=["pywrap_genai_ops.GenAIOpsRegisterer"],
         | 
| 266 | 
            +
                    model_path=tflite_model_path,
         | 
| 267 | 
            +
                    num_threads=2,
         | 
| 268 | 
            +
                    experimental_default_delegate_latest_features=True)
         | 
| 269 | 
            +
                
         | 
| 270 | 
            +
                processor = AutoProcessor.from_pretrained(model_id, do_image_splitting=True)
         | 
| 271 | 
            +
                image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
         | 
| 272 | 
            +
                image = Image.open(requests.get(image_url, stream=True).raw)
         | 
| 273 | 
            +
                # image = Image.open("/home/dragynir/ai_vlm/cats.jpg")
         | 
| 274 | 
            +
                # image = Image.open("/home/dragynir/ai_vlm/car.jpg")
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                messages = [
         | 
| 277 | 
            +
                    {
         | 
| 278 | 
            +
                        "role": "user",
         | 
| 279 | 
            +
                        "content": [
         | 
| 280 | 
            +
                            {"type": "image"},
         | 
| 281 | 
            +
                            {"type": "text", "text": "What in the image?"}
         | 
| 282 | 
            +
                        ]
         | 
| 283 | 
            +
                    },
         | 
| 284 | 
            +
                ]
         | 
| 285 | 
            +
                prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
         | 
| 286 | 
            +
                print(prompt)
         | 
| 287 | 
            +
                inputs = processor(text=prompt, images=[image], return_tensors="pt")
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                # Tflite model inference
         | 
| 290 | 
            +
                pipeline = LiteRTLlmPipeline(interpreter, processor)
         | 
| 291 | 
            +
                tflite_text = pipeline.generate(inputs, max_decode_steps=100)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                # HuggingFace model inference
         | 
| 294 | 
            +
                DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 295 | 
            +
                inputs = inputs.to(DEVICE)
         | 
| 296 | 
            +
                model = AutoModelForVision2Seq.from_pretrained(
         | 
| 297 | 
            +
                  model_id,
         | 
| 298 | 
            +
                  torch_dtype=torch.bfloat16,
         | 
| 299 | 
            +
                  _attn_implementation="eager",
         | 
| 300 | 
            +
                ).to(DEVICE)
         | 
| 301 | 
            +
                generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=100)
         | 
| 302 | 
            +
                generated_texts = processor.batch_decode(
         | 
| 303 | 
            +
                    generated_ids,
         | 
| 304 | 
            +
                    skip_special_tokens=True,
         | 
| 305 | 
            +
                )
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                hf_text = generated_texts[0]
         | 
| 308 | 
            +
                print("-"*100)
         | 
| 309 | 
            +
                print("Tflite:\n", tflite_text)
         | 
| 310 | 
            +
                print("HF:\n", hf_text)
         | 
    	
        tokenizer.model
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:6682f47d3b33538490b21265ba3b2a83f8d48e09dcd7f957b46b508abb427a04
         | 
| 3 | 
            +
            size 881895
         | 
