NeMo / tests /core_ptl /check_imports.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import importlib
import inspect
import os
import traceback
import torch
import wrapt
from nemo.core import Model
from nemo.utils import model_utils
DOMAINS = ['asr', 'tts', 'nlp']
def process_args():
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--domain', choices=DOMAINS, type=str)
args = parser.parse_args()
return args
###############################
def _build_import_path(domain, subdomains: list, imp):
import_path = ["nemo", "collections", domain]
import_path.extend(subdomains)
import_path.append(imp)
path = ".".join(import_path)
return path
def _get_class_from_path(domain, subdomains, imp):
path = _build_import_path(domain, subdomains, imp)
class_ = None
result = None
try:
class_ = model_utils.import_class_by_path(path)
if inspect.isclass(class_):
# Is class wrpped in a wrapt.decorator a the class level? Unwrap for checks.
if isinstance(class_, wrapt.FunctionWrapper):
class_ = class_.__wrapped__
# Subclass tests
if issubclass(class_, (Model, torch.nn.Module)):
result = class_
else:
class_ = None
error = None
except Exception:
error = traceback.format_exc()
return class_, result, error
def _test_domain_module_imports(module, domain, subdomains: list):
module_list = []
failed_list = []
error_list = []
error = None
if len(subdomains) > 0:
basepath = module.__path__[0]
nemo_index = basepath.rfind("nemo")
basepath = basepath[nemo_index:].replace(os.path.sep, ".")
new_path = '.'.join([basepath, *subdomains])
try:
module = importlib.import_module(new_path)
except Exception:
print(f"Could not import `{new_path}` ; Traceback below :")
error = traceback.format_exc()
error_list.append(error)
if error is None:
for imp in dir(module):
class_, result, error = _get_class_from_path(domain, subdomains, imp)
if result is not None:
module_list.append(class_)
elif class_ is not None:
failed_list.append(class_)
if error is not None:
error_list.append(error)
for module in module_list:
print("Module successfully imported :", module)
print()
for module in failed_list:
print("Module did not match a valid signature of NeMo Model (hence ignored):", module)
print()
if len(error_list) > 0:
print("Imports crashed with following traceback !")
for error in error_list:
print("*" * 100)
print()
print(error)
print()
print("*" * 100)
print()
if len(error_list) > 0:
return False
else:
return True
###############################
def test_domain_asr(args):
import nemo.collections.asr as nemo_asr
all_passed = _test_domain_module_imports(nemo_asr, domain=args.domain, subdomains=['models'])
if not all_passed:
exit(1)
def test_domain_nlp(args):
# If even this fails, just fail entirely.
import nemo.collections.nlp as nemo_nlp
# Basic NLP test
all_passed = _test_domain_module_imports(nemo_nlp, domain=args.domain, subdomains=['models'])
# Megatron Test
all_passed = (
_test_domain_module_imports(
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_base_model']
)
and all_passed
)
all_passed = (
_test_domain_module_imports(
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_bert_model']
)
and all_passed
)
all_passed = (
_test_domain_module_imports(
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_glue_model']
)
and all_passed
)
all_passed = (
_test_domain_module_imports(
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_gpt_model']
)
and all_passed
)
all_passed = (
_test_domain_module_imports(
nemo_nlp,
domain=args.domain,
subdomains=['models', 'language_modeling', 'megatron_lm_encoder_decoder_model'],
)
and all_passed
)
all_passed = (
_test_domain_module_imports(
nemo_nlp,
domain=args.domain,
subdomains=['models', 'language_modeling', 'megatron_gpt_prompt_learning_model'],
)
and all_passed
)
all_passed = (
_test_domain_module_imports(
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_t5_model']
)
and all_passed
)
all_passed = (
_test_domain_module_imports(
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_t5_model']
)
and all_passed
)
if not all_passed:
exit(1)
def test_domain_tts(args):
import nemo.collections.tts as nemo_tts
all_passed = _test_domain_module_imports(nemo_tts, domain=args.domain, subdomains=['models'])
if not all_passed:
exit(1)
###############################
def test_domain(args):
domain = args.domain
if domain == 'asr':
test_domain_asr(args)
elif domain == 'nlp':
test_domain_nlp(args)
elif domain == 'tts':
test_domain_tts(args)
else:
raise RuntimeError(f"Cannot resolve domain : {domain}")
def run_checks():
args = process_args()
test_domain(args)
if __name__ == '__main__':
run_checks()