sundaram22verma's picture
initial commit
9d76e23
"""Model fine-tuning implementation."""
import logging
from typing import Dict, List, Optional, Tuple
from datetime import datetime
import torch # type: ignore
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from sentence_transformers import InputExample, CrossEncoder, losses
from pathlib import Path
import numpy as np
import shutil
from transformers import get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from sklearn.model_selection import StratifiedKFold
from .config import DEFAULT_FINE_TUNING_CONFIG, MODEL_STATUS
from .utils import (
save_training_data,
load_training_data,
save_model_metadata,
load_model_metadata,
get_model_path,
cleanup_old_models,
get_latest_model_version,
load_user_feedback
)
logger = logging.getLogger(__name__)
class EarlyStopping:
"""Early stopping to stop training when validation loss doesn't improve."""
def __init__(self, patience=3, min_delta=1e-4):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.should_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
else:
self.best_loss = val_loss
self.counter = 0
return self.should_stop
class ModelTrainer:
"""Handles model fine-tuning and versioning."""
def __init__(self, base_model_name: str, device: Optional[str] = None):
self._validation_samples: List[Dict] = []
self.base_model_name = base_model_name
if device:
self.device = device
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"ModelTrainer initialized to use device: {self.device}")
self.ensemble_weights: Optional[List[float]] = None
self.config = DEFAULT_FINE_TUNING_CONFIG
# Initialize model ensemble
self.model_ensemble = []
# Load or initialize metadata
metadata = load_model_metadata()
if metadata is None:
metadata = {
"current_model_status": MODEL_STATUS["BASE"],
"current_version": "v0",
"last_fine_tuned": None,
"base_model_name": base_model_name,
"fine_tuning_config": self.config,
"base_model_updates": [] # Track base model update history
}
save_model_metadata(metadata)
@staticmethod
def collate_fn(batch):
"""Custom collate function for InputExample objects."""
query_texts = []
candidate_texts = []
labels = []
for example in batch:
query_text, candidate_text = example.texts
query_texts.append(query_text)
candidate_texts.append(candidate_text)
labels.append(example.label)
return {
"texts": list(zip(query_texts, candidate_texts)),
"labels": torch.tensor(labels)
}
def prepare_training_data(self, training_samples: List[Dict], append: bool = True) -> None:
"""Save training samples for future fine-tuning with user feedback integration."""
try:
# Load user feedback data
user_feedback = load_user_feedback()
# Integrate user feedback with training samples
for sample in training_samples:
feedback_key = f"{sample['query_text']}_{sample['candidate_text']}"
if feedback_key in user_feedback:
feedback = user_feedback[feedback_key]
# Update sample weight based on user feedback confidence
sample['weight'] = feedback['confidence']
# Update label if user feedback disagrees with original label
if abs(feedback['rating'] - sample['label']) > 0.3:
sample['label'] = feedback['rating']
else:
sample['weight'] = 1.0 # Default weight
save_training_data(training_samples, append)
logger.info(f"Successfully saved {len(training_samples)} training samples with user feedback integration")
except Exception as e:
logger.error(f"Error preparing training data: {e}")
def check_training_conditions(self) -> bool:
"""Check if conditions are met for fine-tuning."""
metadata = load_model_metadata()
if metadata is None:
return False
# Check if already training
if metadata.get("current_model_status") == MODEL_STATUS["TRAINING"]:
logger.info("Fine-tuning is already in progress")
return False
# Check training interval
last_fine_tuned = metadata.get("last_fine_tuned")
if last_fine_tuned:
last_fine_tuned = datetime.fromisoformat(last_fine_tuned)
hours_since_last_training = (datetime.now() - last_fine_tuned).total_seconds() / 3600
training_interval = metadata.get("fine_tuning_config", {}).get("training_interval_hours", self.config["training_interval_hours"])
if hours_since_last_training < training_interval:
logger.info(f"Not enough time elapsed since last fine-tuning ({hours_since_last_training:.1f} hours)")
return False
# Check training data
training_samples = load_training_data()
min_samples = metadata.get("fine_tuning_config", {}).get("min_training_samples", self.config["min_training_samples"])
if len(training_samples) < min_samples:
logger.info(f"Not enough training samples ({len(training_samples)} < {min_samples})")
return False
return True
def generate_hard_negatives(self, train_examples: List[InputExample], model) -> Tuple[List[InputExample], List[float]]:
"""Generate hard negative examples and their weights using current model predictions."""
hard_negative_examples = []
hard_negative_weights = []
batch_size = self.config["batch_size"]
# Create pairs of all positive examples with potential negatives
positive_examples = [ex for ex in train_examples if ex.label > 0.5]
# These are InputExample objects, we'll use their candidate texts
potential_negative_candidates = [ex for ex in train_examples if ex.label <= 0.5]
for pos_ex in positive_examples:
# Get predictions for potential negatives
neg_scores = []
for i in range(0, len(potential_negative_candidates), batch_size):
batch = potential_negative_candidates[i:i + batch_size]
texts = [(pos_ex.texts[0], neg.texts[1]) for neg in batch]
if not texts:
continue
with torch.no_grad():
scores = model.predict(texts)
neg_scores.extend(scores)
if not neg_scores:
continue
# Select top K highest scoring negatives as hard negatives
top_k = min(self.config.get("hard_negatives_top_k", 3), len(potential_negative_candidates))
if top_k == 0 or len(neg_scores) < top_k : # Ensure we have enough scores for argpartition
continue
top_indices = np.argpartition(neg_scores, -top_k)[-top_k:]
for idx in top_indices:
hard_neg = potential_negative_candidates[idx]
hard_negative_examples.append(InputExample(
texts=[pos_ex.texts[0], hard_neg.texts[1]],
label=0.0
))
hard_negative_weights.append(float(self.config.get("hard_negatives_weight", 1.2)))
return hard_negative_examples, hard_negative_weights
def create_weighted_sampler(self, weights_list: List[float]) -> WeightedRandomSampler:
"""Create a weighted sampler based on a list of weights."""
if not weights_list:
# Fallback or raise error if weights_list is empty,
# though DataLoader would also fail with an empty dataset.
# For now, let it proceed, PyTorch will handle empty tensor if it occurs.
logger.warning("create_weighted_sampler received an empty weights_list.")
weights_tensor = torch.DoubleTensor(weights_list)
# num_samples must be > 0 if weights_list is not empty.
sampler = WeightedRandomSampler(weights_tensor, num_samples=len(weights_tensor), replacement=True)
return sampler
def update_base_model(self, best_model_path: Path, validation_score: float) -> bool:
"""
Update the base model with the best performing model from fine-tuning.
Optimized for faster switching and reduced I/O operations.
Args:
best_model_path: Path to the best performing model
validation_score: Validation score of the best model
Returns:
bool: True if update was successful, False otherwise
"""
try:
base_model_dir = Path(self.base_model_name)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
backup_dir = base_model_dir.parent / f"{base_model_dir.name}_backup_{timestamp}"
# Create a temporary directory for the switch
temp_dir = base_model_dir.parent / f"temp_switch_{timestamp}"
temp_dir.mkdir(parents=True, exist_ok=True)
if base_model_dir.exists():
# Move current model to backup using atomic operation
shutil.move(str(base_model_dir), str(backup_dir))
logger.info(f"Created backup of base model at {backup_dir}")
# Update base model with best performing model using atomic operation
if best_model_path.exists():
# First move the best model to temp directory
shutil.move(str(best_model_path), str(temp_dir / "model"))
# Then move it to the final location
shutil.move(str(temp_dir / "model"), str(base_model_dir))
# Update metadata
metadata = load_model_metadata()
update_info = {
"timestamp": datetime.now().isoformat(),
"validation_score": float(validation_score),
"previous_backup": str(backup_dir)
}
if "base_model_updates" not in metadata:
metadata["base_model_updates"] = []
metadata["base_model_updates"].append(update_info)
save_model_metadata(metadata)
logger.info(f"Successfully updated base model with new version (validation score: {validation_score:.4f})")
# Cleanup temporary directory
if temp_dir.exists():
shutil.rmtree(str(temp_dir))
return True
else:
logger.error(f"Best model path does not exist: {best_model_path}")
return False
except Exception as e:
logger.error(f"Error updating base model: {e}")
# Attempt to restore from backup if available
if 'backup_dir' in locals() and backup_dir.exists():
try:
if base_model_dir.exists():
shutil.rmtree(str(base_model_dir))
shutil.move(str(backup_dir), str(base_model_dir))
logger.info("Restored base model from backup after failed update")
except Exception as restore_error:
logger.error(f"Failed to restore from backup: {restore_error}")
# Cleanup temporary directory if it exists
if 'temp_dir' in locals() and temp_dir.exists():
try:
shutil.rmtree(str(temp_dir))
except Exception as cleanup_error:
logger.error(f"Failed to cleanup temporary directory: {cleanup_error}")
return False
def fine_tune(self) -> Optional[str]:
"""
Run the fine-tuning process with improved accuracy and reliability.
Returns the new model version if successful, None otherwise.
"""
if not self.check_training_conditions():
return None
try:
# Update status to training
metadata = load_model_metadata()
metadata["current_model_status"] = MODEL_STATUS["TRAINING"]
save_model_metadata(metadata)
# Load training data
all_samples = load_training_data()
# Initialize a single model instance that will be reused across folds
base_model = CrossEncoder(self.base_model_name, device=self.device)
# Stratified K-fold cross validation
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
labels = [sample["label"] > 0.5 for sample in all_samples]
best_models = []
best_val_loss = float('inf')
all_fold_val_losses = []
best_model_path = None
# Pre-compute all hard negatives once
all_train_examples = [
InputExample(
texts=[sample["query_text"], sample["candidate_text"]],
label=float(sample["label"])
)
for sample in all_samples
]
hard_neg_examples, hard_neg_weights = self.generate_hard_negatives(all_train_examples, base_model)
for fold, (train_idx, val_idx) in enumerate(skf.split(all_samples, labels)):
logger.info(f"Training fold {fold + 1}/5")
# Create a copy of the base model for this fold
model = CrossEncoder(self.base_model_name, device=self.device)
model.load_state_dict(base_model.state_dict())
train_samples = [all_samples[i] for i in train_idx]
val_samples = [all_samples[i] for i in val_idx]
# Prepare training examples with pre-computed hard negatives
train_input_examples = []
train_example_weights = []
# Add regular training examples
for sample in train_samples:
train_input_examples.append(InputExample(
texts=[sample["query_text"], sample["candidate_text"]],
label=float(sample["label"])
))
train_example_weights.append(float(sample.get("weight", 1.0)))
# Add hard negatives for this fold's training samples
fold_train_texts = set(f"{s['query_text']}_{s['candidate_text']}" for s in train_samples)
for ex, weight in zip(hard_neg_examples, hard_neg_weights):
if f"{ex.texts[0]}_{ex.texts[1]}" in fold_train_texts:
train_input_examples.append(ex)
train_example_weights.append(weight)
val_examples = [
InputExample(
texts=[sample["query_text"], sample["candidate_text"]],
label=float(sample["label"])
)
for sample in val_samples
]
# Create weighted sampler
sampler = self.create_weighted_sampler(train_example_weights)
# Create data loaders with optimized batch size
train_dataloader = DataLoader(
train_input_examples,
sampler=sampler,
batch_size=self.config["batch_size"],
num_workers=self.config["dataloader_num_workers"],
pin_memory=self.config["pin_memory"],
collate_fn=self.collate_fn,
prefetch_factor=2 # Prefetch batches for faster loading
)
val_dataloader = DataLoader(
val_examples,
shuffle=False,
batch_size=self.config["batch_size"] * 2, # Larger batch size for validation
num_workers=self.config["dataloader_num_workers"],
pin_memory=self.config["pin_memory"],
collate_fn=self.collate_fn,
prefetch_factor=2
)
# Calculate steps for warmup and scheduler
num_training_steps = len(train_dataloader) * self.config["epochs"]
warmup_steps = min(self.config["warmup_steps"], num_training_steps // 10)
# Initialize optimizer with weight decay
optimizer = torch.optim.AdamW(
model.model.parameters(),
lr=self.config["learning_rate"],
weight_decay=self.config["weight_decay"],
betas=self.config["adam_betas"],
eps=self.config["adam_epsilon"]
)
# Initialize scheduler
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps,
num_cycles=self.config["num_cycles"]
)
# Initialize early stopping
early_stopping = EarlyStopping(
patience=self.config["early_stopping_patience"],
min_delta=self.config["early_stopping_min_delta"]
)
# Training loop with validation
fold_best_val_loss = float('inf')
best_model_state_for_fold = None
# Enable automatic mixed precision if available
scaler = torch.cuda.amp.GradScaler() if self.config["use_mixed_precision"] and self.device == "cuda" else None
for epoch in range(self.config["epochs"]):
model.train()
total_train_loss = 0
for batch_idx, batch in enumerate(train_dataloader):
optimizer.zero_grad()
texts = [list(text_pair) for text_pair in batch["texts"]]
labels = batch["labels"].to(self.device)
# Tokenize texts using the model's tokenizer
tokenized_inputs = model.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=model.max_length
)
tokenized_inputs = {key: val.to(self.device) for key, val in tokenized_inputs.items()}
# Use mixed precision if enabled
if scaler is not None:
with torch.cuda.amp.autocast():
outputs = model.model(**tokenized_inputs, labels=labels)
loss = outputs.loss / self.config["gradient_accumulation_steps"]
scaler.scale(loss).backward()
if (batch_idx + 1) % self.config["gradient_accumulation_steps"] == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.model.parameters(),
self.config["max_grad_norm"]
)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
else:
outputs = model.model(**tokenized_inputs, labels=labels)
loss = outputs.loss / self.config["gradient_accumulation_steps"]
loss.backward()
if (batch_idx + 1) % self.config["gradient_accumulation_steps"] == 0:
torch.nn.utils.clip_grad_norm_(
model.model.parameters(),
self.config["max_grad_norm"]
)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_train_loss += loss.item() * self.config["gradient_accumulation_steps"]
# Validation phase
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_dataloader:
texts = [list(text_pair) for text_pair in batch["texts"]]
labels = batch["labels"].to(self.device)
tokenized_inputs = model.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=model.max_length
)
tokenized_inputs = {key: val.to(self.device) for key, val in tokenized_inputs.items()}
if scaler is not None:
with torch.cuda.amp.autocast():
outputs = model.model(**tokenized_inputs, labels=labels)
current_loss = outputs.loss
else:
outputs = model.model(**tokenized_inputs, labels=labels)
current_loss = outputs.loss
val_loss += current_loss.item()
val_loss /= len(val_dataloader)
# Save best model for this fold
if val_loss < fold_best_val_loss:
fold_best_val_loss = val_loss
best_model_state_for_fold = model.state_dict()
if early_stopping(val_loss):
logger.info(f"Early stopping triggered in fold {fold + 1}")
break
# Save best model from this fold
if best_model_state_for_fold is not None:
model.load_state_dict(best_model_state_for_fold)
best_models.append(model)
all_fold_val_losses.append(fold_best_val_loss)
if fold_best_val_loss < best_val_loss:
best_val_loss = fold_best_val_loss
new_version = f"v{int(get_latest_model_version()[1:]) + 1}"
best_model_path = get_model_path(new_version) / f"ensemble_{fold}"
best_model_path.parent.mkdir(parents=True, exist_ok=True)
model.save(str(best_model_path))
# Create ensemble model
self.model_ensemble = best_models
self.ensemble_weights = [1.0 / len(best_models)] * len(best_models) if best_models else []
# Save ensemble metadata
ensemble_metadata = {
"num_models": len(best_models),
"fold_val_losses": all_fold_val_losses,
"ensemble_weights": self.ensemble_weights,
"best_model_path": str(best_model_path),
"best_overall_val_loss": best_val_loss
}
# Update metadata
metadata["current_model_status"] = MODEL_STATUS["FINE_TUNED"]
metadata["last_fine_tuned"] = datetime.now().isoformat()
metadata["ensemble_metadata"] = ensemble_metadata
save_model_metadata(metadata)
# Save all ensemble models
new_version = f"v{int(get_latest_model_version()[1:]) + 1}" if not best_model_path else best_model_path.parent.name.split('_')[-1]
model_dir_path = get_model_path(new_version)
# Save models in parallel using multiple processes
from concurrent.futures import ProcessPoolExecutor
def save_model(model_idx_model):
model_idx, model = model_idx_model
model.save(str(model_dir_path / f"ensemble_{model_idx}"))
with ProcessPoolExecutor() as executor:
executor.map(save_model, enumerate(best_models))
# Update base model with best performing model
if best_model_path and self.update_base_model(best_model_path, -best_val_loss):
logger.info("Successfully updated base model with best performing model")
else:
logger.warning("Failed to update base model")
cleanup_old_models()
return new_version
except Exception as e:
logger.error(f"Error during fine-tuning: {e}")
metadata["current_model_status"] = MODEL_STATUS["BASE"]
save_model_metadata(metadata)
return None
def predict(self, query_text: str, candidate_texts: List[str]) -> List[float]:
"""
Make predictions using the ensemble model.
"""
active_ensemble_weights: Optional[List[float]] = None
if not self.model_ensemble:
logger.info("Model ensemble not loaded. Attempting to load from metadata.")
metadata = load_model_metadata()
if metadata and "ensemble_metadata" in metadata and metadata["ensemble_metadata"].get("num_models", 0) > 0:
latest_version = get_latest_model_version()
# Uses get_model_path from .utils
model_dir_path = get_model_path(latest_version)
ensemble_meta = metadata["ensemble_metadata"]
num_models = ensemble_meta.get("num_models", 0)
meta_weights = ensemble_meta.get("ensemble_weights")
loaded_models_temp = []
loaded_weights_temp = []
for i in range(num_models):
member_model_path = model_dir_path / f"ensemble_{i}"
if member_model_path.exists():
try:
model = CrossEncoder(str(member_model_path), device=self.device)
loaded_models_temp.append(model)
if meta_weights and i < len(meta_weights):
loaded_weights_temp.append(meta_weights[i])
else: # Fallback for this model's weight
loaded_weights_temp.append(1.0)
except Exception as e:
logger.error(f"Failed to load ensemble member {member_model_path}: {e}")
else:
logger.warning(f"Ensemble model member path not found: {member_model_path}")
if loaded_models_temp:
self.model_ensemble = loaded_models_temp
# Normalize weights for the successfully loaded models
if loaded_weights_temp and sum(loaded_weights_temp) > 0:
self.ensemble_weights = [w / sum(loaded_weights_temp) for w in loaded_weights_temp]
else: # All weights zero, sum is zero, or no weights loaded, use equal
self.ensemble_weights = [1.0 / len(self.model_ensemble)] * len(self.model_ensemble)
logger.info(f"Loaded {len(self.model_ensemble)} models with weights: {self.ensemble_weights}")
# Fallback to single base model if ensemble loading failed or no ensemble metadata
if not self.model_ensemble:
logger.info("Falling back to base model as ensemble loading failed or no ensemble defined.")
model = CrossEncoder(self.base_model_name, device=self.device)
self.model_ensemble = [model]
self.ensemble_weights = [1.0]
# Determine weights to use for this prediction call
if self.model_ensemble:
if self.ensemble_weights and len(self.ensemble_weights) == len(self.model_ensemble):
active_ensemble_weights = self.ensemble_weights
else:
logger.warning(f"Ensemble weights inconsistent or missing for {len(self.model_ensemble)} models; using equal weights.")
active_ensemble_weights = [1.0 / len(self.model_ensemble)] * len(self.model_ensemble)
else: # No models loaded
logger.error("No models available for prediction.")
return [0.0] * len(candidate_texts)
all_predictions_np = []
for model in self.model_ensemble:
texts = [(query_text, cand_text) for cand_text in candidate_texts]
with torch.no_grad():
predictions = model.predict(texts) # Returns a NumPy array
all_predictions_np.append(predictions)
if not all_predictions_np:
logger.error("No predictions were generated by the ensemble.")
return [0.0] * len(candidate_texts)
# Weighted average of predictions
ensemble_predictions_np = np.average(np.array(all_predictions_np), axis=0, weights=active_ensemble_weights)
return ensemble_predictions_np.tolist()
# Removed conflicting class method `get_model_path`.
# The imported `get_model_path` from `.utils` will be used.