|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
from matplotlib.colors import LinearSegmentedColormap |
|
import seaborn as sns |
|
import io |
|
from PIL import Image |
|
|
|
class TransformerVisualizer: |
|
def __init__(self, model_name): |
|
self.model_name = model_name |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
def tokenize(self, sentence): |
|
|
|
tokens = self.tokenizer.tokenize(sentence) |
|
return tokens, f"Original sentence: '{sentence}'\nTokenized: {tokens}" |
|
|
|
def add_special_tokens(self, tokens): |
|
|
|
tokens_with_special = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] |
|
return tokens_with_special, f"With special tokens: {tokens_with_special}" |
|
|
|
def get_token_ids(self, sentence): |
|
|
|
inputs = self.tokenizer(sentence, return_tensors="pt") |
|
token_ids = inputs["input_ids"][0].tolist() |
|
tokens = self.tokenizer.convert_ids_to_tokens(token_ids) |
|
|
|
result = "Token ID Mapping:\n" |
|
for token, token_id in zip(tokens, token_ids): |
|
result += f"Token: '{token}', ID: {token_id}\n" |
|
|
|
return token_ids, tokens, result |
|
|
|
def get_embeddings(self, sentence): |
|
|
|
inputs = self.tokenizer(sentence, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
embeddings = outputs.last_hidden_state[0].numpy() |
|
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
|
|
result = f"Embedding shape: {embeddings.shape}\n" |
|
result += f"Each token is represented by a {embeddings.shape[1]}-dimensional vector" |
|
|
|
|
|
fig = plt.figure(figsize=(12, len(tokens) * 0.5)) |
|
|
|
|
|
dims = 10 |
|
embedding_subset = embeddings[:, :dims] |
|
|
|
|
|
cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"]) |
|
|
|
|
|
sns.heatmap(embedding_subset, |
|
cmap=cmap, |
|
center=0, |
|
xticklabels=[f"Dim {i+1}" for i in range(dims)], |
|
yticklabels=tokens, |
|
annot=False) |
|
|
|
plt.title(f"Word Embeddings (first {dims} dimensions)") |
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close(fig) |
|
buf.seek(0) |
|
embedding_img = Image.open(buf) |
|
|
|
return embeddings, tokens, result, embedding_img |
|
|
|
def get_positional_encoding(self, seq_length, d_model=768): |
|
|
|
position = np.arange(seq_length)[:, np.newaxis] |
|
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) |
|
|
|
pos_encoding = np.zeros((seq_length, d_model)) |
|
pos_encoding[:, 0::2] = np.sin(position * div_term) |
|
pos_encoding[:, 1::2] = np.cos(position * div_term) |
|
|
|
result = f"Positional encoding shape: {pos_encoding.shape}\n" |
|
result += f"Generated for sequence length: {seq_length}" |
|
|
|
|
|
fig1 = plt.figure(figsize=(12, 6)) |
|
|
|
|
|
dims_to_show = min(20, d_model) |
|
|
|
|
|
cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#2596be", "#ffffff", "#e74c3c"]) |
|
|
|
sns.heatmap(pos_encoding[:, :dims_to_show], |
|
cmap=cmap, |
|
center=0, |
|
xticklabels=[f"Dim {i+1}" for i in range(dims_to_show)], |
|
yticklabels=[f"Pos {i+1}" for i in range(seq_length)]) |
|
|
|
plt.title(f"Positional Encodings (first {dims_to_show} dimensions)") |
|
plt.xlabel("Embedding Dimension") |
|
plt.ylabel("Position in Sequence") |
|
plt.tight_layout() |
|
|
|
|
|
buf1 = io.BytesIO() |
|
plt.savefig(buf1, format='png') |
|
plt.close(fig1) |
|
buf1.seek(0) |
|
pos_encoding_img = Image.open(buf1) |
|
|
|
|
|
fig2 = plt.figure(figsize=(12, 6)) |
|
|
|
dims_to_plot = [0, 2, 4, 20, 100] |
|
for i, dim in enumerate(dims_to_plot): |
|
if dim < pos_encoding.shape[1]: |
|
plt.plot(pos_encoding[:, dim], label=f"Dim {dim} (sin)") |
|
|
|
plt.title("Positional Encoding Sine Waves") |
|
plt.xlabel("Position") |
|
plt.ylabel("Value") |
|
plt.legend() |
|
plt.grid(True) |
|
plt.tight_layout() |
|
|
|
|
|
buf2 = io.BytesIO() |
|
plt.savefig(buf2, format='png') |
|
plt.close(fig2) |
|
buf2.seek(0) |
|
pos_waves_img = Image.open(buf2) |
|
|
|
return result, pos_encoding_img, pos_waves_img |
|
|
|
def process_text(sentence, model_name): |
|
visualizer = TransformerVisualizer(model_name) |
|
|
|
|
|
tokens, tokenization_text = visualizer.tokenize(sentence) |
|
|
|
|
|
tokens_with_special, special_tokens_text = visualizer.add_special_tokens(tokens) |
|
|
|
|
|
token_ids, tokens, token_ids_text = visualizer.get_token_ids(sentence) |
|
|
|
|
|
embeddings, tokens, embeddings_text, embedding_img = visualizer.get_embeddings(sentence) |
|
|
|
|
|
pos_encoding_text, pos_encoding_img, pos_waves_img = visualizer.get_positional_encoding(len(token_ids)) |
|
|
|
return (tokenization_text, special_tokens_text, token_ids_text, |
|
embeddings_text, embedding_img, pos_encoding_text, |
|
pos_encoding_img, pos_waves_img) |
|
|
|
|
|
models = [ |
|
"bert-base-uncased", |
|
"roberta-base", |
|
"distilbert-base-uncased", |
|
"gpt2", |
|
"albert-base-v2", |
|
"xlm-roberta-base" |
|
] |
|
|
|
with gr.Blocks(title="Transformer Process Visualizer") as demo: |
|
gr.Markdown("# Transformer Process Visualizer") |
|
gr.Markdown("This app visualizes the key processes in transformer models: tokenization, special tokens, token IDs, word embeddings, and positional encoding.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox( |
|
label="Input Sentence", |
|
placeholder="Enter a sentence to visualize transformer processes", |
|
value="The transformer architecture revolutionized natural language processing." |
|
) |
|
model_dropdown = gr.Dropdown( |
|
label="Select Model", |
|
choices=models, |
|
value="bert-base-uncased" |
|
) |
|
submit_btn = gr.Button("Visualize") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Tokenization"): |
|
tokenization_output = gr.Textbox(label="Tokenization") |
|
|
|
with gr.TabItem("Special Tokens"): |
|
special_tokens_output = gr.Textbox(label="Special Tokens") |
|
|
|
with gr.TabItem("Token IDs"): |
|
token_ids_output = gr.Textbox(label="Token IDs") |
|
|
|
with gr.TabItem("Word Embeddings"): |
|
embeddings_output = gr.Textbox(label="Embeddings Info") |
|
embedding_plot = gr.Image(label="Embedding Visualization") |
|
|
|
with gr.TabItem("Positional Encoding"): |
|
pos_encoding_output = gr.Textbox(label="Positional Encoding Info") |
|
pos_encoding_plot = gr.Image(label="Positional Encoding Heatmap") |
|
pos_waves_plot = gr.Image(label="Positional Encoding Waves") |
|
|
|
submit_btn.click( |
|
process_text, |
|
inputs=[input_text, model_dropdown], |
|
outputs=[ |
|
tokenization_output, |
|
special_tokens_output, |
|
token_ids_output, |
|
embeddings_output, |
|
embedding_plot, |
|
pos_encoding_output, |
|
pos_encoding_plot, |
|
pos_waves_plot |
|
] |
|
) |
|
|
|
demo.launch() |