MRiabov commited on
Commit
8ea2eff
·
1 Parent(s): 5cab910

Project skeleton generated

Browse files
.gitignore ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # PyInstaller
27
+ # Usually these files are written by a python script from a template
28
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
29
+ *.manifest
30
+ *.spec
31
+
32
+ # Installer logs
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Unit test / coverage reports
37
+ htmlcov/
38
+ .tox/
39
+ .nox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *.cover
46
+ .hypothesis/
47
+ .pytest_cache/
48
+ pytestdebug.log
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+ db.sqlite3-journal
59
+
60
+ # Flask stuff:
61
+ instance/
62
+ .webassets-cache
63
+
64
+ # Scrapy stuff:
65
+ .scrapy
66
+
67
+ # Sphinx documentation
68
+ docs/_build/
69
+
70
+ # PyBuilder
71
+ target/
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # IPython
77
+ profile_default/
78
+ ipython_config.py
79
+
80
+ # pyenv
81
+ .python-version
82
+
83
+ # pipenv
84
+ Pipfile.lock
85
+
86
+ # poetry
87
+ poetry.lock
88
+
89
+ # mypy
90
+ .mypy_cache/
91
+ .dmypy.json
92
+ dmypy.json
93
+
94
+ # Pyre type checker
95
+ .pyre/
96
+
97
+ # pytype
98
+ .pytype/
99
+
100
+ # Cython debug symbols
101
+ cython_debug/
102
+
103
+ # VS Code
104
+ .vscode/
105
+
106
+ # Mac
107
+ .DS_Store
108
+
109
+ # Environment
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
.windsurf/rules/creating-new-class-variables.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ trigger: model_decision
3
+ description: When creating new class variable
4
+ ---
5
+
6
+ It is preferable to create class variable docstrings instead of comments. E.g:
7
+
8
+ ```py
9
+ class Class123:
10
+ var1: int
11
+ """Variable description"""
12
+ ```
13
+ is preferred over
14
+ ```py
15
+ class Class123:
16
+ # Variable description
17
+ var1: int
18
+ ```
.windsurf/rules/defensive-logic.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ trigger: always_on
3
+ ---
4
+
5
+ When deciding if to write defensive logic, e.g. dimensionality handling: `tensor1=tensor1.unsqueeze() if tensor1.ndim==1 else tensor1`, or None handling: `var1=var1 if var1 else torch.zeros(...)`, just don't write these things. In my code, shapes are always static, and there is one execution path for all code. I prefer `assert` over defensive logic. If you are writing something to fix the tests and this seems necessary, it's likely that the tests are setup incorrectly.
6
+ The reason why I don't want it is because defensive logic leads to silent failures, and these are bad for debugging.
7
+
8
+ In addition, writing "int()" type casting is also a piece of defensive logic that slows down the application. Don't write it unless really necessary e.g. putting int to string. In most cases, indexing with a tensor should be better.
.windsurf/rules/running-tests.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ trigger: model_decision
3
+ description: When deciding which tests to run
4
+ ---
5
+
6
+ When calling `pytest`, it's better to never execute the whole test suite because it takes over 5 minutes to run. So, never run `pytest -q` without test file or `-k`. Instead, run an individual test function, class, module or a combination of them.
.windsurf/rules/when-pytest-not-found.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ trigger: model_decision
3
+ description: When pytest is not found
4
+ ---
5
+
6
+ Sometimes when running tests, `pytest` would be not found. This means that `venv` is not activated, and you can activate it with `source ../.venv/bin/activate`. After that, it will certainly work. Do not try to run `./venv/bin/pytest` directly, activate the venv, and then run `pytest` as usual.
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WireSegHR (Segmentation Only)
2
+
3
+ This repository contains the segmentation-only implementation plan and code skeleton for the two-stage WireSegHR model (global-to-local, shared encoder).
4
+
5
+ - Paper sources live under `paper-tex/`.
6
+ - Long-term navigation plan: `SEGMENTATION_PLAN.md`.
7
+
8
+ ## Quick Start (skeleton)
9
+
10
+ 1) Create a virtual environment and install requirements:
11
+
12
+ ```bash
13
+ python -m venv .venv
14
+ source .venv/bin/activate
15
+ pip install -r requirements.txt
16
+ ```
17
+
18
+ 2) Print configuration and verify the skeleton runs:
19
+
20
+ ```bash
21
+ python src/wireseghr/train.py --config configs/default.yaml
22
+ python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/image.png
23
+ ```
24
+
25
+ 3) Next steps:
26
+ - Implement encoder/decoders/condition/minmax/label downsampling per `SEGMENTATION_PLAN.md`.
27
+ - Implement training and inference logic, then metrics and ablations.
28
+
29
+ ## Notes
30
+ - This is a segmentation-only codebase. Inpainting is out of scope here.
31
+ - Defaults locked: MiT-B3 encoder, patch size 768, MinMax 6×6, global+binary mask conditioning with patch-cropped global map.
SEGMENTATION_PLAN.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WireSegHR Segmentation-Only Implementation Plan
2
+
3
+ This plan distills the model and pipeline described in the paper sources:
4
+ - `paper-tex/sections/method.tex`
5
+ - `paper-tex/sections/method_yq.tex`
6
+ - `paper-tex/figure_tex/pipeline.tex`
7
+ - `paper-tex/tables/{component,logit,thresholds}.tex`
8
+
9
+ Focus: segmentation only (no dataset collection or inpainting).
10
+
11
+ ## Decisions and Defaults (locked)
12
+ - Backbone: SegFormer MiT-B3 (shared encoder `E`).
13
+ - Fine/local patch size p: 768.
14
+ - Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
15
+ - Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
16
+ - MinMax feature augmentation: luminance min and max with a fixed 6×6 window; channels concatenated to inputs (Figure `figure_tex/pipeline.tex`, Sec. “Wire Feature Preservation” in `method_yq.tex`).
17
+ - Loss: CE on both branches, λ = 1 (`method_yq.tex`).
18
+ - α-threshold for refining windows: default 0.01 (Table `tables/thresholds.tex`).
19
+ - Coarse input size: train 512×512; test 1024×1024 (`method.tex`).
20
+ - Optim: AdamW (lr=6e-5, wd=0.01, poly schedule with power=1), ~40k iters, batch size ~8 (`method.tex`).
21
+
22
+ ## Project Structure
23
+ - `configs/`
24
+ - `default.yaml` (backbone=mit_b3, p=768, coarse_train=512, coarse_test=1024, alpha=0.01, minmax=true, kernel=6, maxpool_label=true, cond_variant=global+binary_mask)
25
+ - `src/wireseghr/`
26
+ - `model/`
27
+ - `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
28
+ - `decoder.py` (two MLP decoders `D_C`, `D_F` for 2 classes)
29
+ - `condition.py` (1×1 conv to collapse coarse 2-ch logits → 1-ch cond)
30
+ - `minmax.py` (6×6 luminance min/max filtering)
31
+ - `label_downsample.py` (MaxPool-based coarse GT downsampling)
32
+ - `data/`
33
+ - `dataset.py` (image/mask loading, full-res to coarse/fine inputs)
34
+ - `sampler.py` (balanced patch sampling with ≥1% wire pixels)
35
+ - `transforms.py` (scaling, rotation, flip, photometric distortion)
36
+ - `train.py` (end-to-end two-branch training)
37
+ - `infer.py` (coarse-to-fine sliding-window inference + stitching)
38
+ - `metrics.py` (IoU, F1, Precision, Recall)
39
+ - `utils.py` (misc: overlap blending, seeding, logging)
40
+ - `tests/` (unit tests for channel wiring, cond alignment, stitching)
41
+ - `README.md` (segmentation-only usage)
42
+
43
+ ## Model Specification
44
+ - Shared encoder `E`: SegFormer MiT-B3.
45
+ - Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
46
+ - For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
47
+ - Weight init for extra channels: copy mean of RGB conv weights or zero-init.
48
+ - Decoders: two SegFormer MLP decoders
49
+ - `D_C`: coarse logits (2 channels) at coarse resolution.
50
+ - `D_F`: fine logits (2 channels) at patch resolution p×p.
51
+ - Conditioning to fine branch (default):
52
+ - Take coarse pre-softmax logits (2-ch), apply 1×1 conv → 1-ch cond map (`method.tex`).
53
+ - Binary location mask: 1 inside current patch region (in full-image coordinates), 0 elsewhere.
54
+ - Pass patch-aligned cond crop and binary mask as channels to the fine branch input.
55
+ - Notes:
56
+ - We expose a config toggle to switch conditioning variant between: `global+binary_mask` (default) and `global_only` (Table `tables/logit.tex`).
57
+ - We follow the published version (`paper-tex/sections/method_yq.tex`) and use patch-cropped conditioning exclusively; no full-image conditioning variant will be implemented.
58
+
59
+ ## Data and Preprocessing
60
+ - MinMax luminance features (both branches):
61
+ - Y = 0.299R + 0.587G + 0.114B.
62
+ - Y_min = min filter (6×6), Y_max = max filter (6×6).
63
+ - Concat [Y_min, Y_max] to the input image channels.
64
+ - Coarse GT label generation (MaxPool):
65
+ - Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
66
+ - Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
67
+
68
+ ## Training Pipeline
69
+ - Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
70
+ - Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
71
+ - Fine input (per iteration select 1–k patches):
72
+ - Sample p×p patch (p=768) with ≥1% wire pixels (`method.tex`, `method_yq.tex`).
73
+ - Build cond map from coarse logits via 1×1 conv; crop cond to patch region.
74
+ - Build binary location mask for patch region.
75
+ - Build channels [RGB + MinMax + cond + location] → `E` → `D_F`.
76
+ - Losses:
77
+ - L_glo = CE(Softmax(`D_C(E(coarse))`), G_glo), where G_glo uses MaxPool downsample.
78
+ - L_loc = CE(Softmax(`D_F(E(fine))`), G_loc).
79
+ - L = L_glo + λ L_loc, λ=1 (`method_yq.tex`).
80
+ - Optimization:
81
+ - AdamW (lr=6e-5, wd=0.01), poly schedule (power=1.0), ~40k iterations, batch ≈8 (tune by memory).
82
+ - AMP and grad accumulation recommended for stability/memory.
83
+
84
+ ## Inference Pipeline
85
+ - Coarse pass:
86
+ - Downsample to 1024×1024; predict coarse probability/logits.
87
+ - Window proposal (sliding window on full-res):
88
+ - Tile with patch size p=768. Overlap ~128px (configurable). Compute wire fraction within each window from coarse prediction (prob>0.5).
89
+ - If fraction ≥ α (default 0.01), run fine refinement on that patch; else skip (Table `tables/thresholds.tex`).
90
+ - Fine refinement + stitching:
91
+ - For selected windows, build fine input with cond crop + location mask; predict logits.
92
+ - Stitch logits into full-res canvas; average in overlaps; final argmax over classes.
93
+ - Outputs: full-res binary mask, plus optional probability map.
94
+
95
+ ## Metrics and Reporting
96
+ - Implement: IoU, F1, Precision, Recall (global, and optionally per-size bins if available) matching `tables/component.tex`.
97
+ - Validate α trade-offs following `tables/thresholds.tex`.
98
+ - Ablations: MinMax on/off, MaxPool on/off, conditioning variant (Table `tables/logit.tex`).
99
+
100
+ ## Configuration Surface (key)
101
+ - Backbone/weights: `mit_b3` (pretrained ImageNet-1K).
102
+ - Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
103
+ - Conditioning: `use_binary_location=true`, `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
104
+ - MinMax: `enable=true`, `kernel=6`.
105
+ - Label: `coarse_label_downsample='maxpool'`.
106
+ - Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
107
+ - Inference: `alpha=0.01`, `prob_threshold=0.5` for wire fraction, `stitch='avg_logits'`.
108
+
109
+ ## Risks / Gotchas
110
+ - Channel expansion requires careful initialization; confirm no NaNs and stable early training.
111
+ - Precise spatial alignment of cond and location mask with the patch is critical. Add assertions/tests.
112
+ - Even-sized MinMax window (6×6) requires careful padding to maintain alignment.
113
+ - Memory with p=768 and MiT-B3 may need tuning (AMP, batch size, overlap).
114
+
115
+ ## Milestones
116
+ 1) Skeleton + configs + metrics.
117
+ 2) Encoder channel expansion + two decoders + 1×1 cond.
118
+ 3) MinMax (6×6) + MaxPool label downsampling.
119
+ 4) Training loop with ≥1% wire patch sampling.
120
+ 5) Inference α-threshold + stitching.
121
+ 6) Ablations toggles + scripts + README.
122
+ 7) Tests (channel wiring, cond/mask alignment, stitching correctness).
123
+
124
+ ## References (paper sources)
125
+ - `paper-tex/sections/method.tex`: Two-stage design, shared encoder, 1×1 cond, training/inference sizes, optimizer/schedule.
126
+ - `paper-tex/sections/method_yq.tex`: CE losses, λ, sliding-window with α, MinMax & MaxPool rationale.
127
+ - `paper-tex/figure_tex/pipeline.tex`: System overview; MinMax concatenation.
128
+ - `paper-tex/tables/component.tex`: Ablation of MinMax/MaxPool/coarse.
129
+ - `paper-tex/tables/logit.tex`: Conditioning variants.
130
+ - `paper-tex/tables/thresholds.tex`: α vs speed/quality.
WireSegHR-tex.tar.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a17096a2eaad07f51345426465697fcf0ee1a0c5b54aa6742b4ac23406f6bc4
3
- size 33690376
 
 
 
 
configs/default.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default configuration for WireSegHR (segmentation-only)
2
+ backbone: mit_b3
3
+
4
+ coarse:
5
+ train_size: 512
6
+ test_size: 1024
7
+
8
+ fine:
9
+ patch_size: 768
10
+ overlap: 128
11
+
12
+ conditioning:
13
+ use_binary_location: true
14
+ cond_from: coarse_logits_1x1
15
+ cond_crop: patch # per published method (method_yq)
16
+
17
+ minmax:
18
+ enable: true
19
+ kernel: 6 # fixed 6x6 luminance min/max
20
+
21
+ label:
22
+ coarse_downsample: maxpool
23
+
24
+ inference:
25
+ alpha: 0.01
26
+ prob_threshold: 0.5
27
+ stitch: avg_logits
28
+
29
+ optim:
30
+ iters: 40000
31
+ batch_size: 8
32
+ lr: 6e-5
33
+ weight_decay: 0.01
34
+ schedule: poly
35
+ power: 1.0
36
+
37
+ # dataset paths (placeholders)
38
+ data:
39
+ train_images: /path/to/train/images
40
+ train_masks: /path/to/train/masks
41
+ val_images: /path/to/val/images
42
+ val_masks: /path/to/val/masks
pytest.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [pytest]
2
+ pythonpath = src
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.0
3
+ timm>=0.9.8
4
+ numpy>=1.24.0
5
+ opencv-python>=4.8.0.76
6
+ Pillow>=9.5.0
7
+ PyYAML>=6.0.1
8
+ tqdm>=4.65.0
src/wireseghr/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "model",
3
+ "data",
4
+ ]
5
+
6
+ __version__ = "0.1.0"
src/wireseghr/data/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .dataset import WireSegDataset
2
+ from .sampler import BalancedPatchSampler
3
+
4
+ __all__ = [
5
+ "WireSegDataset",
6
+ "BalancedPatchSampler",
7
+ ]
src/wireseghr/data/dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset placeholder for wire segmentation
2
+ """WireSeg dataset indexing and loading.
3
+
4
+ Pairs images in `images_dir` with masks in `masks_dir` by matching filename stems.
5
+ Mask is loaded as single-channel 0/1.
6
+ """
7
+
8
+ from typing import Any, Dict, List
9
+
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import cv2
14
+
15
+
16
+ class WireSegDataset:
17
+ def __init__(self, images_dir: str, masks_dir: str, split: str = "train"):
18
+ self.images_dir = Path(images_dir)
19
+ self.masks_dir = Path(masks_dir)
20
+ self.split = split
21
+ assert self.images_dir.exists(), f"Missing images_dir: {self.images_dir}"
22
+ assert self.masks_dir.exists(), f"Missing masks_dir: {self.masks_dir}"
23
+ self._items: List[tuple[Path, Path]] = self._index_pairs()
24
+
25
+ def __len__(self) -> int:
26
+ return len(self._items)
27
+
28
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
29
+ img_path, mask_path = self._items[idx]
30
+ img_bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
31
+ assert img_bgr is not None, f"Failed to read image: {img_path}"
32
+ img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
33
+ mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
34
+ assert mask is not None, f"Failed to read mask: {mask_path}"
35
+ mask_bin = (mask > 0).astype(np.uint8)
36
+ return {"image": img, "mask": mask_bin, "image_path": str(img_path), "mask_path": str(mask_path)}
37
+
38
+ def _index_pairs(self) -> List[tuple[Path, Path]]:
39
+ exts_img = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
40
+ exts_mask = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
41
+ imgs: Dict[str, Path] = {}
42
+ for p in sorted(self.images_dir.rglob("*")):
43
+ if p.is_file() and p.suffix.lower() in exts_img:
44
+ imgs[p.stem] = p
45
+ masks: Dict[str, Path] = {}
46
+ for p in sorted(self.masks_dir.rglob("*")):
47
+ if p.is_file() and p.suffix.lower() in exts_mask:
48
+ masks[p.stem] = p
49
+ pairs: List[tuple[Path, Path]] = []
50
+ for stem, ip in imgs.items():
51
+ if stem in masks:
52
+ pairs.append((ip, masks[stem]))
53
+ assert len(pairs) > 0, f"No image-mask pairs found in {self.images_dir} and {self.masks_dir}"
54
+ return pairs
src/wireseghr/data/sampler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Balanced patch sampler (>=1% wire pixels)
2
+ # TODO: Implement logic over mask to pick patches with wire ratio >= threshold.
3
+
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class BalancedPatchSampler:
9
+ patch_size: int = 768
10
+ min_wire_ratio: float = 0.01
11
+
12
+ def sample(self, image, mask):
13
+ # TODO: sample and return top-left (y, x) of a valid patch
14
+ return 0, 0
src/wireseghr/data/transforms.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Training-time transforms: scaling, rotation, flip, photometric distortion
2
+ # TODO: Implement deterministic transform composition for reproducibility
3
+
4
+ class TrainTransforms:
5
+ def __init__(self):
6
+ pass
7
+
8
+ def __call__(self, image, mask):
9
+ return image, mask
src/wireseghr/infer.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pprint
4
+ import yaml
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(description="WireSegHR inference (skeleton)")
9
+ parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to YAML config")
10
+ parser.add_argument("--image", type=str, required=False, help="Path to input image")
11
+ args = parser.parse_args()
12
+
13
+ cfg_path = args.config
14
+ if not os.path.isabs(cfg_path):
15
+ cfg_path = os.path.join(os.getcwd(), cfg_path)
16
+
17
+ with open(cfg_path, "r") as f:
18
+ cfg = yaml.safe_load(f)
19
+
20
+ print("[WireSegHR][infer] Loaded config from:", cfg_path)
21
+ pprint.pprint(cfg)
22
+ print("[WireSegHR][infer] Image:", args.image)
23
+ print("[WireSegHR][infer] Skeleton OK. Implement inference per SEGMENTATION_PLAN.md.")
24
+
25
+
26
+ if __name__ == "__main__":
27
+ main()
src/wireseghr/metrics.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Metrics placeholder: IoU, F1, Precision, Recall
2
+ # TODO: Implement proper metrics matching paper tables.
3
+
4
+ from typing import Dict
5
+
6
+
7
+ def compute_metrics(pred_mask, gt_mask) -> Dict[str, float]:
8
+ # TODO: implement
9
+ return {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
src/wireseghr/model/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .encoder import SegFormerEncoder
2
+ from .decoder import CoarseDecoder, FineDecoder
3
+ from .condition import Conditioning1x1
4
+ from .minmax import MinMaxLuminance
5
+ from .label_downsample import downsample_label_maxpool
6
+ from .model import WireSegHR
7
+
8
+ __all__ = [
9
+ "SegFormerEncoder",
10
+ "CoarseDecoder",
11
+ "FineDecoder",
12
+ "Conditioning1x1",
13
+ "MinMaxLuminance",
14
+ "downsample_label_maxpool",
15
+ "WireSegHR",
16
+ ]
src/wireseghr/model/condition.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1x1 conv to collapse 2-ch coarse logits into 1-ch conditioning map
2
+ # TODO: Wire with coarse decoder outputs and proper resize/cropping.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class Conditioning1x1(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.conv = nn.Conv2d(2, 1, kernel_size=1, bias=True)
12
+
13
+ def forward(self, coarse_logits: torch.Tensor) -> torch.Tensor:
14
+ return self.conv(coarse_logits)
src/wireseghr/model/decoder.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SegFormer-like multi-scale decoder heads for coarse and fine branches.
2
+
3
+ Fuse four feature maps from MiT encoder via 1x1 projections, upsample to the
4
+ highest spatial resolution (stage 0), concatenate, and predict 2-class logits.
5
+ """
6
+
7
+ from typing import List
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class _ConvBNReLU(nn.Module):
15
+ def __init__(self, in_ch: int, out_ch: int, k: int, s: int = 1, p: int = 0):
16
+ super().__init__()
17
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False)
18
+ self.bn = nn.BatchNorm2d(out_ch)
19
+ self.relu = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ x = self.conv(x)
23
+ x = self.bn(x)
24
+ x = self.relu(x)
25
+ return x
26
+
27
+
28
+ class _SegFormerHead(nn.Module):
29
+ def __init__(self, in_chs: List[int], embed_dim: int = 128, num_classes: int = 2):
30
+ super().__init__()
31
+ assert len(in_chs) == 4
32
+ self.proj = nn.ModuleList([nn.Conv2d(c, embed_dim, kernel_size=1) for c in in_chs])
33
+ self.fuse = _ConvBNReLU(embed_dim * 4, embed_dim, k=3, p=1)
34
+ self.cls = nn.Conv2d(embed_dim, num_classes, kernel_size=1)
35
+
36
+ def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
37
+ assert len(feats) == 4
38
+ h, w = feats[0].shape[2], feats[0].shape[3]
39
+ xs = []
40
+ for f, proj in zip(feats, self.proj):
41
+ x = proj(f)
42
+ if x.shape[2] != h or x.shape[3] != w:
43
+ x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
44
+ xs.append(x)
45
+ x = torch.cat(xs, dim=1)
46
+ x = self.fuse(x)
47
+ x = self.cls(x)
48
+ return x
49
+
50
+
51
+ class CoarseDecoder(_SegFormerHead):
52
+ def __init__(self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2):
53
+ super().__init__(list(in_chs), embed_dim, num_classes)
54
+
55
+
56
+ class FineDecoder(_SegFormerHead):
57
+ def __init__(self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2):
58
+ super().__init__(list(in_chs), embed_dim, num_classes)
src/wireseghr/model/encoder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SegFormer MiT encoder wrapper with adjustable input channels.
2
+
3
+ Uses timm to instantiate MiT (e.g., mit_b3) and returns a list of multi-scale
4
+ features [C1, C2, C3, C4].
5
+ """
6
+
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import timm
12
+
13
+
14
+ class SegFormerEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ backbone: str = "mit_b3",
18
+ in_channels: int = 7,
19
+ pretrained: bool = True,
20
+ out_indices: Tuple[int, int, int, int] = (0, 1, 2, 3),
21
+ ):
22
+ super().__init__()
23
+ self.backbone_name = backbone
24
+ self.in_channels = in_channels
25
+ self.pretrained = pretrained
26
+ self.out_indices = out_indices
27
+
28
+ # Create MiT with features_only to obtain multi-scale feature maps.
29
+ # in_chans allows expanded inputs (RGB + minmax + cond + loc)
30
+ self.encoder = timm.create_model(
31
+ backbone,
32
+ pretrained=pretrained,
33
+ features_only=True,
34
+ out_indices=out_indices,
35
+ in_chans=in_channels,
36
+ )
37
+
38
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
39
+ feats = self.encoder(x)
40
+ # Ensure list of tensors is returned
41
+ assert isinstance(feats, (list, tuple)) and len(feats) == len(self.out_indices)
42
+ return list(feats)
src/wireseghr/model/label_downsample.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MaxPool-based downsampling for coarse labels
2
+ """Downsample binary masks preserving thin positives.
3
+
4
+ We use area-based resize on float32 masks followed by a >0 threshold.
5
+ This emulates block-wise max pooling: any positive in the source region
6
+ produces a positive in the target pixel.
7
+ """
8
+
9
+ import numpy as np
10
+
11
+
12
+ def downsample_label_maxpool(mask: np.ndarray, out_h: int, out_w: int) -> np.ndarray:
13
+ """
14
+ Args:
15
+ mask: HxW binary (0/1) numpy array
16
+ out_h, out_w: target size
17
+ Returns:
18
+ H'xW' binary array via max-pooling-like downsample
19
+ """
20
+ assert mask.ndim == 2
21
+ # Convert to float32 so area resize yields fractional averages > 0 if any positive present
22
+ import cv2
23
+ m = mask.astype(np.float32)
24
+ r = cv2.resize(m, (out_w, out_h), interpolation=cv2.INTER_AREA)
25
+ out = (r > 0.0).astype(np.uint8)
26
+ return out
src/wireseghr/model/minmax.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MinMax luminance feature computation (6x6 window)
2
+ # Implemented with OpenCV morphology (erode=min, dilate=max) using 6x6 kernel and replicate border.
3
+
4
+ from typing import Tuple
5
+
6
+ import numpy as np
7
+
8
+
9
+ class MinMaxLuminance:
10
+ def __init__(self, kernel: int = 6):
11
+ assert kernel == 6, "Per plan, kernel is fixed to 6x6"
12
+ self.kernel = kernel
13
+
14
+ def __call__(self, img_rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
15
+ """
16
+ Args:
17
+ img_rgb: HxWx3 uint8 or float32 in [0,255] or [0,1]
18
+ Returns:
19
+ (Y_min, Y_max): two HxW float32 arrays
20
+ """
21
+ assert img_rgb.ndim == 3 and img_rgb.shape[2] == 3
22
+ r, g, b = img_rgb[..., 0], img_rgb[..., 1], img_rgb[..., 2]
23
+ y = (0.299 * r + 0.587 * g + 0.114 * b).astype(np.float32)
24
+
25
+ import cv2 # lazy import to avoid test-time dependency at module import
26
+ kernel = np.ones((self.kernel, self.kernel), dtype=np.uint8)
27
+ y_min = cv2.erode(y, kernel, borderType=cv2.BORDER_REPLICATE)
28
+ y_max = cv2.dilate(y, kernel, borderType=cv2.BORDER_REPLICATE)
29
+ return y_min.astype(np.float32), y_max.astype(np.float32)
src/wireseghr/model/model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .encoder import SegFormerEncoder
7
+ from .decoder import CoarseDecoder, FineDecoder
8
+ from .condition import Conditioning1x1
9
+
10
+
11
+ class WireSegHR(nn.Module):
12
+ """
13
+ Two-stage WireSegHR model wrapper with shared encoder.
14
+
15
+ Expects callers to prepare input channel stacks according to the plan:
16
+ - Coarse input: RGB + MinMax (and any extra channels per config), shape (B, Cc, Hc, Wc)
17
+ - Fine input: RGB + MinMax + cond_crop + binary_location_mask, shape (B, Cf, p, p)
18
+
19
+ Conditioning 1x1 is applied to coarse logits to produce a single-channel map.
20
+ """
21
+
22
+ def __init__(self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True):
23
+ super().__init__()
24
+ self.encoder = SegFormerEncoder(backbone=backbone, in_channels=in_channels, pretrained=pretrained)
25
+ # Default MiT-B3 channel dims for stages
26
+ in_chs = (64, 128, 320, 512)
27
+ self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
28
+ self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
29
+ self.cond1x1 = Conditioning1x1()
30
+
31
+ def forward_coarse(self, x_coarse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ assert x_coarse.dim() == 4
33
+ feats = self.encoder(x_coarse)
34
+ logits_coarse = self.coarse_head(feats)
35
+ cond_map = self.cond1x1(logits_coarse)
36
+ return logits_coarse, cond_map
37
+
38
+ def forward_fine(self, x_fine: torch.Tensor) -> torch.Tensor:
39
+ assert x_fine.dim() == 4
40
+ feats = self.encoder(x_fine)
41
+ logits_fine = self.fine_head(feats)
42
+ return logits_fine
src/wireseghr/train.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pprint
4
+ import yaml
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
9
+ parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to YAML config")
10
+ args = parser.parse_args()
11
+
12
+ cfg_path = args.config
13
+ if not os.path.isabs(cfg_path):
14
+ cfg_path = os.path.join(os.getcwd(), cfg_path)
15
+
16
+ with open(cfg_path, "r") as f:
17
+ cfg = yaml.safe_load(f)
18
+
19
+ print("[WireSegHR][train] Loaded config from:", cfg_path)
20
+ pprint.pprint(cfg)
21
+ print("[WireSegHR][train] Skeleton OK. Implement training per SEGMENTATION_PLAN.md.")
22
+
23
+
24
+ if __name__ == "__main__":
25
+ main()
src/wireseghr/utils.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def log(msg: str):
2
+ print(f"[WireSegHR] {msg}")
tests/test_model_forward.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from wireseghr.model import WireSegHR
4
+
5
+
6
+ def test_wireseghr_forward_shapes():
7
+ # Use small input to keep test light and avoid downloading weights
8
+ model = WireSegHR(backbone="mit_b3", in_channels=3, pretrained=False)
9
+
10
+ x = torch.randn(1, 3, 64, 64)
11
+ logits_coarse, cond = model.forward_coarse(x)
12
+ assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
13
+ assert cond.shape[0] == 1 and cond.shape[1] == 1
14
+ # Expect stage 0 resolution ~ 1/4 of input for MiT
15
+ assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
16
+ assert cond.shape[2] == 16 and cond.shape[3] == 16
17
+
18
+ logits_fine = model.forward_fine(x)
19
+ assert logits_fine.shape == logits_coarse.shape
tests/test_skeleton_imports.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ def test_imports():
2
+ import wireseghr
3
+ import wireseghr.model as m
4
+ import wireseghr.data as d
5
+
6
+ assert hasattr(m, "SegFormerEncoder")
7
+ assert hasattr(d, "WireSegDataset")