Janeka commited on
Commit
4f91b92
·
verified ·
1 Parent(s): 029f929

Upload esrganONNX.py

Browse files
Files changed (1) hide show
  1. esrganONNX.py +39 -0
esrganONNX.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ # import torch
3
+ import onnxruntime
4
+ import numpy as np
5
+
6
+
7
+ class RealESRGAN_ONNX:
8
+ def __init__(self, model_path="RealESRGAN_x2.onnx", device='cuda'):
9
+ session_options = onnxruntime.SessionOptions()
10
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
11
+ providers = ["CPUExecutionProvider"]
12
+ if device == 'cuda':
13
+ providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),"CPUExecutionProvider"]
14
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers)
15
+
16
+ def enhance(self, img):
17
+
18
+ img = img.astype(np.float32)
19
+ img = img.transpose((2, 0, 1))
20
+ img = img /255
21
+ img = np.expand_dims(img, axis=0).astype(np.float32)
22
+ #
23
+ result = self.session.run(None, {(self.session.get_inputs()[0].name):img})[0][0]
24
+ #
25
+ result = (result.squeeze().transpose((1,2,0)) * 255).clip(0, 255).astype(np.uint8)
26
+ return result
27
+
28
+ def enhance_fp16(self, img):
29
+
30
+ img = img.astype(np.float16)
31
+ img = img.transpose((2, 0, 1))
32
+ img = img /255
33
+ img = np.expand_dims(img, axis=0).astype(np.float16)
34
+ #
35
+ result = self.session.run(None, {(self.session.get_inputs()[0].name):img})[0][0]
36
+ #
37
+ result = (result.squeeze().transpose((1,2,0)) * 255).clip(0, 255).astype(np.uint8)
38
+ return result
39
+