import torch
import gradio as gr
import plotly.express as px
from transformers import AutoModel, AutoTokenizer
########################################
# Load Transformer (DistilBERT) with attention
########################################
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
model.eval()
def visualize_attention(text, layer=5):
"""
1. Tokenize input text.
2. Run DistilBERT forward pass to get attention matrices.
3. Pick a layer (0..5) and average across attention heads.
4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
5. Label axes with tokens (Query vs. Key).
"""
with torch.no_grad():
inputs = tokenizer.encode_plus(text, return_tensors="pt")
outputs = model(**inputs)
all_attentions = outputs.attentions
# DistilBERT has 6 layers => valid layer indices: 0..5
attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
# Convert to numpy for plotting
attn_matrix = attn_layer[0].cpu().numpy()
# Get tokens (including special tokens)
input_ids = inputs["input_ids"][0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Build a Plotly heatmap
fig = px.imshow(
attn_matrix,
x=tokens,
y=tokens,
labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
color_continuous_scale="Blues",
title=f"DistilBERT Attention (Layer {layer})"
)
fig.update_xaxes(side="top")
# Add tooltip
fig.update_traces(
hovertemplate="Query: %{y}
Key: %{x}
Attention Weight: %{z:.3f}"
)
fig.update_layout(coloraxis_colorbar=dict(title="Attention Weight"))
return fig
def interpret_token_attention(text, token_index=0, layer=5):
"""
Provides a textual explanation for why a particular token (Query) attends
to other tokens in the input, highlighting the top 2 or 3 tokens
it focuses on.
"""
with torch.no_grad():
inputs = tokenizer.encode_plus(text, return_tensors="pt")
outputs = model(**inputs)
all_attentions = outputs.attentions
attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
# Get tokens
input_ids = inputs["input_ids"][0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Safety check for token_index
if token_index < 0 or token_index >= len(tokens):
return "Invalid token index. Please choose a valid token index."
# Extract the row corresponding to our Query token
query_attn = attn_layer[0, token_index, :].cpu().numpy() # shape: (seq_len,)
# Sort tokens by attention weight (descending)
sorted_indices = query_attn.argsort()[::-1]
top_indices = sorted_indices[:3] # Grab top 3
top_tokens = [tokens[i] for i in top_indices]
top_weights = [query_attn[i] for i in top_indices]
# Build an explanation
query_token_str = tokens[token_index]
explanation = (
f"**You chose token index {token_index}, which is '{query_token_str}'.**\n\n"
"In Transformers, each token is converted into Query, Key, and Value vectors:\n"
"- **Query** = What this token is looking for\n"
"- **Key** = What another token has to offer\n"
"- **Value** = The actual information from that token\n\n"
f"As a Query, '{query_token_str}' attends most strongly to:\n"
)
for t, w in zip(top_tokens, top_weights):
explanation += f"- **{t}** with attention weight ~ {w:.3f}\n"
explanation += (
"\nA higher attention weight indicates that this Query token is 'looking at' or "
"focusing on that Key token more strongly, likely because it finds the Key token "
"relevant to its meaning or context."
)
return explanation
# Short explanation text for the UI
description_text = """
## Understanding Transformer Self-Attention
- **Rows = Query token** (the token doing the 'looking').
- **Columns = Key token** (the token being 'looked at').
- Darker color = stronger attention weight.
**Transformers** process all tokens in **parallel**, allowing any token to attend to any other token in the sentence.
This makes it easier for the model to capture long-distance relationships.
"""
########################################
# Gradio Interface
########################################
with gr.Blocks(css="footer{display:none !important}") as demo:
gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
gr.Markdown(description_text)
with gr.Row():
text_input = gr.Textbox(
label="Enter a sentence",
value="Transformers handle long-range context in parallel."
)
layer_slider = gr.Slider(
minimum=0, maximum=5, step=1, value=5,
label="DistilBERT Layer (0=lowest, 5=highest)"
)
output_plot = gr.Plot(label="Attention Heatmap")
# Visualization Button
visualize_button = gr.Button("Visualize Attention")
visualize_button.click(
fn=visualize_attention,
inputs=[text_input, layer_slider],
outputs=output_plot
)
# Dropdown (or Slider) to choose a token index for interpretation
token_index = gr.Number(
label="Choose a token index to interpret (0-based)",
value=0
)
interpretation_output = gr.Markdown(label="Interpretation")
# Interpretation Button
interpret_button = gr.Button("Explain This Token's Attention")
interpret_button.click(
fn=interpret_token_attention,
inputs=[text_input, token_index, layer_slider],
outputs=interpretation_output
)
demo.launch()