Safetensors export and infer.
Browse files- infer.py +18 -5
- requirements.txt +1 -0
- scripts/strip_checkpoint.py +11 -4
infer.py
CHANGED
@@ -3,14 +3,13 @@ import os
|
|
3 |
import pprint
|
4 |
import time
|
5 |
from typing import List, Tuple, Optional, Dict, Any
|
6 |
-
import yaml
|
7 |
-
|
8 |
import numpy as np
|
9 |
import cv2
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
from torch.amp import autocast
|
13 |
from tqdm import tqdm
|
|
|
14 |
|
15 |
from src.wireseghr.model import WireSegHR
|
16 |
from pathlib import Path
|
@@ -256,7 +255,7 @@ def main():
|
|
256 |
"--ckpt",
|
257 |
type=str,
|
258 |
default="",
|
259 |
-
help="Optional checkpoint (.pt
|
260 |
)
|
261 |
parser.add_argument(
|
262 |
"--save_prob", action="store_true", help="Also save probability .npy"
|
@@ -356,8 +355,22 @@ def main():
|
|
356 |
if ckpt_path:
|
357 |
assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
|
358 |
print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}")
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
model.eval()
|
362 |
|
363 |
# Benchmark mode
|
|
|
3 |
import pprint
|
4 |
import time
|
5 |
from typing import List, Tuple, Optional, Dict, Any
|
|
|
|
|
6 |
import numpy as np
|
7 |
import cv2
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
from torch.amp import autocast
|
11 |
from tqdm import tqdm
|
12 |
+
from safetensors.torch import load_file as safe_load_file
|
13 |
|
14 |
from src.wireseghr.model import WireSegHR
|
15 |
from pathlib import Path
|
|
|
255 |
"--ckpt",
|
256 |
type=str,
|
257 |
default="",
|
258 |
+
help="Optional checkpoint (.pt with {'model': state_dict} or .safetensors with pure state_dict)",
|
259 |
)
|
260 |
parser.add_argument(
|
261 |
"--save_prob", action="store_true", help="Also save probability .npy"
|
|
|
355 |
if ckpt_path:
|
356 |
assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
|
357 |
print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}")
|
358 |
+
suffix = Path(ckpt_path).suffix.lower()
|
359 |
+
if suffix == ".safetensors":
|
360 |
+
# Safetensors exports contain a pure state_dict
|
361 |
+
state_dict = safe_load_file(ckpt_path)
|
362 |
+
model.load_state_dict(state_dict)
|
363 |
+
else:
|
364 |
+
print(
|
365 |
+
"[WireSegHR][infer][WARN] Loading a PyTorch checkpoint. Prefer .safetensors for inference-only weights."
|
366 |
+
)
|
367 |
+
# PyTorch .pt/.pth checkpoints expected to have {'model': state_dict}
|
368 |
+
state = torch.load(ckpt_path, map_location=device)
|
369 |
+
assert "model" in state, (
|
370 |
+
"Expected a dict with key 'model' for PyTorch checkpoint. "
|
371 |
+
"Use scripts/strip_checkpoint.py or provide a .safetensors file."
|
372 |
+
)
|
373 |
+
model.load_state_dict(state["model"])
|
374 |
model.eval()
|
375 |
|
376 |
# Benchmark mode
|
requirements.txt
CHANGED
@@ -8,3 +8,4 @@ PyYAML>=6.0.1
|
|
8 |
tqdm>=4.65.0
|
9 |
gdown>=5.1.0
|
10 |
pydrive2
|
|
|
|
8 |
tqdm>=4.65.0
|
9 |
gdown>=5.1.0
|
10 |
pydrive2
|
11 |
+
safetensors
|
scripts/strip_checkpoint.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
import argparse
|
4 |
from pathlib import Path
|
5 |
import torch
|
|
|
6 |
|
7 |
|
8 |
def main():
|
@@ -10,7 +11,8 @@ def main():
|
|
10 |
description="Strip training checkpoint to inference-only weights (FP32)."
|
11 |
)
|
12 |
parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
|
13 |
-
parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt")
|
|
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
in_path = Path(args.inp)
|
@@ -38,9 +40,14 @@ def main():
|
|
38 |
#in the future, can cast to bfloat if necessary.
|
39 |
# state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
if __name__ == "__main__":
|
|
|
3 |
import argparse
|
4 |
from pathlib import Path
|
5 |
import torch
|
6 |
+
from safetensors.torch import save_file as safetensors_save_file
|
7 |
|
8 |
|
9 |
def main():
|
|
|
11 |
description="Strip training checkpoint to inference-only weights (FP32)."
|
12 |
)
|
13 |
parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
|
14 |
+
parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt or .safetensors")
|
15 |
+
# Output format is inferred from --out extension
|
16 |
args = parser.parse_args()
|
17 |
|
18 |
in_path = Path(args.inp)
|
|
|
40 |
#in the future, can cast to bfloat if necessary.
|
41 |
# state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
42 |
|
43 |
+
suffix = out_path.suffix.lower()
|
44 |
+
if suffix == ".safetensors":
|
45 |
+
safetensors_save_file(state_dict, str(out_path))
|
46 |
+
print(f"[strip_checkpoint] Saved safetensors (pure state_dict) to: {out_path}")
|
47 |
+
else:
|
48 |
+
to_save = {"model": state_dict}
|
49 |
+
torch.save(to_save, str(out_path))
|
50 |
+
print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")
|
51 |
|
52 |
|
53 |
if __name__ == "__main__":
|