Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,172 +1,615 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
-
from datasets import load_dataset, get_dataset_config_names
|
5 |
import torch
|
6 |
import re
|
7 |
import json
|
8 |
import pandas as pd
|
9 |
import matplotlib.pyplot as plt
|
|
|
10 |
|
11 |
# Cache to avoid reloading the model
|
12 |
model_cache = {}
|
13 |
|
14 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def load_model(model_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
if model_id in model_cache:
|
|
|
18 |
return model_cache[model_id]
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def format_prompt(item):
|
26 |
-
|
|
|
|
|
|
|
27 |
prompt = f"""{item['question']}
|
28 |
A. {item['choices'][0]}
|
29 |
B. {item['choices'][1]}
|
30 |
C. {item['choices'][2]}
|
31 |
D. {item['choices'][3]}
|
32 |
-
Answer:"""
|
33 |
-
return prompt, item['answer']
|
34 |
|
35 |
def extract_choice_letter(output):
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
match = re.search(r"\b([ABCD])\b", output.strip())
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
def get_choice_letter(index):
|
41 |
"""Converts a numerical choice index (0-3) to a capital letter (A-D)."""
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
total_correct += correct_subject
|
69 |
-
total_samples += len(dataset)
|
70 |
-
avg_accuracy = total_correct / total_samples * 100
|
71 |
-
return avg_accuracy, all_results
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
def save_text(text):
|
114 |
-
return "evaluation_results.txt", text
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
gr.Markdown("""
|
118 |
# π€ LLM Benchmark Evaluator
|
|
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
with gr.Row():
|
148 |
-
leaderboard_plot = gr.Plot(label="Leaderboard Chart")
|
149 |
-
leaderboard_table = gr.Dataframe(headers=["Model ID", "Average Accuracy"], interactive=False, datatype=["str", "number"], row_count=20, col_count=2)
|
150 |
-
|
151 |
-
def load_leaderboard():
|
152 |
-
try:
|
153 |
-
df = pd.read_json("eval.jsonl", lines=True)
|
154 |
-
df_avg = df.groupby("model_id")["accuracy"].mean().reset_index()
|
155 |
-
df_avg.columns = ["model_id", "average_accuracy"]
|
156 |
-
df_sorted = df_avg.sort_values(by="average_accuracy", ascending=False)
|
157 |
-
top10 = df_sorted.head(10)
|
158 |
-
|
159 |
-
fig, ax = plt.subplots()
|
160 |
-
ax.barh(top10['model_id'], top10['average_accuracy'])
|
161 |
-
ax.set_xlabel("Average Accuracy")
|
162 |
-
ax.set_ylabel("Model")
|
163 |
-
ax.set_title("Top 10 Models by Average Accuracy")
|
164 |
-
|
165 |
-
return fig, df_sorted
|
166 |
-
except Exception as e:
|
167 |
-
# Handle the case where eval.jsonl might not exist yet
|
168 |
-
return plt.figure(), pd.DataFrame(columns=["model_id", "average_accuracy"])
|
169 |
-
|
170 |
-
demo.load(load_leaderboard, inputs=[], outputs=[leaderboard_plot, leaderboard_table])
|
171 |
|
|
|
172 |
demo.launch()
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
+
from datasets import load_dataset, get_dataset_config_names
|
5 |
import torch
|
6 |
import re
|
7 |
import json
|
8 |
import pandas as pd
|
9 |
import matplotlib.pyplot as plt
|
10 |
+
import traceback # Import traceback for detailed error logging
|
11 |
|
12 |
# Cache to avoid reloading the model
|
13 |
model_cache = {}
|
14 |
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
16 |
|
17 |
+
# --- Constants for Benchmarks ---
|
18 |
+
MMLU_DATASET = "cais/mmlu"
|
19 |
+
MMLU_PRO_DATASET = "cais/mmlu_pro"
|
20 |
+
# Humanity's Last Exam is a composite benchmark, not a single dataset readily available like MMLU/MMLU-Pro.
|
21 |
+
# For this implementation, we will focus on MMLU and MMLU-Pro, which are direct datasets.
|
22 |
+
# Integrating HLE would require evaluating across multiple specific datasets.
|
23 |
+
|
24 |
+
def get_all_benchmark_options():
|
25 |
+
"""
|
26 |
+
Dynamically fetches all available subjects for MMLU and MMLU-Pro.
|
27 |
+
Returns a dictionary mapping benchmark dataset IDs to their subjects,
|
28 |
+
and a flattened list suitable for a Gradio dropdown.
|
29 |
+
"""
|
30 |
+
all_options = {}
|
31 |
+
gr_dropdown_options = []
|
32 |
+
|
33 |
+
# Get subjects for MMLU
|
34 |
+
try:
|
35 |
+
mmlu_subjects = get_dataset_config_names(MMLU_DATASET, token=HF_TOKEN)
|
36 |
+
all_options[MMLU_DATASET] = ["ALL"] + mmlu_subjects
|
37 |
+
gr_dropdown_options.extend([f"MMLU - {s}" for s in all_options[MMLU_DATASET]])
|
38 |
+
except Exception as e:
|
39 |
+
print(f"Warning: Could not load MMLU dataset configs. Error: {e}")
|
40 |
+
all_options[MMLU_DATASET] = []
|
41 |
+
|
42 |
+
# Get subjects for MMLU-Pro
|
43 |
+
try:
|
44 |
+
mmlu_pro_subjects = get_dataset_config_names(MMLU_PRO_DATASET, token=HF_TOKEN)
|
45 |
+
all_options[MMLU_PRO_DATASET] = ["ALL"] + mmlu_pro_subjects
|
46 |
+
gr_dropdown_options.extend([f"MMLU-Pro - {s}" for s in all_options[MMLU_PRO_DATASET]])
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Warning: Could not load MMLU-Pro dataset configs. It might not be accessible or available. Error: {e}")
|
49 |
+
all_options[MMLU_PRO_DATASET] = []
|
50 |
+
|
51 |
+
return all_options, gr_dropdown_options
|
52 |
+
|
53 |
+
# Initialize these once globally when the app starts
|
54 |
+
ALL_BENCHMARK_SUBJECTS, GRADIO_DROPDOWN_OPTIONS = get_all_benchmark_options()
|
55 |
+
|
56 |
+
|
57 |
def load_model(model_id):
|
58 |
+
"""
|
59 |
+
Loads a Hugging Face model and its tokenizer, then creates a text-generation pipeline.
|
60 |
+
Uses a cache to avoid re-loading if the model is already in memory.
|
61 |
+
Provides Gradio Info/Error messages for user feedback.
|
62 |
+
Raises an exception if model loading fails.
|
63 |
+
"""
|
64 |
+
gr.Info(f"Attempting to load model: {model_id}...")
|
65 |
if model_id in model_cache:
|
66 |
+
gr.Info(f"Model '{model_id}' already loaded from cache.")
|
67 |
return model_cache[model_id]
|
68 |
+
try:
|
69 |
+
# Load tokenizer and model, using bfloat16 if CUDA is available for efficiency
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
|
71 |
+
model = AutoModelForCausalLM.from_pretrained(
|
72 |
+
model_id,
|
73 |
+
token=HF_TOKEN,
|
74 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
75 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
76 |
+
|
77 |
+
# Create a text-generation pipeline
|
78 |
+
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
|
79 |
+
|
80 |
+
# Cache the loaded generator
|
81 |
+
model_cache[model_id] = generator
|
82 |
+
gr.Info(f"Model '{model_id}' loaded successfully.")
|
83 |
+
return generator
|
84 |
+
except Exception as e:
|
85 |
+
# Re-raise the exception to be caught by the outer run_evaluation try-except
|
86 |
+
raise ValueError(f"Failed to load model '{model_id}'. Please verify the model ID and your Hugging Face token. Error: {e}")
|
87 |
+
|
88 |
|
89 |
def format_prompt(item):
|
90 |
+
"""
|
91 |
+
Formats a single MMLU/MMLU-Pro question item into a clear prompt for the LLM.
|
92 |
+
The prompt is designed for the model to output a single letter answer (A, B, C, D).
|
93 |
+
"""
|
94 |
prompt = f"""{item['question']}
|
95 |
A. {item['choices'][0]}
|
96 |
B. {item['choices'][1]}
|
97 |
C. {item['choices'][2]}
|
98 |
D. {item['choices'][3]}
|
99 |
+
Answer:"""
|
100 |
+
return prompt, item['answer'] # Returns the prompt string and the correct choice index (0-3)
|
101 |
|
102 |
def extract_choice_letter(output):
|
103 |
+
"""
|
104 |
+
Extracts the most likely choice letter (A, B, C, D) from the model's generated output.
|
105 |
+
It prioritizes an exact match after "Answer:", then looks for any single capital letter.
|
106 |
+
"""
|
107 |
+
# Look for "Answer: X" pattern first (e.g., "Answer: A" or "Answer: B")
|
108 |
+
match = re.search(r"Answer:\s*([ABCD])", output, re.IGNORECASE) # Added IGNORECASE for robustness
|
109 |
+
if match:
|
110 |
+
return match.group(1).upper() # Ensure it's uppercase
|
111 |
+
|
112 |
+
# Fallback: look for a single capital letter A-D anywhere in the output
|
113 |
match = re.search(r"\b([ABCD])\b", output.strip())
|
114 |
+
if match:
|
115 |
+
return match.group(1)
|
116 |
+
|
117 |
+
return None # Return None if no valid choice letter is found
|
118 |
|
119 |
def get_choice_letter(index):
|
120 |
"""Converts a numerical choice index (0-3) to a capital letter (A-D)."""
|
121 |
+
if 0 <= index <= 3:
|
122 |
+
return chr(ord('A') + index)
|
123 |
+
return None # Return None for invalid indices
|
124 |
+
|
125 |
+
def evaluate_single_subject(generator, dataset_id, subject, sample_count, progress):
|
126 |
+
"""
|
127 |
+
Evaluates a given model generator on a specific subject from a specified dataset.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
generator: The Hugging Face pipeline for text generation.
|
131 |
+
dataset_id (str): The ID of the dataset (e.g., "cais/mmlu", "cais/mmlu_pro").
|
132 |
+
subject (str): The specific subject/config name within the dataset.
|
133 |
+
sample_count (int): The maximum number of samples to evaluate.
|
134 |
+
progress (gr.Progress): Gradio progress tracker.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
tuple: (accuracy, list_of_detailed_results)
|
138 |
+
Raises:
|
139 |
+
Exception: If dataset loading fails.
|
140 |
+
"""
|
141 |
+
gr.Info(f"Loading dataset: {dataset_id} - {subject}...")
|
142 |
+
try:
|
143 |
+
# Load the "test" split of the dataset
|
144 |
+
dataset = load_dataset(dataset_id, subject, token=HF_TOKEN)["test"]
|
145 |
+
except Exception as e:
|
146 |
+
# Re-raise the exception to be caught by the outer run_evaluation try-except
|
147 |
+
raise RuntimeError(f"Failed to load dataset '{dataset_id}' for subject '{subject}'. Error: {e}")
|
148 |
|
149 |
+
# Limit the number of samples and shuffle for consistent evaluation across runs
|
150 |
+
num_samples_to_evaluate = min(sample_count, len(dataset))
|
151 |
+
dataset = dataset.shuffle(seed=42).select(range(num_samples_to_evaluate))
|
152 |
+
|
153 |
+
correct_count = 0
|
154 |
+
subject_results = []
|
155 |
+
|
156 |
+
# Iterate through the selected samples with a progress bar
|
157 |
+
for i, item in enumerate(progress.tqdm(dataset, desc=f"Processing {subject} samples")):
|
158 |
+
prompt, answer_idx = format_prompt(item)
|
159 |
+
expected_letter = get_choice_letter(answer_idx)
|
160 |
+
|
161 |
+
# Generate only 1 new token for the answer (A, B, C, D)
|
162 |
+
# do_sample=False ensures deterministic output for a given prompt (greedy decoding)
|
163 |
+
output_raw = generator(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
|
164 |
|
165 |
+
# Check for potential reasoning model output
|
166 |
+
is_reasoning_model_output = '<' in output_raw or re.search(r"\b(because|therefore|thus|reasoning)\b", output_raw, re.IGNORECASE) is not None
|
167 |
+
|
168 |
+
# Extract the predicted letter from the model's raw output
|
169 |
+
predicted_letter = extract_choice_letter(output_raw)
|
170 |
+
|
171 |
+
is_correct = (predicted_letter == expected_letter)
|
172 |
+
correct_count += is_correct
|
173 |
+
|
174 |
+
# Store detailed results for logging and display
|
175 |
+
subject_results.append({
|
176 |
+
"question": item['question'],
|
177 |
+
"choices": item['choices'],
|
178 |
+
"model_raw_output": output_raw.strip(),
|
179 |
+
"expected_answer_letter": expected_letter,
|
180 |
+
"predicted_answer_letter": predicted_letter,
|
181 |
+
"is_correct": is_correct,
|
182 |
+
"is_reasoning_model_output": is_reasoning_model_output # Store the flag
|
183 |
+
})
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
# Calculate accuracy for the current subject
|
186 |
+
accuracy = (correct_count / len(dataset)) * 100 if len(dataset) > 0 else 0
|
187 |
+
return accuracy, subject_results
|
188 |
+
|
189 |
+
|
190 |
+
def run_evaluation(model_id, selected_benchmark_subject, sample_count, progress=gr.Progress()):
|
191 |
+
"""
|
192 |
+
Main function to orchestrate the evaluation process.
|
193 |
+
Handles single subject or 'ALL' subjects evaluation for MMLU/MMLU-Pro.
|
194 |
+
Returns Gradio.update objects to control UI component visibility and content.
|
195 |
+
"""
|
196 |
+
gr.Info("Starting evaluation...")
|
197 |
+
if not model_id:
|
198 |
+
gr.Warning("Please enter a Hugging Face Model ID before running the evaluation.")
|
199 |
+
# Return updates to hide logs/debug and show empty results
|
200 |
+
return "", gr.update(value="", visible=False), gr.update(visible=False), \
|
201 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False)
|
202 |
+
|
203 |
+
# Parse the selected benchmark and subject from the dropdown string
|
204 |
+
parts = selected_benchmark_subject.split(" - ")
|
205 |
+
if len(parts) != 2:
|
206 |
+
gr.Error("Invalid benchmark selection format. Please select from the dropdown.")
|
207 |
+
return "", gr.update(value="", visible=False), gr.update(visible=False), \
|
208 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False)
|
209 |
+
|
210 |
+
benchmark_name = parts[0]
|
211 |
+
subject_name = parts[1]
|
212 |
+
|
213 |
+
dataset_id_map = {
|
214 |
+
"MMLU": MMLU_DATASET,
|
215 |
+
"MMLU-Pro": MMLU_PRO_DATASET
|
216 |
+
}
|
217 |
+
current_dataset_id = dataset_id_map.get(benchmark_name)
|
218 |
+
|
219 |
+
if not current_dataset_id:
|
220 |
+
gr.Error(f"Unknown benchmark selected: {benchmark_name}. This should not happen.")
|
221 |
+
return "", gr.update(value="", visible=False), gr.update(visible=False), \
|
222 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False)
|
223 |
+
|
224 |
+
try:
|
225 |
+
generator = load_model(model_id) # This function will raise an exception on failure
|
226 |
+
|
227 |
+
all_evaluation_results = []
|
228 |
+
total_correct_overall = 0
|
229 |
+
total_samples_overall = 0
|
230 |
+
eval_summary_lines = []
|
231 |
+
|
232 |
+
if subject_name == "ALL":
|
233 |
+
subjects_to_evaluate = ALL_BENCHMARK_SUBJECTS.get(current_dataset_id, [])
|
234 |
+
if "ALL" in subjects_to_evaluate:
|
235 |
+
subjects_to_evaluate.remove("ALL")
|
236 |
+
|
237 |
+
if not subjects_to_evaluate:
|
238 |
+
gr.Warning(f"No subjects found to evaluate for '{benchmark_name}'.")
|
239 |
+
return "", gr.update(value="", visible=False), gr.update(visible=False), \
|
240 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False)
|
241 |
+
|
242 |
+
for i, sub in enumerate(progress.tqdm(subjects_to_evaluate, desc=f"Evaluating ALL {benchmark_name} subjects")):
|
243 |
+
gr.Info(f"Evaluating {benchmark_name} - {sub} ({i+1}/{len(subjects_to_evaluate)})...")
|
244 |
+
try:
|
245 |
+
accuracy, subject_details = evaluate_single_subject(generator, current_dataset_id, sub, sample_count, progress)
|
246 |
+
all_evaluation_results.extend(subject_details)
|
247 |
+
|
248 |
+
num_evaluated_samples = len(subject_details)
|
249 |
+
num_correct_in_subject = sum(d['is_correct'] for d in subject_details)
|
250 |
+
|
251 |
+
total_correct_overall += num_correct_in_subject
|
252 |
+
total_samples_overall += num_evaluated_samples
|
253 |
+
eval_summary_lines.append(f"- {benchmark_name} - {sub}: {accuracy:.2f}% ({num_correct_in_subject}/{num_evaluated_samples} samples)")
|
254 |
+
except Exception as e:
|
255 |
+
gr.Error(f"Skipping {benchmark_name} - {sub} due to an error: {e}")
|
256 |
+
eval_summary_lines.append(f"- {benchmark_name} - {sub}: Error during evaluation.")
|
257 |
+
continue
|
258 |
+
|
259 |
+
overall_accuracy = (total_correct_overall / total_samples_overall) * 100 if total_samples_overall > 0 else 0
|
260 |
+
score_string = f"Overall Average Accuracy for {benchmark_name}: {overall_accuracy:.2f}% across {total_samples_overall} total samples.\n\n"
|
261 |
+
score_string += "Detailed breakdown:\n" + "\n".join(eval_summary_lines)
|
262 |
+
|
263 |
+
else:
|
264 |
+
accuracy, subject_details = evaluate_single_subject(generator, current_dataset_id, subject_name, sample_count, progress)
|
265 |
+
all_evaluation_results.extend(subject_details)
|
266 |
+
overall_accuracy = accuracy
|
267 |
+
num_evaluated_samples = len(subject_details)
|
268 |
+
score_string = f"Accuracy for {benchmark_name} - {subject_name}: {accuracy:.2f}% out of {num_evaluated_samples} samples."
|
269 |
|
270 |
+
# Format detailed results for display in the text box
|
271 |
+
formatted_details = "\n\n".join([
|
272 |
+
f"### Question:\n{item['question']}\n\n"
|
273 |
+
f"**Choices:**\n" + "\n".join([f"{get_choice_letter(i)}. {c}" for i, c in enumerate(item['choices'])]) + "\n\n"
|
274 |
+
+ (f"**Note:** Reasoning models are currently not fully supported for single-letter extraction. The original model output followed:\n" if item.get('is_reasoning_model_output') else "")
|
275 |
+
f"**Model Raw Output:** {item['model_raw_output']}\n"
|
276 |
+
f"**Expected Answer:** {item['expected_answer_letter']}\n"
|
277 |
+
f"**Predicted Answer:** {item['predicted_answer_letter']}\n"
|
278 |
+
f"**Correct:** {'Yes' if item['is_correct'] else 'No'}"
|
279 |
+
for item in all_evaluation_results
|
280 |
+
])
|
281 |
|
282 |
+
# Record the evaluation result to a JSONL file for the leaderboard
|
283 |
+
record = {
|
284 |
+
"model_id": model_id,
|
285 |
+
"benchmark": benchmark_name,
|
286 |
+
"subject": subject_name,
|
287 |
+
"accuracy": overall_accuracy,
|
288 |
+
"sample_count": total_samples_overall if subject_name == "ALL" else len(all_evaluation_results),
|
289 |
+
"timestamp": pd.Timestamp.now().isoformat()
|
290 |
+
}
|
291 |
+
with open("eval.jsonl", "a") as f:
|
292 |
+
f.write(json.dumps(record) + "\n")
|
293 |
|
294 |
+
gr.Info("Evaluation completed successfully!")
|
295 |
+
return score_string, \
|
296 |
+
gr.update(value="", visible=False), gr.update(visible=False), \
|
297 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(value=formatted_details, visible=False)
|
298 |
+
|
299 |
+
except Exception as e:
|
300 |
+
error_message = str(e)
|
301 |
+
detailed_error_traceback = traceback.format_exc()
|
302 |
+
gr.Error("An error occurred during evaluation.")
|
303 |
+
|
304 |
+
# Return updates for error state
|
305 |
+
return "Error occurred during evaluation. We'll evaluate for you! If this persists, please open a community support tab for assistance.", \
|
306 |
+
gr.update(value=detailed_error_traceback, visible=True), gr.update(visible=True), \
|
307 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False)
|
308 |
+
|
309 |
+
def save_text(text_content):
|
310 |
+
"""Saves the provided text content to a file and returns the file path for download."""
|
311 |
+
if not text_content:
|
312 |
+
gr.Warning("No evaluation results to download.")
|
313 |
+
return None
|
314 |
+
file_path = "evaluation_results.txt"
|
315 |
+
try:
|
316 |
+
with open(file_path, "w") as f:
|
317 |
+
f.write(text_content)
|
318 |
+
return file_path
|
319 |
+
except Exception as e:
|
320 |
+
gr.Error(f"Error saving file: {e}")
|
321 |
+
return None
|
322 |
+
|
323 |
+
def load_leaderboard():
|
324 |
+
"""
|
325 |
+
Loads evaluation data from 'eval.jsonl', computes average accuracy per model,
|
326 |
+
and prepares data for the leaderboard plot and table.
|
327 |
+
"""
|
328 |
+
try:
|
329 |
+
# Read the JSONL file into a pandas DataFrame
|
330 |
+
df = pd.read_json("eval.jsonl", lines=True)
|
331 |
|
332 |
+
# Calculate average accuracy per model across all recorded evaluations
|
333 |
+
df_avg = df.groupby("model_id")["accuracy"].mean().reset_index()
|
334 |
+
df_avg.columns = ["Model ID", "Average Accuracy (%)"]
|
335 |
+
|
336 |
+
# Sort models by average accuracy in descending order
|
337 |
+
df_sorted = df_avg.sort_values(by="Average Accuracy (%)", ascending=False)
|
338 |
+
|
339 |
+
# Select top 10 models for the bar chart
|
340 |
+
top_models = df_sorted.head(10)
|
341 |
+
|
342 |
+
# Create the matplotlib plot
|
343 |
+
fig, ax = plt.subplots(figsize=(10, 6)) # Adjust figure size for better readability
|
344 |
+
# For horizontal bars, it's often better to plot data sorted in ascending order
|
345 |
+
# so the highest bar appears at the top of the chart.
|
346 |
+
top_models_plot = top_models.sort_values(by="Average Accuracy (%)", ascending=True)
|
347 |
|
348 |
+
ax.barh(top_models_plot['Model ID'], top_models_plot['Average Accuracy (%)'], color='#007bff') # Use a nice blue color
|
349 |
+
ax.set_xlabel("Average Accuracy (%)", fontsize=12)
|
350 |
+
ax.set_ylabel("Model ID", fontsize=12)
|
351 |
+
ax.set_title("Top 10 Models by Average MMLU/MMLU-Pro Accuracy", fontsize=14)
|
352 |
+
ax.set_xlim(0, 100) # Ensure accuracy scale is 0-100%
|
353 |
+
ax.tick_params(axis='x', labelsize=10)
|
354 |
+
ax.tick_params(axis='y', labelsize=10)
|
355 |
+
ax.grid(axis='x', linestyle='--', alpha=0.7) # Add grid lines
|
356 |
+
plt.tight_layout() # Adjust layout to prevent labels overlapping
|
357 |
|
358 |
+
# Return the figure and the sorted dataframe as a list of dictionaries for Gradio Dataframe
|
359 |
+
return fig, df_sorted.to_dict('records')
|
360 |
+
except FileNotFoundError:
|
361 |
+
gr.Warning("No evaluation data found yet. Run an evaluation to populate the leaderboard!")
|
362 |
+
return plt.figure(), pd.DataFrame(columns=["Model ID", "Average Accuracy (%)"]).to_dict('records')
|
363 |
+
except Exception as e:
|
364 |
+
gr.Error(f"Error loading leaderboard: {e}")
|
365 |
+
# Return an empty plot and dataframe in case of any other error
|
366 |
+
return plt.figure(), pd.DataFrame(columns=["Model ID", "Average Accuracy (%)"]).to_dict('records')
|
367 |
|
368 |
+
|
369 |
+
# --- Gradio Interface Definition ---
|
370 |
+
with gr.Blocks(css="""
|
371 |
+
/* General body and container styling */
|
372 |
+
body { font-family: 'Inter', sans-serif; background-color: #f0f2f5; margin: 0; padding: 20px; }
|
373 |
+
.gradio-container {
|
374 |
+
max-width: 1200px;
|
375 |
+
margin: 20px auto;
|
376 |
+
padding: 30px;
|
377 |
+
box-shadow: 0 8px 16px rgba(0,0,0,0.15);
|
378 |
+
border-radius: 12px;
|
379 |
+
background-color: #ffffff;
|
380 |
+
border: 1px solid #e0e0e0;
|
381 |
+
}
|
382 |
|
383 |
+
/* Headings */
|
384 |
+
h1 {
|
385 |
+
color: #2c3e50;
|
386 |
+
text-align: center;
|
387 |
+
margin-bottom: 30px;
|
388 |
+
font-size: 2.5em;
|
389 |
+
font-weight: 700;
|
390 |
+
letter-spacing: -0.02em;
|
391 |
+
}
|
392 |
+
h3 { color: #34495e; font-size: 1.2em; margin-bottom: 10px; }
|
393 |
+
|
394 |
+
/* Markdown text */
|
395 |
+
.markdown-text { text-align: center; color: #555; line-height: 1.6; }
|
396 |
+
.markdown-text div { font-size: 1.1em; }
|
397 |
+
|
398 |
+
/* Buttons */
|
399 |
+
.gr-button {
|
400 |
+
background-color: #007bff; /* Primary blue */
|
401 |
+
color: white;
|
402 |
+
border: none;
|
403 |
+
padding: 12px 25px;
|
404 |
+
border-radius: 8px;
|
405 |
+
cursor: pointer;
|
406 |
+
transition: background-color 0.3s ease, transform 0.2s ease;
|
407 |
+
font-size: 1.1em;
|
408 |
+
font-weight: 600;
|
409 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
|
410 |
+
}
|
411 |
+
.gr-button:hover {
|
412 |
+
background-color: #0056b3; /* Darker blue on hover */
|
413 |
+
transform: translateY(-2px); /* Slight lift effect */
|
414 |
+
}
|
415 |
+
.gr-button:active {
|
416 |
+
transform: translateY(0);
|
417 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
418 |
+
}
|
419 |
+
/* Specific button styling for debug/show details */
|
420 |
+
#debug-button, #show-details-button {
|
421 |
+
background-color: #6c757d; /* Grey for secondary actions */
|
422 |
+
}
|
423 |
+
#debug-button:hover, #show-details-button:hover {
|
424 |
+
background-color: #5a6268;
|
425 |
+
}
|
426 |
+
#download-button {
|
427 |
+
background-color: #28a745; /* Green for download */
|
428 |
+
}
|
429 |
+
#download-button:hover {
|
430 |
+
background-color: #218838;
|
431 |
+
}
|
432 |
+
|
433 |
|
434 |
+
/* Input/Output Boxes */
|
435 |
+
.gr-box {
|
436 |
+
border: 1px solid #dee2e6;
|
437 |
+
border-radius: 10px;
|
438 |
+
padding: 20px;
|
439 |
+
margin-bottom: 20px;
|
440 |
+
background-color: #fdfdfd;
|
441 |
+
box-shadow: inset 0 1px 3px rgba(0,0,0,0.05);
|
442 |
+
}
|
443 |
+
.gr-output-text {
|
444 |
+
white-space: pre-wrap;
|
445 |
+
word-wrap: break-word;
|
446 |
+
background-color: #f9f9fb;
|
447 |
+
border: 1px solid #e9ecef;
|
448 |
+
border-radius: 8px;
|
449 |
+
padding: 15px;
|
450 |
+
min-height: 100px; /* Ensure a minimum height */
|
451 |
+
}
|
452 |
+
/* Specific error output style */
|
453 |
+
#error-message-output {
|
454 |
+
background-color: #ffe0e0;
|
455 |
+
border-color: #ff9999;
|
456 |
+
color: #cc0000;
|
457 |
+
}
|
458 |
|
|
|
|
|
459 |
|
460 |
+
/* Labels for inputs */
|
461 |
+
.gr-textbox label, .gr-dropdown label, .gr-slider label {
|
462 |
+
font-weight: 600;
|
463 |
+
color: #495057;
|
464 |
+
margin-bottom: 8px;
|
465 |
+
display: block;
|
466 |
+
font-size: 1em;
|
467 |
+
}
|
468 |
+
|
469 |
+
/* Tab styling */
|
470 |
+
.gr-tab-item { padding: 25px; } /* More padding inside tabs */
|
471 |
+
.gr-tabs-nav button {
|
472 |
+
font-weight: 600;
|
473 |
+
font-size: 1.1em;
|
474 |
+
padding: 10px 20px;
|
475 |
+
border-top-left-radius: 8px;
|
476 |
+
border-top-right-radius: 8px;
|
477 |
+
}
|
478 |
+
""") as demo:
|
479 |
gr.Markdown("""
|
480 |
# π€ LLM Benchmark Evaluator
|
481 |
+
""")
|
482 |
|
483 |
+
with gr.Tabs():
|
484 |
+
with gr.TabItem("π Run Evaluation"):
|
485 |
+
gr.Markdown("""
|
486 |
+
<div style="text-align: center; margin-bottom: 20px; color: #666; font-size: 1.1em;">
|
487 |
+
Enter your Hugging Face Model ID, choose a benchmark (MMLU or MMLU-Pro),
|
488 |
+
select a subject (or 'ALL' for a comprehensive evaluation),
|
489 |
+
and specify the number of samples per subject.
|
490 |
+
</div>
|
491 |
+
""")
|
492 |
+
|
493 |
+
with gr.Column(elem_classes="gr-box"):
|
494 |
+
model_id_input = gr.Textbox(
|
495 |
+
label="Your Hugging Face Model ID",
|
496 |
+
placeholder="e.g., mistralai/Mistral-7B-Instruct-v0.2",
|
497 |
+
interactive=True
|
498 |
+
)
|
499 |
+
with gr.Row():
|
500 |
+
benchmark_subject_dropdown = gr.Dropdown(
|
501 |
+
label="Choose Benchmark and Subject",
|
502 |
+
choices=GRADIO_DROPDOWN_OPTIONS,
|
503 |
+
value="MMLU - ALL", # Default to MMLU ALL for initial load
|
504 |
+
interactive=True,
|
505 |
+
min_width=400 # Ensure sufficient width
|
506 |
+
)
|
507 |
+
sample_count_slider = gr.Slider(
|
508 |
+
label="Number of Samples per Subject (1-100)",
|
509 |
+
minimum=1,
|
510 |
+
maximum=100,
|
511 |
+
value=10, # Default to 10 samples
|
512 |
+
step=1,
|
513 |
+
interactive=True,
|
514 |
+
min_width=200
|
515 |
+
)
|
516 |
+
run_button = gr.Button("π Run Evaluation", elem_classes="gr-button")
|
517 |
|
518 |
+
with gr.Column(elem_classes="gr-box"):
|
519 |
+
acc_output = gr.Textbox(
|
520 |
+
label="Benchmark Accuracy Results",
|
521 |
+
interactive=False,
|
522 |
+
elem_classes="gr-output-text",
|
523 |
+
lines=5,
|
524 |
+
placeholder="Evaluation results will appear here."
|
525 |
+
)
|
526 |
+
|
527 |
+
# Container for debug info, initially hidden
|
528 |
+
with gr.Column(visible=False, elem_id="debug-error-column") as debug_error_column:
|
529 |
+
error_message_output = gr.Textbox(
|
530 |
+
label="Debug Information (Error Details)",
|
531 |
+
lines=10, interactive=False, elem_classes="gr-output-text", elem_id="error-message-output",
|
532 |
+
placeholder="Error details will appear here if an error occurs."
|
533 |
+
)
|
534 |
+
debug_button = gr.Button("π Hide Debug Info", visible=True, elem_id="debug-button", elem_classes="gr-button")
|
535 |
+
|
536 |
+
with gr.Row():
|
537 |
+
show_details_button = gr.Button("π Show Detailed Logs", visible=False, elem_id="show-details-button", elem_classes="gr-button")
|
538 |
+
download_button = gr.Button("π₯ Download Full Evaluation Logs", visible=False, elem_id="download-button", elem_classes="gr-button")
|
539 |
+
|
540 |
+
# Detailed output, initially hidden
|
541 |
+
detail_output = gr.Textbox(
|
542 |
+
label="Detailed Evaluation Logs",
|
543 |
+
lines=20,
|
544 |
+
interactive=False,
|
545 |
+
elem_classes="gr-output-text",
|
546 |
+
placeholder="Detailed logs for each question will appear here upon successful evaluation.",
|
547 |
+
visible=False # Initially hidden
|
548 |
+
)
|
549 |
+
|
550 |
+
# Define button click actions
|
551 |
+
run_button.click(
|
552 |
+
run_evaluation,
|
553 |
+
inputs=[model_id_input, benchmark_subject_dropdown, sample_count_slider],
|
554 |
+
outputs=[
|
555 |
+
acc_output,
|
556 |
+
error_message_output, debug_error_column, # For error state
|
557 |
+
show_details_button, download_button, detail_output # For success state
|
558 |
+
]
|
559 |
+
)
|
560 |
+
|
561 |
+
# Toggle visibility of detail_output
|
562 |
+
show_details_button.click(
|
563 |
+
lambda s: gr.update(visible=not s), # Toggle visibility
|
564 |
+
inputs=[detail_output], # Pass the component itself as input
|
565 |
+
outputs=[detail_output] # The component to update
|
566 |
+
)
|
567 |
+
# Change button text based on visibility
|
568 |
+
show_details_button.click(
|
569 |
+
lambda s: "π Hide Detailed Logs" if not s else "π Show Detailed Logs",
|
570 |
+
inputs=[detail_output],
|
571 |
+
outputs=[show_details_button]
|
572 |
+
)
|
573 |
+
|
574 |
+
# Toggle visibility of debug error column
|
575 |
+
debug_button.click(
|
576 |
+
lambda s: gr.update(visible=not s), # Toggle visibility
|
577 |
+
inputs=[debug_error_column], # Pass the component itself as input
|
578 |
+
outputs=[debug_error_column] # The component to update
|
579 |
+
)
|
580 |
+
# Change debug button text based on visibility
|
581 |
+
debug_button.click(
|
582 |
+
lambda s: "π Show Debug Info" if not s else "π Hide Debug Info",
|
583 |
+
inputs=[debug_error_column],
|
584 |
+
outputs=[debug_button]
|
585 |
+
)
|
586 |
+
|
587 |
+
download_button.click(
|
588 |
+
save_text,
|
589 |
+
inputs=[detail_output],
|
590 |
+
outputs=gr.File(label="Download Evaluation Results", file_count="single", type="filepath")
|
591 |
+
)
|
592 |
|
593 |
+
with gr.TabItem("π Leaderboard"):
|
594 |
+
gr.Markdown("""
|
595 |
+
<div style="text-align: center; margin-bottom: 20px; color: #666; font-size: 1.1em;">
|
596 |
+
See how different models perform on average across all evaluated benchmarks.
|
597 |
+
This leaderboard updates with every new evaluation.
|
598 |
+
</div>
|
599 |
+
""")
|
600 |
+
with gr.Row():
|
601 |
+
leaderboard_plot_output = gr.Plot(label="Top 10 Models by Average Accuracy", scale=2) # Scale for better visibility
|
602 |
+
leaderboard_table_output = gr.Dataframe(
|
603 |
+
headers=["Model ID", "Average Accuracy (%)"],
|
604 |
+
interactive=False,
|
605 |
+
datatype=["str", "number"],
|
606 |
+
row_count=10, # Display top 10 rows initially, but can scroll
|
607 |
+
col_count=2,
|
608 |
+
label="Full Leaderboard Data"
|
609 |
+
)
|
610 |
+
|
611 |
+
# Load leaderboard when the tab is selected or when the app loads
|
612 |
+
demo.load(load_leaderboard, inputs=[], outputs=[leaderboard_plot_output, leaderboard_table_output])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
613 |
|
614 |
+
# Launch the Gradio app
|
615 |
demo.launch()
|