Spaces:
Running
Running
| #!/usr/bin/env python | |
| import os | |
| import json | |
| from typing import List, Tuple | |
| os.environ["GRADIO_LANGUAGE"] = "en" | |
| RESULT_DIR = os.environ.get("MOECAP_RESULT_DIR") | |
| if not RESULT_DIR: | |
| raise RuntimeError( | |
| "MOECAP_RESULT_DIR is not set. Please set MOECAP_RESULT_DIR (HF Repo ID) before running app.py" | |
| ) | |
| import gradio as gr | |
| import pandas as pd | |
| from datasets import load_dataset | |
| import plotly.graph_objects as go | |
| def f2(x): | |
| """Format to 2 decimal places if number, else return as-is.""" | |
| if isinstance(x, (int, float)): | |
| return round(float(x), 2) | |
| return x | |
| def normalize(val, vmin, vmax, baseline=20): | |
| """Normalize value to baseline-100 range.""" | |
| if vmax == vmin: | |
| return baseline + 40 | |
| return baseline + (val - vmin) / (vmax - vmin) * (100 - baseline) | |
| def normalize_reversed(val, vmin, vmax, baseline=20): | |
| """Normalize value (reversed - lower is better) to baseline-100 range.""" | |
| if vmax == vmin: | |
| return baseline + 40 | |
| return baseline + (vmax - val) / (vmax - vmin) * (100 - baseline) | |
| def normalize_cost(val, max_tick, baseline=20): | |
| """Normalize cost (lower is better).""" | |
| if max_tick == 0: | |
| return baseline + 40 | |
| return baseline + (max_tick - min(val, max_tick)) / max_tick * (100 - baseline) | |
| def generate_radar_plot(selected_rows_data: List[dict]) -> go.Figure: | |
| """Generate a CAP radar plot from selected rows.""" | |
| # Validation: max 3 rows, all same dataset | |
| if not selected_rows_data or len(selected_rows_data) == 0: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Please select 1-3 rows from the table to generate radar plot", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=16), | |
| xanchor='center', | |
| yanchor='middle' | |
| ) | |
| fig.update_layout( | |
| height=600, | |
| width=800, | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| paper_bgcolor='white', | |
| plot_bgcolor='white', | |
| margin=dict(t=40, b=40, l=40, r=40) | |
| ) | |
| return fig | |
| if len(selected_rows_data) > 3: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Error: Please select no more than 3 rows!", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=18, color="red"), | |
| xanchor='center', | |
| yanchor='middle' | |
| ) | |
| fig.update_layout( | |
| height=600, | |
| width=800, | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| paper_bgcolor='white', | |
| plot_bgcolor='white', | |
| margin=dict(t=40, b=40, l=40, r=40) | |
| ) | |
| return fig | |
| datasets = [row.get('Dataset', '') for row in selected_rows_data] | |
| unique_datasets = set(datasets) | |
| if len(unique_datasets) > 1: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Error: Please select rows from the same dataset!", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=18, color="red"), | |
| xanchor='center', | |
| yanchor='middle' | |
| ) | |
| fig.update_layout( | |
| height=600, | |
| width=800, | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| paper_bgcolor='white', | |
| plot_bgcolor='white', | |
| margin=dict(t=40, b=40, l=40, r=40) | |
| ) | |
| return fig | |
| dataset_name = datasets[0] if datasets else "Unknown" | |
| # Extract metrics from selected rows | |
| data = {} | |
| for row in selected_rows_data: | |
| # Extract model name from HTML or use as-is | |
| model_name = row.get('Model', 'Unknown') | |
| if isinstance(model_name, str) and 'href' in model_name: | |
| try: | |
| model_name = model_name.split('>', 1)[1].split('<', 1)[0] | |
| except: | |
| pass | |
| # Format legend name: extract name after "/" and add method | |
| method = row.get('Method', '') | |
| if isinstance(model_name, str) and '/' in model_name: | |
| legend_name = model_name.split('/')[-1] # Get part after last / | |
| else: | |
| legend_name = str(model_name) | |
| # Add method suffix | |
| if method and method not in ['Unknown', '-', '']: | |
| legend_name = f"{legend_name}-{method}" | |
| # Get metrics | |
| acc = row.get('Accuracy(%)', 0) | |
| cost = row.get('Cost($)', 0) | |
| throughput = row.get('Decoding T/s', 0) | |
| # Convert to float if needed | |
| try: | |
| acc = float(acc) if acc not in [None, '-', ''] else 0 | |
| cost = float(cost) if cost not in [None, '-', ''] else 0 | |
| throughput = float(throughput) if throughput not in [None, '-', ''] else 0 | |
| except: | |
| acc, cost, throughput = 0, 0, 0 | |
| data[legend_name] = { | |
| 'accuracy': acc / 100.0 if acc > 1 else acc, # Normalize to 0-1 | |
| 'cost': cost, | |
| 'throughput': throughput | |
| } | |
| # Get min/max for normalization | |
| throughputs = [v['throughput'] for v in data.values()] | |
| costs = [v['cost'] for v in data.values()] | |
| accs = [v['accuracy'] for v in data.values()] | |
| tp_min, tp_max = (min(throughputs), max(throughputs)) if throughputs else (0, 1) | |
| cost_max = max(costs) if costs else 1 | |
| acc_min, acc_max = (min(accs), 1.0) if accs else (0, 1) | |
| baseline = 20 | |
| categories = ['Throughput (T/s)', 'Cost ($)', 'Accuracy', 'Throughput (T/s)'] | |
| fig = go.Figure() | |
| for system, values in data.items(): | |
| raw_vals = [values['throughput'], values['cost'], values['accuracy']] | |
| norm_vals = [ | |
| normalize(values['throughput'], tp_min, tp_max, baseline), | |
| normalize_cost(values['cost'], cost_max, baseline), | |
| normalize(values['accuracy'], acc_min, acc_max, baseline) | |
| ] | |
| norm_vals += [norm_vals[0]] # Close the loop | |
| hovertext = [ | |
| f"Throughput: {raw_vals[0]:.2f} T/s", | |
| f"Cost: ${raw_vals[1]:.2f}", | |
| f"Accuracy: {raw_vals[2]*100:.2f}%", | |
| f"Throughput: {raw_vals[0]:.2f} T/s" | |
| ] | |
| fig.add_trace(go.Scatterpolar( | |
| r=norm_vals, | |
| theta=categories, | |
| fill='toself', | |
| name=system, | |
| text=hovertext, | |
| hoverinfo='text+name', | |
| line=dict(width=2) | |
| )) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"CAP Radar Plot: {dataset_name}", | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=18) | |
| ), | |
| polar=dict( | |
| radialaxis=dict(visible=True, range=[0, 100], tickfont=dict(size=11)), | |
| angularaxis=dict( | |
| tickfont=dict(size=13), | |
| rotation=30, | |
| direction='clockwise' | |
| ), | |
| ), | |
| legend=dict( | |
| orientation='h', | |
| yanchor='bottom', | |
| y=-0.15, | |
| xanchor='center', | |
| x=0.5, | |
| font=dict(size=12) | |
| ), | |
| margin=dict(t=80, b=100, l=80, r=80), | |
| height=650, | |
| width=800, | |
| paper_bgcolor='white', | |
| plot_bgcolor='white' | |
| ) | |
| return fig | |
| def json_to_row(path: str, metrics: dict) -> dict: | |
| model_name = metrics.get("model_name") | |
| if not model_name: | |
| model_name = "unknown-model" | |
| dataset = metrics.get("dataset", "Unknown") | |
| method = metrics.get("method", "Unknown") | |
| precision = metrics.get("precision", "Unknown") | |
| model_type = metrics.get("model_type", "Unknown") | |
| e2e_s = metrics.get("e2e_s", None) | |
| batch_size = metrics.get("batch_size", None) | |
| gpu_type = metrics.get("gpu_type", "") | |
| cost = metrics.get("cost", None) | |
| em = metrics.get("exact_match") | |
| correct = metrics.get("correct") | |
| total = metrics.get("total") | |
| if isinstance(correct, (int, float)) and isinstance(total, (int, float)) and total > 0: | |
| acc = correct / total | |
| else: | |
| acc = em | |
| def pct(x): | |
| return round(x * 100, 2) if isinstance(x, (int, float)) else None | |
| if isinstance(model_name, str) and "/" in model_name: | |
| hf_url = f"https://huggingface.co/{model_name}" | |
| model_cell = f"<a href='{hf_url}' target='_blank'>{model_name}</a>" | |
| else: | |
| model_cell = model_name | |
| row = { | |
| "Model": model_cell, | |
| "Dataset": dataset, | |
| "Method": method, | |
| "Model type": model_type, | |
| "Precision": precision, | |
| "E2E(s)": f2(e2e_s), | |
| "GPU": gpu_type, | |
| "Accuracy(%)": pct(acc), | |
| "Cost($)": cost, | |
| "Decoding T/s": f2(metrics.get("decoding_throughput")), | |
| "Prefill T/s": f2(metrics.get("prefill_tp")), | |
| "Prefill<br>S-MBU(%)": pct(metrics.get("prefill_smbu")), | |
| "Prefill<br>S-MFU(%)": pct(metrics.get("prefill_smfu")), | |
| "Decoding<br>S-MBU(%)": pct(metrics.get("decoding_smbu")), | |
| "Decoding<br>S-MFU(%)": pct(metrics.get("decoding_smfu")), | |
| "TTFT(s)": f2(metrics.get("ttft")), | |
| "TPOT(s)": f2(metrics.get("tpot")), | |
| "Batch size": batch_size, # moved to tail | |
| } | |
| return row | |
| def build_leaderboard_from_files(files: List[gr.File], prev_rows: list | None = None): | |
| if prev_rows is None: | |
| prev_rows = [] | |
| if not files and prev_rows: | |
| df = pd.DataFrame(prev_rows) | |
| raw_models = set() | |
| for cell in df["Model"].tolist(): | |
| if isinstance(cell, str) and "href" in cell: | |
| try: | |
| name = cell.split(">", 1)[1].split("<", 1)[0] | |
| except Exception: | |
| name = cell | |
| else: | |
| name = cell | |
| raw_models.add(name) | |
| links = [] | |
| for name in sorted(raw_models): | |
| if isinstance(name, str) and "/" in name: | |
| hf_url = f"https://huggingface.co/{name}" | |
| links.append(f"[{name}]({hf_url})") | |
| else: | |
| links.append(str(name)) | |
| models_str = ", ".join(links) | |
| summary_md = f"**Loaded {len(prev_rows)} result files.** \n**Models:** {models_str}" | |
| table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>' | |
| return summary_md, table_html, prev_rows | |
| new_rows = [] | |
| if files: | |
| for f in files: | |
| path = f.name | |
| try: | |
| with open(path, "r", encoding="utf-8") as fp: | |
| metrics = json.load(fp) | |
| new_rows.append(json_to_row(path, metrics)) | |
| except Exception: | |
| continue | |
| all_rows = prev_rows + new_rows | |
| if not all_rows: | |
| empty_html = "<p>No files loaded.</p>" | |
| return "No files uploaded.", empty_html, [] | |
| df = pd.DataFrame(all_rows) | |
| raw_models = set() | |
| for cell in df["Model"].tolist(): | |
| if isinstance(cell, str) and "href" in cell: | |
| try: | |
| name = cell.split(">", 1)[1].split("<", 1)[0] | |
| except Exception: | |
| name = cell | |
| else: | |
| name = cell | |
| raw_models.add(name) | |
| links = [] | |
| for name in sorted(raw_models): | |
| if isinstance(name, str) and "/" in name: | |
| hf_url = f"https://huggingface.co/{name}" | |
| links.append(f"[{name}]({hf_url})") | |
| else: | |
| links.append(str(name)) | |
| models_str = ", ".join(links) | |
| summary_md = f"**Loaded {len(all_rows)} result files.** \n**Models:** {models_str}" | |
| table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>' | |
| return summary_md, table_html, all_rows | |
| def load_from_dir( | |
| dir_path: str, | |
| selected_tasks: List[str] | None = None, | |
| selected_frameworks: List[str] | None = None, | |
| selected_model_types: List[str] | None = None, | |
| selected_precisions: List[str] | None = None, | |
| search_keyword: str = "", | |
| force_refresh: bool = False, | |
| ): | |
| try: | |
| pattern = f"hf://datasets/{dir_path}/**/*.json" | |
| dl_mode = "force_redownload" if force_refresh else None | |
| print(f"Fetching from {pattern} (mode={dl_mode})...") | |
| ds = load_dataset( | |
| "json", | |
| data_files={"train": pattern}, | |
| split="train", | |
| download_mode=dl_mode, | |
| ) | |
| except Exception as e: | |
| empty_html = "<p>No files loaded or Dataset not found.</p>" | |
| return empty_html | |
| rows = [] | |
| for i, example in enumerate(ds): | |
| if isinstance(example, dict): | |
| metrics = example.get("metrics") or example.get("json") or example | |
| else: | |
| metrics = example | |
| rows.append(json_to_row(f"{dir_path}#{i}", metrics)) | |
| if not rows: | |
| empty_html = "<p>No records found.</p>" | |
| return empty_html | |
| df = pd.DataFrame(rows) | |
| # Dataset filter | |
| if selected_tasks is not None: | |
| lower_selected = [x.lower() for x in selected_tasks] | |
| df = df[df["Dataset"].astype(str).str.lower().isin(lower_selected)] | |
| # Inference framework filter (Method) | |
| if selected_frameworks is not None: | |
| lower_selected = [str(x).lower() for x in selected_frameworks] | |
| df = df[df["Method"].astype(str).str.lower().isin(lower_selected)] | |
| # Model type filter | |
| if selected_model_types is not None: | |
| lower_selected = [str(x).lower() for x in selected_model_types] | |
| df = df[df["Model type"].astype(str).str.lower().isin(lower_selected)] | |
| # Precision filter | |
| if selected_precisions is not None: | |
| lower_selected = [str(x).lower() for x in selected_precisions] | |
| df = df[df["Precision"].astype(str).str.lower().isin(lower_selected)] | |
| # Search keyword filter - search across all columns | |
| if search_keyword and search_keyword.strip(): | |
| keyword_lower = search_keyword.strip().lower() | |
| # Create a mask that checks if the keyword appears in any column | |
| mask = df.astype(str).apply(lambda row: row.str.lower().str.contains(keyword_lower).any(), axis=1) | |
| df = df[mask] | |
| if df.empty: | |
| empty_html = "<p>No records found.</p>" | |
| return empty_html, [] | |
| df = df.fillna("-") | |
| raw_models = set() | |
| for cell in df["Model"].tolist(): | |
| if isinstance(cell, str) and "href" in cell: | |
| try: | |
| name = cell.split(">", 1)[1].split("<", 1)[0] | |
| except Exception: | |
| name = cell | |
| else: | |
| name = cell | |
| raw_models.add(name) | |
| links = [] | |
| for name in sorted(raw_models): | |
| if isinstance(name, str) and "/" in name: | |
| hf_url = f"https://huggingface.co/{name}" | |
| links.append(f"[{name}]({hf_url})") | |
| else: | |
| links.append(str(name)) | |
| models_str = ", ".join(links) | |
| # Insert row number column at the beginning for easy reference | |
| df.insert(0, 'Row #', range(len(df))) | |
| # Create HTML table | |
| table_html = f'<div class="table-container">{df.to_html(escape=False, index=False, classes="metrics-table")}</div>' | |
| df_without_rownum = df.drop('Row #', axis=1) | |
| df_dict = df_without_rownum.to_dict('records') | |
| return table_html, df_dict | |
| def auto_refresh_from_dir( | |
| dir_path: str, | |
| selected_tasks: List[str] | None = None, | |
| selected_frameworks: List[str] | None = None, | |
| selected_model_types: List[str] | None = None, | |
| selected_precisions: List[str] | None = None, | |
| search_keyword: str = "", | |
| ): | |
| return load_from_dir( | |
| dir_path, | |
| selected_tasks=selected_tasks, | |
| selected_frameworks=selected_frameworks, | |
| selected_model_types=selected_model_types, | |
| selected_precisions=selected_precisions, | |
| search_keyword=search_keyword, | |
| force_refresh=True, | |
| ) | |
| def update_radar_plot(df_data: list, selected_indices: list): | |
| """Update radar plot based on selected row indices.""" | |
| if not selected_indices or not df_data: | |
| return generate_radar_plot([]) | |
| # Get selected rows (limit to 3) | |
| selected_rows = [df_data[i] for i in selected_indices[:3] if i < len(df_data)] | |
| return generate_radar_plot(selected_rows) | |
| def parse_and_generate_plot(df_data: list, indices_str: str): | |
| """Parse comma-separated indices and generate radar plot.""" | |
| if not indices_str or not indices_str.strip(): | |
| return generate_radar_plot([]) | |
| try: | |
| # Parse comma-separated indices | |
| indices = [int(idx.strip()) for idx in indices_str.split(',') if idx.strip()] | |
| # Limit to 3 rows | |
| indices = indices[:3] | |
| # Get selected rows | |
| selected_rows = [df_data[i] for i in indices if 0 <= i < len(df_data)] | |
| return generate_radar_plot(selected_rows) | |
| except (ValueError, IndexError): | |
| return generate_radar_plot([]) | |
| def on_table_select(df, evt: gr.SelectData): | |
| """Handle table row selection.""" | |
| return evt.index | |
| # Gradio UI | |
| def build_app() -> gr.Blocks: | |
| row_css = """ | |
| /* Force light theme everywhere */ | |
| body { | |
| background-color: #f5f7fa !important; | |
| } | |
| /* Row number column styling */ | |
| .metrics-table th:first-child, | |
| .metrics-table td:first-child { | |
| width: 60px !important; | |
| text-align: center !important; | |
| padding: 8px !important; | |
| font-weight: 600 !important; | |
| background-color: #f0f0f0 !important; | |
| } | |
| /* The outer Group container */ | |
| .search-box { | |
| background: white !important; | |
| padding: 16px !important; | |
| border-radius: 6px; | |
| border: 2px solid #e1e4e8 !important; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06); | |
| margin-bottom: 16px; | |
| } | |
| /* Reset the internal Gradio Block styles so they fit in the white box */ | |
| .search-box .block { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 0 !important; | |
| } | |
| /* Style the Label Text (π Search) */ | |
| .search-box label span { | |
| color: #24292e !important; | |
| font-weight: 600; | |
| font-size: 14px; | |
| margin-bottom: 8px; | |
| background: transparent !important; | |
| } | |
| /* Style the actual Input Field */ | |
| .search-box input.scroll-hide { | |
| background-color: white !important; | |
| color: #24292e !important; | |
| border: 1.5px solid #e1e4e8 !important; | |
| border-radius: 4px !important; | |
| padding: 10px !important; | |
| box-shadow: none !important; | |
| } | |
| /* Fix focus state */ | |
| .search-box input.scroll-hide:focus { | |
| border-color: #0366d6 !important; | |
| ring: 0 !important; | |
| outline: none !important; | |
| } | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding: 20px !important; | |
| background-color: #f5f7fa !important; | |
| } | |
| /* Override all dark backgrounds */ | |
| .gradio-container .block, | |
| .gradio-container .form, | |
| .gradio-container fieldset, | |
| .gradio-container .input-block, | |
| .gradio-container .wrap, | |
| .gradio-container .gr-box, | |
| .gradio-container .gr-form, | |
| .gradio-container .gr-input { | |
| background-color: white !important; | |
| border-color: #e1e4e8 !important; | |
| } | |
| .gradio-container label { | |
| background-color: transparent !important; | |
| color: #24292e !important; | |
| } | |
| /* Remove any potential dark wrappers */ | |
| .gradio-container > div, | |
| .gradio-container .container { | |
| background-color: transparent !important; | |
| } | |
| /* Force all text to be dark */ | |
| .gradio-container, | |
| .gradio-container label, | |
| .gradio-container p, | |
| .gradio-container span, | |
| .gradio-container div { | |
| color: #24292e !important; | |
| } | |
| /* Table styling */ | |
| .gradio-container table.metrics-table th, | |
| .gradio-container table.metrics-table td { | |
| padding: 10px 14px; | |
| border: 1.5px solid #e1e4e8; | |
| white-space: nowrap; | |
| font-size: 13px; | |
| text-align: left; | |
| color: #24292e !important; | |
| } | |
| .gradio-container table.metrics-table th { | |
| background: linear-gradient(to bottom, #fafbfc, #f6f8fa); | |
| font-weight: 600; | |
| color: #24292e !important; | |
| position: sticky; | |
| top: 0; | |
| z-index: 10; | |
| border-bottom: 2px solid #d1d5da; | |
| } | |
| .gradio-container table.metrics-table tbody tr:nth-child(even) { | |
| background-color: #f6f8fa; | |
| } | |
| .gradio-container table.metrics-table tbody tr:hover { | |
| background-color: #e1e4e8; | |
| } | |
| .gradio-container table.metrics-table { | |
| border-collapse: collapse; | |
| width: 100%; | |
| background: white; | |
| } | |
| .gradio-container table.metrics-table a { | |
| color: #0366d6 !important; | |
| text-decoration: none; | |
| } | |
| .gradio-container table.metrics-table a:hover { | |
| color: #0366d6 !important; | |
| text-decoration: underline; | |
| } | |
| /* Scrollable table container */ | |
| .table-container { | |
| overflow-x: auto; | |
| overflow-y: auto; | |
| max-height: 75vh; | |
| border: 2px solid #e1e4e8; | |
| border-radius: 6px; | |
| background: white; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06); | |
| } | |
| /* Filter section styling */ | |
| .filter-section { | |
| background: white !important; | |
| padding: 0 !important; | |
| border-radius: 6px; | |
| border: 2px solid #e1e4e8 !important; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06); | |
| } | |
| .filter-section * { | |
| color: #24292e !important; | |
| } | |
| .filter-section .wrap, | |
| .filter-section .block, | |
| .filter-section .container, | |
| .filter-section .group, | |
| .filter-section > div, | |
| .filter-section > div > div { | |
| background: transparent !important; | |
| } | |
| .filter-section .wrap { | |
| padding: 20px !important; | |
| } | |
| .filter-section label { | |
| background: transparent !important; | |
| color: #24292e !important; | |
| } | |
| .filter-section fieldset { | |
| background: transparent !important; | |
| border-color: #e1e4e8 !important; | |
| } | |
| /* Accordion styling */ | |
| .gradio-container .accordion { | |
| background: white !important; | |
| border: 2px solid #e1e4e8 !important; | |
| border-radius: 6px !important; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06); | |
| } | |
| .gradio-container .accordion * { | |
| color: #24292e !important; | |
| } | |
| .gradio-container .accordion label { | |
| background: transparent !important; | |
| color: #24292e !important; | |
| } | |
| .gradio-container .accordion button { | |
| background: transparent !important; | |
| color: #24292e !important; | |
| } | |
| /* Info section */ | |
| .info-section { | |
| padding: 16px; | |
| background: white !important; | |
| } | |
| /* Make text in info section dark */ | |
| .info-section p, | |
| .info-section li, | |
| .info-section ul, | |
| .info-section h3, | |
| .info-section strong, | |
| .info-section * { | |
| color: #24292e !important; | |
| } | |
| .info-section a { | |
| color: #0366d6 !important; | |
| } | |
| /* Override any dark backgrounds in groups and accordions */ | |
| .gradio-container .group, | |
| .gradio-container .accordion, | |
| .gradio-container .panel { | |
| background-color: white !important; | |
| } | |
| /* Heading styling */ | |
| .gradio-container h1 { | |
| color: #24292e !important; | |
| font-weight: 700; | |
| margin-bottom: 24px; | |
| } | |
| .gradio-container h3 { | |
| color: #24292e !important; | |
| font-weight: 600; | |
| margin-bottom: 16px; | |
| } | |
| /* Checkbox styling */ | |
| .gradio-container input[type="checkbox"] { | |
| accent-color: #0366d6 !important; | |
| } | |
| """ | |
| # Use Gradio's default (light) theme explicitly | |
| with gr.Blocks(title="MoE-CAP Dashboard", css=row_css, theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# MoE-CAP Dashboard") | |
| with gr.Row(): | |
| # Left side - Filters (narrower) | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="search-box"): | |
| search_input = gr.Textbox( | |
| label="π Search", | |
| placeholder="Search across all columns...", | |
| lines=1 | |
| ) | |
| with gr.Group(elem_classes="filter-section"): | |
| gr.Markdown("### ποΈ Filters") | |
| dir_path = gr.State(RESULT_DIR) | |
| # 1) Tasks filter | |
| task_filter = gr.CheckboxGroup( | |
| label="π Tasks", | |
| choices=[ | |
| ("GSM8K", "gsm8k"), | |
| ("LongBench", "longbench"), | |
| ("MMLU", "mmlu"), | |
| ("NuminaMath", "numinamath"), | |
| ("RULER", "ruler") | |
| ], | |
| value=["gsm8k", "longbench", "mmlu", "numinamath", "ruler"] | |
| ) | |
| # 2) Inference frameworks filter | |
| framework_filter = gr.CheckboxGroup( | |
| label="βοΈ Inference Frameworks", | |
| choices=["sglang", "vllm"], | |
| value=["sglang", "vllm"], | |
| ) | |
| # 3) Model types filter | |
| model_type_filter = gr.CheckboxGroup( | |
| label="π€ Model Types", | |
| choices=["instruct", "thinking"], | |
| value=["instruct", "thinking"], | |
| ) | |
| # 4) Precision filter | |
| precision_filter = gr.CheckboxGroup( | |
| label="π― Precision", | |
| choices=["bfloat16", "fp8"], | |
| value=["bfloat16", "fp8"], | |
| ) | |
| with gr.Accordion("π About Tasks & Metrics", open=True): | |
| gr.Markdown( | |
| "### Tasks\n" | |
| "- **GSM8K** β Mathematics Problem-Solving ([paper](https://arxiv.org/abs/2110-14168))\n" | |
| "- **LongBench** β Long-Context Understanding ([paper](https://arxiv.org/abs/2412.15204))\n" | |
| "- **MMLU** β Multitask Language Understanding ([paper](https://arxiv.org/abs/2009.03300))\n" | |
| "- **NuminaMath** β Mathematical Reasoning ([paper](http://faculty.bicmr.pku.edu.cn/~dongbin/Publications/numina_dataset.pdf))\n" | |
| "- **RULER** β Extreme Long-Context Eval ([paper](https://arxiv.org/abs/2404.06654))\n\n" | |
| "### Metrics\n" | |
| "- **E2E(s)** β End-to-End Latency\n" | |
| "- **Accuracy(%)** β Task Accuracy\n" | |
| "- **Cost($)** β Inference Cost\n" | |
| "- **Decoding/Prefill T/s** β Throughput\n" | |
| "- **S-MBU/MFU(%)** β Hardware Utilization\n" | |
| "- **TTFT(s)** β Time To First Token\n" | |
| "- **TPOT(s)** β Time Per Output Token", | |
| elem_classes="info-section" | |
| ) | |
| # Right side - Table with selection and Radar Plot below | |
| with gr.Column(scale=5): | |
| leaderboard_output = gr.HTML(label="π Results") | |
| with gr.Group(elem_classes="filter-section"): | |
| gr.Markdown("### π CAP Radar Plot") | |
| gr.Markdown( | |
| "**How to use:** Look at the 'Row #' column in the table above. " | |
| "Enter up to 3 row numbers below (separated by commas) and click Generate." | |
| ) | |
| with gr.Row(): | |
| row_indices_input = gr.Textbox( | |
| label="Row Numbers to Compare", | |
| placeholder="Example: 0,1,2", | |
| elem_id="row_indices_input", | |
| scale=3 | |
| ) | |
| generate_btn = gr.Button("π― Generate", variant="primary", scale=1, size="lg") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pass | |
| with gr.Column(scale=5): | |
| radar_plot = gr.Plot(label="", value=generate_radar_plot([])) | |
| with gr.Column(scale=1): | |
| pass | |
| df_data_state = gr.State([]) | |
| demo.load( | |
| fn=auto_refresh_from_dir, | |
| inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input], | |
| outputs=[leaderboard_output, df_data_state], | |
| ) | |
| search_input.change( | |
| fn=load_from_dir, | |
| inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input], | |
| outputs=[leaderboard_output, df_data_state], | |
| ) | |
| task_filter.change( | |
| fn=load_from_dir, | |
| inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input], | |
| outputs=[leaderboard_output, df_data_state], | |
| ) | |
| framework_filter.change( | |
| fn=load_from_dir, | |
| inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input], | |
| outputs=[leaderboard_output, df_data_state], | |
| ) | |
| model_type_filter.change( | |
| fn=load_from_dir, | |
| inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input], | |
| outputs=[leaderboard_output, df_data_state], | |
| ) | |
| precision_filter.change( | |
| fn=load_from_dir, | |
| inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input], | |
| outputs=[leaderboard_output, df_data_state], | |
| ) | |
| # Generate plot on button click | |
| generate_btn.click( | |
| fn=parse_and_generate_plot, | |
| inputs=[df_data_state, row_indices_input], | |
| outputs=[radar_plot] | |
| ) | |
| timer = gr.Timer(60.0) | |
| timer.tick( | |
| fn=auto_refresh_from_dir, | |
| inputs=[dir_path, task_filter, framework_filter, model_type_filter, precision_filter, search_input], | |
| outputs=[leaderboard_output, df_data_state], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch() |