SuperBench-Eval / app.py
Enderchef's picture
Update app.py
8ea457b verified
raw
history blame
4.97 kB
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
import torch
import re
# Cache to avoid reloading the model
model_cache = {}
HF_TOKEN = os.environ.get("HF_TOKEN")
def load_model(model_id):
if model_id in model_cache:
return model_cache[model_id]
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN).to("cuda" if torch.cuda.is_available() else "cpu")
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
model_cache[model_id] = generator
return generator
def format_prompt(item):
system_instruction = "
Only answer with a single letter: A, B, C, or D."
prompt = f"{item['question']}
A. {item['choices'][0]}
B. {item['choices'][1]}
C. {item['choices'][2]}
D. {item['choices'][3]}
Answer:{system_instruction}"
return prompt, item['answer']
def extract_choice_letter(output):
match = re.search(r"\b([ABCD])\b", output.strip())
return match.group(1) if match else None
def evaluate(model_id, sample_count, config_name):
gen = load_model(model_id)
dataset = load_dataset("cais/mmlu", config_name, token=HF_TOKEN)["test"]
dataset = dataset.shuffle(seed=42).select(range(min(sample_count, len(dataset))))
correct = 0
results = []
for item in dataset:
prompt, answer = format_prompt(item)
output = gen(prompt, max_new_tokens=20, do_sample=False)[0]["generated_text"]
output_letter = extract_choice_letter(output)
is_correct = output_letter == answer
correct += is_correct
results.append((prompt, output.strip(), answer, output_letter, is_correct))
accuracy = correct / len(dataset) * 100
return f"Accuracy: {accuracy:.2f}%, out of {len(dataset)} samples", results
def run(model_id, sample_count, config_name):
score, details = evaluate(model_id, sample_count, config_name)
formatted = "\n\n".join([
f"### Question:\n{q}\n\n**Model Answer:** {o}\n**Expected:** {a}\n**Predicted:** {g}\n**Correct:** {c}"
for q, o, a, g, c in details
])
return score, formatted
def save_text(text):
return "evaluation_results.txt", text
with gr.Blocks(css="body {font-family: Inter, sans-serif; padding: 1em; max-width: 900px; margin: auto;}", analytics_enabled=False) as demo:
gr.Markdown("""
# πŸ€– LLM Benchmark Evaluator
Currently, only **MMLU** (`cais/mmlu`) is available for evaluation.
**MMLU-Pro** and **Humanity's Last Exam** will be coming soon.
Enter your model ID, pick MMLU, choose a subject, and hit evaluate.
""")
with gr.Row():
model_id = gr.Textbox(label="Your Hugging Face Model ID", placeholder="e.g., your-org/your-model")
config_name = gr.Dropdown(
label="Choose MMLU Subject",
choices=[
"abstract_algebra", "anatomy", "astronomy", "business_ethics", "college_biology",
"college_chemistry", "college_computer_science", "college_mathematics", "college_medicine",
"college_physics", "computer_security", "econometrics", "electrical_engineering",
"elementary_mathematics", "formal_logic", "global_facts", "high_school_biology",
"high_school_chemistry", "high_school_computer_science", "high_school_european_history",
"high_school_geography", "high_school_government_and_politics", "high_school_macroeconomics",
"high_school_microeconomics", "high_school_physics", "high_school_psychology",
"high_school_statistics", "high_school_us_history", "high_school_world_history", "human_aging",
"human_sexuality", "international_law", "jurisprudence", "logical_fallacies", "machine_learning",
"management", "marketing", "medical_genetics", "miscellaneous", "moral_disputes",
"moral_scenarios", "nutrition", "philosophy", "prehistory", "professional_accounting",
"professional_law", "professional_medicine", "professional_psychology", "public_relations",
"security_studies", "sociology", "us_foreign_policy", "virology", "world_religions"
],
value="college_mathematics"
)
sample_count = gr.Slider(label="Number of Samples", minimum=1, maximum=100, value=10, step=1)
run_button = gr.Button("πŸš€ Run Evaluation")
acc_output = gr.Textbox(label="Benchmark Accuracy", interactive=False)
detail_output = gr.Textbox(label="Evaluation Details", lines=20, interactive=False)
download_button = gr.Button("πŸ“₯ Download Full Evaluation")
run_button.click(run, inputs=[model_id, sample_count, config_name], outputs=[acc_output, detail_output])
download_button.click(save_text, inputs=detail_output, outputs=gr.File())
demo.launch()