Gül Sena Altıntaş commited on
Commit
7ebe82f
·
1 Parent(s): b318650

- supertoken model not working [WIP]

Files changed (2) hide show
  1. app.py +798 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from collections import Counter
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import re
9
+ import logging
10
+ from typing import List, Dict, Any
11
+ import gc
12
+
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Model configurations - maps display names to HF model paths
18
+ PREDEFINED_MODELS = [
19
+ "meta-llama/Llama-3.2-1B",
20
+ "google/gemma-2-2b",
21
+ "Qwen/Qwen3-0.6B",
22
+ "Qwen/Qwen2.5-0.5B",
23
+ "Qwen/Qwen2.5-1.5B",
24
+ "bigscience/bloom-560m",
25
+ "CohereForAI/aya-expanse-8b",
26
+ "common-pile/comma-v0.1-2t",
27
+ "google/byt5-small",
28
+ "google/byt5-small",
29
+ "gsaltintas/supertoken_models-llama_gpt2",
30
+ ]
31
+ # Global cache for loaded models
32
+ model_cache = {}
33
+
34
+ def parse_dataset(text):
35
+ """Parse the input dataset text into structured questions"""
36
+ if not text.strip():
37
+ return [], "Please enter your dataset"
38
+
39
+ lines = text.strip().split('\n')
40
+ if len(lines) < 2:
41
+ return [], "Dataset must have at least a header and one question"
42
+
43
+ # Skip header and detect delimiter
44
+ first_data_line = lines[1] if len(lines) > 1 else lines[0]
45
+ delimiter = '\t' if '\t' in first_data_line else ','
46
+
47
+ questions = []
48
+ errors = []
49
+
50
+ for i, line in enumerate(lines[1:], 2): # Start from line 2 (after header)
51
+ line = line.strip()
52
+ if not line:
53
+ continue
54
+
55
+ parts = [part.strip().strip('"') for part in line.split(delimiter)]
56
+
57
+ if len(parts) < 5:
58
+ errors.append(f"Line {i}: Not enough columns (need 5, got {len(parts)})")
59
+ continue
60
+
61
+ question = {
62
+ 'question': parts[0],
63
+ 'correct_answer': parts[1],
64
+ 'choices': [parts[2], parts[3], parts[4]]
65
+ }
66
+
67
+ # Ensure correct answer is in choices
68
+ if question['correct_answer'] not in question['choices']:
69
+ question['choices'].append(question['correct_answer'])
70
+
71
+ questions.append(question)
72
+
73
+ error_msg = '\n'.join(errors) if errors else ""
74
+ return questions, error_msg
75
+
76
+
77
+ def load_model_and_tokenizer(model_path, use_cache=True, progress_callback=None):
78
+ """Load model and tokenizer with caching"""
79
+ global model_cache
80
+
81
+ if use_cache and model_path in model_cache:
82
+ logger.info(f"Using cached model: {model_path}")
83
+ if progress_callback:
84
+ progress_callback(1.0, f"✅ Using cached model: {model_path}")
85
+ return model_cache[model_path]
86
+
87
+ try:
88
+ if progress_callback:
89
+ progress_callback(0.1, f"🔄 Starting to load model: {model_path}")
90
+
91
+ logger.info(f"Loading model: {model_path}")
92
+
93
+ # Check if CUDA is available
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+
96
+ if progress_callback:
97
+ progress_callback(0.2, f"📥 Loading tokenizer for {model_path}...")
98
+
99
+ # Load tokenizer
100
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, legacy=True)
101
+
102
+ # Add pad token if missing
103
+ if tokenizer.pad_token is None:
104
+ tokenizer.pad_token = tokenizer.eos_token
105
+
106
+ if progress_callback:
107
+ progress_callback(0.5, f"🧠 Loading model weights for {model_path}... (this may take a while)")
108
+
109
+ # Load model with appropriate settings
110
+ model = AutoModelForCausalLM.from_pretrained(
111
+ model_path,
112
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
113
+ device_map="auto" if device == "cuda" else None,
114
+ trust_remote_code=True,
115
+ low_cpu_mem_usage=True
116
+ )
117
+
118
+ model_info = {
119
+ 'tokenizer': tokenizer,
120
+ 'model': model,
121
+ 'device': device
122
+ }
123
+
124
+ if use_cache:
125
+ model_cache[model_path] = model_info
126
+
127
+ if progress_callback:
128
+ progress_callback(1.0, f"✅ Successfully loaded model: {model_path}")
129
+
130
+ return model_info
131
+
132
+ except Exception as e:
133
+ import code
134
+ code.interact(local=dict(globals(), **locals()))
135
+ error_msg = f"❌ Error loading model {model_path}: {str(e)}"
136
+ logger.error(error_msg)
137
+ if progress_callback:
138
+ progress_callback(0.0, error_msg)
139
+ return None
140
+
141
+ def calculate_choice_likelihood(model, tokenizer, question, choice):
142
+ """Calculate the log-likelihood of the choice given the question prompt"""
143
+ try:
144
+ prompt = f"Question: {question}\nAnswer: "
145
+ prompt=question
146
+ full_text = f"{prompt} {choice}"
147
+
148
+ # Tokenize full input (prompt + answer)
149
+ input_ids = tokenizer.encode(full_text, return_tensors="pt", add_special_tokens=False).to(model.device)
150
+ prompt_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
151
+
152
+ if input_ids.size(1) <= prompt_ids.size(1):
153
+ logger.warning("Answer tokens are empty after tokenization.")
154
+ return float("-inf")
155
+
156
+ with torch.no_grad():
157
+ outputs = model(input_ids)
158
+ logits = outputs.logits
159
+
160
+ # Get logits for the answer tokens only
161
+ answer_len = input_ids.size(1) - prompt_ids.size(1)
162
+ target_ids = input_ids[:, -answer_len:]
163
+ logits = logits[:, prompt_ids.size(1)-1:-1, :] # shifted for next-token prediction
164
+
165
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
166
+ token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
167
+
168
+ total_log_prob = token_log_probs.sum().item()
169
+ return total_log_prob
170
+
171
+ except Exception as e:
172
+ logger.error(f"Error calculating likelihood for choice '{choice}': {str(e)}")
173
+ return float("-inf")
174
+
175
+
176
+
177
+ def evaluate_model_on_questions(model_path, questions, progress_callback=None):
178
+ """Evaluate a single model on all questions using likelihood-based scoring"""
179
+
180
+ model_info = load_model_and_tokenizer(model_path, progress_callback=progress_callback)
181
+
182
+ if model_info is None:
183
+ return [{'error': f'Failed to load model {model_path}'}] * len(questions)
184
+
185
+ results = []
186
+ model = model_info['model']
187
+ tokenizer = model_info['tokenizer']
188
+
189
+ for i, question in enumerate(questions):
190
+ try:
191
+ # Calculate likelihood for each choice
192
+ choice_likelihoods = {}
193
+ choice_probs = {}
194
+
195
+ for choice in question['choices']:
196
+ likelihood = calculate_choice_likelihood(model, tokenizer, question['question'], choice)
197
+ choice_likelihoods[choice] = likelihood
198
+
199
+ # Convert log probabilities to probabilities for confidence scoring
200
+ max_log_prob = max(choice_likelihoods.values())
201
+ choice_probs = {choice: torch.exp(torch.tensor(log_prob - max_log_prob)).item()
202
+ for choice, log_prob in choice_likelihoods.items()}
203
+
204
+ # Normalize probabilities
205
+ total_prob = sum(choice_probs.values())
206
+ if total_prob > 0:
207
+ choice_probs = {choice: prob / total_prob for choice, prob in choice_probs.items()}
208
+
209
+ # Select the choice with highest likelihood
210
+ predicted_choice = max(choice_likelihoods.keys(), key=lambda x: choice_likelihoods[x])
211
+ is_correct = predicted_choice == question['correct_answer']
212
+
213
+ # Confidence is the probability of the selected choice
214
+ confidence = choice_probs.get(predicted_choice, 0.0)
215
+
216
+ results.append({
217
+ 'question_idx': i,
218
+ 'predicted': predicted_choice,
219
+ 'correct': is_correct,
220
+ 'confidence': confidence,
221
+ 'choice_likelihoods': choice_likelihoods,
222
+ 'choice_probabilities': choice_probs,
223
+ 'raw_response': f"Likelihoods: {choice_likelihoods}"
224
+ })
225
+
226
+ if progress_callback:
227
+ # Use remaining 80% for evaluation progress
228
+ evaluation_progress = 0.2 + (i + 1) / len(questions) * 0.8
229
+ progress_callback(evaluation_progress, f"🔍 Evaluating {model_path}: {i+1}/{len(questions)} questions (likelihood-based)")
230
+
231
+ except Exception as e:
232
+ logger.error(f"Error evaluating question {i} with {model_path}: {str(e)}")
233
+ results.append({
234
+ 'question_idx': i,
235
+ 'predicted': question['choices'][0] if question['choices'] else '',
236
+ 'correct': False,
237
+ 'confidence': 0.0,
238
+ 'choice_likelihoods': {},
239
+ 'choice_probabilities': {},
240
+ 'raw_response': f"Error: {str(e)}"
241
+ })
242
+
243
+ return results
244
+
245
+ def run_evaluation(dataset_text, selected_predefined, custom_models_text, progress=gr.Progress()):
246
+ """Main evaluation function"""
247
+ if not dataset_text.strip():
248
+ return (
249
+ "Please enter your dataset",
250
+ "<p>No data provided</p>",
251
+ None,
252
+ None,
253
+ gr.update(visible=True)
254
+ )
255
+
256
+ # Parse custom models
257
+ custom_models = []
258
+ if custom_models_text.strip():
259
+ custom_models = [model.strip() for model in custom_models_text.strip().split('\n') if model.strip()]
260
+
261
+ # Combine selected models
262
+ all_models = []
263
+
264
+ # Add predefined models
265
+ all_models.extend(selected_predefined)
266
+ all_models.extend(custom_models)
267
+
268
+ if not all_models:
269
+ return (
270
+ "Please select at least one model or add custom models",
271
+ "<p>No models selected</p>",
272
+ None,
273
+ None,
274
+ gr.update(visible=False)
275
+ )
276
+
277
+ # Parse dataset
278
+ questions, parse_error = parse_dataset(dataset_text)
279
+
280
+ if parse_error:
281
+ return (
282
+ f"Dataset parsing error:\n{parse_error}",
283
+ "<p>Failed to parse dataset</p>",
284
+ None,
285
+ None,
286
+ gr.update(visible=True)
287
+ )
288
+
289
+ if not questions:
290
+ return (
291
+ "No valid questions found in dataset",
292
+ "<p>No questions to evaluate</p>",
293
+ None,
294
+ None,
295
+ gr.update(visible=True)
296
+ )
297
+
298
+ # Run evaluation
299
+ progress(0, "Starting evaluation...")
300
+ results = {}
301
+ total_steps = len(all_models) * len(questions)
302
+ current_step = 0
303
+
304
+ summary_md = create_summary_markdown({})
305
+ for model_path in all_models:
306
+ display_name = model_path.split('/')[-1] if '/' in model_path else model_path
307
+ try:
308
+ def model_progress(p, msg):
309
+ nonlocal current_step
310
+ current_step = int(p * len(questions))
311
+ overall_progress = current_step / total_steps
312
+ progress(overall_progress, msg)
313
+
314
+ model_results = evaluate_model_on_questions(model_path, questions, model_progress)
315
+ results[display_name] = model_results
316
+
317
+ except Exception as e:
318
+ logger.error(f"Failed to evaluate {display_name}: {str(e)}")
319
+ results[display_name] = [{'error': str(e)}] * len(questions)
320
+
321
+ # Clean up GPU memory
322
+ if torch.cuda.is_available():
323
+ torch.cuda.empty_cache()
324
+ gc.collect()
325
+
326
+ # Generate outputs
327
+ summary_stats = generate_summary_stats(questions, results)
328
+ summary_md = create_summary_markdown(summary_stats)
329
+ detailed_html = create_detailed_results_html(questions, results)
330
+ accuracy_chart = create_accuracy_chart(summary_stats)
331
+ confidence_chart = create_confidence_chart(results)
332
+
333
+ return (
334
+ summary_md,
335
+ detailed_html,
336
+ accuracy_chart,
337
+ confidence_chart,
338
+ gr.update(visible=True)
339
+ )
340
+
341
+ def generate_summary_stats(questions, results):
342
+ """Generate summary statistics for all models"""
343
+ summary = {}
344
+
345
+ for model, model_results in results.items():
346
+ if not model_results or 'error' in model_results[0]:
347
+ summary[model] = {
348
+ 'accuracy': 0.0,
349
+ 'correct': 0,
350
+ 'total': len(questions),
351
+ 'avg_confidence': 0.0,
352
+ 'error': model_results[0].get('error', 'Unknown error') if model_results else 'No results'
353
+ }
354
+ continue
355
+
356
+ correct_count = sum(1 for r in model_results if r.get('correct', False))
357
+ total_count = len(model_results)
358
+ accuracy = correct_count / total_count if total_count > 0 else 0
359
+
360
+ # Calculate average confidence
361
+ avg_confidence = sum(r.get('confidence', 0) for r in model_results) / total_count if total_count > 0 else 0
362
+
363
+ summary[model] = {
364
+ 'accuracy': accuracy,
365
+ 'correct': correct_count,
366
+ 'total': total_count,
367
+ 'avg_confidence': avg_confidence
368
+ }
369
+
370
+ return summary
371
+
372
+ def create_summary_markdown(summary_stats):
373
+ """Create markdown summary of results"""
374
+ if not summary_stats:
375
+ return "No results available"
376
+
377
+ # Sort by accuracy
378
+ sorted_models = sorted(summary_stats.items(), key=lambda x: x[1]['accuracy'], reverse=True)
379
+
380
+ lines = ["## 🏆 Model Performance Summary\n"]
381
+
382
+ for i, (model, stats) in enumerate(sorted_models):
383
+ if 'error' in stats:
384
+ lines.append(f"❌ **{model}**: Error - {stats['error']}")
385
+ continue
386
+
387
+ accuracy_pct = stats['accuracy'] * 100
388
+ medal = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else f"{i+1}."
389
+
390
+ lines.append(
391
+ f"{medal} **{model}**: {accuracy_pct:.1f}% "
392
+ f"({stats['correct']}/{stats['total']} correct, "
393
+ f"avg confidence: {stats['avg_confidence']:.2f})"
394
+ )
395
+
396
+ return "\n".join(lines)
397
+
398
+ def create_detailed_results_html(questions, results):
399
+ """Create detailed HTML results for each question"""
400
+ if not questions or not results:
401
+ return "<p>No detailed results available</p>"
402
+
403
+ html_parts = ["""
404
+ <style>
405
+ .question-card {
406
+ background: white;
407
+ border-radius: 12px;
408
+ padding: 20px;
409
+ margin-bottom: 20px;
410
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
411
+ border-left: 5px solid #667eea;
412
+ }
413
+ .question-header {
414
+ display: flex;
415
+ justify-content: space-between;
416
+ align-items: center;
417
+ margin-bottom: 15px;
418
+ }
419
+ .question-number {
420
+ background: linear-gradient(135deg, #667eea, #764ba2);
421
+ color: white;
422
+ padding: 6px 12px;
423
+ border-radius: 20px;
424
+ font-weight: bold;
425
+ font-size: 14px;
426
+ }
427
+ .question-text {
428
+ font-weight: 600;
429
+ font-size: 16px;
430
+ margin: 15px 0;
431
+ color: #2d3748;
432
+ }
433
+ .choices {
434
+ background: #f8fafc;
435
+ border-radius: 8px;
436
+ padding: 15px;
437
+ margin: 10px 0;
438
+ }
439
+ .choice {
440
+ margin: 8px 0;
441
+ color: #4a5568;
442
+ }
443
+ .correct-answer {
444
+ background: linear-gradient(135deg, #c6f6d5, #9ae6b4);
445
+ border-left: 4px solid #48bb78;
446
+ border-radius: 6px;
447
+ padding: 12px;
448
+ margin: 10px 0;
449
+ font-weight: 600;
450
+ color: #22543d;
451
+ }
452
+ .model-results {
453
+ display: grid;
454
+ grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
455
+ gap: 12px;
456
+ margin-top: 15px;
457
+ }
458
+ .model-result {
459
+ padding: 12px;
460
+ border-radius: 8px;
461
+ text-align: center;
462
+ font-weight: 600;
463
+ transition: transform 0.2s ease;
464
+ }
465
+ .model-result:hover {
466
+ transform: scale(1.02);
467
+ }
468
+ .result-correct {
469
+ background: linear-gradient(135deg, #c6f6d5, #9ae6b4);
470
+ color: #22543d;
471
+ border: 2px solid #48bb78;
472
+ }
473
+ .result-incorrect {
474
+ background: linear-gradient(135deg, #fed7d7, #fca5a5);
475
+ color: #742a2a;
476
+ border: 2px solid #e53e3e;
477
+ }
478
+ .result-error {
479
+ background: linear-gradient(135deg, #fbb6ce, #f687b3);
480
+ color: #744210;
481
+ border: 2px solid #d69e2e;
482
+ }
483
+ .raw-response {
484
+ font-size: 10px;
485
+ margin-top: 4px;
486
+ opacity: 0.7;
487
+ font-family: monospace;
488
+ }
489
+ </style>
490
+ """]
491
+
492
+ for q_idx, question in enumerate(questions):
493
+ html_parts.append(f"""
494
+ <div class="question-card">
495
+ <div class="question-header">
496
+ <span class="question-number">Q{q_idx + 1}</span>
497
+ </div>
498
+ <div class="question-text">{question['question']}</div>
499
+ <div class="choices">
500
+ <strong>Choices:</strong><br>
501
+ {' | '.join(f'{chr(65+i)}) {choice}' for i, choice in enumerate(question['choices']))}
502
+ </div>
503
+ <div class="correct-answer">
504
+ <strong>✓ Correct Answer:</strong> {question['correct_answer']}
505
+ </div>
506
+ <div class="model-results">
507
+ """)
508
+
509
+ # Add results for each model
510
+ for model, model_results in results.items():
511
+ if q_idx < len(model_results):
512
+ result = model_results[q_idx]
513
+
514
+ if 'error' in result:
515
+ html_parts.append(f"""
516
+ <div class="model-result result-error">
517
+ <div>⚠️ {model}</div>
518
+ <div style="font-size: 12px; margin-top: 4px;">
519
+ Error occurred
520
+ </div>
521
+ <div class="raw-response">{result.get('raw_response', 'Unknown error')}</div>
522
+ </div>
523
+ """)
524
+ else:
525
+ result_class = 'result-correct' if result.get('correct', False) else 'result-incorrect'
526
+ icon = '✅' if result.get('correct', False) else '❌'
527
+
528
+ html_parts.append(f"""
529
+ <div class="model-result {result_class}">
530
+ <div>{icon} {model}</div>
531
+ <div style="font-size: 12px; margin-top: 4px;">
532
+ "{result.get('predicted', 'No prediction')}"
533
+ </div>
534
+ <div class="raw-response">Raw: "{result.get('raw_response', '')}"</div>
535
+ </div>
536
+ """)
537
+
538
+ html_parts.append("""
539
+ </div>
540
+ </div>
541
+ """)
542
+
543
+ return "".join(html_parts)
544
+
545
+ def create_accuracy_chart(summary_stats):
546
+ """Create accuracy comparison chart"""
547
+ if not summary_stats:
548
+ return None
549
+
550
+ models = []
551
+ accuracies = []
552
+
553
+ for model, stats in summary_stats.items():
554
+ if 'error' not in stats:
555
+ models.append(model)
556
+ accuracies.append(stats['accuracy'] * 100)
557
+
558
+ if not models:
559
+ return None
560
+
561
+ fig = go.Figure(data=[
562
+ go.Bar(
563
+ x=models,
564
+ y=accuracies,
565
+ marker_color='lightblue',
566
+ text=[f'{acc:.1f}%' for acc in accuracies],
567
+ textposition='auto',
568
+ )
569
+ ])
570
+
571
+ fig.update_layout(
572
+ title="Model Accuracy Comparison",
573
+ xaxis_title="Models",
574
+ yaxis_title="Accuracy (%)",
575
+ template="plotly_white",
576
+ showlegend=False
577
+ )
578
+
579
+ return fig
580
+
581
+ def create_confidence_chart(results):
582
+ """Create confidence distribution chart"""
583
+ if not results:
584
+ return None
585
+
586
+ data = []
587
+ for model, model_results in results.items():
588
+ for result in model_results:
589
+ if 'error' not in result and 'confidence' in result:
590
+ data.append({
591
+ 'Model': model,
592
+ 'Confidence': result['confidence'],
593
+ 'Correct': 'Correct' if result.get('correct', False) else 'Incorrect'
594
+ })
595
+
596
+ if not data:
597
+ return None
598
+
599
+ df = pd.DataFrame(data)
600
+
601
+ fig = px.box(
602
+ df,
603
+ x='Model',
604
+ y='Confidence',
605
+ color='Correct',
606
+ title="Confidence Distribution by Model and Correctness",
607
+ template="plotly_white"
608
+ )
609
+
610
+ return fig
611
+
612
+ # Sample datasets for quick testing
613
+ SAMPLE_DATASETS = {
614
+ "Custom (enter below)": "",
615
+ "LP": """Question,Correct Answer,Choice1,Choice2,Choice3
616
+ In which country is Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch located? Wales Germany France Scotland
617
+ In which country is Llanfair pwllgwyngyll located? Wales Germany France Scotland
618
+ In which country is Llanfair PG located? Wales Germany France Scotland""",
619
+ "Simple Math": """Question,Correct Answer,Choice1,Choice2,Choice3
620
+ What is 2+2?,4,3,4,5
621
+ What is 5*3?,15,12,15,18
622
+ What is 10-7?,3,3,4,2
623
+ What is 8/2?,4,3,4,5""",
624
+
625
+ "World Capitals": """Question,Correct Answer,Choice1,Choice2,Choice3
626
+ What is the capital of France?,Paris,London,Berlin,Paris
627
+ What is the capital of Japan?,Tokyo,Seoul,Tokyo,Bangkok
628
+ What is the capital of Brazil?,Brasília,Rio de Janeiro,Brasília,São Paulo
629
+ What is the capital of Australia?,Canberra,Sydney,Melbourne,Canberra""",
630
+
631
+ "Science Quiz": """Question,Correct Answer,Choice1,Choice2,Choice3
632
+ What is the chemical symbol for gold?,Au,Ag,Au,Go
633
+ Which planet is closest to the Sun?,Mercury,Venus,Mercury,Mars
634
+ What is the speed of light?,299792458 m/s,300000000 m/s,299792458 m/s,299000000 m/s
635
+ What gas do plants absorb from the atmosphere?,Carbon dioxide,Oxygen,Carbon dioxide,Nitrogen"""
636
+ }
637
+
638
+ # Custom CSS
639
+ css = """
640
+ .gradio-container {
641
+ font-family: 'Inter', sans-serif;
642
+ }
643
+ .sample-text {
644
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
645
+ font-size: 12px;
646
+ }
647
+ """
648
+
649
+ # Create Gradio interface
650
+ with gr.Blocks(title="🤖 Model Performance Comparison", theme=gr.themes.Soft(), css=css) as demo:
651
+ gr.Markdown("""
652
+ # 🤖 Model Performance Comparison Tool
653
+
654
+ Compare LLM performance on multiple-choice questions using Hugging Face models.
655
+
656
+ **Format**: Each line should have: `Question,Correct Answer,Choice1,Choice2,Choice3`
657
+
658
+ 💡 **Features**:
659
+ - Model evaluation using HuggingFace transformers
660
+ - Support for custom models via HF model paths
661
+ - Detailed question-by-question results
662
+ - Performance charts and statistics
663
+ """)
664
+
665
+ with gr.Row():
666
+ with gr.Column(scale=2):
667
+ # Sample dataset selector
668
+ sample_selector = gr.Dropdown(
669
+ choices=list(SAMPLE_DATASETS.keys()),
670
+ value="Custom (enter below)",
671
+ label="Choose sample dataset or enter your own",
672
+ interactive=True
673
+ )
674
+
675
+ # Dataset input
676
+ dataset_input = gr.Textbox(
677
+ label="Dataset (CSV/TSV format)",
678
+ placeholder="""Enter your dataset here...
679
+
680
+ Example format:
681
+ Question,Correct Answer,Choice1,Choice2,Choice3
682
+ What is 2+2?,4,3,4,5
683
+ What is the capital of France?,Paris,London,Berlin,Paris""",
684
+ lines=8,
685
+ max_lines=15
686
+ )
687
+
688
+ gr.Markdown("""
689
+ **Format Requirements**:
690
+ - First line: header (will be ignored), leave empty if no header
691
+ - Each data line: Question, Correct Answer, Choice1, Choice2, Choice3
692
+ - Use commas or tabs as separators
693
+ """)
694
+
695
+ with gr.Column(scale=1):
696
+ # Model selection
697
+ with gr.Tabs():
698
+ with gr.TabItem("🤖 Predefined Models"):
699
+ predefined_selector = gr.CheckboxGroup(
700
+ choices=PREDEFINED_MODELS,
701
+ value=[PREDEFINED_MODELS[0]],
702
+ label="Select from popular models",
703
+ interactive=True
704
+ )
705
+
706
+ with gr.TabItem("➕ Custom Models"):
707
+ custom_models_input = gr.Textbox(
708
+ label="Custom HuggingFace Model Paths",
709
+ placeholder="""Enter HuggingFace model paths (one per line):
710
+
711
+ microsoft/DialoGPT-medium
712
+ bigscience/bloom-560m""",
713
+ lines=5,
714
+ info="Add any HuggingFace model path. One model per line."
715
+ )
716
+
717
+ gr.Markdown("""
718
+ **Examples of valid model paths**:
719
+ - `microsoft/DialoGPT-medium`
720
+ - `bigscience/bloom-560m`
721
+ - `facebook/opt-350m`
722
+ - Your own fine-tuned models!
723
+ """)
724
+
725
+ # Evaluate button
726
+ evaluate_btn = gr.Button(
727
+ "⚡ Run Evaluation",
728
+ variant="primary",
729
+ scale=1
730
+ )
731
+
732
+ gr.Markdown("""
733
+ **⚠️ Note**:
734
+ - Larger models require more GPU memory, currently we only run on CPU
735
+ - First run will download models (may take time)
736
+ - Models are cached for subsequent runs
737
+ """)
738
+
739
+ # Results section
740
+ with gr.Column(visible=False) as results_section:
741
+ gr.Markdown("## 📊 Results")
742
+
743
+ summary_output = gr.Markdown(
744
+ value="Results will appear here...",
745
+ label="Performance Summary"
746
+ )
747
+
748
+ with gr.Row():
749
+ accuracy_plot = gr.Plot(label="Accuracy Comparison")
750
+ confidence_plot = gr.Plot(label="Confidence Analysis")
751
+
752
+ detailed_results = gr.HTML(
753
+ value="<p>Detailed results will appear here...</p>",
754
+ label="Detailed Question-by-Question Results"
755
+ )
756
+
757
+ # Event handlers
758
+ def update_dataset_from_sample(sample_name):
759
+ if sample_name in SAMPLE_DATASETS:
760
+ return gr.update(value=SAMPLE_DATASETS[sample_name])
761
+ return gr.update()
762
+
763
+ sample_selector.change(
764
+ fn=update_dataset_from_sample,
765
+ inputs=sample_selector,
766
+ outputs=dataset_input
767
+ )
768
+
769
+ evaluate_btn.click(
770
+ fn=run_evaluation,
771
+ inputs=[dataset_input, predefined_selector, custom_models_input],
772
+ outputs=[summary_output, detailed_results, accuracy_plot, confidence_plot, results_section]
773
+ )
774
+
775
+ gr.Markdown("""
776
+ ---
777
+ ### About Model Evaluation
778
+
779
+ This tool loads and runs HuggingFace models for evaluation:
780
+
781
+ **🏗️ How it works**:
782
+ - Downloads models from HuggingFace Hub
783
+ - Formats questions as prompts for each model
784
+ - Runs likelihood based evaluation
785
+
786
+ **⚡ Performance Tips**:
787
+ - Use smaller models for testing
788
+ - Larger models (7B+) require significant GPU memory
789
+ - Models are cached after first load
790
+
791
+ **🔧 Supported Models**:
792
+ - Any HuggingFace autoregressive language model
793
+ - Both instruction-tuned and base models
794
+ - Custom fine-tuned models via HF paths
795
+ """)
796
+
797
+ if __name__ == "__main__":
798
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ tiktoken
3
+ transformers
4
+ torch
5
+ pandas
6
+ plotly