snwy's picture
repro code
aee6a1a verified
import os
import torch
from datasets import load_from_disk, concatenate_datasets, Dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel
from peft.tuners.lora import LoraLayer
from trl import SFTTrainer, SFTConfig
import logging
import torch.distributed as dist
from datetime import timedelta, datetime
import time
from transformers.trainer import TrainerCallback
import gc
import sys
import shutil # For handling file operations
import glob # For file pattern matching
import threading # For background cleanup
import multiprocessing
import subprocess
import tempfile
import json
import random
import math
import queue
import numpy as np
# Import the specific layer class for FSDP wrapping
try:
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
except ImportError:
logging.warning("Could not import Qwen2DecoderLayer. FSDP wrapping might fail.")
Qwen2DecoderLayer = None
# Configure more detailed logging with timestamps
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
stream=sys.stdout, # Ensure logs go to stdout for immediate visibility
force=True
)
# Set up temporary directory for cache files
temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp")
os.makedirs(temp_dir, exist_ok=True)
logging.info(f"Using temporary directory: {temp_dir}")
# Set environment variables to control temporary file creation
os.environ["TMPDIR"] = temp_dir # Unix
os.environ["TEMP"] = temp_dir # Windows
os.environ["TMP"] = temp_dir # Windows alternative
# Set default cache locations
hf_datasets_cache_path = os.path.join(temp_dir, "hf_datasets_cache")
transformers_cache_path = os.path.join(temp_dir, "transformers_cache")
hf_home_path = os.path.join(temp_dir, "hf_home")
os.makedirs(hf_datasets_cache_path, exist_ok=True)
os.makedirs(transformers_cache_path, exist_ok=True)
os.makedirs(hf_home_path, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = hf_datasets_cache_path
os.environ["TRANSFORMERS_CACHE"] = transformers_cache_path
os.environ["HF_HOME"] = hf_home_path
logging.info(f"Hugging Face Datasets cache directed to: {hf_datasets_cache_path}")
logging.info(f"Hugging Face Transformers cache directed to: {transformers_cache_path}")
# Keep forcing Arrow to use system memory pool if possible
os.environ["ARROW_DEFAULT_MEMORY_POOL"] = "system"
logging.info("Configured temporary directory and cache locations.")
# Set environment variable to control PyTorch's memory allocator
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
# Disable PYTORCH_NO_CUDA_MEMORY_CACHING for better performance
if "PYTORCH_NO_CUDA_MEMORY_CACHING" in os.environ:
del os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"]
# Set a longer timeout for NCCL operations
os.environ["NCCL_BLOCKING_WAIT"] = "1"
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
os.environ["NCCL_TIMEOUT"] = "3600" # 1 hour timeout for NCCL operations
# Initialize distributed environment with better error handling
def init_distributed():
try:
# Check if we're in a distributed training environment
if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1:
# Set memory optimization environment variables
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
logging.info("Setting PyTorch memory optimizations for H200 GPUs")
# Empty CUDA cache before initializing process group
if torch.cuda.is_available():
torch.cuda.empty_cache()
logging.info("CUDA cache cleared")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", 0))
logging.info(f"Initializing distributed training for 8x H200s. Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size}")
# Set the device for this process explicitly before initializing
torch.cuda.set_device(local_rank)
logging.info(f"Setting device {local_rank} for process rank {rank}")
# Set a longer timeout to handle long operations (3 hours)
timeout = timedelta(hours=3)
# Initialize the distributed process group
dist.init_process_group(
backend='nccl',
init_method='env://',
timeout=timeout,
rank=rank,
world_size=world_size
)
# Verify initialization was successful
if dist.is_initialized():
logging.info(f"Successfully initialized distributed process group. Rank: {rank}, Device: {torch.cuda.current_device()}")
# Log NCCL environment
logging.info(f"NCCL Version: {torch.cuda.nccl.version() if hasattr(torch.cuda, 'nccl') else 'unknown'}")
logging.info(f"CUDA Device Count: {torch.cuda.device_count()}")
logging.info(f"CUDA Device Name: {torch.cuda.get_device_name(local_rank)}")
else:
logging.error(f"Failed to initialize distributed process group. Rank: {rank}")
# Ensure all processes can communicate with specified device
try:
device_ids = [local_rank]
dist.barrier(device_ids=device_ids)
logging.info(f"Communication test successful. Process {rank} on device {local_rank} can communicate.")
except Exception as e:
logging.error(f"Communication test failed. Processes cannot communicate: {str(e)}. Rank: {rank}")
raise
return True
else:
logging.info("Not running in distributed mode.")
return False
except Exception as e:
logging.error(f"Error initializing distributed environment: {str(e)}")
raise
# Initialize distributed environment
distributed_mode = init_distributed()
# --- Configuration ---
# Model ID updated based on user input
MODEL_ID = "Qwen/QwQ-32B"
# Path to the processed dataset created by preprocess_data.py
DATASET_PATH = "./processed_datasets/combined_code_finetune_data"
# Number of examples to use (set to -1 for all)
MAX_EXAMPLES = -1 # Use all examples by default
# LoRA configuration (Optimized for 8x H200 GPUs)
LORA_R = 64 # Doubled to increase parameter count significantly
LORA_ALPHA = 128 # Increased alpha to match r
LORA_DROPOUT = 0.05 # Dropout probability for LoRA layers
# Target modules might need verification for QwQ-32B specifically.
# Common targets for Qwen models:
LORA_TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
# "embed_tokens", # Removed to reduce overhead/complexity
# "lm_head", # Removed to reduce overhead/complexity
]
# Training arguments optimized for 8x H200 GPUs with memory constraints
OUTPUT_DIR = "./qwq-32b-finetuned-adapters"
PER_DEVICE_TRAIN_BATCH_SIZE = 8 # Increase BS after halving seq length again
GRADIENT_ACCUMULATION_STEPS = 6 # Decrease accumulation (8*8*6 = 384)
# Global batch size = PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * NumGPUs
# Example: 8 * 6 * 8 = 384
LEARNING_RATE = 3e-5 # Slightly higher LR for larger batch size
EPOCHS = 1 # Start with 1 epoch, increase cautiously
MAX_SEQ_LENGTH = 4096 # Halved sequence length again
LOGGING_STEPS = 50 # Increased logging frequency
SAVE_STEPS = 500 # Increased save frequency
OPTIMIZER = "adamw_bnb_8bit" # Use 8-bit optimizer to save significant memory
WARMUP_RATIO = 0.03
LR_SCHEDULER_TYPE = "cosine"
# H200-specific optimizations (8x setup)
USE_FLASH_ATTN = True # Enable Flash Attention 2 for H200s
USE_SEQUENCE_PARALLEL = False # Disable when using FSDP
USE_BETTER_TRANSFORMERS = True # Use better transformers for optimized kernels
DATALOADER_NUM_WORKERS = 8 # Reduced workers to avoid CPU contention
TOKENIZATION_NUM_WORKERS = 224 # Maximum worker count for tokenization
USE_ACTIVATION_CHECKPOINTING = True # Enable activation checkpointing to save memory with long sequences
# Advanced distributed training options for 8x GPUs
USE_FSDP = True # Enable FSDP
FSDP_CONFIG = {
"fsdp_offload_params": False, # Disable CPU Offload
"fsdp_sharding_strategy": 1, # 1 = FULL_SHARD
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_transformer_layer_cls_to_wrap": [Qwen2DecoderLayer.__name__] if Qwen2DecoderLayer else [],
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_backward_prefetch": "backward_post", # Changed from backward_pre
"fsdp_forward_prefetch": False, # Disabled forward prefetch
"fsdp_activation_checkpointing": [Qwen2DecoderLayer.__name__] if Qwen2DecoderLayer else [], # Use FSDP activation checkpointing
}
# WandB Integration
REPORT_TO_WANDB = True # Set to False to disable WandB reporting
WANDB_PROJECT_NAME = "QwQ-32B-Finetune-8xH200" # Updated for 8x GPUs
WANDB_ENTITY = None # Set to your username or team name if needed
# Determine report_to destination
report_to = "none"
if REPORT_TO_WANDB:
# Disable WandB in all processes except rank 0 in distributed mode
if distributed_mode and int(os.environ.get("LOCAL_RANK", 0)) != 0:
logging.info(f"Rank {os.environ.get('RANK', '?')}: Disabling WandB")
os.environ["WANDB_DISABLED"] = "true"
report_to = "none" # Explicitly set to none for non-main processes
else:
# Main process or non-distributed mode, attempt WandB initialization
try:
import wandb
logging.info("Initializing WandB directly...")
run_name = f"qwq-32b-finetune-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
if wandb.run is None:
try:
wandb.init(
project=WANDB_PROJECT_NAME,
entity=WANDB_ENTITY,
name=run_name,
config={
"model_name": MODEL_ID,
"batch_size": PER_DEVICE_TRAIN_BATCH_SIZE,
"gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
"learning_rate": LEARNING_RATE,
"epochs": EPOCHS,
"sequence_length": MAX_SEQ_LENGTH,
"lora_r": LORA_R,
"lora_alpha": LORA_ALPHA,
}
)
logging.info(f"WandB initialized: {wandb.run.name} (ID: {wandb.run.id})")
report_to = "wandb"
except Exception as e:
logging.error(f"WandB initialization error: {str(e)}")
report_to = "tensorboard"
else:
logging.info(f"Using existing WandB run: {wandb.run.name} (ID: {wandb.run.id})")
report_to = "wandb"
except ImportError:
logging.warning("WandB package not installed. Reporting to TensorBoard.")
report_to = "tensorboard"
except Exception as wandb_init_e:
logging.error(f"General WandB setup error: {wandb_init_e}")
report_to = "tensorboard"
# If WandB reporting is disabled, set report_to accordingly
elif not distributed_mode:
report_to = "tensorboard"
logging.info("WandB reporting disabled. Reporting to TensorBoard.")
else: # If WandB is disabled and it IS distributed
report_to = "none"
logging.info("WandB reporting disabled for this distributed rank.")
# Quantization (QLoRA)
USE_4BIT_QUANTIZATION = False # Disable QLoRA due to FSDP incompatibility
BNB_4BIT_COMPUTE_DTYPE = "bfloat16" # Use bfloat16 if supported, else float16
BNB_4BIT_QUANT_TYPE = "nf4"
# --- Check Optional Dependencies (Define flags globally) ---
FLASH_ATTN_AVAILABLE = False
BETTER_TRANSFORMERS_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_AVAILABLE = True
logging.info("Flash Attention available - will be used if enabled.")
except ImportError:
logging.warning("Flash Attention not available. Install with 'pip install flash-attn'")
try:
from optimum.bettertransformer import BetterTransformer
BETTER_TRANSFORMERS_AVAILABLE = True
logging.info("Better Transformers available - will be used if enabled.")
except ImportError:
logging.warning("Better Transformers not available. Install with 'pip install optimum'")
# --- Check Dataset ---
if not os.path.exists(DATASET_PATH):
logging.error(f"Dataset not found at {DATASET_PATH}. Run preprocess_data.py first.")
exit(1)
logging.info(f"Loading dataset from {DATASET_PATH}...")
# Load dataset normally
dataset = load_from_disk(DATASET_PATH)
# Apply truncation if needed
if MAX_EXAMPLES > 0 and len(dataset) > MAX_EXAMPLES:
logging.info(f"Truncating dataset to {MAX_EXAMPLES} examples")
indices = list(range(min(MAX_EXAMPLES, len(dataset))))
dataset = dataset.select(indices)
logging.info(f"Dataset loaded: {dataset} with {len(dataset)} examples")
# --- Tokenizer ---
logging.info(f"Loading tokenizer for {MODEL_ID}...")
# Enable fast tokenizer and optimizations
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
use_fast=True, # Explicitly request the fast Rust-based tokenizer
trust_remote_code=True,
# model_max_length=MAX_SEQ_LENGTH,
padding_side="right",
)
# Log tokenizer type for verification
if hasattr(tokenizer, 'is_fast') and tokenizer.is_fast:
logging.info(f"Successfully loaded fast tokenizer (Rust implementation): {type(tokenizer).__name__}")
# Fast tokenizers are automatically parallel in dataset.map() when num_proc > 1
logging.info(f"Fast tokenizer will use parallel processing during dataset.map() with {TOKENIZATION_NUM_WORKERS} workers")
else:
logging.warning(f"Using Python tokenizer: {type(tokenizer).__name__}")
logging.warning("Python tokenizers are slower than Rust-based fast tokenizers")
# Check and set pad token based on Qwen documentation (<|endoftext|>)
# Qwen models might have this set correctly, but we verify.
EXPECTED_PAD_TOKEN = "<|endoftext|>"
if tokenizer.pad_token is None or tokenizer.pad_token != EXPECTED_PAD_TOKEN:
logging.warning(f"Tokenizer pad_token is missing or not '{EXPECTED_PAD_TOKEN}'. Setting pad_token='{EXPECTED_PAD_TOKEN}'.")
tokenizer.pad_token = EXPECTED_PAD_TOKEN
# Enable padding and truncation defaults for batch processing
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
tokenizer.padding_side = "right" # Typically "right" for decoder-only models like Qwen
# Log tokenizer configuration
logging.info(f"Tokenizer configuration:")
logging.info(f" - Type: {'Fast' if hasattr(tokenizer, 'is_fast') and tokenizer.is_fast else 'Python'}")
logging.info(f" - Pad token: {tokenizer.pad_token}")
logging.info(f" - EOS token: {tokenizer.eos_token}") # Should be <|im_end|>
logging.info(f" - Vocab size: {tokenizer.vocab_size}")
logging.info(f" - Model max length: {tokenizer.model_max_length}")
logging.info(f" - Padding side: {tokenizer.padding_side}")
# Define parallel preprocessing function for the dataset
def preprocess_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=MAX_SEQ_LENGTH,
return_tensors=None, # Return Python lists for dataset
)
# Create a cache directory for tokenized datasets
TOKENIZED_DATASET_CACHE_DIR = os.path.join(os.path.dirname(DATASET_PATH), "tokenized_cache")
os.makedirs(TOKENIZED_DATASET_CACHE_DIR, exist_ok=True)
tokenized_dataset_path = os.path.join(TOKENIZED_DATASET_CACHE_DIR, "tokenized_dataset")
# Create a file to signal tokenization completion
tokenization_done_file = os.path.join(TOKENIZED_DATASET_CACHE_DIR, "tokenization_complete")
# Function to clean up temporary files in dataset directory
def delete_existing_tmp_files():
"""Find and delete any existing tmp files in dataset directory"""
# Look for tmp files in dataset directory
tmp_files = glob.glob(os.path.join(DATASET_PATH, "tmp*"))
if tmp_files:
logging.info(f"Found {len(tmp_files)} existing tmp files, removing...")
for tmp_file in tmp_files:
try:
if os.path.isdir(tmp_file):
shutil.rmtree(tmp_file)
else:
os.remove(tmp_file)
logging.info(f"Removed: {tmp_file}")
except Exception as e:
logging.warning(f"Could not remove {tmp_file}: {str(e)}")
else:
logging.info("No existing tmp files found")
# Check if we're in distributed mode and get rank
if distributed_mode:
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
is_main_process = rank == 0
logging.info(f"Rank {rank}/{world_size}: Preparing for dataset processing")
else:
is_main_process = True
rank = 0
world_size = 1
local_rank = 0
# Clean up temp files - only on main process to avoid conflicts
if is_main_process:
delete_existing_tmp_files()
# Also remove the tokenization_done_file if it exists
if os.path.exists(tokenization_done_file):
os.remove(tokenization_done_file)
logging.info(f"Rank {rank}: Removed old tokenization completion marker")
# Only tokenize on main process (rank 0) to avoid redundant work
need_tokenization = False
# Check if tokenized dataset already exists
if os.path.exists(tokenized_dataset_path) and os.path.isdir(tokenized_dataset_path):
# --- Dataset Exists ---
logging.info(f"Rank {rank}: Found existing tokenized dataset at {tokenized_dataset_path}")
path_to_load = tokenized_dataset_path # All ranks will load from the persistent path
need_tokenization = False
# Rank 0 ensures completion marker exists
if is_main_process and not os.path.exists(tokenization_done_file):
total_original_examples = "unknown"
try:
from datasets import load_dataset_builder # Local import
original_dataset_info = load_dataset_builder(DATASET_PATH).info
total_original_examples = original_dataset_info.splits['train'].num_examples
except Exception as info_e:
logging.warning(f"Rank {rank}: Could not get original dataset info: {info_e}")
try:
# Get size of existing loaded dataset (approximate if needed)
# This requires loading a small part or metadata, might be slow
# For now, let's just mark it as existing
# loaded_size = len(load_from_disk(tokenized_dataset_path, keep_in_memory=False))
loaded_size = "unknown (loaded existing)"
with open(tokenization_done_file, "w") as f:
f.write(f"Tokenization assumed complete (loaded existing) at {datetime.now().isoformat()}\n")
f.write(f"Processed {loaded_size} examples out of {total_original_examples}\n")
logging.info(f"Rank {rank}: Created tokenization completion marker as it was missing.")
except Exception as file_e:
logging.error(f"Rank {rank}: Failed to create missing completion marker: {file_e}")
# Proceeding anyway, but other ranks might hang if they rely solely on the file
# Non-main ranks still need to wait for the marker to be sure Rank 0 checked/created it
elif not is_main_process:
logging.info(f"Rank {rank}: Waiting for main process confirmation via marker file...")
max_wait_time = 300 # Shorter wait, just confirming file exists
wait_start = time.time()
while not os.path.exists(tokenization_done_file):
if time.time() - wait_start > max_wait_time:
logging.error(f"Rank {rank}: Timed out waiting for marker file from Rank 0.")
raise TimeoutError("Marker file wait timeout")
time.sleep(5)
logging.info(f"Rank {rank}: Marker file found.")
elif is_main_process: # Tokenized doesn't exist, Rank 0 needs to create it
logging.info(f"Rank {rank}: Tokenization required. Proceeding with tokenization...")
need_tokenization = True
path_to_load = None
elif distributed_mode: # Tokenized doesn't exist, non-main ranks need to wait
logging.info(f"Rank {rank}: Tokenization required. Waiting for main process...")
need_tokenization = True
path_to_load = tokenized_dataset_path
# --- Perform Tokenization (if needed by Rank 0) ---
if need_tokenization and is_main_process:
tokenized_dataset_obj = None # Use a distinct name for the object returned by map
try:
# Process the dataset using dataset.map with internal parallelism
start_time = time.time() # Define start_time here
# Standard tokenization with caching enabled
logging.info(f"Rank {rank}: Starting tokenization using dataset.map with {TOKENIZATION_NUM_WORKERS} workers.")
tokenized_dataset_obj = dataset.map(
preprocess_function,
batched=True,
batch_size=1000,
num_proc=TOKENIZATION_NUM_WORKERS,
remove_columns=["text"],
load_from_cache_file=True, # Allow using cache file if it exists
desc=f"Tokenizing dataset ({TOKENIZATION_NUM_WORKERS} workers)"
)
elapsed = time.time() - start_time
logging.info(f"Rank {rank}: Tokenization successful in {elapsed:.2f} seconds.")
# If tokenization was successful:
if tokenized_dataset_obj is not None:
logging.info(f"Rank {rank}: Dataset tokenization completed.")
# Save directly to final path
logging.info(f"Rank {rank}: Saving tokenized dataset to {tokenized_dataset_path}...")
save_start = time.time()
# Ensure target directory doesn't exist (needed for clean save)
if os.path.exists(tokenized_dataset_path):
shutil.rmtree(tokenized_dataset_path)
tokenized_dataset_obj.save_to_disk(tokenized_dataset_path)
save_elapsed = time.time() - save_start
logging.info(f"Rank {rank}: Tokenized dataset saved in {save_elapsed:.2f} seconds.")
# Create completion marker file ONLY after successful save
with open(tokenization_done_file, "w") as f:
f.write(f"Tokenization completed and saved at {datetime.now().isoformat()}\n")
logging.info(f"Rank {rank}: Created tokenization completion marker")
# Keep the result in memory for Rank 0 for immediate use
dataset = tokenized_dataset_obj
path_to_load = None # Rank 0 uses the in-memory object directly
except Exception as e:
logging.error(f"Rank {rank}: Tokenization failed: {e}")
import traceback
logging.error(traceback.format_exc())
# Create done file indicating failure
with open(tokenization_done_file, "w") as f:
f.write(f"Tokenization FAILED at {datetime.now().isoformat()}\nError: {e}")
raise RuntimeError("Tokenization failed.") from e
# --- Load Dataset (All Ranks) ---
# This block now runs for all ranks *after* rank 0 has either tokenized or copied data
dataset_for_trainer = None # Use a distinct variable name for clarity
if path_to_load: # If path_to_load is set (means rank 0 copied or non-main rank needs to load)
if not is_main_process and need_tokenization:
# Non-main ranks wait for the done file if tokenization was required
logging.info(f"Rank {rank}: Waiting for tokenization completion signal (already checked for existence)...")
# Wait logic already happened if we got here and path_to_load is set
pass
# All ranks with a path_to_load proceed to load
logging.info(f"Rank {rank}: Loading dataset from {path_to_load}...")
load_start_time = time.time()
try:
# Load without forcing into memory initially
dataset_for_trainer = load_from_disk(path_to_load, keep_in_memory=False)
load_elapsed = time.time() - load_start_time
logging.info(f"Rank {rank}: Successfully loaded dataset in {load_elapsed:.2f}s. Length: {len(dataset_for_trainer)}")
except Exception as e:
logging.error(f"Rank {rank}: CRITICAL - Failed to load dataset from {path_to_load}: {e}")
raise
elif is_main_process and not need_tokenization:
# Rank 0 loaded existing, copied to RAM disk, and path_to_load points there
# It still needs to load it for the trainer
logging.info(f"Rank {rank}: Loading dataset from RAM disk copy {path_to_load}...")
try:
dataset_for_trainer = load_from_disk(path_to_load, keep_in_memory=False)
logging.info(f"Rank {rank}: Successfully loaded dataset from RAM disk copy.")
except Exception as e:
logging.error(f"Rank {rank}: CRITICAL - Failed to load from RAM disk copy {path_to_load}: {e}")
raise
elif is_main_process and need_tokenization:
# Rank 0 just tokenized, 'dataset' variable already holds the result in memory
logging.info(f"Rank {rank}: Using in-memory dataset from successful tokenization.")
dataset_for_trainer = dataset # Use the object directly
else:
# Should not happen
logging.error(f"Rank {rank}: Dataset path logic error. path_to_load='{path_to_load}', need_tokenization={need_tokenization}")
raise RuntimeError("Dataset preparation failed - logic error.")
# At this point, 'dataset' on all ranks should hold the ready-to-use data.
# Synchronize processes after dataset is ready on all ranks
if distributed_mode:
try:
logging.info(f"Rank {rank}: Synchronizing after dataset preparation...")
dist.barrier()
logging.info(f"Rank {rank}: Synchronization complete.")
except Exception as sync_e:
logging.error(f"Rank {rank}: Synchronization after dataset prep failed: {sync_e}")
raise
# --- Helper Function for Memory Check ---
def check_gpu_memory_utilization():
"""Check and report GPU memory utilization"""
if not torch.cuda.is_available():
logging.info("CUDA not available, skipping GPU memory check.")
return 0 # Return 0 utilization if no GPU
logging.info("==== GPU MEMORY UTILIZATION CHECK ====")
total_allocated_gb = 0
total_reserved_gb = 0
total_capacity_gb = 0
try:
for i in range(torch.cuda.device_count()):
free_mem, total_mem = torch.cuda.mem_get_info(i)
allocated = torch.cuda.memory_allocated(i)
reserved = torch.cuda.memory_reserved(i)
free_gb = free_mem / (1024**3)
total_gb = total_mem / (1024**3)
allocated_gb = allocated / (1024**3)
reserved_gb = reserved / (1024**3)
utilized_pct = (1 - free_mem/total_mem) * 100 if total_mem > 0 else 0
total_allocated_gb += allocated_gb
total_reserved_gb += reserved_gb
total_capacity_gb += total_gb
logging.info(f"GPU {i}: Allocated {allocated_gb:.1f}GB, Reserved {reserved_gb:.1f}GB, "
f"Free {free_gb:.1f}GB, Total {total_gb:.1f}GB, "
f"Utilization: {utilized_pct:.1f}%")
avg_utilization = (total_allocated_gb / total_capacity_gb) * 100 if total_capacity_gb > 0 else 0
logging.info(f"OVERALL: Using {total_allocated_gb:.1f}GB / {total_capacity_gb:.1f}GB ({avg_utilization:.1f}% allocated)")
logging.info("========================================")
return avg_utilization
except Exception as e:
logging.error(f"Error checking GPU memory: {e}")
return 0 # Return 0 on error
# --- Model Loading & Preparation (Runs on ALL ranks) ---
logging.info(f"Rank {rank}: Loading model: {MODEL_ID}...")
# 1. Load Model Configuration
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
logging.info("Enabling YaRN scaling in model configuration.")
config.rope_scaling = {
"type": "yarn",
"factor": 4.0,
"original_max_position_embeddings": 32768,
}
# Determine torch dtype
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# Set device_map based on distributed mode
# When using FSDP, device_map should typically be None or "auto", FSDP handles placement.
if USE_FSDP:
device_map = None
logging.info("FSDP enabled: Setting device_map=None")
elif distributed_mode:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device_map = {"": local_rank}
logging.info(f"Rank {rank}: DDP mode: Loading model on device {local_rank}")
else:
device_map = "auto"
logging.info("Rank {rank}: Single process mode: Using automatic device mapping")
# Configure Flash Attention and other optimizations
use_flash_attn = USE_FLASH_ATTN and FLASH_ATTN_AVAILABLE
attn_implementation = "flash_attention_2" if use_flash_attn else None
# Configure Quantization if enabled
# quantization_config = None
# if USE_4BIT_QUANTIZATION:
# logging.info("Configuring 4-bit quantization (QLoRA)...")
# compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)
# quantization_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE,
# bnb_4bit_compute_dtype=compute_dtype,
# bnb_4bit_use_double_quant=True, # Qwen models often benefit from double quant
# )
# # Override torch_dtype when using quantization as recommended
# # torch_dtype = None
# logging.info(f"4-bit quantization config created: type={BNB_4BIT_QUANT_TYPE}, compute={BNB_4BIT_COMPUTE_DTYPE}")
# Configure model loading kwargs
model_load_kwargs = {
"config": config,
"device_map": device_map,
"low_cpu_mem_usage": True,
"trust_remote_code": True,
}
if use_flash_attn:
model_load_kwargs["attn_implementation"] = "flash_attention_2"
# if quantization_config:
# model_load_kwargs["quantization_config"] = quantization_config
# Always set torch_dtype when not using quantization
model_load_kwargs["torch_dtype"] = torch_dtype
# Log memory before loading
# ... (memory logging logic - keep as is) ...
# Load the model
model = None # Initialize model variable
try:
logging.info(f"Rank {rank}: Calling AutoModelForCausalLM.from_pretrained...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
**model_load_kwargs
)
logging.info(f"Rank {rank}: Base model loaded successfully on device: {model.device if device_map is None else 'CPU/Multi'}")
# Ensure consistent dtype before FSDP wrapping (which happens in trainer.train)
if torch_dtype == torch.bfloat16:
logging.info("Explicitly casting model to bfloat16...")
model = model.to(torch.bfloat16)
# Apply Better Transformers optimization
use_better_transformers_flag = USE_BETTER_TRANSFORMERS and BETTER_TRANSFORMERS_AVAILABLE
if use_better_transformers_flag:
try:
logging.info("Applying BetterTransformer optimizations...")
model = BetterTransformer.transform(model)
logging.info("BetterTransformer optimizations applied successfully")
except Exception as bt_e:
logging.warning(f"Could not apply BetterTransformer optimizations: {str(bt_e)}")
# Apply activation checkpointing
if USE_ACTIVATION_CHECKPOINTING:
try:
logging.info("Enabling activation checkpointing...")
model.gradient_checkpointing_enable()
logging.info("Activation checkpointing enabled.")
except Exception as ac_e:
logging.warning(f"Could not enable activation checkpointing: {str(ac_e)}")
# Log model config and check memory utilization
logging.info(f"Rank {rank}: Model setup complete.")
check_gpu_memory_utilization() # This function needs to be defined or moved
except Exception as model_load_e: # Correct indentation for except
logging.error(f"Rank {rank}: Failed during model loading or preparation: {model_load_e}")
import traceback
logging.error(traceback.format_exc())
# Attempt to clean up distributed env before raising
if distributed_mode and dist.is_initialized():
try: dist.destroy_process_group()
except: pass
raise # Re-raise error
# --- LoRA Configuration ---
# ... (LoRA config - keep as is) ...
peft_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
target_modules=LORA_TARGET_MODULES,
bias="none",
task_type="CAUSAL_LM",
)
# --- Synchronize AFTER model loading & PEFT config ---
if distributed_mode:
try:
logging.info(f"Rank {rank}: Synchronizing after model loading...")
dist.barrier()
logging.info(f"Rank {rank}: Synchronization after model loading complete.")
except Exception as sync_e:
logging.error(f"Rank {rank}: Synchronization after model loading failed: {sync_e}")
raise
# --- Define Training Arguments ---
# (Determine determined_run_name logic here as before)
determined_run_name = None
if REPORT_TO_WANDB and is_main_process:
try:
import wandb
if wandb.run is not None: determined_run_name = wandb.run.name
except Exception: pass # Ignore errors here, handled by report_to
base_training_args = {
# ... (all base args, including max_seq_length) ...
"output_dir": OUTPUT_DIR,
"per_device_train_batch_size": PER_DEVICE_TRAIN_BATCH_SIZE,
"gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
"optim": OPTIMIZER,
"save_steps": SAVE_STEPS,
"logging_steps": LOGGING_STEPS,
"learning_rate": LEARNING_RATE,
"num_train_epochs": EPOCHS,
"max_steps": -1,
"fp16": False,
"bf16": torch_dtype == torch.bfloat16, # Use previously determined dtype
"max_grad_norm": 0.3,
"warmup_ratio": WARMUP_RATIO,
"group_by_length": False, # Explicitly disable to prevent pre-computation hang
"lr_scheduler_type": LR_SCHEDULER_TYPE,
"report_to": report_to,
"save_total_limit": 3,
"logging_first_step": True,
**({"run_name": determined_run_name} if determined_run_name is not None else {}),
"fsdp": "full_shard" if USE_FSDP else "", # Pass FSDP strategy string (removed offload)
"fsdp_config": FSDP_CONFIG if USE_FSDP else {}, # Pass FSDP config dict
"dataloader_num_workers": DATALOADER_NUM_WORKERS,
"resume_from_checkpoint": "auto",
"save_strategy": "steps",
"load_best_model_at_end": False,
"metric_for_best_model": None,
"dataset_text_field": "text",
"packing": False,
"max_seq_length": MAX_SEQ_LENGTH,
# Memory/Performance Optimizations
"gradient_checkpointing_kwargs": {"use_reentrant": False}, # More stable checkpointing for FSDP activation checkpointing
"ddp_find_unused_parameters": False, # Should be False for FSDP
"tf32": True, # Enable TF32 for faster compute on compatible GPUs
}
training_arguments = SFTConfig(**base_training_args)
logging.info(f"Rank {rank}: Training arguments (SFTConfig) created.")
# --- Define Callbacks ---
# Create memory monitoring callback
class MemoryMonitorCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % 10 == 0: # Log every 10 steps
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
try:
free_mem, total_mem = torch.cuda.mem_get_info(local_rank)
free_gb = free_mem / (1024**3)
used_gb = (total_mem - free_mem) / (1024**3)
total_gb = total_mem / (1024**3)
reserved = torch.cuda.memory_reserved(local_rank) / (1024**3)
allocated = torch.cuda.memory_allocated(local_rank) / (1024**3)
logging.info(f"Rank {rank}: Memory at step {state.global_step}: "
f"{free_gb:.1f}GB free, {used_gb:.1f}GB used, {total_gb:.1f}GB total, "
f"{reserved:.1f}GB reserved, {allocated:.1f}GB allocated")
except Exception as mem_e:
logging.warning(f"Rank {rank}: Could not get memory info: {mem_e}")
return control
memory_monitor = MemoryMonitorCallback()
# Create a special first step callback with WandB support
class FirstStepCallback(TrainerCallback):
def __init__(self):
self.first_step_start_time = None
self.progress_indicators = 0
self.update_interval = 60 # Check every minute
self.last_update_time = time.time()
def on_step_begin(self, args, state, control, **kwargs):
if state.global_step == 0:
self.first_step_start_time = time.time()
logging.info(f"FIRST STEP STARTING at {datetime.now().strftime('%H:%M:%S')}")
if REPORT_TO_WANDB and 'wandb' in sys.modules:
try:
import wandb # Import locally
if wandb.run:
wandb.log({"training_status": "first_step_started"})
except Exception as log_e: logging.warning(f"Wandb log error: {log_e}")
return control
def on_step_end(self, args, state, control, **kwargs):
if state.global_step == 0:
if self.first_step_start_time is None: # Should not happen, but safeguard
logging.warning("First step ended but start time was not recorded.")
return control
duration = time.time() - self.first_step_start_time
logging.info(f"FIRST STEP COMPLETED at {datetime.now().strftime('%H:%M:%S')} (took {duration:.2f} seconds)")
if REPORT_TO_WANDB and 'wandb' in sys.modules:
try:
import wandb # Import locally
if wandb.run:
wandb.log({
"training_status": "first_step_completed",
"first_step_duration": duration
})
except Exception as log_e: logging.warning(f"Wandb log error: {log_e}")
return control
def on_substep_end(self, args, state, control, **kwargs):
# This tracks progress within a step (during gradient accumulation)
current_time = time.time()
# Only report for the first step/substep and only from rank 0
if (self.first_step_start_time is not None and
state.global_step == 0 and
current_time - self.last_update_time >= self.update_interval and
(not distributed_mode or int(os.environ.get("LOCAL_RANK", 0)) == 0)):
self.progress_indicators += 1
elapsed = current_time - self.first_step_start_time
logging.info(f"First step still in progress... ({elapsed:.1f}s elapsed, progress indicator {self.progress_indicators})")
if REPORT_TO_WANDB and 'wandb' in sys.modules:
try:
import wandb # Import locally
if wandb.run:
wandb.log({
"training_status": "first_step_in_progress",
"first_step_elapsed": elapsed,
"progress_indicator": self.progress_indicators
})
except Exception as log_e: logging.warning(f"Wandb log error: {log_e}")
self.last_update_time = current_time
return control
first_step_callback = FirstStepCallback()
# Add WandB logging callback if WandB is enabled
wandb_callback = None # Initialize
if REPORT_TO_WANDB and 'wandb' in sys.modules and (not distributed_mode or int(os.environ.get("LOCAL_RANK", 0)) == 0):
try:
# **** FULL WandBLoggingCallback Class Definition ****
class WandBLoggingCallback(TrainerCallback):
"""Logs comprehensive training metrics and progress to Weights & Biases"""
def __init__(self):
self.training_start_time = None
self.last_log_time = None
self.total_steps = None
self.samples_seen = 0
self.tokens_seen = 0
self.current_epoch = 0
self.epoch_start_time = None
self.step_history = [] # For tracking steps/second
self.global_tokens_per_second = 0
self.progress_table = None # Initialize table to None
def on_train_begin(self, args, state, control, **kwargs):
"""Log hyperparameters and initialize tracking at the start of training"""
if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return # Check if WandB should be used
try:
import wandb # Import locally
if not wandb.run:
logging.warning("WandBCallback: Wandb not initialized in on_train_begin.")
return
except ImportError:
logging.warning("WandBCallback: wandb not imported, cannot log on_train_begin")
return
self.training_start_time = time.time()
self.epoch_start_time = time.time()
self.last_log_time = time.time()
# Calculate total expected steps
if args.max_steps > 0:
self.total_steps = args.max_steps
else:
# Use trainer passed in kwargs if available (prioritize 'trainer' key)
trainer_instance = kwargs.get('trainer', None)
if trainer_instance is None:
trainer_instance = kwargs.get('model', None) # Fallback to 'model' key
dataset_length = 0
if trainer_instance and hasattr(trainer_instance, 'train_dataset') and trainer_instance.train_dataset is not None:
try:
dataset_length = len(trainer_instance.train_dataset)
except Exception as len_e:
logging.warning(f"WandBCallback: Error getting dataset length: {len_e}")
else:
logging.warning("WandBCallback: Could not access train_dataset length during on_train_begin.")
batch_size = args.per_device_train_batch_size
accumulation = args.gradient_accumulation_steps
world_size = int(os.environ.get("WORLD_SIZE", 1))
global_batch_denom = (batch_size * world_size * accumulation)
if dataset_length > 0 and global_batch_denom > 0:
self.total_steps = (dataset_length // global_batch_denom) * args.num_train_epochs
else:
self.total_steps = -1 # Indicate unknown total steps
# Log key hyperparameters
config = {
"model_name": MODEL_ID,
"lora_r": LORA_R,
"lora_alpha": LORA_ALPHA,
"batch_size": PER_DEVICE_TRAIN_BATCH_SIZE,
"grad_accum": GRADIENT_ACCUMULATION_STEPS,
"effective_batch": PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
"global_batch": PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * world_size,
"learning_rate": LEARNING_RATE,
"seq_length": MAX_SEQ_LENGTH,
"epochs": EPOCHS,
"total_steps_estimated": self.total_steps,
"optimizer": OPTIMIZER,
"warmup_ratio": WARMUP_RATIO,
"scheduler": LR_SCHEDULER_TYPE,
}
wandb.config.update(config)
# Initialize training progress table
columns = ["step", "epoch", "loss", "lr", "tokens/sec", "eta", "elapsed_hrs"]
self.progress_table = wandb.Table(columns=columns)
# Log training start
wandb.log({"training_status": "started"})
logging.info(f"Training started - total estimated steps: {self.total_steps}")
def on_log(self, args, state, control, logs=None, **kwargs):
"""Log detailed metrics and progress after each logging step"""
if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return # Check if WandB should be used
try:
import wandb # Import locally
if not wandb.run:
logging.warning("WandBCallback: Wandb run not active during on_log.")
return
except ImportError:
logging.warning("WandBCallback: wandb not imported, cannot log on_log")
return
if not logs:
return
# Format metrics for logging
metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics[k] = v
elif hasattr(v, "item"): # Handle tensors
try: metrics[k] = v.item()
except: pass
if not metrics:
return
# Calculate time-based metrics
current_time = time.time()
if self.training_start_time is None: self.training_start_time = current_time # Safeguard
elapsed_time = current_time - self.training_start_time
elapsed_hrs = elapsed_time / 3600
# Estimate tokens processed
batch_size = args.per_device_train_batch_size
grad_accum = args.gradient_accumulation_steps
world_size = int(os.environ.get("WORLD_SIZE", 1))
global_batch_size = batch_size * grad_accum * world_size
tokens_per_step = global_batch_size * MAX_SEQ_LENGTH # Use MAX_SEQ_LENGTH from outer scope
# Update tokens seen
steps_since_last = state.global_step - (self.step_history[-1][0] if self.step_history else -1)
if steps_since_last <= 0: steps_since_last = 1 # Avoid issues on first log
new_tokens = tokens_per_step * steps_since_last
self.tokens_seen += new_tokens
# Calculate throughput
time_since_last = current_time - (self.last_log_time if self.last_log_time else current_time)
if time_since_last <= 0: time_since_last = 1.0 # Avoid division by zero
tokens_per_second = new_tokens / time_since_last
# Update rolling average of tokens/sec
alpha = 0.1
self.global_tokens_per_second = alpha * tokens_per_second + (1 - alpha) * self.global_tokens_per_second
# Track epoch progress
if "epoch" in metrics:
new_epoch = int(metrics["epoch"])
if new_epoch > self.current_epoch:
epoch_time = current_time - (self.epoch_start_time if self.epoch_start_time else current_time)
self.epoch_start_time = current_time
self.current_epoch = new_epoch
wandb.log({"epoch/duration_sec": epoch_time}, step=state.global_step)
logging.info(f"Epoch {self.current_epoch-1} completed in {epoch_time:.2f} seconds")
epoch_float = metrics["epoch"]
epoch_progress = epoch_float - int(epoch_float)
metrics["epoch_progress"] = epoch_progress * 100
# Estimate time remaining
eta_hours = float('nan')
if self.total_steps and self.total_steps > 0 and state.global_step > 0:
progress_fraction = state.global_step / self.total_steps
if progress_fraction > 1e-6: # Avoid division by zero early on
eta_seconds = elapsed_time / progress_fraction - elapsed_time
eta_hours = eta_seconds / 3600
metrics["eta_hours"] = eta_hours
# Add additional calculated metrics
metrics.update({
"progress/elapsed_hours": elapsed_hrs,
"progress/tokens_total": self.tokens_seen,
"performance/tokens_per_second": tokens_per_second,
"performance/tokens_per_second_avg": self.global_tokens_per_second,
"performance/global_batch_size": global_batch_size,
})
# Add GPU utilization if available
if torch.cuda.is_available():
try:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
# Note: torch.cuda.utilization might not be available/reliable
# metrics["gpu/utilization"] = torch.cuda.utilization(local_rank)
metrics["gpu/memory_allocated_gb"] = torch.cuda.memory_allocated(local_rank) / 1e9
metrics["gpu/memory_reserved_gb"] = torch.cuda.memory_reserved(local_rank) / 1e9
except Exception as gpu_e:
logging.debug(f"Could not log GPU metrics: {gpu_e}")
# Log all metrics to wandb
wandb.log(metrics, step=state.global_step)
# Add row to progress table
if self.progress_table is not None:
loss_val = metrics.get("loss", float("nan"))
lr_val = metrics.get("learning_rate", float("nan"))
epoch_val = metrics.get("epoch", 0)
tokens_sec = metrics.get("performance/tokens_per_second_avg", 0)
self.progress_table.add_data(
state.global_step,
f"{epoch_val:.2f}",
f"{loss_val:.4f}",
f"{lr_val:.2e}",
f"{tokens_sec:.1f}",
f"{eta_hours:.1f} hrs",
f"{elapsed_hrs:.1f} hrs"
)
# Log the updated progress table (might be verbose, consider less frequent logging)
# wandb.log({"training_progress": self.progress_table}, step=state.global_step)
# Print concise metrics to console
log_info = (
f"Step {state.global_step}"
+ (f"/{self.total_steps} ({100 * state.global_step / self.total_steps:.1f}%)" if self.total_steps and self.total_steps > 0 else "")
+ f" | Loss: {loss_val:.4f} | LR: {lr_val:.2e} | Epoch: {epoch_val:.2f}"
+ f" | Tokens/sec: {tokens_sec:.1f}"
+ (f" | ETA: {eta_hours:.1f}h" if not math.isnan(eta_hours) else "")
)
logging.info(log_info)
# Update time tracking
self.last_log_time = current_time
self.step_history.append((state.global_step, current_time))
if len(self.step_history) > 100: # Keep only recent history
self.step_history = self.step_history[-100:]
def on_train_end(self, args, state, control, **kwargs):
"""Log final statistics at the end of training"""
if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return # Check if WandB should be used
try:
import wandb # Import locally
if not wandb.run:
logging.warning("WandBCallback: Wandb run not active during on_train_end.")
return
except ImportError:
logging.warning("WandBCallback: wandb not imported, cannot log on_train_end")
return
total_time = time.time() - (self.training_start_time if self.training_start_time else time.time())
hours = total_time / 3600
final_stats = {
"training_status": "completed",
"total_steps_completed": state.global_step,
"total_epochs_completed": self.current_epoch,
"total_training_time_hours": hours,
"total_tokens_processed": self.tokens_seen,
"average_tokens_per_second": self.tokens_seen / total_time if total_time > 0 else 0
}
wandb.log(final_stats, step=state.global_step) # Log at final step
wandb.run.summary.update({
"training_duration_hours": hours,
"total_steps": state.global_step,
"total_epochs": self.current_epoch,
"total_tokens": self.tokens_seen
})
logging.info(f"Training complete - {hours:.2f} hours, {state.global_step} steps, {self.tokens_seen:,} tokens processed")
# **** End of WandBLoggingCallback Definition ****
# Create callback instance
wandb_callback = WandBLoggingCallback()
logging.info("Enhanced WandB logging callback created")
except Exception as e:
logging.error(f"Error creating WandB callback: {str(e)}")
wandb_callback = None
# Create the list of callbacks
trainer_callbacks = [memory_monitor, first_step_callback] # Use the instance names
if wandb_callback:
trainer_callbacks.append(wandb_callback)
logging.info("Added WandB callback to trainer")
# trainer_callbacks = [] # Temporarily disable all callbacks
# --- Initialize Trainer ---
logging.info(f"Rank {rank}: Initializing SFTTrainer...")
trainer = None
try:
trainer = SFTTrainer(
model=model,
# Using processing_class as per user confirmation
processing_class=tokenizer,
args=training_arguments,
train_dataset=dataset_for_trainer,
peft_config=peft_config,
# Ensure this matches whether the collator is defined/needed
preprocess_logits_for_metrics=None,
callbacks=trainer_callbacks, # Pass the list here
)
logging.info(f"Rank {rank}: SFTTrainer initialized successfully.")
except Exception as e:
logging.error(f"Rank {rank}: Error initializing SFTTrainer: {e}")
import traceback
logging.error(traceback.format_exc())
if distributed_mode and dist.is_initialized():
try: dist.destroy_process_group()
except: pass
raise
# --- Train ---
if trainer is not None:
logging.info(f"Beginning trainer.train() call at {datetime.now().strftime('%H:%M:%S')}")
try:
trainer.train()
logging.info(f"Training finished successfully at {datetime.now().strftime('%H:%M:%S')}")
except Exception as e:
logging.error(f"Exception during training: {e}")
import traceback
logging.error(traceback.format_exc())
if distributed_mode and dist.is_initialized():
try:
dist.destroy_process_group()
logging.info("Destroyed process group after training error")
except:
pass
raise
# --- Merge Model and Save Full Model ---
logging.info("Merging adapter weights into base model...")
# Clear some memory first if needed (especially if not using massive GPUs)
# del model
# del trainer
# torch.cuda.empty_cache()
# Reload the base model (consider lower precision to save VRAM during merge)
logging.info(f"Reloading base model ({MODEL_ID}) for merging...")
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=config, # Ensure YaRN config is used if applied during training
torch_dtype=torch.bfloat16, # Or torch.float16, adjust as needed
low_cpu_mem_usage=True, # Helps with large models
trust_remote_code=True,
device_map=None, # Load onto CPU first to potentially save GPU VRAM if needed
attn_implementation="flash_attention_2"
)
# Load the PEFT model with adapters
logging.info(f"Loading PEFT model from {OUTPUT_DIR}...")
merged_model = PeftModel.from_pretrained(
base_model,
OUTPUT_DIR,
device_map=None, # Load onto CPU first
)
# Merge the adapter weights
logging.info("Merging LoRA weights...")
merged_model = merged_model.merge_and_unload()
logging.info("LoRA weights merged.")
# Define path for the full model save
full_model_save_path = os.path.join(OUTPUT_DIR, "final_merged_model")
# Save the merged model
logging.info(f"Saving merged model to {full_model_save_path}...")
merged_model.save_pretrained(full_model_save_path)
logging.info("Merged model saved.")
# Save the tokenizer associated with the merged model
logging.info(f"Saving tokenizer to {full_model_save_path}...")
tokenizer.save_pretrained(full_model_save_path)
logging.info("Tokenizer saved.")
logging.info(f"Fine-tuning and merging process complete. Full model saved to {full_model_save_path}")
# --- Notes on Inference and Resuming Training ---
logging.info("Training Checkpoint Notes:")
logging.info(f" • Checkpoints saved to: {OUTPUT_DIR}")
logging.info(f" • To resume training from the latest checkpoint, just rerun this script")
logging.info(f" (resume_from_checkpoint='auto' will automatically find the latest checkpoint)")
logging.info(f" • To resume from a specific checkpoint, set resume_from_checkpoint='path/to/checkpoint'")
# --- Notes on Inference ---
# To use the trained adapters:
# from peft import PeftModel
# base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, ...)
# model = PeftModel.from_pretrained(base_model, final_adapter_path)
# model = model.merge_and_unload() # Optional: merge adapters for faster inference
# Then use model and tokenizer for generation.