|
|
import re |
|
|
import torch |
|
|
|
|
|
from skp.configs import Config |
|
|
from importlib import import_module |
|
|
from typing import Dict, Optional, Sequence |
|
|
|
|
|
|
|
|
def load_weights_from_path(path: str) -> Dict[str, torch.Tensor]: |
|
|
w = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)[ |
|
|
"state_dict" |
|
|
] |
|
|
w = { |
|
|
re.sub(r"^model.", "", k): v |
|
|
for k, v in w.items() |
|
|
if k.startswith("model.") and "criterion" not in k |
|
|
} |
|
|
return w |
|
|
|
|
|
|
|
|
def load_model_from_config( |
|
|
cfg: Config, |
|
|
weights_path: Optional[str] = None, |
|
|
device: str = "cpu", |
|
|
eval_mode: bool = True, |
|
|
) -> torch.nn.Module: |
|
|
model = import_module(f"skp.models.{cfg.model}").Net(cfg) |
|
|
if weights_path: |
|
|
weights = load_weights_from_path(weights_path) |
|
|
model.load_state_dict(weights) |
|
|
model = model.to(device).train(mode=not eval_mode) |
|
|
return model |
|
|
|
|
|
|
|
|
def load_kfold_ensemble_as_list( |
|
|
cfg: Config, |
|
|
weights_paths: Sequence[str], |
|
|
device: str = "cpu", |
|
|
eval_mode: bool = True, |
|
|
) -> torch.nn.ModuleList: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_list = torch.nn.ModuleList() |
|
|
for each_weight in weights_paths: |
|
|
model = load_model_from_config(cfg, each_weight, device, eval_mode) |
|
|
model_list.append(model) |
|
|
return model_list |
|
|
|