|
|
|
import tempfile |
|
from functools import partial |
|
|
|
import mmcv |
|
import numpy as np |
|
import pytest |
|
import torch |
|
from packaging import version |
|
|
|
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, |
|
TensorRTDetector, TensorRTRecognizer) |
|
from mmocr.models import build_detector |
|
|
|
|
|
@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.') |
|
@pytest.mark.skipif( |
|
version.parse(torch.__version__) < version.parse('1.4.0'), |
|
reason='skip if torch=1.3.x') |
|
@pytest.mark.skipif( |
|
not torch.cuda.is_available(), reason='skip if on cpu device') |
|
def test_detector_wrapper(): |
|
try: |
|
import onnxruntime as ort |
|
import tensorrt as trt |
|
from mmcv.tensorrt import onnx2trt, save_trt_engine |
|
except ImportError: |
|
pytest.skip('ONNXRuntime or TensorRT is not available.') |
|
|
|
cfg = dict( |
|
model=dict( |
|
type='DBNet', |
|
backbone=dict( |
|
type='ResNet', |
|
depth=18, |
|
num_stages=4, |
|
out_indices=(0, 1, 2, 3), |
|
frozen_stages=-1, |
|
norm_cfg=dict(type='BN', requires_grad=True), |
|
init_cfg=dict( |
|
type='Pretrained', checkpoint='torchvision://resnet18'), |
|
norm_eval=False, |
|
style='caffe'), |
|
neck=dict( |
|
type='FPNC', |
|
in_channels=[64, 128, 256, 512], |
|
lateral_channels=256), |
|
bbox_head=dict( |
|
type='DBHead', |
|
text_repr_type='quad', |
|
in_channels=256, |
|
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, |
|
bbce_loss=True)), |
|
train_cfg=None, |
|
test_cfg=None)) |
|
|
|
cfg = mmcv.Config(cfg) |
|
|
|
pytorch_model = build_detector(cfg.model, None, None) |
|
|
|
|
|
inputs = torch.rand(1, 3, 224, 224) |
|
img_metas = [{ |
|
'img_shape': [1, 3, 224, 224], |
|
'ori_shape': [1, 3, 224, 224], |
|
'pad_shape': [1, 3, 224, 224], |
|
'filename': None, |
|
'scale_factor': np.array([1, 1, 1, 1]) |
|
}] |
|
|
|
pytorch_model.forward = pytorch_model.forward_dummy |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
onnx_path = f'{tmpdirname}/tmp.onnx' |
|
with torch.no_grad(): |
|
torch.onnx.export( |
|
pytorch_model, |
|
inputs, |
|
onnx_path, |
|
input_names=['input'], |
|
output_names=['output'], |
|
export_params=True, |
|
keep_initializers_as_inputs=False, |
|
verbose=False, |
|
opset_version=11) |
|
|
|
|
|
def get_GiB(x: int): |
|
"""return x GiB.""" |
|
return x * (1 << 30) |
|
|
|
trt_path = onnx_path.replace('.onnx', '.trt') |
|
min_shape = [1, 3, 224, 224] |
|
max_shape = [1, 3, 224, 224] |
|
|
|
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} |
|
max_workspace_size = get_GiB(1) |
|
trt_engine = onnx2trt( |
|
onnx_path, |
|
opt_shape_dict, |
|
log_level=trt.Logger.ERROR, |
|
fp16_mode=False, |
|
max_workspace_size=max_workspace_size) |
|
save_trt_engine(trt_engine, trt_path) |
|
print(f'Successfully created TensorRT engine: {trt_path}') |
|
|
|
wrap_onnx = ONNXRuntimeDetector(onnx_path, cfg, 0) |
|
wrap_trt = TensorRTDetector(trt_path, cfg, 0) |
|
|
|
assert isinstance(wrap_onnx, ONNXRuntimeDetector) |
|
assert isinstance(wrap_trt, TensorRTDetector) |
|
|
|
with torch.no_grad(): |
|
onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) |
|
trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) |
|
|
|
assert isinstance(onnx_outputs[0], dict) |
|
assert isinstance(trt_outputs[0], dict) |
|
assert 'boundary_result' in onnx_outputs[0] |
|
assert 'boundary_result' in trt_outputs[0] |
|
|
|
|
|
@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.') |
|
@pytest.mark.skipif( |
|
version.parse(torch.__version__) < version.parse('1.4.0'), |
|
reason='skip if torch=1.3.x') |
|
@pytest.mark.skipif( |
|
not torch.cuda.is_available(), reason='skip if on cpu device') |
|
def test_recognizer_wrapper(): |
|
try: |
|
import onnxruntime as ort |
|
import tensorrt as trt |
|
from mmcv.tensorrt import onnx2trt, save_trt_engine |
|
except ImportError: |
|
pytest.skip('ONNXRuntime or TensorRT is not available.') |
|
|
|
cfg = dict( |
|
label_convertor=dict( |
|
type='CTCConvertor', |
|
dict_type='DICT36', |
|
with_unknown=False, |
|
lower=True), |
|
model=dict( |
|
type='CRNNNet', |
|
preprocessor=None, |
|
backbone=dict( |
|
type='VeryDeepVgg', leaky_relu=False, input_channels=1), |
|
encoder=None, |
|
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), |
|
loss=dict(type='CTCLoss'), |
|
label_convertor=dict( |
|
type='CTCConvertor', |
|
dict_type='DICT36', |
|
with_unknown=False, |
|
lower=True), |
|
pretrained=None), |
|
train_cfg=None, |
|
test_cfg=None) |
|
|
|
cfg = mmcv.Config(cfg) |
|
|
|
pytorch_model = build_detector(cfg.model, None, None) |
|
|
|
|
|
inputs = torch.rand(1, 1, 32, 32) |
|
img_metas = [{ |
|
'img_shape': [1, 1, 32, 32], |
|
'ori_shape': [1, 1, 32, 32], |
|
'pad_shape': [1, 1, 32, 32], |
|
'filename': None, |
|
'scale_factor': np.array([1, 1, 1, 1]) |
|
}] |
|
|
|
pytorch_model.forward = partial( |
|
pytorch_model.forward, |
|
img_metas=img_metas, |
|
return_loss=False, |
|
rescale=True) |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
onnx_path = f'{tmpdirname}/tmp.onnx' |
|
with torch.no_grad(): |
|
torch.onnx.export( |
|
pytorch_model, |
|
inputs, |
|
onnx_path, |
|
input_names=['input'], |
|
output_names=['output'], |
|
export_params=True, |
|
keep_initializers_as_inputs=False, |
|
verbose=False, |
|
opset_version=11) |
|
|
|
|
|
def get_GiB(x: int): |
|
"""return x GiB.""" |
|
return x * (1 << 30) |
|
|
|
trt_path = onnx_path.replace('.onnx', '.trt') |
|
min_shape = [1, 1, 32, 32] |
|
max_shape = [1, 1, 32, 32] |
|
|
|
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} |
|
max_workspace_size = get_GiB(1) |
|
trt_engine = onnx2trt( |
|
onnx_path, |
|
opt_shape_dict, |
|
log_level=trt.Logger.ERROR, |
|
fp16_mode=False, |
|
max_workspace_size=max_workspace_size) |
|
save_trt_engine(trt_engine, trt_path) |
|
print(f'Successfully created TensorRT engine: {trt_path}') |
|
|
|
wrap_onnx = ONNXRuntimeRecognizer(onnx_path, cfg, 0) |
|
wrap_trt = TensorRTRecognizer(trt_path, cfg, 0) |
|
|
|
assert isinstance(wrap_onnx, ONNXRuntimeRecognizer) |
|
assert isinstance(wrap_trt, TensorRTRecognizer) |
|
|
|
with torch.no_grad(): |
|
onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) |
|
trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) |
|
|
|
assert isinstance(onnx_outputs[0], dict) |
|
assert isinstance(trt_outputs[0], dict) |
|
assert 'text' in onnx_outputs[0] |
|
assert 'text' in trt_outputs[0] |
|
|