batch inference scales linearly with batch size when input is long
#243
by
platypus1989
- opened
Hi, I am noticing when running batch inference over Mixtral-8x7B-Instruct-v0.1
, model seems to be scale nicely (sublinearly) with batch size if input size is small, but when input size gets large (more than 400 tokens), inference time start to become linearly against batch size.
Some sample code to reproduce what I am seeing
import torch
from time import time
import pandas as pd
model_id = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def inference_time(input_size, batch_size):
prompt = "how are you? "*input_size
prompts = [prompt]*batch_size
input = tokenizer(prompts)
input_ids, attention_mask = torch.tensor(input['input_ids']).to(model.device), torch.tensor(input['attention_mask']).to(model.device)
tic = time()
with torch.no_grad():
output = model(input_ids=input_ids, attention_mask=attention_mask)
return time() - tic
input_sizes = []
batch_sizes = []
wall_time = []
for i in [1, 5, 10, 20, 50, 100]:
for j in [1, 5, 10, 20, 50, 100]:
input_sizes.append(i)
batch_sizes.append(j)
wall_time.append(inference_time(i, j))
pd.DataFrame({
'input_size': input_sizes,
'batch_size': batch_sizes,
'inference_time': wall_time,
})