import re import torch from skp.configs import Config from importlib import import_module from typing import Dict, Optional, Sequence def load_weights_from_path(path: str) -> Dict[str, torch.Tensor]: w = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)[ "state_dict" ] w = { re.sub(r"^model.", "", k): v for k, v in w.items() if k.startswith("model.") and "criterion" not in k } return w def load_model_from_config( cfg: Config, weights_path: Optional[str] = None, device: str = "cpu", eval_mode: bool = True, ) -> torch.nn.Module: model = import_module(f"skp.models.{cfg.model}").Net(cfg) if weights_path: weights = load_weights_from_path(weights_path) model.load_state_dict(weights) model = model.to(device).train(mode=not eval_mode) return model def load_kfold_ensemble_as_list( cfg: Config, weights_paths: Sequence[str], device: str = "cpu", eval_mode: bool = True, ) -> torch.nn.ModuleList: # multiple folds for the same model # does not work for ensembling different types of models # assumes that trained weights are available # otherwise why would you load multiple of the same model randomly initialized model_list = torch.nn.ModuleList() for each_weight in weights_paths: model = load_model_from_config(cfg, each_weight, device, eval_mode) model_list.append(model) return model_list