Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from matplotlib.colors import LinearSegmentedColormap | |
import seaborn as sns | |
import io | |
from PIL import Image | |
class TransformerVisualizer: | |
def __init__(self, model_name): | |
self.model_name = model_name | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained(model_name) | |
def tokenize(self, sentence): | |
# Get tokens without special tokens | |
tokens = self.tokenizer.tokenize(sentence) | |
return tokens, f"Original sentence: '{sentence}'\nTokenized: {tokens}" | |
def add_special_tokens(self, tokens): | |
# Add special tokens manually to show the process | |
tokens_with_special = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] | |
return tokens_with_special, f"With special tokens: {tokens_with_special}" | |
def get_token_ids(self, sentence): | |
# Get token IDs with special tokens included | |
inputs = self.tokenizer(sentence, return_tensors="pt") | |
token_ids = inputs["input_ids"][0].tolist() | |
tokens = self.tokenizer.convert_ids_to_tokens(token_ids) | |
result = "Token ID Mapping:\n" | |
for token, token_id in zip(tokens, token_ids): | |
result += f"Token: '{token}', ID: {token_id}\n" | |
return token_ids, tokens, result | |
def get_embeddings(self, sentence): | |
# Get embeddings | |
inputs = self.tokenizer(sentence, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
# Get the embeddings from the first layer | |
embeddings = outputs.last_hidden_state[0].numpy() | |
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
result = f"Embedding shape: {embeddings.shape}\n" | |
result += f"Each token is represented by a {embeddings.shape[1]}-dimensional vector" | |
# Create embedding heatmap | |
fig = plt.figure(figsize=(12, len(tokens) * 0.5)) | |
# Only show first few dimensions to make it readable | |
dims = 10 | |
embedding_subset = embeddings[:, :dims] | |
# Create a custom colormap | |
cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"]) | |
# Plot heatmap | |
sns.heatmap(embedding_subset, | |
cmap=cmap, | |
center=0, | |
xticklabels=[f"Dim {i+1}" for i in range(dims)], | |
yticklabels=tokens, | |
annot=False) | |
plt.title(f"Word Embeddings (first {dims} dimensions)") | |
plt.tight_layout() | |
# Convert plot to image | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close(fig) | |
buf.seek(0) | |
embedding_img = Image.open(buf) | |
return embeddings, tokens, result, embedding_img | |
def get_positional_encoding(self, seq_length, d_model=768): | |
# Create positional encodings | |
position = np.arange(seq_length)[:, np.newaxis] | |
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) | |
pos_encoding = np.zeros((seq_length, d_model)) | |
pos_encoding[:, 0::2] = np.sin(position * div_term) | |
pos_encoding[:, 1::2] = np.cos(position * div_term) | |
result = f"Positional encoding shape: {pos_encoding.shape}\n" | |
result += f"Generated for sequence length: {seq_length}" | |
# Visualize positional encodings | |
fig1 = plt.figure(figsize=(12, 6)) | |
# Only show first 20 dimensions to make it readable | |
dims_to_show = min(20, d_model) | |
# Create a custom colormap | |
cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"]) | |
sns.heatmap(pos_encoding[:, :dims_to_show], | |
cmap=cmap, | |
center=0, | |
xticklabels=[f"Dim {i+1}" for i in range(dims_to_show)], | |
yticklabels=[f"Pos {i+1}" for i in range(seq_length)]) | |
plt.title(f"Positional Encodings (first {dims_to_show} dimensions)") | |
plt.xlabel("Embedding Dimension") | |
plt.ylabel("Position in Sequence") | |
plt.tight_layout() | |
# Convert plot to image | |
buf1 = io.BytesIO() | |
plt.savefig(buf1, format='png') | |
plt.close(fig1) | |
buf1.seek(0) | |
pos_encoding_img = Image.open(buf1) | |
# Plot sine waves for a few dimensions | |
fig2 = plt.figure(figsize=(12, 6)) | |
dims_to_plot = [0, 2, 4, 20, 100] | |
for i, dim in enumerate(dims_to_plot): | |
if dim < pos_encoding.shape[1]: | |
plt.plot(pos_encoding[:, dim], label=f"Dim {dim} (sin)") | |
plt.title("Positional Encoding Sine Waves") | |
plt.xlabel("Position") | |
plt.ylabel("Value") | |
plt.legend() | |
plt.grid(True) | |
plt.tight_layout() | |
# Convert plot to image | |
buf2 = io.BytesIO() | |
plt.savefig(buf2, format='png') | |
plt.close(fig2) | |
buf2.seek(0) | |
pos_waves_img = Image.open(buf2) | |
return result, pos_encoding_img, pos_waves_img | |
def process_text(sentence, model_name): | |
visualizer = TransformerVisualizer(model_name) | |
# 1. Tokenization | |
tokens, tokenization_text = visualizer.tokenize(sentence) | |
# 2. Special Tokens | |
tokens_with_special, special_tokens_text = visualizer.add_special_tokens(tokens) | |
# 3. Token IDs | |
token_ids, tokens, token_ids_text = visualizer.get_token_ids(sentence) | |
# 4. Word Embeddings | |
embeddings, tokens, embeddings_text, embedding_img = visualizer.get_embeddings(sentence) | |
# 5. Positional Encoding | |
pos_encoding_text, pos_encoding_img, pos_waves_img = visualizer.get_positional_encoding(len(token_ids)) | |
return (tokenization_text, special_tokens_text, token_ids_text, | |
embeddings_text, embedding_img, pos_encoding_text, | |
pos_encoding_img, pos_waves_img) | |
# Create Gradio interface | |
models = [ | |
"bert-base-uncased", | |
"roberta-base", | |
"distilbert-base-uncased", | |
"gpt2", | |
"albert-base-v2", | |
"xlm-roberta-base" | |
] | |
with gr.Blocks(title="Transformer Process Visualizer") as demo: | |
gr.Markdown("# Transformer Process Visualizer") | |
gr.Markdown("This app visualizes the key processes in transformer models: tokenization, special tokens, token IDs, word embeddings, and positional encoding.") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Sentence", | |
placeholder="Enter a sentence to visualize transformer processes", | |
value="The transformer architecture revolutionized natural language processing." | |
) | |
model_dropdown = gr.Dropdown( | |
label="Select Model", | |
choices=models, | |
value="bert-base-uncased" | |
) | |
submit_btn = gr.Button("Visualize") | |
with gr.Tabs(): | |
with gr.TabItem("Tokenization"): | |
tokenization_output = gr.Textbox(label="Tokenization") | |
with gr.TabItem("Special Tokens"): | |
special_tokens_output = gr.Textbox(label="Special Tokens") | |
with gr.TabItem("Token IDs"): | |
token_ids_output = gr.Textbox(label="Token IDs") | |
with gr.TabItem("Word Embeddings"): | |
embeddings_output = gr.Textbox(label="Embeddings Info") | |
embedding_plot = gr.Image(label="Embedding Visualization") | |
with gr.TabItem("Positional Encoding"): | |
pos_encoding_output = gr.Textbox(label="Positional Encoding Info") | |
pos_encoding_plot = gr.Image(label="Positional Encoding Heatmap") | |
pos_waves_plot = gr.Image(label="Positional Encoding Waves") | |
submit_btn.click( | |
process_text, | |
inputs=[input_text, model_dropdown], | |
outputs=[ | |
tokenization_output, | |
special_tokens_output, | |
token_ids_output, | |
embeddings_output, | |
embedding_plot, | |
pos_encoding_output, | |
pos_encoding_plot, | |
pos_waves_plot | |
] | |
) | |
demo.launch() |