Spaces:
Running
Running
Commit
·
0b8c16d
1
Parent(s):
ab74236
upload plot
Browse files- app.py +6 -1
- src/plt.py +53 -0
- src/utils.py +12 -0
app.py
CHANGED
|
@@ -5,6 +5,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
| 5 |
from datasets import load_dataset
|
| 6 |
from src.utils import load_all_data
|
| 7 |
from src.md import ABOUT_TEXT, TOP_TEXT
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
api = HfApi()
|
|
@@ -210,7 +211,11 @@ with gr.Blocks() as app:
|
|
| 210 |
sample_display = gr.Markdown("{sampled data loads here}")
|
| 211 |
|
| 212 |
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
# Load data when app starts, TODO make this used somewhere...
|
| 216 |
# def load_data_on_start():
|
|
|
|
| 5 |
from datasets import load_dataset
|
| 6 |
from src.utils import load_all_data
|
| 7 |
from src.md import ABOUT_TEXT, TOP_TEXT
|
| 8 |
+
from src.plt import plot_avg_correlation
|
| 9 |
import numpy as np
|
| 10 |
|
| 11 |
api = HfApi()
|
|
|
|
| 211 |
sample_display = gr.Markdown("{sampled data loads here}")
|
| 212 |
|
| 213 |
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
|
| 214 |
+
# removed plot because not pretty enough
|
| 215 |
+
# with gr.TabItem("Model Correlation"):
|
| 216 |
+
# with gr.Row():
|
| 217 |
+
# plot = plot_avg_correlation(herm_data_avg, prefs_data)
|
| 218 |
+
# gr.Plot(plot)
|
| 219 |
|
| 220 |
# Load data when app starts, TODO make this used somewhere...
|
| 221 |
# def load_data_on_start():
|
src/plt.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from .utils import undo_hyperlink
|
| 4 |
+
|
| 5 |
+
def plot_avg_correlation(df1, df2):
|
| 6 |
+
"""
|
| 7 |
+
Plots the "average" column for each unique model that appears in both dataframes.
|
| 8 |
+
|
| 9 |
+
Parameters:
|
| 10 |
+
- df1: pandas DataFrame containing columns "model" and "average".
|
| 11 |
+
- df2: pandas DataFrame containing columns "model" and "average".
|
| 12 |
+
"""
|
| 13 |
+
# Identify the unique models that appear in both DataFrames
|
| 14 |
+
common_models = pd.Series(list(set(df1['model']) & set(df2['model'])))
|
| 15 |
+
|
| 16 |
+
# Set up the plot
|
| 17 |
+
plt.figure(figsize=(13, 6), constrained_layout=True)
|
| 18 |
+
|
| 19 |
+
# axes from 0 to 1 for x and y
|
| 20 |
+
plt.xlim(0.475, 0.8)
|
| 21 |
+
plt.ylim(0.475, 0.8)
|
| 22 |
+
|
| 23 |
+
# larger font (16)
|
| 24 |
+
plt.rcParams.update({'font.size': 12, 'axes.labelsize': 14,'axes.titlesize': 14})
|
| 25 |
+
# plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
|
| 26 |
+
# plt.tight_layout()
|
| 27 |
+
# plt.margins(0,0)
|
| 28 |
+
|
| 29 |
+
for model in common_models:
|
| 30 |
+
# Filter data for the current model
|
| 31 |
+
df1_model_data = df1[df1['model'] == model]['average'].values
|
| 32 |
+
df2_model_data = df2[df2['model'] == model]['average'].values
|
| 33 |
+
|
| 34 |
+
# Plotting
|
| 35 |
+
plt.scatter(df1_model_data, df2_model_data, label=model)
|
| 36 |
+
m_name = undo_hyperlink(model)
|
| 37 |
+
if m_name == "No text found":
|
| 38 |
+
m_name = "Random"
|
| 39 |
+
# Add text above each point like
|
| 40 |
+
# plt.text(x[i] + 0.1, y[i] + 0.1, label, ha='left', va='bottom')
|
| 41 |
+
plt.text(df1_model_data - .005, df2_model_data, m_name, horizontalalignment='right', verticalalignment='center')
|
| 42 |
+
|
| 43 |
+
# add correlation line to scatter plot
|
| 44 |
+
# first, compute correlation
|
| 45 |
+
corr = df1['average'].corr(df2['average'])
|
| 46 |
+
# add correlation line based on corr
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
plt.xlabel('HERM Eval. Set Avg.', fontsize=16)
|
| 51 |
+
plt.ylabel('Pref. Test Sets Avg.', fontsize=16)
|
| 52 |
+
# plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
|
| 53 |
+
return plt
|
src/utils.py
CHANGED
|
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
import numpy as np
|
| 5 |
import os
|
|
|
|
| 6 |
|
| 7 |
# From Open LLM Leaderboard
|
| 8 |
def model_hyperlink(link, model_name):
|
|
@@ -10,6 +11,17 @@ def model_hyperlink(link, model_name):
|
|
| 10 |
return "random"
|
| 11 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
# Define a function to fetch and process data
|
| 14 |
def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to pull the git repo
|
| 15 |
dir = Path(data_repo)
|
|
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
import numpy as np
|
| 5 |
import os
|
| 6 |
+
import re
|
| 7 |
|
| 8 |
# From Open LLM Leaderboard
|
| 9 |
def model_hyperlink(link, model_name):
|
|
|
|
| 11 |
return "random"
|
| 12 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 13 |
|
| 14 |
+
def undo_hyperlink(html_string):
|
| 15 |
+
# Regex pattern to match content inside > and <
|
| 16 |
+
pattern = r'>[^<]+<'
|
| 17 |
+
match = re.search(pattern, html_string)
|
| 18 |
+
if match:
|
| 19 |
+
# Extract the matched text and remove leading '>' and trailing '<'
|
| 20 |
+
return match.group(0)[1:-1]
|
| 21 |
+
else:
|
| 22 |
+
return "No text found"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
# Define a function to fetch and process data
|
| 26 |
def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to pull the git repo
|
| 27 |
dir = Path(data_repo)
|