zhihengchen commited on
Commit
6f7189c
·
verified ·
1 Parent(s): 3a387bf

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from constants import *
5
+
6
+ # ------------ 下载链接 ------------
7
+ def get_download_link_model(task, dataset, example):
8
+ _task_path = TASK_PATH_MAPPING[task]
9
+ _dataset_path = DATASET_PATH_MAPPING[dataset]
10
+ _example_path = EXAMPLE_PATH_MAPPING[example]
11
+ return os.path.join("data", _task_path, _dataset_path, "weight", f"{_example_path}.zip")
12
+
13
+ def get_download_link_json(task, dataset, example):
14
+ _task_path = TASK_PATH_MAPPING[task]
15
+ _dataset_path = DATASET_PATH_MAPPING[dataset]
16
+ _example_path = EXAMPLE_PATH_MAPPING[example]
17
+ if _task_path == "common":
18
+ return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.jsonl")
19
+ else:
20
+ return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.json")
21
+
22
+ # ------------ 数据读取 + 平均准确率 ------------
23
+ def get_data(task, dataset, example):
24
+ _task_path = TASK_PATH_MAPPING[task]
25
+ _dataset_path = DATASET_PATH_MAPPING[dataset]
26
+ _example_path = EXAMPLE_PATH_MAPPING[example]
27
+ csv_file = os.path.join("data", _task_path, _dataset_path, "csv", f"{_example_path}.csv")
28
+ if not os.path.exists(csv_file):
29
+ return None, None
30
+
31
+ read_data = pd.read_csv(csv_file)
32
+ data = pd.DataFrame(columns=COLUMN_NAMES)
33
+ average_acc = None
34
+
35
+ if _task_path == "coding":
36
+ for _, row in read_data.iterrows():
37
+ data = pd.concat([data, pd.DataFrame([{
38
+ "Prompt": row["prompt"],
39
+ "Pass@1": round(float(row["pass@1"]) * 100, 3),
40
+ "Pass@5": round(float(row["pass@5"]) * 100, 3),
41
+ "Pass@10": round(float(row["pass@10"]) * 100, 3),
42
+ "Correctness": "N/A"
43
+ }])], ignore_index=True)
44
+ # 仅对 HumanEval 数据集计算三列平均
45
+ if "HumanEval" in dataset:
46
+ p1_mean = round(read_data["pass@1"].mean() * 100, 3)
47
+ p5_mean = round(read_data["pass@5"].mean() * 100, 3)
48
+ p10_mean = round(read_data["pass@10"].mean() * 100, 3)
49
+ average_acc = f"{p1_mean} / {p5_mean} / {p10_mean}"
50
+ elif _task_path in ["common", "math"]:
51
+ for _, row in read_data.iterrows():
52
+ data = pd.concat([data, pd.DataFrame([{
53
+ "Prompt": row["prompt"],
54
+ "Pass@1": None,
55
+ "Pass@5": None,
56
+ "Pass@10": None,
57
+ "Correctness": "✅" if row["correctness"] else "❌"
58
+ }])], ignore_index=True)
59
+ average_acc = round(read_data["correctness"].mean() * 100, 3)
60
+
61
+ return data, average_acc
62
+
63
+ # ------------ Gradio UI ------------
64
+ with gr.Blocks() as demo_board:
65
+ gr.HTML(DND_HEADER)
66
+ gr.Markdown(DND_INTRODUCTION)
67
+
68
+ task = gr.Radio(
69
+ label="Task",
70
+ choices=TASK_LIST,
71
+ value=TASK_LIST[0],
72
+ interactive=True,
73
+ )
74
+ dataset = gr.Radio(
75
+ label="Dataset",
76
+ choices=TASK_DATASET_LIST[task.value],
77
+ value=TASK_DATASET_LIST[task.value][0],
78
+ interactive=True
79
+ )
80
+ example = gr.Radio(
81
+ label="Example",
82
+ choices=EXAMPLE_LIST,
83
+ value=EXAMPLE_LIST[0],
84
+ interactive=True,
85
+ )
86
+
87
+ # 平均准确率(放在 Prompt 表格上方)
88
+ average_acc_display = gr.Textbox(
89
+ label="Average Accuracy (%)",
90
+ value=lambda: str(get_data(task.value, dataset.value, example.value)[1]),
91
+ interactive=False,
92
+ visible=True,
93
+ scale=0,
94
+ max_lines=1,
95
+ min_width=160
96
+ )
97
+
98
+ # Prompt 表格
99
+ board = gr.components.Dataframe(
100
+ value=lambda: get_data(task.value, dataset.value, example.value)[0],
101
+ column_widths=["60%", "10%", "10%", "10%", "10%"],
102
+ headers=COLUMN_NAMES,
103
+ type="pandas",
104
+ datatype=DATA_TITLE_TYPE,
105
+ interactive=False,
106
+ visible=True,
107
+ max_height=500,
108
+ )
109
+
110
+ # 联动更新:task -> dataset
111
+ task.change(
112
+ lambda t: gr.Radio(
113
+ label="Dataset",
114
+ choices=TASK_DATASET_LIST[t],
115
+ value=TASK_DATASET_LIST[t][0],
116
+ interactive=True,
117
+ ),
118
+ inputs=[task],
119
+ outputs=dataset
120
+ )
121
+
122
+ # 联动更新:task / dataset / example -> 表格 + 平均准确率
123
+ for component in [task, dataset, example]:
124
+ component.change(
125
+ lambda t, d, e: (get_data(t, d, e)[0], str(get_data(t, d, e)[1])),
126
+ inputs=[task, dataset, example],
127
+ outputs=[board, average_acc_display]
128
+ )
129
+
130
+ # 下载按钮
131
+ with gr.Row():
132
+ json_downloader = gr.DownloadButton("Download JSON", visible=True)
133
+ model_downloader = gr.DownloadButton("Download Model", visible=True)
134
+
135
+ json_downloader.click(
136
+ fn=get_download_link_json,
137
+ inputs=[task, dataset, example],
138
+ outputs=json_downloader,
139
+ )
140
+ model_downloader.click(
141
+ fn=get_download_link_model,
142
+ inputs=[task, dataset, example],
143
+ outputs=model_downloader,
144
+ )
145
+
146
+ # 引用文本
147
+ citation_button = gr.Textbox(
148
+ value=CITATION_BUTTON_TEXT,
149
+ label=CITATION_BUTTON_LABEL,
150
+ elem_id="citation-button",
151
+ lines=6,
152
+ show_copy_button=True,
153
+ )
154
+
155
+ # 启动
156
+ demo_board.launch()