AppleSwing's picture
Update app.py
2decc9c verified
#!/usr/bin/env python
import os
import json
from typing import List, Tuple
os.environ["GRADIO_LANGUAGE"] = "en"
RESULT_DIR = os.environ.get("MOECAP_RESULT_DIR")
if not RESULT_DIR:
# For testing purposes, you can uncomment the line below:
# RESULT_DIR = "generic_result_dir"
# If you are running locally without this env var,
# ensure you handle this error or set the var.
pass
import gradio as gr
import pandas as pd
from datasets import load_dataset
import plotly.graph_objects as go
def f2(x):
"""Format to 2 decimal places if number, else return as-is."""
if isinstance(x, (int, float)):
return round(float(x), 2)
return x
def normalize(val, vmin, vmax, baseline=20):
"""Normalize value to baseline-100 range."""
if vmax == vmin:
return baseline + 40
return baseline + (val - vmin) / (vmax - vmin) * (100 - baseline)
def normalize_cost(val, max_tick, baseline=20):
"""Normalize cost (lower is better)."""
if max_tick == 0:
return baseline + 40
return baseline + (max_tick - min(val, max_tick)) / max_tick * (100 - baseline)
def generate_radar_plot(selected_rows_data: List[dict]) -> go.Figure:
"""Generate a CAP radar plot from selected rows."""
layout_settings = dict(
height=750,
autosize=True,
margin=dict(t=80, b=100, l=80, r=80),
paper_bgcolor='white',
plot_bgcolor='white',
)
if not selected_rows_data or len(selected_rows_data) == 0:
fig = go.Figure()
fig.add_annotation(
text="Please select 1-3 rows from the table to generate radar plot",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=16, color="black"), # Ensure text is black
xanchor='center', yanchor='middle'
)
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings)
return fig
if len(selected_rows_data) > 3:
fig = go.Figure()
fig.add_annotation(
text="Error: Please select no more than 3 rows!",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=18, color="red"),
xanchor='center', yanchor='middle'
)
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings)
return fig
datasets = [row.get('Dataset', '') for row in selected_rows_data]
unique_datasets = set(datasets)
if len(unique_datasets) > 1:
fig = go.Figure()
fig.add_annotation(
text="Error: Please select rows from the same dataset!",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=18, color="red"),
xanchor='center', yanchor='middle'
)
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings)
return fig
dataset_name = datasets[0] if datasets else "Unknown"
data = {}
for row in selected_rows_data:
model_name = row.get('Model', 'Unknown')
if isinstance(model_name, str) and 'href' in model_name:
try:
model_name = model_name.split('>', 1)[1].split('<', 1)[0]
except:
pass
method = row.get('Method', '')
if isinstance(model_name, str) and '/' in model_name:
legend_name = model_name.split('/')[-1]
else:
legend_name = str(model_name)
if method and method not in ['Unknown', '-', '']:
legend_name = f"{legend_name}-{method}"
acc = row.get('Accuracy(%)', 0)
cost = row.get('Cost($)', 0)
throughput = row.get('Decoding T/s', 0)
try:
acc = float(acc) if acc not in [None, '-', ''] else 0
cost = float(cost) if cost not in [None, '-', ''] else 0
throughput = float(throughput) if throughput not in [None, '-', ''] else 0
except:
acc, cost, throughput = 0, 0, 0
data[legend_name] = {
'accuracy': acc / 100.0 if acc > 1 else acc,
'cost': cost,
'throughput': throughput
}
throughputs = [v['throughput'] for v in data.values()]
costs = [v['cost'] for v in data.values()]
accs = [v['accuracy'] for v in data.values()]
tp_min, tp_max = (min(throughputs), max(throughputs)) if throughputs else (0, 1)
cost_max = max(costs) if costs else 1
acc_min, acc_max = (min(accs), 1.0) if accs else (0, 1)
baseline = 20
categories = ['Throughput (T/s)', 'Cost ($)', 'Accuracy', 'Throughput (T/s)']
fig = go.Figure()
for system, values in data.items():
raw_vals = [values['throughput'], values['cost'], values['accuracy']]
norm_vals = [
normalize(values['throughput'], tp_min, tp_max, baseline),
normalize_cost(values['cost'], cost_max, baseline),
normalize(values['accuracy'], acc_min, acc_max, baseline)
]
norm_vals += [norm_vals[0]]
hovertext = [
f"Throughput: {raw_vals[0]:.2f} T/s",
f"Cost: ${raw_vals[1]:.2f}",
f"Accuracy: {raw_vals[2]*100:.2f}%",
f"Throughput: {raw_vals[0]:.2f} T/s"
]
fig.add_trace(go.Scatterpolar(
r=norm_vals,
theta=categories,
fill='toself',
name=system,
text=hovertext,
hoverinfo='text+name',
line=dict(width=2)
))
fig.update_layout(
title=dict(text=f"CAP Radar Plot: {dataset_name}", x=0.5, xanchor='center', font=dict(size=20, color="black")),
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100],
tickfont=dict(size=12, color="black"),
gridcolor='lightgray', # Add this
linecolor='gray', # Add this
showline=True # Add this
),
angularaxis=dict(
tickfont=dict(size=14, color="black"),
rotation=90,
direction='clockwise',
gridcolor='lightgray', # Add this
linecolor='gray', # Add this
showline=True # Add this
),
bgcolor="white"
),
legend=dict(orientation='h', yanchor='bottom', y=-0.15, xanchor='center', x=0.5, font=dict(size=13, color="black")),
**layout_settings
)
return fig
def json_to_row(path: str, metrics: dict) -> dict:
model_name = metrics.get("model_name")
if not model_name:
model_name = "unknown-model"
dataset = metrics.get("dataset", "Unknown")
method = metrics.get("method", "Unknown")
precision = metrics.get("precision", "Unknown")
model_type = metrics.get("model_type", "Unknown")
e2e_s = metrics.get("e2e_s", None)
batch_size = metrics.get("batch_size", None)
gpu_type = metrics.get("gpu_type", "")
cost = metrics.get("cost", None)
em = metrics.get("exact_match")
correct = metrics.get("correct")
total = metrics.get("total")
if isinstance(correct, (int, float)) and isinstance(total, (int, float)) and total > 0:
acc = correct / total
else:
acc = em
def pct(x):
return round(x * 100, 2) if isinstance(x, (int, float)) else None
if isinstance(model_name, str) and "/" in model_name:
hf_url = f"https://huggingface.co/{model_name}"
model_cell = f"<a href='{hf_url}' target='_blank' style='color: #0366d6; text-decoration: none;'>{model_name}</a>"
else:
model_cell = model_name
row = {
"Model": model_cell,
"Dataset": dataset,
"Method": method,
"Model type": model_type,
"Precision": precision,
"E2E(s)": f2(e2e_s),
"GPU": gpu_type,
"Accuracy(%)": pct(acc),
"Cost($)": cost,
"Decoding T/s": f2(metrics.get("decoding_throughput")),
"Prefill T/s": f2(metrics.get("prefill_tp")),
"Prefill<br>S-MBU(%)": pct(metrics.get("prefill_smbu")),
"Prefill<br>S-MFU(%)": pct(metrics.get("prefill_smfu")),
"Decoding<br>S-MBU(%)": pct(metrics.get("decoding_smbu")),
"Decoding<br>S-MFU(%)": pct(metrics.get("decoding_smfu")),
"TTFT(s)": f2(metrics.get("ttft")),
"TPOT(s)": f2(metrics.get("tpot")),
"Batch size": batch_size,
}
return row
def load_from_dir(dir_path: str, selected_tasks=None, selected_frameworks=None, selected_model_types=None, selected_precisions=None, search_keyword="", force_refresh=False):
if not dir_path:
return "<p style='color:black'>Result Directory not set.</p>", []
try:
pattern = f"hf://datasets/{dir_path}/**/*.json"
dl_mode = "force_redownload" if force_refresh else None
print(f"Fetching from {pattern} (mode={dl_mode})...")
ds = load_dataset("json", data_files={"train": pattern}, split="train", download_mode=dl_mode)
except Exception as e:
print(f"Error loading dataset: {e}")
return "<p style='color:black'>No files loaded or Dataset not found.</p>", []
rows = []
for i, example in enumerate(ds):
metrics = example.get("metrics") or example.get("json") or example
rows.append(json_to_row(f"{dir_path}#{i}", metrics))
if not rows:
return "<p style='color:black'>No records found.</p>", []
df = pd.DataFrame(rows)
# --- Filtering Logic ---
# This logic is consistent: if a filter is provided, we ONLY keep rows
# where the column value is inside the selected list.
if selected_tasks:
df = df[df["Dataset"].astype(str).str.lower().isin([x.lower() for x in selected_tasks])]
if selected_frameworks:
df = df[df["Method"].astype(str).str.lower().isin([str(x).lower() for x in selected_frameworks])]
if selected_model_types:
df = df[df["Model type"].astype(str).str.lower().isin([str(x).lower() for x in selected_model_types])]
if selected_precisions:
df = df[df["Precision"].astype(str).str.lower().isin([str(x).lower() for x in selected_precisions])]
if search_keyword and search_keyword.strip():
df = df[df.astype(str).apply(lambda row: row.str.lower().str.contains(search_keyword.strip().lower()).any(), axis=1)]
if df.empty:
return "<p style='color:black'>No records found.</p>", []
df = df.fillna("-")
df.insert(0, 'Row #', range(len(df)))
table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>'
df_without_rownum = df.drop('Row #', axis=1)
return table_html, df_without_rownum.to_dict('records')
def auto_refresh_from_dir(dir_path, tasks, frameworks, types, precisions, search):
return load_from_dir(dir_path, tasks, frameworks, types, precisions, search, force_refresh=True)
def parse_and_generate_plot(df_data, indices_str):
if not indices_str or not indices_str.strip():
return generate_radar_plot([])
try:
indices = [int(idx.strip()) for idx in indices_str.split(',') if idx.strip()][:3]
selected_rows = [df_data[i] for i in indices if 0 <= i < len(df_data)]
return generate_radar_plot(selected_rows)
except:
return generate_radar_plot([])
def build_app() -> gr.Blocks:
# NUCLEAR CSS FIX: Overwrite all generic Gradio variables to force light mode
row_css = """
/* 1. FORCE LIGHT VARIABLES GLOBALLY */
:root, .gradio-container, body {
--body-background-fill: #f5f7fa !important;
--body-text-color: #374151 !important;
--background-fill-primary: #ffffff !important;
--background-fill-secondary: #f3f4f6 !important;
--border-color-primary: #e5e7eb !important;
--block-background-fill: #ffffff !important;
--block-label-text-color: #374151 !important;
--block-title-text-color: #1f2937 !important;
--input-background-fill: #ffffff !important;
--color-accent: #0366d6 !important;
/* Reset dark mode specific variables to light values */
--neutral-50: #f9fafb; --neutral-100: #f3f4f6; --neutral-200: #e5e7eb;
--neutral-300: #d1d5da; --neutral-400: #9ca3af; --neutral-500: #6b7280;
--neutral-600: #4b5563; --neutral-700: #374151; --neutral-800: #1f2937;
}
/* 2. RESET STANDARD CONTAINERS */
.gradio-container .block,
.gradio-container .panel,
.gradio-container .form {
background-color: white !important;
border-color: #e1e4e8 !important;
}
/* 3. SPECIFIC FIX FOR THE DARK "FILTERS" and "RADAR" SECTIONS */
.filter-section {
background-color: #ffffff !important;
border: 2px solid #e1e4e8 !important;
border-radius: 8px !important;
padding: 16px !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important;
color: #24292e !important; /* Set default text color for the section */
}
/* Remove background color from text elements to prevent "dark blocks" */
.filter-section label,
.filter-section span,
.filter-section p {
background-color: transparent !important;
}
/* 4. BUTTON FIXES - TARGET BY ID FOR SPECIFICITY */
#gen_btn {
background-color: #0366d6 !important;
color: white !important;
border: none !important;
}
#gen_btn:hover {
opacity: 0.9;
}
/* 5. INPUTS & CHECKBOXES */
/* Re-apply white background to inputs specifically */
.filter-section input,
.filter-section textarea,
.filter-section select {
background-color: #ffffff !important;
border: 1px solid #d1d5da !important;
color: #24292e !important;
}
/* --- FIX FOR CHECKBOXES --- */
/* Use explicit styling for the checked state to ensure visibility */
.filter-section input[type="checkbox"] {
appearance: none !important;
-webkit-appearance: none !important;
width: 16px !important;
height: 16px !important;
background-color: white !important;
border: 1px solid #d1d5da !important;
border-radius: 3px !important;
position: relative !important;
cursor: pointer !important;
}
.filter-section input[type="checkbox"]:checked {
background-color: #0366d6 !important;
border-color: #0366d6 !important;
/* Draw the checkmark using an SVG data URI */
background-image: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e") !important;
background-size: 100% 100% !important;
background-position: center !important;
background-repeat: no-repeat !important;
}
.filter-section label span {
color: #24292e !important;
}
/* 6. SEARCH BOX */
.search-box {
background: white !important;
padding: 16px !important;
border-radius: 6px;
border: 2px solid #e1e4e8 !important;
margin-bottom: 16px;
}
/* 7. TABLE STYLING */
.table-container {
overflow-x: auto;
max-height: 75vh;
border: 2px solid #e1e4e8;
border-radius: 6px;
background: white !important;
}
table.metrics-table {
width: 100%; border-collapse: collapse; background: white !important;
}
table.metrics-table th, table.metrics-table td {
padding: 10px 14px; border: 1px solid #e1e4e8;
white-space: nowrap; font-size: 13px; color: #24292e !important;
}
table.metrics-table th {
background: #f6f8fa !important; font-weight: 600; position: sticky; top: 0;
}
.metrics-table th:first-child, .metrics-table td:first-child {
background-color: #f0f0f0 !important; text-align: center;
}
/* 8. PLOT CONTAINER - FORCE WHITE BACKGROUND */
.plot-container {
width: 100% !important;
background-color: white !important;
}
.plot-container > div, .plot-container .plotly {
background-color: white !important;
}
/* 9. LINKS */
a { color: #0366d6 !important; text-decoration: none; }
a:hover { text-decoration: underline; }
"""
with gr.Blocks(title="MoE-CAP Dashboard", css=row_css, theme=gr.themes.Default()) as demo:
gr.Markdown("# MoE-CAP Dashboard")
with gr.Row():
# Left Sidebar
with gr.Column(scale=2):
with gr.Group(elem_classes="search-box"):
search_input = gr.Textbox(label="πŸ” Search", placeholder="Search...", lines=1)
with gr.Group(elem_classes="filter-section"):
gr.Markdown("### πŸŽ›οΈ Filters")
dir_path = gr.State(RESULT_DIR)
task_filter = gr.CheckboxGroup(
label="πŸ“Š Tasks",
choices=[("GSM8K", "gsm8k"), ("LongBench", "longbench"), ("MMLU", "mmlu"), ("NuminaMath", "numinamath"), ("RULER", "ruler")],
value=["gsm8k", "longbench", "mmlu", "numinamath", "ruler"]
)
framework_filter = gr.CheckboxGroup(label="βš™οΈ Frameworks", choices=["sglang", "vllm"], value=["sglang", "vllm"])
model_type_filter = gr.CheckboxGroup(label="πŸ€– Model Types", choices=["instruct", "thinking"], value=["instruct", "thinking"])
precision_filter = gr.CheckboxGroup(label="🎯 Precision", choices=["bfloat16", "fp8"], value=["bfloat16", "fp8"])
with gr.Accordion("πŸ“– About Tasks & Metrics", open=True):
gr.Markdown(
"### Tasks\n- **GSM8K**, **LongBench**, **MMLU**, **NuminaMath**, **RULER**\n\n"
"### Metrics\n- **E2E(s)**: Latency | **Cost($)** | **T/s**: Throughput | **S-MBU/MFU**: Utilization",
elem_classes="info-section"
)
# Right Main Content
with gr.Column(scale=5):
leaderboard_output = gr.HTML(label="πŸ“ˆ Results")
with gr.Group(elem_classes="filter-section"):
gr.Markdown("### πŸ“Š CAP Radar Plot")
gr.Markdown("**How to use:** Look at the 'Row #' column in the table. Enter row numbers (e.g., 0,1,2) and click Generate.")
with gr.Row():
row_indices_input = gr.Textbox(label="Row Numbers", placeholder="0,1,2", scale=3)
# Added elem_id="gen_btn" here for specific CSS targeting
generate_btn = gr.Button("🎯 Generate", variant="primary", scale=1, elem_id="gen_btn")
radar_plot = gr.Plot(value=generate_radar_plot([]), elem_classes="plot-container")
# State & Events
df_data_state = gr.State([])
inputs = [dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input]
demo.load(fn=auto_refresh_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state])
search_input.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state])
task_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state])
framework_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state])
model_type_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state])
precision_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state])
generate_btn.click(fn=parse_and_generate_plot, inputs=[df_data_state, row_indices_input], outputs=[radar_plot])
gr.Timer(60.0).tick(fn=auto_refresh_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state])
return demo
if __name__ == "__main__":
app = build_app()
app.launch()