Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| import logging | |
| import os | |
| import re | |
| import yaml | |
| import torch | |
| from collections import OrderedDict | |
| import datetime | |
| def load_checkpoint(model: torch.nn.Module, path: str) -> dict: | |
| if torch.cuda.is_available(): | |
| logging.info("Checkpoint: loading from checkpoint %s for GPU" % path) | |
| checkpoint = torch.load(path) | |
| else: | |
| logging.info("Checkpoint: loading from checkpoint %s for CPU" % path) | |
| checkpoint = torch.load(path, map_location="cpu") | |
| model.load_state_dict(checkpoint, strict=False) | |
| info_path = re.sub(".pt$", ".yaml", path) | |
| configs = {} | |
| if os.path.exists(info_path): | |
| with open(info_path, "r") as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| return configs | |
| def save_checkpoint(model: torch.nn.Module, path: str, infos=None): | |
| """ | |
| Args: | |
| infos (dict or None): any info you want to save. | |
| """ | |
| logging.info("Checkpoint: save to checkpoint %s" % path) | |
| if isinstance(model, torch.nn.DataParallel): | |
| state_dict = model.module.state_dict() | |
| elif isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| torch.save(state_dict, path) | |
| info_path = re.sub(".pt$", ".yaml", path) | |
| if infos is None: | |
| infos = {} | |
| infos["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") | |
| with open(info_path, "w") as fout: | |
| data = yaml.dump(infos) | |
| fout.write(data) | |
| def filter_modules(model_state_dict, modules): | |
| new_mods = [] | |
| incorrect_mods = [] | |
| mods_model = model_state_dict.keys() | |
| for mod in modules: | |
| if any(key.startswith(mod) for key in mods_model): | |
| new_mods += [mod] | |
| else: | |
| incorrect_mods += [mod] | |
| if incorrect_mods: | |
| logging.warning( | |
| "module(s) %s don't match or (partially match) " | |
| "available modules in model.", | |
| incorrect_mods, | |
| ) | |
| logging.warning("for information, the existing modules in model are:") | |
| logging.warning("%s", mods_model) | |
| return new_mods | |
| def load_trained_modules(model: torch.nn.Module, args: None): | |
| # Load encoder modules with pre-trained model(s). | |
| enc_model_path = args.enc_init | |
| enc_modules = args.enc_init_mods | |
| main_state_dict = model.state_dict() | |
| logging.warning("model(s) found for pre-initialization") | |
| if os.path.isfile(enc_model_path): | |
| logging.info("Checkpoint: loading from checkpoint %s for CPU" % enc_model_path) | |
| model_state_dict = torch.load(enc_model_path, map_location="cpu") | |
| modules = filter_modules(model_state_dict, enc_modules) | |
| partial_state_dict = OrderedDict() | |
| for key, value in model_state_dict.items(): | |
| if any(key.startswith(m) for m in modules): | |
| partial_state_dict[key] = value | |
| main_state_dict.update(partial_state_dict) | |
| else: | |
| logging.warning("model was not found : %s", enc_model_path) | |
| model.load_state_dict(main_state_dict) | |
| configs = {} | |
| return configs | |
 
			
