ianpan's picture
update models, output, examples
455e8ef
raw
history blame
1.49 kB
import re
import torch
from skp.configs import Config
from importlib import import_module
from typing import Dict, Optional, Sequence
def load_weights_from_path(path: str) -> Dict[str, torch.Tensor]:
w = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)[
"state_dict"
]
w = {
re.sub(r"^model.", "", k): v
for k, v in w.items()
if k.startswith("model.") and "criterion" not in k
}
return w
def load_model_from_config(
cfg: Config,
weights_path: Optional[str] = None,
device: str = "cpu",
eval_mode: bool = True,
) -> torch.nn.Module:
model = import_module(f"skp.models.{cfg.model}").Net(cfg)
if weights_path:
weights = load_weights_from_path(weights_path)
model.load_state_dict(weights)
model = model.to(device).train(mode=not eval_mode)
return model
def load_kfold_ensemble_as_list(
cfg: Config,
weights_paths: Sequence[str],
device: str = "cpu",
eval_mode: bool = True,
) -> torch.nn.ModuleList:
# multiple folds for the same model
# does not work for ensembling different types of models
# assumes that trained weights are available
# otherwise why would you load multiple of the same model randomly initialized
model_list = torch.nn.ModuleList()
for each_weight in weights_paths:
model = load_model_from_config(cfg, each_weight, device, eval_mode)
model_list.append(model)
return model_list