gmancino-ball's picture
gmb/fix-roc (#5)
5cf40e5
import streamlit as st
from pathlib import Path
import pandas as pd
import altair as alt
import subprocess
import os
## Save results path
COMP_CACHE = Path("competition_cache/safe-challenge")
results_path = Path("competition_cache/cached_results")
TASKS = ["video-challenge-pilot-config", "video-challenge-task-1-config"]
valid_splits = ["public", "private"]
#####################################################################
## Data loading ##
#####################################################################
## Data loading
@st.cache_data
def load_results(task, best_only):
if best_only:
return {
f"{s}_score": pd.read_csv(f"{results_path}/{task}_{s}_score.csv")
.sort_values(["team", "balanced_accuracy"], ascending=False)
.drop_duplicates(subset=["team"])
.sort_values("balanced_accuracy", ascending=False)
.set_index("team")
for s in valid_splits
}
else:
return {
f"{s}_score": pd.read_csv(f"{results_path}/{task}_{s}_score.csv").set_index("team") for s in valid_splits
}
@st.cache_data
def load_submission():
out = []
for task in TASKS:
data = pd.read_csv(f"{results_path}/{task}_submissions.csv")
data["task"] = task
out.append(data)
return pd.concat(out, ignore_index=True)
def get_updated_time(file="competition_cache/updated.txt"):
if os.path.exists(file):
return open(file).read()
else:
return "no time file found"
@st.cache_data
def get_volume():
subs = pd.concat(
[pd.read_csv(f"{results_path}/{task}_submissions.csv") for task in TASKS],
ignore_index=True,
)
subs["datetime"] = pd.DatetimeIndex(subs["datetime"])
subs["date"] = subs["datetime"].dt.date
subs = subs.groupby(["date", "status_reason"]).size().unstack().fillna(0).reset_index()
return subs
@st.cache_data
def make_heatmap(results, label="generated", symbol="πŸ‘€"):
# Assuming df is your wide-format DataFrame (models as rows, datasets as columns)
df_long = results.set_index("team")
team_order = results.index.tolist()
df_long = df_long.loc[:, [c for c in df_long.columns if c.startswith(label) and "accuracy" not in c]]
df_long.columns = [c.replace(f"{label}_", "") for c in df_long.columns]
if "none" in df_long.columns:
df_long = df_long.drop(columns=["none"])
df_long = df_long.reset_index().melt(id_vars="team", var_name="source", value_name="acc")
# Base chart for rectangles
base = alt.Chart(df_long).encode(
x=alt.X("source:O", title="Source", axis=alt.Axis(orient="top", labelAngle=-60)),
y=alt.Y("team:O", title="Team", sort=team_order),
)
# Heatmap rectangles
heatmap = base.mark_rect().encode(
color=alt.Color("acc:Q", scale=alt.Scale(scheme="greens"), title=f"{label} Accuracy")
)
# Text labels
text = base.mark_text(baseline="middle", fontSize=16).encode(
text=alt.Text("acc:Q", format=".2f"),
color=alt.condition(
alt.datum.acc < 0.5, # you can tune this for readability
alt.value("black"),
alt.value("white"),
),
)
# Combine heatmap and text
chart = (heatmap + text).properties(width=600, height=500, title=f"Accuracy on {symbol} {label} sources heatmap")
return chart
@st.cache_data
def make_roc_curves(task, submission_cols, best_only=True):
rocs = pd.read_csv(f"{results_path}/{task}_rocs.csv")
if best_only:
rocs = rocs[rocs["submission_id"].isin(submission_cols)]
roc_chart = alt.Chart(rocs).mark_line().encode(x="fpr", y="tpr", color="team:N", detail="submission_id:N")
return roc_chart
#####################################################################
## Page definition ##
#####################################################################
## Set title
st.set_page_config(
page_title="Leaderboard",
initial_sidebar_state="collapsed",
layout="wide", # This makes the app use the full width of the screen
)
## Pull new results or toggle private public if you are an owner
with st.sidebar:
hf_token = os.getenv("HF_TOKEN")
password = st.text_input("Admin login:", type="password")
if password == hf_token:
if st.button("Pull New Results"):
with st.spinner("Pulling new results", show_time=True):
try:
process = subprocess.Popen(
["python3", "utils.py"],
text=True, # Decode stdout/stderr as text
)
st.info(f"Background task started with PID: {process.pid}")
process.wait()
process.kill()
if process.returncode != 0:
st.error("The process did not finish successfully.")
else:
st.success(f"PID {process.pid} finished!")
# If a user has the right perms, then this clears the cache
load_results.clear()
get_volume.clear()
load_submission.clear()
st.rerun()
except Exception as e:
st.error(f"Error starting background task: {e}")
## Initialize the toggle state in session_state if it doesn't exist
if "private_view" not in st.session_state:
st.session_state.private_view = False
# Create the toggle widget
# The 'value' parameter sets the initial state, here linked to session_state
# The 'key' parameter is crucial for identifying the widget across reruns and linking to session_state
toggle_value = st.toggle("Private Scores", value=st.session_state.private_view, key="private_view")
# The 'toggle_value' variable will hold the current state of the toggle (True or False)
if toggle_value:
st.write("Showing **PRIVATE** scores.")
else:
st.write("Showing **PUBLIC** scores.")
split = "public" if not toggle_value else "private"
else:
split = "public"
def show_leaderboard(results, task):
source_split_map = {}
if split == "private":
_sol_df = pd.read_csv(COMP_CACHE / task / "solution.csv")
pairs_df = _sol_df[["source_og", "split"]].drop_duplicates()
source_split_map = {x: y for x, y in zip(pairs_df["source_og"], pairs_df["split"])}
cols = [
"generated_accuracy",
"real_accuracy",
# "pristine_accuracy",
"balanced_accuracy",
"auc",
"fail_rate",
"total_time",
"datetime",
]
column_config = {
"balanced_accuracy": st.column_config.NumberColumn(
"βš–οΈ Balanced Accruacy",
format="compact",
min_value=0,
pinned=True,
max_value=1.0,
# width="small",
),
"generated_accuracy": st.column_config.NumberColumn(
"πŸ‘€ True Postive Rate",
format="compact",
min_value=0,
pinned=True,
max_value=1.0,
# width="small",
),
"real_accuracy": st.column_config.NumberColumn(
"πŸ§‘β€πŸŽ€ True Negative Rate",
format="compact",
min_value=0,
pinned=True,
max_value=1.0,
# width="small",
),
"auc": st.column_config.NumberColumn(
"πŸ“ AUC",
format="compact",
min_value=0,
pinned=True,
max_value=1.0,
# width="small",
),
"fail_rate": st.column_config.NumberColumn(
"❌ Fail Rate",
format="compact",
# width="small",
),
"total_time": st.column_config.NumberColumn(
"πŸ•’ Inference Time",
format="compact",
# width="small",
),
"datetime": st.column_config.DatetimeColumn(
"πŸ—“οΈ Submission Date",
format="YYYY-MM-DD",
# width="small",
),
}
labels = {"real": "πŸ§‘β€πŸŽ€", "generated": "πŸ‘€"}
for c in results[f"{split}_score"].columns:
if "accuracy" in c:
continue
if any(p in c for p in ["generated", "real"]):
s = c.split("_")
pred = s[0]
source = " ".join(s[1:])
column_config[c] = st.column_config.NumberColumn(
labels[pred] + " " + source,
help=c,
format="compact",
min_value=0,
max_value=1.0,
)
"#### Summary"
st.dataframe(results[f"{split}_score"].loc[:, cols], column_config=column_config)
"##### Accuracy Breakdown by Source"
accuracy_types = {
"True positive/negative rate": 0,
"Conditional balanced accuracy": 1,
"AUC": 2,
}
granularity = st.radio(
"accuracy type",
list(accuracy_types.keys()),
key=f"granularity-{task}",
horizontal=True,
label_visibility="collapsed",
index=0,
)
## Subset the dataset
cols = [
c
for c in results[f"{split}_score"].columns
if "generated_" in c and "accuracy" not in c and "conditional" not in c
]
col_names = [
(
f"πŸ“’ {c.replace('generated_', '')}"
if source_split_map.get(c.replace("generated_", ""), "public") == "public"
else f"πŸ” {c.replace('generated_', '')}"
)
for c in results[f"{split}_score"].columns
if "generated_" in c and "accuracy" not in c and "conditional" not in c
]
gen_tmp = results[f"{split}_score"].loc[:, cols].copy()
gen_tmp.columns = col_names
cols = [
c for c in results[f"{split}_score"].columns if "real_" in c and "accuracy" not in c and "conditional" not in c
]
col_names = [
(
f"πŸ“’ {c.replace('real_', '')}"
if source_split_map.get(c.replace("real_", ""), "public") == "public"
else f"πŸ” {c.replace('real_', '')}"
)
for c in results[f"{split}_score"].columns
if "real_" in c and "accuracy" not in c and "conditional" not in c
]
real_tmp = results[f"{split}_score"].loc[:, cols].copy()
real_tmp.columns = col_names
## Check cases
if accuracy_types[granularity] == 0:
"#### πŸ‘€ True Positive Rate | Generated Source"
st.dataframe(gen_tmp, column_config=column_config)
"#### πŸ§‘β€πŸŽ€ True Negative Rate | Real Source"
st.dataframe(real_tmp, column_config=column_config)
elif accuracy_types[granularity] == 1:
"#### πŸ‘€ Balanced Accuracy | Generated Source"
tnr = results[f"{split}_score"].loc[:, ["real_accuracy"]]
gen_tmp[:] = (gen_tmp.values + tnr.values) / 2.0
st.dataframe(gen_tmp, column_config=column_config)
"#### πŸ§‘β€πŸŽ€ Balanced Accuracy | Real Source"
tpr = results[f"{split}_score"].loc[:, ["generated_accuracy"]]
real_tmp[:] = (real_tmp.values + tpr.values) / 2.0
st.dataframe(real_tmp, column_config=column_config)
else:
cols = [c for c in results[f"{split}_score"].columns if "generated_conditional_auc" in c]
col_names = [
(
f"πŸ“’ {c.replace('generated_conditional_auc_', '')}"
if source_split_map.get(c.replace("generated_conditional_auc_", ""), "public") == "public"
else f"πŸ” {c.replace('generated_conditional_auc_', '')}"
)
for c in results[f"{split}_score"].columns
if "generated_conditional_auc_" in c
]
gen_tmp = results[f"{split}_score"].loc[:, cols].dropna(axis=1).copy()
gen_tmp.columns = col_names
cols = [c for c in results[f"{split}_score"].columns if "real_conditional_auc" in c]
col_names = [
(
f"πŸ“’ {c.replace('real_conditional_auc_', '')}"
if source_split_map.get(c.replace("real_conditional_auc_", ""), "public") == "public"
else f"πŸ” {c.replace('real_conditional_auc_', '')}"
)
for c in results[f"{split}_score"].columns
if "real_conditional_auc" in c
]
real_tmp = results[f"{split}_score"].loc[:, cols].dropna(axis=1).copy()
real_tmp.columns = col_names
"#### πŸ‘€ Conditional AUC | Generated Source"
st.dataframe(gen_tmp, column_config=column_config)
"#### πŸ§‘β€πŸŽ€ Conditional AUC | Real Source"
st.dataframe(real_tmp, column_config=column_config)
def make_roc(results):
results["FA"] = 1.0 - results["real_accuracy"]
chart = (
alt.Chart(results)
.mark_circle()
.encode(
x=alt.X("FA:Q", title="πŸ§‘β€πŸŽ€ False Positive Rate", scale=alt.Scale(domain=[0.0, 1.0])),
y=alt.Y("generated_accuracy:Q", title="πŸ‘€ True Positive Rate", scale=alt.Scale(domain=[0.0, 1.0])),
color="team:N", # Color by categorical field
size=alt.Size(
"total_time:Q", title="πŸ•’ Inference Time", scale=alt.Scale(rangeMin=100)
), # Size by quantitative field
)
.properties(width=400, height=400, title="Detection vs False Alarm vs Inference Time")
)
diag_line = (
alt.Chart(pd.DataFrame(dict(tpr=[0, 1], fpr=[0, 1])))
.mark_line(color="lightgray", strokeDash=[8, 4])
.encode(x="fpr", y="tpr")
)
return chart + diag_line
def make_acc(results):
results = results.loc[results["total_time"] >= 0]
chart = (
alt.Chart(results)
.mark_circle(size=200)
.encode(
x=alt.X("total_time:Q", title="πŸ•’ Inference Time", scale=alt.Scale(domain=[0.0, 10000])),
y=alt.Y(
"balanced_accuracy:Q",
title="Balanced Accuracy",
scale=alt.Scale(domain=[0.4, 1]),
),
color="team:N", # Color by categorical field # Size by quantitative field
)
.properties(width=400, height=400, title="Inference Time vs Balanced Accuracy")
)
diag_line = (
alt.Chart(pd.DataFrame(dict(t=[0, results["total_time"].max()], y=[0.5, 0.5])))
.mark_line(color="lightgray", strokeDash=[8, 4])
.encode(x="t", y="y")
)
return chart + diag_line
def get_heatmaps(temp):
h1 = make_heatmap(temp, "generated", symbol="πŸ‘€")
h2 = make_heatmap(temp, "real", symbol="πŸ§‘β€πŸŽ€")
st.altair_chart(h1, use_container_width=True)
st.altair_chart(h2, use_container_width=True)
if temp.columns.str.contains("aug", case=False).any():
h3 = make_heatmap(temp, "aug", symbol="πŸ› οΈ")
st.altair_chart(h3, use_container_width=True)
def make_plots_for_task(task, split, best_only):
results = load_results(task, best_only=best_only)
temp = results[f"{split}_score"].reset_index()
t1, t2 = st.tabs(["Tables", "Charts"])
with t1:
show_leaderboard(results, task)
with t2:
roc_scatter = make_roc(temp)
acc_vs_time = make_acc(temp)
if split == "private" and hf_token is not None:
full_curves = st.toggle("Full curve", value=True, key=f"all curves {task}")
if full_curves:
roc_scatter = make_roc_curves(task, temp["submission_id"].values.tolist(), best_only) + roc_scatter
st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
else:
st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
updated = get_updated_time()
st.markdown(updated)
best_only = True
tp, t1, volume_tab, all_submission_tab = st.tabs(
["**Pilot Task**", "**Task 1**", "**Submission Volume**", "**All Submissions**"]
)
with tp:
"*Detection of Synthetic Video Content. Video files are unmodified from the original output from the models or the real sources.*"
make_plots_for_task(TASKS[0], split, best_only)
with t1:
"*Detection of Synthetic Video Content. Video files are unmodified from the original output from the models or the real sources.*"
make_plots_for_task(TASKS[1], split, best_only)
with volume_tab:
subs = get_volume()
status_lookup = "QUEUED,PROCESSING,SUCCESS,FAILED".split(",")
found_columns = subs.columns.values.tolist()
status_lookup = list(set(status_lookup) & set(found_columns))
st.bar_chart(subs, x="date", y=status_lookup, stack=True)
total_submissions = int(subs.loc[:, status_lookup].fillna(0).values.sum())
st.metric("Total Submissions", value=total_submissions)
st.metric("Duration", f'{(subs["date"].max() - subs["date"].min()).days} days')
if split == "private":
with all_submission_tab:
data = load_submission()
st.dataframe(data)