Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Utility to export a training checkpoint to a lightweight release checkpoint. | |
| """ | |
| from pathlib import Path | |
| import typing as tp | |
| from omegaconf import OmegaConf, DictConfig | |
| import torch | |
| def _clean_lm_cfg(cfg: DictConfig): | |
| OmegaConf.set_struct(cfg, False) | |
| # This used to be set automatically in the LM solver, need a more robust solution | |
| # for the future. | |
| cfg['transformer_lm']['card'] = 2048 | |
| cfg['transformer_lm']['n_q'] = 4 | |
| # Experimental params no longer supported. | |
| bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', | |
| 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] | |
| for name in bad_params: | |
| del cfg['transformer_lm'][name] | |
| OmegaConf.set_struct(cfg, True) | |
| return cfg | |
| def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): | |
| sig = Path(checkpoint_path).parent.name | |
| assert len(sig) == 8, "Not a valid Dora signature" | |
| pkg = torch.load(checkpoint_path, 'cpu') | |
| new_pkg = { | |
| 'best_state': pkg['ema']['state']['model'], | |
| 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), | |
| } | |
| out_file = Path(out_folder) / f'{sig}.th' | |
| torch.save(new_pkg, out_file) | |
| return out_file | |
| def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): | |
| sig = Path(checkpoint_path).parent.name | |
| assert len(sig) == 8, "Not a valid Dora signature" | |
| pkg = torch.load(checkpoint_path, 'cpu') | |
| new_pkg = { | |
| 'best_state': pkg['fsdp_best_state']['model'], | |
| 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) | |
| } | |
| out_file = Path(out_folder) / f'{sig}.th' | |
| torch.save(new_pkg, out_file) | |
| return out_file | |