Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download | |
from onnxruntime import InferenceSession | |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer | |
MAX_SEQUENCE_LENGTH = 512 | |
models = { | |
"Base model": "madlag/bert-large-uncased-whole-word-masking-finetuned-squadv2", | |
"Pruned model": "madlag/bert-large-uncased-wwm-squadv2-x2.63-f82.6-d16-hybrid-v1", | |
"Pruned ONNX Optimized FP16": "tryolabs/bert-large-uncased-wwm-squadv2-optimized-f16", | |
} | |
loaded_models = { | |
"Pruned ONNX Optimized FP16": hf_hub_download( | |
repo_id=models["Pruned ONNX Optimized FP16"], filename="model.onnx" | |
), | |
"Base model": AutoModelForQuestionAnswering.from_pretrained(models["Base model"]), | |
"Pruned model": AutoModelForQuestionAnswering.from_pretrained( | |
models["Pruned model"] | |
), | |
} | |
def run_ort_inference(model_name, inputs): | |
sess = InferenceSession( | |
loaded_models[model_name], providers=["CPUExecutionProvider"] | |
) | |
start_time = time.time() | |
output = sess.run(None, input_feed=inputs) | |
end_time = time.time() | |
return (output[0], output[1]), (end_time - start_time) | |
def run_normal_hf(model_name, inputs): | |
start_time = time.time() | |
output = loaded_models[model_name](**inputs).values() | |
end_time = time.time() | |
return output, (end_time - start_time) | |
def inference(model_name, context, question): | |
tokenizer = AutoTokenizer.from_pretrained(models[model_name]) | |
if model_name == "Pruned ONNX Optimized FP16": | |
inputs = dict( | |
tokenizer( | |
question, context, return_tensors="np", max_length=MAX_SEQUENCE_LENGTH | |
) | |
) | |
output, inference_time = run_ort_inference(model_name, inputs) | |
answer_start_scores, answer_end_scores = torch.tensor(output[0]), torch.tensor( | |
output[1] | |
) | |
else: | |
inputs = tokenizer( | |
question, context, return_tensors="pt", max_length=MAX_SEQUENCE_LENGTH | |
) | |
output, inference_time = run_normal_hf(model_name, inputs) | |
answer_start_scores, answer_end_scores = output | |
input_ids = inputs["input_ids"].tolist()[0] | |
answer_start = torch.argmax(answer_start_scores) | |
answer_end = torch.argmax(answer_end_scores) + 1 | |
answer = tokenizer.convert_tokens_to_string( | |
tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]) | |
) | |
return answer, f"{inference_time:.4f}s" | |
model_field = gr.Dropdown( | |
choices=["Base model", "Pruned model", "Pruned ONNX Optimized FP16"], | |
value="Pruned ONNX Optimized FP16", | |
label="Model", | |
) | |
input_text_field = gr.Textbox(placeholder="Enter the text here", label="Text") | |
input_question_field = gr.Text(placeholder="Enter the question here", label="Question") | |
output_model = gr.Text(label="Model output") | |
output_inference_time = gr.Text(label="Inference time in seconds") | |
examples = [ | |
[ | |
"Pruned ONNX Optimized FP16", | |
"The first little pig was very lazy. He didn't want to work at all and he built his house out of straw. The second little pig worked a little bit harder but he was somewhat lazy too and he built his house out of sticks. Then, they sang and danced and played together the rest of the day.", | |
"Who worked a little bit harder?", | |
] | |
] | |
demo = gr.Interface( | |
inference, | |
inputs=[model_field, input_text_field, input_question_field], | |
outputs=[output_model, output_inference_time], | |
examples=examples, | |
) | |
demo.launch() | |