WireSegHR (Segmentation Only)
This repository contains the segmentation-only implementation of the two-stage WireSegHR model, training on the WireSegHR dataset plus the TTPLA dataset.
Quick Start
- Get secrets necessary for fetching of the dataset:
You'll need a GDrive service account to fetch the WireSegHR dataset using scripts in this repo. Get a GDrive key as described in this short README, and put it in /secrets/drive-json.json
- Run:
scripts/setup.sh
This installs dependencies and merges the TTPLA dataset into the WireSegHR dataset format.
- Train and run a quick inference check:
python3 train.py --config configs/default.yaml
python3 infer.py --config configs/default.yaml --image /path/to/image.jpg
The default config default.yaml
is suitable for a 24GB VRAM GPU with support for bf16 (e.g., RTX 3090/4090).
Project Overview
- Two-stage, global-to-local segmentation with a shared encoder and a fine decoder conditioned on the coarse stage.
- Full training loop with AMP (optional), Poly LR, periodic evaluation, checkpointing, and test visualizations (
train.py
). - Dataset utilities under
src/wireseghr/data/
and model components undersrc/wireseghr/model/
. - Paper text and figures live in
paper-tex/
(paper-tex/sections/
contains the Method, Results, etc.).
Notes
- This is a segmentation-only codebase. Inpainting is out of scope here.
- Defaults locked: SegFormer MiT-B3 encoder, patch size 768, MinMax 6ร6, global+binary mask conditioning with patch-cropped global map.
Backbone Source
- HuggingFace Transformers SegFormer (e.g.,
nvidia/mit-b3
). We setnum_channels
to match input channels. - Alternative: TorchVision ResNet-50 (
backbone: resnet50
). The stem is adapted to the requestedin_channels
, and we expose features fromlayer1
..layer4
at strides 1/4, 1/8, 1/16, 1/32 with channels [256, 512, 1024, 2048].
Dataset Convention
- Flat directories with numeric filenames; images are
.jpg
/.jpeg
, masks are.png
. - Example (after split 85/5/10):
dataset/train/images/1.jpg, 2.jpg, ...
anddataset/train/gts/1.png, 2.png, ...
dataset/val/images/...
anddataset/val/gts/...
dataset/test/images/...
anddataset/test/gts/...
- Masks are binary: foreground = white (255), background = black (0).
- The loader strictly enforces numeric stems and 1:1 pairing of naming and will raise on file name mismatches.
Update configs/default.yaml
with your paths under data.train_images
, data.train_masks
, etc. Defaults point to dataset/train/images
, dataset/train/gts
, and validation to dataset/val/...
.
Inference
- Single image (optionally save outputs to a directory):
python3 infer.py \
--config configs/default.yaml \
--ckpt ckpt_5000.pt \
--image dataset/test/images/123.jpg \
--out outputs/infer
- Compute metrics for a single image (requires a GT mask):
python3 infer.py \
--config configs/default.yaml \
--ckpt ckpt_5000.pt \
--image dataset/test/images/123.jpg \
--out outputs/infer \
--metrics \
--mask dataset/test/gts/123.png
- Run inference over the entire directory with metrics (images_dir sets the image directory, masks_dir sets the ground truth mask directory):
python3 infer.py \
--config configs/default.yaml \
--ckpt ckpt_5000.pt \
--images_dir dataset/test/images \
--out outputs/infer \
--metrics \
--masks_dir dataset/test/gts
Notes:
- Predictions are saved as 0/255 PNGs. For metrics, predictions are binarized with
> 0
to match training logic. - Masks are matched by filename stem:
images/123.jpg
โgts/123.png
.
Benchmarking and Metrics
Benchmark mode times the model on a directory of images and reports coarse/fine/total latency statistics. When --metrics
is provided, it also computes IoU/F1/Precision/Recall over the benchmark set (both fine and coarse outputs).
Example (uses data.test_images
and data.test_masks
from the config by default):
python3 infer.py \
--config configs/default.yaml \
--benchmark \
--ckpt ckpt_5000.pt \
--bench_warmup 2 \
--bench_limit 0 \
--bench_report_json outputs/bench_report.json \
--metrics
If your ground truth directory is different from data.test_masks
, please override it with --bench_masks_dir
:
python3 infer.py \
--config configs/default.yaml \
--benchmark \
--ckpt ckpt_5000.pt \
--bench_warmup 2 \
--bench_limit 0 \
--bench_report_json outputs/bench_report.json \
--metrics \
--bench_masks_dir /path/to/gts
You will see output like:
[WireSegHR][bench] Results (ms):
Coarse avg=50.16 p50=44.48 p95=76.78
Fine avg=534.38 p50=419.52 p95=1187.66
Total avg=584.54 p50=464.73 p95=1300.07
Target < 1000 ms per 3000x4000 image: YES
[WireSegHR][bench][Fine] IoU=0.6098 F1=0.7576 P=0.6418 R=0.9244
[WireSegHR][bench][Coarse] IoU=0.5315 F1=0.6941 P=0.5467 R=0.9502
*These metrics were obtained after 5000 iterations
Optional: you can save a JSON timing report with --bench_report_json
. Schema:
summary
avg_ms
,p50_ms
,p95_ms
avg_coarse_ms
,avg_fine_ms
images
per_image
: list of objects withpath
,H
,W
,t_coarse_ms
,t_fine_ms
,t_total_ms
Utils:
- Export your model to inference-only weights by scripts/strip_checkpoint.py
Example:
Input:
Fine prediction: