Florence2 Torch JiT Inference
#100
by
vivekkalyanarangan
- opened
I am using the Florence2 model and trying to speed up inference through torch.jit.trace()
Orig model -
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
torch_dtype=torch_dtype,
trust_remote_code=True,
torchscript=True
).to(device)
# model = torch.compile(model)
model.eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
Inference Example -
import requests
prompt = "<MORE_DETAILED_CAPTION>"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
import time
JiT Part
import torch
import types
from torch import nn
import copy
# === 1. Define a clean wrapper module ===
class CleanFlorenceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
# Deepcopy to sever module path references
self.inner_model = copy.deepcopy(model)
self.config = self.inner_model.config
# Patch ALL submodule __module__ attributes recursively
for name, submodule in self.inner_model.named_modules():
submodule.__class__.__module__ = "clean_module"
def forward(self, pixel_values, input_ids, attention_mask):
pixel_values = pixel_values.half()
input_ids = input_ids.long()
attention_mask = attention_mask.half()
with torch.no_grad():
outputs = self.inner_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=input_ids,
use_cache=True,
return_dict=True
)
return outputs.logits
# === 2. Clean up the model reference ===
CleanFlorenceWrapper.__module__ = "clean_module"
# === 3. Wrap and trace ===
clean_model = CleanFlorenceWrapper(model).eval()
example_inputs = (pixel_values, input_ids, attention_mask)
with torch.no_grad():
traced_model = torch.jit.trace(clean_model, example_inputs)
torch.jit.save(traced_model, "florence2_traced.pt")
print("β
Traced and saved clean model.")
Inference with the compiled model
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded_model = torch.jit.load("florence2_traced.pt", map_location=device)
loaded_model.eval()
def generate_tracer(input_ids, pixel_values, attention_mask,
loaded_model, max_new_tokens=100):
generated_ids = input_ids.clone()
eos_token_id = processor.tokenizer.eos_token_id
with torch.no_grad():
logits = loaded_model(pixel_values, input_ids, attention_mask)
for step in range(max_new_tokens):
with torch.no_grad():
logits = loaded_model(pixel_values, generated_ids, attention_mask)
next_token_logits = logits[:, -1, :] # take last token's logits
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
# Append the predicted token
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Update attention mask if needed
attention_mask = torch.cat(
[attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)],
dim=1
)
if torch.all(next_token.squeeze(-1) == eos_token_id):
print(f"β
Stopping at step {step+1} due to EOS.")
break
decoded_output = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return decoded_output
# Warmup
for _ in range(3):
generate_tracer(input_ids, pixel_values, attention_mask,
loaded_model, max_new_tokens=100)
start = time.time()
for _ in range(10):
parsed_answer =generate_tracer(input_ids, pixel_values, attention_mask,
loaded_model, max_new_tokens=100)
print(f"{time.time() - start} seconds")
print(parsed_answer)
My issue is that this inference is about 10X slow, and it is kind of clear why, because there is no way to leverage past_key_values
and inputs are being processed cumulatively.
Question: How do we get around this? In the LLM context, how is it possible to run inference at all if past_key_values
like feature can't be leveraged?