Daksh0505 commited on
Commit
60d52c9
·
verified ·
1 Parent(s): 4e53b85

Upload 4 files

Browse files
app2.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import joblib
3
+ import pandas as pd
4
+ import numpy as np
5
+ import os
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from sklearn.manifold import TSNE
8
+ from sklearn.decomposition import PCA
9
+ from PIL import Image
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots
13
+
14
+ # Cache the data loading
15
+ @st.cache_data
16
+ def load_data():
17
+ characters_df = pd.DataFrame(joblib.load('characters_list_got.joblib'), columns=['character'])
18
+ characters_df['normalized'] = characters_df['character'].str.lower().str.strip()
19
+ character_names = sorted(characters_df['character'].tolist())
20
+ sbert_embeddings = joblib.load('embeddings_got.joblib')
21
+ tfidf_embeddings = joblib.load('tfidf_embeddings_got.joblib')
22
+ return characters_df, character_names, sbert_embeddings, tfidf_embeddings
23
+
24
+ def name_to_folder(name):
25
+ return name.lower().replace(" ", "_")
26
+
27
+ def get_image_path(name):
28
+ normalized = name.lower().strip()
29
+ folder_name = name_to_folder(normalized)
30
+
31
+ # Try different extensions
32
+ for ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp']:
33
+ candidate_path = os.path.join("images", folder_name, f"000001.{ext}")
34
+ if os.path.exists(candidate_path):
35
+ return candidate_path
36
+
37
+ # Fallback to placeholder
38
+ placeholder_path = "images/placeholder.jpg"
39
+ return placeholder_path if os.path.exists(placeholder_path) else None
40
+
41
+ def recommend_characters(model_type, input_character, characters_df, sbert_embeddings, tfidf_embeddings):
42
+ input_character = input_character.lower().strip()
43
+
44
+ if input_character not in characters_df['normalized'].values:
45
+ return []
46
+
47
+ character_index = characters_df[characters_df['normalized'] == input_character].index[0]
48
+ embeddings = sbert_embeddings if model_type == "SBERT" else tfidf_embeddings
49
+ similarity_matrix = cosine_similarity(np.array(embeddings))
50
+ distances = similarity_matrix[character_index]
51
+
52
+ # Get top 5 similar characters
53
+ top_indices = sorted(list(enumerate(distances)), reverse=True, key=lambda x: x[1])[1:6]
54
+
55
+ results = []
56
+ for i, similarity_score in top_indices:
57
+ name = characters_df.iloc[i]['character']
58
+ image_path = get_image_path(name)
59
+ results.append((name.title(), image_path, similarity_score))
60
+
61
+ return results
62
+
63
+ # Visualization functions
64
+ @st.cache_data
65
+ def compute_tsne_2d(embeddings, perplexity=30, random_state=42):
66
+ """Compute 2D t-SNE"""
67
+ tsne = TSNE(n_components=2, perplexity=perplexity, random_state=random_state)
68
+ return tsne.fit_transform(embeddings)
69
+
70
+ @st.cache_data
71
+ def compute_tsne_3d(embeddings, perplexity=30, random_state=42):
72
+ """Compute 3D t-SNE"""
73
+ tsne = TSNE(n_components=3, perplexity=perplexity, random_state=random_state)
74
+ return tsne.fit_transform(embeddings)
75
+
76
+ @st.cache_data
77
+ def compute_pca_2d(embeddings):
78
+ """Compute 2D PCA"""
79
+ pca = PCA(n_components=2)
80
+ return pca.fit_transform(embeddings)
81
+
82
+ @st.cache_data
83
+ def compute_pca_3d(embeddings):
84
+ """Compute 3D PCA"""
85
+ pca = PCA(n_components=3)
86
+ return pca.fit_transform(embeddings)
87
+
88
+ def create_2d_plot(coords, characters, title, method):
89
+ """Create 2D scatter plot"""
90
+ df_plot = pd.DataFrame({
91
+ 'x': coords[:, 0],
92
+ 'y': coords[:, 1],
93
+ 'character': characters
94
+ })
95
+
96
+ fig = px.scatter(
97
+ df_plot,
98
+ x='x',
99
+ y='y',
100
+ text='character',
101
+ title=f"{title} - {method}",
102
+ hover_data={'character': True, 'x': ':.3f', 'y': ':.3f'}
103
+ )
104
+
105
+ fig.update_traces(
106
+ textposition="top center",
107
+ textfont_size=8,
108
+ marker=dict(size=8, opacity=0.7)
109
+ )
110
+
111
+ fig.update_layout(
112
+ height=600,
113
+ showlegend=False,
114
+ xaxis_title=f"{method} Component 1",
115
+ yaxis_title=f"{method} Component 2"
116
+ )
117
+
118
+ return fig
119
+
120
+ def create_3d_plot(coords, characters, title, method):
121
+ """Create 3D scatter plot"""
122
+ fig = go.Figure(data=[go.Scatter3d(
123
+ x=coords[:, 0],
124
+ y=coords[:, 1],
125
+ z=coords[:, 2],
126
+ mode='markers+text',
127
+ text=characters,
128
+ textposition="top center",
129
+ textfont_size=8,
130
+ marker=dict(
131
+ size=6,
132
+ opacity=0.7,
133
+ color=coords[:, 0], # Color by first component
134
+ colorscale='Viridis',
135
+ showscale=True
136
+ ),
137
+ hovertemplate='<b>%{text}</b><br>' +
138
+ f'{method} 1: %{{x:.3f}}<br>' +
139
+ f'{method} 2: %{{y:.3f}}<br>' +
140
+ f'{method} 3: %{{z:.3f}}<br>' +
141
+ '<extra></extra>'
142
+ )])
143
+
144
+ fig.update_layout(
145
+ title=f"{title} - {method}",
146
+ scene=dict(
147
+ xaxis_title=f"{method} Component 1",
148
+ yaxis_title=f"{method} Component 2",
149
+ zaxis_title=f"{method} Component 3"
150
+ ),
151
+ height=600
152
+ )
153
+
154
+ return fig
155
+
156
+ # Streamlit App
157
+ def main():
158
+ st.set_page_config(
159
+ page_title="GoT Character Similarity Explorer",
160
+ page_icon="⚔️",
161
+ layout="wide"
162
+ )
163
+
164
+ st.title("⚔️ Game of Thrones Character Similarity Explorer")
165
+
166
+ # Load data
167
+ characters_df, character_names, sbert_embeddings, tfidf_embeddings = load_data()
168
+
169
+ # Create tabs
170
+ tab1, tab2 = st.tabs(["🔍 Character Similarity", "📊 Dimensionality Reduction"])
171
+
172
+ with tab1:
173
+ st.markdown("Select a model and character to view top semantic matches!")
174
+
175
+ # Sidebar controls
176
+ with st.sidebar:
177
+ st.header("Settings")
178
+ model_type = st.radio(
179
+ "Select Embedding Model:",
180
+ ["SBERT", "TFIDF"],
181
+ help="Choose between SBERT (semantic) or TF-IDF (keyword-based) similarity"
182
+ )
183
+
184
+ selected_character = st.selectbox(
185
+ "Choose Character:",
186
+ character_names,
187
+ help="Select a character to find similar ones"
188
+ )
189
+
190
+ if st.button("Find Similar Characters", type="primary"):
191
+ st.session_state.search_clicked = True
192
+ else:
193
+ st.session_state.search_clicked = getattr(st.session_state, 'search_clicked', False)
194
+
195
+ # Main content
196
+ if st.session_state.search_clicked and selected_character:
197
+ st.subheader(f"Characters similar to **{selected_character}** (using {model_type})")
198
+
199
+ # Get recommendations
200
+ results = recommend_characters(
201
+ model_type, selected_character, characters_df, sbert_embeddings, tfidf_embeddings
202
+ )
203
+
204
+ if results:
205
+ # Display in columns
206
+ cols = st.columns(5)
207
+
208
+ for idx, (name, image_path, similarity) in enumerate(results):
209
+ with cols[idx]:
210
+ if image_path and os.path.exists(image_path):
211
+ try:
212
+ image = Image.open(image_path)
213
+ st.image(image, use_container_width=True)
214
+ except Exception as e:
215
+ st.error(f"Could not load image: {e}")
216
+ else:
217
+ st.info("No image available")
218
+
219
+ st.markdown(f"**{name}**")
220
+ st.caption(f"Similarity: {similarity:.3f}")
221
+ else:
222
+ st.error("Character not found or no similar characters available.")
223
+
224
+ else:
225
+ # Welcome message
226
+ st.info("👈 Select a character from the sidebar and click 'Find Similar Characters' to get started!")
227
+
228
+ # Show some stats
229
+ col1, col2, col3 = st.columns(3)
230
+ with col1:
231
+ st.metric("Total Characters", len(character_names))
232
+ with col2:
233
+ st.metric("Embedding Models", "2")
234
+ with col3:
235
+ st.metric("Similarity Algorithm", "Cosine")
236
+
237
+ with tab2:
238
+ st.markdown("### Interactive Dimensionality Reduction Visualizations")
239
+ st.markdown("Explore character embeddings in 2D and 3D space using t-SNE and PCA")
240
+
241
+ # Controls for visualization
242
+ col1, col2, col3 = st.columns(3)
243
+
244
+ with col1:
245
+ viz_model = st.selectbox(
246
+ "Embedding Model:",
247
+ ["SBERT", "TFIDF"],
248
+ key="viz_model"
249
+ )
250
+
251
+ with col2:
252
+ viz_method = st.selectbox(
253
+ "Reduction Method:",
254
+ ["t-SNE", "PCA"],
255
+ key="viz_method"
256
+ )
257
+
258
+ with col3:
259
+ viz_dims = st.selectbox(
260
+ "Dimensions:",
261
+ ["2D", "3D"],
262
+ key="viz_dims"
263
+ )
264
+
265
+ # Additional parameters for t-SNE
266
+ if viz_method == "t-SNE":
267
+ perplexity = st.slider(
268
+ "Perplexity (t-SNE parameter):",
269
+ min_value=5,
270
+ max_value=50,
271
+ value=30,
272
+ help="Lower values focus on local structure, higher values on global structure"
273
+ )
274
+
275
+ # Generate visualization button
276
+ if st.button("Generate Visualization", type="primary"):
277
+ with st.spinner(f"Computing {viz_method} {viz_dims} for {viz_model} embeddings..."):
278
+ # Get the right embeddings
279
+ embeddings = np.array(sbert_embeddings) if viz_model == "SBERT" else np.array(tfidf_embeddings)
280
+ characters = characters_df['character'].tolist()
281
+
282
+ try:
283
+ # Compute coordinates based on method and dimensions
284
+ if viz_method == "t-SNE" and viz_dims == "2D":
285
+ coords = compute_tsne_2d(embeddings, perplexity=perplexity if viz_method == "t-SNE" else 30)
286
+ fig = create_2d_plot(coords, characters, f"{viz_model} Embeddings", "t-SNE")
287
+
288
+ elif viz_method == "t-SNE" and viz_dims == "3D":
289
+ coords = compute_tsne_3d(embeddings, perplexity=perplexity if viz_method == "t-SNE" else 30)
290
+ fig = create_3d_plot(coords, characters, f"{viz_model} Embeddings", "t-SNE")
291
+
292
+ elif viz_method == "PCA" and viz_dims == "2D":
293
+ coords = compute_pca_2d(embeddings)
294
+ fig = create_2d_plot(coords, characters, f"{viz_model} Embeddings", "PCA")
295
+
296
+ elif viz_method == "PCA" and viz_dims == "3D":
297
+ coords = compute_pca_3d(embeddings)
298
+ fig = create_3d_plot(coords, characters, f"{viz_model} Embeddings", "PCA")
299
+
300
+ # Display the plot
301
+ st.plotly_chart(fig, use_container_width=True)
302
+
303
+ # Show some information about the visualization
304
+ st.info(f"""
305
+ **Visualization Info:**
306
+ - Model: {viz_model}
307
+ - Method: {viz_method} {viz_dims}
308
+ - Characters: {len(characters)}
309
+ - Original dimensions: {embeddings.shape[1]}
310
+ """ + (f"- Perplexity: {perplexity}" if viz_method == "t-SNE" else ""))
311
+
312
+ except Exception as e:
313
+ st.error(f"Error generating visualization: {str(e)}")
314
+
315
+ # Information about methods
316
+ with st.expander("ℹ️ About Dimensionality Reduction Methods"):
317
+ st.markdown("""
318
+ **t-SNE (t-Distributed Stochastic Neighbor Embedding):**
319
+ - Great for visualizing clusters and local neighborhoods
320
+ - Non-linear method that preserves local structure
321
+ - Good for finding groups of similar characters
322
+ - Perplexity controls local vs global structure focus
323
+
324
+ **PCA (Principal Component Analysis):**
325
+ - Linear method that preserves global variance
326
+ - Shows the main directions of variation in the data
327
+ - Faster computation than t-SNE
328
+ - Components have interpretable meaning
329
+
330
+ **2D vs 3D:**
331
+ - 2D is easier to interpret and interact with
332
+ - 3D can reveal additional structure but may be harder to read
333
+ """)
334
+
335
+ if __name__ == "__main__":
336
+ main()
characters_list_got.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edfff5a75d926592b2f646ab7e88eece666b7ff3dcf78a599f010f88422fd0af
3
+ size 1810
embeddings_got.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcc7af34e18c61e74630ba2446ad1773dfd47c2054b47c56382986c3d947d305
3
+ size 377714
tfidf_embeddings_got.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb3e10c1b896b2a42d0fb774f6219122ceefa09669ab9b46e7c9c893d9c4c9aa
3
+ size 9782794