Spaces:
Running
Running
| #!/usr/bin/env python | |
| import os | |
| import json | |
| from typing import List, Tuple | |
| os.environ["GRADIO_LANGUAGE"] = "en" | |
| RESULT_DIR = os.environ.get("MOECAP_RESULT_DIR") | |
| if not RESULT_DIR: | |
| # For testing purposes, you can uncomment the line below: | |
| # RESULT_DIR = "generic_result_dir" | |
| # If you are running locally without this env var, | |
| # ensure you handle this error or set the var. | |
| pass | |
| import gradio as gr | |
| import pandas as pd | |
| from datasets import load_dataset | |
| import plotly.graph_objects as go | |
| def f2(x): | |
| """Format to 2 decimal places if number, else return as-is.""" | |
| if isinstance(x, (int, float)): | |
| return round(float(x), 2) | |
| return x | |
| def normalize(val, vmin, vmax, baseline=20): | |
| """Normalize value to baseline-100 range.""" | |
| if vmax == vmin: | |
| return baseline + 40 | |
| return baseline + (val - vmin) / (vmax - vmin) * (100 - baseline) | |
| def normalize_cost(val, max_tick, baseline=20): | |
| """Normalize cost (lower is better).""" | |
| if max_tick == 0: | |
| return baseline + 40 | |
| return baseline + (max_tick - min(val, max_tick)) / max_tick * (100 - baseline) | |
| def generate_radar_plot(selected_rows_data: List[dict]) -> go.Figure: | |
| """Generate a CAP radar plot from selected rows.""" | |
| layout_settings = dict( | |
| height=750, | |
| autosize=True, | |
| margin=dict(t=80, b=100, l=80, r=80), | |
| paper_bgcolor='white', | |
| plot_bgcolor='white', | |
| ) | |
| if not selected_rows_data or len(selected_rows_data) == 0: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Please select 1-3 rows from the table to generate radar plot", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=16, color="black"), # Ensure text is black | |
| xanchor='center', yanchor='middle' | |
| ) | |
| fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings) | |
| return fig | |
| if len(selected_rows_data) > 3: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Error: Please select no more than 3 rows!", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=18, color="red"), | |
| xanchor='center', yanchor='middle' | |
| ) | |
| fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings) | |
| return fig | |
| datasets = [row.get('Dataset', '') for row in selected_rows_data] | |
| unique_datasets = set(datasets) | |
| if len(unique_datasets) > 1: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Error: Please select rows from the same dataset!", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=18, color="red"), | |
| xanchor='center', yanchor='middle' | |
| ) | |
| fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings) | |
| return fig | |
| dataset_name = datasets[0] if datasets else "Unknown" | |
| data = {} | |
| for row in selected_rows_data: | |
| model_name = row.get('Model', 'Unknown') | |
| if isinstance(model_name, str) and 'href' in model_name: | |
| try: | |
| model_name = model_name.split('>', 1)[1].split('<', 1)[0] | |
| except: | |
| pass | |
| method = row.get('Method', '') | |
| if isinstance(model_name, str) and '/' in model_name: | |
| legend_name = model_name.split('/')[-1] | |
| else: | |
| legend_name = str(model_name) | |
| if method and method not in ['Unknown', '-', '']: | |
| legend_name = f"{legend_name}-{method}" | |
| acc = row.get('Accuracy(%)', 0) | |
| cost = row.get('Cost($)', 0) | |
| throughput = row.get('Decoding T/s', 0) | |
| try: | |
| acc = float(acc) if acc not in [None, '-', ''] else 0 | |
| cost = float(cost) if cost not in [None, '-', ''] else 0 | |
| throughput = float(throughput) if throughput not in [None, '-', ''] else 0 | |
| except: | |
| acc, cost, throughput = 0, 0, 0 | |
| data[legend_name] = { | |
| 'accuracy': acc / 100.0 if acc > 1 else acc, | |
| 'cost': cost, | |
| 'throughput': throughput | |
| } | |
| throughputs = [v['throughput'] for v in data.values()] | |
| costs = [v['cost'] for v in data.values()] | |
| accs = [v['accuracy'] for v in data.values()] | |
| tp_min, tp_max = (min(throughputs), max(throughputs)) if throughputs else (0, 1) | |
| cost_max = max(costs) if costs else 1 | |
| acc_min, acc_max = (min(accs), 1.0) if accs else (0, 1) | |
| baseline = 20 | |
| categories = ['Throughput (T/s)', 'Cost ($)', 'Accuracy', 'Throughput (T/s)'] | |
| fig = go.Figure() | |
| for system, values in data.items(): | |
| raw_vals = [values['throughput'], values['cost'], values['accuracy']] | |
| norm_vals = [ | |
| normalize(values['throughput'], tp_min, tp_max, baseline), | |
| normalize_cost(values['cost'], cost_max, baseline), | |
| normalize(values['accuracy'], acc_min, acc_max, baseline) | |
| ] | |
| norm_vals += [norm_vals[0]] | |
| hovertext = [ | |
| f"Throughput: {raw_vals[0]:.2f} T/s", | |
| f"Cost: ${raw_vals[1]:.2f}", | |
| f"Accuracy: {raw_vals[2]*100:.2f}%", | |
| f"Throughput: {raw_vals[0]:.2f} T/s" | |
| ] | |
| fig.add_trace(go.Scatterpolar( | |
| r=norm_vals, | |
| theta=categories, | |
| fill='toself', | |
| name=system, | |
| text=hovertext, | |
| hoverinfo='text+name', | |
| line=dict(width=2) | |
| )) | |
| fig.update_layout( | |
| title=dict(text=f"CAP Radar Plot: {dataset_name}", x=0.5, xanchor='center', font=dict(size=20, color="black")), | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, 100], | |
| tickfont=dict(size=12, color="black"), | |
| gridcolor='lightgray', # Add this | |
| linecolor='gray', # Add this | |
| showline=True # Add this | |
| ), | |
| angularaxis=dict( | |
| tickfont=dict(size=14, color="black"), | |
| rotation=90, | |
| direction='clockwise', | |
| gridcolor='lightgray', # Add this | |
| linecolor='gray', # Add this | |
| showline=True # Add this | |
| ), | |
| bgcolor="white" | |
| ), | |
| legend=dict(orientation='h', yanchor='bottom', y=-0.15, xanchor='center', x=0.5, font=dict(size=13, color="black")), | |
| **layout_settings | |
| ) | |
| return fig | |
| def json_to_row(path: str, metrics: dict) -> dict: | |
| model_name = metrics.get("model_name") | |
| if not model_name: | |
| model_name = "unknown-model" | |
| dataset = metrics.get("dataset", "Unknown") | |
| method = metrics.get("method", "Unknown") | |
| precision = metrics.get("precision", "Unknown") | |
| model_type = metrics.get("model_type", "Unknown") | |
| e2e_s = metrics.get("e2e_s", None) | |
| batch_size = metrics.get("batch_size", None) | |
| gpu_type = metrics.get("gpu_type", "") | |
| cost = metrics.get("cost", None) | |
| em = metrics.get("exact_match") | |
| correct = metrics.get("correct") | |
| total = metrics.get("total") | |
| if isinstance(correct, (int, float)) and isinstance(total, (int, float)) and total > 0: | |
| acc = correct / total | |
| else: | |
| acc = em | |
| def pct(x): | |
| return round(x * 100, 2) if isinstance(x, (int, float)) else None | |
| if isinstance(model_name, str) and "/" in model_name: | |
| hf_url = f"https://huggingface.co/{model_name}" | |
| model_cell = f"<a href='{hf_url}' target='_blank' style='color: #0366d6; text-decoration: none;'>{model_name}</a>" | |
| else: | |
| model_cell = model_name | |
| row = { | |
| "Model": model_cell, | |
| "Dataset": dataset, | |
| "Method": method, | |
| "Model type": model_type, | |
| "Precision": precision, | |
| "E2E(s)": f2(e2e_s), | |
| "GPU": gpu_type, | |
| "Accuracy(%)": pct(acc), | |
| "Cost($)": cost, | |
| "Decoding T/s": f2(metrics.get("decoding_throughput")), | |
| "Prefill T/s": f2(metrics.get("prefill_tp")), | |
| "Prefill<br>S-MBU(%)": pct(metrics.get("prefill_smbu")), | |
| "Prefill<br>S-MFU(%)": pct(metrics.get("prefill_smfu")), | |
| "Decoding<br>S-MBU(%)": pct(metrics.get("decoding_smbu")), | |
| "Decoding<br>S-MFU(%)": pct(metrics.get("decoding_smfu")), | |
| "TTFT(s)": f2(metrics.get("ttft")), | |
| "TPOT(s)": f2(metrics.get("tpot")), | |
| "Batch size": batch_size, | |
| } | |
| return row | |
| def load_from_dir(dir_path: str, selected_tasks=None, selected_frameworks=None, selected_model_types=None, selected_precisions=None, search_keyword="", force_refresh=False): | |
| if not dir_path: | |
| return "<p style='color:black'>Result Directory not set.</p>", [] | |
| try: | |
| pattern = f"hf://datasets/{dir_path}/**/*.json" | |
| dl_mode = "force_redownload" if force_refresh else None | |
| print(f"Fetching from {pattern} (mode={dl_mode})...") | |
| ds = load_dataset("json", data_files={"train": pattern}, split="train", download_mode=dl_mode) | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| return "<p style='color:black'>No files loaded or Dataset not found.</p>", [] | |
| rows = [] | |
| for i, example in enumerate(ds): | |
| metrics = example.get("metrics") or example.get("json") or example | |
| rows.append(json_to_row(f"{dir_path}#{i}", metrics)) | |
| if not rows: | |
| return "<p style='color:black'>No records found.</p>", [] | |
| df = pd.DataFrame(rows) | |
| # --- Filtering Logic --- | |
| # This logic is consistent: if a filter is provided, we ONLY keep rows | |
| # where the column value is inside the selected list. | |
| if selected_tasks: | |
| df = df[df["Dataset"].astype(str).str.lower().isin([x.lower() for x in selected_tasks])] | |
| if selected_frameworks: | |
| df = df[df["Method"].astype(str).str.lower().isin([str(x).lower() for x in selected_frameworks])] | |
| if selected_model_types: | |
| df = df[df["Model type"].astype(str).str.lower().isin([str(x).lower() for x in selected_model_types])] | |
| if selected_precisions: | |
| df = df[df["Precision"].astype(str).str.lower().isin([str(x).lower() for x in selected_precisions])] | |
| if search_keyword and search_keyword.strip(): | |
| df = df[df.astype(str).apply(lambda row: row.str.lower().str.contains(search_keyword.strip().lower()).any(), axis=1)] | |
| if df.empty: | |
| return "<p style='color:black'>No records found.</p>", [] | |
| df = df.fillna("-") | |
| df.insert(0, 'Row #', range(len(df))) | |
| table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>' | |
| df_without_rownum = df.drop('Row #', axis=1) | |
| return table_html, df_without_rownum.to_dict('records') | |
| def auto_refresh_from_dir(dir_path, tasks, frameworks, types, precisions, search): | |
| return load_from_dir(dir_path, tasks, frameworks, types, precisions, search, force_refresh=True) | |
| def parse_and_generate_plot(df_data, indices_str): | |
| if not indices_str or not indices_str.strip(): | |
| return generate_radar_plot([]) | |
| try: | |
| indices = [int(idx.strip()) for idx in indices_str.split(',') if idx.strip()][:3] | |
| selected_rows = [df_data[i] for i in indices if 0 <= i < len(df_data)] | |
| return generate_radar_plot(selected_rows) | |
| except: | |
| return generate_radar_plot([]) | |
| def build_app() -> gr.Blocks: | |
| # NUCLEAR CSS FIX: Overwrite all generic Gradio variables to force light mode | |
| row_css = """ | |
| /* 1. FORCE LIGHT VARIABLES GLOBALLY */ | |
| :root, .gradio-container, body { | |
| --body-background-fill: #f5f7fa !important; | |
| --body-text-color: #374151 !important; | |
| --background-fill-primary: #ffffff !important; | |
| --background-fill-secondary: #f3f4f6 !important; | |
| --border-color-primary: #e5e7eb !important; | |
| --block-background-fill: #ffffff !important; | |
| --block-label-text-color: #374151 !important; | |
| --block-title-text-color: #1f2937 !important; | |
| --input-background-fill: #ffffff !important; | |
| --color-accent: #0366d6 !important; | |
| /* Reset dark mode specific variables to light values */ | |
| --neutral-50: #f9fafb; --neutral-100: #f3f4f6; --neutral-200: #e5e7eb; | |
| --neutral-300: #d1d5da; --neutral-400: #9ca3af; --neutral-500: #6b7280; | |
| --neutral-600: #4b5563; --neutral-700: #374151; --neutral-800: #1f2937; | |
| } | |
| /* 2. RESET STANDARD CONTAINERS */ | |
| .gradio-container .block, | |
| .gradio-container .panel, | |
| .gradio-container .form { | |
| background-color: white !important; | |
| border-color: #e1e4e8 !important; | |
| } | |
| /* 3. SPECIFIC FIX FOR THE DARK "FILTERS" and "RADAR" SECTIONS */ | |
| .filter-section { | |
| background-color: #ffffff !important; | |
| border: 2px solid #e1e4e8 !important; | |
| border-radius: 8px !important; | |
| padding: 16px !important; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; | |
| color: #24292e !important; /* Set default text color for the section */ | |
| } | |
| /* Remove background color from text elements to prevent "dark blocks" */ | |
| .filter-section label, | |
| .filter-section span, | |
| .filter-section p { | |
| background-color: transparent !important; | |
| } | |
| /* 4. BUTTON FIXES - TARGET BY ID FOR SPECIFICITY */ | |
| #gen_btn { | |
| background-color: #0366d6 !important; | |
| color: white !important; | |
| border: none !important; | |
| } | |
| #gen_btn:hover { | |
| opacity: 0.9; | |
| } | |
| /* 5. INPUTS & CHECKBOXES */ | |
| /* Re-apply white background to inputs specifically */ | |
| .filter-section input, | |
| .filter-section textarea, | |
| .filter-section select { | |
| background-color: #ffffff !important; | |
| border: 1px solid #d1d5da !important; | |
| color: #24292e !important; | |
| } | |
| /* --- FIX FOR CHECKBOXES --- */ | |
| /* Use explicit styling for the checked state to ensure visibility */ | |
| .filter-section input[type="checkbox"] { | |
| appearance: none !important; | |
| -webkit-appearance: none !important; | |
| width: 16px !important; | |
| height: 16px !important; | |
| background-color: white !important; | |
| border: 1px solid #d1d5da !important; | |
| border-radius: 3px !important; | |
| position: relative !important; | |
| cursor: pointer !important; | |
| } | |
| .filter-section input[type="checkbox"]:checked { | |
| background-color: #0366d6 !important; | |
| border-color: #0366d6 !important; | |
| /* Draw the checkmark using an SVG data URI */ | |
| background-image: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e") !important; | |
| background-size: 100% 100% !important; | |
| background-position: center !important; | |
| background-repeat: no-repeat !important; | |
| } | |
| .filter-section label span { | |
| color: #24292e !important; | |
| } | |
| /* 6. SEARCH BOX */ | |
| .search-box { | |
| background: white !important; | |
| padding: 16px !important; | |
| border-radius: 6px; | |
| border: 2px solid #e1e4e8 !important; | |
| margin-bottom: 16px; | |
| } | |
| /* 7. TABLE STYLING */ | |
| .table-container { | |
| overflow-x: auto; | |
| max-height: 75vh; | |
| border: 2px solid #e1e4e8; | |
| border-radius: 6px; | |
| background: white !important; | |
| } | |
| table.metrics-table { | |
| width: 100%; border-collapse: collapse; background: white !important; | |
| } | |
| table.metrics-table th, table.metrics-table td { | |
| padding: 10px 14px; border: 1px solid #e1e4e8; | |
| white-space: nowrap; font-size: 13px; color: #24292e !important; | |
| } | |
| table.metrics-table th { | |
| background: #f6f8fa !important; font-weight: 600; position: sticky; top: 0; | |
| } | |
| .metrics-table th:first-child, .metrics-table td:first-child { | |
| background-color: #f0f0f0 !important; text-align: center; | |
| } | |
| /* 8. PLOT CONTAINER - FORCE WHITE BACKGROUND */ | |
| .plot-container { | |
| width: 100% !important; | |
| background-color: white !important; | |
| } | |
| .plot-container > div, .plot-container .plotly { | |
| background-color: white !important; | |
| } | |
| /* 9. LINKS */ | |
| a { color: #0366d6 !important; text-decoration: none; } | |
| a:hover { text-decoration: underline; } | |
| """ | |
| with gr.Blocks(title="MoE-CAP Dashboard", css=row_css, theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# MoE-CAP Dashboard") | |
| with gr.Row(): | |
| # Left Sidebar | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="search-box"): | |
| search_input = gr.Textbox(label="π Search", placeholder="Search...", lines=1) | |
| with gr.Group(elem_classes="filter-section"): | |
| gr.Markdown("### ποΈ Filters") | |
| dir_path = gr.State(RESULT_DIR) | |
| task_filter = gr.CheckboxGroup( | |
| label="π Tasks", | |
| choices=[("GSM8K", "gsm8k"), ("LongBench", "longbench"), ("MMLU", "mmlu"), ("NuminaMath", "numinamath"), ("RULER", "ruler")], | |
| value=["gsm8k", "longbench", "mmlu", "numinamath", "ruler"] | |
| ) | |
| framework_filter = gr.CheckboxGroup(label="βοΈ Frameworks", choices=["sglang", "vllm"], value=["sglang", "vllm"]) | |
| model_type_filter = gr.CheckboxGroup(label="π€ Model Types", choices=["instruct", "thinking"], value=["instruct", "thinking"]) | |
| precision_filter = gr.CheckboxGroup(label="π― Precision", choices=["bfloat16", "fp8"], value=["bfloat16", "fp8"]) | |
| with gr.Accordion("π About Tasks & Metrics", open=True): | |
| gr.Markdown( | |
| "### Tasks\n- **GSM8K**, **LongBench**, **MMLU**, **NuminaMath**, **RULER**\n\n" | |
| "### Metrics\n- **E2E(s)**: Latency | **Cost($)** | **T/s**: Throughput | **S-MBU/MFU**: Utilization", | |
| elem_classes="info-section" | |
| ) | |
| # Right Main Content | |
| with gr.Column(scale=5): | |
| leaderboard_output = gr.HTML(label="π Results") | |
| with gr.Group(elem_classes="filter-section"): | |
| gr.Markdown("### π CAP Radar Plot") | |
| gr.Markdown("**How to use:** Look at the 'Row #' column in the table. Enter row numbers (e.g., 0,1,2) and click Generate.") | |
| with gr.Row(): | |
| row_indices_input = gr.Textbox(label="Row Numbers", placeholder="0,1,2", scale=3) | |
| # Added elem_id="gen_btn" here for specific CSS targeting | |
| generate_btn = gr.Button("π― Generate", variant="primary", scale=1, elem_id="gen_btn") | |
| radar_plot = gr.Plot(value=generate_radar_plot([]), elem_classes="plot-container") | |
| # State & Events | |
| df_data_state = gr.State([]) | |
| inputs = [dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input] | |
| demo.load(fn=auto_refresh_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) | |
| search_input.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) | |
| task_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) | |
| framework_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) | |
| model_type_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) | |
| precision_filter.change(fn=load_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) | |
| generate_btn.click(fn=parse_and_generate_plot, inputs=[df_data_state, row_indices_input], outputs=[radar_plot]) | |
| gr.Timer(60.0).tick(fn=auto_refresh_from_dir, inputs=inputs, outputs=[leaderboard_output, df_data_state]) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch() |