Enderchef commited on
Commit
8ea457b
·
verified ·
1 Parent(s): 9dcd426

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from datasets import load_dataset
5
  import torch
 
6
 
7
  # Cache to avoid reloading the model
8
  model_cache = {}
@@ -19,9 +20,20 @@ def load_model(model_id):
19
  return generator
20
 
21
  def format_prompt(item):
22
- prompt = f"{item['question']}\nA. {item['choices'][0]}\nB. {item['choices'][1]}\nC. {item['choices'][2]}\nD. {item['choices'][3]}\nAnswer:"
 
 
 
 
 
 
 
23
  return prompt, item['answer']
24
 
 
 
 
 
25
  def evaluate(model_id, sample_count, config_name):
26
  gen = load_model(model_id)
27
  dataset = load_dataset("cais/mmlu", config_name, token=HF_TOKEN)["test"]
@@ -32,8 +44,8 @@ def evaluate(model_id, sample_count, config_name):
32
 
33
  for item in dataset:
34
  prompt, answer = format_prompt(item)
35
- output = gen(prompt, max_new_tokens=10, do_sample=False)[0]["generated_text"]
36
- output_letter = next((char for char in reversed(output) if char in "ABCD"), None)
37
  is_correct = output_letter == answer
38
  correct += is_correct
39
  results.append((prompt, output.strip(), answer, output_letter, is_correct))
@@ -93,4 +105,4 @@ with gr.Blocks(css="body {font-family: Inter, sans-serif; padding: 1em; max-widt
93
  run_button.click(run, inputs=[model_id, sample_count, config_name], outputs=[acc_output, detail_output])
94
  download_button.click(save_text, inputs=detail_output, outputs=gr.File())
95
 
96
- demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from datasets import load_dataset
5
  import torch
6
+ import re
7
 
8
  # Cache to avoid reloading the model
9
  model_cache = {}
 
20
  return generator
21
 
22
  def format_prompt(item):
23
+ system_instruction = "
24
+ Only answer with a single letter: A, B, C, or D."
25
+ prompt = f"{item['question']}
26
+ A. {item['choices'][0]}
27
+ B. {item['choices'][1]}
28
+ C. {item['choices'][2]}
29
+ D. {item['choices'][3]}
30
+ Answer:{system_instruction}"
31
  return prompt, item['answer']
32
 
33
+ def extract_choice_letter(output):
34
+ match = re.search(r"\b([ABCD])\b", output.strip())
35
+ return match.group(1) if match else None
36
+
37
  def evaluate(model_id, sample_count, config_name):
38
  gen = load_model(model_id)
39
  dataset = load_dataset("cais/mmlu", config_name, token=HF_TOKEN)["test"]
 
44
 
45
  for item in dataset:
46
  prompt, answer = format_prompt(item)
47
+ output = gen(prompt, max_new_tokens=20, do_sample=False)[0]["generated_text"]
48
+ output_letter = extract_choice_letter(output)
49
  is_correct = output_letter == answer
50
  correct += is_correct
51
  results.append((prompt, output.strip(), answer, output_letter, is_correct))
 
105
  run_button.click(run, inputs=[model_id, sample_count, config_name], outputs=[acc_output, detail_output])
106
  download_button.click(save_text, inputs=detail_output, outputs=gr.File())
107
 
108
+ demo.launch()