Spaces:
Runtime error
Runtime error
| import warnings | |
| from collections import namedtuple | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import onnxruntime | |
| try: | |
| import tensorrt as trt | |
| except Exception: | |
| trt = None | |
| import torch | |
| warnings.filterwarnings(action='ignore', category=DeprecationWarning) | |
| class TRTWrapper(torch.nn.Module): | |
| dtype_mapping = {} | |
| def __init__(self, weight: Union[str, Path], | |
| device: Optional[torch.device]): | |
| super().__init__() | |
| weight = Path(weight) if isinstance(weight, str) else weight | |
| assert weight.exists() and weight.suffix in ('.engine', '.plan') | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| elif isinstance(device, int): | |
| device = torch.device(f'cuda:{device}') | |
| self.weight = weight | |
| self.device = device | |
| self.stream = torch.cuda.Stream(device=device) | |
| self.__update_mapping() | |
| self.__init_engine() | |
| self.__init_bindings() | |
| def __update_mapping(self): | |
| self.dtype_mapping.update({ | |
| trt.bool: torch.bool, | |
| trt.int8: torch.int8, | |
| trt.int32: torch.int32, | |
| trt.float16: torch.float16, | |
| trt.float32: torch.float32 | |
| }) | |
| def __init_engine(self): | |
| logger = trt.Logger(trt.Logger.ERROR) | |
| self.log = partial(logger.log, trt.Logger.ERROR) | |
| trt.init_libnvinfer_plugins(logger, namespace='') | |
| self.logger = logger | |
| with trt.Runtime(logger) as runtime: | |
| model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) | |
| context = model.create_execution_context() | |
| names = [model.get_binding_name(i) for i in range(model.num_bindings)] | |
| num_inputs, num_outputs = 0, 0 | |
| for i in range(model.num_bindings): | |
| if model.binding_is_input(i): | |
| num_inputs += 1 | |
| else: | |
| num_outputs += 1 | |
| self.is_dynamic = -1 in model.get_binding_shape(0) | |
| self.model = model | |
| self.context = context | |
| self.input_names = names[:num_inputs] | |
| self.output_names = names[num_inputs:] | |
| self.num_inputs = num_inputs | |
| self.num_outputs = num_outputs | |
| self.num_bindings = num_inputs + num_outputs | |
| self.bindings: List[int] = [0] * self.num_bindings | |
| def __init_bindings(self): | |
| Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) | |
| inputs_info = [] | |
| outputs_info = [] | |
| for i, name in enumerate(self.input_names): | |
| assert self.model.get_binding_name(i) == name | |
| dtype = self.dtype_mapping[self.model.get_binding_dtype(i)] | |
| shape = tuple(self.model.get_binding_shape(i)) | |
| inputs_info.append(Binding(name, dtype, shape)) | |
| for i, name in enumerate(self.output_names): | |
| i += self.num_inputs | |
| assert self.model.get_binding_name(i) == name | |
| dtype = self.dtype_mapping[self.model.get_binding_dtype(i)] | |
| shape = tuple(self.model.get_binding_shape(i)) | |
| outputs_info.append(Binding(name, dtype, shape)) | |
| self.inputs_info = inputs_info | |
| self.outputs_info = outputs_info | |
| if not self.is_dynamic: | |
| self.output_tensor = [ | |
| torch.empty(o.shape, dtype=o.dtype, device=self.device) | |
| for o in outputs_info | |
| ] | |
| def forward(self, *inputs): | |
| assert len(inputs) == self.num_inputs | |
| contiguous_inputs: List[torch.Tensor] = [ | |
| i.contiguous() for i in inputs | |
| ] | |
| for i in range(self.num_inputs): | |
| self.bindings[i] = contiguous_inputs[i].data_ptr() | |
| if self.is_dynamic: | |
| self.context.set_binding_shape( | |
| i, tuple(contiguous_inputs[i].shape)) | |
| # create output tensors | |
| outputs: List[torch.Tensor] = [] | |
| for i in range(self.num_outputs): | |
| j = i + self.num_inputs | |
| if self.is_dynamic: | |
| shape = tuple(self.context.get_binding_shape(j)) | |
| output = torch.empty( | |
| size=shape, | |
| dtype=self.output_dtypes[i], | |
| device=self.device) | |
| else: | |
| output = self.output_tensor[i] | |
| outputs.append(output) | |
| self.bindings[j] = output.data_ptr() | |
| self.context.execute_async_v2(self.bindings, self.stream.cuda_stream) | |
| self.stream.synchronize() | |
| return tuple(outputs) | |
| class ORTWrapper(torch.nn.Module): | |
| def __init__(self, weight: Union[str, Path], | |
| device: Optional[torch.device]): | |
| super().__init__() | |
| weight = Path(weight) if isinstance(weight, str) else weight | |
| assert weight.exists() and weight.suffix == '.onnx' | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| elif isinstance(device, int): | |
| device = torch.device(f'cuda:{device}') | |
| self.weight = weight | |
| self.device = device | |
| self.__init_session() | |
| self.__init_bindings() | |
| def __init_session(self): | |
| providers = ['CPUExecutionProvider'] | |
| if 'cuda' in self.device.type: | |
| providers.insert(0, 'CUDAExecutionProvider') | |
| session = onnxruntime.InferenceSession( | |
| str(self.weight), providers=providers) | |
| self.session = session | |
| def __init_bindings(self): | |
| Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) | |
| inputs_info = [] | |
| outputs_info = [] | |
| self.is_dynamic = False | |
| for i, tensor in enumerate(self.session.get_inputs()): | |
| if any(not isinstance(i, int) for i in tensor.shape): | |
| self.is_dynamic = True | |
| inputs_info.append( | |
| Binding(tensor.name, tensor.type, tuple(tensor.shape))) | |
| for i, tensor in enumerate(self.session.get_outputs()): | |
| outputs_info.append( | |
| Binding(tensor.name, tensor.type, tuple(tensor.shape))) | |
| self.inputs_info = inputs_info | |
| self.outputs_info = outputs_info | |
| self.num_inputs = len(inputs_info) | |
| def forward(self, *inputs): | |
| assert len(inputs) == self.num_inputs | |
| contiguous_inputs: List[np.ndarray] = [ | |
| i.contiguous().cpu().numpy() for i in inputs | |
| ] | |
| if not self.is_dynamic: | |
| # make sure input shape is right for static input shape | |
| for i in range(self.num_inputs): | |
| assert contiguous_inputs[i].shape == self.inputs_info[i].shape | |
| outputs = self.session.run([o.name for o in self.outputs_info], { | |
| j.name: contiguous_inputs[i] | |
| for i, j in enumerate(self.inputs_info) | |
| }) | |
| return tuple(torch.from_numpy(o).to(self.device) for o in outputs) | |