|
|
|
|
|
import argparse |
|
from pathlib import Path |
|
import torch |
|
from safetensors.torch import save_file as safetensors_save_file |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Strip training checkpoint to inference-only weights (FP32)." |
|
) |
|
parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt") |
|
parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt or .safetensors") |
|
|
|
args = parser.parse_args() |
|
|
|
in_path = Path(args.inp) |
|
out_path = Path(args.out) |
|
|
|
assert in_path.is_file(), f"Input file does not exist: {in_path}" |
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
ckpt = torch.load(str(in_path), map_location="cpu") |
|
|
|
|
|
if isinstance(ckpt, dict) and "model" in ckpt: |
|
state_dict = ckpt["model"] |
|
|
|
elif isinstance(ckpt, dict) and "state_dict" in ckpt: |
|
state_dict = ckpt["state_dict"] |
|
else: |
|
|
|
assert isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()), ( |
|
"Checkpoint is not a recognized format: expected keys 'model' or 'state_dict', " |
|
"or a pure state_dict (name->Tensor)." |
|
) |
|
state_dict = ckpt |
|
|
|
|
|
|
|
|
|
suffix = out_path.suffix.lower() |
|
if suffix == ".safetensors": |
|
safetensors_save_file(state_dict, str(out_path)) |
|
print(f"[strip_checkpoint] Saved safetensors (pure state_dict) to: {out_path}") |
|
else: |
|
to_save = {"model": state_dict} |
|
torch.save(to_save, str(out_path)) |
|
print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|