MRiabov commited on
Commit
e05cc17
·
1 Parent(s): d79c41b

README and cleanup

Browse files
.windsurf/rules/defensive-logic.md CHANGED
@@ -5,4 +5,4 @@ trigger: always_on
5
  When deciding if to write defensive logic, e.g. dimensionality handling: `tensor1=tensor1.unsqueeze() if tensor1.ndim==1 else tensor1`, or None handling: `var1=var1 if var1 else torch.zeros(...)`, just don't write these things. In my code, shapes are always static, and there is one execution path for all code. I prefer `assert` over defensive logic. If you are writing something to fix the tests and this seems necessary, it's likely that the tests are setup incorrectly.
6
  The reason why I don't want it is because defensive logic leads to silent failures, and these are bad for debugging.
7
 
8
- In addition, writing "int()" type casting is also a piece of defensive logic that slows down the application. Don't write it unless really necessary e.g. putting int to string. In most cases, indexing with a tensor should be better.
 
5
  When deciding if to write defensive logic, e.g. dimensionality handling: `tensor1=tensor1.unsqueeze() if tensor1.ndim==1 else tensor1`, or None handling: `var1=var1 if var1 else torch.zeros(...)`, just don't write these things. In my code, shapes are always static, and there is one execution path for all code. I prefer `assert` over defensive logic. If you are writing something to fix the tests and this seems necessary, it's likely that the tests are setup incorrectly.
6
  The reason why I don't want it is because defensive logic leads to silent failures, and these are bad for debugging.
7
 
8
+ In addition, writing "int()", "float()" or "bool()" type casting is also a piece of defensive logic that slows down the application. Don't write it unless really necessary e.g. putting int to string. In most cases, indexing with a tensor should be better.
README.md CHANGED
@@ -1,30 +1,35 @@
1
  # WireSegHR (Segmentation Only)
2
 
3
- This repository contains the segmentation-only implementation plan and code skeleton for the two-stage WireSegHR model (global-to-local, shared encoder).
4
 
5
- - Paper sources live under `paper-tex/`.
6
- - Long-term navigation plan: `SEGMENTATION_PLAN.md`.
7
 
8
- ## Quick Start (skeleton)
9
 
10
- 1) Create a virtual environment and install requirements:
 
 
11
 
12
  ```bash
13
- python -m venv .venv
14
- source .venv/bin/activate
15
- pip install -r requirements.txt
16
  ```
 
17
 
18
- 2) Print configuration and verify the skeleton runs:
19
 
20
  ```bash
21
- python src/wireseghr/train.py --config configs/default.yaml
22
- python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/image.png
23
  ```
24
 
25
- 3) Next steps:
26
- - Implement encoder/decoders/condition/minmax/label downsampling per `SEGMENTATION_PLAN.md`.
27
- - Implement training and inference logic, then metrics and ablations.
 
 
 
 
 
28
 
29
  ## Notes
30
  - This is a segmentation-only codebase. Inpainting is out of scope here.
@@ -41,6 +46,101 @@ python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/ima
41
  - `dataset/val/images/...` and `dataset/val/gts/...`
42
  - `dataset/test/images/...` and `dataset/test/gts/...`
43
  - Masks are binary: foreground = white (255), background = black (0).
44
- - The loader strictly enforces numeric stems and 1:1 pairing and will assert on mismatches.
45
 
46
  Update `configs/default.yaml` with your paths under `data.train_images`, `data.train_masks`, etc. Defaults point to `dataset/train/images`, `dataset/train/gts`, and validation to `dataset/val/...`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # WireSegHR (Segmentation Only)
2
 
3
+ This repository contains the segmentation-only implementation of the two-stage [WireSegHR model](https://arxiv.org/abs/2304.00221), training on the WireSegHR dataset plus the [TTPLA dataset](https://github.com/R3ab/ttpla_dataset).
4
 
5
+ ## Quick Start
 
6
 
7
+ 1) Get secrets necessary for fetching of the dataset:
8
 
9
+ You'll need a GDrive service account to fetch the WireSegHR dataset using scripts in this repo. Get a GDrive key as described [in this short README](scripts/drive-viewer-key-readme.md), and put it in `/secrets/drive-json.json`
10
+
11
+ 2) Run:
12
 
13
  ```bash
14
+ scripts/setup.sh
 
 
15
  ```
16
+ This installs dependencies and merges the TTPLA dataset into the WireSegHR dataset format.
17
 
18
+ 3) Train and run a quick inference check:
19
 
20
  ```bash
21
+ python3 train.py --config configs/default.yaml
22
+ python3 infer.py --config configs/default.yaml --image /path/to/image.jpg
23
  ```
24
 
25
+ The default config `default.yaml` is suitable for a 24GB VRAM GPU with support for bf16 (e.g., RTX 3090/4090).
26
+ <!-- For a quick RTX GPU setup, I recommend [vast.ai](https://cloud.vast.ai/?ref_id=162850) -->
27
+
28
+ ## Project Overview
29
+ - Two-stage, global-to-local segmentation with a shared encoder and a fine decoder conditioned on the coarse stage.
30
+ - Full training loop with AMP (optional), Poly LR, periodic evaluation, checkpointing, and test visualizations (`train.py`).
31
+ - Dataset utilities under `src/wireseghr/data/` and model components under `src/wireseghr/model/`.
32
+ - Paper text and figures live in `paper-tex/` (`paper-tex/sections/` contains the Method, Results, etc.).
33
 
34
  ## Notes
35
  - This is a segmentation-only codebase. Inpainting is out of scope here.
 
46
  - `dataset/val/images/...` and `dataset/val/gts/...`
47
  - `dataset/test/images/...` and `dataset/test/gts/...`
48
  - Masks are binary: foreground = white (255), background = black (0).
49
+ - The loader strictly enforces numeric stems and 1:1 pairing of naming and will raise on file name mismatches.
50
 
51
  Update `configs/default.yaml` with your paths under `data.train_images`, `data.train_masks`, etc. Defaults point to `dataset/train/images`, `dataset/train/gts`, and validation to `dataset/val/...`.
52
+
53
+ ## Inference
54
+
55
+ - Single image (optionally save outputs to a directory):
56
+
57
+ ```bash
58
+ python3 infer.py \
59
+ --config configs/default.yaml \
60
+ --ckpt ckpt_5000.pt \
61
+ --image dataset/test/images/123.jpg \
62
+ --out outputs/infer
63
+ ```
64
+
65
+ - Compute metrics for a single image (requires a GT mask):
66
+
67
+ ```bash
68
+ python3 infer.py \
69
+ --config configs/default.yaml \
70
+ --ckpt ckpt_5000.pt \
71
+ --image dataset/test/images/123.jpg \
72
+ --out outputs/infer \
73
+ --metrics \
74
+ --mask dataset/test/gts/123.png
75
+ ```
76
+
77
+ - Run inference over the entire directory with metrics (images_dir sets the image directory, masks_dir sets the ground truth mask directory):
78
+
79
+ ```bash
80
+ python3 infer.py \
81
+ --config configs/default.yaml \
82
+ --ckpt ckpt_5000.pt \
83
+ --images_dir dataset/test/images \
84
+ --out outputs/infer \
85
+ --metrics \
86
+ --masks_dir dataset/test/gts
87
+ ```
88
+
89
+ Notes:
90
+ - Predictions are saved as 0/255 PNGs. For metrics, predictions are binarized with `> 0` to match training logic.
91
+ - Masks are matched by filename stem: `images/123.jpg` ↔ `gts/123.png`.
92
+
93
+ ## Benchmarking and Metrics
94
+
95
+ Benchmark mode times the model on a directory of images and reports coarse/fine/total latency statistics. When `--metrics` is provided, it also computes IoU/F1/Precision/Recall over the benchmark set (both fine and coarse outputs).
96
+
97
+ Example (uses `data.test_images` and `data.test_masks` from the config by default):
98
+
99
+ ```bash
100
+ python3 infer.py \
101
+ --config configs/default.yaml \
102
+ --benchmark \
103
+ --ckpt ckpt_5000.pt \
104
+ --bench_warmup 2 \
105
+ --bench_limit 0 \
106
+ --bench_report_json outputs/bench_report.json \
107
+ --metrics
108
+ ```
109
+
110
+ If your ground truth directory is different from `data.test_masks`, please override it with `--bench_masks_dir`:
111
+
112
+ ```bash
113
+ python3 infer.py \
114
+ --config configs/default.yaml \
115
+ --benchmark \
116
+ --ckpt ckpt_5000.pt \
117
+ --bench_warmup 2 \
118
+ --bench_limit 0 \
119
+ --bench_report_json outputs/bench_report.json \
120
+ --metrics \
121
+ --bench_masks_dir /path/to/gts
122
+ ```
123
+
124
+ You will see output like:
125
+
126
+ ```
127
+ [WireSegHR][bench] Results (ms):
128
+ Coarse avg=50.16 p50=44.48 p95=76.78
129
+ Fine avg=534.38 p50=419.52 p95=1187.66
130
+ Total avg=584.54 p50=464.73 p95=1300.07
131
+ Target < 1000 ms per 3000x4000 image: YES
132
+ [WireSegHR][bench][Fine] IoU=0.6098 F1=0.7576 P=0.6418 R=0.9244
133
+ [WireSegHR][bench][Coarse] IoU=0.5315 F1=0.6941 P=0.5467 R=0.9502
134
+ ```
135
+ **These metrics were obtained after 5000 iterations*
136
+
137
+ Optional: you can save a JSON timing report with `--bench_report_json`. Schema:
138
+ - `summary`
139
+ - `avg_ms`, `p50_ms`, `p95_ms`
140
+ - `avg_coarse_ms`, `avg_fine_ms`
141
+ - `images`
142
+ - `per_image`: list of objects with
143
+ - `path`, `H`, `W`, `t_coarse_ms`, `t_fine_ms`, `t_total_ms`
144
+
145
+ Utils:
146
+ - Export your model to inference-only weights by scripts/strip_checkpoint.py
SEGMENTATION_PLAN.md DELETED
@@ -1,136 +0,0 @@
1
- # WireSegHR Segmentation-Only Implementation Plan
2
-
3
- This plan distills the model and pipeline described in the paper sources:
4
- - `paper-tex/sections/method.tex`
5
- - `paper-tex/sections/method_yq.tex`
6
- - `paper-tex/figure_tex/pipeline.tex`
7
- - `paper-tex/tables/{component,logit,thresholds}.tex`
8
-
9
- Focus: segmentation only (no dataset collection or inpainting).
10
-
11
- ## Decisions and Defaults (locked)
12
- - Backbone: SegFormer MiT-B3 via HuggingFace Transformers.
13
- - Fine/local patch size p: 768.
14
- - Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
15
- - Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
16
- - MinMax feature augmentation: luminance min and max with a fixed 6×6 window; channels concatenated to inputs (Figure `figure_tex/pipeline.tex`, Sec. “Wire Feature Preservation” in `method_yq.tex`).
17
- - Loss: CE on both branches, λ = 1 (`method_yq.tex`).
18
- - α-threshold for refining windows: default 0.01 (Table `tables/thresholds.tex`).
19
- - Coarse input size: train 512×512; test 1024×1024 (`method.tex`).
20
- - Optim: AdamW (lr=6e-5, wd=0.01, poly schedule with power=1), ~40k iters, batch size ~8 (`method.tex`).
21
-
22
- ## Project Structure
23
- - `configs/`
24
- - `default.yaml` (backbone=mit_b2, p=768, coarse_train=512, coarse_test=1024, alpha=0.01, minmax=true, kernel=6, maxpool_label=true, cond_variant=global)
25
- - `src/wireseghr/`
26
- - `model/`
27
- - `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
28
- - `decoder.py` (two MLP decoders `D_C`, `D_F` for 2 classes)
29
- - `condition.py` (1×1 conv to collapse coarse 2-ch logits → 1-ch cond)
30
- - `minmax.py` (6×6 luminance min/max filtering)
31
- - `label_downsample.py` (MaxPool-based coarse GT downsampling)
32
- - `data/`
33
- - `dataset.py` (image/mask loading, full-res to coarse/fine inputs)
34
- - `sampler.py` (balanced patch sampling with ≥1% wire pixels)
35
- - `transforms.py` (scaling, rotation, flip, photometric distortion)
36
- - `train.py` (end-to-end two-branch training)
37
- - `infer.py` (coarse-to-fine sliding-window inference + stitching)
38
- - `metrics.py` (IoU, F1, Precision, Recall)
39
- - `utils.py` (misc: overlap blending, seeding, logging)
40
- - `tests/` (unit tests for channel wiring, cond alignment, stitching)
41
- - `README.md` (segmentation-only usage)
42
-
43
- ## Model Specification
44
- - Shared encoder `E`: SegFormer MiT-B3 (HF Transformers preferred).
45
- - Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
46
- - For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
47
- - Weight init for extra channels: copy mean of RGB conv weights or zero-init.
48
- - Decoders: two SegFormer MLP decoders
49
- - `D_C`: coarse logits (2 channels) at coarse resolution.
50
- - `D_F`: fine logits (2 channels) at patch resolution p×p.
51
- - Conditioning to fine branch (default):
52
- - Take coarse pre-softmax logits (2-ch), apply 1×1 conv → 1-ch cond map (`method.tex`).
53
- - Binary location mask: 1 inside current patch region (in full-image coordinates), 0 elsewhere.
54
- - Pass patch-aligned cond crop and binary mask as channels to the fine branch input.
55
- - Notes:
56
- - We follow the published version (`paper-tex/sections/method_yq.tex`) and use patch-cropped conditioning exclusively; no full-image conditioning variant will be implemented.
57
-
58
- ## Data and Preprocessing
59
- - MinMax luminance features (both branches):
60
- - Y = 0.299R + 0.587G + 0.114B.
61
- - Y_min = min filter (6×6), Y_max = max filter (6×6).
62
- - Concat [Y_min, Y_max] to the input image channels.
63
- - Coarse GT label generation (MaxPool):
64
- - Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
65
- - Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
66
-
67
- ### Dataset Convention (project-specific)
68
- - Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
69
- - Example:
70
- - `dataset/images/1.jpg, 2.jpg, ..., N.jpg` (or `.jpeg`)
71
- - `dataset/gts/1.png, 2.png, ..., N.png`
72
- - Masks are binary: foreground = white (255), background = black (0).
73
- - The loader (`data/dataset.py`) strictly enforces numeric stems and 1:1 pairing and will assert on mismatch.
74
-
75
- ## Training Pipeline
76
- - Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
77
- - Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
78
- - Fine input (per iteration select 1–k patches):
79
- - Sample p×p patch (p=768) with ≥1% wire pixels (`method.tex`, `method_yq.tex`).
80
- - Build cond map from coarse logits via 1×1 conv; crop cond to patch region.
81
- - Build binary location mask for patch region.
82
- - Build channels [RGB + MinMax + cond + location] → `E` → `D_F`.
83
- - Losses:
84
- - L_glo = CE(Softmax(`D_C(E(coarse))`), G_glo), where G_glo uses MaxPool downsample.
85
- - L_loc = CE(Softmax(`D_F(E(fine))`), G_loc).
86
- - L = L_glo + λ L_loc, λ=1 (`method_yq.tex`).
87
- - Optimization:
88
- - AdamW (lr=6e-5, wd=0.01), poly schedule (power=1.0), ~40k iterations, batch ≈8 (tune by memory).
89
- - AMP and grad accumulation recommended for stability/memory.
90
-
91
- ## Inference Pipeline
92
- - Coarse pass:
93
- - Downsample to 1024×1024; predict coarse probability/logits.
94
- - Window proposal (sliding window on full-res):
95
- - Tile with patch size p=768. Overlap ~128px (configurable). Compute wire fraction within each window from coarse prediction (prob>0.5).
96
- - If fraction ≥ α (default 0.01), run fine refinement on that patch; else skip (Table `tables/thresholds.tex`).
97
- - Fine refinement + stitching:
98
- - For selected windows, build fine input with cond crop + location mask; predict logits.
99
- - Stitch logits into full-res canvas; average in overlaps; final argmax over classes.
100
- - Outputs: full-res binary mask, plus optional probability map.
101
-
102
- ## Metrics and Reporting
103
- - Implement: IoU, F1, Precision, Recall (global, and optionally per-size bins if available) matching `tables/component.tex`.
104
- - Validate α trade-offs following `tables/thresholds.tex`.
105
-
106
- ## Configuration Surface (key)
107
- - Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
108
- - Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
109
- - Conditioning: `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
110
- - MinMax: `enable=true`, `kernel=6`.
111
- - Label: `coarse_label_downsample='maxpool'`.
112
- - Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
113
- - Inference: `alpha=0.01`, `prob_threshold=0.5` for wire fraction, `stitch='avg_logits'`.
114
-
115
- ## Risks / Gotchas
116
- - Channel expansion requires careful initialization; confirm no NaNs and stable early training.
117
- - Precise spatial alignment of cond and location mask with the patch is critical. Add assertions/tests.
118
- - Even-sized MinMax window (6×6) requires careful padding to maintain alignment.
119
- - Memory with p=768 and MiT-B3 may need tuning (AMP, batch size, overlap).
120
-
121
- ## Milestones
122
- 1) Skeleton + configs + metrics.
123
- 2) Encoder channel expansion + two decoders + 1×1 cond.
124
- 3) MinMax (6×6) + MaxPool label downsampling.
125
- 4) Training loop with ≥1% wire patch sampling.
126
- 5) Inference α-threshold + stitching.
127
- 6) Ablations toggles + scripts + README.
128
- 7) Tests (channel wiring, cond/mask alignment, stitching correctness).
129
-
130
- ## References (paper sources)
131
- - `paper-tex/sections/method.tex`: Two-stage design, shared encoder, 1×1 cond, training/inference sizes, optimizer/schedule.
132
- - `paper-tex/sections/method_yq.tex`: CE losses, λ, sliding-window with α, MinMax & MaxPool rationale.
133
- - `paper-tex/figure_tex/pipeline.tex`: System overview; MinMax concatenation.
134
- - `paper-tex/tables/component.tex`: Ablation of MinMax/MaxPool/coarse.
135
- - `paper-tex/tables/logit.tex`: Conditioning variants.
136
- - `paper-tex/tables/thresholds.tex`: α vs speed/quality.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WireSegHR.pdf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7f6db9a06575398aeb0903c8d19e68f27d983223ca128ff3a3ae12a8aeb8f4a9
3
- size 17039360
 
 
 
 
scripts/drive-viewer-key-readme.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Setting Up Google Drive Access via Service Account (PyDrive2)
2
+
3
+ You'll need a service account to fetch the WireSegHR dataset using a script in this folder. Follow the steps below:
4
+
5
+ 1. **Create a Service Account**
6
+ - Navigate to the Google Cloud Console. Under **IAM & Admin → Service Accounts**, create a new service account.
7
+ - Assign it a Viewer role.
8
+
9
+ 2. **Generate and Download JSON Key**
10
+ - In the service account details, go to **Keys → Add Key → Create new key**, choose **JSON**, and download the key file.
11
+ - Save this file locally to this repo as `secrets/drive-json.json`.
12
+
13
+ 3. **Share the Drive Folder or Files**
14
+ - Grant the service account access to the target Drive folder - https://drive.google.com/drive/folders/1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p - using its service account email. Go to the folder in Google Drive, click on the share button, and add the service account email with Viewer permissions.
15
+
16
+
scripts/export_onnx_trt.py DELETED
@@ -1,167 +0,0 @@
1
- import argparse
2
- import os
3
- import pprint
4
- import shutil
5
- import subprocess
6
- from typing import Tuple
7
-
8
- import torch
9
- import tensorrt as trt
10
-
11
- from src.wireseghr.model import WireSegHR
12
- from pathlib import Path
13
-
14
-
15
- class CoarseModule(torch.nn.Module):
16
- def __init__(self, core: WireSegHR):
17
- super().__init__()
18
- self.core = core
19
-
20
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
21
- logits, cond = self.core.forward_coarse(x)
22
- return logits, cond
23
-
24
-
25
- class FineModule(torch.nn.Module):
26
- def __init__(self, core: WireSegHR):
27
- super().__init__()
28
- self.core = core
29
-
30
- def forward(self, x: torch.Tensor) -> torch.Tensor:
31
- logits = self.core.forward_fine(x)
32
- return logits
33
-
34
-
35
- def build_model(cfg: dict, device: torch.device) -> WireSegHR:
36
- pretrained_flag = bool(cfg.get("pretrained", False))
37
- model = WireSegHR(backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag)
38
- model = model.to(device)
39
- return model
40
-
41
-
42
- def main():
43
- parser = argparse.ArgumentParser(description="Export WireSegHR to ONNX and TensorRT")
44
- parser.add_argument("--config", type=str, default="configs/default.yaml")
45
- parser.add_argument("--ckpt", type=str, default="", help="Path to checkpoint .pt")
46
- parser.add_argument("--out_dir", type=str, default="exports")
47
- parser.add_argument("--coarse_size", type=int, default=1024)
48
- parser.add_argument("--fine_patch_size", type=int, default=1024)
49
- parser.add_argument("--opset", type=int, default=17)
50
- parser.add_argument("--trtexec", type=str, default="", help="Optional path to trtexec to build TRT engines")
51
- parser.add_argument("--build_trt", action="store_true", help="Build TensorRT engines after ONNX export")
52
-
53
- args = parser.parse_args()
54
-
55
- import yaml
56
-
57
- with open(args.config, "r") as f:
58
- cfg = yaml.safe_load(f)
59
- print("[export] Loaded config:")
60
- pprint.pprint(cfg)
61
-
62
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
- model = build_model(cfg, device)
64
-
65
- ckpt_path = args.ckpt if args.ckpt else cfg.get("resume", "")
66
- if ckpt_path:
67
- assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
68
- print(f"[export] Loading checkpoint: {ckpt_path}")
69
- state = torch.load(ckpt_path, map_location=device)
70
- model.load_state_dict(state["model"]) # expects dict with key 'model'
71
- model.eval()
72
-
73
- Path(args.out_dir).mkdir(parents=True, exist_ok=True)
74
-
75
- # Prepare dummy inputs (static shapes for best TRT performance)
76
- coarse_in = torch.randn(1, 6, args.coarse_size, args.coarse_size, device=device)
77
- fine_in = torch.randn(1, 6, args.fine_patch_size, args.fine_patch_size, device=device)
78
-
79
- # Coarse export
80
- coarse_wrapper = CoarseModule(model).to(device).eval()
81
- coarse_onnx = Path(args.out_dir) / f"wireseghr_coarse_{args.coarse_size}.onnx"
82
- print(f"[export] Exporting COARSE to {coarse_onnx}")
83
- torch.onnx.export(
84
- coarse_wrapper,
85
- coarse_in,
86
- str(coarse_onnx),
87
- export_params=True,
88
- opset_version=args.opset,
89
- do_constant_folding=True,
90
- input_names=["x_coarse"],
91
- output_names=["logits", "cond"],
92
- dynamic_axes=None,
93
- dynamo=True
94
- )
95
-
96
- # Fine export
97
- fine_wrapper = FineModule(model).to(device).eval()
98
- fine_onnx = Path(args.out_dir) / f"wireseghr_fine_{args.fine_patch_size}.onnx"
99
- print(f"[export] Exporting FINE to {fine_onnx}")
100
- torch.onnx.export(
101
- fine_wrapper,
102
- fine_in,
103
- str(fine_onnx),
104
- export_params=True,
105
- opset_version=args.opset,
106
- do_constant_folding=True,
107
- input_names=["x_fine"],
108
- output_names=["logits"],
109
- dynamic_axes=None,
110
- )
111
-
112
- # Optional TensorRT building via trtexec; fallback to Python API if unavailable
113
- if args.build_trt:
114
- trtexec_path = args.trtexec if args.trtexec else shutil.which("trtexec")
115
- coarse_engine = Path(args.out_dir) / f"wireseghr_coarse_{args.coarse_size}.engine"
116
- fine_engine = Path(args.out_dir) / f"wireseghr_fine_{args.fine_patch_size}.engine"
117
- if trtexec_path:
118
- def build_engine_cli(onnx_path: str, engine_path: str):
119
- print(f"[export] Building TRT engine (trtexec): {engine_path}")
120
- cmd = [
121
- trtexec_path,
122
- f"--onnx={onnx_path}",
123
- f"--saveEngine={engine_path}",
124
- "--explicitBatch",
125
- "--fp16",
126
- ]
127
- subprocess.run(cmd, check=True)
128
-
129
- build_engine_cli(str(coarse_onnx), str(coarse_engine))
130
- build_engine_cli(str(fine_onnx), str(fine_engine))
131
- else:
132
- print("[export] trtexec not found; building engines via TensorRT Python API")
133
-
134
- def build_engine_py(onnx_path: str, engine_path: str):
135
- logger = trt.Logger(trt.Logger.WARNING)
136
- builder = trt.Builder(logger)
137
- network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
138
- parser = trt.OnnxParser(network, logger)
139
- with open(str(onnx_path), "rb") as f:
140
- data = f.read()
141
- ok = parser.parse(data)
142
- if not ok:
143
- for i in range(parser.num_errors):
144
- print(f"[TRT][parser] {parser.get_error(i)}")
145
- raise RuntimeError("ONNX parse failed")
146
-
147
- config = builder.create_builder_config()
148
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
149
- if builder.platform_has_fast_fp16:
150
- config.set_flag(trt.BuilderFlag.FP16)
151
-
152
- print(f"[export] Building TRT engine (Python): {engine_path}")
153
- serialized = builder.build_serialized_network(network, config)
154
- assert serialized is not None, "Failed to build TensorRT engine"
155
- with open(str(engine_path), "wb") as f:
156
- f.write(serialized)
157
-
158
- build_engine_py(coarse_onnx, coarse_engine)
159
- build_engine_py(fine_onnx, fine_engine)
160
- else:
161
- print("[export] Skipping TensorRT engine build (use --build_trt to enable)")
162
-
163
- print("[export] Done.")
164
-
165
-
166
- if __name__ == "__main__":
167
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/trt_infer.py DELETED
@@ -1,488 +0,0 @@
1
- import argparse
2
- import os
3
- import pprint
4
- import time
5
- from typing import List, Tuple, Optional, Dict, Any
6
-
7
- import numpy as np
8
- import cv2
9
-
10
- # TensorRT + CUDA
11
- try:
12
- import tensorrt as trt
13
- import pycuda.autoinit # noqa: F401 # initializes CUDA driver context
14
- import pycuda.driver as cuda
15
- except Exception as e: # pragma: no cover
16
- raise RuntimeError(
17
- "TensorRT runner requires 'tensorrt' and 'pycuda' Python packages and a valid CUDA/TensorRT install"
18
- ) from e
19
-
20
- import yaml
21
- from pathlib import Path
22
-
23
-
24
- # ---- Utility: TRT engine wrapper ----
25
- class TrtEngine:
26
- def __init__(self, engine_path: str):
27
- assert Path(engine_path).is_file(), f"Engine not found: {engine_path}"
28
- logger = trt.Logger(trt.Logger.ERROR)
29
- with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime:
30
- self.engine = runtime.deserialize_cuda_engine(f.read())
31
- assert self.engine is not None, f"Failed to load engine: {engine_path}"
32
- self.context = self.engine.create_execution_context()
33
- # Default to profile 0
34
- try:
35
- self.context.active_optimization_profile = 0
36
- except Exception:
37
- pass
38
- self.stream = cuda.Stream()
39
- self.bindings: List[int] = [0] * self.engine.num_bindings
40
- self.host_mem: Dict[int, Any] = {}
41
- self.device_mem: Dict[int, Any] = {}
42
-
43
- def _allocate_binding(self, idx: int, shape: Tuple[int, ...]):
44
- dtype = trt.nptype(self.engine.get_binding_dtype(idx))
45
- nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
46
- if idx in self.device_mem:
47
- # Reuse if same size; else reallocate
48
- old = self.device_mem[idx]
49
- if old.size >= nbytes:
50
- self.host_mem[idx] = np.empty(shape, dtype=dtype)
51
- return
52
- # free old and reallocate
53
- del old
54
- self.host_mem[idx] = np.empty(shape, dtype=dtype)
55
- self.device_mem[idx] = cuda.mem_alloc(nbytes)
56
-
57
- def infer(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
58
- # Map names -> indices
59
- name_to_idx = {self.engine.get_binding_name(i): i for i in range(self.engine.num_bindings)}
60
- # Set input shapes (handle dynamic batch if present)
61
- for name, arr in inputs.items():
62
- idx = name_to_idx[name]
63
- assert self.engine.binding_is_input(idx)
64
- # Set shape if dynamic
65
- shape = tuple(arr.shape)
66
- try:
67
- self.context.set_binding_shape(idx, shape)
68
- except Exception:
69
- # Static shape engines won't allow setting; assert it matches
70
- eshape = tuple(self.engine.get_binding_shape(idx))
71
- assert eshape == shape, f"Static engine expects {eshape}, got {shape} for input {name}"
72
- self._allocate_binding(idx, shape)
73
-
74
- # Allocate outputs for resolved shapes
75
- for i in range(self.engine.num_bindings):
76
- if not self.engine.binding_is_input(i):
77
- shape = tuple(self.context.get_binding_shape(i))
78
- assert all(s > 0 for s in shape), f"Unresolved output shape at binding {i}: {shape}"
79
- self._allocate_binding(i, shape)
80
-
81
- # Copy inputs H2D
82
- for name, arr in inputs.items():
83
- idx = name_to_idx[name]
84
- h_arr = self.host_mem[idx]
85
- assert h_arr.shape == arr.shape and h_arr.dtype == arr.dtype
86
- cuda.memcpy_htod_async(self.device_mem[idx], arr, self.stream)
87
- self.bindings[idx] = int(self.device_mem[idx])
88
-
89
- # Set output bindings
90
- for i in range(self.engine.num_bindings):
91
- if not self.engine.binding_is_input(i):
92
- self.bindings[i] = int(self.device_mem[i])
93
-
94
- # Execute
95
- self.context.execute_async_v2(self.bindings, self.stream.handle)
96
-
97
- # D2H outputs
98
- outputs: Dict[str, np.ndarray] = {}
99
- for i in range(self.engine.num_bindings):
100
- if not self.engine.binding_is_input(i):
101
- name = self.engine.get_binding_name(i)
102
- h_arr = self.host_mem[i]
103
- cuda.memcpy_dtoh_async(h_arr, self.device_mem[i], self.stream)
104
- outputs[name] = h_arr
105
-
106
- self.stream.synchronize()
107
- # Return copies to detach from internal buffers
108
- return {k: np.array(v) for k, v in outputs.items()}
109
-
110
-
111
- # ---- Pre/post processing consistent with infer.py ----
112
-
113
- def _pad_for_minmax(kernel: int) -> Tuple[int, int, int, int]:
114
- if (kernel % 2) == 0:
115
- return (kernel // 2 - 1, kernel // 2, kernel // 2 - 1, kernel // 2)
116
- else:
117
- return (kernel // 2, kernel // 2, kernel // 2, kernel // 2)
118
-
119
-
120
- def _build_6ch_coarse(rgb: np.ndarray, coarse_size: int, minmax_enable: bool, minmax_kernel: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
121
- # rgb: HxWx3 float32 [0,1]
122
- H, W = int(rgb.shape[0]), int(rgb.shape[1])
123
- # To match training/minmax in torch, we replicate exact pad + pool logic via torch CPU
124
- import torch
125
- import torch.nn.functional as F
126
-
127
- t_img = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float() # 1x3xHxW
128
- y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
129
- if minmax_enable:
130
- pad = _pad_for_minmax(minmax_kernel)
131
- y_p = F.pad(y_t, pad, mode="replicate")
132
- y_max_full = F.max_pool2d(y_p, kernel_size=minmax_kernel, stride=1)
133
- y_min_full = -F.max_pool2d(-y_p, kernel_size=minmax_kernel, stride=1)
134
- else:
135
- y_min_full = y_t
136
- y_max_full = y_t
137
-
138
- # Resize for coarse
139
- rgb_c = cv2.resize(rgb, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
140
- y_min_c = cv2.resize(y_min_full[0, 0].numpy(), (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
141
- y_max_c = cv2.resize(y_max_full[0, 0].numpy(), (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
142
-
143
- zeros_c = np.zeros((coarse_size, coarse_size), dtype=np.float32)
144
- x6 = np.stack([
145
- rgb_c[:, :, 0], rgb_c[:, :, 1], rgb_c[:, :, 2], y_min_c, y_max_c, zeros_c
146
- ], axis=0) # 6xHc x Wc
147
- return x6.astype(np.float32), y_min_full[0, 0].numpy().astype(np.float32), y_max_full[0, 0].numpy().astype(np.float32), t_img.numpy().astype(np.float32)
148
-
149
-
150
- def _softmax_channel(x: np.ndarray, axis: int = 1) -> np.ndarray:
151
- x_max = np.max(x, axis=axis, keepdims=True)
152
- e = np.exp(x - x_max)
153
- return e / np.sum(e, axis=axis, keepdims=True)
154
-
155
-
156
- def _tiled_fine_trt(
157
- fine: TrtEngine,
158
- t_img: np.ndarray, # 1x3xHxW float32
159
- cond_map: np.ndarray, # 1x1xhxw float32
160
- y_min_full: np.ndarray, # HxW float32
161
- y_max_full: np.ndarray, # HxW float32
162
- patch_size: int,
163
- overlap: int,
164
- fine_batch: int,
165
- ) -> np.ndarray:
166
- H, W = int(t_img.shape[2]), int(t_img.shape[3])
167
- P = patch_size
168
- stride = P - overlap
169
- assert stride > 0
170
- assert H >= P and W >= P
171
-
172
- prob_sum = np.zeros((H, W), dtype=np.float32)
173
- weight = np.zeros((H, W), dtype=np.float32)
174
-
175
- hc4, wc4 = int(cond_map.shape[2]), int(cond_map.shape[3])
176
-
177
- ys = list(range(0, H - P + 1, stride))
178
- if ys[-1] != (H - P):
179
- ys.append(H - P)
180
- xs = list(range(0, W - P + 1, stride))
181
- if xs[-1] != (W - P):
182
- xs.append(W - P)
183
-
184
- coords: List[Tuple[int, int]] = [(y0, x0) for y0 in ys for x0 in xs]
185
-
186
- # Run with batches supported by engine if dynamic; otherwise enforce 1
187
- input_name = None
188
- for i in range(fine.engine.num_bindings):
189
- if fine.engine.binding_is_input(i):
190
- input_name = fine.engine.get_binding_name(i)
191
- shape_decl = fine.engine.get_binding_shape(i)
192
- break
193
- assert input_name is not None
194
-
195
- dynamic_batch = -1 in list(shape_decl)
196
- batch_allowed = fine_batch if dynamic_batch else 1
197
-
198
- for i0 in range(0, len(coords), batch_allowed):
199
- batch_coords = coords[i0 : i0 + batch_allowed]
200
- B = len(batch_coords)
201
- xs_list: List[np.ndarray] = []
202
- for (y0, x0) in batch_coords:
203
- y1, x1 = y0 + P, x0 + P
204
- y0c = (y0 * hc4) // H
205
- y1c = ((y1 * hc4) + H - 1) // H
206
- x0c = (x0 * wc4) // W
207
- x1c = ((x1 * wc4) + W - 1) // W
208
- cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c][0, 0]
209
- cond_patch = cv2.resize(cond_sub, (P, P), interpolation=cv2.INTER_LINEAR)
210
-
211
- rgb_patch = t_img[0, :, y0:y1, x0:x1] # 3xPxP
212
- ymin_patch = y_min_full[y0:y1, x0:x1][None, ...] # 1xPxP
213
- ymax_patch = y_max_full[y0:y1, x0:x1][None, ...] # 1xPxP
214
- x6 = np.concatenate([rgb_patch, ymin_patch, ymax_patch, cond_patch[None, ...]], axis=0)
215
- xs_list.append(x6)
216
-
217
- x_batch = np.stack(xs_list, axis=0).astype(np.float32) # Bx6xPxP
218
- outputs = fine.infer({input_name: x_batch})
219
-
220
- # Assume single output named 'logits' or similar; take the first one
221
- out_name = [n for n in outputs.keys()][0]
222
- logits = outputs[out_name] # Bx2xPxP
223
- prob = _softmax_channel(logits, axis=1)[:, 1, :, :] # BxPxP
224
-
225
- for bi, (y0, x0) in enumerate(batch_coords):
226
- y1, x1 = y0 + P, x0 + P
227
- prob_sum[y0:y1, x0:x1] += prob[bi]
228
- weight[y0:y1, x0:x1] += 1.0
229
-
230
- prob_full = prob_sum / weight
231
- return prob_full.astype(np.float32)
232
-
233
-
234
- def _coarse_trt(coarse: TrtEngine, rgb: np.ndarray, coarse_size: int, minmax_enable: bool, minmax_kernel: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
235
- x6, y_min_full, y_max_full, t_img = _build_6ch_coarse(rgb, coarse_size, minmax_enable, minmax_kernel)
236
- # Engine input name
237
- input_name = None
238
- for i in range(coarse.engine.num_bindings):
239
- if coarse.engine.binding_is_input(i):
240
- input_name = coarse.engine.get_binding_name(i)
241
- break
242
- assert input_name is not None
243
- x = x6[None, ...].astype(np.float32) # 1x6xHc x Wc
244
- outputs = coarse.infer({input_name: x})
245
- # Identify outputs: we expect 2 outputs (logits 1x2xHc x Wc, cond 1x1xHc x Wc)
246
- assert len(outputs) == 2, f"Coarse engine must have 2 outputs, got {list(outputs.keys())}"
247
- # Determine which is cond by channel dim =1
248
- names = list(outputs.keys())
249
- a, b = outputs[names[0]], outputs[names[1]]
250
- if a.shape[1] == 1:
251
- cond = a
252
- logits = b
253
- else:
254
- cond = b
255
- logits = a
256
- # Coarse prob upsampled to full HxW (optional)
257
- prob_c = _softmax_channel(logits, axis=1)[:, 1:2]
258
- H, W = int(t_img.shape[2]), int(t_img.shape[3])
259
- prob_up = cv2.resize(prob_c[0, 0], (W, H), interpolation=cv2.INTER_LINEAR)
260
- return prob_up.astype(np.float32), cond.astype(np.float32), t_img.astype(np.float32), y_min_full, y_max_full
261
-
262
-
263
- # ---- Inference API ----
264
-
265
- def infer_image_trt(
266
- coarse: TrtEngine,
267
- fine: TrtEngine,
268
- img_path: str,
269
- cfg: dict,
270
- out_dir: Optional[str] = None,
271
- save_prob: bool = False,
272
- prob_thresh: Optional[float] = None,
273
- ) -> Tuple[np.ndarray, np.ndarray]:
274
- assert Path(img_path).is_file(), f"Image not found: {img_path}"
275
- bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
276
- assert bgr is not None, f"Failed to read {img_path}"
277
- rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
278
-
279
- coarse_size = int(cfg["coarse"]["test_size"])
280
- patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024 for inference
281
- overlap = int(cfg["fine"]["overlap"])
282
- minmax_enable = bool(cfg["minmax"]["enable"])
283
- minmax_kernel = int(cfg["minmax"]["kernel"])
284
- if prob_thresh is None:
285
- prob_thresh = float(cfg["inference"]["prob_threshold"])
286
-
287
- prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_trt(
288
- coarse, rgb, coarse_size, minmax_enable, minmax_kernel
289
- )
290
-
291
- prob_f = _tiled_fine_trt(
292
- fine,
293
- t_img,
294
- cond_map,
295
- y_min_full,
296
- y_max_full,
297
- patch_size,
298
- overlap,
299
- int(cfg.get("eval", {}).get("fine_batch", 16)),
300
- )
301
-
302
- pred = (prob_f > prob_thresh).astype(np.uint8) * 255
303
-
304
- if out_dir is not None:
305
- Path(out_dir).mkdir(parents=True, exist_ok=True)
306
- stem = Path(img_path).stem
307
- out_mask = Path(out_dir) / f"{stem}_pred.png"
308
- cv2.imwrite(str(out_mask), pred)
309
- if save_prob:
310
- out_prob = Path(out_dir) / f"{stem}_prob.npy"
311
- np.save(str(out_prob), prob_f.astype(np.float32))
312
-
313
- return pred, prob_f
314
-
315
-
316
- def main():
317
- parser = argparse.ArgumentParser(description="WireSegHR TensorRT Inference")
318
- parser.add_argument("--config", type=str, default="configs/default.yaml")
319
- parser.add_argument("--coarse_engine", type=str, required=True)
320
- parser.add_argument("--fine_engine", type=str, required=True)
321
- parser.add_argument("--image", type=str, default="", help="Path to single image")
322
- parser.add_argument("--images_dir", type=str, default="", help="Directory with images")
323
- parser.add_argument("--out", type=str, default="outputs/trt_infer")
324
- parser.add_argument("--save_prob", action="store_true")
325
- # Benchmarking
326
- parser.add_argument("--benchmark", action="store_true")
327
- parser.add_argument("--bench_images_dir", type=str, default="")
328
- parser.add_argument("--bench_limit", type=int, default=0)
329
- parser.add_argument("--bench_warmup", type=int, default=2)
330
- parser.add_argument("--bench_size_filter", type=str, default="")
331
- parser.add_argument("--bench_report_json", type=str, default="")
332
-
333
- args = parser.parse_args()
334
-
335
- with open(args.config, "r") as f:
336
- cfg = yaml.safe_load(f)
337
- print("[TRT][infer] Loaded config:")
338
- pprint.pprint(cfg)
339
-
340
- coarse = TrtEngine(args.coarse_engine)
341
- fine = TrtEngine(args.fine_engine)
342
-
343
- if args.benchmark:
344
- bench_dir = args.bench_images_dir or cfg["data"]["test_images"]
345
- assert Path(bench_dir).is_dir(), f"Not a directory: {bench_dir}"
346
- size_filter: Optional[Tuple[int, int]] = None
347
- if args.bench_size_filter:
348
- try:
349
- h_str, w_str = args.bench_size_filter.lower().split("x")
350
- size_filter = (int(h_str), int(w_str))
351
- except Exception:
352
- raise AssertionError(
353
- f"Invalid --bench_size_filter format: {args.bench_size_filter} (use HxW)"
354
- )
355
- img_files = sorted(
356
- [
357
- str(Path(bench_dir) / p)
358
- for p in os.listdir(bench_dir)
359
- if p.lower().endswith((".jpg", ".jpeg"))
360
- ]
361
- )
362
- assert len(img_files) > 0, f"No .jpg/.jpeg in {bench_dir}"
363
-
364
- if size_filter is not None:
365
- sel: List[str] = []
366
- for p in img_files:
367
- im = cv2.imread(p, cv2.IMREAD_COLOR)
368
- assert im is not None
369
- if im.shape[0] == size_filter[0] and im.shape[1] == size_filter[1]:
370
- sel.append(p)
371
- img_files = sel
372
- assert len(img_files) > 0, (
373
- f"No images matching {size_filter[0]}x{size_filter[1]} in {bench_dir}"
374
- )
375
-
376
- if args.bench_limit > 0:
377
- img_files = img_files[: args.bench_limit]
378
-
379
- print(f"[TRT][bench] Images: {len(img_files)} from {bench_dir}")
380
- print(f"[TRT][bench] Warmup: {args.bench_warmup}")
381
-
382
- timings: List[Dict[str, Any]] = []
383
- # Warmup
384
- for i in range(min(args.bench_warmup, len(img_files))):
385
- infer_image_trt(coarse, fine, img_files[i], cfg, out_dir=None, save_prob=False)
386
-
387
- # Timed runs
388
- for p in img_files[args.bench_warmup :]:
389
- t0 = time.perf_counter()
390
- bgr = cv2.imread(p, cv2.IMREAD_COLOR)
391
- assert bgr is not None
392
- rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
393
-
394
- coarse_size = int(cfg["coarse"]["test_size"])
395
- minmax_enable = bool(cfg["minmax"]["enable"])
396
- minmax_kernel = int(cfg["minmax"]["kernel"])
397
-
398
- c0 = time.perf_counter()
399
- prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_trt(
400
- coarse, rgb, coarse_size, minmax_enable, minmax_kernel
401
- )
402
- c1 = time.perf_counter()
403
-
404
- patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024
405
- overlap = int(cfg["fine"]["overlap"])
406
-
407
- prob_f = _tiled_fine_trt(
408
- fine,
409
- t_img,
410
- cond_map,
411
- y_min_full,
412
- y_max_full,
413
- patch_size,
414
- overlap,
415
- int(cfg.get("eval", {}).get("fine_batch", 16)),
416
- )
417
- c2 = time.perf_counter()
418
-
419
- timings.append(
420
- {
421
- "path": p,
422
- "H": int(t_img.shape[2]),
423
- "W": int(t_img.shape[3]),
424
- "t_coarse_ms": (c1 - c0) * 1000.0,
425
- "t_fine_ms": (c2 - c1) * 1000.0,
426
- "t_total_ms": (c2 - t0) * 1000.0,
427
- }
428
- )
429
-
430
- if len(timings) == 0:
431
- print("[TRT][bench] Nothing to benchmark after warmup.")
432
- return
433
-
434
- def _agg(key: str) -> Tuple[float, float, float]:
435
- vals = sorted([t[key] for t in timings])
436
- n = len(vals)
437
- p50 = vals[n // 2]
438
- p95 = vals[min(n - 1, int(0.95 * (n - 1)))]
439
- avg = sum(vals) / n
440
- return avg, p50, p95
441
-
442
- avg_c, p50_c, p95_c = _agg("t_coarse_ms")
443
- avg_f, p50_f, p95_f = _agg("t_fine_ms")
444
- avg_t, p50_t, p95_t = _agg("t_total_ms")
445
-
446
- print("[TRT][bench] Results (ms):")
447
- print(f" Coarse avg={avg_c:.2f} p50={p50_c:.2f} p95={p95_c:.2f}")
448
- print(f" Fine avg={avg_f:.2f} p50={p50_f:.2f} p95={p95_f:.2f}")
449
- print(f" Total avg={avg_t:.2f} p50={p50_t:.2f} p95={p95_t:.2f}")
450
- print(f" Target < 1000 ms per 3000x4000 image: {'YES' if p50_t < 1000.0 else 'NO'}")
451
-
452
- if args.bench_report_json:
453
- import json
454
- report = {
455
- "summary": {
456
- "avg_ms": avg_t,
457
- "p50_ms": p50_t,
458
- "p95_ms": p95_t,
459
- "avg_coarse_ms": avg_c,
460
- "avg_fine_ms": avg_f,
461
- "images": len(timings),
462
- },
463
- "per_image": timings,
464
- }
465
- with open(args.bench_report_json, "w") as f:
466
- json.dump(report, f, indent=2)
467
- return
468
-
469
- # Non-benchmark single/directory
470
- assert (args.image != "") ^ (args.images_dir != ""), "Provide exactly one of --image or --images_dir"
471
- if args.image:
472
- infer_image_trt(coarse, fine, args.image, cfg, out_dir=args.out, save_prob=args.save_prob)
473
- print("[TRT][infer] Done.")
474
- return
475
-
476
- img_dir = args.images_dir
477
- assert Path(img_dir).is_dir()
478
- Path(args.out).mkdir(parents=True, exist_ok=True)
479
- img_files = sorted([p for p in os.listdir(img_dir) if p.lower().endswith((".jpg", ".jpeg"))])
480
- assert len(img_files) > 0
481
- for name in img_files:
482
- p = str(Path(img_dir) / name)
483
- infer_image_trt(coarse, fine, p, cfg, out_dir=args.out, save_prob=args.save_prob)
484
- print("[TRT][infer] Done.")
485
-
486
-
487
- if __name__ == "__main__":
488
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -401,30 +401,6 @@ def main():
401
  print("[WireSegHR][train] Done.")
402
 
403
 
404
- def _sample_batch_same_size(
405
- dset: WireSegDataset, batch_size: int
406
- ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
407
- # Use precomputed size bins to sample a batch from a single (H, W) bin
408
- assert len(dset) > 0
409
- bins = dset.size_bins
410
- keys = list(bins.keys())
411
- random.shuffle(keys)
412
- chosen_key = None
413
- for hw in keys:
414
- if len(bins[hw]) >= batch_size:
415
- chosen_key = hw
416
- break
417
- assert chosen_key is not None, f"No size bin with at least {batch_size} samples"
418
- pool = bins[chosen_key]
419
- idxs = np.random.choice(pool, size=batch_size, replace=False)
420
- imgs: List[np.ndarray] = []
421
- masks: List[np.ndarray] = []
422
- for idx in idxs:
423
- item = dset[int(idx)]
424
- imgs.append(item["image"])
425
- masks.append(item["mask"])
426
- return imgs, masks
427
-
428
 
429
  def _prepare_batch(
430
  imgs: List[np.ndarray],
 
401
  print("[WireSegHR][train] Done.")
402
 
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  def _prepare_batch(
406
  imgs: List[np.ndarray],