AppleSwing's picture
Update app.py
6eeb754 verified
raw
history blame
30.9 kB
#!/usr/bin/env python
import os
import json
from typing import List, Tuple
os.environ["GRADIO_LANGUAGE"] = "en"
RESULT_DIR = os.environ.get("MOECAP_RESULT_DIR")
if not RESULT_DIR:
raise RuntimeError(
"MOECAP_RESULT_DIR is not set. Please set MOECAP_RESULT_DIR (HF Repo ID) before running app.py"
)
import gradio as gr
import pandas as pd
from datasets import load_dataset
import plotly.graph_objects as go
def f2(x):
"""Format to 2 decimal places if number, else return as-is."""
if isinstance(x, (int, float)):
return round(float(x), 2)
return x
def normalize(val, vmin, vmax, baseline=20):
"""Normalize value to baseline-100 range."""
if vmax == vmin:
return baseline + 40
return baseline + (val - vmin) / (vmax - vmin) * (100 - baseline)
def normalize_reversed(val, vmin, vmax, baseline=20):
"""Normalize value (reversed - lower is better) to baseline-100 range."""
if vmax == vmin:
return baseline + 40
return baseline + (vmax - val) / (vmax - vmin) * (100 - baseline)
def normalize_cost(val, max_tick, baseline=20):
"""Normalize cost (lower is better)."""
if max_tick == 0:
return baseline + 40
return baseline + (max_tick - min(val, max_tick)) / max_tick * (100 - baseline)
def generate_radar_plot(selected_rows_data: List[dict]) -> go.Figure:
"""Generate a CAP radar plot from selected rows."""
# Validation: max 3 rows, all same dataset
if not selected_rows_data or len(selected_rows_data) == 0:
fig = go.Figure()
fig.add_annotation(
text="Please select 1-3 rows from the table to generate radar plot",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16),
xanchor='center',
yanchor='middle'
)
fig.update_layout(
height=600,
width=800,
xaxis=dict(visible=False),
yaxis=dict(visible=False),
paper_bgcolor='white',
plot_bgcolor='white',
margin=dict(t=40, b=40, l=40, r=40)
)
return fig
if len(selected_rows_data) > 3:
fig = go.Figure()
fig.add_annotation(
text="Error: Please select no more than 3 rows!",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=18, color="red"),
xanchor='center',
yanchor='middle'
)
fig.update_layout(
height=600,
width=800,
xaxis=dict(visible=False),
yaxis=dict(visible=False),
paper_bgcolor='white',
plot_bgcolor='white',
margin=dict(t=40, b=40, l=40, r=40)
)
return fig
datasets = [row.get('Dataset', '') for row in selected_rows_data]
unique_datasets = set(datasets)
if len(unique_datasets) > 1:
fig = go.Figure()
fig.add_annotation(
text="Error: Please select rows from the same dataset!",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=18, color="red"),
xanchor='center',
yanchor='middle'
)
fig.update_layout(
height=600,
width=800,
xaxis=dict(visible=False),
yaxis=dict(visible=False),
paper_bgcolor='white',
plot_bgcolor='white',
margin=dict(t=40, b=40, l=40, r=40)
)
return fig
dataset_name = datasets[0] if datasets else "Unknown"
# Extract metrics from selected rows
data = {}
for row in selected_rows_data:
# Extract model name from HTML or use as-is
model_name = row.get('Model', 'Unknown')
if isinstance(model_name, str) and 'href' in model_name:
try:
model_name = model_name.split('>', 1)[1].split('<', 1)[0]
except:
pass
# Format legend name: extract name after "/" and add method
method = row.get('Method', '')
if isinstance(model_name, str) and '/' in model_name:
legend_name = model_name.split('/')[-1] # Get part after last /
else:
legend_name = str(model_name)
# Add method suffix
if method and method not in ['Unknown', '-', '']:
legend_name = f"{legend_name}-{method}"
# Get metrics
acc = row.get('Accuracy(%)', 0)
cost = row.get('Cost($)', 0)
throughput = row.get('Decoding T/s', 0)
# Convert to float if needed
try:
acc = float(acc) if acc not in [None, '-', ''] else 0
cost = float(cost) if cost not in [None, '-', ''] else 0
throughput = float(throughput) if throughput not in [None, '-', ''] else 0
except:
acc, cost, throughput = 0, 0, 0
data[legend_name] = {
'accuracy': acc / 100.0 if acc > 1 else acc, # Normalize to 0-1
'cost': cost,
'throughput': throughput
}
# Get min/max for normalization
throughputs = [v['throughput'] for v in data.values()]
costs = [v['cost'] for v in data.values()]
accs = [v['accuracy'] for v in data.values()]
tp_min, tp_max = (min(throughputs), max(throughputs)) if throughputs else (0, 1)
cost_max = max(costs) if costs else 1
acc_min, acc_max = (min(accs), 1.0) if accs else (0, 1)
baseline = 20
categories = ['Throughput (T/s)', 'Cost ($)', 'Accuracy', 'Throughput (T/s)']
fig = go.Figure()
for system, values in data.items():
raw_vals = [values['throughput'], values['cost'], values['accuracy']]
norm_vals = [
normalize(values['throughput'], tp_min, tp_max, baseline),
normalize_cost(values['cost'], cost_max, baseline),
normalize(values['accuracy'], acc_min, acc_max, baseline)
]
norm_vals += [norm_vals[0]] # Close the loop
hovertext = [
f"Throughput: {raw_vals[0]:.2f} T/s",
f"Cost: ${raw_vals[1]:.2f}",
f"Accuracy: {raw_vals[2]*100:.2f}%",
f"Throughput: {raw_vals[0]:.2f} T/s"
]
fig.add_trace(go.Scatterpolar(
r=norm_vals,
theta=categories,
fill='toself',
name=system,
text=hovertext,
hoverinfo='text+name',
line=dict(width=2)
))
fig.update_layout(
title=dict(
text=f"CAP Radar Plot: {dataset_name}",
x=0.5,
xanchor='center',
font=dict(size=18)
),
polar=dict(
radialaxis=dict(visible=True, range=[0, 100], tickfont=dict(size=11)),
angularaxis=dict(
tickfont=dict(size=13),
rotation=30,
direction='clockwise'
),
),
legend=dict(
orientation='h',
yanchor='bottom',
y=-0.15,
xanchor='center',
x=0.5,
font=dict(size=12)
),
margin=dict(t=80, b=100, l=80, r=80),
height=650,
width=800,
paper_bgcolor='white',
plot_bgcolor='white'
)
return fig
def json_to_row(path: str, metrics: dict) -> dict:
model_name = metrics.get("model_name")
if not model_name:
model_name = "unknown-model"
dataset = metrics.get("dataset", "Unknown")
method = metrics.get("method", "Unknown")
precision = metrics.get("precision", "Unknown")
model_type = metrics.get("model_type", "Unknown")
e2e_s = metrics.get("e2e_s", None)
batch_size = metrics.get("batch_size", None)
gpu_type = metrics.get("gpu_type", "")
cost = metrics.get("cost", None)
em = metrics.get("exact_match")
correct = metrics.get("correct")
total = metrics.get("total")
if isinstance(correct, (int, float)) and isinstance(total, (int, float)) and total > 0:
acc = correct / total
else:
acc = em
def pct(x):
return round(x * 100, 2) if isinstance(x, (int, float)) else None
if isinstance(model_name, str) and "/" in model_name:
hf_url = f"https://huggingface.co/{model_name}"
model_cell = f"<a href='{hf_url}' target='_blank'>{model_name}</a>"
else:
model_cell = model_name
row = {
"Model": model_cell,
"Dataset": dataset,
"Method": method,
"Model type": model_type,
"Precision": precision,
"E2E(s)": f2(e2e_s),
"GPU": gpu_type,
"Accuracy(%)": pct(acc),
"Cost($)": cost,
"Decoding T/s": f2(metrics.get("decoding_throughput")),
"Prefill T/s": f2(metrics.get("prefill_tp")),
"Prefill<br>S-MBU(%)": pct(metrics.get("prefill_smbu")),
"Prefill<br>S-MFU(%)": pct(metrics.get("prefill_smfu")),
"Decoding<br>S-MBU(%)": pct(metrics.get("decoding_smbu")),
"Decoding<br>S-MFU(%)": pct(metrics.get("decoding_smfu")),
"TTFT(s)": f2(metrics.get("ttft")),
"TPOT(s)": f2(metrics.get("tpot")),
"Batch size": batch_size, # moved to tail
}
return row
def build_leaderboard_from_files(files: List[gr.File], prev_rows: list | None = None):
if prev_rows is None:
prev_rows = []
if not files and prev_rows:
df = pd.DataFrame(prev_rows)
raw_models = set()
for cell in df["Model"].tolist():
if isinstance(cell, str) and "href" in cell:
try:
name = cell.split(">", 1)[1].split("<", 1)[0]
except Exception:
name = cell
else:
name = cell
raw_models.add(name)
links = []
for name in sorted(raw_models):
if isinstance(name, str) and "/" in name:
hf_url = f"https://huggingface.co/{name}"
links.append(f"[{name}]({hf_url})")
else:
links.append(str(name))
models_str = ", ".join(links)
summary_md = f"**Loaded {len(prev_rows)} result files.** \n**Models:** {models_str}"
table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>'
return summary_md, table_html, prev_rows
new_rows = []
if files:
for f in files:
path = f.name
try:
with open(path, "r", encoding="utf-8") as fp:
metrics = json.load(fp)
new_rows.append(json_to_row(path, metrics))
except Exception:
continue
all_rows = prev_rows + new_rows
if not all_rows:
empty_html = "<p>No files loaded.</p>"
return "No files uploaded.", empty_html, []
df = pd.DataFrame(all_rows)
raw_models = set()
for cell in df["Model"].tolist():
if isinstance(cell, str) and "href" in cell:
try:
name = cell.split(">", 1)[1].split("<", 1)[0]
except Exception:
name = cell
else:
name = cell
raw_models.add(name)
links = []
for name in sorted(raw_models):
if isinstance(name, str) and "/" in name:
hf_url = f"https://huggingface.co/{name}"
links.append(f"[{name}]({hf_url})")
else:
links.append(str(name))
models_str = ", ".join(links)
summary_md = f"**Loaded {len(all_rows)} result files.** \n**Models:** {models_str}"
table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>'
return summary_md, table_html, all_rows
def load_from_dir(
dir_path: str,
selected_tasks: List[str] | None = None,
selected_frameworks: List[str] | None = None,
selected_model_types: List[str] | None = None,
selected_precisions: List[str] | None = None,
search_keyword: str = "",
force_refresh: bool = False,
):
try:
pattern = f"hf://datasets/{dir_path}/**/*.json"
dl_mode = "force_redownload" if force_refresh else None
print(f"Fetching from {pattern} (mode={dl_mode})...")
ds = load_dataset(
"json",
data_files={"train": pattern},
split="train",
download_mode=dl_mode,
)
except Exception as e:
empty_html = "<p>No files loaded or Dataset not found.</p>"
return empty_html
rows = []
for i, example in enumerate(ds):
if isinstance(example, dict):
metrics = example.get("metrics") or example.get("json") or example
else:
metrics = example
rows.append(json_to_row(f"{dir_path}#{i}", metrics))
if not rows:
empty_html = "<p>No records found.</p>"
return empty_html
df = pd.DataFrame(rows)
# Dataset filter
if selected_tasks is not None:
lower_selected = [x.lower() for x in selected_tasks]
df = df[df["Dataset"].astype(str).str.lower().isin(lower_selected)]
# Inference framework filter (Method)
if selected_frameworks is not None:
lower_selected = [str(x).lower() for x in selected_frameworks]
df = df[df["Method"].astype(str).str.lower().isin(lower_selected)]
# Model type filter
if selected_model_types is not None:
lower_selected = [str(x).lower() for x in selected_model_types]
df = df[df["Model type"].astype(str).str.lower().isin(lower_selected)]
# Precision filter
if selected_precisions is not None:
lower_selected = [str(x).lower() for x in selected_precisions]
df = df[df["Precision"].astype(str).str.lower().isin(lower_selected)]
# Search keyword filter - search across all columns
if search_keyword and search_keyword.strip():
keyword_lower = search_keyword.strip().lower()
# Create a mask that checks if the keyword appears in any column
mask = df.astype(str).apply(lambda row: row.str.lower().str.contains(keyword_lower).any(), axis=1)
df = df[mask]
if df.empty:
empty_html = "<p>No records found.</p>"
return empty_html, []
df = df.fillna("-")
raw_models = set()
for cell in df["Model"].tolist():
if isinstance(cell, str) and "href" in cell:
try:
name = cell.split(">", 1)[1].split("<", 1)[0]
except Exception:
name = cell
else:
name = cell
raw_models.add(name)
links = []
for name in sorted(raw_models):
if isinstance(name, str) and "/" in name:
hf_url = f"https://huggingface.co/{name}"
links.append(f"[{name}]({hf_url})")
else:
links.append(str(name))
models_str = ", ".join(links)
# Insert row number column at the beginning for easy reference
df.insert(0, 'Row #', range(len(df)))
# Create HTML table
table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>'
df_without_rownum = df.drop('Row #', axis=1)
df_dict = df_without_rownum.to_dict('records')
return table_html, df_dict
def auto_refresh_from_dir(
dir_path: str,
selected_tasks: List[str] | None = None,
selected_frameworks: List[str] | None = None,
selected_model_types: List[str] | None = None,
selected_precisions: List[str] | None = None,
search_keyword: str = "",
):
return load_from_dir(
dir_path,
selected_tasks=selected_tasks,
selected_frameworks=selected_frameworks,
selected_model_types=selected_model_types,
selected_precisions=selected_precisions,
search_keyword=search_keyword,
force_refresh=True,
)
def update_radar_plot(df_data: list, selected_indices: list):
"""Update radar plot based on selected row indices."""
if not selected_indices or not df_data:
return generate_radar_plot([])
# Get selected rows (limit to 3)
selected_rows = [df_data[i] for i in selected_indices[:3] if i < len(df_data)]
return generate_radar_plot(selected_rows)
def parse_and_generate_plot(df_data: list, indices_str: str):
"""Parse comma-separated indices and generate radar plot."""
if not indices_str or not indices_str.strip():
return generate_radar_plot([])
try:
# Parse comma-separated indices
indices = [int(idx.strip()) for idx in indices_str.split(',') if idx.strip()]
# Limit to 3 rows
indices = indices[:3]
# Get selected rows
selected_rows = [df_data[i] for i in indices if 0 <= i < len(df_data)]
return generate_radar_plot(selected_rows)
except (ValueError, IndexError):
return generate_radar_plot([])
def on_table_select(df, evt: gr.SelectData):
"""Handle table row selection."""
return evt.index
# Gradio UI
def build_app() -> gr.Blocks:
row_css = """
/* Force light theme everywhere */
body {
background-color: #f5f7fa !important;
}
/* Row number column styling */
.metrics-table th:first-child,
.metrics-table td:first-child {
width: 60px !important;
text-align: center !important;
padding: 8px !important;
font-weight: 600 !important;
background-color: #f0f0f0 !important;
}
/* The outer Group container */
.search-box {
background: white !important;
padding: 16px !important;
border-radius: 6px;
border: 2px solid #e1e4e8 !important;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06);
margin-bottom: 16px;
}
/* Reset the internal Gradio Block styles so they fit in the white box */
.search-box .block {
background: transparent !important;
border: none !important;
padding: 0 !important;
}
/* Style the Label Text (πŸ” Search) */
.search-box label span {
color: #24292e !important;
font-weight: 600;
font-size: 14px;
margin-bottom: 8px;
background: transparent !important;
}
/* Style the actual Input Field */
.search-box input.scroll-hide {
background-color: white !important;
color: #24292e !important;
border: 1.5px solid #e1e4e8 !important;
border-radius: 4px !important;
padding: 10px !important;
box-shadow: none !important;
}
/* Fix focus state */
.search-box input.scroll-hide:focus {
border-color: #0366d6 !important;
ring: 0 !important;
outline: none !important;
}
.gradio-container {
max-width: 100% !important;
padding: 20px !important;
background-color: #f5f7fa !important;
}
/* Override all dark backgrounds */
.gradio-container .block,
.gradio-container .form,
.gradio-container fieldset,
.gradio-container .input-block,
.gradio-container .wrap,
.gradio-container .gr-box,
.gradio-container .gr-form,
.gradio-container .gr-input {
background-color: white !important;
border-color: #e1e4e8 !important;
}
.gradio-container label {
background-color: transparent !important;
color: #24292e !important;
}
/* Remove any potential dark wrappers */
.gradio-container > div,
.gradio-container .container {
background-color: transparent !important;
}
/* Force all text to be dark */
.gradio-container,
.gradio-container label,
.gradio-container p,
.gradio-container span,
.gradio-container div {
color: #24292e !important;
}
/* Table styling */
.gradio-container table.metrics-table th,
.gradio-container table.metrics-table td {
padding: 10px 14px;
border: 1.5px solid #e1e4e8;
white-space: nowrap;
font-size: 13px;
text-align: left;
color: #24292e !important;
}
.gradio-container table.metrics-table th {
background: linear-gradient(to bottom, #fafbfc, #f6f8fa);
font-weight: 600;
color: #24292e !important;
position: sticky;
top: 0;
z-index: 10;
border-bottom: 2px solid #d1d5da;
}
.gradio-container table.metrics-table tbody tr:nth-child(even) {
background-color: #f6f8fa;
}
.gradio-container table.metrics-table tbody tr:hover {
background-color: #e1e4e8;
}
.gradio-container table.metrics-table {
border-collapse: collapse;
width: 100%;
background: white;
}
.gradio-container table.metrics-table a {
color: #0366d6 !important;
text-decoration: none;
}
.gradio-container table.metrics-table a:hover {
color: #0366d6 !important;
text-decoration: underline;
}
/* Scrollable table container */
.table-container {
overflow-x: auto;
overflow-y: auto;
max-height: 75vh;
border: 2px solid #e1e4e8;
border-radius: 6px;
background: white;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06);
}
/* Filter section styling */
.filter-section {
background: white !important;
padding: 0 !important;
border-radius: 6px;
border: 2px solid #e1e4e8 !important;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06);
}
.filter-section * {
color: #24292e !important;
}
.filter-section .wrap,
.filter-section .block,
.filter-section .container,
.filter-section .group,
.filter-section > div,
.filter-section > div > div {
background: transparent !important;
}
.filter-section .wrap {
padding: 20px !important;
}
.filter-section label {
background: transparent !important;
color: #24292e !important;
}
.filter-section fieldset {
background: transparent !important;
border-color: #e1e4e8 !important;
}
/* Accordion styling */
.gradio-container .accordion {
background: white !important;
border: 2px solid #e1e4e8 !important;
border-radius: 6px !important;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06);
}
.gradio-container .accordion * {
color: #24292e !important;
}
.gradio-container .accordion label {
background: transparent !important;
color: #24292e !important;
}
.gradio-container .accordion button {
background: transparent !important;
color: #24292e !important;
}
/* Info section */
.info-section {
padding: 16px;
background: white !important;
}
/* Make text in info section dark */
.info-section p,
.info-section li,
.info-section ul,
.info-section h3,
.info-section strong,
.info-section * {
color: #24292e !important;
}
.info-section a {
color: #0366d6 !important;
}
/* Override any dark backgrounds in groups and accordions */
.gradio-container .group,
.gradio-container .accordion,
.gradio-container .panel {
background-color: white !important;
}
/* Heading styling */
.gradio-container h1 {
color: #24292e !important;
font-weight: 700;
margin-bottom: 24px;
}
.gradio-container h3 {
color: #24292e !important;
font-weight: 600;
margin-bottom: 16px;
}
/* Checkbox styling */
.gradio-container input[type="checkbox"] {
accent-color: #0366d6 !important;
}
"""
# Use Gradio's default (light) theme explicitly
with gr.Blocks(title="MoE-CAP Dashboard", css=row_css, theme=gr.themes.Default()) as demo:
gr.Markdown("# MoE-CAP Dashboard")
with gr.Row():
# Left side - Filters (narrower)
with gr.Column(scale=2):
with gr.Group(elem_classes="search-box"):
search_input = gr.Textbox(
label="πŸ” Search",
placeholder="Search across all columns...",
lines=1
)
with gr.Group(elem_classes="filter-section"):
gr.Markdown("### πŸŽ›οΈ Filters")
dir_path = gr.State(RESULT_DIR)
# 1) Tasks filter
task_filter = gr.CheckboxGroup(
label="πŸ“Š Tasks",
choices=[
("GSM8K", "gsm8k"),
("LongBench", "longbench"),
("MMLU", "mmlu"),
("NuminaMath", "numinamath"),
("RULER", "ruler")
],
value=["gsm8k", "longbench", "mmlu", "numinamath", "ruler"]
)
# 2) Inference frameworks filter
framework_filter = gr.CheckboxGroup(
label="βš™οΈ Inference Frameworks",
choices=["sglang", "vllm"],
value=["sglang", "vllm"],
)
# 3) Model types filter
model_type_filter = gr.CheckboxGroup(
label="πŸ€– Model Types",
choices=["instruct", "thinking"],
value=["instruct", "thinking"],
)
# 4) Precision filter
precision_filter = gr.CheckboxGroup(
label="🎯 Precision",
choices=["bfloat16", "fp8"],
value=["bfloat16", "fp8"],
)
with gr.Accordion("πŸ“– About Tasks & Metrics", open=True):
gr.Markdown(
"### Tasks\n"
"- **GSM8K** β€” Mathematics Problem-Solving ([paper](https://arxiv.org/abs/2110-14168))\n"
"- **LongBench** β€” Long-Context Understanding ([paper](https://arxiv.org/abs/2412.15204))\n"
"- **MMLU** β€” Multitask Language Understanding ([paper](https://arxiv.org/abs/2009.03300))\n"
"- **NuminaMath** β€” Mathematical Reasoning ([paper](http://faculty.bicmr.pku.edu.cn/~dongbin/Publications/numina_dataset.pdf))\n"
"- **RULER** β€” Extreme Long-Context Eval ([paper](https://arxiv.org/abs/2404.06654))\n\n"
"### Metrics\n"
"- **E2E(s)** β€” End-to-End Latency\n"
"- **Accuracy(%)** β€” Task Accuracy\n"
"- **Cost($)** β€” Inference Cost\n"
"- **Decoding/Prefill T/s** β€” Throughput\n"
"- **S-MBU/MFU(%)** β€” Hardware Utilization\n"
"- **TTFT(s)** β€” Time To First Token\n"
"- **TPOT(s)** β€” Time Per Output Token",
elem_classes="info-section"
)
# Right side - Table with selection and Radar Plot below
with gr.Column(scale=5):
leaderboard_output = gr.HTML(label="πŸ“ˆ Results")
with gr.Group(elem_classes="filter-section"):
gr.Markdown("### πŸ“Š CAP Radar Plot")
gr.Markdown(
"**How to use:** Look at the 'Row #' column in the table above. "
"Enter up to 3 row numbers below (separated by commas) and click Generate."
)
with gr.Row():
row_indices_input = gr.Textbox(
label="Row Numbers to Compare",
placeholder="Example: 0,1,2",
elem_id="row_indices_input",
scale=3
)
generate_btn = gr.Button("🎯 Generate", variant="primary", scale=1, size="lg")
with gr.Row():
with gr.Column(scale=1):
pass
with gr.Column(scale=5):
radar_plot = gr.Plot(label="", value=generate_radar_plot([]))
with gr.Column(scale=1):
pass
df_data_state = gr.State([])
demo.load(
fn=auto_refresh_from_dir,
inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
outputs=[leaderboard_output, df_data_state],
)
search_input.change(
fn=load_from_dir,
inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
outputs=[leaderboard_output, df_data_state],
)
task_filter.change(
fn=load_from_dir,
inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
outputs=[leaderboard_output, df_data_state],
)
framework_filter.change(
fn=load_from_dir,
inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
outputs=[leaderboard_output, df_data_state],
)
model_type_filter.change(
fn=load_from_dir,
inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
outputs=[leaderboard_output, df_data_state],
)
precision_filter.change(
fn=load_from_dir,
inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
outputs=[leaderboard_output, df_data_state],
)
# Generate plot on button click
generate_btn.click(
fn=parse_and_generate_plot,
inputs=[df_data_state, row_indices_input],
outputs=[radar_plot]
)
timer = gr.Timer(60.0)
timer.tick(
fn=auto_refresh_from_dir,
inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input],
outputs=[leaderboard_output, df_data_state],
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch()