|
import os |
|
import torch |
|
from datasets import load_from_disk, concatenate_datasets, Dataset |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig |
|
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel |
|
from peft.tuners.lora import LoraLayer |
|
from trl import SFTTrainer, SFTConfig |
|
import logging |
|
import torch.distributed as dist |
|
from datetime import timedelta, datetime |
|
import time |
|
from transformers.trainer import TrainerCallback |
|
import gc |
|
import sys |
|
import shutil |
|
import glob |
|
import threading |
|
import multiprocessing |
|
import subprocess |
|
import tempfile |
|
import json |
|
import random |
|
import math |
|
import queue |
|
import numpy as np |
|
|
|
|
|
try: |
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer |
|
except ImportError: |
|
logging.warning("Could not import Qwen2DecoderLayer. FSDP wrapping might fail.") |
|
Qwen2DecoderLayer = None |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
stream=sys.stdout, |
|
force=True |
|
) |
|
|
|
|
|
temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp") |
|
os.makedirs(temp_dir, exist_ok=True) |
|
logging.info(f"Using temporary directory: {temp_dir}") |
|
|
|
|
|
os.environ["TMPDIR"] = temp_dir |
|
os.environ["TEMP"] = temp_dir |
|
os.environ["TMP"] = temp_dir |
|
|
|
|
|
hf_datasets_cache_path = os.path.join(temp_dir, "hf_datasets_cache") |
|
transformers_cache_path = os.path.join(temp_dir, "transformers_cache") |
|
hf_home_path = os.path.join(temp_dir, "hf_home") |
|
os.makedirs(hf_datasets_cache_path, exist_ok=True) |
|
os.makedirs(transformers_cache_path, exist_ok=True) |
|
os.makedirs(hf_home_path, exist_ok=True) |
|
|
|
os.environ["HF_DATASETS_CACHE"] = hf_datasets_cache_path |
|
os.environ["TRANSFORMERS_CACHE"] = transformers_cache_path |
|
os.environ["HF_HOME"] = hf_home_path |
|
logging.info(f"Hugging Face Datasets cache directed to: {hf_datasets_cache_path}") |
|
logging.info(f"Hugging Face Transformers cache directed to: {transformers_cache_path}") |
|
|
|
|
|
os.environ["ARROW_DEFAULT_MEMORY_POOL"] = "system" |
|
logging.info("Configured temporary directory and cache locations.") |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512" |
|
|
|
if "PYTORCH_NO_CUDA_MEMORY_CACHING" in os.environ: |
|
del os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] |
|
|
|
os.environ["NCCL_BLOCKING_WAIT"] = "1" |
|
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" |
|
os.environ["NCCL_TIMEOUT"] = "3600" |
|
|
|
|
|
def init_distributed(): |
|
try: |
|
|
|
if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1: |
|
|
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0: |
|
logging.info("Setting PyTorch memory optimizations for H200 GPUs") |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
logging.info("CUDA cache cleared") |
|
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
rank = int(os.environ.get("RANK", 0)) |
|
|
|
logging.info(f"Initializing distributed training for 8x H200s. Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size}") |
|
|
|
|
|
torch.cuda.set_device(local_rank) |
|
logging.info(f"Setting device {local_rank} for process rank {rank}") |
|
|
|
|
|
timeout = timedelta(hours=3) |
|
|
|
|
|
dist.init_process_group( |
|
backend='nccl', |
|
init_method='env://', |
|
timeout=timeout, |
|
rank=rank, |
|
world_size=world_size |
|
) |
|
|
|
|
|
if dist.is_initialized(): |
|
logging.info(f"Successfully initialized distributed process group. Rank: {rank}, Device: {torch.cuda.current_device()}") |
|
|
|
logging.info(f"NCCL Version: {torch.cuda.nccl.version() if hasattr(torch.cuda, 'nccl') else 'unknown'}") |
|
logging.info(f"CUDA Device Count: {torch.cuda.device_count()}") |
|
logging.info(f"CUDA Device Name: {torch.cuda.get_device_name(local_rank)}") |
|
else: |
|
logging.error(f"Failed to initialize distributed process group. Rank: {rank}") |
|
|
|
|
|
try: |
|
device_ids = [local_rank] |
|
dist.barrier(device_ids=device_ids) |
|
logging.info(f"Communication test successful. Process {rank} on device {local_rank} can communicate.") |
|
except Exception as e: |
|
logging.error(f"Communication test failed. Processes cannot communicate: {str(e)}. Rank: {rank}") |
|
raise |
|
|
|
return True |
|
else: |
|
logging.info("Not running in distributed mode.") |
|
return False |
|
except Exception as e: |
|
logging.error(f"Error initializing distributed environment: {str(e)}") |
|
raise |
|
|
|
|
|
distributed_mode = init_distributed() |
|
|
|
|
|
|
|
|
|
MODEL_ID = "Qwen/QwQ-32B" |
|
|
|
|
|
DATASET_PATH = "./processed_datasets/combined_code_finetune_data" |
|
|
|
|
|
MAX_EXAMPLES = -1 |
|
|
|
|
|
LORA_R = 64 |
|
LORA_ALPHA = 128 |
|
LORA_DROPOUT = 0.05 |
|
|
|
|
|
LORA_TARGET_MODULES = [ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"gate_proj", |
|
"up_proj", |
|
"down_proj", |
|
|
|
|
|
] |
|
|
|
|
|
OUTPUT_DIR = "./qwq-32b-finetuned-adapters" |
|
PER_DEVICE_TRAIN_BATCH_SIZE = 8 |
|
GRADIENT_ACCUMULATION_STEPS = 6 |
|
|
|
|
|
LEARNING_RATE = 3e-5 |
|
EPOCHS = 1 |
|
MAX_SEQ_LENGTH = 4096 |
|
LOGGING_STEPS = 50 |
|
SAVE_STEPS = 500 |
|
OPTIMIZER = "adamw_bnb_8bit" |
|
WARMUP_RATIO = 0.03 |
|
LR_SCHEDULER_TYPE = "cosine" |
|
|
|
|
|
USE_FLASH_ATTN = True |
|
USE_SEQUENCE_PARALLEL = False |
|
USE_BETTER_TRANSFORMERS = True |
|
DATALOADER_NUM_WORKERS = 8 |
|
TOKENIZATION_NUM_WORKERS = 224 |
|
USE_ACTIVATION_CHECKPOINTING = True |
|
|
|
|
|
USE_FSDP = True |
|
FSDP_CONFIG = { |
|
"fsdp_offload_params": False, |
|
"fsdp_sharding_strategy": 1, |
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", |
|
"fsdp_transformer_layer_cls_to_wrap": [Qwen2DecoderLayer.__name__] if Qwen2DecoderLayer else [], |
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT", |
|
"fsdp_backward_prefetch": "backward_post", |
|
"fsdp_forward_prefetch": False, |
|
"fsdp_activation_checkpointing": [Qwen2DecoderLayer.__name__] if Qwen2DecoderLayer else [], |
|
} |
|
|
|
|
|
REPORT_TO_WANDB = True |
|
WANDB_PROJECT_NAME = "QwQ-32B-Finetune-8xH200" |
|
WANDB_ENTITY = None |
|
|
|
|
|
report_to = "none" |
|
if REPORT_TO_WANDB: |
|
|
|
if distributed_mode and int(os.environ.get("LOCAL_RANK", 0)) != 0: |
|
logging.info(f"Rank {os.environ.get('RANK', '?')}: Disabling WandB") |
|
os.environ["WANDB_DISABLED"] = "true" |
|
report_to = "none" |
|
else: |
|
|
|
try: |
|
import wandb |
|
logging.info("Initializing WandB directly...") |
|
run_name = f"qwq-32b-finetune-{datetime.now().strftime('%Y%m%d-%H%M%S')}" |
|
if wandb.run is None: |
|
try: |
|
wandb.init( |
|
project=WANDB_PROJECT_NAME, |
|
entity=WANDB_ENTITY, |
|
name=run_name, |
|
config={ |
|
"model_name": MODEL_ID, |
|
"batch_size": PER_DEVICE_TRAIN_BATCH_SIZE, |
|
"gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS, |
|
"learning_rate": LEARNING_RATE, |
|
"epochs": EPOCHS, |
|
"sequence_length": MAX_SEQ_LENGTH, |
|
"lora_r": LORA_R, |
|
"lora_alpha": LORA_ALPHA, |
|
} |
|
) |
|
logging.info(f"WandB initialized: {wandb.run.name} (ID: {wandb.run.id})") |
|
report_to = "wandb" |
|
except Exception as e: |
|
logging.error(f"WandB initialization error: {str(e)}") |
|
report_to = "tensorboard" |
|
else: |
|
logging.info(f"Using existing WandB run: {wandb.run.name} (ID: {wandb.run.id})") |
|
report_to = "wandb" |
|
except ImportError: |
|
logging.warning("WandB package not installed. Reporting to TensorBoard.") |
|
report_to = "tensorboard" |
|
except Exception as wandb_init_e: |
|
logging.error(f"General WandB setup error: {wandb_init_e}") |
|
report_to = "tensorboard" |
|
|
|
elif not distributed_mode: |
|
report_to = "tensorboard" |
|
logging.info("WandB reporting disabled. Reporting to TensorBoard.") |
|
else: |
|
report_to = "none" |
|
logging.info("WandB reporting disabled for this distributed rank.") |
|
|
|
|
|
USE_4BIT_QUANTIZATION = False |
|
BNB_4BIT_COMPUTE_DTYPE = "bfloat16" |
|
BNB_4BIT_QUANT_TYPE = "nf4" |
|
|
|
|
|
FLASH_ATTN_AVAILABLE = False |
|
BETTER_TRANSFORMERS_AVAILABLE = False |
|
try: |
|
import flash_attn |
|
FLASH_ATTN_AVAILABLE = True |
|
logging.info("Flash Attention available - will be used if enabled.") |
|
except ImportError: |
|
logging.warning("Flash Attention not available. Install with 'pip install flash-attn'") |
|
|
|
try: |
|
from optimum.bettertransformer import BetterTransformer |
|
BETTER_TRANSFORMERS_AVAILABLE = True |
|
logging.info("Better Transformers available - will be used if enabled.") |
|
except ImportError: |
|
logging.warning("Better Transformers not available. Install with 'pip install optimum'") |
|
|
|
|
|
if not os.path.exists(DATASET_PATH): |
|
logging.error(f"Dataset not found at {DATASET_PATH}. Run preprocess_data.py first.") |
|
exit(1) |
|
|
|
logging.info(f"Loading dataset from {DATASET_PATH}...") |
|
|
|
|
|
dataset = load_from_disk(DATASET_PATH) |
|
|
|
|
|
if MAX_EXAMPLES > 0 and len(dataset) > MAX_EXAMPLES: |
|
logging.info(f"Truncating dataset to {MAX_EXAMPLES} examples") |
|
indices = list(range(min(MAX_EXAMPLES, len(dataset)))) |
|
dataset = dataset.select(indices) |
|
|
|
logging.info(f"Dataset loaded: {dataset} with {len(dataset)} examples") |
|
|
|
|
|
logging.info(f"Loading tokenizer for {MODEL_ID}...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_ID, |
|
use_fast=True, |
|
trust_remote_code=True, |
|
|
|
padding_side="right", |
|
) |
|
|
|
|
|
if hasattr(tokenizer, 'is_fast') and tokenizer.is_fast: |
|
logging.info(f"Successfully loaded fast tokenizer (Rust implementation): {type(tokenizer).__name__}") |
|
|
|
logging.info(f"Fast tokenizer will use parallel processing during dataset.map() with {TOKENIZATION_NUM_WORKERS} workers") |
|
else: |
|
logging.warning(f"Using Python tokenizer: {type(tokenizer).__name__}") |
|
logging.warning("Python tokenizers are slower than Rust-based fast tokenizers") |
|
|
|
|
|
|
|
EXPECTED_PAD_TOKEN = "<|endoftext|>" |
|
if tokenizer.pad_token is None or tokenizer.pad_token != EXPECTED_PAD_TOKEN: |
|
logging.warning(f"Tokenizer pad_token is missing or not '{EXPECTED_PAD_TOKEN}'. Setting pad_token='{EXPECTED_PAD_TOKEN}'.") |
|
tokenizer.pad_token = EXPECTED_PAD_TOKEN |
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token |
|
tokenizer.padding_side = "right" |
|
|
|
|
|
logging.info(f"Tokenizer configuration:") |
|
logging.info(f" - Type: {'Fast' if hasattr(tokenizer, 'is_fast') and tokenizer.is_fast else 'Python'}") |
|
logging.info(f" - Pad token: {tokenizer.pad_token}") |
|
logging.info(f" - EOS token: {tokenizer.eos_token}") |
|
logging.info(f" - Vocab size: {tokenizer.vocab_size}") |
|
logging.info(f" - Model max length: {tokenizer.model_max_length}") |
|
logging.info(f" - Padding side: {tokenizer.padding_side}") |
|
|
|
|
|
def preprocess_function(examples): |
|
return tokenizer( |
|
examples["text"], |
|
padding="max_length", |
|
truncation=True, |
|
max_length=MAX_SEQ_LENGTH, |
|
return_tensors=None, |
|
) |
|
|
|
|
|
TOKENIZED_DATASET_CACHE_DIR = os.path.join(os.path.dirname(DATASET_PATH), "tokenized_cache") |
|
os.makedirs(TOKENIZED_DATASET_CACHE_DIR, exist_ok=True) |
|
tokenized_dataset_path = os.path.join(TOKENIZED_DATASET_CACHE_DIR, "tokenized_dataset") |
|
|
|
|
|
tokenization_done_file = os.path.join(TOKENIZED_DATASET_CACHE_DIR, "tokenization_complete") |
|
|
|
|
|
def delete_existing_tmp_files(): |
|
"""Find and delete any existing tmp files in dataset directory""" |
|
|
|
tmp_files = glob.glob(os.path.join(DATASET_PATH, "tmp*")) |
|
|
|
if tmp_files: |
|
logging.info(f"Found {len(tmp_files)} existing tmp files, removing...") |
|
for tmp_file in tmp_files: |
|
try: |
|
if os.path.isdir(tmp_file): |
|
shutil.rmtree(tmp_file) |
|
else: |
|
os.remove(tmp_file) |
|
logging.info(f"Removed: {tmp_file}") |
|
except Exception as e: |
|
logging.warning(f"Could not remove {tmp_file}: {str(e)}") |
|
else: |
|
logging.info("No existing tmp files found") |
|
|
|
|
|
if distributed_mode: |
|
rank = int(os.environ.get("RANK", "0")) |
|
world_size = int(os.environ.get("WORLD_SIZE", "1")) |
|
local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
|
is_main_process = rank == 0 |
|
logging.info(f"Rank {rank}/{world_size}: Preparing for dataset processing") |
|
else: |
|
is_main_process = True |
|
rank = 0 |
|
world_size = 1 |
|
local_rank = 0 |
|
|
|
|
|
if is_main_process: |
|
delete_existing_tmp_files() |
|
|
|
if os.path.exists(tokenization_done_file): |
|
os.remove(tokenization_done_file) |
|
logging.info(f"Rank {rank}: Removed old tokenization completion marker") |
|
|
|
|
|
need_tokenization = False |
|
|
|
|
|
if os.path.exists(tokenized_dataset_path) and os.path.isdir(tokenized_dataset_path): |
|
|
|
logging.info(f"Rank {rank}: Found existing tokenized dataset at {tokenized_dataset_path}") |
|
path_to_load = tokenized_dataset_path |
|
need_tokenization = False |
|
|
|
|
|
if is_main_process and not os.path.exists(tokenization_done_file): |
|
total_original_examples = "unknown" |
|
try: |
|
from datasets import load_dataset_builder |
|
original_dataset_info = load_dataset_builder(DATASET_PATH).info |
|
total_original_examples = original_dataset_info.splits['train'].num_examples |
|
except Exception as info_e: |
|
logging.warning(f"Rank {rank}: Could not get original dataset info: {info_e}") |
|
try: |
|
|
|
|
|
|
|
|
|
loaded_size = "unknown (loaded existing)" |
|
with open(tokenization_done_file, "w") as f: |
|
f.write(f"Tokenization assumed complete (loaded existing) at {datetime.now().isoformat()}\n") |
|
f.write(f"Processed {loaded_size} examples out of {total_original_examples}\n") |
|
logging.info(f"Rank {rank}: Created tokenization completion marker as it was missing.") |
|
except Exception as file_e: |
|
logging.error(f"Rank {rank}: Failed to create missing completion marker: {file_e}") |
|
|
|
|
|
|
|
elif not is_main_process: |
|
logging.info(f"Rank {rank}: Waiting for main process confirmation via marker file...") |
|
max_wait_time = 300 |
|
wait_start = time.time() |
|
while not os.path.exists(tokenization_done_file): |
|
if time.time() - wait_start > max_wait_time: |
|
logging.error(f"Rank {rank}: Timed out waiting for marker file from Rank 0.") |
|
raise TimeoutError("Marker file wait timeout") |
|
time.sleep(5) |
|
logging.info(f"Rank {rank}: Marker file found.") |
|
|
|
elif is_main_process: |
|
logging.info(f"Rank {rank}: Tokenization required. Proceeding with tokenization...") |
|
need_tokenization = True |
|
path_to_load = None |
|
|
|
elif distributed_mode: |
|
logging.info(f"Rank {rank}: Tokenization required. Waiting for main process...") |
|
need_tokenization = True |
|
path_to_load = tokenized_dataset_path |
|
|
|
|
|
if need_tokenization and is_main_process: |
|
tokenized_dataset_obj = None |
|
try: |
|
|
|
start_time = time.time() |
|
|
|
|
|
logging.info(f"Rank {rank}: Starting tokenization using dataset.map with {TOKENIZATION_NUM_WORKERS} workers.") |
|
|
|
tokenized_dataset_obj = dataset.map( |
|
preprocess_function, |
|
batched=True, |
|
batch_size=1000, |
|
num_proc=TOKENIZATION_NUM_WORKERS, |
|
remove_columns=["text"], |
|
load_from_cache_file=True, |
|
desc=f"Tokenizing dataset ({TOKENIZATION_NUM_WORKERS} workers)" |
|
) |
|
|
|
elapsed = time.time() - start_time |
|
logging.info(f"Rank {rank}: Tokenization successful in {elapsed:.2f} seconds.") |
|
|
|
|
|
if tokenized_dataset_obj is not None: |
|
logging.info(f"Rank {rank}: Dataset tokenization completed.") |
|
|
|
|
|
logging.info(f"Rank {rank}: Saving tokenized dataset to {tokenized_dataset_path}...") |
|
save_start = time.time() |
|
|
|
|
|
if os.path.exists(tokenized_dataset_path): |
|
shutil.rmtree(tokenized_dataset_path) |
|
|
|
tokenized_dataset_obj.save_to_disk(tokenized_dataset_path) |
|
save_elapsed = time.time() - save_start |
|
logging.info(f"Rank {rank}: Tokenized dataset saved in {save_elapsed:.2f} seconds.") |
|
|
|
|
|
with open(tokenization_done_file, "w") as f: |
|
f.write(f"Tokenization completed and saved at {datetime.now().isoformat()}\n") |
|
logging.info(f"Rank {rank}: Created tokenization completion marker") |
|
|
|
|
|
dataset = tokenized_dataset_obj |
|
path_to_load = None |
|
|
|
except Exception as e: |
|
logging.error(f"Rank {rank}: Tokenization failed: {e}") |
|
import traceback |
|
logging.error(traceback.format_exc()) |
|
|
|
with open(tokenization_done_file, "w") as f: |
|
f.write(f"Tokenization FAILED at {datetime.now().isoformat()}\nError: {e}") |
|
raise RuntimeError("Tokenization failed.") from e |
|
|
|
|
|
|
|
dataset_for_trainer = None |
|
if path_to_load: |
|
if not is_main_process and need_tokenization: |
|
|
|
logging.info(f"Rank {rank}: Waiting for tokenization completion signal (already checked for existence)...") |
|
|
|
pass |
|
|
|
|
|
logging.info(f"Rank {rank}: Loading dataset from {path_to_load}...") |
|
load_start_time = time.time() |
|
try: |
|
|
|
dataset_for_trainer = load_from_disk(path_to_load, keep_in_memory=False) |
|
load_elapsed = time.time() - load_start_time |
|
logging.info(f"Rank {rank}: Successfully loaded dataset in {load_elapsed:.2f}s. Length: {len(dataset_for_trainer)}") |
|
except Exception as e: |
|
logging.error(f"Rank {rank}: CRITICAL - Failed to load dataset from {path_to_load}: {e}") |
|
raise |
|
elif is_main_process and not need_tokenization: |
|
|
|
|
|
logging.info(f"Rank {rank}: Loading dataset from RAM disk copy {path_to_load}...") |
|
try: |
|
dataset_for_trainer = load_from_disk(path_to_load, keep_in_memory=False) |
|
logging.info(f"Rank {rank}: Successfully loaded dataset from RAM disk copy.") |
|
except Exception as e: |
|
logging.error(f"Rank {rank}: CRITICAL - Failed to load from RAM disk copy {path_to_load}: {e}") |
|
raise |
|
elif is_main_process and need_tokenization: |
|
|
|
logging.info(f"Rank {rank}: Using in-memory dataset from successful tokenization.") |
|
dataset_for_trainer = dataset |
|
else: |
|
|
|
logging.error(f"Rank {rank}: Dataset path logic error. path_to_load='{path_to_load}', need_tokenization={need_tokenization}") |
|
raise RuntimeError("Dataset preparation failed - logic error.") |
|
|
|
|
|
|
|
|
|
if distributed_mode: |
|
try: |
|
logging.info(f"Rank {rank}: Synchronizing after dataset preparation...") |
|
dist.barrier() |
|
logging.info(f"Rank {rank}: Synchronization complete.") |
|
except Exception as sync_e: |
|
logging.error(f"Rank {rank}: Synchronization after dataset prep failed: {sync_e}") |
|
raise |
|
|
|
|
|
def check_gpu_memory_utilization(): |
|
"""Check and report GPU memory utilization""" |
|
if not torch.cuda.is_available(): |
|
logging.info("CUDA not available, skipping GPU memory check.") |
|
return 0 |
|
|
|
logging.info("==== GPU MEMORY UTILIZATION CHECK ====") |
|
total_allocated_gb = 0 |
|
total_reserved_gb = 0 |
|
total_capacity_gb = 0 |
|
|
|
try: |
|
for i in range(torch.cuda.device_count()): |
|
free_mem, total_mem = torch.cuda.mem_get_info(i) |
|
allocated = torch.cuda.memory_allocated(i) |
|
reserved = torch.cuda.memory_reserved(i) |
|
|
|
free_gb = free_mem / (1024**3) |
|
total_gb = total_mem / (1024**3) |
|
allocated_gb = allocated / (1024**3) |
|
reserved_gb = reserved / (1024**3) |
|
utilized_pct = (1 - free_mem/total_mem) * 100 if total_mem > 0 else 0 |
|
|
|
total_allocated_gb += allocated_gb |
|
total_reserved_gb += reserved_gb |
|
total_capacity_gb += total_gb |
|
|
|
logging.info(f"GPU {i}: Allocated {allocated_gb:.1f}GB, Reserved {reserved_gb:.1f}GB, " |
|
f"Free {free_gb:.1f}GB, Total {total_gb:.1f}GB, " |
|
f"Utilization: {utilized_pct:.1f}%") |
|
|
|
avg_utilization = (total_allocated_gb / total_capacity_gb) * 100 if total_capacity_gb > 0 else 0 |
|
logging.info(f"OVERALL: Using {total_allocated_gb:.1f}GB / {total_capacity_gb:.1f}GB ({avg_utilization:.1f}% allocated)") |
|
logging.info("========================================") |
|
return avg_utilization |
|
except Exception as e: |
|
logging.error(f"Error checking GPU memory: {e}") |
|
return 0 |
|
|
|
|
|
logging.info(f"Rank {rank}: Loading model: {MODEL_ID}...") |
|
|
|
|
|
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
logging.info("Enabling YaRN scaling in model configuration.") |
|
config.rope_scaling = { |
|
"type": "yarn", |
|
"factor": 4.0, |
|
"original_max_position_embeddings": 32768, |
|
} |
|
|
|
|
|
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
|
|
|
|
|
if USE_FSDP: |
|
device_map = None |
|
logging.info("FSDP enabled: Setting device_map=None") |
|
elif distributed_mode: |
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
device_map = {"": local_rank} |
|
logging.info(f"Rank {rank}: DDP mode: Loading model on device {local_rank}") |
|
else: |
|
device_map = "auto" |
|
logging.info("Rank {rank}: Single process mode: Using automatic device mapping") |
|
|
|
|
|
use_flash_attn = USE_FLASH_ATTN and FLASH_ATTN_AVAILABLE |
|
attn_implementation = "flash_attention_2" if use_flash_attn else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_load_kwargs = { |
|
"config": config, |
|
"device_map": device_map, |
|
"low_cpu_mem_usage": True, |
|
"trust_remote_code": True, |
|
} |
|
if use_flash_attn: |
|
model_load_kwargs["attn_implementation"] = "flash_attention_2" |
|
|
|
|
|
|
|
model_load_kwargs["torch_dtype"] = torch_dtype |
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
try: |
|
logging.info(f"Rank {rank}: Calling AutoModelForCausalLM.from_pretrained...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
**model_load_kwargs |
|
) |
|
logging.info(f"Rank {rank}: Base model loaded successfully on device: {model.device if device_map is None else 'CPU/Multi'}") |
|
|
|
|
|
if torch_dtype == torch.bfloat16: |
|
logging.info("Explicitly casting model to bfloat16...") |
|
model = model.to(torch.bfloat16) |
|
|
|
|
|
use_better_transformers_flag = USE_BETTER_TRANSFORMERS and BETTER_TRANSFORMERS_AVAILABLE |
|
if use_better_transformers_flag: |
|
try: |
|
logging.info("Applying BetterTransformer optimizations...") |
|
model = BetterTransformer.transform(model) |
|
logging.info("BetterTransformer optimizations applied successfully") |
|
except Exception as bt_e: |
|
logging.warning(f"Could not apply BetterTransformer optimizations: {str(bt_e)}") |
|
|
|
|
|
if USE_ACTIVATION_CHECKPOINTING: |
|
try: |
|
logging.info("Enabling activation checkpointing...") |
|
model.gradient_checkpointing_enable() |
|
logging.info("Activation checkpointing enabled.") |
|
except Exception as ac_e: |
|
logging.warning(f"Could not enable activation checkpointing: {str(ac_e)}") |
|
|
|
|
|
logging.info(f"Rank {rank}: Model setup complete.") |
|
check_gpu_memory_utilization() |
|
|
|
except Exception as model_load_e: |
|
logging.error(f"Rank {rank}: Failed during model loading or preparation: {model_load_e}") |
|
import traceback |
|
logging.error(traceback.format_exc()) |
|
|
|
if distributed_mode and dist.is_initialized(): |
|
try: dist.destroy_process_group() |
|
except: pass |
|
raise |
|
|
|
|
|
|
|
peft_config = LoraConfig( |
|
r=LORA_R, |
|
lora_alpha=LORA_ALPHA, |
|
lora_dropout=LORA_DROPOUT, |
|
target_modules=LORA_TARGET_MODULES, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
|
|
if distributed_mode: |
|
try: |
|
logging.info(f"Rank {rank}: Synchronizing after model loading...") |
|
dist.barrier() |
|
logging.info(f"Rank {rank}: Synchronization after model loading complete.") |
|
except Exception as sync_e: |
|
logging.error(f"Rank {rank}: Synchronization after model loading failed: {sync_e}") |
|
raise |
|
|
|
|
|
|
|
determined_run_name = None |
|
if REPORT_TO_WANDB and is_main_process: |
|
try: |
|
import wandb |
|
if wandb.run is not None: determined_run_name = wandb.run.name |
|
except Exception: pass |
|
|
|
base_training_args = { |
|
|
|
"output_dir": OUTPUT_DIR, |
|
"per_device_train_batch_size": PER_DEVICE_TRAIN_BATCH_SIZE, |
|
"gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS, |
|
"optim": OPTIMIZER, |
|
"save_steps": SAVE_STEPS, |
|
"logging_steps": LOGGING_STEPS, |
|
"learning_rate": LEARNING_RATE, |
|
"num_train_epochs": EPOCHS, |
|
"max_steps": -1, |
|
"fp16": False, |
|
"bf16": torch_dtype == torch.bfloat16, |
|
"max_grad_norm": 0.3, |
|
"warmup_ratio": WARMUP_RATIO, |
|
"group_by_length": False, |
|
"lr_scheduler_type": LR_SCHEDULER_TYPE, |
|
"report_to": report_to, |
|
"save_total_limit": 3, |
|
"logging_first_step": True, |
|
**({"run_name": determined_run_name} if determined_run_name is not None else {}), |
|
"fsdp": "full_shard" if USE_FSDP else "", |
|
"fsdp_config": FSDP_CONFIG if USE_FSDP else {}, |
|
"dataloader_num_workers": DATALOADER_NUM_WORKERS, |
|
"resume_from_checkpoint": "auto", |
|
"save_strategy": "steps", |
|
"load_best_model_at_end": False, |
|
"metric_for_best_model": None, |
|
"dataset_text_field": "text", |
|
"packing": False, |
|
"max_seq_length": MAX_SEQ_LENGTH, |
|
|
|
"gradient_checkpointing_kwargs": {"use_reentrant": False}, |
|
"ddp_find_unused_parameters": False, |
|
"tf32": True, |
|
} |
|
training_arguments = SFTConfig(**base_training_args) |
|
logging.info(f"Rank {rank}: Training arguments (SFTConfig) created.") |
|
|
|
|
|
|
|
|
|
class MemoryMonitorCallback(TrainerCallback): |
|
def on_step_end(self, args, state, control, **kwargs): |
|
if state.global_step % 10 == 0: |
|
if torch.cuda.is_available(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
rank = int(os.environ.get("RANK", 0)) |
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
try: |
|
free_mem, total_mem = torch.cuda.mem_get_info(local_rank) |
|
free_gb = free_mem / (1024**3) |
|
used_gb = (total_mem - free_mem) / (1024**3) |
|
total_gb = total_mem / (1024**3) |
|
reserved = torch.cuda.memory_reserved(local_rank) / (1024**3) |
|
allocated = torch.cuda.memory_allocated(local_rank) / (1024**3) |
|
logging.info(f"Rank {rank}: Memory at step {state.global_step}: " |
|
f"{free_gb:.1f}GB free, {used_gb:.1f}GB used, {total_gb:.1f}GB total, " |
|
f"{reserved:.1f}GB reserved, {allocated:.1f}GB allocated") |
|
except Exception as mem_e: |
|
logging.warning(f"Rank {rank}: Could not get memory info: {mem_e}") |
|
return control |
|
|
|
memory_monitor = MemoryMonitorCallback() |
|
|
|
|
|
class FirstStepCallback(TrainerCallback): |
|
def __init__(self): |
|
self.first_step_start_time = None |
|
self.progress_indicators = 0 |
|
self.update_interval = 60 |
|
self.last_update_time = time.time() |
|
|
|
def on_step_begin(self, args, state, control, **kwargs): |
|
if state.global_step == 0: |
|
self.first_step_start_time = time.time() |
|
logging.info(f"FIRST STEP STARTING at {datetime.now().strftime('%H:%M:%S')}") |
|
if REPORT_TO_WANDB and 'wandb' in sys.modules: |
|
try: |
|
import wandb |
|
if wandb.run: |
|
wandb.log({"training_status": "first_step_started"}) |
|
except Exception as log_e: logging.warning(f"Wandb log error: {log_e}") |
|
return control |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
if state.global_step == 0: |
|
if self.first_step_start_time is None: |
|
logging.warning("First step ended but start time was not recorded.") |
|
return control |
|
duration = time.time() - self.first_step_start_time |
|
logging.info(f"FIRST STEP COMPLETED at {datetime.now().strftime('%H:%M:%S')} (took {duration:.2f} seconds)") |
|
if REPORT_TO_WANDB and 'wandb' in sys.modules: |
|
try: |
|
import wandb |
|
if wandb.run: |
|
wandb.log({ |
|
"training_status": "first_step_completed", |
|
"first_step_duration": duration |
|
}) |
|
except Exception as log_e: logging.warning(f"Wandb log error: {log_e}") |
|
return control |
|
|
|
def on_substep_end(self, args, state, control, **kwargs): |
|
|
|
current_time = time.time() |
|
|
|
if (self.first_step_start_time is not None and |
|
state.global_step == 0 and |
|
current_time - self.last_update_time >= self.update_interval and |
|
(not distributed_mode or int(os.environ.get("LOCAL_RANK", 0)) == 0)): |
|
self.progress_indicators += 1 |
|
elapsed = current_time - self.first_step_start_time |
|
logging.info(f"First step still in progress... ({elapsed:.1f}s elapsed, progress indicator {self.progress_indicators})") |
|
if REPORT_TO_WANDB and 'wandb' in sys.modules: |
|
try: |
|
import wandb |
|
if wandb.run: |
|
wandb.log({ |
|
"training_status": "first_step_in_progress", |
|
"first_step_elapsed": elapsed, |
|
"progress_indicator": self.progress_indicators |
|
}) |
|
except Exception as log_e: logging.warning(f"Wandb log error: {log_e}") |
|
self.last_update_time = current_time |
|
return control |
|
|
|
first_step_callback = FirstStepCallback() |
|
|
|
|
|
wandb_callback = None |
|
if REPORT_TO_WANDB and 'wandb' in sys.modules and (not distributed_mode or int(os.environ.get("LOCAL_RANK", 0)) == 0): |
|
try: |
|
|
|
class WandBLoggingCallback(TrainerCallback): |
|
"""Logs comprehensive training metrics and progress to Weights & Biases""" |
|
|
|
def __init__(self): |
|
self.training_start_time = None |
|
self.last_log_time = None |
|
self.total_steps = None |
|
self.samples_seen = 0 |
|
self.tokens_seen = 0 |
|
self.current_epoch = 0 |
|
self.epoch_start_time = None |
|
self.step_history = [] |
|
self.global_tokens_per_second = 0 |
|
self.progress_table = None |
|
|
|
def on_train_begin(self, args, state, control, **kwargs): |
|
"""Log hyperparameters and initialize tracking at the start of training""" |
|
if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return |
|
|
|
try: |
|
import wandb |
|
if not wandb.run: |
|
logging.warning("WandBCallback: Wandb not initialized in on_train_begin.") |
|
return |
|
except ImportError: |
|
logging.warning("WandBCallback: wandb not imported, cannot log on_train_begin") |
|
return |
|
|
|
self.training_start_time = time.time() |
|
self.epoch_start_time = time.time() |
|
self.last_log_time = time.time() |
|
|
|
|
|
if args.max_steps > 0: |
|
self.total_steps = args.max_steps |
|
else: |
|
|
|
trainer_instance = kwargs.get('trainer', None) |
|
if trainer_instance is None: |
|
trainer_instance = kwargs.get('model', None) |
|
|
|
dataset_length = 0 |
|
if trainer_instance and hasattr(trainer_instance, 'train_dataset') and trainer_instance.train_dataset is not None: |
|
try: |
|
dataset_length = len(trainer_instance.train_dataset) |
|
except Exception as len_e: |
|
logging.warning(f"WandBCallback: Error getting dataset length: {len_e}") |
|
else: |
|
logging.warning("WandBCallback: Could not access train_dataset length during on_train_begin.") |
|
|
|
batch_size = args.per_device_train_batch_size |
|
accumulation = args.gradient_accumulation_steps |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
global_batch_denom = (batch_size * world_size * accumulation) |
|
if dataset_length > 0 and global_batch_denom > 0: |
|
self.total_steps = (dataset_length // global_batch_denom) * args.num_train_epochs |
|
else: |
|
self.total_steps = -1 |
|
|
|
|
|
config = { |
|
"model_name": MODEL_ID, |
|
"lora_r": LORA_R, |
|
"lora_alpha": LORA_ALPHA, |
|
"batch_size": PER_DEVICE_TRAIN_BATCH_SIZE, |
|
"grad_accum": GRADIENT_ACCUMULATION_STEPS, |
|
"effective_batch": PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS, |
|
"global_batch": PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * world_size, |
|
"learning_rate": LEARNING_RATE, |
|
"seq_length": MAX_SEQ_LENGTH, |
|
"epochs": EPOCHS, |
|
"total_steps_estimated": self.total_steps, |
|
"optimizer": OPTIMIZER, |
|
"warmup_ratio": WARMUP_RATIO, |
|
"scheduler": LR_SCHEDULER_TYPE, |
|
} |
|
wandb.config.update(config) |
|
|
|
|
|
columns = ["step", "epoch", "loss", "lr", "tokens/sec", "eta", "elapsed_hrs"] |
|
self.progress_table = wandb.Table(columns=columns) |
|
|
|
|
|
wandb.log({"training_status": "started"}) |
|
logging.info(f"Training started - total estimated steps: {self.total_steps}") |
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
"""Log detailed metrics and progress after each logging step""" |
|
if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return |
|
|
|
try: |
|
import wandb |
|
if not wandb.run: |
|
logging.warning("WandBCallback: Wandb run not active during on_log.") |
|
return |
|
except ImportError: |
|
logging.warning("WandBCallback: wandb not imported, cannot log on_log") |
|
return |
|
|
|
if not logs: |
|
return |
|
|
|
|
|
metrics = {} |
|
for k, v in logs.items(): |
|
if isinstance(v, (int, float)): |
|
metrics[k] = v |
|
elif hasattr(v, "item"): |
|
try: metrics[k] = v.item() |
|
except: pass |
|
|
|
if not metrics: |
|
return |
|
|
|
|
|
current_time = time.time() |
|
if self.training_start_time is None: self.training_start_time = current_time |
|
elapsed_time = current_time - self.training_start_time |
|
elapsed_hrs = elapsed_time / 3600 |
|
|
|
|
|
batch_size = args.per_device_train_batch_size |
|
grad_accum = args.gradient_accumulation_steps |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
global_batch_size = batch_size * grad_accum * world_size |
|
tokens_per_step = global_batch_size * MAX_SEQ_LENGTH |
|
|
|
|
|
steps_since_last = state.global_step - (self.step_history[-1][0] if self.step_history else -1) |
|
if steps_since_last <= 0: steps_since_last = 1 |
|
new_tokens = tokens_per_step * steps_since_last |
|
self.tokens_seen += new_tokens |
|
|
|
|
|
time_since_last = current_time - (self.last_log_time if self.last_log_time else current_time) |
|
if time_since_last <= 0: time_since_last = 1.0 |
|
tokens_per_second = new_tokens / time_since_last |
|
|
|
|
|
alpha = 0.1 |
|
self.global_tokens_per_second = alpha * tokens_per_second + (1 - alpha) * self.global_tokens_per_second |
|
|
|
|
|
if "epoch" in metrics: |
|
new_epoch = int(metrics["epoch"]) |
|
if new_epoch > self.current_epoch: |
|
epoch_time = current_time - (self.epoch_start_time if self.epoch_start_time else current_time) |
|
self.epoch_start_time = current_time |
|
self.current_epoch = new_epoch |
|
wandb.log({"epoch/duration_sec": epoch_time}, step=state.global_step) |
|
logging.info(f"Epoch {self.current_epoch-1} completed in {epoch_time:.2f} seconds") |
|
|
|
epoch_float = metrics["epoch"] |
|
epoch_progress = epoch_float - int(epoch_float) |
|
metrics["epoch_progress"] = epoch_progress * 100 |
|
|
|
|
|
eta_hours = float('nan') |
|
if self.total_steps and self.total_steps > 0 and state.global_step > 0: |
|
progress_fraction = state.global_step / self.total_steps |
|
if progress_fraction > 1e-6: |
|
eta_seconds = elapsed_time / progress_fraction - elapsed_time |
|
eta_hours = eta_seconds / 3600 |
|
metrics["eta_hours"] = eta_hours |
|
|
|
|
|
metrics.update({ |
|
"progress/elapsed_hours": elapsed_hrs, |
|
"progress/tokens_total": self.tokens_seen, |
|
"performance/tokens_per_second": tokens_per_second, |
|
"performance/tokens_per_second_avg": self.global_tokens_per_second, |
|
"performance/global_batch_size": global_batch_size, |
|
}) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
try: |
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
|
|
|
metrics["gpu/memory_allocated_gb"] = torch.cuda.memory_allocated(local_rank) / 1e9 |
|
metrics["gpu/memory_reserved_gb"] = torch.cuda.memory_reserved(local_rank) / 1e9 |
|
except Exception as gpu_e: |
|
logging.debug(f"Could not log GPU metrics: {gpu_e}") |
|
|
|
|
|
wandb.log(metrics, step=state.global_step) |
|
|
|
|
|
if self.progress_table is not None: |
|
loss_val = metrics.get("loss", float("nan")) |
|
lr_val = metrics.get("learning_rate", float("nan")) |
|
epoch_val = metrics.get("epoch", 0) |
|
tokens_sec = metrics.get("performance/tokens_per_second_avg", 0) |
|
|
|
self.progress_table.add_data( |
|
state.global_step, |
|
f"{epoch_val:.2f}", |
|
f"{loss_val:.4f}", |
|
f"{lr_val:.2e}", |
|
f"{tokens_sec:.1f}", |
|
f"{eta_hours:.1f} hrs", |
|
f"{elapsed_hrs:.1f} hrs" |
|
) |
|
|
|
|
|
|
|
|
|
log_info = ( |
|
f"Step {state.global_step}" |
|
+ (f"/{self.total_steps} ({100 * state.global_step / self.total_steps:.1f}%)" if self.total_steps and self.total_steps > 0 else "") |
|
+ f" | Loss: {loss_val:.4f} | LR: {lr_val:.2e} | Epoch: {epoch_val:.2f}" |
|
+ f" | Tokens/sec: {tokens_sec:.1f}" |
|
+ (f" | ETA: {eta_hours:.1f}h" if not math.isnan(eta_hours) else "") |
|
) |
|
logging.info(log_info) |
|
|
|
|
|
self.last_log_time = current_time |
|
self.step_history.append((state.global_step, current_time)) |
|
if len(self.step_history) > 100: |
|
self.step_history = self.step_history[-100:] |
|
|
|
def on_train_end(self, args, state, control, **kwargs): |
|
"""Log final statistics at the end of training""" |
|
if not (REPORT_TO_WANDB and 'wandb' in sys.modules): return |
|
|
|
try: |
|
import wandb |
|
if not wandb.run: |
|
logging.warning("WandBCallback: Wandb run not active during on_train_end.") |
|
return |
|
except ImportError: |
|
logging.warning("WandBCallback: wandb not imported, cannot log on_train_end") |
|
return |
|
|
|
total_time = time.time() - (self.training_start_time if self.training_start_time else time.time()) |
|
hours = total_time / 3600 |
|
|
|
final_stats = { |
|
"training_status": "completed", |
|
"total_steps_completed": state.global_step, |
|
"total_epochs_completed": self.current_epoch, |
|
"total_training_time_hours": hours, |
|
"total_tokens_processed": self.tokens_seen, |
|
"average_tokens_per_second": self.tokens_seen / total_time if total_time > 0 else 0 |
|
} |
|
wandb.log(final_stats, step=state.global_step) |
|
|
|
wandb.run.summary.update({ |
|
"training_duration_hours": hours, |
|
"total_steps": state.global_step, |
|
"total_epochs": self.current_epoch, |
|
"total_tokens": self.tokens_seen |
|
}) |
|
logging.info(f"Training complete - {hours:.2f} hours, {state.global_step} steps, {self.tokens_seen:,} tokens processed") |
|
|
|
|
|
|
|
wandb_callback = WandBLoggingCallback() |
|
logging.info("Enhanced WandB logging callback created") |
|
except Exception as e: |
|
logging.error(f"Error creating WandB callback: {str(e)}") |
|
wandb_callback = None |
|
|
|
|
|
trainer_callbacks = [memory_monitor, first_step_callback] |
|
if wandb_callback: |
|
trainer_callbacks.append(wandb_callback) |
|
logging.info("Added WandB callback to trainer") |
|
|
|
|
|
|
|
logging.info(f"Rank {rank}: Initializing SFTTrainer...") |
|
|
|
trainer = None |
|
try: |
|
trainer = SFTTrainer( |
|
model=model, |
|
|
|
processing_class=tokenizer, |
|
args=training_arguments, |
|
train_dataset=dataset_for_trainer, |
|
peft_config=peft_config, |
|
|
|
preprocess_logits_for_metrics=None, |
|
callbacks=trainer_callbacks, |
|
) |
|
logging.info(f"Rank {rank}: SFTTrainer initialized successfully.") |
|
except Exception as e: |
|
logging.error(f"Rank {rank}: Error initializing SFTTrainer: {e}") |
|
import traceback |
|
logging.error(traceback.format_exc()) |
|
if distributed_mode and dist.is_initialized(): |
|
try: dist.destroy_process_group() |
|
except: pass |
|
raise |
|
|
|
|
|
if trainer is not None: |
|
logging.info(f"Beginning trainer.train() call at {datetime.now().strftime('%H:%M:%S')}") |
|
try: |
|
trainer.train() |
|
logging.info(f"Training finished successfully at {datetime.now().strftime('%H:%M:%S')}") |
|
except Exception as e: |
|
logging.error(f"Exception during training: {e}") |
|
import traceback |
|
logging.error(traceback.format_exc()) |
|
if distributed_mode and dist.is_initialized(): |
|
try: |
|
dist.destroy_process_group() |
|
logging.info("Destroyed process group after training error") |
|
except: |
|
pass |
|
raise |
|
|
|
|
|
logging.info("Merging adapter weights into base model...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info(f"Reloading base model ({MODEL_ID}) for merging...") |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
config=config, |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
device_map=None, |
|
attn_implementation="flash_attention_2" |
|
) |
|
|
|
|
|
logging.info(f"Loading PEFT model from {OUTPUT_DIR}...") |
|
merged_model = PeftModel.from_pretrained( |
|
base_model, |
|
OUTPUT_DIR, |
|
device_map=None, |
|
) |
|
|
|
|
|
logging.info("Merging LoRA weights...") |
|
merged_model = merged_model.merge_and_unload() |
|
logging.info("LoRA weights merged.") |
|
|
|
|
|
full_model_save_path = os.path.join(OUTPUT_DIR, "final_merged_model") |
|
|
|
|
|
logging.info(f"Saving merged model to {full_model_save_path}...") |
|
merged_model.save_pretrained(full_model_save_path) |
|
logging.info("Merged model saved.") |
|
|
|
|
|
logging.info(f"Saving tokenizer to {full_model_save_path}...") |
|
tokenizer.save_pretrained(full_model_save_path) |
|
logging.info("Tokenizer saved.") |
|
|
|
logging.info(f"Fine-tuning and merging process complete. Full model saved to {full_model_save_path}") |
|
|
|
|
|
logging.info("Training Checkpoint Notes:") |
|
logging.info(f" • Checkpoints saved to: {OUTPUT_DIR}") |
|
logging.info(f" • To resume training from the latest checkpoint, just rerun this script") |
|
logging.info(f" (resume_from_checkpoint='auto' will automatically find the latest checkpoint)") |
|
logging.info(f" • To resume from a specific checkpoint, set resume_from_checkpoint='path/to/checkpoint'") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|