#!/usr/bin/env python
import os
import json
from typing import List, Tuple
os.environ["GRADIO_LANGUAGE"] = "en"
RESULT_DIR = os.environ.get("MOECAP_RESULT_DIR")
if not RESULT_DIR:
# For testing purposes, you can uncomment the line below:
# RESULT_DIR = "generic_result_dir"
# If you are running locally without this env var,
# ensure you handle this error or set the var.
pass
import gradio as gr
import pandas as pd
from datasets import load_dataset
import plotly.graph_objects as go
def f2(x):
"""Format to 2 decimal places if number, else return as-is."""
if isinstance(x, (int, float)):
return round(float(x), 2)
return x
def normalize(val, vmin, vmax, baseline=20):
"""Normalize value to baseline-100 range."""
if vmax == vmin:
return baseline + 40
return baseline + (val - vmin) / (vmax - vmin) * (100 - baseline)
def normalize_cost(val, max_tick, baseline=20):
"""Normalize cost (lower is better)."""
if max_tick == 0:
return baseline + 40
return baseline + (max_tick - min(val, max_tick)) / max_tick * (100 - baseline)
def generate_radar_plot(selected_rows_data: List[dict]) -> go.Figure:
"""Generate a CAP radar plot from selected rows."""
layout_settings = dict(
height=750,
autosize=True,
margin=dict(t=80, b=100, l=80, r=80),
paper_bgcolor='white',
plot_bgcolor='white',
)
if not selected_rows_data or len(selected_rows_data) == 0:
fig = go.Figure()
fig.add_annotation(
text="Please select 1-3 rows from the table to generate radar plot",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=16, color="black"), # Ensure text is black
xanchor='center', yanchor='middle'
)
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings)
return fig
if len(selected_rows_data) > 3:
fig = go.Figure()
fig.add_annotation(
text="Error: Please select no more than 3 rows!",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=18, color="red"),
xanchor='center', yanchor='middle'
)
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings)
return fig
datasets = [row.get('Dataset', '') for row in selected_rows_data]
unique_datasets = set(datasets)
if len(unique_datasets) > 1:
fig = go.Figure()
fig.add_annotation(
text="Error: Please select rows from the same dataset!",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=18, color="red"),
xanchor='center', yanchor='middle'
)
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False), **layout_settings)
return fig
dataset_name = datasets[0] if datasets else "Unknown"
data = {}
for row in selected_rows_data:
model_name = row.get('Model', 'Unknown')
if isinstance(model_name, str) and 'href' in model_name:
try:
model_name = model_name.split('>', 1)[1].split('<', 1)[0]
except:
pass
method = row.get('Method', '')
if isinstance(model_name, str) and '/' in model_name:
legend_name = model_name.split('/')[-1]
else:
legend_name = str(model_name)
if method and method not in ['Unknown', '-', '']:
legend_name = f"{legend_name}-{method}"
acc = row.get('Accuracy(%)', 0)
cost = row.get('Cost($)', 0)
throughput = row.get('Decoding T/s', 0)
try:
acc = float(acc) if acc not in [None, '-', ''] else 0
cost = float(cost) if cost not in [None, '-', ''] else 0
throughput = float(throughput) if throughput not in [None, '-', ''] else 0
except:
acc, cost, throughput = 0, 0, 0
data[legend_name] = {
'accuracy': acc / 100.0 if acc > 1 else acc,
'cost': cost,
'throughput': throughput
}
throughputs = [v['throughput'] for v in data.values()]
costs = [v['cost'] for v in data.values()]
accs = [v['accuracy'] for v in data.values()]
tp_min, tp_max = (min(throughputs), max(throughputs)) if throughputs else (0, 1)
cost_max = max(costs) if costs else 1
acc_min, acc_max = (min(accs), 1.0) if accs else (0, 1)
baseline = 20
categories = ['Throughput (T/s)', 'Cost ($)', 'Accuracy', 'Throughput (T/s)']
fig = go.Figure()
for system, values in data.items():
raw_vals = [values['throughput'], values['cost'], values['accuracy']]
norm_vals = [
normalize(values['throughput'], tp_min, tp_max, baseline),
normalize_cost(values['cost'], cost_max, baseline),
normalize(values['accuracy'], acc_min, acc_max, baseline)
]
norm_vals += [norm_vals[0]]
hovertext = [
f"Throughput: {raw_vals[0]:.2f} T/s",
f"Cost: ${raw_vals[1]:.2f}",
f"Accuracy: {raw_vals[2]*100:.2f}%",
f"Throughput: {raw_vals[0]:.2f} T/s"
]
fig.add_trace(go.Scatterpolar(
r=norm_vals,
theta=categories,
fill='toself',
name=system,
text=hovertext,
hoverinfo='text+name',
line=dict(width=2)
))
fig.update_layout(
title=dict(text=f"CAP Radar Plot: {dataset_name}", x=0.5, xanchor='center', font=dict(size=20, color="black")),
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100],
tickfont=dict(size=12, color="black"),
gridcolor='lightgray', # Add this
linecolor='gray', # Add this
showline=True # Add this
),
angularaxis=dict(
tickfont=dict(size=14, color="black"),
rotation=90,
direction='clockwise',
gridcolor='lightgray', # Add this
linecolor='gray', # Add this
showline=True # Add this
),
bgcolor="white"
),
legend=dict(orientation='h', yanchor='bottom', y=-0.15, xanchor='center', x=0.5, font=dict(size=13, color="black")),
**layout_settings
)
return fig
def json_to_row(path: str, metrics: dict) -> dict:
model_name = metrics.get("model_name")
if not model_name:
model_name = "unknown-model"
dataset = metrics.get("dataset", "Unknown")
method = metrics.get("method", "Unknown")
precision = metrics.get("precision", "Unknown")
model_type = metrics.get("model_type", "Unknown")
e2e_s = metrics.get("e2e_s", None)
batch_size = metrics.get("batch_size", None)
gpu_type = metrics.get("gpu_type", "")
cost = metrics.get("cost", None)
em = metrics.get("exact_match")
correct = metrics.get("correct")
total = metrics.get("total")
if isinstance(correct, (int, float)) and isinstance(total, (int, float)) and total > 0:
acc = correct / total
else:
acc = em
def pct(x):
return round(x * 100, 2) if isinstance(x, (int, float)) else None
if isinstance(model_name, str) and "/" in model_name:
hf_url = f"https://huggingface.co/{model_name}"
model_cell = f"{model_name}"
else:
model_cell = model_name
row = {
"Model": model_cell,
"Dataset": dataset,
"Method": method,
"Model type": model_type,
"Precision": precision,
"E2E(s)": f2(e2e_s),
"GPU": gpu_type,
"Accuracy(%)": pct(acc),
"Cost($)": cost,
"Decoding T/s": f2(metrics.get("decoding_throughput")),
"Prefill T/s": f2(metrics.get("prefill_tp")),
"Prefill
S-MBU(%)": pct(metrics.get("prefill_smbu")),
"Prefill
S-MFU(%)": pct(metrics.get("prefill_smfu")),
"Decoding
S-MBU(%)": pct(metrics.get("decoding_smbu")),
"Decoding
S-MFU(%)": pct(metrics.get("decoding_smfu")),
"TTFT(s)": f2(metrics.get("ttft")),
"TPOT(s)": f2(metrics.get("tpot")),
"Batch size": batch_size,
}
return row
def load_from_dir(dir_path: str, selected_tasks=None, selected_frameworks=None, selected_model_types=None, selected_precisions=None, search_keyword="", force_refresh=False):
if not dir_path:
return "
Result Directory not set.
", [] try: pattern = f"hf://datasets/{dir_path}/**/*.json" dl_mode = "force_redownload" if force_refresh else None print(f"Fetching from {pattern} (mode={dl_mode})...") ds = load_dataset("json", data_files={"train": pattern}, split="train", download_mode=dl_mode) except Exception as e: print(f"Error loading dataset: {e}") return "No files loaded or Dataset not found.
", [] rows = [] for i, example in enumerate(ds): metrics = example.get("metrics") or example.get("json") or example rows.append(json_to_row(f"{dir_path}#{i}", metrics)) if not rows: return "No records found.
", [] df = pd.DataFrame(rows) # --- Filtering Logic --- # This logic is consistent: if a filter is provided, we ONLY keep rows # where the column value is inside the selected list. if selected_tasks: df = df[df["Dataset"].astype(str).str.lower().isin([x.lower() for x in selected_tasks])] if selected_frameworks: df = df[df["Method"].astype(str).str.lower().isin([str(x).lower() for x in selected_frameworks])] if selected_model_types: df = df[df["Model type"].astype(str).str.lower().isin([str(x).lower() for x in selected_model_types])] if selected_precisions: df = df[df["Precision"].astype(str).str.lower().isin([str(x).lower() for x in selected_precisions])] if search_keyword and search_keyword.strip(): df = df[df.astype(str).apply(lambda row: row.str.lower().str.contains(search_keyword.strip().lower()).any(), axis=1)] if df.empty: return "No records found.
", [] df = df.fillna("-") df.insert(0, 'Row #', range(len(df))) table_html = f'