|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import importlib |
|
import inspect |
|
import os |
|
import traceback |
|
|
|
import torch |
|
import wrapt |
|
|
|
from nemo.core import Model |
|
from nemo.utils import model_utils |
|
|
|
DOMAINS = ['asr', 'tts', 'nlp'] |
|
|
|
|
|
def process_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-d', '--domain', choices=DOMAINS, type=str) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
|
|
|
|
|
|
def _build_import_path(domain, subdomains: list, imp): |
|
import_path = ["nemo", "collections", domain] |
|
import_path.extend(subdomains) |
|
import_path.append(imp) |
|
|
|
path = ".".join(import_path) |
|
return path |
|
|
|
|
|
def _get_class_from_path(domain, subdomains, imp): |
|
path = _build_import_path(domain, subdomains, imp) |
|
|
|
class_ = None |
|
result = None |
|
|
|
try: |
|
class_ = model_utils.import_class_by_path(path) |
|
|
|
if inspect.isclass(class_): |
|
|
|
if isinstance(class_, wrapt.FunctionWrapper): |
|
class_ = class_.__wrapped__ |
|
|
|
|
|
if issubclass(class_, (Model, torch.nn.Module)): |
|
result = class_ |
|
else: |
|
class_ = None |
|
|
|
error = None |
|
|
|
except Exception: |
|
error = traceback.format_exc() |
|
|
|
return class_, result, error |
|
|
|
|
|
def _test_domain_module_imports(module, domain, subdomains: list): |
|
module_list = [] |
|
failed_list = [] |
|
error_list = [] |
|
|
|
error = None |
|
if len(subdomains) > 0: |
|
basepath = module.__path__[0] |
|
nemo_index = basepath.rfind("nemo") |
|
basepath = basepath[nemo_index:].replace(os.path.sep, ".") |
|
new_path = '.'.join([basepath, *subdomains]) |
|
|
|
try: |
|
module = importlib.import_module(new_path) |
|
except Exception: |
|
print(f"Could not import `{new_path}` ; Traceback below :") |
|
error = traceback.format_exc() |
|
error_list.append(error) |
|
|
|
if error is None: |
|
for imp in dir(module): |
|
class_, result, error = _get_class_from_path(domain, subdomains, imp) |
|
|
|
if result is not None: |
|
module_list.append(class_) |
|
|
|
elif class_ is not None: |
|
failed_list.append(class_) |
|
|
|
if error is not None: |
|
error_list.append(error) |
|
|
|
for module in module_list: |
|
print("Module successfully imported :", module) |
|
|
|
print() |
|
for module in failed_list: |
|
print("Module did not match a valid signature of NeMo Model (hence ignored):", module) |
|
|
|
print() |
|
if len(error_list) > 0: |
|
print("Imports crashed with following traceback !") |
|
|
|
for error in error_list: |
|
print("*" * 100) |
|
print() |
|
print(error) |
|
print() |
|
print("*" * 100) |
|
print() |
|
|
|
if len(error_list) > 0: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def test_domain_asr(args): |
|
import nemo.collections.asr as nemo_asr |
|
|
|
all_passed = _test_domain_module_imports(nemo_asr, domain=args.domain, subdomains=['models']) |
|
|
|
if not all_passed: |
|
exit(1) |
|
|
|
|
|
def test_domain_nlp(args): |
|
|
|
import nemo.collections.nlp as nemo_nlp |
|
|
|
|
|
all_passed = _test_domain_module_imports(nemo_nlp, domain=args.domain, subdomains=['models']) |
|
|
|
|
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_base_model'] |
|
) |
|
and all_passed |
|
) |
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_bert_model'] |
|
) |
|
and all_passed |
|
) |
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_glue_model'] |
|
) |
|
and all_passed |
|
) |
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_gpt_model'] |
|
) |
|
and all_passed |
|
) |
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, |
|
domain=args.domain, |
|
subdomains=['models', 'language_modeling', 'megatron_lm_encoder_decoder_model'], |
|
) |
|
and all_passed |
|
) |
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, |
|
domain=args.domain, |
|
subdomains=['models', 'language_modeling', 'megatron_gpt_prompt_learning_model'], |
|
) |
|
and all_passed |
|
) |
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_t5_model'] |
|
) |
|
and all_passed |
|
) |
|
all_passed = ( |
|
_test_domain_module_imports( |
|
nemo_nlp, domain=args.domain, subdomains=['models', 'language_modeling', 'megatron_t5_model'] |
|
) |
|
and all_passed |
|
) |
|
|
|
if not all_passed: |
|
exit(1) |
|
|
|
|
|
def test_domain_tts(args): |
|
import nemo.collections.tts as nemo_tts |
|
|
|
all_passed = _test_domain_module_imports(nemo_tts, domain=args.domain, subdomains=['models']) |
|
|
|
if not all_passed: |
|
exit(1) |
|
|
|
|
|
|
|
|
|
|
|
def test_domain(args): |
|
domain = args.domain |
|
|
|
if domain == 'asr': |
|
test_domain_asr(args) |
|
elif domain == 'nlp': |
|
test_domain_nlp(args) |
|
elif domain == 'tts': |
|
test_domain_tts(args) |
|
else: |
|
raise RuntimeError(f"Cannot resolve domain : {domain}") |
|
|
|
|
|
def run_checks(): |
|
args = process_args() |
|
test_domain(args) |
|
|
|
|
|
if __name__ == '__main__': |
|
run_checks() |
|
|