mobrown commited on
Commit
8c497f4
1 Parent(s): 669acda

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.decomposition import PCA
3
+ import gensim.downloader as api
4
+ import gradio as gr
5
+ import plotly.graph_objects as go
6
+
7
+ # Load the Word2Vec model
8
+ model = api.load("word2vec-google-news-300")
9
+
10
+
11
+ def gensim_analogy(model, word1, word2, word3):
12
+ try:
13
+ result = model.most_similar(positive=[word2, word3], negative=[word1], topn=1)
14
+ return result[0][0] # Return the word
15
+ except KeyError as e:
16
+ return str(e)
17
+
18
+
19
+ def plot_words_plotly(model, words):
20
+ vectors = np.array([model[word] for word in words if word in model.key_to_index])
21
+
22
+ # Reduce dimensions to 2D for plotting
23
+ pca = PCA(n_components=2)
24
+ vectors_2d = pca.fit_transform(vectors)
25
+
26
+ # Create a scatter plot
27
+ fig = go.Figure()
28
+
29
+ # Add scatter points for each word vector
30
+ for word, vec in zip(words, vectors_2d):
31
+ fig.add_trace(go.Scatter(x=[vec[0]], y=[vec[1]],
32
+ text=[word], mode='markers+text',
33
+ textposition="bottom center",
34
+ name=word))
35
+
36
+ fig.update_layout(title="Word Vectors Visualization",
37
+ xaxis_title="PCA 1",
38
+ yaxis_title="PCA 2",
39
+ showlegend=True)
40
+
41
+ return fig
42
+
43
+
44
+ def gradio_interface(choice, custom_input=None):
45
+ if choice == "Custom":
46
+ if not custom_input or len(custom_input.split(", ")) != 3:
47
+ return "Invalid input. Please enter exactly three words, separated by commas.", None, {
48
+ "error": "Invalid input"}
49
+ words = custom_input.split(", ")
50
+ else:
51
+ words = choice.split(", ")
52
+
53
+ word1, word2, word3 = words
54
+ word4 = gensim_analogy(model, word1, word2, word3)
55
+ plot_fig = plot_words_plotly(model, [word1, word2, word3, word4])
56
+
57
+ if word4 in model.key_to_index:
58
+ vector = model[word4]
59
+ vector_display = {word4: [round(num, 2) for num in vector.tolist()]}
60
+ else:
61
+ vector_display = {"error": "Vector not available for the resulting word"}
62
+
63
+ return word4, plot_fig, vector_display
64
+
65
+
66
+ choices = [
67
+ "man, king, woman",
68
+ "Paris, France, London",
69
+ "strong, stronger, weak",
70
+ "pork, pig, beef",
71
+ "Custom"
72
+ ]
73
+
74
+ iface = gr.Interface(
75
+ fn=gradio_interface,
76
+ inputs=[
77
+ gr.Dropdown(choices=choices, label="Choose predefined words or enter custom words"),
78
+ gr.Textbox(label="Custom words (comma-separated, required for custom choice; use only if 'Custom' is selected)",
79
+ placeholder="Enter 3 words separated by commas")
80
+ ],
81
+ outputs=["text", "plot", "json"],
82
+ title="Word Analogy and Vector Visualization with Plotly",
83
+ description="Select a predefined triplet of words or choose 'Custom' and enter your own (comma-separated) to find a fourth word by analogy, and see their vectors plotted with Plotly."
84
+ )
85
+
86
+ iface.launch(share=True)