Spaces:
Sleeping
Sleeping
import json | |
from Levenshtein import distance | |
import streamlit as st | |
import numpy as np | |
import plotly.express as px | |
from sklearn.decomposition import PCA | |
def load_data(): | |
embeddings = np.load("data/simplesegmentT5_embeddings.npy") | |
words = json.load(open("data/words.json", "r")) | |
return embeddings, words | |
def project_embeddings(embeddings): | |
pca = PCA(n_components=3) | |
proj = pca.fit_transform(embeddings) | |
return proj | |
def filter_words(words, remove_capitalized, length): | |
idx = [] | |
for i, w in enumerate(words): | |
if remove_capitalized and w.lower() != w: | |
continue | |
if len(w) < length[0] or len(w) > length[1]: | |
continue | |
idx.append(i) | |
return idx | |
def color_length(words): | |
return [len(w) for w in words] | |
def color_first_letter(words): | |
return [min(1, max(0, (ord(w.lower()[0]) - 97) / 26)) for w in words] | |
def color_levenshtein(words): | |
return [distance(w, words[4]) for w in words] | |
def plot_scatter(words, embeddings, remove_capitalized, length, color_select): | |
idx = filter_words(words, remove_capitalized, length) | |
filtered_embeddings = embeddings[idx] | |
filtered_words = [words[i] for i in idx] | |
proj = project_embeddings(filtered_embeddings) | |
if color_select == "Word length": | |
color = color_length(filtered_words) | |
else: | |
color = color_levenshtein(filtered_words) | |
fig = px.scatter_3d( | |
x=proj[:, 0], | |
y=proj[:, 1], | |
z=proj[:, 2], | |
width=800, | |
height=600, | |
color=color, | |
color_continuous_scale=px.colors.sequential.Viridis, | |
hover_name=filtered_words, | |
title="SimpleSegmentT5 Embeddings", | |
) | |
fig.update_traces( | |
marker={"size": 6, "line": {"width": 2}}, | |
selector={"mode": "markers"}, | |
) | |
return fig | |
def main(): | |
embeddings, words = load_data() | |
proj = project_embeddings(embeddings) | |
fig = px.scatter_3d( | |
x=proj[:, 0], | |
y=proj[:, 1], | |
z=proj[:, 2], | |
color=[len(w) for w in words], | |
hover_name=words, | |
title="SimpleSegmentT5 Embeddings", | |
) | |
st.sidebar.title("Settings") | |
remove_checkbox = st.sidebar.checkbox( | |
"Remove capitalized words", | |
value=True, | |
key="include_capitalized", | |
) | |
length_slider = st.sidebar.slider("Word length", 3, 9, (3, 9)) | |
color_select = st.sidebar.radio("Color by", ["Word length", "Levenshtein distance to random word"]) | |
scatter = st.plotly_chart(plot_scatter(words, embeddings, remove_checkbox, length_slider, color_select)) | |
if __name__ == "__main__": | |
main() | |