fix strip checkpoint saving
Browse files
scripts/strip_checkpoint.py
CHANGED
@@ -35,11 +35,12 @@ def main():
|
|
35 |
)
|
36 |
state_dict = ckpt
|
37 |
|
38 |
-
#
|
39 |
-
state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
40 |
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
|
45 |
if __name__ == "__main__":
|
|
|
35 |
)
|
36 |
state_dict = ckpt
|
37 |
|
38 |
+
#in the future, can cast to bfloat if necessary.
|
39 |
+
# state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
40 |
|
41 |
+
to_save = {"model": state_dict}
|
42 |
+
torch.save(to_save, str(out_path))
|
43 |
+
print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")
|
44 |
|
45 |
|
46 |
if __name__ == "__main__":
|