WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
804 Bytes
import os
import gzip
import torch
import torch.nn as nn
def save_model(model: torch.nn.Module, path: str) -> None:
"""Save a model using gzip compression."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with gzip.open(path, 'wb') as f:
torch.save(model, f)
def load_model(path: str) -> torch.nn.Module:
"""Load a model saved with ``save_model``."""
with gzip.open(path, 'rb') as f:
model = torch.load(f, map_location="cpu", weights_only=False)
return model
def set_dropout(model: torch.nn.Module, p: float) -> None:
"""Set dropout probability ``p`` for all dropout layers in ``model``."""
for module in model.modules():
if isinstance(module, nn.Dropout):
module.p = p
__all__ = ["save_model", "load_model", "set_dropout"]