iqramukhtiar commited on
Commit
86493fa
·
verified ·
1 Parent(s): 5452837

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +229 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
+ from matplotlib.colors import LinearSegmentedColormap
7
+ import seaborn as sns
8
+ import io
9
+ from PIL import Image
10
+
11
+ class TransformerVisualizer:
12
+ def __init__(self, model_name):
13
+ self.model_name = model_name
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ self.model = AutoModel.from_pretrained(model_name)
16
+
17
+ def tokenize(self, sentence):
18
+ # Get tokens without special tokens
19
+ tokens = self.tokenizer.tokenize(sentence)
20
+ return tokens, f"Original sentence: '{sentence}'\nTokenized: {tokens}"
21
+
22
+ def add_special_tokens(self, tokens):
23
+ # Add special tokens manually to show the process
24
+ tokens_with_special = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
25
+ return tokens_with_special, f"With special tokens: {tokens_with_special}"
26
+
27
+ def get_token_ids(self, sentence):
28
+ # Get token IDs with special tokens included
29
+ inputs = self.tokenizer(sentence, return_tensors="pt")
30
+ token_ids = inputs["input_ids"][0].tolist()
31
+ tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
32
+
33
+ result = "Token ID Mapping:\n"
34
+ for token, token_id in zip(tokens, token_ids):
35
+ result += f"Token: '{token}', ID: {token_id}\n"
36
+
37
+ return token_ids, tokens, result
38
+
39
+ def get_embeddings(self, sentence):
40
+ # Get embeddings
41
+ inputs = self.tokenizer(sentence, return_tensors="pt")
42
+ with torch.no_grad():
43
+ outputs = self.model(**inputs)
44
+
45
+ # Get the embeddings from the first layer
46
+ embeddings = outputs.last_hidden_state[0].numpy()
47
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
48
+
49
+ result = f"Embedding shape: {embeddings.shape}\n"
50
+ result += f"Each token is represented by a {embeddings.shape[1]}-dimensional vector"
51
+
52
+ # Create embedding heatmap
53
+ fig = plt.figure(figsize=(12, len(tokens) * 0.5))
54
+
55
+ # Only show first few dimensions to make it readable
56
+ dims = 10
57
+ embedding_subset = embeddings[:, :dims]
58
+
59
+ # Create a custom colormap
60
+ cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"])
61
+
62
+ # Plot heatmap
63
+ sns.heatmap(embedding_subset,
64
+ cmap=cmap,
65
+ center=0,
66
+ xticklabels=[f"Dim {i+1}" for i in range(dims)],
67
+ yticklabels=tokens,
68
+ annot=False)
69
+
70
+ plt.title(f"Word Embeddings (first {dims} dimensions)")
71
+ plt.tight_layout()
72
+
73
+ # Convert plot to image
74
+ buf = io.BytesIO()
75
+ plt.savefig(buf, format='png')
76
+ plt.close(fig)
77
+ buf.seek(0)
78
+ embedding_img = Image.open(buf)
79
+
80
+ return embeddings, tokens, result, embedding_img
81
+
82
+ def get_positional_encoding(self, seq_length, d_model=768):
83
+ # Create positional encodings
84
+ position = np.arange(seq_length)[:, np.newaxis]
85
+ div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
86
+
87
+ pos_encoding = np.zeros((seq_length, d_model))
88
+ pos_encoding[:, 0::2] = np.sin(position * div_term)
89
+ pos_encoding[:, 1::2] = np.cos(position * div_term)
90
+
91
+ result = f"Positional encoding shape: {pos_encoding.shape}\n"
92
+ result += f"Generated for sequence length: {seq_length}"
93
+
94
+ # Visualize positional encodings
95
+ fig1 = plt.figure(figsize=(12, 6))
96
+
97
+ # Only show first 20 dimensions to make it readable
98
+ dims_to_show = min(20, d_model)
99
+
100
+ # Create a custom colormap
101
+ cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"])
102
+
103
+ sns.heatmap(pos_encoding[:, :dims_to_show],
104
+ cmap=cmap,
105
+ center=0,
106
+ xticklabels=[f"Dim {i+1}" for i in range(dims_to_show)],
107
+ yticklabels=[f"Pos {i+1}" for i in range(seq_length)])
108
+
109
+ plt.title(f"Positional Encodings (first {dims_to_show} dimensions)")
110
+ plt.xlabel("Embedding Dimension")
111
+ plt.ylabel("Position in Sequence")
112
+ plt.tight_layout()
113
+
114
+ # Convert plot to image
115
+ buf1 = io.BytesIO()
116
+ plt.savefig(buf1, format='png')
117
+ plt.close(fig1)
118
+ buf1.seek(0)
119
+ pos_encoding_img = Image.open(buf1)
120
+
121
+ # Plot sine waves for a few dimensions
122
+ fig2 = plt.figure(figsize=(12, 6))
123
+
124
+ dims_to_plot = [0, 2, 4, 20, 100]
125
+ for i, dim in enumerate(dims_to_plot):
126
+ if dim < pos_encoding.shape[1]:
127
+ plt.plot(pos_encoding[:, dim], label=f"Dim {dim} (sin)")
128
+
129
+ plt.title("Positional Encoding Sine Waves")
130
+ plt.xlabel("Position")
131
+ plt.ylabel("Value")
132
+ plt.legend()
133
+ plt.grid(True)
134
+ plt.tight_layout()
135
+
136
+ # Convert plot to image
137
+ buf2 = io.BytesIO()
138
+ plt.savefig(buf2, format='png')
139
+ plt.close(fig2)
140
+ buf2.seek(0)
141
+ pos_waves_img = Image.open(buf2)
142
+
143
+ return result, pos_encoding_img, pos_waves_img
144
+
145
+ def process_text(sentence, model_name):
146
+ visualizer = TransformerVisualizer(model_name)
147
+
148
+ # 1. Tokenization
149
+ tokens, tokenization_text = visualizer.tokenize(sentence)
150
+
151
+ # 2. Special Tokens
152
+ tokens_with_special, special_tokens_text = visualizer.add_special_tokens(tokens)
153
+
154
+ # 3. Token IDs
155
+ token_ids, tokens, token_ids_text = visualizer.get_token_ids(sentence)
156
+
157
+ # 4. Word Embeddings
158
+ embeddings, tokens, embeddings_text, embedding_img = visualizer.get_embeddings(sentence)
159
+
160
+ # 5. Positional Encoding
161
+ pos_encoding_text, pos_encoding_img, pos_waves_img = visualizer.get_positional_encoding(len(token_ids))
162
+
163
+ return (tokenization_text, special_tokens_text, token_ids_text,
164
+ embeddings_text, embedding_img, pos_encoding_text,
165
+ pos_encoding_img, pos_waves_img)
166
+
167
+ # Create Gradio interface
168
+ models = [
169
+ "bert-base-uncased",
170
+ "roberta-base",
171
+ "distilbert-base-uncased",
172
+ "gpt2",
173
+ "albert-base-v2",
174
+ "xlm-roberta-base"
175
+ ]
176
+
177
+ with gr.Blocks(title="Transformer Process Visualizer") as demo:
178
+ gr.Markdown("# Transformer Process Visualizer")
179
+ gr.Markdown("This app visualizes the key processes in transformer models: tokenization, special tokens, token IDs, word embeddings, and positional encoding.")
180
+
181
+ with gr.Row():
182
+ with gr.Column():
183
+ input_text = gr.Textbox(
184
+ label="Input Sentence",
185
+ placeholder="Enter a sentence to visualize transformer processes",
186
+ value="The transformer architecture revolutionized natural language processing."
187
+ )
188
+ model_dropdown = gr.Dropdown(
189
+ label="Select Model",
190
+ choices=models,
191
+ value="bert-base-uncased"
192
+ )
193
+ submit_btn = gr.Button("Visualize")
194
+
195
+ with gr.Tabs():
196
+ with gr.TabItem("Tokenization"):
197
+ tokenization_output = gr.Textbox(label="Tokenization")
198
+
199
+ with gr.TabItem("Special Tokens"):
200
+ special_tokens_output = gr.Textbox(label="Special Tokens")
201
+
202
+ with gr.TabItem("Token IDs"):
203
+ token_ids_output = gr.Textbox(label="Token IDs")
204
+
205
+ with gr.TabItem("Word Embeddings"):
206
+ embeddings_output = gr.Textbox(label="Embeddings Info")
207
+ embedding_plot = gr.Image(label="Embedding Visualization")
208
+
209
+ with gr.TabItem("Positional Encoding"):
210
+ pos_encoding_output = gr.Textbox(label="Positional Encoding Info")
211
+ pos_encoding_plot = gr.Image(label="Positional Encoding Heatmap")
212
+ pos_waves_plot = gr.Image(label="Positional Encoding Waves")
213
+
214
+ submit_btn.click(
215
+ process_text,
216
+ inputs=[input_text, model_dropdown],
217
+ outputs=[
218
+ tokenization_output,
219
+ special_tokens_output,
220
+ token_ids_output,
221
+ embeddings_output,
222
+ embedding_plot,
223
+ pos_encoding_output,
224
+ pos_encoding_plot,
225
+ pos_waves_plot
226
+ ]
227
+ )
228
+
229
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ numpy
4
+ matplotlib
5
+ seaborn
6
+ gradio
7
+ pillow