batch inference scales linearly with batch size when input is long

#243
by platypus1989 - opened

Hi, I am noticing when running batch inference over Mixtral-8x7B-Instruct-v0.1, model seems to be scale nicely (sublinearly) with batch size if input size is small, but when input size gets large (more than 400 tokens), inference time start to become linearly against batch size.

Some sample code to reproduce what I am seeing

import torch
from time import time
import pandas as pd
model_id = 'mistralai/Mixtral-8x7B-Instruct-v0.1'

model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map='auto',
    )
tokenizer = AutoTokenizer.from_pretrained(model_id)

def inference_time(input_size, batch_size):
    prompt = "how are you? "*input_size
    prompts = [prompt]*batch_size
    input = tokenizer(prompts)
    input_ids, attention_mask = torch.tensor(input['input_ids']).to(model.device), torch.tensor(input['attention_mask']).to(model.device)
    tic = time()
    with torch.no_grad():
        output = model(input_ids=input_ids, attention_mask=attention_mask)
    return time() - tic


input_sizes = []
batch_sizes = []
wall_time = []
for i in [1, 5, 10, 20, 50, 100]:
    for j in [1, 5, 10, 20, 50, 100]:
        input_sizes.append(i)
        batch_sizes.append(j)
        wall_time.append(inference_time(i, j))

pd.DataFrame({
    'input_size': input_sizes,
    'batch_size': batch_sizes,
    'inference_time': wall_time,
})

image.png

Sign up or log in to comment