| # custom_early_stopping.py | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
| class MultiMetricEarlyStopping(EarlyStopping): | |
| def __init__(self, monitor_mood, monitor_va, patience, min_delta, mode="min"): | |
| super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode) | |
| self.monitor_mood = monitor_mood | |
| self.monitor_va = monitor_va | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.mode = mode | |
| # Initialize tracking variables | |
| self.wait_mood = 0 | |
| self.wait_va = 0 | |
| self.best_mood = float('inf') if mode == "min" else -float('inf') | |
| self.best_va = float('inf') if mode == "min" else -float('inf') | |
| def _check_stop(self, current, best, wait): | |
| if self.mode == "min" and current < best - self.min_delta: | |
| return current, 0 | |
| elif self.mode == "max" and current > best + self.min_delta: | |
| return current, 0 | |
| else: | |
| return best, wait + 1 | |
| def on_validation_epoch_end(self, trainer, pl_module): | |
| logs = trainer.callback_metrics | |
| if self.monitor_mood not in logs or self.monitor_va not in logs: | |
| raise RuntimeError(f"Metrics {self.monitor_mood} or {self.monitor_va} not available.") | |
| # Get current values for the monitored metrics | |
| current_mood = logs[self.monitor_mood].item() | |
| current_va = logs[self.monitor_va].item() | |
| # Check stopping conditions for both metrics | |
| self.best_mood, self.wait_mood = self._check_stop(current_mood, self.best_mood, self.wait_mood) | |
| self.best_va, self.wait_va = self._check_stop(current_va, self.best_va, self.wait_va) | |
| # Stop if patience exceeded for both metrics | |
| if self.wait_mood > self.patience and self.wait_va > self.patience: | |
| self.stopped_epoch = trainer.current_epoch | |
| trainer.should_stop = True | |
| # # custom_early_stopping.py | |
| # import pytorch_lightning as pl | |
| # from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
| # class MultiMetricEarlyStopping(EarlyStopping): | |
| # def __init__(self, monitor_mood: str, monitor_va: str, patience: int = 10, min_delta: float = 0.0, mode: str = "min"): | |
| # super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode) | |
| # self.monitor_mood = monitor_mood | |
| # self.monitor_va = monitor_va | |
| # self.wait_mood = 0 | |
| # self.wait_va = 0 | |
| # self.best_mood_score = None | |
| # self.best_va_score = None | |
| # self.patience = patience | |
| # self.stopped_epoch = 0 | |
| # def on_validation_end(self, trainer, pl_module): | |
| # current_mood = trainer.callback_metrics.get(self.monitor_mood) | |
| # current_va = trainer.callback_metrics.get(self.monitor_va) | |
| # # Check if current_mood improved | |
| # if self.best_mood_score is None or self._compare(current_mood, self.best_mood_score): | |
| # self.best_mood_score = current_mood | |
| # self.wait_mood = 0 | |
| # else: | |
| # self.wait_mood += 1 | |
| # # Check if current_va improved | |
| # if self.best_va_score is None or self._compare(current_va, self.best_va_score): | |
| # self.best_va_score = current_va | |
| # self.wait_va = 0 | |
| # else: | |
| # self.wait_va += 1 | |
| # # If both metrics are stagnant for patience epochs, stop training | |
| # if self.wait_mood >= self.patience and self.wait_va >= self.patience: | |
| # self.stopped_epoch = trainer.current_epoch | |
| # trainer.should_stop = True | |
| # def _compare(self, current, best): | |
| # if self.mode == "min": | |
| # return current < best - self.min_delta | |
| # else: | |
| # return current > best + self.min_delta |