Roman Castagné
Initial commit
54c73d3
raw
history blame
2.64 kB
import json
from Levenshtein import distance
import streamlit as st
import numpy as np
import plotly.express as px
from sklearn.decomposition import PCA
def load_data():
embeddings = np.load("data/simplesegmentT5_embeddings.npy")
words = json.load(open("data/words.json", "r"))
return embeddings, words
def project_embeddings(embeddings):
pca = PCA(n_components=3)
proj = pca.fit_transform(embeddings)
return proj
def filter_words(words, remove_capitalized, length):
idx = []
for i, w in enumerate(words):
if remove_capitalized and w.lower() != w:
continue
if len(w) < length[0] or len(w) > length[1]:
continue
idx.append(i)
return idx
def color_length(words):
return [len(w) for w in words]
def color_first_letter(words):
return [min(1, max(0, (ord(w.lower()[0]) - 97) / 26)) for w in words]
def color_levenshtein(words):
return [distance(w, words[4]) for w in words]
def plot_scatter(words, embeddings, remove_capitalized, length, color_select):
idx = filter_words(words, remove_capitalized, length)
filtered_embeddings = embeddings[idx]
filtered_words = [words[i] for i in idx]
proj = project_embeddings(filtered_embeddings)
if color_select == "Word length":
color = color_length(filtered_words)
else:
color = color_levenshtein(filtered_words)
fig = px.scatter_3d(
x=proj[:, 0],
y=proj[:, 1],
z=proj[:, 2],
width=800,
height=600,
color=color,
color_continuous_scale=px.colors.sequential.Viridis,
hover_name=filtered_words,
title="SimpleSegmentT5 Embeddings",
)
fig.update_traces(
marker={"size": 6, "line": {"width": 2}},
selector={"mode": "markers"},
)
return fig
def main():
embeddings, words = load_data()
proj = project_embeddings(embeddings)
fig = px.scatter_3d(
x=proj[:, 0],
y=proj[:, 1],
z=proj[:, 2],
color=[len(w) for w in words],
hover_name=words,
title="SimpleSegmentT5 Embeddings",
)
st.sidebar.title("Settings")
remove_checkbox = st.sidebar.checkbox(
"Remove capitalized words",
value=True,
key="include_capitalized",
)
length_slider = st.sidebar.slider("Word length", 3, 9, (3, 9))
color_select = st.sidebar.radio("Color by", ["Word length", "Levenshtein distance to random word"])
scatter = st.plotly_chart(plot_scatter(words, embeddings, remove_checkbox, length_slider, color_select))
if __name__ == "__main__":
main()