README and cleanup
Browse files- .windsurf/rules/defensive-logic.md +1 -1
- README.md +115 -15
- SEGMENTATION_PLAN.md +0 -136
- WireSegHR.pdf +0 -3
- scripts/drive-viewer-key-readme.md +16 -0
- scripts/export_onnx_trt.py +0 -167
- scripts/trt_infer.py +0 -488
- train.py +0 -24
.windsurf/rules/defensive-logic.md
CHANGED
@@ -5,4 +5,4 @@ trigger: always_on
|
|
5 |
When deciding if to write defensive logic, e.g. dimensionality handling: `tensor1=tensor1.unsqueeze() if tensor1.ndim==1 else tensor1`, or None handling: `var1=var1 if var1 else torch.zeros(...)`, just don't write these things. In my code, shapes are always static, and there is one execution path for all code. I prefer `assert` over defensive logic. If you are writing something to fix the tests and this seems necessary, it's likely that the tests are setup incorrectly.
|
6 |
The reason why I don't want it is because defensive logic leads to silent failures, and these are bad for debugging.
|
7 |
|
8 |
-
In addition, writing "int()" type casting is also a piece of defensive logic that slows down the application. Don't write it unless really necessary e.g. putting int to string. In most cases, indexing with a tensor should be better.
|
|
|
5 |
When deciding if to write defensive logic, e.g. dimensionality handling: `tensor1=tensor1.unsqueeze() if tensor1.ndim==1 else tensor1`, or None handling: `var1=var1 if var1 else torch.zeros(...)`, just don't write these things. In my code, shapes are always static, and there is one execution path for all code. I prefer `assert` over defensive logic. If you are writing something to fix the tests and this seems necessary, it's likely that the tests are setup incorrectly.
|
6 |
The reason why I don't want it is because defensive logic leads to silent failures, and these are bad for debugging.
|
7 |
|
8 |
+
In addition, writing "int()", "float()" or "bool()" type casting is also a piece of defensive logic that slows down the application. Don't write it unless really necessary e.g. putting int to string. In most cases, indexing with a tensor should be better.
|
README.md
CHANGED
@@ -1,30 +1,35 @@
|
|
1 |
# WireSegHR (Segmentation Only)
|
2 |
|
3 |
-
This repository contains the segmentation-only implementation
|
4 |
|
5 |
-
|
6 |
-
- Long-term navigation plan: `SEGMENTATION_PLAN.md`.
|
7 |
|
8 |
-
|
9 |
|
10 |
-
|
|
|
|
|
11 |
|
12 |
```bash
|
13 |
-
|
14 |
-
source .venv/bin/activate
|
15 |
-
pip install -r requirements.txt
|
16 |
```
|
|
|
17 |
|
18 |
-
|
19 |
|
20 |
```bash
|
21 |
-
|
22 |
-
|
23 |
```
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
## Notes
|
30 |
- This is a segmentation-only codebase. Inpainting is out of scope here.
|
@@ -41,6 +46,101 @@ python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/ima
|
|
41 |
- `dataset/val/images/...` and `dataset/val/gts/...`
|
42 |
- `dataset/test/images/...` and `dataset/test/gts/...`
|
43 |
- Masks are binary: foreground = white (255), background = black (0).
|
44 |
-
- The loader strictly enforces numeric stems and 1:1 pairing and will
|
45 |
|
46 |
Update `configs/default.yaml` with your paths under `data.train_images`, `data.train_masks`, etc. Defaults point to `dataset/train/images`, `dataset/train/gts`, and validation to `dataset/val/...`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# WireSegHR (Segmentation Only)
|
2 |
|
3 |
+
This repository contains the segmentation-only implementation of the two-stage [WireSegHR model](https://arxiv.org/abs/2304.00221), training on the WireSegHR dataset plus the [TTPLA dataset](https://github.com/R3ab/ttpla_dataset).
|
4 |
|
5 |
+
## Quick Start
|
|
|
6 |
|
7 |
+
1) Get secrets necessary for fetching of the dataset:
|
8 |
|
9 |
+
You'll need a GDrive service account to fetch the WireSegHR dataset using scripts in this repo. Get a GDrive key as described [in this short README](scripts/drive-viewer-key-readme.md), and put it in `/secrets/drive-json.json`
|
10 |
+
|
11 |
+
2) Run:
|
12 |
|
13 |
```bash
|
14 |
+
scripts/setup.sh
|
|
|
|
|
15 |
```
|
16 |
+
This installs dependencies and merges the TTPLA dataset into the WireSegHR dataset format.
|
17 |
|
18 |
+
3) Train and run a quick inference check:
|
19 |
|
20 |
```bash
|
21 |
+
python3 train.py --config configs/default.yaml
|
22 |
+
python3 infer.py --config configs/default.yaml --image /path/to/image.jpg
|
23 |
```
|
24 |
|
25 |
+
The default config `default.yaml` is suitable for a 24GB VRAM GPU with support for bf16 (e.g., RTX 3090/4090).
|
26 |
+
<!-- For a quick RTX GPU setup, I recommend [vast.ai](https://cloud.vast.ai/?ref_id=162850) -->
|
27 |
+
|
28 |
+
## Project Overview
|
29 |
+
- Two-stage, global-to-local segmentation with a shared encoder and a fine decoder conditioned on the coarse stage.
|
30 |
+
- Full training loop with AMP (optional), Poly LR, periodic evaluation, checkpointing, and test visualizations (`train.py`).
|
31 |
+
- Dataset utilities under `src/wireseghr/data/` and model components under `src/wireseghr/model/`.
|
32 |
+
- Paper text and figures live in `paper-tex/` (`paper-tex/sections/` contains the Method, Results, etc.).
|
33 |
|
34 |
## Notes
|
35 |
- This is a segmentation-only codebase. Inpainting is out of scope here.
|
|
|
46 |
- `dataset/val/images/...` and `dataset/val/gts/...`
|
47 |
- `dataset/test/images/...` and `dataset/test/gts/...`
|
48 |
- Masks are binary: foreground = white (255), background = black (0).
|
49 |
+
- The loader strictly enforces numeric stems and 1:1 pairing of naming and will raise on file name mismatches.
|
50 |
|
51 |
Update `configs/default.yaml` with your paths under `data.train_images`, `data.train_masks`, etc. Defaults point to `dataset/train/images`, `dataset/train/gts`, and validation to `dataset/val/...`.
|
52 |
+
|
53 |
+
## Inference
|
54 |
+
|
55 |
+
- Single image (optionally save outputs to a directory):
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python3 infer.py \
|
59 |
+
--config configs/default.yaml \
|
60 |
+
--ckpt ckpt_5000.pt \
|
61 |
+
--image dataset/test/images/123.jpg \
|
62 |
+
--out outputs/infer
|
63 |
+
```
|
64 |
+
|
65 |
+
- Compute metrics for a single image (requires a GT mask):
|
66 |
+
|
67 |
+
```bash
|
68 |
+
python3 infer.py \
|
69 |
+
--config configs/default.yaml \
|
70 |
+
--ckpt ckpt_5000.pt \
|
71 |
+
--image dataset/test/images/123.jpg \
|
72 |
+
--out outputs/infer \
|
73 |
+
--metrics \
|
74 |
+
--mask dataset/test/gts/123.png
|
75 |
+
```
|
76 |
+
|
77 |
+
- Run inference over the entire directory with metrics (images_dir sets the image directory, masks_dir sets the ground truth mask directory):
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python3 infer.py \
|
81 |
+
--config configs/default.yaml \
|
82 |
+
--ckpt ckpt_5000.pt \
|
83 |
+
--images_dir dataset/test/images \
|
84 |
+
--out outputs/infer \
|
85 |
+
--metrics \
|
86 |
+
--masks_dir dataset/test/gts
|
87 |
+
```
|
88 |
+
|
89 |
+
Notes:
|
90 |
+
- Predictions are saved as 0/255 PNGs. For metrics, predictions are binarized with `> 0` to match training logic.
|
91 |
+
- Masks are matched by filename stem: `images/123.jpg` ↔ `gts/123.png`.
|
92 |
+
|
93 |
+
## Benchmarking and Metrics
|
94 |
+
|
95 |
+
Benchmark mode times the model on a directory of images and reports coarse/fine/total latency statistics. When `--metrics` is provided, it also computes IoU/F1/Precision/Recall over the benchmark set (both fine and coarse outputs).
|
96 |
+
|
97 |
+
Example (uses `data.test_images` and `data.test_masks` from the config by default):
|
98 |
+
|
99 |
+
```bash
|
100 |
+
python3 infer.py \
|
101 |
+
--config configs/default.yaml \
|
102 |
+
--benchmark \
|
103 |
+
--ckpt ckpt_5000.pt \
|
104 |
+
--bench_warmup 2 \
|
105 |
+
--bench_limit 0 \
|
106 |
+
--bench_report_json outputs/bench_report.json \
|
107 |
+
--metrics
|
108 |
+
```
|
109 |
+
|
110 |
+
If your ground truth directory is different from `data.test_masks`, please override it with `--bench_masks_dir`:
|
111 |
+
|
112 |
+
```bash
|
113 |
+
python3 infer.py \
|
114 |
+
--config configs/default.yaml \
|
115 |
+
--benchmark \
|
116 |
+
--ckpt ckpt_5000.pt \
|
117 |
+
--bench_warmup 2 \
|
118 |
+
--bench_limit 0 \
|
119 |
+
--bench_report_json outputs/bench_report.json \
|
120 |
+
--metrics \
|
121 |
+
--bench_masks_dir /path/to/gts
|
122 |
+
```
|
123 |
+
|
124 |
+
You will see output like:
|
125 |
+
|
126 |
+
```
|
127 |
+
[WireSegHR][bench] Results (ms):
|
128 |
+
Coarse avg=50.16 p50=44.48 p95=76.78
|
129 |
+
Fine avg=534.38 p50=419.52 p95=1187.66
|
130 |
+
Total avg=584.54 p50=464.73 p95=1300.07
|
131 |
+
Target < 1000 ms per 3000x4000 image: YES
|
132 |
+
[WireSegHR][bench][Fine] IoU=0.6098 F1=0.7576 P=0.6418 R=0.9244
|
133 |
+
[WireSegHR][bench][Coarse] IoU=0.5315 F1=0.6941 P=0.5467 R=0.9502
|
134 |
+
```
|
135 |
+
**These metrics were obtained after 5000 iterations*
|
136 |
+
|
137 |
+
Optional: you can save a JSON timing report with `--bench_report_json`. Schema:
|
138 |
+
- `summary`
|
139 |
+
- `avg_ms`, `p50_ms`, `p95_ms`
|
140 |
+
- `avg_coarse_ms`, `avg_fine_ms`
|
141 |
+
- `images`
|
142 |
+
- `per_image`: list of objects with
|
143 |
+
- `path`, `H`, `W`, `t_coarse_ms`, `t_fine_ms`, `t_total_ms`
|
144 |
+
|
145 |
+
Utils:
|
146 |
+
- Export your model to inference-only weights by scripts/strip_checkpoint.py
|
SEGMENTATION_PLAN.md
DELETED
@@ -1,136 +0,0 @@
|
|
1 |
-
# WireSegHR Segmentation-Only Implementation Plan
|
2 |
-
|
3 |
-
This plan distills the model and pipeline described in the paper sources:
|
4 |
-
- `paper-tex/sections/method.tex`
|
5 |
-
- `paper-tex/sections/method_yq.tex`
|
6 |
-
- `paper-tex/figure_tex/pipeline.tex`
|
7 |
-
- `paper-tex/tables/{component,logit,thresholds}.tex`
|
8 |
-
|
9 |
-
Focus: segmentation only (no dataset collection or inpainting).
|
10 |
-
|
11 |
-
## Decisions and Defaults (locked)
|
12 |
-
- Backbone: SegFormer MiT-B3 via HuggingFace Transformers.
|
13 |
-
- Fine/local patch size p: 768.
|
14 |
-
- Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
|
15 |
-
- Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
|
16 |
-
- MinMax feature augmentation: luminance min and max with a fixed 6×6 window; channels concatenated to inputs (Figure `figure_tex/pipeline.tex`, Sec. “Wire Feature Preservation” in `method_yq.tex`).
|
17 |
-
- Loss: CE on both branches, λ = 1 (`method_yq.tex`).
|
18 |
-
- α-threshold for refining windows: default 0.01 (Table `tables/thresholds.tex`).
|
19 |
-
- Coarse input size: train 512×512; test 1024×1024 (`method.tex`).
|
20 |
-
- Optim: AdamW (lr=6e-5, wd=0.01, poly schedule with power=1), ~40k iters, batch size ~8 (`method.tex`).
|
21 |
-
|
22 |
-
## Project Structure
|
23 |
-
- `configs/`
|
24 |
-
- `default.yaml` (backbone=mit_b2, p=768, coarse_train=512, coarse_test=1024, alpha=0.01, minmax=true, kernel=6, maxpool_label=true, cond_variant=global)
|
25 |
-
- `src/wireseghr/`
|
26 |
-
- `model/`
|
27 |
-
- `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
|
28 |
-
- `decoder.py` (two MLP decoders `D_C`, `D_F` for 2 classes)
|
29 |
-
- `condition.py` (1×1 conv to collapse coarse 2-ch logits → 1-ch cond)
|
30 |
-
- `minmax.py` (6×6 luminance min/max filtering)
|
31 |
-
- `label_downsample.py` (MaxPool-based coarse GT downsampling)
|
32 |
-
- `data/`
|
33 |
-
- `dataset.py` (image/mask loading, full-res to coarse/fine inputs)
|
34 |
-
- `sampler.py` (balanced patch sampling with ≥1% wire pixels)
|
35 |
-
- `transforms.py` (scaling, rotation, flip, photometric distortion)
|
36 |
-
- `train.py` (end-to-end two-branch training)
|
37 |
-
- `infer.py` (coarse-to-fine sliding-window inference + stitching)
|
38 |
-
- `metrics.py` (IoU, F1, Precision, Recall)
|
39 |
-
- `utils.py` (misc: overlap blending, seeding, logging)
|
40 |
-
- `tests/` (unit tests for channel wiring, cond alignment, stitching)
|
41 |
-
- `README.md` (segmentation-only usage)
|
42 |
-
|
43 |
-
## Model Specification
|
44 |
-
- Shared encoder `E`: SegFormer MiT-B3 (HF Transformers preferred).
|
45 |
-
- Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
|
46 |
-
- For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
|
47 |
-
- Weight init for extra channels: copy mean of RGB conv weights or zero-init.
|
48 |
-
- Decoders: two SegFormer MLP decoders
|
49 |
-
- `D_C`: coarse logits (2 channels) at coarse resolution.
|
50 |
-
- `D_F`: fine logits (2 channels) at patch resolution p×p.
|
51 |
-
- Conditioning to fine branch (default):
|
52 |
-
- Take coarse pre-softmax logits (2-ch), apply 1×1 conv → 1-ch cond map (`method.tex`).
|
53 |
-
- Binary location mask: 1 inside current patch region (in full-image coordinates), 0 elsewhere.
|
54 |
-
- Pass patch-aligned cond crop and binary mask as channels to the fine branch input.
|
55 |
-
- Notes:
|
56 |
-
- We follow the published version (`paper-tex/sections/method_yq.tex`) and use patch-cropped conditioning exclusively; no full-image conditioning variant will be implemented.
|
57 |
-
|
58 |
-
## Data and Preprocessing
|
59 |
-
- MinMax luminance features (both branches):
|
60 |
-
- Y = 0.299R + 0.587G + 0.114B.
|
61 |
-
- Y_min = min filter (6×6), Y_max = max filter (6×6).
|
62 |
-
- Concat [Y_min, Y_max] to the input image channels.
|
63 |
-
- Coarse GT label generation (MaxPool):
|
64 |
-
- Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
|
65 |
-
- Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
|
66 |
-
|
67 |
-
### Dataset Convention (project-specific)
|
68 |
-
- Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
|
69 |
-
- Example:
|
70 |
-
- `dataset/images/1.jpg, 2.jpg, ..., N.jpg` (or `.jpeg`)
|
71 |
-
- `dataset/gts/1.png, 2.png, ..., N.png`
|
72 |
-
- Masks are binary: foreground = white (255), background = black (0).
|
73 |
-
- The loader (`data/dataset.py`) strictly enforces numeric stems and 1:1 pairing and will assert on mismatch.
|
74 |
-
|
75 |
-
## Training Pipeline
|
76 |
-
- Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
|
77 |
-
- Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
|
78 |
-
- Fine input (per iteration select 1–k patches):
|
79 |
-
- Sample p×p patch (p=768) with ≥1% wire pixels (`method.tex`, `method_yq.tex`).
|
80 |
-
- Build cond map from coarse logits via 1×1 conv; crop cond to patch region.
|
81 |
-
- Build binary location mask for patch region.
|
82 |
-
- Build channels [RGB + MinMax + cond + location] → `E` → `D_F`.
|
83 |
-
- Losses:
|
84 |
-
- L_glo = CE(Softmax(`D_C(E(coarse))`), G_glo), where G_glo uses MaxPool downsample.
|
85 |
-
- L_loc = CE(Softmax(`D_F(E(fine))`), G_loc).
|
86 |
-
- L = L_glo + λ L_loc, λ=1 (`method_yq.tex`).
|
87 |
-
- Optimization:
|
88 |
-
- AdamW (lr=6e-5, wd=0.01), poly schedule (power=1.0), ~40k iterations, batch ≈8 (tune by memory).
|
89 |
-
- AMP and grad accumulation recommended for stability/memory.
|
90 |
-
|
91 |
-
## Inference Pipeline
|
92 |
-
- Coarse pass:
|
93 |
-
- Downsample to 1024×1024; predict coarse probability/logits.
|
94 |
-
- Window proposal (sliding window on full-res):
|
95 |
-
- Tile with patch size p=768. Overlap ~128px (configurable). Compute wire fraction within each window from coarse prediction (prob>0.5).
|
96 |
-
- If fraction ≥ α (default 0.01), run fine refinement on that patch; else skip (Table `tables/thresholds.tex`).
|
97 |
-
- Fine refinement + stitching:
|
98 |
-
- For selected windows, build fine input with cond crop + location mask; predict logits.
|
99 |
-
- Stitch logits into full-res canvas; average in overlaps; final argmax over classes.
|
100 |
-
- Outputs: full-res binary mask, plus optional probability map.
|
101 |
-
|
102 |
-
## Metrics and Reporting
|
103 |
-
- Implement: IoU, F1, Precision, Recall (global, and optionally per-size bins if available) matching `tables/component.tex`.
|
104 |
-
- Validate α trade-offs following `tables/thresholds.tex`.
|
105 |
-
|
106 |
-
## Configuration Surface (key)
|
107 |
-
- Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
|
108 |
-
- Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
|
109 |
-
- Conditioning: `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
|
110 |
-
- MinMax: `enable=true`, `kernel=6`.
|
111 |
-
- Label: `coarse_label_downsample='maxpool'`.
|
112 |
-
- Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
|
113 |
-
- Inference: `alpha=0.01`, `prob_threshold=0.5` for wire fraction, `stitch='avg_logits'`.
|
114 |
-
|
115 |
-
## Risks / Gotchas
|
116 |
-
- Channel expansion requires careful initialization; confirm no NaNs and stable early training.
|
117 |
-
- Precise spatial alignment of cond and location mask with the patch is critical. Add assertions/tests.
|
118 |
-
- Even-sized MinMax window (6×6) requires careful padding to maintain alignment.
|
119 |
-
- Memory with p=768 and MiT-B3 may need tuning (AMP, batch size, overlap).
|
120 |
-
|
121 |
-
## Milestones
|
122 |
-
1) Skeleton + configs + metrics.
|
123 |
-
2) Encoder channel expansion + two decoders + 1×1 cond.
|
124 |
-
3) MinMax (6×6) + MaxPool label downsampling.
|
125 |
-
4) Training loop with ≥1% wire patch sampling.
|
126 |
-
5) Inference α-threshold + stitching.
|
127 |
-
6) Ablations toggles + scripts + README.
|
128 |
-
7) Tests (channel wiring, cond/mask alignment, stitching correctness).
|
129 |
-
|
130 |
-
## References (paper sources)
|
131 |
-
- `paper-tex/sections/method.tex`: Two-stage design, shared encoder, 1×1 cond, training/inference sizes, optimizer/schedule.
|
132 |
-
- `paper-tex/sections/method_yq.tex`: CE losses, λ, sliding-window with α, MinMax & MaxPool rationale.
|
133 |
-
- `paper-tex/figure_tex/pipeline.tex`: System overview; MinMax concatenation.
|
134 |
-
- `paper-tex/tables/component.tex`: Ablation of MinMax/MaxPool/coarse.
|
135 |
-
- `paper-tex/tables/logit.tex`: Conditioning variants.
|
136 |
-
- `paper-tex/tables/thresholds.tex`: α vs speed/quality.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WireSegHR.pdf
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7f6db9a06575398aeb0903c8d19e68f27d983223ca128ff3a3ae12a8aeb8f4a9
|
3 |
-
size 17039360
|
|
|
|
|
|
|
|
scripts/drive-viewer-key-readme.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Setting Up Google Drive Access via Service Account (PyDrive2)
|
2 |
+
|
3 |
+
You'll need a service account to fetch the WireSegHR dataset using a script in this folder. Follow the steps below:
|
4 |
+
|
5 |
+
1. **Create a Service Account**
|
6 |
+
- Navigate to the Google Cloud Console. Under **IAM & Admin → Service Accounts**, create a new service account.
|
7 |
+
- Assign it a Viewer role.
|
8 |
+
|
9 |
+
2. **Generate and Download JSON Key**
|
10 |
+
- In the service account details, go to **Keys → Add Key → Create new key**, choose **JSON**, and download the key file.
|
11 |
+
- Save this file locally to this repo as `secrets/drive-json.json`.
|
12 |
+
|
13 |
+
3. **Share the Drive Folder or Files**
|
14 |
+
- Grant the service account access to the target Drive folder - https://drive.google.com/drive/folders/1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p - using its service account email. Go to the folder in Google Drive, click on the share button, and add the service account email with Viewer permissions.
|
15 |
+
|
16 |
+
|
scripts/export_onnx_trt.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import pprint
|
4 |
-
import shutil
|
5 |
-
import subprocess
|
6 |
-
from typing import Tuple
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import tensorrt as trt
|
10 |
-
|
11 |
-
from src.wireseghr.model import WireSegHR
|
12 |
-
from pathlib import Path
|
13 |
-
|
14 |
-
|
15 |
-
class CoarseModule(torch.nn.Module):
|
16 |
-
def __init__(self, core: WireSegHR):
|
17 |
-
super().__init__()
|
18 |
-
self.core = core
|
19 |
-
|
20 |
-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
21 |
-
logits, cond = self.core.forward_coarse(x)
|
22 |
-
return logits, cond
|
23 |
-
|
24 |
-
|
25 |
-
class FineModule(torch.nn.Module):
|
26 |
-
def __init__(self, core: WireSegHR):
|
27 |
-
super().__init__()
|
28 |
-
self.core = core
|
29 |
-
|
30 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
31 |
-
logits = self.core.forward_fine(x)
|
32 |
-
return logits
|
33 |
-
|
34 |
-
|
35 |
-
def build_model(cfg: dict, device: torch.device) -> WireSegHR:
|
36 |
-
pretrained_flag = bool(cfg.get("pretrained", False))
|
37 |
-
model = WireSegHR(backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag)
|
38 |
-
model = model.to(device)
|
39 |
-
return model
|
40 |
-
|
41 |
-
|
42 |
-
def main():
|
43 |
-
parser = argparse.ArgumentParser(description="Export WireSegHR to ONNX and TensorRT")
|
44 |
-
parser.add_argument("--config", type=str, default="configs/default.yaml")
|
45 |
-
parser.add_argument("--ckpt", type=str, default="", help="Path to checkpoint .pt")
|
46 |
-
parser.add_argument("--out_dir", type=str, default="exports")
|
47 |
-
parser.add_argument("--coarse_size", type=int, default=1024)
|
48 |
-
parser.add_argument("--fine_patch_size", type=int, default=1024)
|
49 |
-
parser.add_argument("--opset", type=int, default=17)
|
50 |
-
parser.add_argument("--trtexec", type=str, default="", help="Optional path to trtexec to build TRT engines")
|
51 |
-
parser.add_argument("--build_trt", action="store_true", help="Build TensorRT engines after ONNX export")
|
52 |
-
|
53 |
-
args = parser.parse_args()
|
54 |
-
|
55 |
-
import yaml
|
56 |
-
|
57 |
-
with open(args.config, "r") as f:
|
58 |
-
cfg = yaml.safe_load(f)
|
59 |
-
print("[export] Loaded config:")
|
60 |
-
pprint.pprint(cfg)
|
61 |
-
|
62 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
-
model = build_model(cfg, device)
|
64 |
-
|
65 |
-
ckpt_path = args.ckpt if args.ckpt else cfg.get("resume", "")
|
66 |
-
if ckpt_path:
|
67 |
-
assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
|
68 |
-
print(f"[export] Loading checkpoint: {ckpt_path}")
|
69 |
-
state = torch.load(ckpt_path, map_location=device)
|
70 |
-
model.load_state_dict(state["model"]) # expects dict with key 'model'
|
71 |
-
model.eval()
|
72 |
-
|
73 |
-
Path(args.out_dir).mkdir(parents=True, exist_ok=True)
|
74 |
-
|
75 |
-
# Prepare dummy inputs (static shapes for best TRT performance)
|
76 |
-
coarse_in = torch.randn(1, 6, args.coarse_size, args.coarse_size, device=device)
|
77 |
-
fine_in = torch.randn(1, 6, args.fine_patch_size, args.fine_patch_size, device=device)
|
78 |
-
|
79 |
-
# Coarse export
|
80 |
-
coarse_wrapper = CoarseModule(model).to(device).eval()
|
81 |
-
coarse_onnx = Path(args.out_dir) / f"wireseghr_coarse_{args.coarse_size}.onnx"
|
82 |
-
print(f"[export] Exporting COARSE to {coarse_onnx}")
|
83 |
-
torch.onnx.export(
|
84 |
-
coarse_wrapper,
|
85 |
-
coarse_in,
|
86 |
-
str(coarse_onnx),
|
87 |
-
export_params=True,
|
88 |
-
opset_version=args.opset,
|
89 |
-
do_constant_folding=True,
|
90 |
-
input_names=["x_coarse"],
|
91 |
-
output_names=["logits", "cond"],
|
92 |
-
dynamic_axes=None,
|
93 |
-
dynamo=True
|
94 |
-
)
|
95 |
-
|
96 |
-
# Fine export
|
97 |
-
fine_wrapper = FineModule(model).to(device).eval()
|
98 |
-
fine_onnx = Path(args.out_dir) / f"wireseghr_fine_{args.fine_patch_size}.onnx"
|
99 |
-
print(f"[export] Exporting FINE to {fine_onnx}")
|
100 |
-
torch.onnx.export(
|
101 |
-
fine_wrapper,
|
102 |
-
fine_in,
|
103 |
-
str(fine_onnx),
|
104 |
-
export_params=True,
|
105 |
-
opset_version=args.opset,
|
106 |
-
do_constant_folding=True,
|
107 |
-
input_names=["x_fine"],
|
108 |
-
output_names=["logits"],
|
109 |
-
dynamic_axes=None,
|
110 |
-
)
|
111 |
-
|
112 |
-
# Optional TensorRT building via trtexec; fallback to Python API if unavailable
|
113 |
-
if args.build_trt:
|
114 |
-
trtexec_path = args.trtexec if args.trtexec else shutil.which("trtexec")
|
115 |
-
coarse_engine = Path(args.out_dir) / f"wireseghr_coarse_{args.coarse_size}.engine"
|
116 |
-
fine_engine = Path(args.out_dir) / f"wireseghr_fine_{args.fine_patch_size}.engine"
|
117 |
-
if trtexec_path:
|
118 |
-
def build_engine_cli(onnx_path: str, engine_path: str):
|
119 |
-
print(f"[export] Building TRT engine (trtexec): {engine_path}")
|
120 |
-
cmd = [
|
121 |
-
trtexec_path,
|
122 |
-
f"--onnx={onnx_path}",
|
123 |
-
f"--saveEngine={engine_path}",
|
124 |
-
"--explicitBatch",
|
125 |
-
"--fp16",
|
126 |
-
]
|
127 |
-
subprocess.run(cmd, check=True)
|
128 |
-
|
129 |
-
build_engine_cli(str(coarse_onnx), str(coarse_engine))
|
130 |
-
build_engine_cli(str(fine_onnx), str(fine_engine))
|
131 |
-
else:
|
132 |
-
print("[export] trtexec not found; building engines via TensorRT Python API")
|
133 |
-
|
134 |
-
def build_engine_py(onnx_path: str, engine_path: str):
|
135 |
-
logger = trt.Logger(trt.Logger.WARNING)
|
136 |
-
builder = trt.Builder(logger)
|
137 |
-
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
138 |
-
parser = trt.OnnxParser(network, logger)
|
139 |
-
with open(str(onnx_path), "rb") as f:
|
140 |
-
data = f.read()
|
141 |
-
ok = parser.parse(data)
|
142 |
-
if not ok:
|
143 |
-
for i in range(parser.num_errors):
|
144 |
-
print(f"[TRT][parser] {parser.get_error(i)}")
|
145 |
-
raise RuntimeError("ONNX parse failed")
|
146 |
-
|
147 |
-
config = builder.create_builder_config()
|
148 |
-
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
|
149 |
-
if builder.platform_has_fast_fp16:
|
150 |
-
config.set_flag(trt.BuilderFlag.FP16)
|
151 |
-
|
152 |
-
print(f"[export] Building TRT engine (Python): {engine_path}")
|
153 |
-
serialized = builder.build_serialized_network(network, config)
|
154 |
-
assert serialized is not None, "Failed to build TensorRT engine"
|
155 |
-
with open(str(engine_path), "wb") as f:
|
156 |
-
f.write(serialized)
|
157 |
-
|
158 |
-
build_engine_py(coarse_onnx, coarse_engine)
|
159 |
-
build_engine_py(fine_onnx, fine_engine)
|
160 |
-
else:
|
161 |
-
print("[export] Skipping TensorRT engine build (use --build_trt to enable)")
|
162 |
-
|
163 |
-
print("[export] Done.")
|
164 |
-
|
165 |
-
|
166 |
-
if __name__ == "__main__":
|
167 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/trt_infer.py
DELETED
@@ -1,488 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import pprint
|
4 |
-
import time
|
5 |
-
from typing import List, Tuple, Optional, Dict, Any
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import cv2
|
9 |
-
|
10 |
-
# TensorRT + CUDA
|
11 |
-
try:
|
12 |
-
import tensorrt as trt
|
13 |
-
import pycuda.autoinit # noqa: F401 # initializes CUDA driver context
|
14 |
-
import pycuda.driver as cuda
|
15 |
-
except Exception as e: # pragma: no cover
|
16 |
-
raise RuntimeError(
|
17 |
-
"TensorRT runner requires 'tensorrt' and 'pycuda' Python packages and a valid CUDA/TensorRT install"
|
18 |
-
) from e
|
19 |
-
|
20 |
-
import yaml
|
21 |
-
from pathlib import Path
|
22 |
-
|
23 |
-
|
24 |
-
# ---- Utility: TRT engine wrapper ----
|
25 |
-
class TrtEngine:
|
26 |
-
def __init__(self, engine_path: str):
|
27 |
-
assert Path(engine_path).is_file(), f"Engine not found: {engine_path}"
|
28 |
-
logger = trt.Logger(trt.Logger.ERROR)
|
29 |
-
with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime:
|
30 |
-
self.engine = runtime.deserialize_cuda_engine(f.read())
|
31 |
-
assert self.engine is not None, f"Failed to load engine: {engine_path}"
|
32 |
-
self.context = self.engine.create_execution_context()
|
33 |
-
# Default to profile 0
|
34 |
-
try:
|
35 |
-
self.context.active_optimization_profile = 0
|
36 |
-
except Exception:
|
37 |
-
pass
|
38 |
-
self.stream = cuda.Stream()
|
39 |
-
self.bindings: List[int] = [0] * self.engine.num_bindings
|
40 |
-
self.host_mem: Dict[int, Any] = {}
|
41 |
-
self.device_mem: Dict[int, Any] = {}
|
42 |
-
|
43 |
-
def _allocate_binding(self, idx: int, shape: Tuple[int, ...]):
|
44 |
-
dtype = trt.nptype(self.engine.get_binding_dtype(idx))
|
45 |
-
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
|
46 |
-
if idx in self.device_mem:
|
47 |
-
# Reuse if same size; else reallocate
|
48 |
-
old = self.device_mem[idx]
|
49 |
-
if old.size >= nbytes:
|
50 |
-
self.host_mem[idx] = np.empty(shape, dtype=dtype)
|
51 |
-
return
|
52 |
-
# free old and reallocate
|
53 |
-
del old
|
54 |
-
self.host_mem[idx] = np.empty(shape, dtype=dtype)
|
55 |
-
self.device_mem[idx] = cuda.mem_alloc(nbytes)
|
56 |
-
|
57 |
-
def infer(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
58 |
-
# Map names -> indices
|
59 |
-
name_to_idx = {self.engine.get_binding_name(i): i for i in range(self.engine.num_bindings)}
|
60 |
-
# Set input shapes (handle dynamic batch if present)
|
61 |
-
for name, arr in inputs.items():
|
62 |
-
idx = name_to_idx[name]
|
63 |
-
assert self.engine.binding_is_input(idx)
|
64 |
-
# Set shape if dynamic
|
65 |
-
shape = tuple(arr.shape)
|
66 |
-
try:
|
67 |
-
self.context.set_binding_shape(idx, shape)
|
68 |
-
except Exception:
|
69 |
-
# Static shape engines won't allow setting; assert it matches
|
70 |
-
eshape = tuple(self.engine.get_binding_shape(idx))
|
71 |
-
assert eshape == shape, f"Static engine expects {eshape}, got {shape} for input {name}"
|
72 |
-
self._allocate_binding(idx, shape)
|
73 |
-
|
74 |
-
# Allocate outputs for resolved shapes
|
75 |
-
for i in range(self.engine.num_bindings):
|
76 |
-
if not self.engine.binding_is_input(i):
|
77 |
-
shape = tuple(self.context.get_binding_shape(i))
|
78 |
-
assert all(s > 0 for s in shape), f"Unresolved output shape at binding {i}: {shape}"
|
79 |
-
self._allocate_binding(i, shape)
|
80 |
-
|
81 |
-
# Copy inputs H2D
|
82 |
-
for name, arr in inputs.items():
|
83 |
-
idx = name_to_idx[name]
|
84 |
-
h_arr = self.host_mem[idx]
|
85 |
-
assert h_arr.shape == arr.shape and h_arr.dtype == arr.dtype
|
86 |
-
cuda.memcpy_htod_async(self.device_mem[idx], arr, self.stream)
|
87 |
-
self.bindings[idx] = int(self.device_mem[idx])
|
88 |
-
|
89 |
-
# Set output bindings
|
90 |
-
for i in range(self.engine.num_bindings):
|
91 |
-
if not self.engine.binding_is_input(i):
|
92 |
-
self.bindings[i] = int(self.device_mem[i])
|
93 |
-
|
94 |
-
# Execute
|
95 |
-
self.context.execute_async_v2(self.bindings, self.stream.handle)
|
96 |
-
|
97 |
-
# D2H outputs
|
98 |
-
outputs: Dict[str, np.ndarray] = {}
|
99 |
-
for i in range(self.engine.num_bindings):
|
100 |
-
if not self.engine.binding_is_input(i):
|
101 |
-
name = self.engine.get_binding_name(i)
|
102 |
-
h_arr = self.host_mem[i]
|
103 |
-
cuda.memcpy_dtoh_async(h_arr, self.device_mem[i], self.stream)
|
104 |
-
outputs[name] = h_arr
|
105 |
-
|
106 |
-
self.stream.synchronize()
|
107 |
-
# Return copies to detach from internal buffers
|
108 |
-
return {k: np.array(v) for k, v in outputs.items()}
|
109 |
-
|
110 |
-
|
111 |
-
# ---- Pre/post processing consistent with infer.py ----
|
112 |
-
|
113 |
-
def _pad_for_minmax(kernel: int) -> Tuple[int, int, int, int]:
|
114 |
-
if (kernel % 2) == 0:
|
115 |
-
return (kernel // 2 - 1, kernel // 2, kernel // 2 - 1, kernel // 2)
|
116 |
-
else:
|
117 |
-
return (kernel // 2, kernel // 2, kernel // 2, kernel // 2)
|
118 |
-
|
119 |
-
|
120 |
-
def _build_6ch_coarse(rgb: np.ndarray, coarse_size: int, minmax_enable: bool, minmax_kernel: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
121 |
-
# rgb: HxWx3 float32 [0,1]
|
122 |
-
H, W = int(rgb.shape[0]), int(rgb.shape[1])
|
123 |
-
# To match training/minmax in torch, we replicate exact pad + pool logic via torch CPU
|
124 |
-
import torch
|
125 |
-
import torch.nn.functional as F
|
126 |
-
|
127 |
-
t_img = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float() # 1x3xHxW
|
128 |
-
y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
|
129 |
-
if minmax_enable:
|
130 |
-
pad = _pad_for_minmax(minmax_kernel)
|
131 |
-
y_p = F.pad(y_t, pad, mode="replicate")
|
132 |
-
y_max_full = F.max_pool2d(y_p, kernel_size=minmax_kernel, stride=1)
|
133 |
-
y_min_full = -F.max_pool2d(-y_p, kernel_size=minmax_kernel, stride=1)
|
134 |
-
else:
|
135 |
-
y_min_full = y_t
|
136 |
-
y_max_full = y_t
|
137 |
-
|
138 |
-
# Resize for coarse
|
139 |
-
rgb_c = cv2.resize(rgb, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
|
140 |
-
y_min_c = cv2.resize(y_min_full[0, 0].numpy(), (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
|
141 |
-
y_max_c = cv2.resize(y_max_full[0, 0].numpy(), (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
|
142 |
-
|
143 |
-
zeros_c = np.zeros((coarse_size, coarse_size), dtype=np.float32)
|
144 |
-
x6 = np.stack([
|
145 |
-
rgb_c[:, :, 0], rgb_c[:, :, 1], rgb_c[:, :, 2], y_min_c, y_max_c, zeros_c
|
146 |
-
], axis=0) # 6xHc x Wc
|
147 |
-
return x6.astype(np.float32), y_min_full[0, 0].numpy().astype(np.float32), y_max_full[0, 0].numpy().astype(np.float32), t_img.numpy().astype(np.float32)
|
148 |
-
|
149 |
-
|
150 |
-
def _softmax_channel(x: np.ndarray, axis: int = 1) -> np.ndarray:
|
151 |
-
x_max = np.max(x, axis=axis, keepdims=True)
|
152 |
-
e = np.exp(x - x_max)
|
153 |
-
return e / np.sum(e, axis=axis, keepdims=True)
|
154 |
-
|
155 |
-
|
156 |
-
def _tiled_fine_trt(
|
157 |
-
fine: TrtEngine,
|
158 |
-
t_img: np.ndarray, # 1x3xHxW float32
|
159 |
-
cond_map: np.ndarray, # 1x1xhxw float32
|
160 |
-
y_min_full: np.ndarray, # HxW float32
|
161 |
-
y_max_full: np.ndarray, # HxW float32
|
162 |
-
patch_size: int,
|
163 |
-
overlap: int,
|
164 |
-
fine_batch: int,
|
165 |
-
) -> np.ndarray:
|
166 |
-
H, W = int(t_img.shape[2]), int(t_img.shape[3])
|
167 |
-
P = patch_size
|
168 |
-
stride = P - overlap
|
169 |
-
assert stride > 0
|
170 |
-
assert H >= P and W >= P
|
171 |
-
|
172 |
-
prob_sum = np.zeros((H, W), dtype=np.float32)
|
173 |
-
weight = np.zeros((H, W), dtype=np.float32)
|
174 |
-
|
175 |
-
hc4, wc4 = int(cond_map.shape[2]), int(cond_map.shape[3])
|
176 |
-
|
177 |
-
ys = list(range(0, H - P + 1, stride))
|
178 |
-
if ys[-1] != (H - P):
|
179 |
-
ys.append(H - P)
|
180 |
-
xs = list(range(0, W - P + 1, stride))
|
181 |
-
if xs[-1] != (W - P):
|
182 |
-
xs.append(W - P)
|
183 |
-
|
184 |
-
coords: List[Tuple[int, int]] = [(y0, x0) for y0 in ys for x0 in xs]
|
185 |
-
|
186 |
-
# Run with batches supported by engine if dynamic; otherwise enforce 1
|
187 |
-
input_name = None
|
188 |
-
for i in range(fine.engine.num_bindings):
|
189 |
-
if fine.engine.binding_is_input(i):
|
190 |
-
input_name = fine.engine.get_binding_name(i)
|
191 |
-
shape_decl = fine.engine.get_binding_shape(i)
|
192 |
-
break
|
193 |
-
assert input_name is not None
|
194 |
-
|
195 |
-
dynamic_batch = -1 in list(shape_decl)
|
196 |
-
batch_allowed = fine_batch if dynamic_batch else 1
|
197 |
-
|
198 |
-
for i0 in range(0, len(coords), batch_allowed):
|
199 |
-
batch_coords = coords[i0 : i0 + batch_allowed]
|
200 |
-
B = len(batch_coords)
|
201 |
-
xs_list: List[np.ndarray] = []
|
202 |
-
for (y0, x0) in batch_coords:
|
203 |
-
y1, x1 = y0 + P, x0 + P
|
204 |
-
y0c = (y0 * hc4) // H
|
205 |
-
y1c = ((y1 * hc4) + H - 1) // H
|
206 |
-
x0c = (x0 * wc4) // W
|
207 |
-
x1c = ((x1 * wc4) + W - 1) // W
|
208 |
-
cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c][0, 0]
|
209 |
-
cond_patch = cv2.resize(cond_sub, (P, P), interpolation=cv2.INTER_LINEAR)
|
210 |
-
|
211 |
-
rgb_patch = t_img[0, :, y0:y1, x0:x1] # 3xPxP
|
212 |
-
ymin_patch = y_min_full[y0:y1, x0:x1][None, ...] # 1xPxP
|
213 |
-
ymax_patch = y_max_full[y0:y1, x0:x1][None, ...] # 1xPxP
|
214 |
-
x6 = np.concatenate([rgb_patch, ymin_patch, ymax_patch, cond_patch[None, ...]], axis=0)
|
215 |
-
xs_list.append(x6)
|
216 |
-
|
217 |
-
x_batch = np.stack(xs_list, axis=0).astype(np.float32) # Bx6xPxP
|
218 |
-
outputs = fine.infer({input_name: x_batch})
|
219 |
-
|
220 |
-
# Assume single output named 'logits' or similar; take the first one
|
221 |
-
out_name = [n for n in outputs.keys()][0]
|
222 |
-
logits = outputs[out_name] # Bx2xPxP
|
223 |
-
prob = _softmax_channel(logits, axis=1)[:, 1, :, :] # BxPxP
|
224 |
-
|
225 |
-
for bi, (y0, x0) in enumerate(batch_coords):
|
226 |
-
y1, x1 = y0 + P, x0 + P
|
227 |
-
prob_sum[y0:y1, x0:x1] += prob[bi]
|
228 |
-
weight[y0:y1, x0:x1] += 1.0
|
229 |
-
|
230 |
-
prob_full = prob_sum / weight
|
231 |
-
return prob_full.astype(np.float32)
|
232 |
-
|
233 |
-
|
234 |
-
def _coarse_trt(coarse: TrtEngine, rgb: np.ndarray, coarse_size: int, minmax_enable: bool, minmax_kernel: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
235 |
-
x6, y_min_full, y_max_full, t_img = _build_6ch_coarse(rgb, coarse_size, minmax_enable, minmax_kernel)
|
236 |
-
# Engine input name
|
237 |
-
input_name = None
|
238 |
-
for i in range(coarse.engine.num_bindings):
|
239 |
-
if coarse.engine.binding_is_input(i):
|
240 |
-
input_name = coarse.engine.get_binding_name(i)
|
241 |
-
break
|
242 |
-
assert input_name is not None
|
243 |
-
x = x6[None, ...].astype(np.float32) # 1x6xHc x Wc
|
244 |
-
outputs = coarse.infer({input_name: x})
|
245 |
-
# Identify outputs: we expect 2 outputs (logits 1x2xHc x Wc, cond 1x1xHc x Wc)
|
246 |
-
assert len(outputs) == 2, f"Coarse engine must have 2 outputs, got {list(outputs.keys())}"
|
247 |
-
# Determine which is cond by channel dim =1
|
248 |
-
names = list(outputs.keys())
|
249 |
-
a, b = outputs[names[0]], outputs[names[1]]
|
250 |
-
if a.shape[1] == 1:
|
251 |
-
cond = a
|
252 |
-
logits = b
|
253 |
-
else:
|
254 |
-
cond = b
|
255 |
-
logits = a
|
256 |
-
# Coarse prob upsampled to full HxW (optional)
|
257 |
-
prob_c = _softmax_channel(logits, axis=1)[:, 1:2]
|
258 |
-
H, W = int(t_img.shape[2]), int(t_img.shape[3])
|
259 |
-
prob_up = cv2.resize(prob_c[0, 0], (W, H), interpolation=cv2.INTER_LINEAR)
|
260 |
-
return prob_up.astype(np.float32), cond.astype(np.float32), t_img.astype(np.float32), y_min_full, y_max_full
|
261 |
-
|
262 |
-
|
263 |
-
# ---- Inference API ----
|
264 |
-
|
265 |
-
def infer_image_trt(
|
266 |
-
coarse: TrtEngine,
|
267 |
-
fine: TrtEngine,
|
268 |
-
img_path: str,
|
269 |
-
cfg: dict,
|
270 |
-
out_dir: Optional[str] = None,
|
271 |
-
save_prob: bool = False,
|
272 |
-
prob_thresh: Optional[float] = None,
|
273 |
-
) -> Tuple[np.ndarray, np.ndarray]:
|
274 |
-
assert Path(img_path).is_file(), f"Image not found: {img_path}"
|
275 |
-
bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
276 |
-
assert bgr is not None, f"Failed to read {img_path}"
|
277 |
-
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
278 |
-
|
279 |
-
coarse_size = int(cfg["coarse"]["test_size"])
|
280 |
-
patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024 for inference
|
281 |
-
overlap = int(cfg["fine"]["overlap"])
|
282 |
-
minmax_enable = bool(cfg["minmax"]["enable"])
|
283 |
-
minmax_kernel = int(cfg["minmax"]["kernel"])
|
284 |
-
if prob_thresh is None:
|
285 |
-
prob_thresh = float(cfg["inference"]["prob_threshold"])
|
286 |
-
|
287 |
-
prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_trt(
|
288 |
-
coarse, rgb, coarse_size, minmax_enable, minmax_kernel
|
289 |
-
)
|
290 |
-
|
291 |
-
prob_f = _tiled_fine_trt(
|
292 |
-
fine,
|
293 |
-
t_img,
|
294 |
-
cond_map,
|
295 |
-
y_min_full,
|
296 |
-
y_max_full,
|
297 |
-
patch_size,
|
298 |
-
overlap,
|
299 |
-
int(cfg.get("eval", {}).get("fine_batch", 16)),
|
300 |
-
)
|
301 |
-
|
302 |
-
pred = (prob_f > prob_thresh).astype(np.uint8) * 255
|
303 |
-
|
304 |
-
if out_dir is not None:
|
305 |
-
Path(out_dir).mkdir(parents=True, exist_ok=True)
|
306 |
-
stem = Path(img_path).stem
|
307 |
-
out_mask = Path(out_dir) / f"{stem}_pred.png"
|
308 |
-
cv2.imwrite(str(out_mask), pred)
|
309 |
-
if save_prob:
|
310 |
-
out_prob = Path(out_dir) / f"{stem}_prob.npy"
|
311 |
-
np.save(str(out_prob), prob_f.astype(np.float32))
|
312 |
-
|
313 |
-
return pred, prob_f
|
314 |
-
|
315 |
-
|
316 |
-
def main():
|
317 |
-
parser = argparse.ArgumentParser(description="WireSegHR TensorRT Inference")
|
318 |
-
parser.add_argument("--config", type=str, default="configs/default.yaml")
|
319 |
-
parser.add_argument("--coarse_engine", type=str, required=True)
|
320 |
-
parser.add_argument("--fine_engine", type=str, required=True)
|
321 |
-
parser.add_argument("--image", type=str, default="", help="Path to single image")
|
322 |
-
parser.add_argument("--images_dir", type=str, default="", help="Directory with images")
|
323 |
-
parser.add_argument("--out", type=str, default="outputs/trt_infer")
|
324 |
-
parser.add_argument("--save_prob", action="store_true")
|
325 |
-
# Benchmarking
|
326 |
-
parser.add_argument("--benchmark", action="store_true")
|
327 |
-
parser.add_argument("--bench_images_dir", type=str, default="")
|
328 |
-
parser.add_argument("--bench_limit", type=int, default=0)
|
329 |
-
parser.add_argument("--bench_warmup", type=int, default=2)
|
330 |
-
parser.add_argument("--bench_size_filter", type=str, default="")
|
331 |
-
parser.add_argument("--bench_report_json", type=str, default="")
|
332 |
-
|
333 |
-
args = parser.parse_args()
|
334 |
-
|
335 |
-
with open(args.config, "r") as f:
|
336 |
-
cfg = yaml.safe_load(f)
|
337 |
-
print("[TRT][infer] Loaded config:")
|
338 |
-
pprint.pprint(cfg)
|
339 |
-
|
340 |
-
coarse = TrtEngine(args.coarse_engine)
|
341 |
-
fine = TrtEngine(args.fine_engine)
|
342 |
-
|
343 |
-
if args.benchmark:
|
344 |
-
bench_dir = args.bench_images_dir or cfg["data"]["test_images"]
|
345 |
-
assert Path(bench_dir).is_dir(), f"Not a directory: {bench_dir}"
|
346 |
-
size_filter: Optional[Tuple[int, int]] = None
|
347 |
-
if args.bench_size_filter:
|
348 |
-
try:
|
349 |
-
h_str, w_str = args.bench_size_filter.lower().split("x")
|
350 |
-
size_filter = (int(h_str), int(w_str))
|
351 |
-
except Exception:
|
352 |
-
raise AssertionError(
|
353 |
-
f"Invalid --bench_size_filter format: {args.bench_size_filter} (use HxW)"
|
354 |
-
)
|
355 |
-
img_files = sorted(
|
356 |
-
[
|
357 |
-
str(Path(bench_dir) / p)
|
358 |
-
for p in os.listdir(bench_dir)
|
359 |
-
if p.lower().endswith((".jpg", ".jpeg"))
|
360 |
-
]
|
361 |
-
)
|
362 |
-
assert len(img_files) > 0, f"No .jpg/.jpeg in {bench_dir}"
|
363 |
-
|
364 |
-
if size_filter is not None:
|
365 |
-
sel: List[str] = []
|
366 |
-
for p in img_files:
|
367 |
-
im = cv2.imread(p, cv2.IMREAD_COLOR)
|
368 |
-
assert im is not None
|
369 |
-
if im.shape[0] == size_filter[0] and im.shape[1] == size_filter[1]:
|
370 |
-
sel.append(p)
|
371 |
-
img_files = sel
|
372 |
-
assert len(img_files) > 0, (
|
373 |
-
f"No images matching {size_filter[0]}x{size_filter[1]} in {bench_dir}"
|
374 |
-
)
|
375 |
-
|
376 |
-
if args.bench_limit > 0:
|
377 |
-
img_files = img_files[: args.bench_limit]
|
378 |
-
|
379 |
-
print(f"[TRT][bench] Images: {len(img_files)} from {bench_dir}")
|
380 |
-
print(f"[TRT][bench] Warmup: {args.bench_warmup}")
|
381 |
-
|
382 |
-
timings: List[Dict[str, Any]] = []
|
383 |
-
# Warmup
|
384 |
-
for i in range(min(args.bench_warmup, len(img_files))):
|
385 |
-
infer_image_trt(coarse, fine, img_files[i], cfg, out_dir=None, save_prob=False)
|
386 |
-
|
387 |
-
# Timed runs
|
388 |
-
for p in img_files[args.bench_warmup :]:
|
389 |
-
t0 = time.perf_counter()
|
390 |
-
bgr = cv2.imread(p, cv2.IMREAD_COLOR)
|
391 |
-
assert bgr is not None
|
392 |
-
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
393 |
-
|
394 |
-
coarse_size = int(cfg["coarse"]["test_size"])
|
395 |
-
minmax_enable = bool(cfg["minmax"]["enable"])
|
396 |
-
minmax_kernel = int(cfg["minmax"]["kernel"])
|
397 |
-
|
398 |
-
c0 = time.perf_counter()
|
399 |
-
prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_trt(
|
400 |
-
coarse, rgb, coarse_size, minmax_enable, minmax_kernel
|
401 |
-
)
|
402 |
-
c1 = time.perf_counter()
|
403 |
-
|
404 |
-
patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024
|
405 |
-
overlap = int(cfg["fine"]["overlap"])
|
406 |
-
|
407 |
-
prob_f = _tiled_fine_trt(
|
408 |
-
fine,
|
409 |
-
t_img,
|
410 |
-
cond_map,
|
411 |
-
y_min_full,
|
412 |
-
y_max_full,
|
413 |
-
patch_size,
|
414 |
-
overlap,
|
415 |
-
int(cfg.get("eval", {}).get("fine_batch", 16)),
|
416 |
-
)
|
417 |
-
c2 = time.perf_counter()
|
418 |
-
|
419 |
-
timings.append(
|
420 |
-
{
|
421 |
-
"path": p,
|
422 |
-
"H": int(t_img.shape[2]),
|
423 |
-
"W": int(t_img.shape[3]),
|
424 |
-
"t_coarse_ms": (c1 - c0) * 1000.0,
|
425 |
-
"t_fine_ms": (c2 - c1) * 1000.0,
|
426 |
-
"t_total_ms": (c2 - t0) * 1000.0,
|
427 |
-
}
|
428 |
-
)
|
429 |
-
|
430 |
-
if len(timings) == 0:
|
431 |
-
print("[TRT][bench] Nothing to benchmark after warmup.")
|
432 |
-
return
|
433 |
-
|
434 |
-
def _agg(key: str) -> Tuple[float, float, float]:
|
435 |
-
vals = sorted([t[key] for t in timings])
|
436 |
-
n = len(vals)
|
437 |
-
p50 = vals[n // 2]
|
438 |
-
p95 = vals[min(n - 1, int(0.95 * (n - 1)))]
|
439 |
-
avg = sum(vals) / n
|
440 |
-
return avg, p50, p95
|
441 |
-
|
442 |
-
avg_c, p50_c, p95_c = _agg("t_coarse_ms")
|
443 |
-
avg_f, p50_f, p95_f = _agg("t_fine_ms")
|
444 |
-
avg_t, p50_t, p95_t = _agg("t_total_ms")
|
445 |
-
|
446 |
-
print("[TRT][bench] Results (ms):")
|
447 |
-
print(f" Coarse avg={avg_c:.2f} p50={p50_c:.2f} p95={p95_c:.2f}")
|
448 |
-
print(f" Fine avg={avg_f:.2f} p50={p50_f:.2f} p95={p95_f:.2f}")
|
449 |
-
print(f" Total avg={avg_t:.2f} p50={p50_t:.2f} p95={p95_t:.2f}")
|
450 |
-
print(f" Target < 1000 ms per 3000x4000 image: {'YES' if p50_t < 1000.0 else 'NO'}")
|
451 |
-
|
452 |
-
if args.bench_report_json:
|
453 |
-
import json
|
454 |
-
report = {
|
455 |
-
"summary": {
|
456 |
-
"avg_ms": avg_t,
|
457 |
-
"p50_ms": p50_t,
|
458 |
-
"p95_ms": p95_t,
|
459 |
-
"avg_coarse_ms": avg_c,
|
460 |
-
"avg_fine_ms": avg_f,
|
461 |
-
"images": len(timings),
|
462 |
-
},
|
463 |
-
"per_image": timings,
|
464 |
-
}
|
465 |
-
with open(args.bench_report_json, "w") as f:
|
466 |
-
json.dump(report, f, indent=2)
|
467 |
-
return
|
468 |
-
|
469 |
-
# Non-benchmark single/directory
|
470 |
-
assert (args.image != "") ^ (args.images_dir != ""), "Provide exactly one of --image or --images_dir"
|
471 |
-
if args.image:
|
472 |
-
infer_image_trt(coarse, fine, args.image, cfg, out_dir=args.out, save_prob=args.save_prob)
|
473 |
-
print("[TRT][infer] Done.")
|
474 |
-
return
|
475 |
-
|
476 |
-
img_dir = args.images_dir
|
477 |
-
assert Path(img_dir).is_dir()
|
478 |
-
Path(args.out).mkdir(parents=True, exist_ok=True)
|
479 |
-
img_files = sorted([p for p in os.listdir(img_dir) if p.lower().endswith((".jpg", ".jpeg"))])
|
480 |
-
assert len(img_files) > 0
|
481 |
-
for name in img_files:
|
482 |
-
p = str(Path(img_dir) / name)
|
483 |
-
infer_image_trt(coarse, fine, p, cfg, out_dir=args.out, save_prob=args.save_prob)
|
484 |
-
print("[TRT][infer] Done.")
|
485 |
-
|
486 |
-
|
487 |
-
if __name__ == "__main__":
|
488 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
CHANGED
@@ -401,30 +401,6 @@ def main():
|
|
401 |
print("[WireSegHR][train] Done.")
|
402 |
|
403 |
|
404 |
-
def _sample_batch_same_size(
|
405 |
-
dset: WireSegDataset, batch_size: int
|
406 |
-
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
407 |
-
# Use precomputed size bins to sample a batch from a single (H, W) bin
|
408 |
-
assert len(dset) > 0
|
409 |
-
bins = dset.size_bins
|
410 |
-
keys = list(bins.keys())
|
411 |
-
random.shuffle(keys)
|
412 |
-
chosen_key = None
|
413 |
-
for hw in keys:
|
414 |
-
if len(bins[hw]) >= batch_size:
|
415 |
-
chosen_key = hw
|
416 |
-
break
|
417 |
-
assert chosen_key is not None, f"No size bin with at least {batch_size} samples"
|
418 |
-
pool = bins[chosen_key]
|
419 |
-
idxs = np.random.choice(pool, size=batch_size, replace=False)
|
420 |
-
imgs: List[np.ndarray] = []
|
421 |
-
masks: List[np.ndarray] = []
|
422 |
-
for idx in idxs:
|
423 |
-
item = dset[int(idx)]
|
424 |
-
imgs.append(item["image"])
|
425 |
-
masks.append(item["mask"])
|
426 |
-
return imgs, masks
|
427 |
-
|
428 |
|
429 |
def _prepare_batch(
|
430 |
imgs: List[np.ndarray],
|
|
|
401 |
print("[WireSegHR][train] Done.")
|
402 |
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
def _prepare_batch(
|
406 |
imgs: List[np.ndarray],
|