File size: 948 Bytes
c9a1cd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.onnx as onnx
from basicsr.archs.rrdbnet_arch import RRDBNet

# Load the PyTorch model
device = torch.device('cpu')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)

# Load the state dictionary
state_dict = torch.load('Real-ESRGAN_x2plus.pth', map_location=device)

# Load the state dictionary
model.load_state_dict(state_dict['params_ema'])
model.train(False)

# Set the model to evaluation mode
model.eval()

# Define the input shape
input_shape = (1, 3, 64, 64)  # batch_size, channels, height, width

# Create a dummy input tensor
dummy_input = torch.randn(input_shape)

# Convert the model to ONNX
onnx.export(model, 
            dummy_input, 
            'Real-ESRGAN_x2plus.onnx', 
            opset_version=11, 
            input_names=['input'], 
            output_names=['output'], 
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})