AppleSwing commited on
Commit
b689423
Β·
verified Β·
1 Parent(s): 5efa2a7
Files changed (2) hide show
  1. app.py +255 -13
  2. requirements.txt +3 -1
app.py CHANGED
@@ -14,6 +14,7 @@ if not RESULT_DIR:
14
  import gradio as gr
15
  import pandas as pd
16
  from datasets import load_dataset
 
17
 
18
 
19
  def f2(x):
@@ -23,6 +24,166 @@ def f2(x):
23
  return x
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def json_to_row(path: str, metrics: dict) -> dict:
27
  model_name = metrics.get("model_name")
28
  if not model_name:
@@ -63,7 +224,6 @@ def json_to_row(path: str, metrics: dict) -> dict:
63
  "Model type": model_type,
64
  "Precision": precision,
65
  "E2E(s)": f2(e2e_s),
66
- "Batch size": batch_size,
67
  "GPU": gpu_type,
68
  "Accuracy(%)": pct(acc),
69
  "Cost($)": cost,
@@ -75,6 +235,7 @@ def json_to_row(path: str, metrics: dict) -> dict:
75
  "Decoding<br>S-MFU(%)": pct(metrics.get("decoding_smfu")),
76
  "TTFT(s)": f2(metrics.get("ttft")),
77
  "TPOT(s)": f2(metrics.get("tpot")),
 
78
  }
79
  return row
80
 
@@ -219,7 +380,7 @@ def load_from_dir(
219
 
220
  if df.empty:
221
  empty_html = "<p>No records found.</p>"
222
- return empty_html
223
 
224
  df = df.fillna("-")
225
  raw_models = set()
@@ -244,8 +405,14 @@ def load_from_dir(
244
  links.append(str(name))
245
  models_str = ", ".join(links)
246
 
 
 
 
 
247
  table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>'
248
- return table_html
 
 
249
 
250
 
251
  def auto_refresh_from_dir(
@@ -267,6 +434,38 @@ def auto_refresh_from_dir(
267
  )
268
 
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  # Gradio UI
271
 
272
  def build_app() -> gr.Blocks:
@@ -275,6 +474,16 @@ def build_app() -> gr.Blocks:
275
  body {
276
  background-color: #f5f7fa !important;
277
  }
 
 
 
 
 
 
 
 
 
 
278
 
279
  /* The outer Group container */
280
  .search-box {
@@ -571,7 +780,7 @@ def build_app() -> gr.Blocks:
571
  value=["bfloat16", "fp8"],
572
  )
573
 
574
- with gr.Accordion("πŸ“– About Tasks & Metrics", open=False):
575
  gr.Markdown(
576
  "### Tasks\n"
577
  "- **GSM8K** β€” Mathematics Problem-Solving ([paper](https://arxiv.org/abs/2110-14168))\n"
@@ -591,48 +800,81 @@ def build_app() -> gr.Blocks:
591
  elem_classes="info-section"
592
  )
593
 
594
- # Right side - Table (wider)
595
  with gr.Column(scale=5):
596
  leaderboard_output = gr.HTML(label="πŸ“ˆ Results")
597
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  demo.load(
599
  fn=auto_refresh_from_dir,
600
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
601
- outputs=[leaderboard_output],
602
  )
603
 
604
  search_input.change(
605
  fn=load_from_dir,
606
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
607
- outputs=[leaderboard_output],
608
  )
609
 
610
  task_filter.change(
611
  fn=load_from_dir,
612
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
613
- outputs=[leaderboard_output],
614
  )
615
  framework_filter.change(
616
  fn=load_from_dir,
617
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
618
- outputs=[leaderboard_output],
619
  )
620
  model_type_filter.change(
621
  fn=load_from_dir,
622
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
623
- outputs=[leaderboard_output],
624
  )
625
  precision_filter.change(
626
  fn=load_from_dir,
627
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
628
- outputs=[leaderboard_output],
629
  )
630
 
 
 
 
 
 
 
 
631
  timer = gr.Timer(60.0)
632
  timer.tick(
633
  fn=auto_refresh_from_dir,
634
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
635
- outputs=[leaderboard_output],
636
  )
637
 
638
  return demo
 
14
  import gradio as gr
15
  import pandas as pd
16
  from datasets import load_dataset
17
+ import plotly.graph_objects as go
18
 
19
 
20
  def f2(x):
 
24
  return x
25
 
26
 
27
+ def normalize(val, vmin, vmax, baseline=20):
28
+ """Normalize value to baseline-100 range."""
29
+ if vmax == vmin:
30
+ return baseline + 40
31
+ return baseline + (val - vmin) / (vmax - vmin) * (100 - baseline)
32
+
33
+
34
+ def normalize_reversed(val, vmin, vmax, baseline=20):
35
+ """Normalize value (reversed - lower is better) to baseline-100 range."""
36
+ if vmax == vmin:
37
+ return baseline + 40
38
+ return baseline + (vmax - val) / (vmax - vmin) * (100 - baseline)
39
+
40
+
41
+ def normalize_cost(val, max_tick, baseline=20):
42
+ """Normalize cost (lower is better)."""
43
+ if max_tick == 0:
44
+ return baseline + 40
45
+ return baseline + (max_tick - min(val, max_tick)) / max_tick * (100 - baseline)
46
+
47
+
48
+ def generate_radar_plot(selected_rows_data: List[dict]) -> go.Figure:
49
+ """Generate a CAP radar plot from selected rows."""
50
+ # Validation: max 3 rows, all same dataset
51
+ if not selected_rows_data or len(selected_rows_data) == 0:
52
+ fig = go.Figure()
53
+ fig.add_annotation(
54
+ text="Please select 1-3 rows from the table to generate radar plot",
55
+ xref="paper", yref="paper",
56
+ x=0.05, y=0.5, showarrow=False,
57
+ font=dict(size=16)
58
+ )
59
+ fig.update_layout(height=600, width=900)
60
+ return fig
61
+ if len(selected_rows_data) > 3:
62
+ fig = go.Figure()
63
+ fig.add_annotation(
64
+ text="Error: Please select no more than 3 rows!",
65
+ xref="paper", yref="paper",
66
+ x=0.5, y=0.5, showarrow=False,
67
+ font=dict(size=18, color="red")
68
+ )
69
+ fig.update_layout(height=600, width=900)
70
+ return fig
71
+ datasets = [row.get('Dataset', '') for row in selected_rows_data]
72
+ unique_datasets = set(datasets)
73
+ if len(unique_datasets) > 1:
74
+ fig = go.Figure()
75
+ fig.add_annotation(
76
+ text="Error: Please select rows from the same dataset!",
77
+ xref="paper", yref="paper",
78
+ x=0.5, y=0.5, showarrow=False,
79
+ font=dict(size=18, color="red")
80
+ )
81
+ fig.update_layout(height=600, width=900)
82
+ return fig
83
+ dataset_name = datasets[0] if datasets else "Unknown"
84
+
85
+ # Extract metrics from selected rows
86
+ data = {}
87
+ for row in selected_rows_data:
88
+ # Extract model name from HTML or use as-is
89
+ model_name = row.get('Model', 'Unknown')
90
+ if isinstance(model_name, str) and 'href' in model_name:
91
+ try:
92
+ model_name = model_name.split('>', 1)[1].split('<', 1)[0]
93
+ except:
94
+ pass
95
+
96
+ # Format legend name: extract name after "/" and add method
97
+ method = row.get('Method', '')
98
+ if isinstance(model_name, str) and '/' in model_name:
99
+ legend_name = model_name.split('/')[-1] # Get part after last /
100
+ else:
101
+ legend_name = str(model_name)
102
+
103
+ # Add method suffix
104
+ if method and method not in ['Unknown', '-', '']:
105
+ legend_name = f"{legend_name}-{method}"
106
+
107
+ # Get metrics
108
+ acc = row.get('Accuracy(%)', 0)
109
+ cost = row.get('Cost($)', 0)
110
+ throughput = row.get('Decoding T/s', 0)
111
+
112
+ # Convert to float if needed
113
+ try:
114
+ acc = float(acc) if acc not in [None, '-', ''] else 0
115
+ cost = float(cost) if cost not in [None, '-', ''] else 0
116
+ throughput = float(throughput) if throughput not in [None, '-', ''] else 0
117
+ except:
118
+ acc, cost, throughput = 0, 0, 0
119
+
120
+ data[legend_name] = {
121
+ 'accuracy': acc / 100.0 if acc > 1 else acc, # Normalize to 0-1
122
+ 'cost': cost,
123
+ 'throughput': throughput
124
+ }
125
+
126
+ # Get min/max for normalization
127
+ throughputs = [v['throughput'] for v in data.values()]
128
+ costs = [v['cost'] for v in data.values()]
129
+ accs = [v['accuracy'] for v in data.values()]
130
+
131
+ tp_min, tp_max = (min(throughputs), max(throughputs)) if throughputs else (0, 1)
132
+ cost_max = max(costs) if costs else 1
133
+ acc_min, acc_max = (min(accs), 1.0) if accs else (0, 1)
134
+
135
+ baseline = 20
136
+ categories = ['Throughput (T/s)', 'Cost ($)', 'Accuracy', 'Throughput (T/s)']
137
+
138
+ fig = go.Figure()
139
+
140
+ for system, values in data.items():
141
+ raw_vals = [values['throughput'], values['cost'], values['accuracy']]
142
+ norm_vals = [
143
+ normalize(values['throughput'], tp_min, tp_max, baseline),
144
+ normalize_cost(values['cost'], cost_max, baseline),
145
+ normalize(values['accuracy'], acc_min, acc_max, baseline)
146
+ ]
147
+ norm_vals += [norm_vals[0]] # Close the loop
148
+
149
+ hovertext = [
150
+ f"Throughput: {raw_vals[0]:.2f} T/s",
151
+ f"Cost: ${raw_vals[1]:.2f}",
152
+ f"Accuracy: {raw_vals[2]*100:.2f}%",
153
+ f"Throughput: {raw_vals[0]:.2f} T/s"
154
+ ]
155
+
156
+ fig.add_trace(go.Scatterpolar(
157
+ r=norm_vals,
158
+ theta=categories,
159
+ fill='toself',
160
+ name=system,
161
+ text=hovertext,
162
+ hoverinfo='text+name',
163
+ line=dict(width=2)
164
+ ))
165
+
166
+ fig.update_layout(
167
+ title=f"CAP Radar Plot: {dataset_name}",
168
+ polar=dict(
169
+ radialaxis=dict(visible=True, range=[0, 100], tickfont=dict(size=10)),
170
+ angularaxis=dict(
171
+ tickfont=dict(size=12),
172
+ rotation=30,
173
+ direction='clockwise'
174
+ ),
175
+ ),
176
+ legend=dict(orientation='h', yanchor='bottom', y=-0.2, xanchor='center', x=0.5),
177
+ margin=dict(t=100, b=120, l=100, r=1000),
178
+ height=700,
179
+ width=1500,
180
+ paper_bgcolor='white',
181
+ plot_bgcolor='white'
182
+ )
183
+
184
+ return fig
185
+
186
+
187
  def json_to_row(path: str, metrics: dict) -> dict:
188
  model_name = metrics.get("model_name")
189
  if not model_name:
 
224
  "Model type": model_type,
225
  "Precision": precision,
226
  "E2E(s)": f2(e2e_s),
 
227
  "GPU": gpu_type,
228
  "Accuracy(%)": pct(acc),
229
  "Cost($)": cost,
 
235
  "Decoding<br>S-MFU(%)": pct(metrics.get("decoding_smfu")),
236
  "TTFT(s)": f2(metrics.get("ttft")),
237
  "TPOT(s)": f2(metrics.get("tpot")),
238
+ "Batch size": batch_size, # moved to tail
239
  }
240
  return row
241
 
 
380
 
381
  if df.empty:
382
  empty_html = "<p>No records found.</p>"
383
+ return empty_html, []
384
 
385
  df = df.fillna("-")
386
  raw_models = set()
 
405
  links.append(str(name))
406
  models_str = ", ".join(links)
407
 
408
+ # Insert row number column at the beginning for easy reference
409
+ df.insert(0, 'Row #', range(len(df)))
410
+
411
+ # Create HTML table
412
  table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>'
413
+ df_without_rownum = df.drop('Row #', axis=1)
414
+ df_dict = df_without_rownum.to_dict('records')
415
+ return table_html, df_dict
416
 
417
 
418
  def auto_refresh_from_dir(
 
434
  )
435
 
436
 
437
+ def update_radar_plot(df_data: list, selected_indices: list):
438
+ """Update radar plot based on selected row indices."""
439
+ if not selected_indices or not df_data:
440
+ return generate_radar_plot([])
441
+
442
+ # Get selected rows (limit to 3)
443
+ selected_rows = [df_data[i] for i in selected_indices[:3] if i < len(df_data)]
444
+ return generate_radar_plot(selected_rows)
445
+
446
+
447
+ def parse_and_generate_plot(df_data: list, indices_str: str):
448
+ """Parse comma-separated indices and generate radar plot."""
449
+ if not indices_str or not indices_str.strip():
450
+ return generate_radar_plot([])
451
+
452
+ try:
453
+ # Parse comma-separated indices
454
+ indices = [int(idx.strip()) for idx in indices_str.split(',') if idx.strip()]
455
+ # Limit to 3 rows
456
+ indices = indices[:3]
457
+ # Get selected rows
458
+ selected_rows = [df_data[i] for i in indices if 0 <= i < len(df_data)]
459
+ return generate_radar_plot(selected_rows)
460
+ except (ValueError, IndexError):
461
+ return generate_radar_plot([])
462
+
463
+
464
+ def on_table_select(df, evt: gr.SelectData):
465
+ """Handle table row selection."""
466
+ return evt.index
467
+
468
+
469
  # Gradio UI
470
 
471
  def build_app() -> gr.Blocks:
 
474
  body {
475
  background-color: #f5f7fa !important;
476
  }
477
+
478
+ /* Row number column styling */
479
+ .metrics-table th:first-child,
480
+ .metrics-table td:first-child {
481
+ width: 60px !important;
482
+ text-align: center !important;
483
+ padding: 8px !important;
484
+ font-weight: 600 !important;
485
+ background-color: #f0f0f0 !important;
486
+ }
487
 
488
  /* The outer Group container */
489
  .search-box {
 
780
  value=["bfloat16", "fp8"],
781
  )
782
 
783
+ with gr.Accordion("πŸ“– About Tasks & Metrics", open=True):
784
  gr.Markdown(
785
  "### Tasks\n"
786
  "- **GSM8K** β€” Mathematics Problem-Solving ([paper](https://arxiv.org/abs/2110-14168))\n"
 
800
  elem_classes="info-section"
801
  )
802
 
803
+ # Right side - Table with selection and Radar Plot below
804
  with gr.Column(scale=5):
805
  leaderboard_output = gr.HTML(label="πŸ“ˆ Results")
806
+
807
+ with gr.Group(elem_classes="filter-section"):
808
+ gr.Markdown("### πŸ“Š CAP Radar Plot")
809
+ gr.Markdown(
810
+ "**How to use:** Look at the 'Row #' column in the table above. "
811
+ "Enter up to 3 row numbers below (separated by commas) and click Generate."
812
+ )
813
+
814
+ with gr.Row():
815
+ row_indices_input = gr.Textbox(
816
+ label="Row Numbers to Compare",
817
+ placeholder="Example: 0,1,2",
818
+ elem_id="row_indices_input",
819
+ scale=3
820
+ )
821
+ generate_btn = gr.Button("🎯 Generate", variant="primary", scale=1, size="lg")
822
+
823
+ with gr.Row():
824
+ with gr.Column(scale=1):
825
+ pass
826
+ with gr.Column(scale=5):
827
+ radar_plot = gr.Plot(label="", value=generate_radar_plot([]))
828
+ with gr.Column(scale=1):
829
+ pass
830
+
831
+ df_data_state = gr.State([])
832
+
833
  demo.load(
834
  fn=auto_refresh_from_dir,
835
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
836
+ outputs=[leaderboard_output, df_data_state],
837
  )
838
 
839
  search_input.change(
840
  fn=load_from_dir,
841
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
842
+ outputs=[leaderboard_output, df_data_state],
843
  )
844
 
845
  task_filter.change(
846
  fn=load_from_dir,
847
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
848
+ outputs=[leaderboard_output, df_data_state],
849
  )
850
  framework_filter.change(
851
  fn=load_from_dir,
852
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
853
+ outputs=[leaderboard_output, df_data_state],
854
  )
855
  model_type_filter.change(
856
  fn=load_from_dir,
857
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
858
+ outputs=[leaderboard_output, df_data_state],
859
  )
860
  precision_filter.change(
861
  fn=load_from_dir,
862
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
863
+ outputs=[leaderboard_output, df_data_state],
864
  )
865
 
866
+ # Generate plot on button click
867
+ generate_btn.click(
868
+ fn=parse_and_generate_plot,
869
+ inputs=[df_data_state, row_indices_input],
870
+ outputs=[radar_plot]
871
+ )
872
+
873
  timer = gr.Timer(60.0)
874
  timer.tick(
875
  fn=auto_refresh_from_dir,
876
  inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
877
+ outputs=[leaderboard_output, df_data_state],
878
  )
879
 
880
  return demo
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  gradio>=4.44.0
2
  pandas
3
  datasets
4
- huggingface_hub<0.25.0
 
 
 
1
  gradio>=4.44.0
2
  pandas
3
  datasets
4
+ huggingface_hub<0.25.0
5
+ plotly>=5.0.0
6
+ kaleido>=0.2.1