Daksh0505's picture
Update app.py
ebd8345 verified
import streamlit as st
import joblib
import pandas as pd
import numpy as np
import os
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
import io
# =============================
# Data Loading (cached properly)
# =============================
@st.cache_resource
def load_data():
characters_df = pd.DataFrame(joblib.load('characters_list_got.joblib'), columns=['character'])
characters_df['normalized'] = characters_df['character'].str.lower().str.strip()
character_names = sorted(characters_df['character'].tolist())
sbert_embeddings = joblib.load('embeddings_got.joblib')
tfidf_embeddings = joblib.load('tfidf_embeddings_got.joblib')
# Precompute similarity matrices
sbert_sim = cosine_similarity(np.array(sbert_embeddings))
tfidf_sim = cosine_similarity(np.array(tfidf_embeddings))
return characters_df, character_names, sbert_embeddings, tfidf_embeddings, sbert_sim, tfidf_sim
def name_to_folder(name):
return name.lower().replace(" ", "_")
def get_image_path(name):
normalized = name.lower().strip()
folder_name = name_to_folder(normalized)
for ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp']:
candidate_path = os.path.join("images", folder_name, f"000001.{ext}")
if os.path.exists(candidate_path):
return candidate_path
placeholder_path = "images/placeholder.jpg"
return placeholder_path if os.path.exists(placeholder_path) else None
# =============================
# Recommendation Function
# =============================
def recommend_characters(model_type, input_character, characters_df, sbert_sim, tfidf_sim, top_n=5, weight=0.7):
input_character = input_character.lower().strip()
if input_character not in characters_df['normalized'].values:
return []
character_index = characters_df[characters_df['normalized'] == input_character].index[0]
if model_type == "Hybrid":
similarity_matrix = weight * sbert_sim + (1 - weight) * tfidf_sim
elif model_type == "SBERT":
similarity_matrix = sbert_sim
else:
similarity_matrix = tfidf_sim
distances = similarity_matrix[character_index]
# Get top N similar characters
top_indices = sorted(list(enumerate(distances)), reverse=True, key=lambda x: x[1])[1: top_n + 1]
results = []
for i, similarity_score in top_indices:
name = characters_df.iloc[i]['character']
image_path = get_image_path(name)
results.append((name.title(), image_path, similarity_score))
return results
# =============================
# Visualization Functions
# =============================
@st.cache_data
def compute_tsne_2d(embeddings, perplexity=30, random_state=42):
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=random_state)
return tsne.fit_transform(embeddings)
@st.cache_data
def compute_tsne_3d(embeddings, perplexity=30, random_state=42):
tsne = TSNE(n_components=3, perplexity=perplexity, random_state=random_state)
return tsne.fit_transform(embeddings)
@st.cache_data
def compute_pca_2d(embeddings):
pca = PCA(n_components=2)
return pca.fit_transform(embeddings)
@st.cache_data
def compute_pca_3d(embeddings):
pca = PCA(n_components=3)
return pca.fit_transform(embeddings)
def create_2d_plot(coords, characters, title, method):
df_plot = pd.DataFrame({
'x': coords[:, 0],
'y': coords[:, 1],
'character': characters
})
fig = px.scatter(
df_plot,
x='x',
y='y',
text='character',
title=f"{title} - {method}",
hover_data={'character': True, 'x': ':.3f', 'y': ':.3f'}
)
fig.update_traces(
textposition="top center",
textfont_size=8,
marker=dict(size=8, opacity=0.7)
)
fig.update_layout(
height=600,
showlegend=False,
xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2"
)
return fig
def create_3d_plot(coords, characters, title, method):
fig = go.Figure(data=[go.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode='markers+text',
text=characters,
textposition="top center",
textfont_size=8,
marker=dict(
size=6,
opacity=0.7,
color=coords[:, 0],
colorscale='Viridis',
showscale=True
),
hovertemplate='<b>%{text}</b><br>' +
f'{method} 1: %{{x:.3f}}<br>' +
f'{method} 2: %{{y:.3f}}<br>' +
f'{method} 3: %{{z:.3f}}<br>' +
'<extra></extra>'
)])
fig.update_layout(
title=f"{title} - {method}",
scene=dict(
xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2",
zaxis_title=f"{method} Component 3"
),
height=600
)
return fig
# =============================
# Streamlit App
# =============================
def main():
st.set_page_config(
page_title="GoT Character Similarity Explorer",
page_icon="βš”οΈ",
layout="wide"
)
st.title("βš”οΈ Game of Thrones Character Similarity Explorer")
# Load data
characters_df, character_names, sbert_embeddings, tfidf_embeddings, sbert_sim, tfidf_sim = load_data()
# Tabs
tab1, tab2 = st.tabs(["πŸ” Character Similarity", "πŸ“Š Dimensionality Reduction"])
# -----------------------------
# Tab 1: Character Similarity
# -----------------------------
with tab1:
st.markdown("Select a model and character to view top semantic matches!")
with st.sidebar:
st.header("Settings")
model_type = st.radio("Select Embedding Model:", ["SBERT", "TFIDF", "Hybrid"])
selected_character = st.selectbox("Choose Character:", character_names)
# Number of similar characters
top_n = st.slider("How many similar characters?", 1, 20, 5)
# Hybrid weight if chosen
weight = 0.7
if model_type == "Hybrid":
weight = st.slider("Weight for SBERT (TF-IDF = 1 - weight)", 0.0, 1.0, 0.7, 0.1)
if st.button("Find Similar Characters", type="primary", key="search_button"):
st.session_state.selected_character = selected_character
st.session_state.model_type = model_type
st.session_state.top_n = top_n
st.session_state.weight = weight
result_placeholder = st.empty()
if "selected_character" in st.session_state:
results = recommend_characters(
st.session_state.model_type,
st.session_state.selected_character,
characters_df,
sbert_sim,
tfidf_sim,
top_n=st.session_state.top_n,
weight=st.session_state.weight
)
with result_placeholder.container():
st.subheader(
f"Characters similar to **{st.session_state.selected_character}** "
f"(using {st.session_state.model_type})"
)
if results:
cols = st.columns(min(5, len(results)))
df_results = []
for idx, (name, image_path, similarity) in enumerate(results):
df_results.append({"Character": name, "Similarity": similarity})
with cols[idx % len(cols)]:
if image_path and os.path.exists(image_path):
try:
st.image(image_path, use_container_width=True, caption=name)
except Exception:
st.info("No image available")
else:
st.info("No image available")
st.caption(f"Similarity: {similarity:.3f}")
# Download CSV
df_results = pd.DataFrame(df_results)
csv = df_results.to_csv(index=False).encode("utf-8")
st.download_button("πŸ“₯ Download Results as CSV", csv, "similar_characters.csv", "text/csv")
else:
st.error("No similar characters found.")
else:
st.info("πŸ‘ˆ Select a character from the sidebar and click 'Find Similar Characters' to get started!")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("<h3 style='text-align:center;'>Total Characters</h3>", unsafe_allow_html=True)
st.markdown(f"<h1 style='text-align:center;'>{len(character_names)}</h1>", unsafe_allow_html=True)
with col2:
st.markdown("<h3 style='text-align:center;'>Embedding Models</h3>", unsafe_allow_html=True)
st.markdown("<h2 style='text-align:center;'>SBERT, TF-IDF, Hybrid</h2>", unsafe_allow_html=True)
with col3:
st.markdown("<h3 style='text-align:center;'>Similarity Algorithm</h3>", unsafe_allow_html=True)
st.markdown("<h1 style='text-align:center;'>Cosine</h1>", unsafe_allow_html=True)
# -----------------------------
# Tab 2: Dimensionality Reduction
# -----------------------------
with tab2:
st.markdown("### Interactive Dimensionality Reduction Visualizations")
st.markdown("Explore character embeddings in 2D and 3D space using t-SNE and PCA")
col1, col2, col3 = st.columns(3)
with col1:
viz_model = st.selectbox("Embedding Model:", ["SBERT", "TFIDF"], key="viz_model")
with col2:
viz_method = st.selectbox("Reduction Method:", ["t-SNE", "PCA"], key="viz_method")
with col3:
viz_dims = st.selectbox("Dimensions:", ["2D", "3D"], key="viz_dims")
perplexity = 30
if viz_method == "t-SNE":
perplexity = st.slider(
"Perplexity (t-SNE parameter):",
min_value=5,
max_value=50,
value=30,
help="Lower values focus on local structure, higher values on global structure"
)
if st.button("Generate Visualization", type="primary", key="viz_button"):
with st.spinner(f"Computing {viz_method} {viz_dims} for {viz_model} embeddings..."):
embeddings = np.array(sbert_embeddings) if viz_model == "SBERT" else np.array(tfidf_embeddings)
characters = characters_df['character'].tolist()
try:
# Compute coordinates
if viz_method == "t-SNE" and viz_dims == "2D":
coords = compute_tsne_2d(embeddings, perplexity=perplexity)
fig = create_2d_plot(coords, characters, f"{viz_model} Embeddings", "t-SNE")
elif viz_method == "t-SNE" and viz_dims == "3D":
coords = compute_tsne_3d(embeddings, perplexity=perplexity)
fig = create_3d_plot(coords, characters, f"{viz_model} Embeddings", "t-SNE")
elif viz_method == "PCA" and viz_dims == "2D":
coords = compute_pca_2d(embeddings)
fig = create_2d_plot(coords, characters, f"{viz_model} Embeddings", "PCA")
elif viz_method == "PCA" and viz_dims == "3D":
coords = compute_pca_3d(embeddings)
fig = create_3d_plot(coords, characters, f"{viz_model} Embeddings", "PCA")
st.plotly_chart(fig, use_container_width=True)
# Download options
try:
fig_html = fig.to_html()
st.download_button(
"πŸ“₯ Download Plot as HTML",
fig_html,
file_name=f"{viz_model}_{viz_method}_{viz_dims}.html",
mime="text/html"
)
except Exception as e:
st.error(f"Error in saving HTML File because of {e} error")
st.info(f"""
**Visualization Info:**
- Model: {viz_model}
- Method: {viz_method} {viz_dims}
- Characters: {len(characters)}
- Original dimensions: {embeddings.shape[1]}
""" + (f"- Perplexity: {perplexity}" if viz_method == "t-SNE" else ""))
except Exception as e:
st.error(f"Error generating visualization: {e}")
with st.expander("ℹ️ About Dimensionality Reduction Methods"):
st.markdown("""
**t-SNE (t-Distributed Stochastic Neighbor Embedding):**
- Great for visualizing clusters and local neighborhoods
- Non-linear method that preserves local structure
- Good for finding groups of similar characters
- Perplexity controls local vs global structure focus
**PCA (Principal Component Analysis):**
- Linear method that preserves global variance
- Shows the main directions of variation in the data
- Faster computation than t-SNE
- Components have interpretable meaning
**2D vs 3D:**
- 2D is easier to interpret and interact with
- 3D can reveal additional structure but may be harder to read
""")
if __name__ == "__main__":
main()