|
import torch |
|
import functools |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from datasets import load_dataset |
|
from collections import defaultdict |
|
import random |
|
import matplotlib.pyplot as plt |
|
import math |
|
import os |
|
|
|
|
|
MODEL_NAME = "/jizhicfs/bojoli/vicuna-7b-1.5" |
|
|
|
DATASET_NAME = "../../dataset/wikitext" |
|
DATASET_CONFIG_NAME = "wikitext-2-raw-v1" |
|
TEXT_COLUMN_NAME = "text" |
|
NUM_TEXT_SAMPLES = 500 |
|
SAMPLED_QK_PAIRS_PER_LAYER_PER_TEXT = 10 |
|
MAX_SEQ_LENGTH = 1024 |
|
MAX_HYPOTHETICAL_RELATIVE_POS = 1024 |
|
REFERENCE_Q_POS = MAX_HYPOTHETICAL_RELATIVE_POS |
|
|
|
|
|
hook_captured_raw_q = {} |
|
hook_captured_raw_k = {} |
|
|
|
|
|
all_sampled_pre_rope_qk_vec_pairs = [] |
|
|
|
|
|
dot_products_by_hypo_rel_pos = defaultdict(lambda: defaultdict(list)) |
|
|
|
dot_products_per_head_by_hypo_rel_pos = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) |
|
hook_handles = [] |
|
|
|
|
|
def rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def raw_projection_capture_hook(storage_dict, module, input_args, output, layer_idx): |
|
storage_dict[layer_idx] = output.detach().cpu() |
|
|
|
|
|
def get_texts_from_dataset(dataset_name, dataset_config_name, text_column, num_samples): |
|
print(f"Loading dataset: {dataset_name}, config: {dataset_config_name}") |
|
try: |
|
dataset = load_dataset(dataset_name, dataset_config_name, split="train", streaming=False) |
|
except Exception as e: |
|
print(f"Could not load train split. Error: {e}. Trying first available split.") |
|
try: |
|
dataset_info = load_dataset(dataset_name, dataset_config_name) |
|
split_name = list(dataset_info.keys())[0] |
|
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=False) |
|
except Exception as e2: |
|
print(f"Error loading dataset: {e2}") |
|
return None |
|
|
|
if text_column not in dataset.column_names: |
|
print(f"Error: Text column '{text_column}' not found. Available: {dataset.column_names}") |
|
return None |
|
dataset_size = len(dataset) |
|
actual_num_samples = min(num_samples, dataset_size) |
|
if actual_num_samples>100: |
|
actual_num_samples = 100 |
|
dataset= dataset.shuffle(seed=42).select(range(actual_num_samples)) |
|
|
|
sampled_texts = [item[text_column].strip() for item in dataset] |
|
print(f"Successfully sampled {len(sampled_texts)} non-empty text examples.") |
|
return sampled_texts |
|
|
|
def main(): |
|
|
|
texts_to_process = get_texts_from_dataset(DATASET_NAME, DATASET_CONFIG_NAME, TEXT_COLUMN_NAME, NUM_TEXT_SAMPLES) |
|
if not texts_to_process: |
|
print("No texts to process. Exiting.") |
|
return |
|
|
|
|
|
print(f"Loading model: {MODEL_NAME}") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
except Exception as e: |
|
print(f"Error loading model or tokenizer: {e}"); return |
|
|
|
model.eval() |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
print(f"Model loaded on {device}") |
|
|
|
num_layers = 0 |
|
if hasattr(model, 'model') and hasattr(model.model, 'layers'): |
|
num_layers = len(model.model.layers) |
|
else: |
|
print("Model structure does not match expected Llama-like structure."); return |
|
|
|
print("Registering hooks on q_proj and k_proj layers...") |
|
for i, layer_module in enumerate(model.model.layers): |
|
if hasattr(layer_module, 'self_attn') and \ |
|
hasattr(layer_module.self_attn, 'q_proj') and \ |
|
hasattr(layer_module.self_attn, 'k_proj'): |
|
|
|
q_hook_fn = functools.partial(raw_projection_capture_hook, hook_captured_raw_q, layer_idx=i) |
|
q_handle = layer_module.self_attn.q_proj.register_forward_hook(q_hook_fn) |
|
hook_handles.append(q_handle) |
|
|
|
k_hook_fn = functools.partial(raw_projection_capture_hook, hook_captured_raw_k, layer_idx=i) |
|
k_handle = layer_module.self_attn.k_proj.register_forward_hook(k_hook_fn) |
|
hook_handles.append(k_handle) |
|
print(f"Registered {len(hook_handles)} hooks.") |
|
|
|
|
|
print("\n--- Starting Data Collection and QK Pair Sampling ---") |
|
for text_idx, current_text in enumerate(texts_to_process): |
|
print(f"Processing text sample {text_idx + 1}/{len(texts_to_process)}") |
|
|
|
hook_captured_raw_q.clear() |
|
hook_captured_raw_k.clear() |
|
|
|
inputs = tokenizer(current_text, return_tensors="pt", padding="longest", truncation=True, max_length=MAX_SEQ_LENGTH).to(device) |
|
|
|
with torch.no_grad(): |
|
try: |
|
model(**inputs) |
|
except Exception as e: |
|
print(f" Error during model forward pass: {e}"); continue |
|
|
|
for layer_idx in range(num_layers): |
|
if layer_idx not in hook_captured_raw_q or layer_idx not in hook_captured_raw_k: |
|
continue |
|
|
|
raw_q_proj = hook_captured_raw_q[layer_idx].to(device) |
|
raw_k_proj = hook_captured_raw_k[layer_idx].to(device) |
|
|
|
attn_module = model.model.layers[layer_idx].self_attn |
|
bsz, seq_len, _ = raw_q_proj.shape |
|
|
|
q_reshaped = raw_q_proj.view(bsz, seq_len, attn_module.num_heads, attn_module.head_dim).transpose(1, 2) |
|
k_reshaped = raw_k_proj.view(bsz, seq_len, attn_module.num_key_value_heads, attn_module.head_dim).transpose(1, 2) |
|
|
|
possible_indices = [] |
|
for b_idx in range(bsz): |
|
for i_idx in range(seq_len): |
|
for j_idx in range(seq_len): |
|
if i_idx == j_idx and seq_len == 1: |
|
possible_indices.append((b_idx, i_idx, j_idx)) |
|
elif i_idx != j_idx : |
|
possible_indices.append((b_idx, i_idx, j_idx)) |
|
|
|
num_to_sample = min(len(possible_indices), SAMPLED_QK_PAIRS_PER_LAYER_PER_TEXT) |
|
if not possible_indices or num_to_sample == 0: continue |
|
|
|
selected_indices = random.sample(possible_indices, num_to_sample) |
|
|
|
for b_sel, i_sel, j_sel in selected_indices: |
|
q_vec = q_reshaped[b_sel, :, i_sel, :].detach().clone() |
|
k_vec = k_reshaped[b_sel, :, j_sel, :].detach().clone() |
|
all_sampled_pre_rope_qk_vec_pairs.append({ |
|
'layer_idx': layer_idx, |
|
'q_vec': q_vec, |
|
'k_vec': k_vec, |
|
'num_heads': attn_module.num_heads, |
|
'head_dim': attn_module.head_dim, |
|
'num_kv_heads': attn_module.num_key_value_heads |
|
}) |
|
print(f"\nCollected {len(all_sampled_pre_rope_qk_vec_pairs)} (Q,K) vector pairs for RoPE simulation.") |
|
|
|
|
|
print("\n--- Starting RoPE Simulation and Dot Product Calculation ---") |
|
if not all_sampled_pre_rope_qk_vec_pairs: |
|
print("No QK pairs were sampled. Exiting simulation phase.") |
|
else: |
|
layer_rope_caches = {} |
|
model_dtype = model.dtype |
|
|
|
for layer_idx in range(num_layers): |
|
attn_module = model.model.layers[layer_idx].self_attn |
|
rotary_emb = attn_module.rotary_emb |
|
|
|
max_pos_for_cache = REFERENCE_Q_POS + 1 |
|
|
|
dummy_x_for_rope = torch.zeros( |
|
(1, 1, max_pos_for_cache, attn_module.head_dim), |
|
device=device, |
|
dtype=model_dtype |
|
) |
|
position_ids_for_cache = torch.arange( |
|
max_pos_for_cache, |
|
device=device, |
|
dtype=torch.long |
|
).unsqueeze(0) |
|
|
|
cos_cached_for_layer, sin_cached_for_layer = rotary_emb(dummy_x_for_rope, position_ids_for_cache) |
|
layer_rope_caches[layer_idx] = (cos_cached_for_layer, sin_cached_for_layer) |
|
|
|
for pair_info in all_sampled_pre_rope_qk_vec_pairs: |
|
layer_idx = pair_info['layer_idx'] |
|
q_vec_pre_rope = pair_info['q_vec'] |
|
k_vec_pre_rope = pair_info['k_vec'] |
|
num_h, head_dim = pair_info['num_heads'], pair_info['head_dim'] |
|
num_kv_h = pair_info['num_kv_heads'] |
|
num_key_value_groups = num_h // num_kv_h |
|
|
|
cos_cache_full, sin_cache_full = layer_rope_caches[layer_idx] |
|
|
|
q_for_rope = q_vec_pre_rope.unsqueeze(0).unsqueeze(2) |
|
k_for_rope = k_vec_pre_rope.unsqueeze(0).unsqueeze(2) |
|
|
|
for hypo_rel_pos in range(1, MAX_HYPOTHETICAL_RELATIVE_POS + 1): |
|
pos_q_prime = REFERENCE_Q_POS |
|
pos_k_prime = REFERENCE_Q_POS - hypo_rel_pos |
|
|
|
if pos_k_prime < 0: continue |
|
if pos_q_prime >= cos_cache_full.shape[1] or pos_k_prime >= cos_cache_full.shape[1]: |
|
print(f"Warning: Position out of bounds for cached RoPE. pos_q_prime={pos_q_prime}, pos_k_prime={pos_k_prime}, cache_len={cos_cache_full.shape[1]}") |
|
continue |
|
|
|
cos_q = cos_cache_full[:, pos_q_prime:pos_q_prime+1, :] |
|
sin_q = sin_cache_full[:, pos_q_prime:pos_q_prime+1, :] |
|
q_rope = (q_for_rope * cos_q) + (rotate_half(q_for_rope) * sin_q) |
|
|
|
cos_k = cos_cache_full[:, pos_k_prime:pos_k_prime+1, :] |
|
sin_k = sin_cache_full[:, pos_k_prime:pos_k_prime+1, :] |
|
k_rope = (k_for_rope * cos_k) + (rotate_half(k_for_rope) * sin_k) |
|
|
|
q_final = q_rope.squeeze(2).squeeze(0) |
|
k_final = k_rope.squeeze(2).squeeze(0) |
|
|
|
if num_key_value_groups > 1: |
|
k_final_expanded = k_final.repeat_interleave(num_key_value_groups, dim=0) |
|
else: |
|
k_final_expanded = k_final |
|
|
|
dot_products_all_heads = torch.sum(q_final * k_final_expanded, dim=-1) |
|
avg_dot_product = dot_products_all_heads.mean().item() |
|
dot_products_by_hypo_rel_pos[layer_idx][hypo_rel_pos].append(avg_dot_product) |
|
|
|
|
|
for head_idx in range(num_h): |
|
dot_product_for_head = dot_products_all_heads[head_idx].item() |
|
dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][hypo_rel_pos].append(dot_product_for_head) |
|
|
|
|
|
|
|
print("\n--- Plotting Combined Results ---") |
|
plt.figure(figsize=(15, 8)) |
|
layers_plotted_combined = 0 |
|
for layer_idx in range(num_layers): |
|
if not dot_products_by_hypo_rel_pos[layer_idx]: continue |
|
layers_plotted_combined +=1 |
|
avg_dot_products_for_layer = {} |
|
rel_positions_sorted = sorted(dot_products_by_hypo_rel_pos[layer_idx].keys()) |
|
if not rel_positions_sorted: continue |
|
for rel_pos in rel_positions_sorted: |
|
if dot_products_by_hypo_rel_pos[layer_idx][rel_pos]: |
|
avg_dot_products_for_layer[rel_pos] = sum(dot_products_by_hypo_rel_pos[layer_idx][rel_pos]) / len(dot_products_by_hypo_rel_pos[layer_idx][rel_pos]) |
|
if not avg_dot_products_for_layer: continue |
|
x_vals = list(avg_dot_products_for_layer.keys()) |
|
y_vals = list(avg_dot_products_for_layer.values()) |
|
plt.plot(x_vals, y_vals, marker='.', linestyle='-', label=f'Layer {layer_idx}', alpha=0.7) |
|
|
|
if layers_plotted_combined > 0: |
|
plt.title(f'Avg QK Dot Product (Simulated RoPE) vs. Hypothetical Relative Position (All Layers)') |
|
plt.xlabel('Hypothetical Relative Position (d = Q_pos - K_pos)') |
|
plt.ylabel('Average Dot Product (across heads & samples)') |
|
plt.grid(True) |
|
plt.xlim(0, MAX_HYPOTHETICAL_RELATIVE_POS + 1) |
|
|
|
num_legend_cols = math.ceil(layers_plotted_combined / 20) if layers_plotted_combined > 0 else 1 |
|
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=num_legend_cols) |
|
plt.tight_layout(rect=[0, 0, 0.85 if num_legend_cols <=2 else 0.75, 1]) |
|
|
|
svg_filename = f"all_layers_simulated_rope_qk_dot_product_corrected.svg" |
|
pdf_filename = f"all_layers_simulated_rope_qk_dot_product_corrected.pdf" |
|
plt.savefig(svg_filename, format='svg') |
|
plt.savefig(pdf_filename, format='pdf') |
|
print(f"Saved combined plot to {svg_filename} and {pdf_filename}") |
|
plt.close() |
|
else: |
|
print("No data collected for any layer to plot (combined).") |
|
|
|
|
|
print("\n--- Plotting Per-Head Results ---") |
|
per_head_output_dir = "per_head" |
|
os.makedirs(per_head_output_dir, exist_ok=True) |
|
|
|
for layer_idx in range(num_layers): |
|
if not dot_products_per_head_by_hypo_rel_pos[layer_idx]: |
|
print(f"No per-head data for layer {layer_idx} to plot.") |
|
continue |
|
|
|
plt.figure(figsize=(15, 8)) |
|
|
|
attn_module = model.model.layers[layer_idx].self_attn |
|
num_heads_in_layer = attn_module.num_heads |
|
|
|
heads_plotted_for_layer = 0 |
|
for head_idx in range(num_heads_in_layer): |
|
if not dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx]: |
|
continue |
|
|
|
avg_dot_products_for_head = {} |
|
rel_positions_sorted = sorted(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx].keys()) |
|
if not rel_positions_sorted: |
|
continue |
|
|
|
for rel_pos in rel_positions_sorted: |
|
if dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]: |
|
avg_dot_products_for_head[rel_pos] = sum(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]) / len(dot_products_per_head_by_hypo_rel_pos[layer_idx][head_idx][rel_pos]) |
|
|
|
if not avg_dot_products_for_head: |
|
continue |
|
|
|
x_vals = list(avg_dot_products_for_head.keys()) |
|
y_vals = list(avg_dot_products_for_head.values()) |
|
plt.plot(x_vals, y_vals, marker='.', linestyle='-', label=f'Head {head_idx}', alpha=0.6) |
|
heads_plotted_for_layer += 1 |
|
|
|
if heads_plotted_for_layer > 0: |
|
plt.title(f'Layer {layer_idx}: Avg QK Dot Product vs. Hypothetical Relative Position (Per Head)') |
|
plt.xlabel('Hypothetical Relative Position (d = Q_pos - K_pos)') |
|
plt.ylabel('Average Dot Product (across samples)') |
|
plt.grid(True) |
|
plt.xlim(0, MAX_HYPOTHETICAL_RELATIVE_POS + 1) |
|
|
|
num_legend_cols_per_head = math.ceil(heads_plotted_for_layer / 20) if heads_plotted_for_layer > 0 else 1 |
|
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=num_legend_cols_per_head) |
|
plt.tight_layout(rect=[0, 0, 0.85 if num_legend_cols_per_head <= 2 else 0.75, 1]) |
|
|
|
svg_filename_per_head = os.path.join(per_head_output_dir, f"layer_{layer_idx}_per_head_simulated_rope_qk_dot_product.svg") |
|
pdf_filename_per_head = os.path.join(per_head_output_dir, f"layer_{layer_idx}_per_head_simulated_rope_qk_dot_product.pdf") |
|
plt.savefig(svg_filename_per_head, format='svg') |
|
plt.savefig(pdf_filename_per_head, format='pdf') |
|
print(f"Saved per-head plot for layer {layer_idx} to {svg_filename_per_head} and {pdf_filename_per_head}") |
|
plt.close() |
|
else: |
|
print(f"No data plotted for any head in layer {layer_idx}.") |
|
plt.close() |
|
|
|
print("\nRemoving hooks...") |
|
for handle in hook_handles: handle.remove() |
|
print("Hooks removed. Script finished.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|