# ==================================== # 1) Install Dependencies (One-Time) # ==================================== # ==================================== # 2) Imports # ==================================== import torch import gradio as gr import matplotlib.pyplot as plt import numpy as np import base64 from io import BytesIO # We'll use a "tiny" BERT model to reduce loading/inference time: from transformers import AutoTokenizer, AutoModel # Check if GPU is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # ==================================== # 3) Load Tiny BERT Model + Tokenizer # ==================================== # This "prajjwal1/bert-tiny" model is just 2 Transformer layers # (and ~4 million parameters), so it loads and runs faster. tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny") model = AutoModel.from_pretrained( "prajjwal1/bert-tiny", output_attentions=True # so we can visualize attention ).to(device) model.eval() # ==================================== # 4) Helper Functions # ==================================== def plot_heatmap(matrix, tokens, title="Attention", cmap="Blues"): """ Creates a heatmap from a 2D matrix: [seq_len, seq_len] with tokens on both axes. Returns a base64-encoded PNG. """ fig, ax = plt.subplots(figsize=(6, 5)) cax = ax.imshow(matrix, interpolation='nearest', cmap=cmap) ax.set_title(title) # Show tokens on x and y axis ax.set_xticks(range(len(tokens))) ax.set_xticklabels(tokens, rotation=90) ax.set_yticks(range(len(tokens))) ax.set_yticklabels(tokens) fig.colorbar(cax, ax=ax) plt.tight_layout() # Convert plot to base64-encoded PNG buf = BytesIO() plt.savefig(buf, format='png', bbox_inches="tight") plt.close(fig) buf.seek(0) return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8") def simulate_rnn_hidden_states(tokens): """ Simulate how an RNN processes tokens one-by-one. We'll just create random hidden states for illustration. Returns a base64-encoded PNG heatmap of shape [seq_len, hidden_dim]. """ seq_len = len(tokens) hidden_dim = 8 # small dimension for the "hidden state" # Create random hidden states: shape [seq_len, hidden_dim] random_states = np.random.rand(seq_len, hidden_dim) fig, ax = plt.subplots(figsize=(6, 3)) cax = ax.imshow(random_states, interpolation='nearest', aspect='auto', cmap="viridis") ax.set_title("Simulated RNN Hidden States") ax.set_xlabel("Hidden Dim") ax.set_ylabel("Token Index") fig.colorbar(cax, ax=ax) plt.tight_layout() buf = BytesIO() plt.savefig(buf, format='png', bbox_inches="tight") plt.close(fig) buf.seek(0) return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8") # ==================================== # 5) Gradio Inference Function # ==================================== def compare_rnn_transformer(input_text): """ - Tokenize input_text - Simulate an RNN's hidden states - Show Tiny BERT attention (averaged over heads from last layer) - Return two images: RNN hidden states, Transformer attention map """ # 1) Tokenize input text inputs = tokenizer.encode_plus( input_text, return_tensors="pt", truncation=True, max_length=50 ) # Move to GPU if available inputs = {k: v.to(device) for k, v in inputs.items()} # Convert IDs to tokens (just for axis labels) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist()) # 2) Simulate RNN hidden states rnn_heatmap = simulate_rnn_hidden_states(tokens) # 3) Forward pass through Tiny BERT with torch.no_grad(): outputs = model(**inputs) # outputs.attentions: [n_layers, batch_size, n_heads, seq_len, seq_len] attentions = outputs.attentions # Take the last layer's attention last_layer_attention = attentions[-1].squeeze(0) # shape: [n_heads, seq_len, seq_len] # Average across heads -> [seq_len, seq_len] avg_attention = last_layer_attention.mean(dim=0).cpu().numpy() # 4) Create a heatmap for attention transformer_heatmap = plot_heatmap(avg_attention, tokens, title="Transformer Attention") return (rnn_heatmap, transformer_heatmap) # ==================================== # 6) Create and Launch Gradio Interface # ==================================== interface = gr.Interface( fn=compare_rnn_transformer, inputs=gr.Textbox( lines=3, label="Enter a sentence to see RNN vs. Transformer visualization" ), outputs=[ gr.Image(label="RNN Hidden States"), gr.Image(label="Transformer Attention Map") ], title="RNN vs. Tiny BERT Demo", description=( "Type in a sentence and see how a simulated RNN processes tokens step-by-step " "vs. how a real (tiny) Transformer computes attention across all tokens in parallel.\n\n" "For best performance, enable GPU under Runtime > Change runtime type > GPU." ) ) interface.launch()