"""Callback that loads model weights from the state dict.""" # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import logging import torch from lightning.pytorch import Callback, Trainer from anomalib.models.components import AnomalyModule logger = logging.getLogger(__name__) class LoadModelCallback(Callback): """Callback that loads the model weights from the state dict. Examples: >>> from anomalib.callbacks import LoadModelCallback >>> from anomalib.engine import Engine ... >>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")] >>> engine = Engine(callbacks=callbacks) """ def __init__(self, weights_path: str) -> None: self.weights_path = weights_path def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None: """Call when inference begins. Loads the model weights from ``weights_path`` into the PyTorch module. """ del trainer, stage # These variables are not used. logger.info("Loading the model from %s", self.weights_path) pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])