#!/usr/bin/env python3 import os # prevent HF tokenizers threads from hanging the process os.environ["TOKENIZERS_PARALLELISM"] = "false" import sys import torch import torch.nn as nn import torch.nn.functional as F import wandb from torch.utils.data import TensorDataset, DataLoader from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup from peft import PeftModel from torch.cuda.amp import GradScaler, autocast from tqdm.auto import tqdm from multiprocessing import freeze_support import shutil # Import shutil for removing old checkpoints import collections # Import collections for deque def main(): # --- Config --- PRET_FILE = "pretokenized_queries.pt" MODEL_NAME = "google/gemma-3-1b-pt" LORA_DIR = "phase2_triplet_amp/final" # Adapters from previous stage BATCH_SIZE = 200 LR = 1e-5 WEIGHT_DECAY = 0.01 NUM_EPOCHS = 1 # As per our discussion, 1 epoch is likely sufficient given fast convergence TEMP = 0.05 OUTPUT_DIR = "phase3_self_contrast_wandb" GRAD_CLIP_NORM = 1.0 SEED = 42 WANDB_PROJECT = "query-encoder-phase3" # --- Checkpointing Configuration --- SAVE_INTERVAL = 1000 # Save a checkpoint every N steps KEEP_LAST_CKPTS = 5 # Keep only the last N checkpoints (to save disk space) os.makedirs(OUTPUT_DIR, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(SEED) # --- Initialize WandB --- wandb.init( project=WANDB_PROJECT, config={ "model_name": MODEL_NAME, "lora_dir": LORA_DIR, "batch_size": BATCH_SIZE, "lr": LR, "num_epochs": NUM_EPOCHS, "seed": SEED, "save_interval_steps": SAVE_INTERVAL, "keep_last_checkpoints": KEEP_LAST_CKPTS, } ) # --- Load pretokenized queries safely --- print(f"Loading pretokenized queries from {PRET_FILE}...") data = torch.load(PRET_FILE, weights_only=True) input_ids = data["input_ids"] attention_mask = data["attention_mask"] dataset = TensorDataset(input_ids, attention_mask) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) print(f"Loaded {len(dataset)} samples.") # --- Load base model + LoRA adapters from previous stage --- print(f"Loading base model '{MODEL_NAME}' and LoRA adapters from '{LORA_DIR}'...") base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager") peft = PeftModel.from_pretrained(base, LORA_DIR).to(device) print("LoRA adapters loaded.") # --- Projection head now outputs hidden_size --- class GemmaSelfContrast(nn.Module): def __init__(self, peft_model): super().__init__() self.peft = peft_model hs = peft_model.base_model.config.hidden_size self.proj = nn.Sequential( nn.Linear(hs, 512), nn.ReLU(), nn.Linear(512, hs), ) def forward(self, ids, mask): out = self.peft.base_model( input_ids=ids, attention_mask=mask, output_hidden_states=True, return_dict=True ) h = out.hidden_states[-1].mean(dim=1) h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6) z = self.proj(h) # now (B, hidden_size) z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) return z / norm model = GemmaSelfContrast(peft).to(device) print("Encoder model (with projection head) initialized.") # Watch the model with wandb (optional, can be slow, but good for tracking gradients) # wandb.watch(model, log="all", log_freq=100) # Commented out due to potential slowdown # --- Optimizer, scheduler, AMP scaler --- optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) total_steps = len(loader) * NUM_EPOCHS scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps ) scaler = GradScaler() print(f"Training will run for {total_steps} steps.") # Deque to manage checkpoint paths and enforce keeping only the last N checkpoint_paths = collections.deque(maxlen=KEEP_LAST_CKPTS) # --- Training loop --- model.train() global_step = 0 for epoch in range(1, NUM_EPOCHS + 1): total_loss = 0.0 pbar = tqdm(loader, desc=f"Epoch {epoch}", unit="batch") for ids, mask in pbar: ids, mask = ids.to(device), mask.to(device) with autocast(): e1 = model(ids, mask) e2 = model(ids, mask) emb = torch.cat([e1, e2], dim=0) sim = (emb @ emb.T) / TEMP # mask diagonal with -inf mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool) sim = sim.masked_fill(mask_eye, float('-inf')) B = e1.size(0) labels = torch.cat([ torch.arange(B, device=device) + B, torch.arange(B, device=device) ], dim=0) loss = F.cross_entropy(sim, labels) optimizer.zero_grad() scaler.scale(loss).backward() scaler.unscale_(optimizer) # Unscale gradients before clipping torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) scaler.step(optimizer) scaler.update() scheduler.step() # --- Log metrics to WandB at every step --- wandb.log({ "train/loss": loss.item(), "train/lr": scheduler.get_last_lr()[0], "train/epoch": epoch, "train/global_step": global_step }, step=global_step) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) # --- PERIODIC SAVING BLOCK --- # Save checkpoint every SAVE_INTERVAL steps if (global_step + 1) % SAVE_INTERVAL == 0: # Create a unique directory for this checkpoint ckpt_dir = os.path.join(OUTPUT_DIR, f"checkpoint-step-{global_step + 1}") os.makedirs(ckpt_dir, exist_ok=True) print(f"\nSaving checkpoint to {ckpt_dir}...") # Save the PEFT adapters peft.save_pretrained(ckpt_dir) # Save the trained projection head's state dictionary torch.save(model.proj.state_dict(), os.path.join(ckpt_dir, "encoder_proj.pth")) # Manage old checkpoints if len(checkpoint_paths) == KEEP_LAST_CKPTS: oldest_ckpt = checkpoint_paths.popleft() # Remove the oldest path from deque if os.path.isdir(oldest_ckpt): print(f"Removing old checkpoint: {oldest_ckpt}") shutil.rmtree(oldest_ckpt, ignore_errors=True) # Delete the directory checkpoint_paths.append(ckpt_dir) # Add new checkpoint path print("Checkpoint saved and old ones managed.") # --- END PERIODIC SAVING --- global_step += 1 total_loss += loss.item() avg_loss = total_loss / len(loader) print(f"Epoch {epoch} training complete. Avg loss: {avg_loss:.6f}") # Log average epoch loss as well wandb.log({"train/epoch_avg_loss": avg_loss, "epoch": epoch}, step=global_step) # --- Final Save for the "final" directory --- # This ensures that even if you stop mid-epoch (after a checkpoint) # or don't stop, there's always a clear 'final' model. print("\nTraining finished. Saving final model to 'final' directory...") final_dir = os.path.join(OUTPUT_DIR, "final") os.makedirs(final_dir, exist_ok=True) # Save the LoRA adapters peft.save_pretrained(final_dir) # Save the trained projection head's state dictionary torch.save(model.proj.state_dict(), os.path.join(final_dir, "encoder_proj.pth")) print(f"Phase 3 complete. LoRA adapters and projection head saved to {final_dir}") # --- Finalize WandB run --- wandb.finish() if __name__ == "__main__": freeze_support() main()