Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023-2024, Zexin He | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import argparse | |
| from omegaconf import OmegaConf | |
| import torch.nn as nn | |
| from accelerate import Accelerator | |
| import safetensors | |
| import sys | |
| sys.path.append(".") | |
| from openlrm.utils.hf_hub import wrap_model_hub | |
| from openlrm.models import model_dict | |
| def auto_load_model(cfg, model: nn.Module) -> int: | |
| ckpt_root = os.path.join( | |
| cfg.saver.checkpoint_root, | |
| cfg.experiment.parent, cfg.experiment.child, | |
| ) | |
| if not os.path.exists(ckpt_root): | |
| raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}") | |
| ckpt_dirs = os.listdir(ckpt_root) | |
| if len(ckpt_dirs) == 0: | |
| raise FileNotFoundError(f"No checkpoint found in {ckpt_root}") | |
| ckpt_dirs.sort() | |
| load_step = f"{cfg.convert.global_step}" if cfg.convert.global_step is not None else ckpt_dirs[-1] | |
| load_model_path = os.path.join(ckpt_root, load_step, 'model.safetensors') | |
| print(f"Loading from {load_model_path}") | |
| safetensors.torch.load_model(model, load_model_path) | |
| return int(load_step) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, default='./assets/config.yaml') | |
| args, unknown = parser.parse_known_args() | |
| cfg = OmegaConf.load(args.config) | |
| cli_cfg = OmegaConf.from_cli(unknown) | |
| cfg = OmegaConf.merge(cfg, cli_cfg) | |
| """ | |
| [cfg.convert] | |
| global_step: int | |
| save_dir: str | |
| """ | |
| accelerator = Accelerator() | |
| hf_model_cls = wrap_model_hub(model_dict[cfg.experiment.type]) | |
| hf_model = hf_model_cls(dict(cfg.model)) | |
| loaded_step = auto_load_model(cfg, hf_model) | |
| dump_path = os.path.join( | |
| f"./exps/releases", | |
| cfg.experiment.parent, cfg.experiment.child, | |
| f'step_{loaded_step:06d}', | |
| ) | |
| print(f"Saving locally to {dump_path}") | |
| os.makedirs(dump_path, exist_ok=True) | |
| hf_model.save_pretrained( | |
| save_directory=dump_path, | |
| config=hf_model.config, | |
| ) | |