Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from matplotlib.colors import LinearSegmentedColormap | |
| import seaborn as sns | |
| import io | |
| from PIL import Image | |
| class TransformerVisualizer: | |
| def __init__(self, model_name): | |
| self.model_name = model_name | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| def tokenize(self, sentence): | |
| # Get tokens without special tokens | |
| tokens = self.tokenizer.tokenize(sentence) | |
| return tokens, f"Original sentence: '{sentence}'\nTokenized: {tokens}" | |
| def add_special_tokens(self, tokens): | |
| # Add special tokens manually to show the process | |
| tokens_with_special = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] | |
| return tokens_with_special, f"With special tokens: {tokens_with_special}" | |
| def get_token_ids(self, sentence): | |
| # Get token IDs with special tokens included | |
| inputs = self.tokenizer(sentence, return_tensors="pt") | |
| token_ids = inputs["input_ids"][0].tolist() | |
| tokens = self.tokenizer.convert_ids_to_tokens(token_ids) | |
| result = "Token ID Mapping:\n" | |
| for token, token_id in zip(tokens, token_ids): | |
| result += f"Token: '{token}', ID: {token_id}\n" | |
| return token_ids, tokens, result | |
| def get_embeddings(self, sentence): | |
| # Get embeddings | |
| inputs = self.tokenizer(sentence, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Get the embeddings from the first layer | |
| embeddings = outputs.last_hidden_state[0].numpy() | |
| tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| result = f"Embedding shape: {embeddings.shape}\n" | |
| result += f"Each token is represented by a {embeddings.shape[1]}-dimensional vector" | |
| # Create embedding heatmap | |
| fig = plt.figure(figsize=(12, len(tokens) * 0.5)) | |
| # Only show first few dimensions to make it readable | |
| dims = 10 | |
| embedding_subset = embeddings[:, :dims] | |
| # Create a custom colormap | |
| cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"]) | |
| # Plot heatmap | |
| sns.heatmap(embedding_subset, | |
| cmap=cmap, | |
| center=0, | |
| xticklabels=[f"Dim {i+1}" for i in range(dims)], | |
| yticklabels=tokens, | |
| annot=False) | |
| plt.title(f"Word Embeddings (first {dims} dimensions)") | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close(fig) | |
| buf.seek(0) | |
| embedding_img = Image.open(buf) | |
| return embeddings, tokens, result, embedding_img | |
| def get_positional_encoding(self, seq_length, d_model=768): | |
| # Create positional encodings | |
| position = np.arange(seq_length)[:, np.newaxis] | |
| div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) | |
| pos_encoding = np.zeros((seq_length, d_model)) | |
| pos_encoding[:, 0::2] = np.sin(position * div_term) | |
| pos_encoding[:, 1::2] = np.cos(position * div_term) | |
| result = f"Positional encoding shape: {pos_encoding.shape}\n" | |
| result += f"Generated for sequence length: {seq_length}" | |
| # Visualize positional encodings | |
| fig1 = plt.figure(figsize=(12, 6)) | |
| # Only show first 20 dimensions to make it readable | |
| dims_to_show = min(20, d_model) | |
| # Create a custom colormap | |
| cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"]) | |
| sns.heatmap(pos_encoding[:, :dims_to_show], | |
| cmap=cmap, | |
| center=0, | |
| xticklabels=[f"Dim {i+1}" for i in range(dims_to_show)], | |
| yticklabels=[f"Pos {i+1}" for i in range(seq_length)]) | |
| plt.title(f"Positional Encodings (first {dims_to_show} dimensions)") | |
| plt.xlabel("Embedding Dimension") | |
| plt.ylabel("Position in Sequence") | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf1 = io.BytesIO() | |
| plt.savefig(buf1, format='png') | |
| plt.close(fig1) | |
| buf1.seek(0) | |
| pos_encoding_img = Image.open(buf1) | |
| # Plot sine waves for a few dimensions | |
| fig2 = plt.figure(figsize=(12, 6)) | |
| dims_to_plot = [0, 2, 4, 20, 100] | |
| for i, dim in enumerate(dims_to_plot): | |
| if dim < pos_encoding.shape[1]: | |
| plt.plot(pos_encoding[:, dim], label=f"Dim {dim} (sin)") | |
| plt.title("Positional Encoding Sine Waves") | |
| plt.xlabel("Position") | |
| plt.ylabel("Value") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf2 = io.BytesIO() | |
| plt.savefig(buf2, format='png') | |
| plt.close(fig2) | |
| buf2.seek(0) | |
| pos_waves_img = Image.open(buf2) | |
| return result, pos_encoding_img, pos_waves_img | |
| def process_text(sentence, model_name): | |
| visualizer = TransformerVisualizer(model_name) | |
| # 1. Tokenization | |
| tokens, tokenization_text = visualizer.tokenize(sentence) | |
| # 2. Special Tokens | |
| tokens_with_special, special_tokens_text = visualizer.add_special_tokens(tokens) | |
| # 3. Token IDs | |
| token_ids, tokens, token_ids_text = visualizer.get_token_ids(sentence) | |
| # 4. Word Embeddings | |
| embeddings, tokens, embeddings_text, embedding_img = visualizer.get_embeddings(sentence) | |
| # 5. Positional Encoding | |
| pos_encoding_text, pos_encoding_img, pos_waves_img = visualizer.get_positional_encoding(len(token_ids)) | |
| return (tokenization_text, special_tokens_text, token_ids_text, | |
| embeddings_text, embedding_img, pos_encoding_text, | |
| pos_encoding_img, pos_waves_img) | |
| # Create Gradio interface | |
| models = [ | |
| "bert-base-uncased", | |
| "roberta-base", | |
| "distilbert-base-uncased", | |
| "gpt2", | |
| "albert-base-v2", | |
| "xlm-roberta-base" | |
| ] | |
| with gr.Blocks(title="Transformer Process Visualizer") as demo: | |
| gr.Markdown("# Transformer Process Visualizer") | |
| gr.Markdown("This app visualizes the key processes in transformer models: tokenization, special tokens, token IDs, word embeddings, and positional encoding.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="Input Sentence", | |
| placeholder="Enter a sentence to visualize transformer processes", | |
| value="The transformer architecture revolutionized natural language processing." | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| label="Select Model", | |
| choices=models, | |
| value="bert-base-uncased" | |
| ) | |
| submit_btn = gr.Button("Visualize") | |
| with gr.Tabs(): | |
| with gr.TabItem("Tokenization"): | |
| tokenization_output = gr.Textbox(label="Tokenization") | |
| with gr.TabItem("Special Tokens"): | |
| special_tokens_output = gr.Textbox(label="Special Tokens") | |
| with gr.TabItem("Token IDs"): | |
| token_ids_output = gr.Textbox(label="Token IDs") | |
| with gr.TabItem("Word Embeddings"): | |
| embeddings_output = gr.Textbox(label="Embeddings Info") | |
| embedding_plot = gr.Image(label="Embedding Visualization") | |
| with gr.TabItem("Positional Encoding"): | |
| pos_encoding_output = gr.Textbox(label="Positional Encoding Info") | |
| pos_encoding_plot = gr.Image(label="Positional Encoding Heatmap") | |
| pos_waves_plot = gr.Image(label="Positional Encoding Waves") | |
| submit_btn.click( | |
| process_text, | |
| inputs=[input_text, model_dropdown], | |
| outputs=[ | |
| tokenization_output, | |
| special_tokens_output, | |
| token_ids_output, | |
| embeddings_output, | |
| embedding_plot, | |
| pos_encoding_output, | |
| pos_encoding_plot, | |
| pos_waves_plot | |
| ] | |
| ) | |
| demo.launch() |