Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from data_loader import MODELS, DATASETS, SCORES, HEADER_CONTENT | |
from chat import ( | |
format_chat_display, | |
format_metrics_display, | |
format_tool_info, | |
) | |
def get_updated_df(df, df_output): | |
df = df.iloc[: len(df_output)].copy() | |
df["response"] = df_output["response"].tolist() | |
df["rationale"] = df_output["rationale"].tolist() | |
df["explanation"] = df_output["explanation"].tolist() | |
df["score"] = df_output["score"].tolist() | |
cols = [ | |
"conversation", | |
"tools_langchain", | |
"n_turns", | |
"len_query", | |
"n_tools", | |
"response", | |
"rationale", | |
"explanation", | |
"score", | |
] | |
return df[cols] | |
def get_chat_and_score_df(model, dataset): | |
df_output = pd.read_parquet(f"output/{model}/{dataset}.parquet") | |
df = pd.read_parquet(f"datasets/{dataset}.parquet") | |
df = get_updated_df(df, df_output) | |
return df | |
def on_filter_change( | |
model, | |
dataset, | |
min_score, | |
max_score, | |
min_n_turns, | |
min_len_query, | |
min_n_tools, | |
): | |
try: | |
# Call filter_and_display with index 0 and unpack 4 values | |
chat_html, metrics_html, tool_html, index_html = filter_and_display( | |
model, | |
dataset, | |
min_score, | |
max_score, | |
min_n_turns, | |
min_len_query, | |
min_n_tools, | |
0, | |
) | |
# Return exactly 4 values | |
return chat_html, metrics_html, tool_html, index_html | |
except Exception as e: | |
error_html = f""" | |
<div style="padding: 1.5rem; color: var(--score-low);"> | |
<div style="font-weight: 600;">Filter Error</div> | |
<div style="font-family: monospace; background-color: var(--surface-color-alt); padding: 0.5rem; margin-top: 0.5rem;"> | |
{str(e)} | |
</div> | |
</div> | |
""" | |
return ( | |
error_html, | |
"<div style='text-align: center;'>No metrics available</div>", | |
"<div style='text-align: center;'>No tool information available</div>", | |
"<div style='text-align: center;'>0/0</div>", | |
) | |
def navigate_prev( | |
current_idx, | |
model, | |
dataset, | |
min_score, | |
max_score, | |
min_n_turns, | |
min_len_query, | |
min_n_tools, | |
): | |
try: | |
# Handle current_idx as dictionary | |
if isinstance(current_idx, dict) and "value" in current_idx: | |
idx_val = int(current_idx["value"]) | |
else: | |
idx_val = int(current_idx) if current_idx is not None else 0 | |
new_index = max(0, idx_val - 1) | |
chat_html, metrics_html, tool_html, index_html = filter_and_display( | |
model, | |
dataset, | |
min_score, | |
max_score, | |
min_n_turns, | |
min_len_query, | |
min_n_tools, | |
new_index, | |
) | |
return chat_html, metrics_html, tool_html, index_html, new_index | |
except Exception as e: | |
error_html = f""" | |
<div style="padding: 1.5rem; color: var(--score-low);"> | |
<div style="font-weight: 600;">Navigation Error</div> | |
<div style="font-family: monospace; background-color: var(--surface-color-alt); padding: 0.5rem; margin-top: 0.5rem;"> | |
{str(e)} | |
</div> | |
</div> | |
""" | |
return ( | |
error_html, | |
"<div style='text-align: center;'>No metrics available</div>", | |
"<div style='text-align: center;'>No tool information available</div>", | |
"<div style='text-align: center;'>0/0</div>", | |
current_idx or 0, | |
) | |
def navigate_next( | |
current_idx, | |
model, | |
dataset, | |
min_score, | |
max_score, | |
min_n_turns, | |
min_len_query, | |
min_n_tools, | |
): | |
try: | |
# Handle current_idx as dictionary | |
if isinstance(current_idx, dict) and "value" in current_idx: | |
idx_val = int(current_idx["value"]) | |
else: | |
idx_val = int(current_idx) if current_idx is not None else 0 | |
new_index = idx_val + 1 | |
chat_html, metrics_html, tool_html, index_html = filter_and_display( | |
model, | |
dataset, | |
min_score, | |
max_score, | |
min_n_turns, | |
min_len_query, | |
min_n_tools, | |
new_index, | |
) | |
return chat_html, metrics_html, tool_html, index_html, new_index | |
except Exception as e: | |
error_html = f""" | |
<div style="padding: 1.5rem; color: var(--score-low);"> | |
<div style="font-weight: 600;">Navigation Error</div> | |
<div style="font-family: monospace; background-color: var(--surface-color-alt); padding: 0.5rem; margin-top: 0.5rem;"> | |
{str(e)} | |
</div> | |
</div> | |
""" | |
return ( | |
error_html, | |
"<div style='text-align: center;'>No metrics available</div>", | |
"<div style='text-align: center;'>No tool information available</div>", | |
"<div style='text-align: center;'>0/0</div>", | |
current_idx or 0, | |
) | |
def filter_and_display( | |
model, | |
dataset, | |
min_score, | |
max_score, | |
min_n_turns, | |
min_len_query, | |
min_n_tools, | |
index=0, | |
): | |
"""Combined function to filter data and update display""" | |
try: | |
# Extract model | |
if isinstance(model, dict): | |
if "value" in model: | |
model_str = str(model["value"]) | |
else: | |
model_str = MODELS[0] | |
else: | |
model_str = str(model) if model is not None else MODELS[0] | |
# Extract dataset | |
if isinstance(dataset, dict): | |
if "value" in dataset: | |
dataset_str = str(dataset["value"]) | |
else: | |
dataset_str = DATASETS[0] | |
else: | |
dataset_str = str(dataset) if dataset is not None else DATASETS[0] | |
# Extract min_score | |
if isinstance(min_score, dict): | |
if "value" in min_score: | |
min_score_val = float(min_score["value"]) | |
else: | |
min_score_val = float(min(SCORES)) | |
else: | |
min_score_val = ( | |
float(min_score) if min_score is not None else float(min(SCORES)) | |
) | |
# Extract max_score | |
if isinstance(max_score, dict): | |
if "value" in max_score: | |
max_score_val = float(max_score["value"]) | |
else: | |
max_score_val = float(max(SCORES)) | |
else: | |
max_score_val = ( | |
float(max_score) if max_score is not None else float(max(SCORES)) | |
) | |
# Extract min_n_turns | |
if isinstance(min_n_turns, dict): | |
if "value" in min_n_turns: | |
min_n_turns_val = int(min_n_turns["value"]) | |
else: | |
min_n_turns_val = 0 | |
else: | |
min_n_turns_val = int(min_n_turns) if min_n_turns is not None else 0 | |
# Extract min_len_query | |
if isinstance(min_len_query, dict): | |
if "value" in min_len_query: | |
min_len_query_val = int(min_len_query["value"]) | |
else: | |
min_len_query_val = 0 | |
else: | |
min_len_query_val = int(min_len_query) if min_len_query is not None else 0 | |
# Extract min_n_tools | |
if isinstance(min_n_tools, dict): | |
if "value" in min_n_tools: | |
min_n_tools_val = int(min_n_tools["value"]) | |
else: | |
min_n_tools_val = 0 | |
else: | |
min_n_tools_val = int(min_n_tools) if min_n_tools is not None else 0 | |
# Extract index | |
if isinstance(index, dict): | |
if "value" in index: | |
try: | |
index_val = int(index["value"]) | |
except (ValueError, TypeError): | |
index_val = 0 | |
else: | |
index_val = 0 | |
else: | |
try: | |
index_val = int(index) if index is not None else 0 | |
except (ValueError, TypeError): | |
index_val = 0 | |
# Get the data | |
df_chat = get_chat_and_score_df(model_str, dataset_str) | |
# Ensure filter columns exist | |
for col, default in [ | |
("score", 0.0), | |
("n_turns", 0), | |
("len_query", 0), | |
("n_tools", 0), | |
]: | |
if col not in df_chat.columns: | |
df_chat[col] = default | |
else: | |
df_chat[col] = pd.to_numeric(df_chat[col], errors="coerce").fillna( | |
default | |
) | |
# Apply all filters | |
df_filtered = df_chat[ | |
(df_chat["score"] >= min_score_val) | |
& (df_chat["score"] <= max_score_val) | |
& (df_chat["n_turns"] >= min_n_turns_val) | |
& (df_chat["len_query"] >= min_len_query_val) | |
& (df_chat["n_tools"] >= min_n_tools_val) | |
].copy() | |
# Check if dataframe is empty | |
if len(df_filtered) == 0: | |
empty_message = """ | |
<div style=" | |
padding: 1.5rem; | |
text-align: center; | |
color: var(--text-muted); | |
background-color: var(--surface-color-alt); | |
border-radius: 8px; | |
border: 1px dashed var(--border-color); | |
margin: 1rem 0;"> | |
<div style="font-size: 2rem; margin-bottom: 1rem;">📭</div> | |
<div style="font-weight: 500; margin-bottom: 0.5rem;">No Results Found</div> | |
<div style="font-style: italic; font-size: 0.9rem;">Try adjusting your filters to see more data</div> | |
</div> | |
""" | |
return ( | |
empty_message, | |
empty_message, | |
empty_message, | |
"<div style='text-align: center; color: var(--text-muted);'>0/0</div>", | |
) | |
# Ensure index is valid | |
max_index = len(df_filtered) - 1 | |
valid_index = max(0, min(index_val, max_index)) | |
# Get the row | |
row = df_filtered.iloc[valid_index] | |
# Format displays | |
chat_html = format_chat_display(row) | |
metrics_html = format_metrics_display(row) | |
# Get tools info with error handling | |
try: | |
tool_html = format_tool_info(row["tools_langchain"]) | |
except Exception as e: | |
tool_html = f""" | |
<div style="padding: 1rem; background-color: var(--surface-color-alt); border-radius: 8px; color: var(--text-muted);"> | |
<div style="font-weight: 500; margin-bottom: 0.5rem;">Tool Information Unavailable</div> | |
<div style="font-size: 0.9rem;">Error: {str(e)}</div> | |
</div> | |
""" | |
# Index display | |
index_html = f""" | |
<div style=" | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
font-weight: 500; | |
color: var(--primary-text); | |
background-color: var(--surface-color-alt); | |
padding: 0.5rem 1rem; | |
border-radius: 20px; | |
font-size: 0.9rem; | |
width: fit-content; | |
margin: 0 auto;"> | |
<span style="margin-right: 0.5rem;">📄</span>{valid_index + 1}/{len(df_filtered)} | |
</div> | |
""" | |
return chat_html, metrics_html, tool_html, index_html | |
except Exception as e: | |
error_html = f""" | |
<div style=" | |
padding: 1.5rem; | |
color: var(--score-low); | |
background-color: var(--surface-color); | |
border: 1px solid var(--score-low); | |
border-radius: 8px; | |
margin: 1rem 0; | |
display: flex; | |
align-items: flex-start;"> | |
<div style="flex-shrink: 0; margin-right: 1rem; font-size: 1.5rem;">⚠️</div> | |
<div> | |
<div style="font-weight: 600; margin-bottom: 0.5rem;">Error Occurred</div> | |
<div style=" | |
font-family: monospace; | |
background-color: var(--surface-color-alt); | |
padding: 1rem; | |
border-radius: 4px; | |
white-space: pre-wrap; | |
font-size: 0.9rem;"> | |
{str(e)} | |
</div> | |
</div> | |
</div> | |
""" | |
return ( | |
error_html, | |
"<div style='padding: 1.5rem; color: var(--text-muted); text-align: center;'>No metrics available</div>", | |
"<div style='padding: 1.5rem; color: var(--text-muted); text-align: center;'>No tool information available</div>", | |
"<div style='text-align: center; color: var(--text-muted);'>0/0</div>", | |
) | |
def create_exploration_tab(df): | |
"""Create an enhanced data exploration tab with better UI and functionality.""" | |
# Main UI setup | |
with gr.Tab("Data Exploration"): | |
# CSS styling (unchanged) | |
gr.HTML( | |
""" | |
<style> | |
/* Custom styling for the exploration tab */ | |
:root[data-theme="light"] { | |
--surface-color: #f8f9fa; | |
--surface-color-alt: #ffffff; | |
--text-color: #202124; | |
--text-muted: #666666; | |
--primary-text: #1a73e8; | |
--primary-text-light: rgba(26, 115, 232, 0.3); | |
--border-color: #e9ecef; | |
--border-color-light: #f1f3f5; | |
--shadow-color: rgba(0,0,0,0.05); | |
--message-bg-user: #E5F6FD; | |
--message-bg-assistant: #F7F7F8; | |
--message-bg-system: #FFF3E0; | |
--response-bg: #F0F7FF; | |
--score-high: #1a73e8; | |
--score-med: #f4b400; | |
--score-low: #ea4335; | |
} | |
:root[data-theme="dark"] { | |
--surface-color: #1e1e1e; | |
--surface-color-alt: #2d2d2d; | |
--text-color: #ffffff; | |
--text-muted: #a0a0a0; | |
--primary-text: #60a5fa; | |
--primary-text-light: rgba(96, 165, 250, 0.3); | |
--border-color: #404040; | |
--border-color-light: #333333; | |
--shadow-color: rgba(0,0,0,0.2); | |
--message-bg-user: #2d3748; | |
--message-bg-assistant: #1a1a1a; | |
--message-bg-system: #2c2516; | |
--response-bg: #1e2a3a; | |
--score-high: #60a5fa; | |
--score-med: #fbbf24; | |
--score-low: #ef4444; | |
} | |
#exploration-header { | |
margin-bottom: 1.5rem; | |
padding-bottom: 1rem; | |
border-bottom: 1px solid var(--border-color); | |
} | |
.filter-container { | |
background-color: var(--surface-color); | |
border-radius: 10px; | |
padding: 1rem; | |
margin-bottom: 1.5rem; | |
border: 1px solid var(--border-color); | |
box-shadow: 0 2px 6px var(--shadow-color); | |
} | |
.navigation-buttons button { | |
min-width: 120px; | |
font-weight: 500; | |
} | |
.content-panel { | |
margin-top: 1.5rem; | |
} | |
@media (max-width: 768px) { | |
.filter-row { | |
flex-direction: column; | |
} | |
} | |
</style> | |
""" | |
) | |
# Header | |
with gr.Row(elem_id="exploration-header"): | |
gr.HTML(HEADER_CONTENT) | |
# Filters section | |
with gr.Column(elem_classes="filter-container"): | |
gr.Markdown("### 🔍 Filter Options") | |
with gr.Row(equal_height=True, elem_classes="filter-row"): | |
explore_model = gr.Dropdown( | |
choices=MODELS, | |
value=MODELS[0], | |
label="Model", | |
container=True, | |
scale=1, | |
info="Select AI model", | |
) | |
explore_dataset = gr.Dropdown( | |
choices=DATASETS, | |
value=DATASETS[0], | |
label="Dataset", | |
container=True, | |
scale=1, | |
info="Select evaluation dataset", | |
) | |
with gr.Row(equal_height=True, elem_classes="filter-row"): | |
min_score = gr.Slider( | |
minimum=float(min(SCORES)), | |
maximum=float(max(SCORES)), | |
value=float(min(SCORES)), | |
step=0.1, | |
label="Minimum TSQ Score", | |
container=True, | |
scale=1, | |
info="Filter responses with scores above this threshold", | |
) | |
max_score = gr.Slider( | |
minimum=float(min(SCORES)), | |
maximum=float(max(SCORES)), | |
value=float(max(SCORES)), | |
step=0.1, | |
label="Maximum TSQ Score", | |
container=True, | |
scale=1, | |
info="Filter responses with scores below this threshold", | |
) | |
# Get the data for initial ranges | |
df_chat = get_chat_and_score_df(explore_model.value, explore_dataset.value) | |
# Ensure columns exist and get ranges | |
n_turns_max = int(df_chat["n_turns"].max()) | |
len_query_max = int(df_chat["len_query"].max()) | |
n_tools_max = int(df_chat["n_tools"].max()) | |
with gr.Row(equal_height=True, elem_classes="filter-row"): | |
n_turns_filter = gr.Slider( | |
minimum=0, | |
maximum=n_turns_max, | |
value=0, | |
step=1, | |
label="Minimum Turn Count", | |
container=True, | |
scale=1, | |
info="Filter by minimum number of conversation turns", | |
) | |
len_query_filter = gr.Slider( | |
minimum=0, | |
maximum=len_query_max, | |
value=0, | |
step=10, | |
label="Minimum Query Length", | |
container=True, | |
scale=1, | |
info="Filter by minimum length of query in characters", | |
) | |
n_tools_filter = gr.Slider( | |
minimum=0, | |
maximum=n_tools_max, | |
value=0, | |
step=1, | |
label="Minimum Tool Count", | |
container=True, | |
scale=1, | |
info="Filter by minimum number of tools used", | |
) | |
with gr.Row(): | |
reset_btn = gr.Button("Reset Filters", size="sm", variant="secondary") | |
# Navigation row | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
prev_btn = gr.Button( | |
"← Previous", | |
size="lg", | |
variant="secondary", | |
elem_classes="navigation-buttons", | |
) | |
with gr.Column(scale=1, min_width=100): | |
# Get initial count from default data | |
df_initial = get_chat_and_score_df(MODELS[0], DATASETS[0]) | |
initial_count = len(df_initial) | |
index_display = gr.HTML( | |
value=f"""<div style=" | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
font-weight: 500; | |
color: var(--primary-text); | |
background-color: var(--surface-color-alt); | |
padding: 0.5rem 1rem; | |
border-radius: 20px; | |
font-size: 0.9rem; | |
width: fit-content; | |
margin: 0 auto;"> | |
<span style="margin-right: 0.5rem;">📄</span>1/{initial_count} | |
</div>""", | |
elem_id="index-display", | |
) | |
with gr.Column(scale=1): | |
next_btn = gr.Button( | |
"Next →", | |
size="lg", | |
variant="secondary", | |
elem_classes="navigation-buttons", | |
) | |
# Content areas | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
chat_display = gr.HTML() | |
with gr.Column(scale=1): | |
metrics_display = gr.HTML() | |
with gr.Row(): | |
tool_info_display = gr.HTML() | |
# State for tracking current index (simple integer state) | |
current_index = gr.State(value=0) | |
def reset_index(): | |
"""Reset the current index to 0""" | |
return 0 | |
# Add these explicit event handlers for model and dataset changes | |
explore_model.change( | |
reset_index, | |
inputs=[], | |
outputs=[current_index], | |
) | |
explore_dataset.change( | |
reset_index, | |
inputs=[], | |
outputs=[current_index], | |
) | |
min_score.change( | |
reset_index, | |
inputs=[], | |
outputs=[current_index], | |
) | |
max_score.change( | |
reset_index, | |
inputs=[], | |
outputs=[current_index], | |
) | |
n_turns_filter.change( | |
reset_index, | |
inputs=[], | |
outputs=[current_index], | |
) | |
len_query_filter.change( | |
reset_index, | |
inputs=[], | |
outputs=[current_index], | |
) | |
n_tools_filter.change( | |
reset_index, | |
inputs=[], | |
outputs=[current_index], | |
) | |
# Reset filters | |
def reset_filters(): | |
return ( | |
MODELS[0], | |
DATASETS[0], | |
float(min(SCORES)), | |
float(max(SCORES)), | |
0, # n_turns | |
0, # len_query | |
0, # n_tools | |
) | |
reset_btn.click( | |
reset_filters, | |
outputs=[ | |
explore_model, | |
explore_dataset, | |
min_score, | |
max_score, | |
n_turns_filter, | |
len_query_filter, | |
n_tools_filter, | |
], | |
) | |
# Connect filter changes | |
# Replace the existing filter connections with this: | |
for control in [ | |
explore_model, | |
explore_dataset, | |
min_score, | |
max_score, | |
n_turns_filter, | |
len_query_filter, | |
n_tools_filter, | |
]: | |
control.change( | |
on_filter_change, | |
inputs=[ | |
explore_model, | |
explore_dataset, | |
min_score, | |
max_score, | |
n_turns_filter, | |
len_query_filter, | |
n_tools_filter, | |
], | |
outputs=[ | |
chat_display, | |
metrics_display, | |
tool_info_display, | |
index_display, | |
], | |
) | |
# Connect navigation buttons with necessary filter parameters | |
prev_btn.click( | |
navigate_prev, | |
inputs=[ | |
current_index, | |
explore_model, | |
explore_dataset, | |
min_score, | |
max_score, | |
n_turns_filter, | |
len_query_filter, | |
n_tools_filter, | |
], | |
outputs=[ | |
chat_display, | |
metrics_display, | |
tool_info_display, | |
index_display, | |
current_index, | |
], | |
) | |
next_btn.click( | |
navigate_next, | |
inputs=[ | |
current_index, | |
explore_model, | |
explore_dataset, | |
min_score, | |
max_score, | |
n_turns_filter, | |
len_query_filter, | |
n_tools_filter, | |
], | |
outputs=[ | |
chat_display, | |
metrics_display, | |
tool_info_display, | |
index_display, | |
current_index, | |
], | |
) | |
def update_slider_ranges(model, dataset): | |
df_chat = get_chat_and_score_df(model, dataset) | |
# Make sure columns are numeric first | |
df_chat["n_turns"] = pd.to_numeric( | |
df_chat["n_turns"], errors="coerce" | |
).fillna(0) | |
df_chat["len_query"] = pd.to_numeric( | |
df_chat["len_query"], errors="coerce" | |
).fillna(0) | |
df_chat["n_tools"] = pd.to_numeric( | |
df_chat["n_tools"], errors="coerce" | |
).fillna(0) | |
# Calculate maximums with safety buffers | |
n_turns_max = max(1, int(df_chat["n_turns"].max())) | |
len_query_max = max(10, int(df_chat["len_query"].max())) | |
n_tools_max = max(1, int(df_chat["n_tools"].max())) | |
# Return updated sliders using gr.update() | |
return ( | |
gr.update(maximum=n_turns_max, value=0), | |
gr.update(maximum=len_query_max, value=0), | |
gr.update(maximum=n_tools_max, value=0), | |
) | |
# Connect model and dataset changes to slider range updates | |
explore_model.change( | |
update_slider_ranges, | |
inputs=[explore_model, explore_dataset], | |
outputs=[n_turns_filter, len_query_filter, n_tools_filter], | |
) | |
explore_dataset.change( | |
update_slider_ranges, | |
inputs=[explore_model, explore_dataset], | |
outputs=[n_turns_filter, len_query_filter, n_tools_filter], | |
) | |
return [ | |
chat_display, | |
metrics_display, | |
tool_info_display, | |
index_display, | |
] | |