Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from datasets import load_dataset, get_dataset_config_names | |
import torch | |
import re | |
import json | |
import pandas as pd | |
import traceback | |
import spaces | |
# --- Environment and Caching --- | |
# It's good practice to ensure the cache directory exists. | |
CACHE_DIR = "evaluation_cache" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
EVAL_FILE = os.path.join(CACHE_DIR, "eval.jsonl") | |
# Cache to avoid reloading models and dataset configs | |
model_cache = {} | |
benchmark_subject_cache = {} | |
# Use environment variable for the Hugging Face token | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# --- Constants for Benchmarks --- | |
MMLU_DATASET = "cais/mmlu" | |
MMLU_PRO_DATASET = "TIGER-Lab/MMLU-Pro" | |
BENCHMARK_MAP = { | |
"MMLU": MMLU_DATASET, | |
"MMLU-Pro": MMLU_PRO_DATASET | |
} | |
# --- Data Loading and Preparation --- | |
def get_all_benchmark_options(): | |
""" | |
Fetches and caches the available subjects (configs) for each benchmark dataset. | |
This function now populates a global cache to avoid repeated API calls. | |
""" | |
if benchmark_subject_cache: | |
return benchmark_subject_cache | |
print("Fetching benchmark configurations for the first time...") | |
for key, dataset_id in BENCHMARK_MAP.items(): | |
try: | |
# Fetching dataset configurations requires authentication if the dataset is private | |
subjects = get_dataset_config_names(dataset_id, token=HF_TOKEN) | |
benchmark_subject_cache[key] = ["ALL"] + subjects | |
except Exception as e: | |
print(f"Warning: Could not load configs for {key} ({dataset_id}). It might be private or unavailable. Error: {e}") | |
benchmark_subject_cache[key] = [] | |
print("Benchmark configurations cached.") | |
return benchmark_subject_cache | |
# Initialize the cache on startup | |
ALL_BENCHMARK_SUBJECTS = get_all_benchmark_options() | |
def load_model(model_id): | |
""" | |
Loads a Hugging Face model and tokenizer, creating a text-generation pipeline. | |
Uses a cache to avoid reloading models. | |
""" | |
if not model_id: | |
raise ValueError("Model ID cannot be empty.") | |
gr.Info(f"Attempting to load model: {model_id}...") | |
if model_id in model_cache: | |
gr.Info(f"Model '{model_id}' found in cache.") | |
return model_cache[model_id] | |
try: | |
# Use bfloat16 for better performance on modern GPUs | |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=HF_TOKEN, | |
torch_dtype=dtype, | |
trust_remote_code=True | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
# Create the pipeline for text generation | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
model_cache[model_id] = generator | |
gr.Info(f"Model '{model_id}' loaded successfully.") | |
return generator | |
except Exception as e: | |
# Raise a more specific error to be caught by the main evaluation function | |
raise RuntimeError(f"Failed to load model '{model_id}'. Please verify the model ID and your Hugging Face token (if required). Error: {e}") | |
# --- Evaluation Logic --- | |
def format_prompt(item): | |
"""Formats the MMLU question and choices into a standardized prompt.""" | |
prompt = f"Question: {item['question']}\n\nChoices:\nA. {item['choices'][0]}\nB. {item['choices'][1]}\nC. {item['choices'][2]}\nD. {item['choices'][3]}\n\nAnswer:" | |
return prompt, item['answer'] | |
def get_choice_letter(index): | |
"""Converts a numerical choice index (0-3) to a letter (A-D).""" | |
return chr(ord('A') + index) if 0 <= index <= 3 else None | |
def extract_predicted_letter(output_text): | |
""" | |
Extracts the predicted letter from the model's output. | |
It looks for a letter (A, B, C, D) immediately following 'Answer:'. | |
""" | |
match = re.search(r"Answer:\s*([ABCD])", output_text, re.IGNORECASE) | |
return match.group(1).upper() if match else None | |
def evaluate_single_subject(generator, dataset_id, subject, sample_count, progress): | |
""" | |
Evaluates a model on a specific subject from a dataset. | |
""" | |
gr.Info(f"Loading dataset: {dataset_id} ({subject})...") | |
try: | |
# Load the 'test' split as it's standard for MMLU evaluation | |
dataset = load_dataset(dataset_id, subject, token=HF_TOKEN, split="test") | |
except Exception as e: | |
raise RuntimeError(f"Failed to load dataset '{dataset_id}' for subject '{subject}'. Error: {e}") | |
# Shuffle and select a subset of samples for evaluation | |
num_samples = min(sample_count, len(dataset)) | |
dataset = dataset.shuffle(seed=42).select(range(num_samples)) | |
correct_predictions = 0 | |
results_details = [] | |
for item in progress.tqdm(dataset, desc=f"Evaluating {subject}"): | |
prompt, correct_answer_idx = format_prompt(item) | |
expected_letter = get_choice_letter(correct_answer_idx) | |
# Generate a short response, aiming for a single letter answer. | |
# do_sample=False (greedy decoding) is crucial for reproducibility. | |
raw_output = generator(prompt, max_new_tokens=5, do_sample=False)[0]["generated_text"] | |
predicted_letter = extract_predicted_letter(raw_output) | |
is_correct = (predicted_letter == expected_letter) | |
if is_correct: | |
correct_predictions += 1 | |
results_details.append({ | |
"question": item['question'], | |
"choices": item['choices'], | |
"raw_output": raw_output.strip(), | |
"expected_letter": expected_letter, | |
"predicted_letter": predicted_letter, | |
"is_correct": is_correct, | |
}) | |
accuracy = (correct_predictions / num_samples) * 100 if num_samples > 0 else 0 | |
return accuracy, results_details | |
def run_evaluation(model_id, benchmark_category, subject_name, sample_count, progress=gr.Progress(track_tqdm=True)): | |
""" | |
Main function to orchestrate the entire evaluation process. | |
Handles single subject or 'ALL' subjects evaluation. | |
Returns updates for Gradio UI components. | |
""" | |
try: | |
gr.Info("Starting evaluation...") | |
generator = load_model(model_id) | |
dataset_id = BENCHMARK_MAP.get(benchmark_category) | |
if not dataset_id: | |
raise ValueError(f"Invalid benchmark category: {benchmark_category}") | |
all_results_details = [] | |
summary_lines = [] | |
total_correct = 0 | |
total_samples = 0 | |
subjects_to_run = [] | |
if subject_name == "ALL": | |
subjects_to_run = ALL_BENCHMARK_SUBJECTS.get(benchmark_category, []) | |
if "ALL" in subjects_to_run: | |
subjects_to_run.remove("ALL") # Remove 'ALL' from the list of subjects to run | |
else: | |
subjects_to_run = [subject_name] | |
if not subjects_to_run: | |
gr.Warning(f"No subjects found for '{benchmark_category}'.") | |
return "", "", "", pd.DataFrame().to_dict('records') | |
for i, subject in enumerate(subjects_to_run): | |
gr.Info(f"Evaluating {benchmark_category} - {subject} ({i+1}/{len(subjects_to_run)})...") | |
try: | |
accuracy, subject_details = evaluate_single_subject(generator, dataset_id, subject, sample_count, progress) | |
all_results_details.extend(subject_details) | |
num_correct = sum(d['is_correct'] for d in subject_details) | |
num_evaluated = len(subject_details) | |
total_correct += num_correct | |
total_samples += num_evaluated | |
summary_lines.append(f"- **{subject}**: {accuracy:.2f}% ({num_correct}/{num_evaluated})") | |
except Exception as e: | |
gr.Error(f"Skipping {subject} due to an error: {e}") | |
summary_lines.append(f"- **{subject}**: Evaluation failed.") | |
continue | |
overall_accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0 | |
# --- Prepare Outputs --- | |
if subject_name == "ALL": | |
result_summary = f"### Overall Average Accuracy for {benchmark_category}: {overall_accuracy:.2f}%\n" | |
result_summary += "across {:,} total samples.\n\n---\n\n**Breakdown by Subject:**\n".format(total_samples) | |
result_summary += "\n".join(summary_lines) | |
else: | |
result_summary = f"### Accuracy for {benchmark_category} - {subject_name}: {overall_accuracy:.2f}%\n" | |
result_summary += "({:,}/{:,} correct)".format(total_correct, total_samples) | |
# Create a detailed DataFrame for inspection | |
df_details = pd.DataFrame(all_results_details) | |
# Save results for leaderboard | |
record = { | |
"model_id": model_id, | |
"benchmark": benchmark_category, | |
"accuracy": overall_accuracy, | |
"subject": subject_name, | |
"sample_count": total_samples, | |
"timestamp": pd.Timestamp.now().isoformat() | |
} | |
with open(EVAL_FILE, "a") as f: | |
f.write(json.dumps(record) + "\n") | |
gr.Info("Evaluation completed successfully!") | |
# Return updates for the UI | |
return ( | |
gr.update(value=result_summary, visible=True), | |
gr.update(value="", visible=False), # Hide error message | |
gr.update(value="", visible=False), # Hide error details | |
gr.update(value=df_details.to_dict('records'), visible=True) # Show detailed results table | |
) | |
except Exception as e: | |
error_message = f"An unexpected error occurred: {e}" | |
error_details = traceback.format_exc() | |
gr.Error(error_message) | |
# Return error updates for the UI | |
return ( | |
gr.update(value="", visible=False), # Hide results summary | |
gr.update(value=error_message, visible=True), | |
gr.update(value=error_details, visible=True), | |
gr.update(value=pd.DataFrame().to_dict('records'), visible=False) # Hide detailed results | |
) | |
# --- UI Helper Functions --- | |
def update_subject_dropdown(benchmark_category): | |
"""Updates the subject dropdown choices based on the selected benchmark.""" | |
choices = ALL_BENCHMARK_SUBJECTS.get(benchmark_category, []) | |
default_value = "ALL" if "ALL" in choices else (choices[0] if choices else None) | |
return gr.update(choices=choices, value=default_value) | |
def load_leaderboard(benchmark_filter): | |
""" | |
Loads and processes evaluation data to display on the leaderboard. | |
It now correctly averages scores for models that were evaluated on 'ALL' subjects. | |
""" | |
try: | |
if not os.path.exists(EVAL_FILE): | |
return pd.DataFrame(columns=["Model ID", "Avg. Accuracy (%)", "Total Samples"]).to_dict('records') | |
df = pd.read_json(EVAL_FILE, lines=True) | |
if df.empty: | |
return pd.DataFrame(columns=["Model ID", "Avg. Accuracy (%)", "Total Samples"]).to_dict('records') | |
# Coerce accuracy to numeric and filter valid entries | |
df['accuracy'] = pd.to_numeric(df['accuracy'], errors='coerce') | |
df.dropna(subset=['accuracy'], inplace=True) | |
# Filter by the selected benchmark (e.g., MMLU or MMLU-Pro) | |
df_filtered = df[df['benchmark'] == benchmark_filter].copy() | |
if df_filtered.empty: | |
return pd.DataFrame(columns=["Model ID", "Avg. Accuracy (%)", "Total Samples"]).to_dict('records') | |
# We are interested in the 'ALL' subject evaluations for the main leaderboard | |
df_all = df_filtered[df_filtered['subject'] == 'ALL'].copy() | |
if df_all.empty: | |
return pd.DataFrame(columns=["Model ID", "Avg. Accuracy (%)", "Total Samples"]).to_dict('records') | |
# Find the latest evaluation for each model | |
df_all['timestamp'] = pd.to_datetime(df_all['timestamp']) | |
latest_evals = df_all.loc[df_all.groupby('model_id')['timestamp'].idxmax()] | |
leaderboard_df = latest_evals[['model_id', 'accuracy', 'sample_count']].copy() | |
leaderboard_df.columns = ["Model ID", "Avg. Accuracy (%)", "Total Samples"] | |
# Format accuracy to 2 decimal places | |
leaderboard_df["Avg. Accuracy (%)"] = leaderboard_df["Avg. Accuracy (%)"].map('{:.2f}'.format) | |
# Sort by accuracy | |
leaderboard_df = leaderboard_df.sort_values(by="Avg. Accuracy (%)", ascending=False) | |
return leaderboard_df.to_dict('records') | |
except Exception as e: | |
gr.Error(f"Error loading leaderboard: {e}") | |
traceback.print_exc() | |
return pd.DataFrame(columns=["Model ID", "Avg. Accuracy (%)", "Total Samples"]).to_dict('records') | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Soft(), css=""" | |
/* --- Global & Layout --- */ | |
body { font-family: 'Inter', sans-serif; background-color: #f8f9fa; } | |
.gradio-container { max-width: 1280px !important; margin: auto; } | |
.gr-box { border-radius: 12px !important; box-shadow: 0 4px 12px rgba(0,0,0,0.05) !important; border: 1px solid #e9ecef !important; } | |
/* --- Typography --- */ | |
h1 { | |
text-align: center; | |
font-size: 2.5rem !important; | |
font-weight: 700; | |
color: #212529; | |
margin-bottom: 0.5rem; | |
letter-spacing: -1px; | |
} | |
.subtitle { | |
text-align: center; color: #6c757d; font-size: 1.1rem; margin-bottom: 2.5rem; | |
} | |
/* --- Buttons & Inputs --- */ | |
.gr-button { | |
border-radius: 8px !important; | |
font-weight: 600 !important; | |
padding: 10px 20px !important; | |
transition: all 0.2s ease; | |
} | |
.gr-button-primary { box-shadow: 0 4px 10px rgba(0, 123, 255, 0.2); } | |
.gr-button-primary:hover { transform: translateY(-2px); box-shadow: 0 6px 15px rgba(0, 123, 255, 0.3); } | |
/* --- Custom Radio Buttons (Segmented Control) --- */ | |
#leaderboard-toggle, #eval-benchmark-selection { | |
background-color: #e9ecef; | |
padding: 5px; | |
border-radius: 10px; | |
display: inline-flex; | |
margin: auto; | |
} | |
#leaderboard-toggle div.gr-form, #eval-benchmark-selection div.gr-form { | |
display: flex; | |
gap: 5px; | |
} | |
#leaderboard-toggle input[type='radio'], #eval-benchmark-selection input[type='radio'] { display: none; } | |
#leaderboard-toggle label, #eval-benchmark-selection label { | |
padding: 8px 16px; | |
border-radius: 8px; | |
cursor: pointer; | |
transition: background-color 0.3s, color 0.3s, box-shadow 0.3s; | |
font-weight: 500; | |
color: #495057; | |
background: transparent; | |
border: none; | |
box-shadow: none; | |
} | |
#leaderboard-toggle input[type='radio']:checked + label, #eval-benchmark-selection input[type='radio']:checked + label { | |
background-color: white; | |
color: #007bff; | |
font-weight: 600; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.1); | |
} | |
/* --- Dataframe / Table Styling --- */ | |
.leaderboard-table .gr-dataframe table { | |
border-collapse: collapse; | |
width: 100%; | |
} | |
.leaderboard-table .gr-dataframe thead th { | |
background-color: #f8f9fa !important; | |
color: #495057 !important; | |
font-weight: 600 !important; | |
text-align: left; | |
padding: 12px 15px; | |
border-bottom: 2px solid #dee2e6; | |
} | |
.leaderboard-table .gr-dataframe tbody tr:nth-of-type(even) { | |
background-color: #f8f9fa; | |
} | |
.leaderboard-table .gr-dataframe tbody tr:hover { | |
background-color: #e9ecef; | |
} | |
.leaderboard-table .gr-dataframe tbody td { | |
padding: 12px 15px; | |
border-bottom: 1px solid #dee2e6; | |
} | |
/* --- Error & Result Panes --- */ | |
#error-display-box { background-color: #fff3f3; border-color: #ffc9c9; } | |
#error-display-box .gr-label { color: #d9480f !important; font-weight: 600; } | |
#result-summary-box { background-color: #f3f9ff; border-color: #cde4ff; } | |
""") as demo: | |
gr.Markdown("<h1>π€ Open LLM Evaluator</h1>") | |
gr.Markdown("<p class='subtitle'>Benchmark leading models on MMLU and MMLU-Pro. Your results contribute to a live leaderboard.</p>") | |
with gr.Tabs(): | |
# --- Evaluation Tab --- | |
with gr.TabItem("π Run Evaluation"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Box(): | |
gr.Markdown("### 1. Configure Evaluation") | |
model_id_input = gr.Textbox( | |
label="Hugging Face Model ID", | |
placeholder="e.g., meta-llama/Meta-Llama-3-8B-Instruct", | |
interactive=True | |
) | |
with gr.Row(): | |
benchmark_selection_radio = gr.Radio( | |
["MMLU", "MMLU-Pro"], | |
label="Benchmark", | |
value="MMLU", | |
interactive=True, | |
elem_id="eval-benchmark-selection", | |
container=False | |
) | |
with gr.Row(): | |
benchmark_subject_dropdown = gr.Dropdown( | |
label="Subject", | |
choices=ALL_BENCHMARK_SUBJECTS.get("MMLU", []), | |
value="ALL", | |
interactive=True | |
) | |
sample_count_slider = gr.Slider( | |
label="Samples per Subject", | |
minimum=5, maximum=100, value=25, step=5, interactive=True | |
) | |
run_button = gr.Button("Start Evaluation", variant="primary", scale=1) | |
with gr.Column(scale=3): | |
gr.Markdown("### 2. View Results") | |
# Panel for displaying the summary of results | |
with gr.Box(visible=False, elem_id="result-summary-box") as result_summary_box: | |
result_summary_output = gr.Markdown() | |
# Panel for displaying errors | |
with gr.Box(visible=False, elem_id="error-display-box") as error_box: | |
error_output = gr.Textbox(label="Error Message", interactive=False) | |
error_details_output = gr.Textbox(label="Error Details (Traceback)", interactive=False, lines=8) | |
# Panel for detailed, row-by-row results | |
with gr.Box(visible=False) as details_box: | |
gr.Markdown("#### Detailed Evaluation Log") | |
detailed_results_df = gr.Dataframe( | |
headers=["Question", "Correct", "Expected", "Predicted", "Raw Output"], | |
datatype=["str", "bool", "str", "str", "str"], | |
interactive=False, | |
row_count=10, | |
col_count=5 | |
) | |
# --- Leaderboard Tab --- | |
with gr.TabItem("π Leaderboard"): | |
with gr.Column(): | |
gr.Markdown("<div style='display: flex; justify-content: center; width: 100%; margin-bottom: 20px;'></div>", elem_id="leaderboard-toggle-container") | |
leaderboard_type_toggle = gr.Radio( | |
["MMLU", "MMLU-Pro"], | |
label="Select Benchmark", | |
value="MMLU", | |
interactive=True, | |
elem_id="leaderboard-toggle", | |
container=False | |
) | |
leaderboard_table_output = gr.Dataframe( | |
headers=["Model ID", "Avg. Accuracy (%)", "Total Samples"], | |
interactive=False, | |
datatype=["str", "str", "number"], | |
row_count=15, | |
elem_classes="leaderboard-table" | |
) | |
# --- Event Handlers & Logic --- | |
# Update subject dropdown when benchmark type changes | |
benchmark_selection_radio.change( | |
fn=update_subject_dropdown, | |
inputs=[benchmark_selection_radio], | |
outputs=[benchmark_subject_dropdown] | |
) | |
# Main evaluation trigger | |
run_button.click( | |
fn=run_evaluation, | |
inputs=[model_id_input, benchmark_selection_radio, benchmark_subject_dropdown, sample_count_slider], | |
outputs=[result_summary_box, error_box, error_details_output, details_box] | |
).then( | |
# This chained function updates the component values *after* the visibility is set | |
lambda r, e, d, df: (r, e, d, df.to_dict('records')), | |
inputs=[result_summary_box, error_box, error_details_output, details_box], | |
outputs=[result_summary_output, error_output, error_details_output, detailed_results_df] | |
) | |
# Leaderboard loading logic | |
demo.load( | |
fn=load_leaderboard, | |
inputs=[leaderboard_type_toggle], | |
outputs=[leaderboard_table_output] | |
) | |
leaderboard_type_toggle.change( | |
fn=load_leaderboard, | |
inputs=[leaderboard_type_toggle], | |
outputs=[leaderboard_table_output] | |
) | |
# When the run button is clicked again, refresh the leaderboard | |
run_button.click( | |
fn=load_leaderboard, | |
inputs=[leaderboard_type_toggle], | |
outputs=[leaderboard_table_output] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |