id-align / qk.py
zooblastlbz's picture
Upload folder using huggingface_hub
a9e1e1a verified
import torch
import functools
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from collections import defaultdict
import random
import matplotlib.pyplot as plt
import math
import os # Added for directory creation
# --- Configuration ---
MODEL_NAME = "/jizhicfs/bojoli/vicuna-7b-1.5" # Or any other Vicuna model based on Llama
# Dataset configuration
DATASET_NAME = "../../dataset/wikitext" # Example: "wikitext", "c4", "glue", "squad" etc.
DATASET_CONFIG_NAME = "wikitext-2-raw-v1"
TEXT_COLUMN_NAME = "text"
NUM_TEXT_SAMPLES = 500
SAMPLED_QK_PAIRS_PER_LAYER_PER_TEXT = 10
MAX_SEQ_LENGTH = 1024
MAX_HYPOTHETICAL_RELATIVE_POS = 1024
REFERENCE_Q_POS = MAX_HYPOTHETICAL_RELATIVE_POS # The fixed "original" position for Q
# --- Global storage for HOOKS (cleared per forward pass) ---
hook_captured_raw_q = {}
hook_captured_raw_k = {}
# --- Global storage for SAMPLED QK PAIRS (accumulated across texts) ---
all_sampled_pre_rope_qk_vec_pairs = []
# --- Final dot products ---
dot_products_by_hypo_rel_pos = defaultdict(lambda: defaultdict(list))
# New: Storage for per-head dot products
dot_products_per_head_by_hypo_rel_pos = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
hook_handles = []
# --- Helper Functions ---
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# --- Hook Function to capture raw q_proj and k_proj outputs ---
def raw_projection_capture_hook(storage_dict, module, input_args, output, layer_idx):
storage_dict[layer_idx] = output.detach().cpu()
def get_texts_from_dataset(dataset_name, dataset_config_name, text_column, num_samples):
print(f"Loading dataset: {dataset_name}, config: {dataset_config_name}")
try:
dataset = load_dataset(dataset_name, dataset_config_name, split="train", streaming=False)
except Exception as e:
print(f"Could not load train split. Error: {e}. Trying first available split.")
try:
dataset_info = load_dataset(dataset_name, dataset_config_name)
split_name = list(dataset_info.keys())[0]
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=False)
except Exception as e2:
print(f"Error loading dataset: {e2}")
return None
if text_column not in dataset.column_names:
print(f"Error: Text column '{text_column}' not found. Available: {dataset.column_names}")
return None
dataset_size = len(dataset)
actual_num_samples = min(num_samples, dataset_size)
if actual_num_samples>100:
actual_num_samples = 100
dataset= dataset.shuffle(seed=42).select(range(actual_num_samples))
sampled_texts = [item[text_column].strip() for item in dataset]
print(f"Successfully sampled {len(sampled_texts)} non-empty text examples.")
return sampled_texts
def main():
# --- Phase 0: Load texts ---
texts_to_process = get_texts_from_dataset(DATASET_NAME, DATASET_CONFIG_NAME, TEXT_COLUMN_NAME, NUM_TEXT_SAMPLES)
if not texts_to_process:
print("No texts to process. Exiting.")
return
# --- Phase 1: Load Model & Tokenizer, Register Hooks ---
print(f"Loading model: {MODEL_NAME}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
except Exception as e:
print(f"Error loading model or tokenizer: {e}"); return
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}")
num_layers = 0
if hasattr(model, 'model') and hasattr(model.model, 'layers'):
num_layers = len(model.model.layers)
else:
print("Model structure does not match expected Llama-like structure."); return
print("Registering hooks on q_proj and k_proj layers...")
for i, layer_module in enumerate(model.model.layers):
if hasattr(layer_module, 'self_attn') and \
hasattr(layer_module.self_attn, 'q_proj') and \
hasattr(layer_module.self_attn, 'k_proj'):
q_hook_fn = functools.partial(raw_projection_capture_hook, hook_captured_raw_q, layer_idx=i)
q_handle = layer_module.self_attn.q_proj.register_forward_hook(q_hook_fn)
hook_handles.append(q_handle)
k_hook_fn = functools.partial(raw_projection_capture_hook, hook_captured_raw_k, layer_idx=i)
k_handle = layer_module.self_attn.k_proj.register_forward_hook(k_hook_fn)
hook_handles.append(k_handle)
print(f"Registered {len(hook_handles)} hooks.")
# --- Phase 2: Data Collection and QK Pair Sampling ---
print("\n--- Starting Data Collection and QK Pair Sampling ---")
for text_idx, current_text in enumerate(texts_to_process):
print(f"Processing text sample {text_idx + 1}/{len(texts_to_process)}")
hook_captured_raw_q.clear()
hook_captured_raw_k.clear()
inputs = tokenizer(current_text, return_tensors="pt", padding="longest", truncation=True, max_length=MAX_SEQ_LENGTH).to(device)
with torch.no_grad():
try:
model(**inputs) # Triggers hooks
except Exception as e:
print(f" Error during model forward pass: {e}"); continue
for layer_idx in range(num_layers):
if layer_idx not in hook_captured_raw_q or layer_idx not in hook_captured_raw_k:
continue
raw_q_proj = hook_captured_raw_q[layer_idx].to(device)
raw_k_proj = hook_captured_raw_k[layer_idx].to(device)
attn_module = model.model.layers[layer_idx].self_attn
bsz, seq_len, _ = raw_q_proj.shape
q_reshaped = raw_q_proj.view(bsz, seq_len, attn_module.num_heads, attn_module.head_dim).transpose(1, 2)
k_reshaped = raw_k_proj.view(bsz, seq_len, attn_module.num_key_value_heads, attn_module.head_dim).transpose(1, 2)
possible_indices = []
for b_idx in range(bsz):
for i_idx in range(seq_len):
for j_idx in range(seq_len):
if i_idx == j_idx and seq_len == 1:
possible_indices.append((b_idx, i_idx, j_idx))
elif i_idx != j_idx :
possible_indices.append((b_idx, i_idx, j_idx))
num_to_sample = min(len(possible_indices), SAMPLED_QK_PAIRS_PER_LAYER_PER_TEXT)
if not possible_indices or num_to_sample == 0: continue
selected_indices = random.sample(possible_indices, num_to_sample)
for b_sel, i_sel, j_sel in selected_indices:
q_vec = q_reshaped[b_sel, :, i_sel, :].detach().clone()
k_vec = k_reshaped[b_sel, :, j_sel, :].detach().clone()
all_sampled_pre_rope_qk_vec_pairs.append({
'layer_idx': layer_idx,
'q_vec': q_vec,
'k_vec': k_vec,
'num_heads': attn_module.num_heads,
'head_dim': attn_module.head_dim,
'num_kv_heads': attn_module.num_key_value_heads
})
print(f"\nCollected {len(all_sampled_pre_rope_qk_vec_pairs)} (Q,K) vector pairs for RoPE simulation.")
# --- Phase 3: RoPE Simulation and Dot Product Calculation ---
print("\n--- Starting RoPE Simulation and Dot Product Calculation ---")
if not all_sampled_pre_rope_qk_vec_pairs:
print("No QK pairs were sampled. Exiting simulation phase.")
else:
layer_rope_caches = {}
model_dtype = model.dtype
for layer_idx in range(num_layers):
attn_module = model.model.layers[layer_idx].self_attn
rotary_emb = attn_module.rotary_emb
max_pos_for_cache = REFERENCE_Q_POS + 1
dummy_x_for_rope = torch.zeros(
(1, 1, max_pos_for_cache, attn_module.head_dim),
device=device,
dtype=model_dtype
)
position_ids_for_cache = torch.arange(
max_pos_for_cache,
device=device,
dtype=torch.long
).unsqueeze(0)
cos_cached_for_layer, sin_cached_for_layer = rotary_emb(dummy_x_for_rope, position_ids_for_cache)
layer_rope_caches[layer_idx] = (cos_cached_for_layer, sin_cached_for_layer)
for pair_info in all_sampled_pre_rope_qk_vec_pairs:
layer_idx = pair_info['layer_idx']
q_vec_pre_rope = pair_info['q_vec']
k_vec_pre_rope = pair_info['k_vec']
num_h, head_dim = pair_info['num_heads'], pair_info['head_dim']
num_kv_h = pair_info['num_kv_heads']
num_key_value_groups = num_h // num_kv_h
cos_cache_full, sin_cache_full = layer_rope_caches[layer_idx]
q_for_rope = q_vec_pre_rope.unsqueeze(0).unsqueeze(2)
k_for_rope = k_vec_pre_rope.unsqueeze(0).unsqueeze(2)
for hypo_rel_pos in range(1, MAX_HYPOTHETICAL_RELATIVE_POS + 1):
pos_q_prime = REFERENCE_Q_POS
pos_k_prime = REFERENCE_Q_POS - hypo_rel_pos
if pos_k_prime < 0: continue
if pos_q_prime >= cos_cache_full.shape[1] or pos_k_prime >= cos_cache_full.shape[1]:
print(f"Warning: Position out of bounds for cached RoPE. pos_q_prime={pos_q_prime}, pos_k_prime={pos_k_prime}, cache_len={cos_cache_full.shape[1]}")
continue
cos_q = cos_cache_full[:, pos_q_prime:pos_q_prime+1, :]
sin_q = sin_cache_full[:, pos_q_prime:pos_q_prime+1, :]
q_rope = (q_for_rope * cos_q) + (rotate_half(q_for_rope) * sin_q)
cos_k = cos_cache_full[:, pos_k_prime:pos_k_prime+1, :]
sin_k = sin_cache_full[:, pos_k_prime:pos_k_prime+1, :]
k_rope = (k_for_rope * cos_k) + (rotate_half(k_for_rope) * sin_k)
q_final = q_rope.squeeze(2).squeeze(0)
k_final = k_rope.squeeze(2).squeeze(0)
if num_key_value_groups > 1:
k_final_expanded = k_final.repeat_interleave(num_key_value_groups, dim=0)
else:
k_final_expanded = k_final
dot_products_all_heads = torch.sum(q_final * k_final_expanded, dim=-1) # Shape: (num_h)
avg_dot_product = dot_products_all_heads.mean().item()
dot_products_by_hypo_rel_pos[layer_idx][hypo_rel_pos].append(avg_dot_product)
# New: Store per-head dot products
for head_idx in range(num_h):
dot_product_for_head = dot_products_all_heads[head_idx].item()
dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][hypo_rel_pos].append(dot_product_for_head)
# --- Phase 4a: Plotting Combined Results ---
print("\n--- Plotting Combined Results ---")
plt.figure(figsize=(15, 8))
layers_plotted_combined = 0
for layer_idx in range(num_layers):
if not dot_products_by_hypo_rel_pos[layer_idx]: continue
layers_plotted_combined +=1
avg_dot_products_for_layer = {}
rel_positions_sorted = sorted(dot_products_by_hypo_rel_pos[layer_idx].keys())
if not rel_positions_sorted: continue
for rel_pos in rel_positions_sorted:
if dot_products_by_hypo_rel_pos[layer_idx][rel_pos]:
avg_dot_products_for_layer[rel_pos] = sum(dot_products_by_hypo_rel_pos[layer_idx][rel_pos]) / len(dot_products_by_hypo_rel_pos[layer_idx][rel_pos])
if not avg_dot_products_for_layer: continue
x_vals = list(avg_dot_products_for_layer.keys())
y_vals = list(avg_dot_products_for_layer.values())
plt.plot(x_vals, y_vals, marker='.', linestyle='-', label=f'Layer {layer_idx}', alpha=0.7)
if layers_plotted_combined > 0:
plt.title(f'Avg QK Dot Product (Simulated RoPE) vs. Hypothetical Relative Position (All Layers)')
plt.xlabel('Hypothetical Relative Position (d = Q_pos - K_pos)')
plt.ylabel('Average Dot Product (across heads & samples)')
plt.grid(True)
plt.xlim(0, MAX_HYPOTHETICAL_RELATIVE_POS + 1)
num_legend_cols = math.ceil(layers_plotted_combined / 20) if layers_plotted_combined > 0 else 1
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=num_legend_cols)
plt.tight_layout(rect=[0, 0, 0.85 if num_legend_cols <=2 else 0.75, 1])
svg_filename = f"all_layers_simulated_rope_qk_dot_product_corrected.svg"
pdf_filename = f"all_layers_simulated_rope_qk_dot_product_corrected.pdf"
plt.savefig(svg_filename, format='svg')
plt.savefig(pdf_filename, format='pdf')
print(f"Saved combined plot to {svg_filename} and {pdf_filename}")
plt.close()
else:
print("No data collected for any layer to plot (combined).")
# --- Phase 4b: Plotting Per-Head Results ---
print("\n--- Plotting Per-Head Results ---")
per_head_output_dir = "per_head"
os.makedirs(per_head_output_dir, exist_ok=True)
for layer_idx in range(num_layers):
if not dot_products_per_head_by_hypo_rel_pos[layer_idx]:
print(f"No per-head data for layer {layer_idx} to plot.")
continue
plt.figure(figsize=(15, 8))
attn_module = model.model.layers[layer_idx].self_attn
num_heads_in_layer = attn_module.num_heads
heads_plotted_for_layer = 0
for head_idx in range(num_heads_in_layer):
if not dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx]:
continue
avg_dot_products_for_head = {}
rel_positions_sorted = sorted(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx].keys())
if not rel_positions_sorted:
continue
for rel_pos in rel_positions_sorted:
if dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]:
avg_dot_products_for_head[rel_pos] = sum(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]) / len(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos])
if not avg_dot_products_for_head:
continue
x_vals = list(avg_dot_products_for_head.keys())
y_vals = list(avg_dot_products_for_head.values())
plt.plot(x_vals, y_vals, marker='.', linestyle='-', label=f'Head {head_idx}', alpha=0.6)
heads_plotted_for_layer += 1
if heads_plotted_for_layer > 0:
plt.title(f'Layer {layer_idx}: Avg QK Dot Product vs. Hypothetical Relative Position (Per Head)')
plt.xlabel('Hypothetical Relative Position (d = Q_pos - K_pos)')
plt.ylabel('Average Dot Product (across samples)')
plt.grid(True)
plt.xlim(0, MAX_HYPOTHETICAL_RELATIVE_POS + 1)
num_legend_cols_per_head = math.ceil(heads_plotted_for_layer / 20) if heads_plotted_for_layer > 0 else 1
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=num_legend_cols_per_head)
plt.tight_layout(rect=[0, 0, 0.85 if num_legend_cols_per_head <= 2 else 0.75, 1])
svg_filename_per_head = os.path.join(per_head_output_dir, f"layer_{layer_idx}_per_head_simulated_rope_qk_dot_product.svg")
pdf_filename_per_head = os.path.join(per_head_output_dir, f"layer_{layer_idx}_per_head_simulated_rope_qk_dot_product.pdf")
plt.savefig(svg_filename_per_head, format='svg')
plt.savefig(pdf_filename_per_head, format='pdf')
print(f"Saved per-head plot for layer {layer_idx} to {svg_filename_per_head} and {pdf_filename_per_head}")
plt.close()
else:
print(f"No data plotted for any head in layer {layer_idx}.")
plt.close()
print("\nRemoving hooks...")
for handle in hook_handles: handle.remove()
print("Hooks removed. Script finished.")
if __name__ == "__main__":
main()