MRiabov commited on
Commit
c2fb738
·
1 Parent(s): 8943a48

Add export to onnx and tensorRT

Browse files
Files changed (1) hide show
  1. scripts/export_onnx_trt.py +130 -0
scripts/export_onnx_trt.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pprint
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+ from src.wireseghr.model import WireSegHR
9
+
10
+
11
+ class CoarseModule(torch.nn.Module):
12
+ def __init__(self, core: WireSegHR):
13
+ super().__init__()
14
+ self.core = core
15
+
16
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
17
+ logits, cond = self.core.forward_coarse(x)
18
+ return logits, cond
19
+
20
+
21
+ class FineModule(torch.nn.Module):
22
+ def __init__(self, core: WireSegHR):
23
+ super().__init__()
24
+ self.core = core
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ logits = self.core.forward_fine(x)
28
+ return logits
29
+
30
+
31
+ def build_model(cfg: dict, device: torch.device) -> WireSegHR:
32
+ pretrained_flag = bool(cfg.get("pretrained", False))
33
+ model = WireSegHR(backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag)
34
+ model = model.to(device)
35
+ return model
36
+
37
+
38
+ def main():
39
+ parser = argparse.ArgumentParser(description="Export WireSegHR to ONNX and TensorRT")
40
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
41
+ parser.add_argument("--ckpt", type=str, default="", help="Path to checkpoint .pt")
42
+ parser.add_argument("--out_dir", type=str, default="exports")
43
+ parser.add_argument("--coarse_size", type=int, default=1024)
44
+ parser.add_argument("--fine_patch_size", type=int, default=1024)
45
+ parser.add_argument("--opset", type=int, default=17)
46
+ parser.add_argument("--trtexec", type=str, default="", help="Optional path to trtexec to build TRT engines")
47
+
48
+ args = parser.parse_args()
49
+
50
+ import yaml
51
+
52
+ with open(args.config, "r") as f:
53
+ cfg = yaml.safe_load(f)
54
+ print("[export] Loaded config:")
55
+ pprint.pprint(cfg)
56
+
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ model = build_model(cfg, device)
59
+
60
+ ckpt_path = args.ckpt if args.ckpt else cfg.get("resume", "")
61
+ if ckpt_path:
62
+ assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"
63
+ print(f"[export] Loading checkpoint: {ckpt_path}")
64
+ state = torch.load(ckpt_path, map_location=device)
65
+ model.load_state_dict(state["model"]) # expects dict with key 'model'
66
+ model.eval()
67
+
68
+ os.makedirs(args.out_dir, exist_ok=True)
69
+
70
+ # Prepare dummy inputs (static shapes for best TRT performance)
71
+ coarse_in = torch.randn(1, 6, args.coarse_size, args.coarse_size, device=device)
72
+ fine_in = torch.randn(1, 6, args.fine_patch_size, args.fine_patch_size, device=device)
73
+
74
+ # Coarse export
75
+ coarse_wrapper = CoarseModule(model).to(device).eval()
76
+ coarse_onnx = os.path.join(args.out_dir, f"wireseghr_coarse_{args.coarse_size}.onnx")
77
+ print(f"[export] Exporting COARSE to {coarse_onnx}")
78
+ torch.onnx.export(
79
+ coarse_wrapper,
80
+ coarse_in,
81
+ coarse_onnx,
82
+ export_params=True,
83
+ opset_version=args.opset,
84
+ do_constant_folding=True,
85
+ input_names=["x_coarse"],
86
+ output_names=["logits", "cond"],
87
+ dynamic_axes=None,
88
+ )
89
+
90
+ # Fine export
91
+ fine_wrapper = FineModule(model).to(device).eval()
92
+ fine_onnx = os.path.join(args.out_dir, f"wireseghr_fine_{args.fine_patch_size}.onnx")
93
+ print(f"[export] Exporting FINE to {fine_onnx}")
94
+ torch.onnx.export(
95
+ fine_wrapper,
96
+ fine_in,
97
+ fine_onnx,
98
+ export_params=True,
99
+ opset_version=args.opset,
100
+ do_constant_folding=True,
101
+ input_names=["x_fine"],
102
+ output_names=["logits"],
103
+ dynamic_axes=None,
104
+ )
105
+
106
+ # Optional TensorRT building via trtexec
107
+ if args.trtexec:
108
+ import subprocess
109
+
110
+ def build_engine(onnx_path: str, engine_path: str):
111
+ print(f"[export] Building TRT engine: {engine_path}")
112
+ cmd = [
113
+ args.trtexec,
114
+ f"--onnx={onnx_path}",
115
+ f"--saveEngine={engine_path}",
116
+ "--explicitBatch",
117
+ "--fp16",
118
+ ]
119
+ subprocess.run(cmd, check=True)
120
+
121
+ coarse_engine = os.path.join(args.out_dir, f"wireseghr_coarse_{args.coarse_size}.engine")
122
+ fine_engine = os.path.join(args.out_dir, f"wireseghr_fine_{args.fine_patch_size}.engine")
123
+ build_engine(coarse_onnx, coarse_engine)
124
+ build_engine(fine_onnx, fine_engine)
125
+
126
+ print("[export] Done.")
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()