|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from collections import Counter |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import re |
|
import logging |
|
from typing import List, Dict, Any |
|
import gc |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
PREDEFINED_MODELS = [ |
|
"meta-llama/Llama-3.2-1B", |
|
"google/gemma-2-2b", |
|
"Qwen/Qwen3-0.6B", |
|
"Qwen/Qwen2.5-0.5B", |
|
"Qwen/Qwen2.5-1.5B", |
|
"bigscience/bloom-560m", |
|
"CohereForAI/aya-expanse-8b", |
|
"common-pile/comma-v0.1-2t", |
|
"google/byt5-small", |
|
"google/byt5-small", |
|
"gsaltintas/supertoken_models-llama_gpt2", |
|
] |
|
|
|
model_cache = {} |
|
|
|
def parse_dataset(text): |
|
"""Parse the input dataset text into structured questions""" |
|
if not text.strip(): |
|
return [], "Please enter your dataset" |
|
|
|
lines = text.strip().split('\n') |
|
if len(lines) < 2: |
|
return [], "Dataset must have at least a header and one question" |
|
|
|
|
|
first_data_line = lines[1] if len(lines) > 1 else lines[0] |
|
delimiter = '\t' if '\t' in first_data_line else ',' |
|
|
|
questions = [] |
|
errors = [] |
|
|
|
for i, line in enumerate(lines[1:], 2): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
|
|
parts = [part.strip().strip('"') for part in line.split(delimiter)] |
|
|
|
if len(parts) < 5: |
|
errors.append(f"Line {i}: Not enough columns (need 5, got {len(parts)})") |
|
continue |
|
|
|
question = { |
|
'question': parts[0], |
|
'correct_answer': parts[1], |
|
'choices': [parts[2], parts[3], parts[4]] |
|
} |
|
|
|
|
|
if question['correct_answer'] not in question['choices']: |
|
question['choices'].append(question['correct_answer']) |
|
|
|
questions.append(question) |
|
|
|
error_msg = '\n'.join(errors) if errors else "" |
|
return questions, error_msg |
|
|
|
|
|
def load_model_and_tokenizer(model_path, use_cache=True, progress_callback=None): |
|
"""Load model and tokenizer with caching""" |
|
global model_cache |
|
|
|
if use_cache and model_path in model_cache: |
|
logger.info(f"Using cached model: {model_path}") |
|
if progress_callback: |
|
progress_callback(1.0, f"✅ Using cached model: {model_path}") |
|
return model_cache[model_path] |
|
|
|
try: |
|
if progress_callback: |
|
progress_callback(0.1, f"🔄 Starting to load model: {model_path}") |
|
|
|
logger.info(f"Loading model: {model_path}") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if progress_callback: |
|
progress_callback(0.2, f"📥 Loading tokenizer for {model_path}...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, legacy=True) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
if progress_callback: |
|
progress_callback(0.5, f"🧠 Loading model weights for {model_path}... (this may take a while)") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
device_map="auto" if device== "cuda" else None, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
model_info = { |
|
'tokenizer': tokenizer, |
|
'model': model, |
|
'device': device |
|
} |
|
|
|
if use_cache: |
|
model_cache[model_path] = model_info |
|
|
|
if progress_callback: |
|
progress_callback(1.0, f"✅ Successfully loaded model: {model_path}") |
|
|
|
return model_info |
|
|
|
except Exception as e: |
|
import code |
|
error_msg = f"❌ Error loading model {model_path}: {str(e)}" |
|
logger.error(error_msg) |
|
|
|
if progress_callback: |
|
progress_callback(0.0, error_msg) |
|
return None |
|
|
|
def calculate_choice_likelihood(model, tokenizer, question, choice): |
|
"""Calculate the log-likelihood of the choice given the question prompt""" |
|
try: |
|
prompt = f"Question: {question}\nAnswer: " |
|
prompt=question |
|
full_text = f"{prompt} {choice}" |
|
|
|
|
|
input_ids = tokenizer.encode(full_text, return_tensors="pt", add_special_tokens=False).to(model.device) |
|
prompt_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) |
|
|
|
if input_ids.size(1) <= prompt_ids.size(1): |
|
logger.warning("Answer tokens are empty after tokenization.") |
|
return float("-inf") |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids) |
|
logits = outputs.logits |
|
|
|
|
|
answer_len = input_ids.size(1) - prompt_ids.size(1) |
|
target_ids = input_ids[:, -answer_len:] |
|
logits = logits[:, prompt_ids.size(1)-1:-1, :] |
|
|
|
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) |
|
token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) |
|
|
|
total_log_prob = token_log_probs.sum().item() |
|
return total_log_prob |
|
|
|
except Exception as e: |
|
logger.error(f"Error calculating likelihood for choice '{choice}': {str(e)}") |
|
return float("-inf") |
|
|
|
|
|
|
|
def evaluate_model_on_questions(model_path, questions, progress_callback=None): |
|
"""Evaluate a single model on all questions using likelihood-based scoring""" |
|
|
|
model_info = load_model_and_tokenizer(model_path, progress_callback=progress_callback) |
|
|
|
if model_info is None: |
|
return [{'error': f'Failed to load model {model_path}'}] * len(questions) |
|
|
|
results = [] |
|
model = model_info['model'] |
|
tokenizer = model_info['tokenizer'] |
|
|
|
for i, question in enumerate(questions): |
|
try: |
|
|
|
choice_likelihoods = {} |
|
choice_probs = {} |
|
|
|
for choice in question['choices']: |
|
likelihood = calculate_choice_likelihood(model, tokenizer, question['question'], choice) |
|
choice_likelihoods[choice] = likelihood |
|
|
|
|
|
max_log_prob = max(choice_likelihoods.values()) |
|
choice_probs = {choice: torch.exp(torch.tensor(log_prob - max_log_prob)).item() |
|
for choice, log_prob in choice_likelihoods.items()} |
|
|
|
|
|
total_prob = sum(choice_probs.values()) |
|
if total_prob > 0: |
|
choice_probs = {choice: prob / total_prob for choice, prob in choice_probs.items()} |
|
|
|
|
|
predicted_choice = max(choice_likelihoods.keys(), key=lambda x: choice_likelihoods[x]) |
|
is_correct = predicted_choice == question['correct_answer'] |
|
|
|
|
|
confidence = choice_probs.get(predicted_choice, 0.0) |
|
|
|
results.append({ |
|
'question_idx': i, |
|
'predicted': predicted_choice, |
|
'correct': is_correct, |
|
'confidence': confidence, |
|
'choice_likelihoods': choice_likelihoods, |
|
'choice_probabilities': choice_probs, |
|
'raw_response': f"Likelihoods: {choice_likelihoods}" |
|
}) |
|
|
|
if progress_callback: |
|
|
|
evaluation_progress = 0.2 + (i + 1) / len(questions) * 0.8 |
|
progress_callback(evaluation_progress, f"🔍 Evaluating {model_path}: {i+1}/{len(questions)} questions (likelihood-based)") |
|
|
|
except Exception as e: |
|
logger.error(f"Error evaluating question {i} with {model_path}: {str(e)}") |
|
results.append({ |
|
'question_idx': i, |
|
'predicted': question['choices'][0] if question['choices'] else '', |
|
'correct': False, |
|
'confidence': 0.0, |
|
'choice_likelihoods': {}, |
|
'choice_probabilities': {}, |
|
'raw_response': f"Error: {str(e)}" |
|
}) |
|
|
|
return results |
|
|
|
def run_evaluation(dataset_text, selected_predefined, custom_models_text="", progress=gr.Progress()): |
|
"""Main evaluation function""" |
|
if not dataset_text.strip(): |
|
return ( |
|
"Please enter your dataset", |
|
"<p>No data provided</p>", |
|
None, |
|
None, |
|
gr.update(visible=True) |
|
) |
|
|
|
|
|
custom_models = [] |
|
if custom_models_text is None: |
|
custom_models_text = "" |
|
if custom_models_text.strip(): |
|
custom_models = [model.strip() for model in custom_models_text.strip().split('\n') if model.strip()] |
|
|
|
|
|
all_models = [] |
|
|
|
|
|
all_models.extend(selected_predefined) |
|
all_models.extend(custom_models) |
|
|
|
if not all_models: |
|
return ( |
|
"Please select at least one model or add custom models", |
|
"<p>No models selected</p>", |
|
None, |
|
None, |
|
gr.update(visible=False) |
|
) |
|
|
|
|
|
questions, parse_error = parse_dataset(dataset_text) |
|
|
|
if parse_error: |
|
return ( |
|
f"Dataset parsing error:\n{parse_error}", |
|
"<p>Failed to parse dataset</p>", |
|
None, |
|
None, |
|
gr.update(visible=True) |
|
) |
|
|
|
if not questions: |
|
return ( |
|
"No valid questions found in dataset", |
|
"<p>No questions to evaluate</p>", |
|
None, |
|
None, |
|
gr.update(visible=True) |
|
) |
|
|
|
|
|
progress(0, "Starting evaluation...") |
|
results = {} |
|
total_steps = len(all_models) * len(questions) |
|
current_step = 0 |
|
|
|
summary_md = create_summary_markdown({}) |
|
for model_path in all_models: |
|
display_name = model_path.split('/')[-1] if '/' in model_path else model_path |
|
try: |
|
def model_progress(p, msg): |
|
nonlocal current_step |
|
current_step = int(p * len(questions)) |
|
overall_progress = current_step / total_steps |
|
progress(overall_progress, msg) |
|
|
|
model_results = evaluate_model_on_questions(model_path, questions, model_progress) |
|
results[display_name] = model_results |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to evaluate {display_name}: {str(e)}") |
|
results[display_name] = [{'error': str(e)}] * len(questions) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
summary_stats = generate_summary_stats(questions, results) |
|
summary_md = create_summary_markdown(summary_stats) |
|
detailed_html = create_detailed_results_html(questions, results) |
|
accuracy_chart = create_accuracy_chart(summary_stats) |
|
confidence_chart = create_confidence_chart(results) |
|
|
|
return ( |
|
summary_md, |
|
detailed_html, |
|
accuracy_chart, |
|
confidence_chart, |
|
gr.update(visible=True) |
|
) |
|
|
|
def generate_summary_stats(questions, results): |
|
"""Generate summary statistics for all models""" |
|
summary = {} |
|
|
|
for model, model_results in results.items(): |
|
if not model_results or 'error' in model_results[0]: |
|
summary[model] = { |
|
'accuracy': 0.0, |
|
'correct': 0, |
|
'total': len(questions), |
|
'avg_confidence': 0.0, |
|
'error': model_results[0].get('error', 'Unknown error') if model_results else 'No results' |
|
} |
|
continue |
|
|
|
correct_count = sum(1 for r in model_results if r.get('correct', False)) |
|
total_count = len(model_results) |
|
accuracy = correct_count / total_count if total_count > 0 else 0 |
|
|
|
|
|
avg_confidence = sum(r.get('confidence', 0) for r in model_results) / total_count if total_count > 0 else 0 |
|
|
|
summary[model] = { |
|
'accuracy': accuracy, |
|
'correct': correct_count, |
|
'total': total_count, |
|
'avg_confidence': avg_confidence |
|
} |
|
|
|
return summary |
|
|
|
def create_summary_markdown(summary_stats): |
|
"""Create markdown summary of results""" |
|
if not summary_stats: |
|
return "No results available" |
|
|
|
|
|
sorted_models = sorted(summary_stats.items(), key=lambda x: x[1]['accuracy'], reverse=True) |
|
|
|
lines = ["## 🏆 Model Performance Summary\n"] |
|
|
|
for i, (model, stats) in enumerate(sorted_models): |
|
if 'error' in stats: |
|
lines.append(f"❌ **{model}**: Error - {stats['error']}") |
|
continue |
|
|
|
accuracy_pct = stats['accuracy'] * 100 |
|
medal = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else f"{i+1}." |
|
|
|
lines.append( |
|
f"{medal} **{model}**: {accuracy_pct:.1f}% " |
|
f"({stats['correct']}/{stats['total']} correct, " |
|
f"avg confidence: {stats['avg_confidence']:.2f})" |
|
) |
|
|
|
return "\n".join(lines) |
|
|
|
def create_detailed_results_html(questions, results): |
|
"""Create detailed HTML results for each question""" |
|
if not questions or not results: |
|
return "<p>No detailed results available</p>" |
|
|
|
html_parts = [""" |
|
<style> |
|
.question-card { |
|
background: white; |
|
border-radius: 12px; |
|
padding: 20px; |
|
margin-bottom: 20px; |
|
box-shadow: 0 2px 8px rgba(0,0,0,0.1); |
|
border-left: 5px solid #667eea; |
|
} |
|
.question-header { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-bottom: 15px; |
|
} |
|
.question-number { |
|
background: linear-gradient(135deg, #667eea, #764ba2); |
|
color: white; |
|
padding: 6px 12px; |
|
border-radius: 20px; |
|
font-weight: bold; |
|
font-size: 14px; |
|
} |
|
.question-text { |
|
font-weight: 600; |
|
font-size: 16px; |
|
margin: 15px 0; |
|
color: #2d3748; |
|
} |
|
.choices { |
|
background: #f8fafc; |
|
border-radius: 8px; |
|
padding: 15px; |
|
margin: 10px 0; |
|
} |
|
.choice { |
|
margin: 8px 0; |
|
color: #4a5568; |
|
} |
|
.correct-answer { |
|
background: linear-gradient(135deg, #c6f6d5, #9ae6b4); |
|
border-left: 4px solid #48bb78; |
|
border-radius: 6px; |
|
padding: 12px; |
|
margin: 10px 0; |
|
font-weight: 600; |
|
color: #22543d; |
|
} |
|
.model-results { |
|
display: grid; |
|
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); |
|
gap: 12px; |
|
margin-top: 15px; |
|
} |
|
.model-result { |
|
padding: 12px; |
|
border-radius: 8px; |
|
text-align: center; |
|
font-weight: 600; |
|
transition: transform 0.2s ease; |
|
} |
|
.model-result:hover { |
|
transform: scale(1.02); |
|
} |
|
.result-correct { |
|
background: linear-gradient(135deg, #c6f6d5, #9ae6b4); |
|
color: #22543d; |
|
border: 2px solid #48bb78; |
|
} |
|
.result-incorrect { |
|
background: linear-gradient(135deg, #fed7d7, #fca5a5); |
|
color: #742a2a; |
|
border: 2px solid #e53e3e; |
|
} |
|
.result-error { |
|
background: linear-gradient(135deg, #fbb6ce, #f687b3); |
|
color: #744210; |
|
border: 2px solid #d69e2e; |
|
} |
|
.raw-response { |
|
font-size: 10px; |
|
margin-top: 4px; |
|
opacity: 0.7; |
|
font-family: monospace; |
|
} |
|
</style> |
|
"""] |
|
|
|
for q_idx, question in enumerate(questions): |
|
html_parts.append(f""" |
|
<div class="question-card"> |
|
<div class="question-header"> |
|
<span class="question-number">Q{q_idx + 1}</span> |
|
</div> |
|
<div class="question-text">{question['question']}</div> |
|
<div class="choices"> |
|
<strong>Choices:</strong><br> |
|
{' | '.join(f'{chr(65+i)}) {choice}' for i, choice in enumerate(question['choices']))} |
|
</div> |
|
<div class="correct-answer"> |
|
<strong>✓ Correct Answer:</strong> {question['correct_answer']} |
|
</div> |
|
<div class="model-results"> |
|
""") |
|
|
|
|
|
for model, model_results in results.items(): |
|
if q_idx < len(model_results): |
|
result = model_results[q_idx] |
|
|
|
if 'error' in result: |
|
html_parts.append(f""" |
|
<div class="model-result result-error"> |
|
<div>⚠️ {model}</div> |
|
<div style="font-size: 12px; margin-top: 4px;"> |
|
Error occurred |
|
</div> |
|
<div class="raw-response">{result.get('raw_response', 'Unknown error')}</div> |
|
</div> |
|
""") |
|
else: |
|
result_class = 'result-correct' if result.get('correct', False) else 'result-incorrect' |
|
icon = '✅' if result.get('correct', False) else '❌' |
|
|
|
html_parts.append(f""" |
|
<div class="model-result {result_class}"> |
|
<div>{icon} {model}</div> |
|
<div style="font-size: 12px; margin-top: 4px;"> |
|
"{result.get('predicted', 'No prediction')}" |
|
</div> |
|
<div class="raw-response">Raw: "{result.get('raw_response', '')}"</div> |
|
</div> |
|
""") |
|
|
|
html_parts.append(""" |
|
</div> |
|
</div> |
|
""") |
|
|
|
return "".join(html_parts) |
|
|
|
def create_accuracy_chart(summary_stats): |
|
"""Create accuracy comparison chart""" |
|
if not summary_stats: |
|
return None |
|
|
|
models = [] |
|
accuracies = [] |
|
|
|
for model, stats in summary_stats.items(): |
|
if 'error' not in stats: |
|
models.append(model) |
|
accuracies.append(stats['accuracy'] * 100) |
|
|
|
if not models: |
|
return None |
|
|
|
fig = go.Figure(data=[ |
|
go.Bar( |
|
x=models, |
|
y=accuracies, |
|
marker_color='lightblue', |
|
text=[f'{acc:.1f}%' for acc in accuracies], |
|
textposition='auto', |
|
) |
|
]) |
|
|
|
fig.update_layout( |
|
title="Model Accuracy Comparison", |
|
xaxis_title="Models", |
|
yaxis_title="Accuracy (%)", |
|
template="plotly_white", |
|
showlegend=False |
|
) |
|
|
|
return fig |
|
|
|
def create_confidence_chart(results): |
|
"""Create confidence distribution chart""" |
|
if not results: |
|
return None |
|
|
|
data = [] |
|
for model, model_results in results.items(): |
|
for result in model_results: |
|
if 'error' not in result and 'confidence' in result: |
|
data.append({ |
|
'Model': model, |
|
'Confidence': result['confidence'], |
|
'Correct': 'Correct' if result.get('correct', False) else 'Incorrect' |
|
}) |
|
|
|
if not data: |
|
return None |
|
|
|
df = pd.DataFrame(data) |
|
|
|
fig = px.box( |
|
df, |
|
x='Model', |
|
y='Confidence', |
|
color='Correct', |
|
title="Confidence Distribution by Model and Correctness", |
|
template="plotly_white" |
|
) |
|
|
|
return fig |
|
|
|
|
|
SAMPLE_DATASETS = { |
|
"Custom (enter below)": "", |
|
"LP": """Question,Correct Answer,Choice1,Choice2,Choice3 |
|
In which country is Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch located? Wales Germany France Scotland |
|
In which country is Llanfair pwllgwyngyll located? Wales Germany France Scotland |
|
In which country is Llanfair PG located? Wales Germany France Scotland""", |
|
"Simple Math": """Question,Correct Answer,Choice1,Choice2,Choice3 |
|
What is 2+2?,4,3,2,5 |
|
What is 5*3?,15,12,16,18 |
|
What is 10-7?,3,7,4,2 |
|
What is 8/2?,4,3,2,5""", |
|
|
|
"World Capitals": """Question,Correct Answer,Choice1,Choice2,Choice3 |
|
What is the capital of France?,Paris,London,Berlin,Rome |
|
What is the capital of Japan?,Tokyo,Seoul,Beijing,Bangkok |
|
What is the capital of Brazil?,Brasília,Rio de Janeiro,São Paulo,Salvador |
|
What is the capital of Australia?,Canberra,Sydney,Melbourne,Perth""", |
|
|
|
"Science Quiz": """Question,Correct Answer,Choice1,Choice2,Choice3 |
|
What is the chemical symbol for gold?,Au,Ag,Ca,K |
|
Which planet is closest to the Sun?,Mercury,Venus,Earth,Mars |
|
What is the speed of light?,299792458 m/s,300000000 m/s,2992458 m/s,299000000 m/s |
|
What gas do plants absorb from the atmosphere?,Carbon dioxide,Oxygen,Nitrogen,Hydrogen""" |
|
} |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
font-family: 'Inter', sans-serif; |
|
} |
|
.sample-text { |
|
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; |
|
font-size: 12px; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(title="🤖 Model Performance Comparison", theme=gr.themes.Soft(), css=css) as demo: |
|
gr.Markdown(""" |
|
# 🤖 Model Performance Comparison Tool |
|
|
|
Compare LLM performance on multiple-choice questions using Hugging Face models. |
|
|
|
**Format**: Each line should have: `Question,Correct Answer,Choice1,Choice2,Choice3` |
|
|
|
💡 **Features**: |
|
- Model evaluation using HuggingFace transformers |
|
- Support for custom models via HF model paths |
|
- Detailed question-by-question results |
|
- Performance charts and statistics |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
sample_selector = gr.Dropdown( |
|
choices=list(SAMPLE_DATASETS.keys()), |
|
value="Custom (enter below)", |
|
label="Choose sample dataset or enter your own", |
|
interactive=True |
|
) |
|
|
|
|
|
dataset_input = gr.Textbox( |
|
label="Dataset (CSV/TSV format)", |
|
placeholder="""Enter your dataset here... |
|
|
|
Example format: |
|
Question,Correct Answer,Choice1,Choice2,Choice3 |
|
What is 2+2?,4,3,2,5 |
|
What is the capital of France?,Paris,London,Berlin,Paris""", |
|
lines=8, |
|
max_lines=15 |
|
) |
|
|
|
gr.Markdown(""" |
|
**Format Requirements**: |
|
- First line: header (will be ignored), leave empty if no header |
|
- Each data line: Question, Correct Answer, Choice1, Choice2, Choice3 |
|
- Use commas or tabs as separators |
|
""") |
|
|
|
with gr.Column(scale=1): |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("🤖 Predefined Models"): |
|
predefined_selector = gr.CheckboxGroup( |
|
choices=PREDEFINED_MODELS, |
|
value=[PREDEFINED_MODELS[0]], |
|
label="Select from popular models", |
|
interactive=True |
|
) |
|
|
|
with gr.TabItem("➕ Custom Models"): |
|
custom_models_input = gr.Textbox( |
|
label="Custom HuggingFace Model Paths", |
|
placeholder="""Enter HuggingFace model paths (one per line): |
|
|
|
microsoft/DialoGPT-medium |
|
bigscience/bloom-560m""", |
|
lines=5, |
|
info="Add any HuggingFace model path. One model per line.", |
|
) |
|
|
|
gr.Markdown(""" |
|
**Examples of valid model paths**: |
|
- `microsoft/DialoGPT-medium` |
|
- `bigscience/bloom-560m` |
|
- `facebook/opt-350m` |
|
- Your own fine-tuned models! |
|
""") |
|
|
|
|
|
evaluate_btn = gr.Button( |
|
"⚡ Run Evaluation", |
|
variant="primary", |
|
scale=1 |
|
) |
|
|
|
gr.Markdown(""" |
|
**⚠️ Note**: |
|
- Larger models require more GPU memory, currently we only run on CPU |
|
- First run will download models (may take time) |
|
- Models are cached for subsequent runs |
|
""") |
|
|
|
|
|
with gr.Column(visible=True) as results_section: |
|
gr.Markdown("## 📊 Results") |
|
|
|
summary_output = gr.Markdown( |
|
value="Results will appear here...", |
|
label="Performance Summary" |
|
) |
|
|
|
with gr.Row(): |
|
accuracy_plot = gr.Plot(label="Accuracy Comparison") |
|
confidence_plot = gr.Plot(label="Confidence Analysis") |
|
|
|
detailed_results = gr.HTML( |
|
value="<p>Detailed results will appear here...</p>", |
|
label="Detailed Question-by-Question Results" |
|
) |
|
|
|
|
|
def update_dataset_from_sample(sample_name): |
|
if sample_name in SAMPLE_DATASETS: |
|
return gr.update(value=SAMPLE_DATASETS[sample_name]) |
|
return gr.update() |
|
|
|
sample_selector.change( |
|
fn=update_dataset_from_sample, |
|
inputs=sample_selector, |
|
outputs=dataset_input |
|
) |
|
|
|
evaluate_btn.click( |
|
fn=run_evaluation, |
|
inputs=[dataset_input, predefined_selector, custom_models_input], |
|
outputs=[summary_output, detailed_results, accuracy_plot, confidence_plot, results_section] |
|
) |
|
|
|
gr.Markdown(""" |
|
--- |
|
### About Model Evaluation |
|
|
|
This tool loads and runs HuggingFace models for evaluation: |
|
|
|
**🏗️ How it works**: |
|
- Downloads models from HuggingFace Hub |
|
- Formats questions as prompts for each model |
|
- Runs likelihood based evaluation |
|
|
|
**⚡ Performance Tips**: |
|
- Use smaller models for testing |
|
- Larger models (7B+) require significant GPU memory |
|
- Models are cached after first load |
|
|
|
**🔧 Supported Models**: |
|
- Any HuggingFace autoregressive language model |
|
- Both instruction-tuned and base models |
|
- Custom fine-tuned models via HF paths |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |