Spaces:
Sleeping
Sleeping
"""Model fine-tuning implementation.""" | |
import logging | |
from typing import Dict, List, Optional, Tuple | |
from datetime import datetime | |
import torch # type: ignore | |
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler | |
from sentence_transformers import InputExample, CrossEncoder, losses | |
from pathlib import Path | |
import numpy as np | |
import shutil | |
from transformers import get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup | |
from sklearn.model_selection import StratifiedKFold | |
from .config import DEFAULT_FINE_TUNING_CONFIG, MODEL_STATUS | |
from .utils import ( | |
save_training_data, | |
load_training_data, | |
save_model_metadata, | |
load_model_metadata, | |
get_model_path, | |
cleanup_old_models, | |
get_latest_model_version, | |
load_user_feedback | |
) | |
logger = logging.getLogger(__name__) | |
class EarlyStopping: | |
"""Early stopping to stop training when validation loss doesn't improve.""" | |
def __init__(self, patience=3, min_delta=1e-4): | |
self.patience = patience | |
self.min_delta = min_delta | |
self.counter = 0 | |
self.best_loss = None | |
self.should_stop = False | |
def __call__(self, val_loss): | |
if self.best_loss is None: | |
self.best_loss = val_loss | |
elif val_loss > self.best_loss - self.min_delta: | |
self.counter += 1 | |
if self.counter >= self.patience: | |
self.should_stop = True | |
else: | |
self.best_loss = val_loss | |
self.counter = 0 | |
return self.should_stop | |
class ModelTrainer: | |
"""Handles model fine-tuning and versioning.""" | |
def __init__(self, base_model_name: str, device: Optional[str] = None): | |
self._validation_samples: List[Dict] = [] | |
self.base_model_name = base_model_name | |
if device: | |
self.device = device | |
else: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"ModelTrainer initialized to use device: {self.device}") | |
self.ensemble_weights: Optional[List[float]] = None | |
self.config = DEFAULT_FINE_TUNING_CONFIG | |
# Initialize model ensemble | |
self.model_ensemble = [] | |
# Load or initialize metadata | |
metadata = load_model_metadata() | |
if metadata is None: | |
metadata = { | |
"current_model_status": MODEL_STATUS["BASE"], | |
"current_version": "v0", | |
"last_fine_tuned": None, | |
"base_model_name": base_model_name, | |
"fine_tuning_config": self.config, | |
"base_model_updates": [] # Track base model update history | |
} | |
save_model_metadata(metadata) | |
def collate_fn(batch): | |
"""Custom collate function for InputExample objects.""" | |
query_texts = [] | |
candidate_texts = [] | |
labels = [] | |
for example in batch: | |
query_text, candidate_text = example.texts | |
query_texts.append(query_text) | |
candidate_texts.append(candidate_text) | |
labels.append(example.label) | |
return { | |
"texts": list(zip(query_texts, candidate_texts)), | |
"labels": torch.tensor(labels) | |
} | |
def prepare_training_data(self, training_samples: List[Dict], append: bool = True) -> None: | |
"""Save training samples for future fine-tuning with user feedback integration.""" | |
try: | |
# Load user feedback data | |
user_feedback = load_user_feedback() | |
# Integrate user feedback with training samples | |
for sample in training_samples: | |
feedback_key = f"{sample['query_text']}_{sample['candidate_text']}" | |
if feedback_key in user_feedback: | |
feedback = user_feedback[feedback_key] | |
# Update sample weight based on user feedback confidence | |
sample['weight'] = feedback['confidence'] | |
# Update label if user feedback disagrees with original label | |
if abs(feedback['rating'] - sample['label']) > 0.3: | |
sample['label'] = feedback['rating'] | |
else: | |
sample['weight'] = 1.0 # Default weight | |
save_training_data(training_samples, append) | |
logger.info(f"Successfully saved {len(training_samples)} training samples with user feedback integration") | |
except Exception as e: | |
logger.error(f"Error preparing training data: {e}") | |
def check_training_conditions(self) -> bool: | |
"""Check if conditions are met for fine-tuning.""" | |
metadata = load_model_metadata() | |
if metadata is None: | |
return False | |
# Check if already training | |
if metadata.get("current_model_status") == MODEL_STATUS["TRAINING"]: | |
logger.info("Fine-tuning is already in progress") | |
return False | |
# Check training interval | |
last_fine_tuned = metadata.get("last_fine_tuned") | |
if last_fine_tuned: | |
last_fine_tuned = datetime.fromisoformat(last_fine_tuned) | |
hours_since_last_training = (datetime.now() - last_fine_tuned).total_seconds() / 3600 | |
training_interval = metadata.get("fine_tuning_config", {}).get("training_interval_hours", self.config["training_interval_hours"]) | |
if hours_since_last_training < training_interval: | |
logger.info(f"Not enough time elapsed since last fine-tuning ({hours_since_last_training:.1f} hours)") | |
return False | |
# Check training data | |
training_samples = load_training_data() | |
min_samples = metadata.get("fine_tuning_config", {}).get("min_training_samples", self.config["min_training_samples"]) | |
if len(training_samples) < min_samples: | |
logger.info(f"Not enough training samples ({len(training_samples)} < {min_samples})") | |
return False | |
return True | |
def generate_hard_negatives(self, train_examples: List[InputExample], model) -> Tuple[List[InputExample], List[float]]: | |
"""Generate hard negative examples and their weights using current model predictions.""" | |
hard_negative_examples = [] | |
hard_negative_weights = [] | |
batch_size = self.config["batch_size"] | |
# Create pairs of all positive examples with potential negatives | |
positive_examples = [ex for ex in train_examples if ex.label > 0.5] | |
# These are InputExample objects, we'll use their candidate texts | |
potential_negative_candidates = [ex for ex in train_examples if ex.label <= 0.5] | |
for pos_ex in positive_examples: | |
# Get predictions for potential negatives | |
neg_scores = [] | |
for i in range(0, len(potential_negative_candidates), batch_size): | |
batch = potential_negative_candidates[i:i + batch_size] | |
texts = [(pos_ex.texts[0], neg.texts[1]) for neg in batch] | |
if not texts: | |
continue | |
with torch.no_grad(): | |
scores = model.predict(texts) | |
neg_scores.extend(scores) | |
if not neg_scores: | |
continue | |
# Select top K highest scoring negatives as hard negatives | |
top_k = min(self.config.get("hard_negatives_top_k", 3), len(potential_negative_candidates)) | |
if top_k == 0 or len(neg_scores) < top_k : # Ensure we have enough scores for argpartition | |
continue | |
top_indices = np.argpartition(neg_scores, -top_k)[-top_k:] | |
for idx in top_indices: | |
hard_neg = potential_negative_candidates[idx] | |
hard_negative_examples.append(InputExample( | |
texts=[pos_ex.texts[0], hard_neg.texts[1]], | |
label=0.0 | |
)) | |
hard_negative_weights.append(float(self.config.get("hard_negatives_weight", 1.2))) | |
return hard_negative_examples, hard_negative_weights | |
def create_weighted_sampler(self, weights_list: List[float]) -> WeightedRandomSampler: | |
"""Create a weighted sampler based on a list of weights.""" | |
if not weights_list: | |
# Fallback or raise error if weights_list is empty, | |
# though DataLoader would also fail with an empty dataset. | |
# For now, let it proceed, PyTorch will handle empty tensor if it occurs. | |
logger.warning("create_weighted_sampler received an empty weights_list.") | |
weights_tensor = torch.DoubleTensor(weights_list) | |
# num_samples must be > 0 if weights_list is not empty. | |
sampler = WeightedRandomSampler(weights_tensor, num_samples=len(weights_tensor), replacement=True) | |
return sampler | |
def update_base_model(self, best_model_path: Path, validation_score: float) -> bool: | |
""" | |
Update the base model with the best performing model from fine-tuning. | |
Optimized for faster switching and reduced I/O operations. | |
Args: | |
best_model_path: Path to the best performing model | |
validation_score: Validation score of the best model | |
Returns: | |
bool: True if update was successful, False otherwise | |
""" | |
try: | |
base_model_dir = Path(self.base_model_name) | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
backup_dir = base_model_dir.parent / f"{base_model_dir.name}_backup_{timestamp}" | |
# Create a temporary directory for the switch | |
temp_dir = base_model_dir.parent / f"temp_switch_{timestamp}" | |
temp_dir.mkdir(parents=True, exist_ok=True) | |
if base_model_dir.exists(): | |
# Move current model to backup using atomic operation | |
shutil.move(str(base_model_dir), str(backup_dir)) | |
logger.info(f"Created backup of base model at {backup_dir}") | |
# Update base model with best performing model using atomic operation | |
if best_model_path.exists(): | |
# First move the best model to temp directory | |
shutil.move(str(best_model_path), str(temp_dir / "model")) | |
# Then move it to the final location | |
shutil.move(str(temp_dir / "model"), str(base_model_dir)) | |
# Update metadata | |
metadata = load_model_metadata() | |
update_info = { | |
"timestamp": datetime.now().isoformat(), | |
"validation_score": float(validation_score), | |
"previous_backup": str(backup_dir) | |
} | |
if "base_model_updates" not in metadata: | |
metadata["base_model_updates"] = [] | |
metadata["base_model_updates"].append(update_info) | |
save_model_metadata(metadata) | |
logger.info(f"Successfully updated base model with new version (validation score: {validation_score:.4f})") | |
# Cleanup temporary directory | |
if temp_dir.exists(): | |
shutil.rmtree(str(temp_dir)) | |
return True | |
else: | |
logger.error(f"Best model path does not exist: {best_model_path}") | |
return False | |
except Exception as e: | |
logger.error(f"Error updating base model: {e}") | |
# Attempt to restore from backup if available | |
if 'backup_dir' in locals() and backup_dir.exists(): | |
try: | |
if base_model_dir.exists(): | |
shutil.rmtree(str(base_model_dir)) | |
shutil.move(str(backup_dir), str(base_model_dir)) | |
logger.info("Restored base model from backup after failed update") | |
except Exception as restore_error: | |
logger.error(f"Failed to restore from backup: {restore_error}") | |
# Cleanup temporary directory if it exists | |
if 'temp_dir' in locals() and temp_dir.exists(): | |
try: | |
shutil.rmtree(str(temp_dir)) | |
except Exception as cleanup_error: | |
logger.error(f"Failed to cleanup temporary directory: {cleanup_error}") | |
return False | |
def fine_tune(self) -> Optional[str]: | |
""" | |
Run the fine-tuning process with improved accuracy and reliability. | |
Returns the new model version if successful, None otherwise. | |
""" | |
if not self.check_training_conditions(): | |
return None | |
try: | |
# Update status to training | |
metadata = load_model_metadata() | |
metadata["current_model_status"] = MODEL_STATUS["TRAINING"] | |
save_model_metadata(metadata) | |
# Load training data | |
all_samples = load_training_data() | |
# Initialize a single model instance that will be reused across folds | |
base_model = CrossEncoder(self.base_model_name, device=self.device) | |
# Stratified K-fold cross validation | |
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) | |
labels = [sample["label"] > 0.5 for sample in all_samples] | |
best_models = [] | |
best_val_loss = float('inf') | |
all_fold_val_losses = [] | |
best_model_path = None | |
# Pre-compute all hard negatives once | |
all_train_examples = [ | |
InputExample( | |
texts=[sample["query_text"], sample["candidate_text"]], | |
label=float(sample["label"]) | |
) | |
for sample in all_samples | |
] | |
hard_neg_examples, hard_neg_weights = self.generate_hard_negatives(all_train_examples, base_model) | |
for fold, (train_idx, val_idx) in enumerate(skf.split(all_samples, labels)): | |
logger.info(f"Training fold {fold + 1}/5") | |
# Create a copy of the base model for this fold | |
model = CrossEncoder(self.base_model_name, device=self.device) | |
model.load_state_dict(base_model.state_dict()) | |
train_samples = [all_samples[i] for i in train_idx] | |
val_samples = [all_samples[i] for i in val_idx] | |
# Prepare training examples with pre-computed hard negatives | |
train_input_examples = [] | |
train_example_weights = [] | |
# Add regular training examples | |
for sample in train_samples: | |
train_input_examples.append(InputExample( | |
texts=[sample["query_text"], sample["candidate_text"]], | |
label=float(sample["label"]) | |
)) | |
train_example_weights.append(float(sample.get("weight", 1.0))) | |
# Add hard negatives for this fold's training samples | |
fold_train_texts = set(f"{s['query_text']}_{s['candidate_text']}" for s in train_samples) | |
for ex, weight in zip(hard_neg_examples, hard_neg_weights): | |
if f"{ex.texts[0]}_{ex.texts[1]}" in fold_train_texts: | |
train_input_examples.append(ex) | |
train_example_weights.append(weight) | |
val_examples = [ | |
InputExample( | |
texts=[sample["query_text"], sample["candidate_text"]], | |
label=float(sample["label"]) | |
) | |
for sample in val_samples | |
] | |
# Create weighted sampler | |
sampler = self.create_weighted_sampler(train_example_weights) | |
# Create data loaders with optimized batch size | |
train_dataloader = DataLoader( | |
train_input_examples, | |
sampler=sampler, | |
batch_size=self.config["batch_size"], | |
num_workers=self.config["dataloader_num_workers"], | |
pin_memory=self.config["pin_memory"], | |
collate_fn=self.collate_fn, | |
prefetch_factor=2 # Prefetch batches for faster loading | |
) | |
val_dataloader = DataLoader( | |
val_examples, | |
shuffle=False, | |
batch_size=self.config["batch_size"] * 2, # Larger batch size for validation | |
num_workers=self.config["dataloader_num_workers"], | |
pin_memory=self.config["pin_memory"], | |
collate_fn=self.collate_fn, | |
prefetch_factor=2 | |
) | |
# Calculate steps for warmup and scheduler | |
num_training_steps = len(train_dataloader) * self.config["epochs"] | |
warmup_steps = min(self.config["warmup_steps"], num_training_steps // 10) | |
# Initialize optimizer with weight decay | |
optimizer = torch.optim.AdamW( | |
model.model.parameters(), | |
lr=self.config["learning_rate"], | |
weight_decay=self.config["weight_decay"], | |
betas=self.config["adam_betas"], | |
eps=self.config["adam_epsilon"] | |
) | |
# Initialize scheduler | |
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=warmup_steps, | |
num_training_steps=num_training_steps, | |
num_cycles=self.config["num_cycles"] | |
) | |
# Initialize early stopping | |
early_stopping = EarlyStopping( | |
patience=self.config["early_stopping_patience"], | |
min_delta=self.config["early_stopping_min_delta"] | |
) | |
# Training loop with validation | |
fold_best_val_loss = float('inf') | |
best_model_state_for_fold = None | |
# Enable automatic mixed precision if available | |
scaler = torch.cuda.amp.GradScaler() if self.config["use_mixed_precision"] and self.device == "cuda" else None | |
for epoch in range(self.config["epochs"]): | |
model.train() | |
total_train_loss = 0 | |
for batch_idx, batch in enumerate(train_dataloader): | |
optimizer.zero_grad() | |
texts = [list(text_pair) for text_pair in batch["texts"]] | |
labels = batch["labels"].to(self.device) | |
# Tokenize texts using the model's tokenizer | |
tokenized_inputs = model.tokenizer( | |
texts, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=model.max_length | |
) | |
tokenized_inputs = {key: val.to(self.device) for key, val in tokenized_inputs.items()} | |
# Use mixed precision if enabled | |
if scaler is not None: | |
with torch.cuda.amp.autocast(): | |
outputs = model.model(**tokenized_inputs, labels=labels) | |
loss = outputs.loss / self.config["gradient_accumulation_steps"] | |
scaler.scale(loss).backward() | |
if (batch_idx + 1) % self.config["gradient_accumulation_steps"] == 0: | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_( | |
model.model.parameters(), | |
self.config["max_grad_norm"] | |
) | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler.step() | |
optimizer.zero_grad() | |
else: | |
outputs = model.model(**tokenized_inputs, labels=labels) | |
loss = outputs.loss / self.config["gradient_accumulation_steps"] | |
loss.backward() | |
if (batch_idx + 1) % self.config["gradient_accumulation_steps"] == 0: | |
torch.nn.utils.clip_grad_norm_( | |
model.model.parameters(), | |
self.config["max_grad_norm"] | |
) | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
total_train_loss += loss.item() * self.config["gradient_accumulation_steps"] | |
# Validation phase | |
model.eval() | |
val_loss = 0 | |
with torch.no_grad(): | |
for batch in val_dataloader: | |
texts = [list(text_pair) for text_pair in batch["texts"]] | |
labels = batch["labels"].to(self.device) | |
tokenized_inputs = model.tokenizer( | |
texts, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=model.max_length | |
) | |
tokenized_inputs = {key: val.to(self.device) for key, val in tokenized_inputs.items()} | |
if scaler is not None: | |
with torch.cuda.amp.autocast(): | |
outputs = model.model(**tokenized_inputs, labels=labels) | |
current_loss = outputs.loss | |
else: | |
outputs = model.model(**tokenized_inputs, labels=labels) | |
current_loss = outputs.loss | |
val_loss += current_loss.item() | |
val_loss /= len(val_dataloader) | |
# Save best model for this fold | |
if val_loss < fold_best_val_loss: | |
fold_best_val_loss = val_loss | |
best_model_state_for_fold = model.state_dict() | |
if early_stopping(val_loss): | |
logger.info(f"Early stopping triggered in fold {fold + 1}") | |
break | |
# Save best model from this fold | |
if best_model_state_for_fold is not None: | |
model.load_state_dict(best_model_state_for_fold) | |
best_models.append(model) | |
all_fold_val_losses.append(fold_best_val_loss) | |
if fold_best_val_loss < best_val_loss: | |
best_val_loss = fold_best_val_loss | |
new_version = f"v{int(get_latest_model_version()[1:]) + 1}" | |
best_model_path = get_model_path(new_version) / f"ensemble_{fold}" | |
best_model_path.parent.mkdir(parents=True, exist_ok=True) | |
model.save(str(best_model_path)) | |
# Create ensemble model | |
self.model_ensemble = best_models | |
self.ensemble_weights = [1.0 / len(best_models)] * len(best_models) if best_models else [] | |
# Save ensemble metadata | |
ensemble_metadata = { | |
"num_models": len(best_models), | |
"fold_val_losses": all_fold_val_losses, | |
"ensemble_weights": self.ensemble_weights, | |
"best_model_path": str(best_model_path), | |
"best_overall_val_loss": best_val_loss | |
} | |
# Update metadata | |
metadata["current_model_status"] = MODEL_STATUS["FINE_TUNED"] | |
metadata["last_fine_tuned"] = datetime.now().isoformat() | |
metadata["ensemble_metadata"] = ensemble_metadata | |
save_model_metadata(metadata) | |
# Save all ensemble models | |
new_version = f"v{int(get_latest_model_version()[1:]) + 1}" if not best_model_path else best_model_path.parent.name.split('_')[-1] | |
model_dir_path = get_model_path(new_version) | |
# Save models in parallel using multiple processes | |
from concurrent.futures import ProcessPoolExecutor | |
def save_model(model_idx_model): | |
model_idx, model = model_idx_model | |
model.save(str(model_dir_path / f"ensemble_{model_idx}")) | |
with ProcessPoolExecutor() as executor: | |
executor.map(save_model, enumerate(best_models)) | |
# Update base model with best performing model | |
if best_model_path and self.update_base_model(best_model_path, -best_val_loss): | |
logger.info("Successfully updated base model with best performing model") | |
else: | |
logger.warning("Failed to update base model") | |
cleanup_old_models() | |
return new_version | |
except Exception as e: | |
logger.error(f"Error during fine-tuning: {e}") | |
metadata["current_model_status"] = MODEL_STATUS["BASE"] | |
save_model_metadata(metadata) | |
return None | |
def predict(self, query_text: str, candidate_texts: List[str]) -> List[float]: | |
""" | |
Make predictions using the ensemble model. | |
""" | |
active_ensemble_weights: Optional[List[float]] = None | |
if not self.model_ensemble: | |
logger.info("Model ensemble not loaded. Attempting to load from metadata.") | |
metadata = load_model_metadata() | |
if metadata and "ensemble_metadata" in metadata and metadata["ensemble_metadata"].get("num_models", 0) > 0: | |
latest_version = get_latest_model_version() | |
# Uses get_model_path from .utils | |
model_dir_path = get_model_path(latest_version) | |
ensemble_meta = metadata["ensemble_metadata"] | |
num_models = ensemble_meta.get("num_models", 0) | |
meta_weights = ensemble_meta.get("ensemble_weights") | |
loaded_models_temp = [] | |
loaded_weights_temp = [] | |
for i in range(num_models): | |
member_model_path = model_dir_path / f"ensemble_{i}" | |
if member_model_path.exists(): | |
try: | |
model = CrossEncoder(str(member_model_path), device=self.device) | |
loaded_models_temp.append(model) | |
if meta_weights and i < len(meta_weights): | |
loaded_weights_temp.append(meta_weights[i]) | |
else: # Fallback for this model's weight | |
loaded_weights_temp.append(1.0) | |
except Exception as e: | |
logger.error(f"Failed to load ensemble member {member_model_path}: {e}") | |
else: | |
logger.warning(f"Ensemble model member path not found: {member_model_path}") | |
if loaded_models_temp: | |
self.model_ensemble = loaded_models_temp | |
# Normalize weights for the successfully loaded models | |
if loaded_weights_temp and sum(loaded_weights_temp) > 0: | |
self.ensemble_weights = [w / sum(loaded_weights_temp) for w in loaded_weights_temp] | |
else: # All weights zero, sum is zero, or no weights loaded, use equal | |
self.ensemble_weights = [1.0 / len(self.model_ensemble)] * len(self.model_ensemble) | |
logger.info(f"Loaded {len(self.model_ensemble)} models with weights: {self.ensemble_weights}") | |
# Fallback to single base model if ensemble loading failed or no ensemble metadata | |
if not self.model_ensemble: | |
logger.info("Falling back to base model as ensemble loading failed or no ensemble defined.") | |
model = CrossEncoder(self.base_model_name, device=self.device) | |
self.model_ensemble = [model] | |
self.ensemble_weights = [1.0] | |
# Determine weights to use for this prediction call | |
if self.model_ensemble: | |
if self.ensemble_weights and len(self.ensemble_weights) == len(self.model_ensemble): | |
active_ensemble_weights = self.ensemble_weights | |
else: | |
logger.warning(f"Ensemble weights inconsistent or missing for {len(self.model_ensemble)} models; using equal weights.") | |
active_ensemble_weights = [1.0 / len(self.model_ensemble)] * len(self.model_ensemble) | |
else: # No models loaded | |
logger.error("No models available for prediction.") | |
return [0.0] * len(candidate_texts) | |
all_predictions_np = [] | |
for model in self.model_ensemble: | |
texts = [(query_text, cand_text) for cand_text in candidate_texts] | |
with torch.no_grad(): | |
predictions = model.predict(texts) # Returns a NumPy array | |
all_predictions_np.append(predictions) | |
if not all_predictions_np: | |
logger.error("No predictions were generated by the ensemble.") | |
return [0.0] * len(candidate_texts) | |
# Weighted average of predictions | |
ensemble_predictions_np = np.average(np.array(all_predictions_np), axis=0, weights=active_ensemble_weights) | |
return ensemble_predictions_np.tolist() | |
# Removed conflicting class method `get_model_path`. | |
# The imported `get_model_path` from `.utils` will be used. | |