MRiabov commited on
Commit
5571883
·
1 Parent(s): e05cc17

fix strip checkpoint saving

Browse files
Files changed (1) hide show
  1. scripts/strip_checkpoint.py +5 -4
scripts/strip_checkpoint.py CHANGED
@@ -35,11 +35,12 @@ def main():
35
  )
36
  state_dict = ckpt
37
 
38
- # Ensure FP32 tensors (no casting to bf16/fp16 per request)
39
- state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
40
 
41
- torch.save(state_dict, str(out_path))
42
- print(f"[strip_checkpoint] Saved weights-only to: {out_path}")
 
43
 
44
 
45
  if __name__ == "__main__":
 
35
  )
36
  state_dict = ckpt
37
 
38
+ #in the future, can cast to bfloat if necessary.
39
+ # state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
40
 
41
+ to_save = {"model": state_dict}
42
+ torch.save(to_save, str(out_path))
43
+ print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")
44
 
45
 
46
  if __name__ == "__main__":