import torch import gradio as gr import plotly.express as px from transformers import AutoModel, AutoTokenizer ######################################## # Load Transformer (DistilBERT) with attention ######################################## model_name = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name, output_attentions=True) model.eval() def visualize_attention(text, layer=5): """ 1. Tokenize input text. 2. Run DistilBERT forward pass to get attention matrices. 3. Pick a layer (0..5) and average across attention heads. 4. Generate a heatmap (Plotly) of shape (seq_len x seq_len). 5. Label axes with tokens (Query vs. Key). """ with torch.no_grad(): inputs = tokenizer.encode_plus(text, return_tensors="pt") outputs = model(**inputs) all_attentions = outputs.attentions # DistilBERT has 6 layers => valid layer indices: 0..5 attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len) # Convert to numpy for plotting attn_matrix = attn_layer[0].cpu().numpy() # Get tokens (including special tokens) input_ids = inputs["input_ids"][0] tokens = tokenizer.convert_ids_to_tokens(input_ids) # Build a Plotly heatmap fig = px.imshow( attn_matrix, x=tokens, y=tokens, labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"}, color_continuous_scale="Blues", title=f"DistilBERT Attention (Layer {layer})" ) fig.update_xaxes(side="top") # Add tooltip fig.update_traces( hovertemplate="Query: %{y}
Key: %{x}
Attention Weight: %{z:.3f}" ) fig.update_layout(coloraxis_colorbar=dict(title="Attention Weight")) return fig def interpret_token_attention(text, token_index=0, layer=5): """ Provides a textual explanation for why a particular token (Query) attends to other tokens in the input, highlighting the top 2 or 3 tokens it focuses on. """ with torch.no_grad(): inputs = tokenizer.encode_plus(text, return_tensors="pt") outputs = model(**inputs) all_attentions = outputs.attentions attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len) # Get tokens input_ids = inputs["input_ids"][0] tokens = tokenizer.convert_ids_to_tokens(input_ids) # Safety check for token_index if token_index < 0 or token_index >= len(tokens): return "Invalid token index. Please choose a valid token index." # Extract the row corresponding to our Query token query_attn = attn_layer[0, token_index, :].cpu().numpy() # shape: (seq_len,) # Sort tokens by attention weight (descending) sorted_indices = query_attn.argsort()[::-1] top_indices = sorted_indices[:3] # Grab top 3 top_tokens = [tokens[i] for i in top_indices] top_weights = [query_attn[i] for i in top_indices] # Build an explanation query_token_str = tokens[token_index] explanation = ( f"**You chose token index {token_index}, which is '{query_token_str}'.**\n\n" "In Transformers, each token is converted into Query, Key, and Value vectors:\n" "- **Query** = What this token is looking for\n" "- **Key** = What another token has to offer\n" "- **Value** = The actual information from that token\n\n" f"As a Query, '{query_token_str}' attends most strongly to:\n" ) for t, w in zip(top_tokens, top_weights): explanation += f"- **{t}** with attention weight ~ {w:.3f}\n" explanation += ( "\nA higher attention weight indicates that this Query token is 'looking at' or " "focusing on that Key token more strongly, likely because it finds the Key token " "relevant to its meaning or context." ) return explanation # Short explanation text for the UI description_text = """ ## Understanding Transformer Self-Attention - **Rows = Query token** (the token doing the 'looking'). - **Columns = Key token** (the token being 'looked at'). - Darker color = stronger attention weight. **Transformers** process all tokens in **parallel**, allowing any token to attend to any other token in the sentence. This makes it easier for the model to capture long-distance relationships. """ ######################################## # Gradio Interface ######################################## with gr.Blocks(css="footer{display:none !important}") as demo: gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)") gr.Markdown(description_text) with gr.Row(): text_input = gr.Textbox( label="Enter a sentence", value="Transformers handle long-range context in parallel." ) layer_slider = gr.Slider( minimum=0, maximum=5, step=1, value=5, label="DistilBERT Layer (0=lowest, 5=highest)" ) output_plot = gr.Plot(label="Attention Heatmap") # Visualization Button visualize_button = gr.Button("Visualize Attention") visualize_button.click( fn=visualize_attention, inputs=[text_input, layer_slider], outputs=output_plot ) # Dropdown (or Slider) to choose a token index for interpretation token_index = gr.Number( label="Choose a token index to interpret (0-based)", value=0 ) interpretation_output = gr.Markdown(label="Interpretation") # Interpretation Button interpret_button = gr.Button("Explain This Token's Attention") interpret_button.click( fn=interpret_token_attention, inputs=[text_input, token_index, layer_slider], outputs=interpretation_output ) demo.launch()