import os import gzip import torch import torch.nn as nn def save_model(model: torch.nn.Module, path: str) -> None: """Save a model using gzip compression.""" os.makedirs(os.path.dirname(path), exist_ok=True) with gzip.open(path, 'wb') as f: torch.save(model, f) def load_model(path: str) -> torch.nn.Module: """Load a model saved with ``save_model``.""" with gzip.open(path, 'rb') as f: model = torch.load(f, map_location="cpu", weights_only=False) return model def set_dropout(model: torch.nn.Module, p: float) -> None: """Set dropout probability ``p`` for all dropout layers in ``model``.""" for module in model.modules(): if isinstance(module, nn.Dropout): module.p = p __all__ = ["save_model", "load_model", "set_dropout"]