import torch import functools from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from collections import defaultdict import random import matplotlib.pyplot as plt import math import os # Added for directory creation # --- Configuration --- MODEL_NAME = "/jizhicfs/bojoli/vicuna-7b-1.5" # Or any other Vicuna model based on Llama # Dataset configuration DATASET_NAME = "../../dataset/wikitext" # Example: "wikitext", "c4", "glue", "squad" etc. DATASET_CONFIG_NAME = "wikitext-2-raw-v1" TEXT_COLUMN_NAME = "text" NUM_TEXT_SAMPLES = 500 SAMPLED_QK_PAIRS_PER_LAYER_PER_TEXT = 10 MAX_SEQ_LENGTH = 1024 MAX_HYPOTHETICAL_RELATIVE_POS = 1024 REFERENCE_Q_POS = MAX_HYPOTHETICAL_RELATIVE_POS # The fixed "original" position for Q # --- Global storage for HOOKS (cleared per forward pass) --- hook_captured_raw_q = {} hook_captured_raw_k = {} # --- Global storage for SAMPLED QK PAIRS (accumulated across texts) --- all_sampled_pre_rope_qk_vec_pairs = [] # --- Final dot products --- dot_products_by_hypo_rel_pos = defaultdict(lambda: defaultdict(list)) # New: Storage for per-head dot products dot_products_per_head_by_hypo_rel_pos = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) hook_handles = [] # --- Helper Functions --- def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # --- Hook Function to capture raw q_proj and k_proj outputs --- def raw_projection_capture_hook(storage_dict, module, input_args, output, layer_idx): storage_dict[layer_idx] = output.detach().cpu() def get_texts_from_dataset(dataset_name, dataset_config_name, text_column, num_samples): print(f"Loading dataset: {dataset_name}, config: {dataset_config_name}") try: dataset = load_dataset(dataset_name, dataset_config_name, split="train", streaming=False) except Exception as e: print(f"Could not load train split. Error: {e}. Trying first available split.") try: dataset_info = load_dataset(dataset_name, dataset_config_name) split_name = list(dataset_info.keys())[0] dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=False) except Exception as e2: print(f"Error loading dataset: {e2}") return None if text_column not in dataset.column_names: print(f"Error: Text column '{text_column}' not found. Available: {dataset.column_names}") return None dataset_size = len(dataset) actual_num_samples = min(num_samples, dataset_size) if actual_num_samples>100: actual_num_samples = 100 dataset= dataset.shuffle(seed=42).select(range(actual_num_samples)) sampled_texts = [item[text_column].strip() for item in dataset] print(f"Successfully sampled {len(sampled_texts)} non-empty text examples.") return sampled_texts def main(): # --- Phase 0: Load texts --- texts_to_process = get_texts_from_dataset(DATASET_NAME, DATASET_CONFIG_NAME, TEXT_COLUMN_NAME, NUM_TEXT_SAMPLES) if not texts_to_process: print("No texts to process. Exiting.") return # --- Phase 1: Load Model & Tokenizer, Register Hooks --- print(f"Loading model: {MODEL_NAME}") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) except Exception as e: print(f"Error loading model or tokenizer: {e}"); return model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f"Model loaded on {device}") num_layers = 0 if hasattr(model, 'model') and hasattr(model.model, 'layers'): num_layers = len(model.model.layers) else: print("Model structure does not match expected Llama-like structure."); return print("Registering hooks on q_proj and k_proj layers...") for i, layer_module in enumerate(model.model.layers): if hasattr(layer_module, 'self_attn') and \ hasattr(layer_module.self_attn, 'q_proj') and \ hasattr(layer_module.self_attn, 'k_proj'): q_hook_fn = functools.partial(raw_projection_capture_hook, hook_captured_raw_q, layer_idx=i) q_handle = layer_module.self_attn.q_proj.register_forward_hook(q_hook_fn) hook_handles.append(q_handle) k_hook_fn = functools.partial(raw_projection_capture_hook, hook_captured_raw_k, layer_idx=i) k_handle = layer_module.self_attn.k_proj.register_forward_hook(k_hook_fn) hook_handles.append(k_handle) print(f"Registered {len(hook_handles)} hooks.") # --- Phase 2: Data Collection and QK Pair Sampling --- print("\n--- Starting Data Collection and QK Pair Sampling ---") for text_idx, current_text in enumerate(texts_to_process): print(f"Processing text sample {text_idx + 1}/{len(texts_to_process)}") hook_captured_raw_q.clear() hook_captured_raw_k.clear() inputs = tokenizer(current_text, return_tensors="pt", padding="longest", truncation=True, max_length=MAX_SEQ_LENGTH).to(device) with torch.no_grad(): try: model(**inputs) # Triggers hooks except Exception as e: print(f" Error during model forward pass: {e}"); continue for layer_idx in range(num_layers): if layer_idx not in hook_captured_raw_q or layer_idx not in hook_captured_raw_k: continue raw_q_proj = hook_captured_raw_q[layer_idx].to(device) raw_k_proj = hook_captured_raw_k[layer_idx].to(device) attn_module = model.model.layers[layer_idx].self_attn bsz, seq_len, _ = raw_q_proj.shape q_reshaped = raw_q_proj.view(bsz, seq_len, attn_module.num_heads, attn_module.head_dim).transpose(1, 2) k_reshaped = raw_k_proj.view(bsz, seq_len, attn_module.num_key_value_heads, attn_module.head_dim).transpose(1, 2) possible_indices = [] for b_idx in range(bsz): for i_idx in range(seq_len): for j_idx in range(seq_len): if i_idx == j_idx and seq_len == 1: possible_indices.append((b_idx, i_idx, j_idx)) elif i_idx != j_idx : possible_indices.append((b_idx, i_idx, j_idx)) num_to_sample = min(len(possible_indices), SAMPLED_QK_PAIRS_PER_LAYER_PER_TEXT) if not possible_indices or num_to_sample == 0: continue selected_indices = random.sample(possible_indices, num_to_sample) for b_sel, i_sel, j_sel in selected_indices: q_vec = q_reshaped[b_sel, :, i_sel, :].detach().clone() k_vec = k_reshaped[b_sel, :, j_sel, :].detach().clone() all_sampled_pre_rope_qk_vec_pairs.append({ 'layer_idx': layer_idx, 'q_vec': q_vec, 'k_vec': k_vec, 'num_heads': attn_module.num_heads, 'head_dim': attn_module.head_dim, 'num_kv_heads': attn_module.num_key_value_heads }) print(f"\nCollected {len(all_sampled_pre_rope_qk_vec_pairs)} (Q,K) vector pairs for RoPE simulation.") # --- Phase 3: RoPE Simulation and Dot Product Calculation --- print("\n--- Starting RoPE Simulation and Dot Product Calculation ---") if not all_sampled_pre_rope_qk_vec_pairs: print("No QK pairs were sampled. Exiting simulation phase.") else: layer_rope_caches = {} model_dtype = model.dtype for layer_idx in range(num_layers): attn_module = model.model.layers[layer_idx].self_attn rotary_emb = attn_module.rotary_emb max_pos_for_cache = REFERENCE_Q_POS + 1 dummy_x_for_rope = torch.zeros( (1, 1, max_pos_for_cache, attn_module.head_dim), device=device, dtype=model_dtype ) position_ids_for_cache = torch.arange( max_pos_for_cache, device=device, dtype=torch.long ).unsqueeze(0) cos_cached_for_layer, sin_cached_for_layer = rotary_emb(dummy_x_for_rope, position_ids_for_cache) layer_rope_caches[layer_idx] = (cos_cached_for_layer, sin_cached_for_layer) for pair_info in all_sampled_pre_rope_qk_vec_pairs: layer_idx = pair_info['layer_idx'] q_vec_pre_rope = pair_info['q_vec'] k_vec_pre_rope = pair_info['k_vec'] num_h, head_dim = pair_info['num_heads'], pair_info['head_dim'] num_kv_h = pair_info['num_kv_heads'] num_key_value_groups = num_h // num_kv_h cos_cache_full, sin_cache_full = layer_rope_caches[layer_idx] q_for_rope = q_vec_pre_rope.unsqueeze(0).unsqueeze(2) k_for_rope = k_vec_pre_rope.unsqueeze(0).unsqueeze(2) for hypo_rel_pos in range(1, MAX_HYPOTHETICAL_RELATIVE_POS + 1): pos_q_prime = REFERENCE_Q_POS pos_k_prime = REFERENCE_Q_POS - hypo_rel_pos if pos_k_prime < 0: continue if pos_q_prime >= cos_cache_full.shape[1] or pos_k_prime >= cos_cache_full.shape[1]: print(f"Warning: Position out of bounds for cached RoPE. pos_q_prime={pos_q_prime}, pos_k_prime={pos_k_prime}, cache_len={cos_cache_full.shape[1]}") continue cos_q = cos_cache_full[:, pos_q_prime:pos_q_prime+1, :] sin_q = sin_cache_full[:, pos_q_prime:pos_q_prime+1, :] q_rope = (q_for_rope * cos_q) + (rotate_half(q_for_rope) * sin_q) cos_k = cos_cache_full[:, pos_k_prime:pos_k_prime+1, :] sin_k = sin_cache_full[:, pos_k_prime:pos_k_prime+1, :] k_rope = (k_for_rope * cos_k) + (rotate_half(k_for_rope) * sin_k) q_final = q_rope.squeeze(2).squeeze(0) k_final = k_rope.squeeze(2).squeeze(0) if num_key_value_groups > 1: k_final_expanded = k_final.repeat_interleave(num_key_value_groups, dim=0) else: k_final_expanded = k_final dot_products_all_heads = torch.sum(q_final * k_final_expanded, dim=-1) # Shape: (num_h) avg_dot_product = dot_products_all_heads.mean().item() dot_products_by_hypo_rel_pos[layer_idx][hypo_rel_pos].append(avg_dot_product) # New: Store per-head dot products for head_idx in range(num_h): dot_product_for_head = dot_products_all_heads[head_idx].item() dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][hypo_rel_pos].append(dot_product_for_head) # --- Phase 4a: Plotting Combined Results --- print("\n--- Plotting Combined Results ---") plt.figure(figsize=(15, 8)) layers_plotted_combined = 0 for layer_idx in range(num_layers): if not dot_products_by_hypo_rel_pos[layer_idx]: continue layers_plotted_combined +=1 avg_dot_products_for_layer = {} rel_positions_sorted = sorted(dot_products_by_hypo_rel_pos[layer_idx].keys()) if not rel_positions_sorted: continue for rel_pos in rel_positions_sorted: if dot_products_by_hypo_rel_pos[layer_idx][rel_pos]: avg_dot_products_for_layer[rel_pos] = sum(dot_products_by_hypo_rel_pos[layer_idx][rel_pos]) / len(dot_products_by_hypo_rel_pos[layer_idx][rel_pos]) if not avg_dot_products_for_layer: continue x_vals = list(avg_dot_products_for_layer.keys()) y_vals = list(avg_dot_products_for_layer.values()) plt.plot(x_vals, y_vals, marker='.', linestyle='-', label=f'Layer {layer_idx}', alpha=0.7) if layers_plotted_combined > 0: plt.title(f'Avg QK Dot Product (Simulated RoPE) vs. Hypothetical Relative Position (All Layers)') plt.xlabel('Hypothetical Relative Position (d = Q_pos - K_pos)') plt.ylabel('Average Dot Product (across heads & samples)') plt.grid(True) plt.xlim(0, MAX_HYPOTHETICAL_RELATIVE_POS + 1) num_legend_cols = math.ceil(layers_plotted_combined / 20) if layers_plotted_combined > 0 else 1 plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=num_legend_cols) plt.tight_layout(rect=[0, 0, 0.85 if num_legend_cols <=2 else 0.75, 1]) svg_filename = f"all_layers_simulated_rope_qk_dot_product_corrected.svg" pdf_filename = f"all_layers_simulated_rope_qk_dot_product_corrected.pdf" plt.savefig(svg_filename, format='svg') plt.savefig(pdf_filename, format='pdf') print(f"Saved combined plot to {svg_filename} and {pdf_filename}") plt.close() else: print("No data collected for any layer to plot (combined).") # --- Phase 4b: Plotting Per-Head Results --- print("\n--- Plotting Per-Head Results ---") per_head_output_dir = "per_head" os.makedirs(per_head_output_dir, exist_ok=True) for layer_idx in range(num_layers): if not dot_products_per_head_by_hypo_rel_pos[layer_idx]: print(f"No per-head data for layer {layer_idx} to plot.") continue plt.figure(figsize=(15, 8)) attn_module = model.model.layers[layer_idx].self_attn num_heads_in_layer = attn_module.num_heads heads_plotted_for_layer = 0 for head_idx in range(num_heads_in_layer): if not dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx]: continue avg_dot_products_for_head = {} rel_positions_sorted = sorted(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx].keys()) if not rel_positions_sorted: continue for rel_pos in rel_positions_sorted: if dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]: avg_dot_products_for_head[rel_pos] = sum(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]) / len(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]) if not avg_dot_products_for_head: continue x_vals = list(avg_dot_products_for_head.keys()) y_vals = list(avg_dot_products_for_head.values()) plt.plot(x_vals, y_vals, marker='.', linestyle='-', label=f'Head {head_idx}', alpha=0.6) heads_plotted_for_layer += 1 if heads_plotted_for_layer > 0: plt.title(f'Layer {layer_idx}: Avg QK Dot Product vs. Hypothetical Relative Position (Per Head)') plt.xlabel('Hypothetical Relative Position (d = Q_pos - K_pos)') plt.ylabel('Average Dot Product (across samples)') plt.grid(True) plt.xlim(0, MAX_HYPOTHETICAL_RELATIVE_POS + 1) num_legend_cols_per_head = math.ceil(heads_plotted_for_layer / 20) if heads_plotted_for_layer > 0 else 1 plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=num_legend_cols_per_head) plt.tight_layout(rect=[0, 0, 0.85 if num_legend_cols_per_head <= 2 else 0.75, 1]) svg_filename_per_head = os.path.join(per_head_output_dir, f"layer_{layer_idx}_per_head_simulated_rope_qk_dot_product.svg") pdf_filename_per_head = os.path.join(per_head_output_dir, f"layer_{layer_idx}_per_head_simulated_rope_qk_dot_product.pdf") plt.savefig(svg_filename_per_head, format='svg') plt.savefig(pdf_filename_per_head, format='pdf') print(f"Saved per-head plot for layer {layer_idx} to {svg_filename_per_head} and {pdf_filename_per_head}") plt.close() else: print(f"No data plotted for any head in layer {layer_idx}.") plt.close() print("\nRemoving hooks...") for handle in hook_handles: handle.remove() print("Hooks removed. Script finished.") if __name__ == "__main__": main()