Florence2 Torch JiT Inference

#100
by vivekkalyanarangan - opened

I am using the Florence2 model and trying to speed up inference through torch.jit.trace()

Orig model -

import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-large",
    torch_dtype=torch_dtype,
    trust_remote_code=True,
    torchscript=True
).to(device)
# model = torch.compile(model)
model.eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

Inference Example -

import requests

prompt = "<MORE_DETAILED_CAPTION>"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)

import time

JiT Part

import torch
import types
from torch import nn
import copy

# === 1. Define a clean wrapper module ===
class CleanFlorenceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        # Deepcopy to sever module path references
        self.inner_model = copy.deepcopy(model)
        self.config = self.inner_model.config

        # Patch ALL submodule __module__ attributes recursively
        for name, submodule in self.inner_model.named_modules():
            submodule.__class__.__module__ = "clean_module"

    def forward(self, pixel_values, input_ids, attention_mask):
        pixel_values = pixel_values.half()
        input_ids = input_ids.long()
        attention_mask = attention_mask.half()

        with torch.no_grad():
            outputs = self.inner_model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=input_ids,
                use_cache=True,
                return_dict=True
            )
            return outputs.logits

# === 2. Clean up the model reference ===
CleanFlorenceWrapper.__module__ = "clean_module"

# === 3. Wrap and trace ===
clean_model = CleanFlorenceWrapper(model).eval()

example_inputs = (pixel_values, input_ids, attention_mask)

with torch.no_grad():
    traced_model = torch.jit.trace(clean_model, example_inputs)

torch.jit.save(traced_model, "florence2_traced.pt")
print("βœ… Traced and saved clean model.")

Inference with the compiled model

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded_model = torch.jit.load("florence2_traced.pt", map_location=device)
loaded_model.eval()

def generate_tracer(input_ids, pixel_values, attention_mask, 
             loaded_model, max_new_tokens=100):
  generated_ids = input_ids.clone()
  eos_token_id = processor.tokenizer.eos_token_id

  with torch.no_grad():
      logits = loaded_model(pixel_values, input_ids, attention_mask)

  for step in range(max_new_tokens):
      with torch.no_grad():
          logits = loaded_model(pixel_values, generated_ids, attention_mask)
      
      next_token_logits = logits[:, -1, :]  # take last token's logits
      next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

      # Append the predicted token
      generated_ids = torch.cat([generated_ids, next_token], dim=-1)

      # Update attention mask if needed
      attention_mask = torch.cat(
          [attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)],
          dim=1
      )

      if torch.all(next_token.squeeze(-1) == eos_token_id):
          print(f"βœ… Stopping at step {step+1} due to EOS.")
          break
  
  decoded_output = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  return decoded_output

# Warmup
for _ in range(3):
  generate_tracer(input_ids, pixel_values, attention_mask, 
             loaded_model, max_new_tokens=100)
  
start = time.time()
for _ in range(10):
  parsed_answer =generate_tracer(input_ids, pixel_values, attention_mask, 
             loaded_model, max_new_tokens=100)

print(f"{time.time() - start} seconds")
print(parsed_answer)

My issue is that this inference is about 10X slow, and it is kind of clear why, because there is no way to leverage past_key_values and inputs are being processed cumulatively.

Question: How do we get around this? In the LLM context, how is it possible to run inference at all if past_key_values like feature can't be leveraged?

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment