import os import torch from datasets import load_from_disk, concatenate_datasets, Dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel from peft.tuners.lora import LoraLayer from trl import SFTTrainer, SFTConfig import logging import torch.distributed as dist from datetime import timedelta, datetime import time from transformers.trainer import TrainerCallback import gc import sys import shutil # For handling file operations import glob # For file pattern matching import threading # For background cleanup import multiprocessing import subprocess import tempfile import json import random import math import queue import numpy as np # Import the specific layer class for FSDP wrapping try: from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer except ImportError: logging.warning("Could not import Qwen2DecoderLayer. FSDP wrapping might fail.") Qwen2DecoderLayer = None # Configure more detailed logging with timestamps logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', stream=sys.stdout, # Ensure logs go to stdout for immediate visibility force=True ) # Set up temporary directory for cache files temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp") os.makedirs(temp_dir, exist_ok=True) logging.info(f"Using temporary directory: {temp_dir}") # Set environment variables to control temporary file creation os.environ["TMPDIR"] = temp_dir # Unix os.environ["TEMP"] = temp_dir # Windows os.environ["TMP"] = temp_dir # Windows alternative # Set default cache locations hf_datasets_cache_path = os.path.join(temp_dir, "hf_datasets_cache") transformers_cache_path = os.path.join(temp_dir, "transformers_cache") hf_home_path = os.path.join(temp_dir, "hf_home") os.makedirs(hf_datasets_cache_path, exist_ok=True) os.makedirs(transformers_cache_path, exist_ok=True) os.makedirs(hf_home_path, exist_ok=True) os.environ["HF_DATASETS_CACHE"] = hf_datasets_cache_path os.environ["TRANSFORMERS_CACHE"] = transformers_cache_path os.environ["HF_HOME"] = hf_home_path logging.info(f"Hugging Face Datasets cache directed to: {hf_datasets_cache_path}") logging.info(f"Hugging Face Transformers cache directed to: {transformers_cache_path}") # Keep forcing Arrow to use system memory pool if possible os.environ["ARROW_DEFAULT_MEMORY_POOL"] = "system" logging.info("Configured temporary directory and cache locations.") # Set environment variable to control PyTorch's memory allocator os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512" # Disable PYTORCH_NO_CUDA_MEMORY_CACHING for better performance if "PYTORCH_NO_CUDA_MEMORY_CACHING" in os.environ: del os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] # Set a longer timeout for NCCL operations os.environ["NCCL_BLOCKING_WAIT"] = "1" os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" os.environ["NCCL_TIMEOUT"] = "3600" # 1 hour timeout for NCCL operations # Initialize distributed environment with better error handling def init_distributed(): try: # Check if we're in a distributed training environment if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1: # Set memory optimization environment variables if int(os.environ.get("LOCAL_RANK", 0)) == 0: logging.info("Setting PyTorch memory optimizations for H200 GPUs") # Empty CUDA cache before initializing process group if torch.cuda.is_available(): torch.cuda.empty_cache() logging.info("CUDA cache cleared") local_rank = int(os.environ.get("LOCAL_RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) rank = int(os.environ.get("RANK", 0)) logging.info(f"Initializing distributed training for 8x H200s. Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size}") # Set the device for this process explicitly before initializing torch.cuda.set_device(local_rank) logging.info(f"Setting device {local_rank} for process rank {rank}") # Set a longer timeout to handle long operations (3 hours) timeout = timedelta(hours=3) # Initialize the distributed process group dist.init_process_group( backend='nccl', init_method='env://', timeout=timeout, rank=rank, world_size=world_size ) # Verify initialization was successful if dist.is_initialized(): logging.info(f"Successfully initialized distributed process group. Rank: {rank}, Device: {torch.cuda.current_device()}") # Log NCCL environment logging.info(f"NCCL Version: {torch.cuda.nccl.version() if hasattr(torch.cuda, 'nccl') else 'unknown'}") logging.info(f"CUDA Device Count: {torch.cuda.device_count()}") logging.info(f"CUDA Device Name: {torch.cuda.get_device_name(local_rank)}") else: logging.error(f"Failed to initialize distributed process group. Rank: {rank}") # Ensure all processes can communicate with specified device try: device_ids = [local_rank] dist.barrier(device_ids=device_ids) logging.info(f"Communication test successful. Process {rank} on device {local_rank} can communicate.") except Exception as e: logging.error(f"Communication test failed. Processes cannot communicate: {str(e)}. Rank: {rank}") raise return True else: logging.info("Not running in distributed mode.") return False except Exception as e: logging.error(f"Error initializing distributed environment: {str(e)}") raise # Initialize distributed environment distributed_mode = init_distributed() # --- Configuration --- # Model ID updated based on user input MODEL_ID = "Qwen/QwQ-32B" # Path to the processed dataset created by preprocess_data.py DATASET_PATH = "./processed_datasets/combined_code_finetune_data" # Number of examples to use (set to -1 for all) MAX_EXAMPLES = -1 # Use all examples by default # LoRA configuration (Optimized for 8x H200 GPUs) LORA_R = 64 # Doubled to increase parameter count significantly LORA_ALPHA = 128 # Increased alpha to match r LORA_DROPOUT = 0.05 # Dropout probability for LoRA layers # Target modules might need verification for QwQ-32B specifically. # Common targets for Qwen models: LORA_TARGET_MODULES = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", # "embed_tokens", # Removed to reduce overhead/complexity # "lm_head", # Removed to reduce overhead/complexity ] # Training arguments optimized for 8x H200 GPUs with memory constraints OUTPUT_DIR = "./qwq-32b-finetuned-adapters" PER_DEVICE_TRAIN_BATCH_SIZE = 8 # Increase BS after halving seq length again GRADIENT_ACCUMULATION_STEPS = 6 # Decrease accumulation (8*8*6 = 384) # Global batch size = PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * NumGPUs # Example: 8 * 6 * 8 = 384 LEARNING_RATE = 3e-5 # Slightly higher LR for larger batch size EPOCHS = 1 # Start with 1 epoch, increase cautiously MAX_SEQ_LENGTH = 4096 # Halved sequence length again LOGGING_STEPS = 50 # Increased logging frequency SAVE_STEPS = 500 # Increased save frequency OPTIMIZER = "adamw_bnb_8bit" # Use 8-bit optimizer to save significant memory WARMUP_RATIO = 0.03 LR_SCHEDULER_TYPE = "cosine" # H200-specific optimizations (8x setup) USE_FLASH_ATTN = True # Enable Flash Attention 2 for H200s USE_SEQUENCE_PARALLEL = False # Disable when using FSDP USE_BETTER_TRANSFORMERS = True # Use better transformers for optimized kernels DATALOADER_NUM_WORKERS = 8 # Reduced workers to avoid CPU contention TOKENIZATION_NUM_WORKERS = 224 # Maximum worker count for tokenization USE_ACTIVATION_CHECKPOINTING = True # Enable activation checkpointing to save memory with long sequences # Advanced distributed training options for 8x GPUs USE_FSDP = True # Enable FSDP FSDP_CONFIG = { "fsdp_offload_params": False, # Disable CPU Offload "fsdp_sharding_strategy": 1, # 1 = FULL_SHARD "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_transformer_layer_cls_to_wrap": [Qwen2DecoderLayer.__name__] if Qwen2DecoderLayer else [], "fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_backward_prefetch": "backward_post", # Changed from backward_pre "fsdp_forward_prefetch": False, # Disabled forward prefetch "fsdp_activation_checkpointing": [Qwen2DecoderLayer.__name__] if Qwen2DecoderLayer else [], # Use FSDP activation checkpointing } # WandB Integration REPORT_TO_WANDB = True # Set to False to disable WandB reporting WANDB_PROJECT_NAME = "QwQ-32B-Finetune-8xH200" # Updated for 8x GPUs WANDB_ENTITY = None # Set to your username or team name if needed # Determine report_to destination report_to = "none" if REPORT_TO_WANDB: # Disable WandB in all processes except rank 0 in distributed mode if distributed_mode and int(os.environ.get("LOCAL_RANK", 0)) != 0: logging.info(f"Rank {os.environ.get('RANK', '?')}: Disabling WandB") os.environ["WANDB_DISABLED"] = "true" report_to = "none" # Explicitly set to none for non-main processes else: # Main process or non-distributed mode, attempt WandB initialization try: import wandb logging.info("Initializing WandB directly...") run_name = f"qwq-32b-finetune-{datetime.now().strftime('%Y%m%d-%H%M%S')}" if wandb.run is None: try: wandb.init( project=WANDB_PROJECT_NAME, entity=WANDB_ENTITY, name=run_name, config={ "model_name": MODEL_ID, "batch_size": PER_DEVICE_TRAIN_BATCH_SIZE, "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS, "learning_rate": LEARNING_RATE, "epochs": EPOCHS, "sequence_length": MAX_SEQ_LENGTH, "lora_r": LORA_R, "lora_alpha": LORA_ALPHA, } ) logging.info(f"WandB initialized: {wandb.run.name} (ID: {wandb.run.id})") report_to = "wandb" except Exception as e: logging.error(f"WandB initialization error: {str(e)}") report_to = "tensorboard" else: logging.info(f"Using existing WandB run: {wandb.run.name} (ID: {wandb.run.id})") report_to = "wandb" except ImportError: logging.warning("WandB package not installed. Reporting to TensorBoard.") report_to = "tensorboard" except Exception as wandb_init_e: logging.error(f"General WandB setup error: {wandb_init_e}") report_to = "tensorboard" # If WandB reporting is disabled, set report_to accordingly elif not distributed_mode: report_to = "tensorboard" logging.info("WandB reporting disabled. Reporting to TensorBoard.") else: # If WandB is disabled and it IS distributed report_to = "none" logging.info("WandB reporting disabled for this distributed rank.") # Quantization (QLoRA) USE_4BIT_QUANTIZATION = False # Disable QLoRA due to FSDP incompatibility BNB_4BIT_COMPUTE_DTYPE = "bfloat16" # Use bfloat16 if supported, else float16 BNB_4BIT_QUANT_TYPE = "nf4" # --- Check Optional Dependencies (Define flags globally) --- FLASH_ATTN_AVAILABLE = False BETTER_TRANSFORMERS_AVAILABLE = False try: import flash_attn FLASH_ATTN_AVAILABLE = True logging.info("Flash Attention available - will be used if enabled.") except ImportError: logging.warning("Flash Attention not available. Install with 'pip install flash-attn'") try: from optimum.bettertransformer import BetterTransformer BETTER_TRANSFORMERS_AVAILABLE = True logging.info("Better Transformers available - will be used if enabled.") except ImportError: logging.warning("Better Transformers not available. Install with 'pip install optimum'") # --- Check Dataset --- if not os.path.exists(DATASET_PATH): logging.error(f"Dataset not found at {DATASET_PATH}. Run preprocess_data.py first.") exit(1) logging.info(f"Loading dataset from {DATASET_PATH}...") # Load dataset normally dataset = load_from_disk(DATASET_PATH) # Apply truncation if needed if MAX_EXAMPLES > 0 and len(dataset) > MAX_EXAMPLES: logging.info(f"Truncating dataset to {MAX_EXAMPLES} examples") indices = list(range(min(MAX_EXAMPLES, len(dataset)))) dataset = dataset.select(indices) logging.info(f"Dataset loaded: {dataset} with {len(dataset)} examples") # --- Tokenizer --- logging.info(f"Loading tokenizer for {MODEL_ID}...") # Enable fast tokenizer and optimizations tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, use_fast=True, # Explicitly request the fast Rust-based tokenizer trust_remote_code=True, # model_max_length=MAX_SEQ_LENGTH, padding_side="right", ) # Log tokenizer type for verification if hasattr(tokenizer, 'is_fast') and tokenizer.is_fast: logging.info(f"Successfully loaded fast tokenizer (Rust implementation): {type(tokenizer).__name__}") # Fast tokenizers are automatically parallel in dataset.map() when num_proc > 1 logging.info(f"Fast tokenizer will use parallel processing during dataset.map() with {TOKENIZATION_NUM_WORKERS} workers") else: logging.warning(f"Using Python tokenizer: {type(tokenizer).__name__}") logging.warning("Python tokenizers are slower than Rust-based fast tokenizers") # Check and set pad token based on Qwen documentation (<|endoftext|>) # Qwen models might have this set correctly, but we verify. EXPECTED_PAD_TOKEN = "<|endoftext|>" if tokenizer.pad_token is None or tokenizer.pad_token != EXPECTED_PAD_TOKEN: logging.warning(f"Tokenizer pad_token is missing or not '{EXPECTED_PAD_TOKEN}'. Setting pad_token='{EXPECTED_PAD_TOKEN}'.") tokenizer.pad_token = EXPECTED_PAD_TOKEN # Enable padding and truncation defaults for batch processing tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token tokenizer.padding_side = "right" # Typically "right" for decoder-only models like Qwen # Log tokenizer configuration logging.info(f"Tokenizer configuration:") logging.info(f" - Type: {'Fast' if hasattr(tokenizer, 'is_fast') and tokenizer.is_fast else 'Python'}") logging.info(f" - Pad token: {tokenizer.pad_token}") logging.info(f" - EOS token: {tokenizer.eos_token}") # Should be <|im_end|> logging.info(f" - Vocab size: {tokenizer.vocab_size}") logging.info(f" - Model max length: {tokenizer.model_max_length}") logging.info(f" - Padding side: {tokenizer.padding_side}") # Define parallel preprocessing function for the dataset def preprocess_function(examples): return tokenizer( examples["text"], padding="max_length", truncation=True, max_length=MAX_SEQ_LENGTH, return_tensors=None, # Return Python lists for dataset ) # Create a cache directory for tokenized datasets TOKENIZED_DATASET_CACHE_DIR = os.path.join(os.path.dirname(DATASET_PATH), "tokenized_cache") os.makedirs(TOKENIZED_DATASET_CACHE_DIR, exist_ok=True) tokenized_dataset_path = os.path.join(TOKENIZED_DATASET_CACHE_DIR, "tokenized_dataset") # Create a file to signal tokenization completion tokenization_done_file = os.path.join(TOKENIZED_DATASET_CACHE_DIR, "tokenization_complete") # Function to clean up temporary files in dataset directory def delete_existing_tmp_files(): """Find and delete any existing tmp files in dataset directory""" # Look for tmp files in dataset directory tmp_files = glob.glob(os.path.join(DATASET_PATH, "tmp*")) if tmp_files: logging.info(f"Found {len(tmp_files)} existing tmp files, removing...") for tmp_file in tmp_files: try: if os.path.isdir(tmp_file): shutil.rmtree(tmp_file) else: os.remove(tmp_file) logging.info(f"Removed: {tmp_file}") except Exception as e: logging.warning(f"Could not remove {tmp_file}: {str(e)}") else: logging.info("No existing tmp files found") # Check if we're in distributed mode and get rank if distributed_mode: rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) is_main_process = rank == 0 logging.info(f"Rank {rank}/{world_size}: Preparing for dataset processing") else: is_main_process = True rank = 0 world_size = 1 local_rank = 0 # Clean up temp files - only on main process to avoid conflicts if is_main_process: delete_existing_tmp_files() # Also remove the tokenization_done_file if it exists if os.path.exists(tokenization_done_file): os.remove(tokenization_done_file) logging.info(f"Rank {rank}: Removed old tokenization completion marker") # Only tokenize on main process (rank 0) to avoid redundant work need_tokenization = False # Check if tokenized dataset already exists if os.path.exists(tokenized_dataset_path) and os.path.isdir(tokenized_dataset_path): # --- Dataset Exists --- logging.info(f"Rank {rank}: Found existing tokenized dataset at {tokenized_dataset_path}") path_to_load = tokenized_dataset_path # All ranks will load from the persistent path need_tokenization = False # Rank 0 ensures completion marker exists if is_main_process and not os.path.exists(tokenization_done_file): total_original_examples = "unknown" try: from datasets import load_dataset_builder # Local import original_dataset_info = load_dataset_builder(DATASET_PATH).info total_original_examples = original_dataset_info.splits['train'].num_examples except Exception as info_e: logging.warning(f"Rank {rank}: Could not get original dataset info: {info_e}") try: # Get size of existing loaded dataset (approximate if needed) # This requires loading a small part or metadata, might be slow # For now, let's just mark it as existing # loaded_size = len(load_from_disk(tokenized_dataset_path, keep_in_memory=False)) loaded_size = "unknown (loaded existing)" with open(tokenization_done_file, "w") as f: f.write(f"Tokenization assumed complete (loaded existing) at {datetime.now().isoformat()}\n") f.write(f"Processed {loaded_size} examples out of {total_original_examples}\n") logging.info(f"Rank {rank}: Created tokenization completion marker as it was missing.") except Exception as file_e: logging.error(f"Rank {rank}: Failed to create missing completion marker: {file_e}") # Proceeding anyway, but other ranks might hang if they rely solely on the file # Non-main ranks still need to wait for the marker to be sure Rank 0 checked/created it elif not is_main_process: logging.info(f"Rank {rank}: Waiting for main process confirmation via marker file...") max_wait_time = 300 # Shorter wait, just confirming file exists wait_start = time.time() while not os.path.exists(tokenization_done_file): if time.time() - wait_start > max_wait_time: logging.error(f"Rank {rank}: Timed out waiting for marker file from Rank 0.") raise TimeoutError("Marker file wait timeout") time.sleep(5) logging.info(f"Rank {rank}: Marker file found.") elif is_main_process: # Tokenized doesn't exist, Rank 0 needs to create it logging.info(f"Rank {rank}: Tokenization required. Proceeding with tokenization...") need_tokenization = True path_to_load = None elif distributed_mode: # Tokenized doesn't exist, non-main ranks need to wait logging.info(f"Rank {rank}: Tokenization required. Waiting for main process...") need_tokenization = True path_to_load = tokenized_dataset_path # --- Perform Tokenization (if needed by Rank 0) --- if need_tokenization and is_main_process: tokenized_dataset_obj = None # Use a distinct name for the object returned by map try: # Process the dataset using dataset.map with internal parallelism start_time = time.time() # Define start_time here # Standard tokenization with caching enabled logging.info(f"Rank {rank}: Starting tokenization using dataset.map with {TOKENIZATION_NUM_WORKERS} workers.") tokenized_dataset_obj = dataset.map( preprocess_function, batched=True, batch_size=1000, num_proc=TOKENIZATION_NUM_WORKERS, remove_columns=["text"], load_from_cache_file=True, # Allow using cache file if it exists desc=f"Tokenizing dataset ({TOKENIZATION_NUM_WORKERS} workers)" ) elapsed = time.time() - start_time logging.info(f"Rank {rank}: Tokenization successful in {elapsed:.2f} seconds.") # If tokenization was successful: if tokenized_dataset_obj is not None: logging.info(f"Rank {rank}: Dataset tokenization completed.") # Save directly to final path logging.info(f"Rank {rank}: Saving tokenized dataset to {tokenized_dataset_path}...") save_start = time.time() # Ensure target directory doesn't exist (needed for clean save) if os.path.exists(tokenized_dataset_path): shutil.rmtree(tokenized_dataset_path) tokenized_dataset_obj.save_to_disk(tokenized_dataset_path) save_elapsed = time.time() - save_start logging.info(f"Rank {rank}: Tokenized dataset saved in {save_elapsed:.2f} seconds.") # Create completion marker file ONLY after successful save with open(tokenization_done_file, "w") as f: f.write(f"Tokenization completed and saved at {datetime.now().isoformat()}\n") logging.info(f"Rank {rank}: Created tokenization completion marker") # Keep the result in memory for Rank 0 for immediate use dataset = tokenized_dataset_obj path_to_load = None # Rank 0 uses the in-memory object directly except Exception as e: logging.error(f"Rank {rank}: Tokenization failed: {e}") import traceback logging.error(traceback.format_exc()) # Create done file indicating failure with open(tokenization_done_file, "w") as f: f.write(f"Tokenization FAILED at {datetime.now().isoformat()}\nError: {e}") raise RuntimeError("Tokenization failed.") from e # --- Load Dataset (All Ranks) --- # This block now runs for all ranks *after* rank 0 has either tokenized or copied data dataset_for_trainer = None # Use a distinct variable name for clarity if path_to_load: # If path_to_load is set (means rank 0 copied or non-main rank needs to load) if not is_main_process and need_tokenization: # Non-main ranks wait for the done file if tokenization was required logging.info(f"Rank {rank}: Waiting for tokenization completion signal (already checked for existence)...") # Wait logic already happened if we got here and path_to_load is set pass # All ranks with a path_to_load proceed to load logging.info(f"Rank {rank}: Loading dataset from {path_to_load}...") load_start_time = time.time() try: # Load without forcing into memory initially dataset_for_trainer = load_from_disk(path_to_load, keep_in_memory=False) load_elapsed = time.time() - load_start_time logging.info(f"Rank {rank}: Successfully loaded dataset in {load_elapsed:.2f}s. Length: {len(dataset_for_trainer)}") except Exception as e: logging.error(f"Rank {rank}: CRITICAL - Failed to load dataset from {path_to_load}: {e}") raise elif is_main_process and not need_tokenization: # Rank 0 loaded existing, copied to RAM disk, and path_to_load points there # It still needs to load it for the trainer logging.info(f"Rank {rank}: Loading dataset from RAM disk copy {path_to_load}...") try: dataset_for_trainer = load_from_disk(path_to_load, keep_in_memory=False) logging.info(f"Rank {rank}: Successfully loaded dataset from RAM disk copy.") except Exception as e: logging.error(f"Rank {rank}: CRITICAL - Failed to load from RAM disk copy {path_to_load}: {e}") raise elif is_main_process and need_tokenization: # Rank 0 just tokenized, 'dataset' variable already holds the result in memory logging.info(f"Rank {rank}: Using in-memory dataset from successful tokenization.") dataset_for_trainer = dataset # Use the object directly else: # Should not happen logging.error(f"Rank {rank}: Dataset path logic error. path_to_load='{path_to_load}', need_tokenization={need_tokenization}") raise RuntimeError("Dataset preparation failed - logic error.") # At this point, 'dataset' on all ranks should hold the ready-to-use data. # Synchronize processes after dataset is ready on all ranks if distributed_mode: try: logging.info(f"Rank {rank}: Synchronizing after dataset preparation...") dist.barrier() logging.info(f"Rank {rank}: Synchronization complete.") except Exception as sync_e: logging.error(f"Rank {rank}: Synchronization after dataset prep failed: {sync_e}") raise # --- Helper Function for Memory Check --- def check_gpu_memory_utilization(): """Check and report GPU memory utilization""" if not torch.cuda.is_available(): logging.info("CUDA not available, skipping GPU memory check.") return 0 # Return 0 utilization if no GPU logging.info("==== GPU MEMORY UTILIZATION CHECK ====") total_allocated_gb = 0 total_reserved_gb = 0 total_capacity_gb = 0 try: for i in range(torch.cuda.device_count()): free_mem, total_mem = torch.cuda.mem_get_info(i) allocated = torch.cuda.memory_allocated(i) reserved = torch.cuda.memory_reserved(i) free_gb = free_mem / (1024**3) total_gb = total_mem / (1024**3) allocated_gb = allocated / (1024**3) reserved_gb = reserved / (1024**3) utilized_pct = (1 - free_mem/total_mem) * 100 if total_mem > 0 else 0 total_allocated_gb += allocated_gb total_reserved_gb += reserved_gb total_capacity_gb += total_gb logging.info(f"GPU {i}: Allocated {allocated_gb:.1f}GB, Reserved {reserved_gb:.1f}GB, " f"Free {free_gb:.1f}GB, Total {total_gb:.1f}GB, " f"Utilization: {utilized_pct:.1f}%") avg_utilization = (total_allocated_gb / total_capacity_gb) * 100 if total_capacity_gb > 0 else 0 logging.info(f"OVERALL: Using {total_allocated_gb:.1f}GB / {total_capacity_gb:.1f}GB ({avg_utilization:.1f}% allocated)") logging.info("========================================") return avg_utilization except Exception as e: logging.error(f"Error checking GPU memory: {e}") return 0 # Return 0 on error # --- Model Loading & Preparation (Runs on ALL ranks) --- logging.info(f"Rank {rank}: Loading model: {MODEL_ID}...") # 1. Load Model Configuration config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) logging.info("Enabling YaRN scaling in model configuration.") config.rope_scaling = { "type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768, } # Determine torch dtype torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # Set device_map based on distributed mode # When using FSDP, device_map should typically be None or "auto", FSDP handles placement. if USE_FSDP: device_map = None logging.info("FSDP enabled: Setting device_map=None") elif distributed_mode: local_rank = int(os.environ.get("LOCAL_RANK", 0)) device_map = {"": local_rank} logging.info(f"Rank {rank}: DDP mode: Loading model on device {local_rank}") else: device_map = "auto" logging.info("Rank {rank}: Single process mode: Using automatic device mapping") # Configure Flash Attention and other optimizations use_flash_attn = USE_FLASH_ATTN and FLASH_ATTN_AVAILABLE attn_implementation = "flash_attention_2" if use_flash_attn else None # Configure Quantization if enabled # quantization_config = None # if USE_4BIT_QUANTIZATION: # logging.info("Configuring 4-bit quantization (QLoRA)...") # compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE) # quantization_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE, # bnb_4bit_compute_dtype=compute_dtype, # bnb_4bit_use_double_quant=True, # Qwen models often benefit from double quant # ) # # Override torch_dtype when using quantization as recommended # # torch_dtype = None # logging.info(f"4-bit quantization config created: type={BNB_4BIT_QUANT_TYPE}, compute={BNB_4BIT_COMPUTE_DTYPE}") # Configure model loading kwargs model_load_kwargs = { "config": config, "device_map": device_map, "low_cpu_mem_usage": True, "trust_remote_code": True, } if use_flash_attn: model_load_kwargs["attn_implementation"] = "flash_attention_2" # if quantization_config: # model_load_kwargs["quantization_config"] = quantization_config # Always set torch_dtype when not using quantization model_load_kwargs["torch_dtype"] = torch_dtype # Log memory before loading # ... (memory logging logic - keep as is) ... # Load the model model = None # Initialize model variable try: logging.info(f"Rank {rank}: Calling AutoModelForCausalLM.from_pretrained...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, **model_load_kwargs ) logging.info(f"Rank {rank}: Base model loaded successfully on device: {model.device if device_map is None else 'CPU/Multi'}") # Ensure consistent dtype before FSDP wrapping (which happens in trainer.train) if torch_dtype == torch.bfloat16: logging.info("Explicitly casting model to bfloat16...") model = model.to(torch.bfloat16) # Apply Better Transformers optimization use_better_transformers_flag = USE_BETTER_TRANSFORMERS and BETTER_TRANSFORMERS_AVAILABLE if use_better_transformers_flag: try: logging.info("Applying BetterTransformer optimizations...") model = BetterTransformer.transform(model) logging.info("BetterTransformer optimizations applied successfully") except Exception as bt_e: logging.warning(f"Could not apply BetterTransformer optimizations: {str(bt_e)}") # Apply activation checkpointing if USE_ACTIVATION_CHECKPOINTING: try: logging.info("Enabling activation checkpointing...") model.gradient_checkpointing_enable() logging.info("Activation checkpointing enabled.") except Exception as ac_e: logging.warning(f"Could not enable activation checkpointing: {str(ac_e)}") # Log model config and check memory utilization logging.info(f"Rank {rank}: Model setup complete.") check_gpu_memory_utilization() # This function needs to be defined or moved except Exception as model_load_e: # Correct indentation for except logging.error(f"Rank {rank}: Failed during model loading or preparation: {model_load_e}") import traceback logging.error(traceback.format_exc()) # Attempt to clean up distributed env before raising if distributed_mode and dist.is_initialized(): try: dist.destroy_process_group() except: pass raise # Re-raise error # --- LoRA Configuration --- # ... (LoRA config - keep as is) ... peft_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=LORA_TARGET_MODULES, bias="none", task_type="CAUSAL_LM", ) # --- Synchronize AFTER model loading & PEFT config --- if distributed_mode: try: logging.info(f"Rank {rank}: Synchronizing after model loading...") dist.barrier() logging.info(f"Rank {rank}: Synchronization after model loading complete.") except Exception as sync_e: logging.error(f"Rank {rank}: Synchronization after model loading failed: {sync_e}") raise # --- Define Training Arguments --- # (Determine determined_run_name logic here as before) determined_run_name = None if REPORT_TO_WANDB and is_main_process: try: import wandb if wandb.run is not None: determined_run_name = wandb.run.name except Exception: pass # Ignore errors here, handled by report_to base_training_args = { # ... (all base args, including max_seq_length) ... "output_dir": OUTPUT_DIR, "per_device_train_batch_size": PER_DEVICE_TRAIN_BATCH_SIZE, "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS, "optim": OPTIMIZER, "save_steps": SAVE_STEPS, "logging_steps": LOGGING_STEPS, "learning_rate": LEARNING_RATE, "num_train_epochs": EPOCHS, "max_steps": -1, "fp16": False, "bf16": torch_dtype == torch.bfloat16, # Use previously determined dtype "max_grad_norm": 0.3, "warmup_ratio": WARMUP_RATIO, "group_by_length": False, # Explicitly disable to prevent pre-computation hang "lr_scheduler_type": LR_SCHEDULER_TYPE, "report_to": report_to, "save_total_limit": 3, "logging_first_step": True, **({"run_name": determined_run_name} if determined_run_name is not None else {}), "fsdp": "full_shard" if USE_FSDP else "", # Pass FSDP strategy string (removed offload) "fsdp_config": FSDP_CONFIG if USE_FSDP else {}, # Pass FSDP config dict "dataloader_num_workers": DATALOADER_NUM_WORKERS, "resume_from_checkpoint": "auto", "save_strategy": "steps", "load_best_model_at_end": False, "metric_for_best_model": None, "dataset_text_field": "text", "packing": False, "max_seq_length": MAX_SEQ_LENGTH, # Memory/Performance Optimizations "gradient_checkpointing_kwargs": {"use_reentrant": False}, # More stable checkpointing for FSDP activation checkpointing "ddp_find_unused_parameters": False, # Should be False for FSDP "tf32": True, # Enable TF32 for faster compute on compatible GPUs } training_arguments = SFTConfig(**base_training_args) logging.info(f"Rank {rank}: Training arguments (SFTConfig) created.") # --- Define Callbacks --- # Create memory monitoring callback class MemoryMonitorCallback(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): if state.global_step % 10 == 0: # Log every 10 steps if torch.cuda.is_available(): gc.collect() torch.cuda.empty_cache() rank = int(os.environ.get("RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) try: free_mem, total_mem = torch.cuda.mem_get_info(local_rank) free_gb = free_mem / (1024**3) used_gb = (total_mem - free_mem) / (1024**3) total_gb = total_mem / (1024**3) reserved = torch.cuda.memory_reserved(local_rank) / (1024**3) allocated = torch.cuda.memory_allocated(local_rank) / (1024**3) logging.info(f"Rank {rank}: Memory at step {state.global_step}: " f"{free_gb:.1f}GB free, {used_gb:.1f}GB used, {total_gb:.1f}GB total, " f"{reserved:.1f}GB reserved, {allocated:.1f}GB allocated") except Exception as mem_e: logging.warning(f"Rank {rank}: Could not get memory info: {mem_e}") return control memory_monitor = MemoryMonitorCallback() # Create a special first step callback with WandB support class FirstStepCallback(TrainerCallback): def __init__(self): self.first_step_start_time = None self.progress_indicators = 0 self.update_interval = 60 # Check every minute self.last_update_time = time.time() def on_step_begin(self, args, state, control, **kwargs): if state.global_step == 0: self.first_step_start_time = time.time() logging.info(f"FIRST STEP STARTING at {datetime.now().strftime('%H:%M:%S')}") if REPORT_TO_WANDB and 'wandb' in sys.modules: try: import wandb # Import locally if wandb.run: wandb.log({"training_status": "first_step_started"}) except Exception as log_e: logging.warning(f"Wandb log error: {log_e}") return control def on_step_end(self, args, state, control, **kwargs): if state.global_step == 0: if self.first_step_start_time is None: # Should not happen, but safeguard logging.warning("First step ended but start time was not recorded.") return control duration = time.time() - self.first_step_start_time logging.info(f"FIRST STEP COMPLETED at {datetime.now().strftime('%H:%M:%S')} (took {duration:.2f} seconds)") if REPORT_TO_WANDB and 'wandb' in sys.modules: try: import wandb # Import locally if wandb.run: wandb.log({ "training_status": "first_step_completed", "first_step_duration": duration }) except Exception as log_e: logging.warning(f"Wandb log error: {log_e}") return control def on_substep_end(self, args, state, control, **kwargs): # This tracks progress within a step (during gradient accumulation) current_time = time.time() # Only report for the first step/substep and only from rank 0 if (self.first_step_start_time is not None and state.global_step == 0 and current_time - self.last_update_time >= self.update_interval and (not distributed_mode or int(os.environ.get("LOCAL_RANK", 0)) == 0)): self.progress_indicators += 1 elapsed = current_time - self.first_step_start_time logging.info(f"First step still in progress... ({elapsed:.1f}s elapsed, progress indicator {self.progress_indicators})") if REPORT_TO_WANDB and 'wandb' in sys.modules: try: import wandb # Import locally if wandb.run: wandb.log({ "training_status": "first_step_in_progress", "first_step_elapsed": elapsed, "progress_indicator": self.progress_indicators }) except Exception as log_e: logging.warning(f"Wandb log error: {log_e}") self.last_update_time = current_time return control first_step_callback = FirstStepCallback() # Add WandB logging callback if WandB is enabled wandb_callback = None # Initialize if REPORT_TO_WANDB and 'wandb' in sys.modules and (not distributed_mode or int(os.environ.get("LOCAL_RANK", 0)) == 0): try: # **** FULL WandBLoggingCallback Class Definition **** class WandBLoggingCallback(TrainerCallback): """Logs comprehensive training metrics and progress to Weights & Biases""" def __init__(self): self.training_start_time = None self.last_log_time = None self.total_steps = None self.samples_seen = 0 self.tokens_seen = 0 self.current_epoch = 0 self.epoch_start_time = None self.step_history = [] # For tracking steps/second self.global_tokens_per_second = 0 self.progress_table = None # Initialize table to None def on_train_begin(self, args, state, control, **kwargs): """Log hyperparameters and initialize tracking at the start of training""" if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return # Check if WandB should be used try: import wandb # Import locally if not wandb.run: logging.warning("WandBCallback: Wandb not initialized in on_train_begin.") return except ImportError: logging.warning("WandBCallback: wandb not imported, cannot log on_train_begin") return self.training_start_time = time.time() self.epoch_start_time = time.time() self.last_log_time = time.time() # Calculate total expected steps if args.max_steps > 0: self.total_steps = args.max_steps else: # Use trainer passed in kwargs if available (prioritize 'trainer' key) trainer_instance = kwargs.get('trainer', None) if trainer_instance is None: trainer_instance = kwargs.get('model', None) # Fallback to 'model' key dataset_length = 0 if trainer_instance and hasattr(trainer_instance, 'train_dataset') and trainer_instance.train_dataset is not None: try: dataset_length = len(trainer_instance.train_dataset) except Exception as len_e: logging.warning(f"WandBCallback: Error getting dataset length: {len_e}") else: logging.warning("WandBCallback: Could not access train_dataset length during on_train_begin.") batch_size = args.per_device_train_batch_size accumulation = args.gradient_accumulation_steps world_size = int(os.environ.get("WORLD_SIZE", 1)) global_batch_denom = (batch_size * world_size * accumulation) if dataset_length > 0 and global_batch_denom > 0: self.total_steps = (dataset_length // global_batch_denom) * args.num_train_epochs else: self.total_steps = -1 # Indicate unknown total steps # Log key hyperparameters config = { "model_name": MODEL_ID, "lora_r": LORA_R, "lora_alpha": LORA_ALPHA, "batch_size": PER_DEVICE_TRAIN_BATCH_SIZE, "grad_accum": GRADIENT_ACCUMULATION_STEPS, "effective_batch": PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS, "global_batch": PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * world_size, "learning_rate": LEARNING_RATE, "seq_length": MAX_SEQ_LENGTH, "epochs": EPOCHS, "total_steps_estimated": self.total_steps, "optimizer": OPTIMIZER, "warmup_ratio": WARMUP_RATIO, "scheduler": LR_SCHEDULER_TYPE, } wandb.config.update(config) # Initialize training progress table columns = ["step", "epoch", "loss", "lr", "tokens/sec", "eta", "elapsed_hrs"] self.progress_table = wandb.Table(columns=columns) # Log training start wandb.log({"training_status": "started"}) logging.info(f"Training started - total estimated steps: {self.total_steps}") def on_log(self, args, state, control, logs=None, **kwargs): """Log detailed metrics and progress after each logging step""" if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return # Check if WandB should be used try: import wandb # Import locally if not wandb.run: logging.warning("WandBCallback: Wandb run not active during on_log.") return except ImportError: logging.warning("WandBCallback: wandb not imported, cannot log on_log") return if not logs: return # Format metrics for logging metrics = {} for k, v in logs.items(): if isinstance(v, (int, float)): metrics[k] = v elif hasattr(v, "item"): # Handle tensors try: metrics[k] = v.item() except: pass if not metrics: return # Calculate time-based metrics current_time = time.time() if self.training_start_time is None: self.training_start_time = current_time # Safeguard elapsed_time = current_time - self.training_start_time elapsed_hrs = elapsed_time / 3600 # Estimate tokens processed batch_size = args.per_device_train_batch_size grad_accum = args.gradient_accumulation_steps world_size = int(os.environ.get("WORLD_SIZE", 1)) global_batch_size = batch_size * grad_accum * world_size tokens_per_step = global_batch_size * MAX_SEQ_LENGTH # Use MAX_SEQ_LENGTH from outer scope # Update tokens seen steps_since_last = state.global_step - (self.step_history[-1][0] if self.step_history else -1) if steps_since_last <= 0: steps_since_last = 1 # Avoid issues on first log new_tokens = tokens_per_step * steps_since_last self.tokens_seen += new_tokens # Calculate throughput time_since_last = current_time - (self.last_log_time if self.last_log_time else current_time) if time_since_last <= 0: time_since_last = 1.0 # Avoid division by zero tokens_per_second = new_tokens / time_since_last # Update rolling average of tokens/sec alpha = 0.1 self.global_tokens_per_second = alpha * tokens_per_second + (1 - alpha) * self.global_tokens_per_second # Track epoch progress if "epoch" in metrics: new_epoch = int(metrics["epoch"]) if new_epoch > self.current_epoch: epoch_time = current_time - (self.epoch_start_time if self.epoch_start_time else current_time) self.epoch_start_time = current_time self.current_epoch = new_epoch wandb.log({"epoch/duration_sec": epoch_time}, step=state.global_step) logging.info(f"Epoch {self.current_epoch-1} completed in {epoch_time:.2f} seconds") epoch_float = metrics["epoch"] epoch_progress = epoch_float - int(epoch_float) metrics["epoch_progress"] = epoch_progress * 100 # Estimate time remaining eta_hours = float('nan') if self.total_steps and self.total_steps > 0 and state.global_step > 0: progress_fraction = state.global_step / self.total_steps if progress_fraction > 1e-6: # Avoid division by zero early on eta_seconds = elapsed_time / progress_fraction - elapsed_time eta_hours = eta_seconds / 3600 metrics["eta_hours"] = eta_hours # Add additional calculated metrics metrics.update({ "progress/elapsed_hours": elapsed_hrs, "progress/tokens_total": self.tokens_seen, "performance/tokens_per_second": tokens_per_second, "performance/tokens_per_second_avg": self.global_tokens_per_second, "performance/global_batch_size": global_batch_size, }) # Add GPU utilization if available if torch.cuda.is_available(): try: local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Note: torch.cuda.utilization might not be available/reliable # metrics["gpu/utilization"] = torch.cuda.utilization(local_rank) metrics["gpu/memory_allocated_gb"] = torch.cuda.memory_allocated(local_rank) / 1e9 metrics["gpu/memory_reserved_gb"] = torch.cuda.memory_reserved(local_rank) / 1e9 except Exception as gpu_e: logging.debug(f"Could not log GPU metrics: {gpu_e}") # Log all metrics to wandb wandb.log(metrics, step=state.global_step) # Add row to progress table if self.progress_table is not None: loss_val = metrics.get("loss", float("nan")) lr_val = metrics.get("learning_rate", float("nan")) epoch_val = metrics.get("epoch", 0) tokens_sec = metrics.get("performance/tokens_per_second_avg", 0) self.progress_table.add_data( state.global_step, f"{epoch_val:.2f}", f"{loss_val:.4f}", f"{lr_val:.2e}", f"{tokens_sec:.1f}", f"{eta_hours:.1f} hrs", f"{elapsed_hrs:.1f} hrs" ) # Log the updated progress table (might be verbose, consider less frequent logging) # wandb.log({"training_progress": self.progress_table}, step=state.global_step) # Print concise metrics to console log_info = ( f"Step {state.global_step}" + (f"/{self.total_steps} ({100 * state.global_step / self.total_steps:.1f}%)" if self.total_steps and self.total_steps > 0 else "") + f" | Loss: {loss_val:.4f} | LR: {lr_val:.2e} | Epoch: {epoch_val:.2f}" + f" | Tokens/sec: {tokens_sec:.1f}" + (f" | ETA: {eta_hours:.1f}h" if not math.isnan(eta_hours) else "") ) logging.info(log_info) # Update time tracking self.last_log_time = current_time self.step_history.append((state.global_step, current_time)) if len(self.step_history) > 100: # Keep only recent history self.step_history = self.step_history[-100:] def on_train_end(self, args, state, control, **kwargs): """Log final statistics at the end of training""" if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return # Check if WandB should be used try: import wandb # Import locally if not wandb.run: logging.warning("WandBCallback: Wandb run not active during on_train_end.") return except ImportError: logging.warning("WandBCallback: wandb not imported, cannot log on_train_end") return total_time = time.time() - (self.training_start_time if self.training_start_time else time.time()) hours = total_time / 3600 final_stats = { "training_status": "completed", "total_steps_completed": state.global_step, "total_epochs_completed": self.current_epoch, "total_training_time_hours": hours, "total_tokens_processed": self.tokens_seen, "average_tokens_per_second": self.tokens_seen / total_time if total_time > 0 else 0 } wandb.log(final_stats, step=state.global_step) # Log at final step wandb.run.summary.update({ "training_duration_hours": hours, "total_steps": state.global_step, "total_epochs": self.current_epoch, "total_tokens": self.tokens_seen }) logging.info(f"Training complete - {hours:.2f} hours, {state.global_step} steps, {self.tokens_seen:,} tokens processed") # **** End of WandBLoggingCallback Definition **** # Create callback instance wandb_callback = WandBLoggingCallback() logging.info("Enhanced WandB logging callback created") except Exception as e: logging.error(f"Error creating WandB callback: {str(e)}") wandb_callback = None # Create the list of callbacks trainer_callbacks = [memory_monitor, first_step_callback] # Use the instance names if wandb_callback: trainer_callbacks.append(wandb_callback) logging.info("Added WandB callback to trainer") # trainer_callbacks = [] # Temporarily disable all callbacks # --- Initialize Trainer --- logging.info(f"Rank {rank}: Initializing SFTTrainer...") trainer = None try: trainer = SFTTrainer( model=model, # Using processing_class as per user confirmation processing_class=tokenizer, args=training_arguments, train_dataset=dataset_for_trainer, peft_config=peft_config, # Ensure this matches whether the collator is defined/needed preprocess_logits_for_metrics=None, callbacks=trainer_callbacks, # Pass the list here ) logging.info(f"Rank {rank}: SFTTrainer initialized successfully.") except Exception as e: logging.error(f"Rank {rank}: Error initializing SFTTrainer: {e}") import traceback logging.error(traceback.format_exc()) if distributed_mode and dist.is_initialized(): try: dist.destroy_process_group() except: pass raise # --- Train --- if trainer is not None: logging.info(f"Beginning trainer.train() call at {datetime.now().strftime('%H:%M:%S')}") try: trainer.train() logging.info(f"Training finished successfully at {datetime.now().strftime('%H:%M:%S')}") except Exception as e: logging.error(f"Exception during training: {e}") import traceback logging.error(traceback.format_exc()) if distributed_mode and dist.is_initialized(): try: dist.destroy_process_group() logging.info("Destroyed process group after training error") except: pass raise # --- Merge Model and Save Full Model --- logging.info("Merging adapter weights into base model...") # Clear some memory first if needed (especially if not using massive GPUs) # del model # del trainer # torch.cuda.empty_cache() # Reload the base model (consider lower precision to save VRAM during merge) logging.info(f"Reloading base model ({MODEL_ID}) for merging...") base_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, config=config, # Ensure YaRN config is used if applied during training torch_dtype=torch.bfloat16, # Or torch.float16, adjust as needed low_cpu_mem_usage=True, # Helps with large models trust_remote_code=True, device_map=None, # Load onto CPU first to potentially save GPU VRAM if needed attn_implementation="flash_attention_2" ) # Load the PEFT model with adapters logging.info(f"Loading PEFT model from {OUTPUT_DIR}...") merged_model = PeftModel.from_pretrained( base_model, OUTPUT_DIR, device_map=None, # Load onto CPU first ) # Merge the adapter weights logging.info("Merging LoRA weights...") merged_model = merged_model.merge_and_unload() logging.info("LoRA weights merged.") # Define path for the full model save full_model_save_path = os.path.join(OUTPUT_DIR, "final_merged_model") # Save the merged model logging.info(f"Saving merged model to {full_model_save_path}...") merged_model.save_pretrained(full_model_save_path) logging.info("Merged model saved.") # Save the tokenizer associated with the merged model logging.info(f"Saving tokenizer to {full_model_save_path}...") tokenizer.save_pretrained(full_model_save_path) logging.info("Tokenizer saved.") logging.info(f"Fine-tuning and merging process complete. Full model saved to {full_model_save_path}") # --- Notes on Inference and Resuming Training --- logging.info("Training Checkpoint Notes:") logging.info(f" • Checkpoints saved to: {OUTPUT_DIR}") logging.info(f" • To resume training from the latest checkpoint, just rerun this script") logging.info(f" (resume_from_checkpoint='auto' will automatically find the latest checkpoint)") logging.info(f" • To resume from a specific checkpoint, set resume_from_checkpoint='path/to/checkpoint'") # --- Notes on Inference --- # To use the trained adapters: # from peft import PeftModel # base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, ...) # model = PeftModel.from_pretrained(base_model, final_adapter_path) # model = model.merge_and_unload() # Optional: merge adapters for faster inference # Then use model and tokenizer for generation.