#!/usr/bin/env python import os import json from typing import List, Tuple os.environ["GRADIO_LANGUAGE"] = "en" RESULT_DIR = os.environ.get("MOECAP_RESULT_DIR") if not RESULT_DIR: # For testing purposes, you can uncomment the line below: # RESULT_DIR = "generic_result_dir" # If you are running locally without this env var, # ensure you handle this error or set the var. pass import gradio as gr import pandas as pd from datasets import load_dataset import plotly.graph_objects as go def f2(x): """Format to 2 decimal places if number, else return as-is.""" if isinstance(x, (int, float)): return round(float(x), 2) return x def normalize(val, vmin, vmax, baseline=20): """Normalize value to baseline-100 range.""" if vmax == vmin: return baseline + 40 return baseline + (val - vmin) / (vmax - vmin) * (100 - baseline) def normalize_cost(val, max_tick, baseline=20): """Normalize cost (lower is better).""" if max_tick == 0: return baseline + 40 return baseline + (max_tick - min(val, max_tick)) / max_tick * (100 - baseline) def generate_radar_plot(selected_rows_data: List[dict]) -> go.Figure: """Generate a CAP radar plot from selected rows.""" layout_settings = dict( height=750, autosize=True, margin=dict(t=80, b=100, l=80, r=80), paper_bgcolor='white', plot_bgcolor='white', ) if not selected_rows_data or len(selected_rows_data) == 0: fig = go.Figure() fig.add_annotation( text="Please select 1-3 rows from the table to generate radar plot", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16, color="black"), # Ensure text is black xanchor='center', yanchor='middle' ) fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings) return fig if len(selected_rows_data) > 3: fig = go.Figure() fig.add_annotation( text="Error: Please select no more than 3 rows!", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=18, color="red"), xanchor='center', yanchor='middle' ) fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings) return fig datasets = [row.get('Dataset', '') for row in selected_rows_data] unique_datasets = set(datasets) if len(unique_datasets) > 1: fig = go.Figure() fig.add_annotation( text="Error: Please select rows from the same dataset!", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=18, color="red"), xanchor='center', yanchor='middle' ) fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings) return fig dataset_name = datasets[0] if datasets else "Unknown" data = {} for row in selected_rows_data: model_name = row.get('Model', 'Unknown') if isinstance(model_name, str) and 'href' in model_name: try: model_name = model_name.split('>', 1)[1].split('<', 1)[0] except: pass method = row.get('Method', '') if isinstance(model_name, str) and '/' in model_name: legend_name = model_name.split('/')[-1] else: legend_name = str(model_name) if method and method not in ['Unknown', '-', '']: legend_name = f"{legend_name}-{method}" acc = row.get('Accuracy(%)', 0) cost = row.get('Cost($)', 0) throughput = row.get('Decoding T/s', 0) try: acc = float(acc) if acc not in [None, '-', ''] else 0 cost = float(cost) if cost not in [None, '-', ''] else 0 throughput = float(throughput) if throughput not in [None, '-', ''] else 0 except: acc, cost, throughput = 0, 0, 0 data[legend_name] = { 'accuracy': acc / 100.0 if acc > 1 else acc, 'cost': cost, 'throughput': throughput } throughputs = [v['throughput'] for v in data.values()] costs = [v['cost'] for v in data.values()] accs = [v['accuracy'] for v in data.values()] tp_min, tp_max = (min(throughputs), max(throughputs)) if throughputs else (0, 1) cost_max = max(costs) if costs else 1 acc_min, acc_max = (min(accs), 1.0) if accs else (0, 1) baseline = 20 categories = ['Throughput (T/s)', 'Cost ($)', 'Accuracy', 'Throughput (T/s)'] fig = go.Figure() for system, values in data.items(): raw_vals = [values['throughput'], values['cost'], values['accuracy']] norm_vals = [ normalize(values['throughput'], tp_min, tp_max, baseline), normalize_cost(values['cost'], cost_max, baseline), normalize(values['accuracy'], acc_min, acc_max, baseline) ] norm_vals += [norm_vals[0]] hovertext = [ f"Throughput: {raw_vals[0]:.2f} T/s", f"Cost: ${raw_vals[1]:.2f}", f"Accuracy: {raw_vals[2]*100:.2f}%", f"Throughput: {raw_vals[0]:.2f} T/s" ] fig.add_trace(go.Scatterpolar( r=norm_vals, theta=categories, fill='toself', name=system, text=hovertext, hoverinfo='text+name', line=dict(width=2) )) fig.update_layout( title=dict(text=f"CAP Radar Plot: {dataset_name}", x=0.5, xanchor='center', font=dict(size=20, color="black")), polar=dict( radialaxis=dict( visible=True, range=[0, 100], tickfont=dict(size=12, color="black"), gridcolor='lightgray', # Add this linecolor='gray', # Add this showline=True # Add this ), angularaxis=dict( tickfont=dict(size=14, color="black"), rotation=90, direction='clockwise', gridcolor='lightgray', # Add this linecolor='gray', # Add this showline=True # Add this ), bgcolor="white" ), legend=dict(orientation='h', yanchor='bottom', y=-0.15, xanchor='center', x=0.5, font=dict(size=13, color="black")), **layout_settings ) return fig def json_to_row(path: str, metrics: dict) -> dict: model_name = metrics.get("model_name") if not model_name: model_name = "unknown-model" dataset = metrics.get("dataset", "Unknown") method = metrics.get("method", "Unknown") precision = metrics.get("precision", "Unknown") model_type = metrics.get("model_type", "Unknown") e2e_s = metrics.get("e2e_s", None) batch_size = metrics.get("batch_size", None) gpu_type = metrics.get("gpu_type", "") cost = metrics.get("cost", None) em = metrics.get("exact_match") correct = metrics.get("correct") total = metrics.get("total") if isinstance(correct, (int, float)) and isinstance(total, (int, float)) and total > 0: acc = correct / total else: acc = em def pct(x): return round(x * 100, 2) if isinstance(x, (int, float)) else None if isinstance(model_name, str) and "/" in model_name: hf_url = f"https://huggingface.co/{model_name}" model_cell = f"{model_name}" else: model_cell = model_name row = { "Model": model_cell, "Dataset": dataset, "Method": method, "Model type": model_type, "Precision": precision, "E2E(s)": f2(e2e_s), "GPU": gpu_type, "Accuracy(%)": pct(acc), "Cost($)": cost, "Decoding T/s": f2(metrics.get("decoding_throughput")), "Prefill T/s": f2(metrics.get("prefill_tp")), "Prefill
S-MBU(%)": pct(metrics.get("prefill_smbu")), "Prefill
S-MFU(%)": pct(metrics.get("prefill_smfu")), "Decoding
S-MBU(%)": pct(metrics.get("decoding_smbu")), "Decoding
S-MFU(%)": pct(metrics.get("decoding_smfu")), "TTFT(s)": f2(metrics.get("ttft")), "TPOT(s)": f2(metrics.get("tpot")), "Batch size": batch_size, } return row def load_from_dir(dir_path: str, selected_tasks=None, selected_frameworks=None, selected_model_types=None, selected_precisions=None, search_keyword="", force_refresh=False): if not dir_path: return "

Result Directory not set.

", [] try: pattern = f"hf://datasets/{dir_path}/**/*.json" dl_mode = "force_redownload" if force_refresh else None print(f"Fetching from {pattern} (mode={dl_mode})...") ds = load_dataset("json", data_files={"train": pattern}, split="train", download_mode=dl_mode) except Exception as e: print(f"Error loading dataset: {e}") return "

No files loaded or Dataset not found.

", [] rows = [] for i, example in enumerate(ds): metrics = example.get("metrics") or example.get("json") or example rows.append(json_to_row(f"{dir_path}#{i}", metrics)) if not rows: return "

No records found.

", [] df = pd.DataFrame(rows) # --- Filtering Logic --- # This logic is consistent: if a filter is provided, we ONLY keep rows # where the column value is inside the selected list. if selected_tasks: df = df[df["Dataset"].astype(str).str.lower().isin([x.lower() for x in selected_tasks])] if selected_frameworks: df = df[df["Method"].astype(str).str.lower().isin([str(x).lower() for x in selected_frameworks])] if selected_model_types: df = df[df["Model type"].astype(str).str.lower().isin([str(x).lower() for x in selected_model_types])] if selected_precisions: df = df[df["Precision"].astype(str).str.lower().isin([str(x).lower() for x in selected_precisions])] if search_keyword and search_keyword.strip(): df = df[df.astype(str).apply(lambda row: row.str.lower().str.contains(search_keyword.strip().lower()).any(), axis=1)] if df.empty: return "

No records found.

", [] df = df.fillna("-") df.insert(0, 'Row #', range(len(df))) table_html = f'
{df.to_html(escape=False, index=False, classes="metrics-table")}
' df_without_rownum = df.drop('Row #', axis=1) return table_html, df_without_rownum.to_dict('records') def auto_refresh_from_dir(dir_path, tasks, frameworks, types, precisions, search): return load_from_dir(dir_path, tasks, frameworks, types, precisions, search, force_refresh=True) def parse_and_generate_plot(df_data, indices_str): if not indices_str or not indices_str.strip(): return generate_radar_plot([]) try: indices = [int(idx.strip()) for idx in indices_str.split(',') if idx.strip()][:3] selected_rows = [df_data[i] for i in indices if 0 <= i < len(df_data)] return generate_radar_plot(selected_rows) except: return generate_radar_plot([]) def build_app() -> gr.Blocks: # NUCLEAR CSS FIX: Overwrite all generic Gradio variables to force light mode row_css = """ /* 1. FORCE LIGHT VARIABLES GLOBALLY */ :root, .gradio-container, body { --body-background-fill: #f5f7fa !important; --body-text-color: #374151 !important; --background-fill-primary: #ffffff !important; --background-fill-secondary: #f3f4f6 !important; --border-color-primary: #e5e7eb !important; --block-background-fill: #ffffff !important; --block-label-text-color: #374151 !important; --block-title-text-color: #1f2937 !important; --input-background-fill: #ffffff !important; --color-accent: #0366d6 !important; /* Reset dark mode specific variables to light values */ --neutral-50: #f9fafb; --neutral-100: #f3f4f6; --neutral-200: #e5e7eb; --neutral-300: #d1d5da; --neutral-400: #9ca3af; --neutral-500: #6b7280; --neutral-600: #4b5563; --neutral-700: #374151; --neutral-800: #1f2937; } /* 2. RESET STANDARD CONTAINERS */ .gradio-container .block, .gradio-container .panel, .gradio-container .form { background-color: white !important; border-color: #e1e4e8 !important; } /* 3. SPECIFIC FIX FOR THE DARK "FILTERS" and "RADAR" SECTIONS */ .filter-section { background-color: #ffffff !important; border: 2px solid #e1e4e8 !important; border-radius: 8px !important; padding: 16px !important; box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; color: #24292e !important; /* Set default text color for the section */ } /* Remove background color from text elements to prevent "dark blocks" */ .filter-section label, .filter-section span, .filter-section p { background-color: transparent !important; } /* 4. BUTTON FIXES - TARGET BY ID FOR SPECIFICITY */ #gen_btn { background-color: #0366d6 !important; color: white !important; border: none !important; } #gen_btn:hover { opacity: 0.9; } /* 5. INPUTS & CHECKBOXES */ /* Re-apply white background to inputs specifically */ .filter-section input, .filter-section textarea, .filter-section select { background-color: #ffffff !important; border: 1px solid #d1d5da !important; color: #24292e !important; } /* --- FIX FOR CHECKBOXES --- */ /* Use explicit styling for the checked state to ensure visibility */ .filter-section input[type="checkbox"] { appearance: none !important; -webkit-appearance: none !important; width: 16px !important; height: 16px !important; background-color: white !important; border: 1px solid #d1d5da !important; border-radius: 3px !important; position: relative !important; cursor: pointer !important; } .filter-section input[type="checkbox"]:checked { background-color: #0366d6 !important; border-color: #0366d6 !important; /* Draw the checkmark using an SVG data URI */ background-image: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e") !important; background-size: 100% 100% !important; background-position: center !important; background-repeat: no-repeat !important; } .filter-section label span { color: #24292e !important; } /* 6. SEARCH BOX */ .search-box { background: white !important; padding: 16px !important; border-radius: 6px; border: 2px solid #e1e4e8 !important; margin-bottom: 16px; } /* 7. TABLE STYLING */ .table-container { overflow-x: auto; max-height: 75vh; border: 2px solid #e1e4e8; border-radius: 6px; background: white !important; } table.metrics-table { width: 100%; border-collapse: collapse; background: white !important; } table.metrics-table th, table.metrics-table td { padding: 10px 14px; border: 1px solid #e1e4e8; white-space: nowrap; font-size: 13px; color: #24292e !important; } table.metrics-table th { background: #f6f8fa !important; font-weight: 600; position: sticky; top: 0; } .metrics-table th:first-child, .metrics-table td:first-child { background-color: #f0f0f0 !important; text-align: center; } /* 8. PLOT CONTAINER - FORCE WHITE BACKGROUND */ .plot-container { width: 100% !important; background-color: white !important; } .plot-container > div, .plot-container .plotly { background-color: white !important; } /* 9. LINKS */ a { color: #0366d6 !important; text-decoration: none; } a:hover { text-decoration: underline; } """ with gr.Blocks(title="MoE-CAP Dashboard", css=row_css, theme=gr.themes.Default()) as demo: gr.Markdown("# MoE-CAP Dashboard") with gr.Row(): # Left Sidebar with gr.Column(scale=2): with gr.Group(elem_classes="search-box"): search_input = gr.Textbox(label="🔍 Search", placeholder="Search...", lines=1) with gr.Group(elem_classes="filter-section"): gr.Markdown("### 🎛️ Filters") dir_path = gr.State(RESULT_DIR) task_filter = gr.CheckboxGroup( label="📊 Tasks", choices=[("GSM8K", "gsm8k"), ("LongBench", "longbench"), ("MMLU", "mmlu"), ("NuminaMath", "numinamath"), ("RULER", "ruler")], value=["gsm8k", "longbench", "mmlu", "numinamath", "ruler"] ) framework_filter = gr.CheckboxGroup(label="⚙️ Frameworks", choices=["sglang", "vllm"], value=["sglang", "vllm"]) model_type_filter = gr.CheckboxGroup(label="🤖 Model Types", choices=["instruct", "thinking"], value=["instruct", "thinking"]) precision_filter = gr.CheckboxGroup(label="🎯 Precision", choices=["bfloat16", "fp8"], value=["bfloat16", "fp8"]) with gr.Accordion("📖 About Tasks & Metrics", open=True): gr.Markdown( "### Tasks\n- **GSM8K**, **LongBench**, **MMLU**, **NuminaMath**, **RULER**\n\n" "### Metrics\n- **E2E(s)**: Latency | **Cost($)** | **T/s**: Throughput | **S-MBU/MFU**: Utilization", elem_classes="info-section" ) # Right Main Content with gr.Column(scale=5): leaderboard_output = gr.HTML(label="📈 Results") with gr.Group(elem_classes="filter-section"): gr.Markdown("### 📊 CAP Radar Plot") gr.Markdown("**How to use:** Look at the 'Row #' column in the table. Enter row numbers (e.g., 0,1,2) and click Generate.") with gr.Row(): row_indices_input = gr.Textbox(label="Row Numbers", placeholder="0,1,2", scale=3) # Added elem_id="gen_btn" here for specific CSS targeting generate_btn = gr.Button("🎯 Generate", variant="primary", scale=1, elem_id="gen_btn") radar_plot = gr.Plot(value=generate_radar_plot([]), elem_classes="plot-container") # State & Events df_data_state = gr.State([]) inputs = [dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input] demo.load(fn=auto_refresh_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) search_input.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) task_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) framework_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) model_type_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) precision_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) generate_btn.click(fn=parse_and_generate_plot, inputs=[df_data_state, row_indices_input], outputs=[radar_plot]) gr.Timer(60.0).tick(fn=auto_refresh_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) return demo if __name__ == "__main__": app = build_app() app.launch()