MRiabov commited on
Commit
dae69c0
·
1 Parent(s): befda65

Safetensors export and infer.

Browse files
Files changed (3) hide show
  1. infer.py +18 -5
  2. requirements.txt +1 -0
  3. scripts/strip_checkpoint.py +11 -4
infer.py CHANGED
@@ -3,14 +3,13 @@ import os
3
  import pprint
4
  import time
5
  from typing import List, Tuple, Optional, Dict, Any
6
- import yaml
7
-
8
  import numpy as np
9
  import cv2
10
  import torch
11
  import torch.nn.functional as F
12
  from torch.amp import autocast
13
  from tqdm import tqdm
 
14
 
15
  from src.wireseghr.model import WireSegHR
16
  from pathlib import Path
@@ -256,7 +255,7 @@ def main():
256
  "--ckpt",
257
  type=str,
258
  default="",
259
- help="Optional checkpoint (.pt) with model state",
260
  )
261
  parser.add_argument(
262
  "--save_prob", action="store_true", help="Also save probability .npy"
@@ -356,8 +355,22 @@ def main():
356
  if ckpt_path:
357
  assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
358
  print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}")
359
- state = torch.load(ckpt_path, map_location=device)
360
- model.load_state_dict(state["model"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  model.eval()
362
 
363
  # Benchmark mode
 
3
  import pprint
4
  import time
5
  from typing import List, Tuple, Optional, Dict, Any
 
 
6
  import numpy as np
7
  import cv2
8
  import torch
9
  import torch.nn.functional as F
10
  from torch.amp import autocast
11
  from tqdm import tqdm
12
+ from safetensors.torch import load_file as safe_load_file
13
 
14
  from src.wireseghr.model import WireSegHR
15
  from pathlib import Path
 
255
  "--ckpt",
256
  type=str,
257
  default="",
258
+ help="Optional checkpoint (.pt with {'model': state_dict} or .safetensors with pure state_dict)",
259
  )
260
  parser.add_argument(
261
  "--save_prob", action="store_true", help="Also save probability .npy"
 
355
  if ckpt_path:
356
  assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
357
  print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}")
358
+ suffix = Path(ckpt_path).suffix.lower()
359
+ if suffix == ".safetensors":
360
+ # Safetensors exports contain a pure state_dict
361
+ state_dict = safe_load_file(ckpt_path)
362
+ model.load_state_dict(state_dict)
363
+ else:
364
+ print(
365
+ "[WireSegHR][infer][WARN] Loading a PyTorch checkpoint. Prefer .safetensors for inference-only weights."
366
+ )
367
+ # PyTorch .pt/.pth checkpoints expected to have {'model': state_dict}
368
+ state = torch.load(ckpt_path, map_location=device)
369
+ assert "model" in state, (
370
+ "Expected a dict with key 'model' for PyTorch checkpoint. "
371
+ "Use scripts/strip_checkpoint.py or provide a .safetensors file."
372
+ )
373
+ model.load_state_dict(state["model"])
374
  model.eval()
375
 
376
  # Benchmark mode
requirements.txt CHANGED
@@ -8,3 +8,4 @@ PyYAML>=6.0.1
8
  tqdm>=4.65.0
9
  gdown>=5.1.0
10
  pydrive2
 
 
8
  tqdm>=4.65.0
9
  gdown>=5.1.0
10
  pydrive2
11
+ safetensors
scripts/strip_checkpoint.py CHANGED
@@ -3,6 +3,7 @@
3
  import argparse
4
  from pathlib import Path
5
  import torch
 
6
 
7
 
8
  def main():
@@ -10,7 +11,8 @@ def main():
10
  description="Strip training checkpoint to inference-only weights (FP32)."
11
  )
12
  parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
13
- parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt")
 
14
  args = parser.parse_args()
15
 
16
  in_path = Path(args.inp)
@@ -38,9 +40,14 @@ def main():
38
  #in the future, can cast to bfloat if necessary.
39
  # state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
40
 
41
- to_save = {"model": state_dict}
42
- torch.save(to_save, str(out_path))
43
- print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")
 
 
 
 
 
44
 
45
 
46
  if __name__ == "__main__":
 
3
  import argparse
4
  from pathlib import Path
5
  import torch
6
+ from safetensors.torch import save_file as safetensors_save_file
7
 
8
 
9
  def main():
 
11
  description="Strip training checkpoint to inference-only weights (FP32)."
12
  )
13
  parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
14
+ parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt or .safetensors")
15
+ # Output format is inferred from --out extension
16
  args = parser.parse_args()
17
 
18
  in_path = Path(args.inp)
 
40
  #in the future, can cast to bfloat if necessary.
41
  # state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
42
 
43
+ suffix = out_path.suffix.lower()
44
+ if suffix == ".safetensors":
45
+ safetensors_save_file(state_dict, str(out_path))
46
+ print(f"[strip_checkpoint] Saved safetensors (pure state_dict) to: {out_path}")
47
+ else:
48
+ to_save = {"model": state_dict}
49
+ torch.save(to_save, str(out_path))
50
+ print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")
51
 
52
 
53
  if __name__ == "__main__":