Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
from datasets import load_dataset
|
5 |
import torch
|
|
|
6 |
|
7 |
# Cache to avoid reloading the model
|
8 |
model_cache = {}
|
@@ -19,9 +20,20 @@ def load_model(model_id):
|
|
19 |
return generator
|
20 |
|
21 |
def format_prompt(item):
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
return prompt, item['answer']
|
24 |
|
|
|
|
|
|
|
|
|
25 |
def evaluate(model_id, sample_count, config_name):
|
26 |
gen = load_model(model_id)
|
27 |
dataset = load_dataset("cais/mmlu", config_name, token=HF_TOKEN)["test"]
|
@@ -32,8 +44,8 @@ def evaluate(model_id, sample_count, config_name):
|
|
32 |
|
33 |
for item in dataset:
|
34 |
prompt, answer = format_prompt(item)
|
35 |
-
output = gen(prompt, max_new_tokens=
|
36 |
-
output_letter =
|
37 |
is_correct = output_letter == answer
|
38 |
correct += is_correct
|
39 |
results.append((prompt, output.strip(), answer, output_letter, is_correct))
|
@@ -93,4 +105,4 @@ with gr.Blocks(css="body {font-family: Inter, sans-serif; padding: 1em; max-widt
|
|
93 |
run_button.click(run, inputs=[model_id, sample_count, config_name], outputs=[acc_output, detail_output])
|
94 |
download_button.click(save_text, inputs=detail_output, outputs=gr.File())
|
95 |
|
96 |
-
demo.launch()
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
from datasets import load_dataset
|
5 |
import torch
|
6 |
+
import re
|
7 |
|
8 |
# Cache to avoid reloading the model
|
9 |
model_cache = {}
|
|
|
20 |
return generator
|
21 |
|
22 |
def format_prompt(item):
|
23 |
+
system_instruction = "
|
24 |
+
Only answer with a single letter: A, B, C, or D."
|
25 |
+
prompt = f"{item['question']}
|
26 |
+
A. {item['choices'][0]}
|
27 |
+
B. {item['choices'][1]}
|
28 |
+
C. {item['choices'][2]}
|
29 |
+
D. {item['choices'][3]}
|
30 |
+
Answer:{system_instruction}"
|
31 |
return prompt, item['answer']
|
32 |
|
33 |
+
def extract_choice_letter(output):
|
34 |
+
match = re.search(r"\b([ABCD])\b", output.strip())
|
35 |
+
return match.group(1) if match else None
|
36 |
+
|
37 |
def evaluate(model_id, sample_count, config_name):
|
38 |
gen = load_model(model_id)
|
39 |
dataset = load_dataset("cais/mmlu", config_name, token=HF_TOKEN)["test"]
|
|
|
44 |
|
45 |
for item in dataset:
|
46 |
prompt, answer = format_prompt(item)
|
47 |
+
output = gen(prompt, max_new_tokens=20, do_sample=False)[0]["generated_text"]
|
48 |
+
output_letter = extract_choice_letter(output)
|
49 |
is_correct = output_letter == answer
|
50 |
correct += is_correct
|
51 |
results.append((prompt, output.strip(), answer, output_letter, is_correct))
|
|
|
105 |
run_button.click(run, inputs=[model_id, sample_count, config_name], outputs=[acc_output, detail_output])
|
106 |
download_button.click(save_text, inputs=detail_output, outputs=gr.File())
|
107 |
|
108 |
+
demo.launch()
|