Project skeleton generated
Browse files- .gitignore +116 -0
- .windsurf/rules/creating-new-class-variables.md +18 -0
- .windsurf/rules/defensive-logic.md +8 -0
- .windsurf/rules/running-tests.md +6 -0
- .windsurf/rules/when-pytest-not-found.md +6 -0
- README.md +31 -0
- SEGMENTATION_PLAN.md +130 -0
- WireSegHR-tex.tar.gz +0 -3
- configs/default.yaml +42 -0
- pytest.ini +2 -0
- requirements.txt +8 -0
- src/wireseghr/__init__.py +6 -0
- src/wireseghr/data/__init__.py +7 -0
- src/wireseghr/data/dataset.py +54 -0
- src/wireseghr/data/sampler.py +14 -0
- src/wireseghr/data/transforms.py +9 -0
- src/wireseghr/infer.py +27 -0
- src/wireseghr/metrics.py +9 -0
- src/wireseghr/model/__init__.py +16 -0
- src/wireseghr/model/condition.py +14 -0
- src/wireseghr/model/decoder.py +58 -0
- src/wireseghr/model/encoder.py +42 -0
- src/wireseghr/model/label_downsample.py +26 -0
- src/wireseghr/model/minmax.py +29 -0
- src/wireseghr/model/model.py +42 -0
- src/wireseghr/train.py +25 -0
- src/wireseghr/utils.py +2 -0
- tests/test_model_forward.py +19 -0
- tests/test_skeleton_imports.py +7 -0
.gitignore
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
*.egg-info/
|
23 |
+
.installed.cfg
|
24 |
+
*.egg
|
25 |
+
|
26 |
+
# PyInstaller
|
27 |
+
# Usually these files are written by a python script from a template
|
28 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
29 |
+
*.manifest
|
30 |
+
*.spec
|
31 |
+
|
32 |
+
# Installer logs
|
33 |
+
pip-log.txt
|
34 |
+
pip-delete-this-directory.txt
|
35 |
+
|
36 |
+
# Unit test / coverage reports
|
37 |
+
htmlcov/
|
38 |
+
.tox/
|
39 |
+
.nox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*.cover
|
46 |
+
.hypothesis/
|
47 |
+
.pytest_cache/
|
48 |
+
pytestdebug.log
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
db.sqlite3-journal
|
59 |
+
|
60 |
+
# Flask stuff:
|
61 |
+
instance/
|
62 |
+
.webassets-cache
|
63 |
+
|
64 |
+
# Scrapy stuff:
|
65 |
+
.scrapy
|
66 |
+
|
67 |
+
# Sphinx documentation
|
68 |
+
docs/_build/
|
69 |
+
|
70 |
+
# PyBuilder
|
71 |
+
target/
|
72 |
+
|
73 |
+
# Jupyter Notebook
|
74 |
+
.ipynb_checkpoints
|
75 |
+
|
76 |
+
# IPython
|
77 |
+
profile_default/
|
78 |
+
ipython_config.py
|
79 |
+
|
80 |
+
# pyenv
|
81 |
+
.python-version
|
82 |
+
|
83 |
+
# pipenv
|
84 |
+
Pipfile.lock
|
85 |
+
|
86 |
+
# poetry
|
87 |
+
poetry.lock
|
88 |
+
|
89 |
+
# mypy
|
90 |
+
.mypy_cache/
|
91 |
+
.dmypy.json
|
92 |
+
dmypy.json
|
93 |
+
|
94 |
+
# Pyre type checker
|
95 |
+
.pyre/
|
96 |
+
|
97 |
+
# pytype
|
98 |
+
.pytype/
|
99 |
+
|
100 |
+
# Cython debug symbols
|
101 |
+
cython_debug/
|
102 |
+
|
103 |
+
# VS Code
|
104 |
+
.vscode/
|
105 |
+
|
106 |
+
# Mac
|
107 |
+
.DS_Store
|
108 |
+
|
109 |
+
# Environment
|
110 |
+
.env
|
111 |
+
.venv
|
112 |
+
env/
|
113 |
+
venv/
|
114 |
+
ENV/
|
115 |
+
env.bak/
|
116 |
+
venv.bak/
|
.windsurf/rules/creating-new-class-variables.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
trigger: model_decision
|
3 |
+
description: When creating new class variable
|
4 |
+
---
|
5 |
+
|
6 |
+
It is preferable to create class variable docstrings instead of comments. E.g:
|
7 |
+
|
8 |
+
```py
|
9 |
+
class Class123:
|
10 |
+
var1: int
|
11 |
+
"""Variable description"""
|
12 |
+
```
|
13 |
+
is preferred over
|
14 |
+
```py
|
15 |
+
class Class123:
|
16 |
+
# Variable description
|
17 |
+
var1: int
|
18 |
+
```
|
.windsurf/rules/defensive-logic.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
trigger: always_on
|
3 |
+
---
|
4 |
+
|
5 |
+
When deciding if to write defensive logic, e.g. dimensionality handling: `tensor1=tensor1.unsqueeze() if tensor1.ndim==1 else tensor1`, or None handling: `var1=var1 if var1 else torch.zeros(...)`, just don't write these things. In my code, shapes are always static, and there is one execution path for all code. I prefer `assert` over defensive logic. If you are writing something to fix the tests and this seems necessary, it's likely that the tests are setup incorrectly.
|
6 |
+
The reason why I don't want it is because defensive logic leads to silent failures, and these are bad for debugging.
|
7 |
+
|
8 |
+
In addition, writing "int()" type casting is also a piece of defensive logic that slows down the application. Don't write it unless really necessary e.g. putting int to string. In most cases, indexing with a tensor should be better.
|
.windsurf/rules/running-tests.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
trigger: model_decision
|
3 |
+
description: When deciding which tests to run
|
4 |
+
---
|
5 |
+
|
6 |
+
When calling `pytest`, it's better to never execute the whole test suite because it takes over 5 minutes to run. So, never run `pytest -q` without test file or `-k`. Instead, run an individual test function, class, module or a combination of them.
|
.windsurf/rules/when-pytest-not-found.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
trigger: model_decision
|
3 |
+
description: When pytest is not found
|
4 |
+
---
|
5 |
+
|
6 |
+
Sometimes when running tests, `pytest` would be not found. This means that `venv` is not activated, and you can activate it with `source ../.venv/bin/activate`. After that, it will certainly work. Do not try to run `./venv/bin/pytest` directly, activate the venv, and then run `pytest` as usual.
|
README.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# WireSegHR (Segmentation Only)
|
2 |
+
|
3 |
+
This repository contains the segmentation-only implementation plan and code skeleton for the two-stage WireSegHR model (global-to-local, shared encoder).
|
4 |
+
|
5 |
+
- Paper sources live under `paper-tex/`.
|
6 |
+
- Long-term navigation plan: `SEGMENTATION_PLAN.md`.
|
7 |
+
|
8 |
+
## Quick Start (skeleton)
|
9 |
+
|
10 |
+
1) Create a virtual environment and install requirements:
|
11 |
+
|
12 |
+
```bash
|
13 |
+
python -m venv .venv
|
14 |
+
source .venv/bin/activate
|
15 |
+
pip install -r requirements.txt
|
16 |
+
```
|
17 |
+
|
18 |
+
2) Print configuration and verify the skeleton runs:
|
19 |
+
|
20 |
+
```bash
|
21 |
+
python src/wireseghr/train.py --config configs/default.yaml
|
22 |
+
python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/image.png
|
23 |
+
```
|
24 |
+
|
25 |
+
3) Next steps:
|
26 |
+
- Implement encoder/decoders/condition/minmax/label downsampling per `SEGMENTATION_PLAN.md`.
|
27 |
+
- Implement training and inference logic, then metrics and ablations.
|
28 |
+
|
29 |
+
## Notes
|
30 |
+
- This is a segmentation-only codebase. Inpainting is out of scope here.
|
31 |
+
- Defaults locked: MiT-B3 encoder, patch size 768, MinMax 6×6, global+binary mask conditioning with patch-cropped global map.
|
SEGMENTATION_PLAN.md
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# WireSegHR Segmentation-Only Implementation Plan
|
2 |
+
|
3 |
+
This plan distills the model and pipeline described in the paper sources:
|
4 |
+
- `paper-tex/sections/method.tex`
|
5 |
+
- `paper-tex/sections/method_yq.tex`
|
6 |
+
- `paper-tex/figure_tex/pipeline.tex`
|
7 |
+
- `paper-tex/tables/{component,logit,thresholds}.tex`
|
8 |
+
|
9 |
+
Focus: segmentation only (no dataset collection or inpainting).
|
10 |
+
|
11 |
+
## Decisions and Defaults (locked)
|
12 |
+
- Backbone: SegFormer MiT-B3 (shared encoder `E`).
|
13 |
+
- Fine/local patch size p: 768.
|
14 |
+
- Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
|
15 |
+
- Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
|
16 |
+
- MinMax feature augmentation: luminance min and max with a fixed 6×6 window; channels concatenated to inputs (Figure `figure_tex/pipeline.tex`, Sec. “Wire Feature Preservation” in `method_yq.tex`).
|
17 |
+
- Loss: CE on both branches, λ = 1 (`method_yq.tex`).
|
18 |
+
- α-threshold for refining windows: default 0.01 (Table `tables/thresholds.tex`).
|
19 |
+
- Coarse input size: train 512×512; test 1024×1024 (`method.tex`).
|
20 |
+
- Optim: AdamW (lr=6e-5, wd=0.01, poly schedule with power=1), ~40k iters, batch size ~8 (`method.tex`).
|
21 |
+
|
22 |
+
## Project Structure
|
23 |
+
- `configs/`
|
24 |
+
- `default.yaml` (backbone=mit_b3, p=768, coarse_train=512, coarse_test=1024, alpha=0.01, minmax=true, kernel=6, maxpool_label=true, cond_variant=global+binary_mask)
|
25 |
+
- `src/wireseghr/`
|
26 |
+
- `model/`
|
27 |
+
- `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
|
28 |
+
- `decoder.py` (two MLP decoders `D_C`, `D_F` for 2 classes)
|
29 |
+
- `condition.py` (1×1 conv to collapse coarse 2-ch logits → 1-ch cond)
|
30 |
+
- `minmax.py` (6×6 luminance min/max filtering)
|
31 |
+
- `label_downsample.py` (MaxPool-based coarse GT downsampling)
|
32 |
+
- `data/`
|
33 |
+
- `dataset.py` (image/mask loading, full-res to coarse/fine inputs)
|
34 |
+
- `sampler.py` (balanced patch sampling with ≥1% wire pixels)
|
35 |
+
- `transforms.py` (scaling, rotation, flip, photometric distortion)
|
36 |
+
- `train.py` (end-to-end two-branch training)
|
37 |
+
- `infer.py` (coarse-to-fine sliding-window inference + stitching)
|
38 |
+
- `metrics.py` (IoU, F1, Precision, Recall)
|
39 |
+
- `utils.py` (misc: overlap blending, seeding, logging)
|
40 |
+
- `tests/` (unit tests for channel wiring, cond alignment, stitching)
|
41 |
+
- `README.md` (segmentation-only usage)
|
42 |
+
|
43 |
+
## Model Specification
|
44 |
+
- Shared encoder `E`: SegFormer MiT-B3.
|
45 |
+
- Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
|
46 |
+
- For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
|
47 |
+
- Weight init for extra channels: copy mean of RGB conv weights or zero-init.
|
48 |
+
- Decoders: two SegFormer MLP decoders
|
49 |
+
- `D_C`: coarse logits (2 channels) at coarse resolution.
|
50 |
+
- `D_F`: fine logits (2 channels) at patch resolution p×p.
|
51 |
+
- Conditioning to fine branch (default):
|
52 |
+
- Take coarse pre-softmax logits (2-ch), apply 1×1 conv → 1-ch cond map (`method.tex`).
|
53 |
+
- Binary location mask: 1 inside current patch region (in full-image coordinates), 0 elsewhere.
|
54 |
+
- Pass patch-aligned cond crop and binary mask as channels to the fine branch input.
|
55 |
+
- Notes:
|
56 |
+
- We expose a config toggle to switch conditioning variant between: `global+binary_mask` (default) and `global_only` (Table `tables/logit.tex`).
|
57 |
+
- We follow the published version (`paper-tex/sections/method_yq.tex`) and use patch-cropped conditioning exclusively; no full-image conditioning variant will be implemented.
|
58 |
+
|
59 |
+
## Data and Preprocessing
|
60 |
+
- MinMax luminance features (both branches):
|
61 |
+
- Y = 0.299R + 0.587G + 0.114B.
|
62 |
+
- Y_min = min filter (6×6), Y_max = max filter (6×6).
|
63 |
+
- Concat [Y_min, Y_max] to the input image channels.
|
64 |
+
- Coarse GT label generation (MaxPool):
|
65 |
+
- Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
|
66 |
+
- Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
|
67 |
+
|
68 |
+
## Training Pipeline
|
69 |
+
- Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
|
70 |
+
- Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
|
71 |
+
- Fine input (per iteration select 1–k patches):
|
72 |
+
- Sample p×p patch (p=768) with ≥1% wire pixels (`method.tex`, `method_yq.tex`).
|
73 |
+
- Build cond map from coarse logits via 1×1 conv; crop cond to patch region.
|
74 |
+
- Build binary location mask for patch region.
|
75 |
+
- Build channels [RGB + MinMax + cond + location] → `E` → `D_F`.
|
76 |
+
- Losses:
|
77 |
+
- L_glo = CE(Softmax(`D_C(E(coarse))`), G_glo), where G_glo uses MaxPool downsample.
|
78 |
+
- L_loc = CE(Softmax(`D_F(E(fine))`), G_loc).
|
79 |
+
- L = L_glo + λ L_loc, λ=1 (`method_yq.tex`).
|
80 |
+
- Optimization:
|
81 |
+
- AdamW (lr=6e-5, wd=0.01), poly schedule (power=1.0), ~40k iterations, batch ≈8 (tune by memory).
|
82 |
+
- AMP and grad accumulation recommended for stability/memory.
|
83 |
+
|
84 |
+
## Inference Pipeline
|
85 |
+
- Coarse pass:
|
86 |
+
- Downsample to 1024×1024; predict coarse probability/logits.
|
87 |
+
- Window proposal (sliding window on full-res):
|
88 |
+
- Tile with patch size p=768. Overlap ~128px (configurable). Compute wire fraction within each window from coarse prediction (prob>0.5).
|
89 |
+
- If fraction ≥ α (default 0.01), run fine refinement on that patch; else skip (Table `tables/thresholds.tex`).
|
90 |
+
- Fine refinement + stitching:
|
91 |
+
- For selected windows, build fine input with cond crop + location mask; predict logits.
|
92 |
+
- Stitch logits into full-res canvas; average in overlaps; final argmax over classes.
|
93 |
+
- Outputs: full-res binary mask, plus optional probability map.
|
94 |
+
|
95 |
+
## Metrics and Reporting
|
96 |
+
- Implement: IoU, F1, Precision, Recall (global, and optionally per-size bins if available) matching `tables/component.tex`.
|
97 |
+
- Validate α trade-offs following `tables/thresholds.tex`.
|
98 |
+
- Ablations: MinMax on/off, MaxPool on/off, conditioning variant (Table `tables/logit.tex`).
|
99 |
+
|
100 |
+
## Configuration Surface (key)
|
101 |
+
- Backbone/weights: `mit_b3` (pretrained ImageNet-1K).
|
102 |
+
- Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
|
103 |
+
- Conditioning: `use_binary_location=true`, `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
|
104 |
+
- MinMax: `enable=true`, `kernel=6`.
|
105 |
+
- Label: `coarse_label_downsample='maxpool'`.
|
106 |
+
- Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
|
107 |
+
- Inference: `alpha=0.01`, `prob_threshold=0.5` for wire fraction, `stitch='avg_logits'`.
|
108 |
+
|
109 |
+
## Risks / Gotchas
|
110 |
+
- Channel expansion requires careful initialization; confirm no NaNs and stable early training.
|
111 |
+
- Precise spatial alignment of cond and location mask with the patch is critical. Add assertions/tests.
|
112 |
+
- Even-sized MinMax window (6×6) requires careful padding to maintain alignment.
|
113 |
+
- Memory with p=768 and MiT-B3 may need tuning (AMP, batch size, overlap).
|
114 |
+
|
115 |
+
## Milestones
|
116 |
+
1) Skeleton + configs + metrics.
|
117 |
+
2) Encoder channel expansion + two decoders + 1×1 cond.
|
118 |
+
3) MinMax (6×6) + MaxPool label downsampling.
|
119 |
+
4) Training loop with ≥1% wire patch sampling.
|
120 |
+
5) Inference α-threshold + stitching.
|
121 |
+
6) Ablations toggles + scripts + README.
|
122 |
+
7) Tests (channel wiring, cond/mask alignment, stitching correctness).
|
123 |
+
|
124 |
+
## References (paper sources)
|
125 |
+
- `paper-tex/sections/method.tex`: Two-stage design, shared encoder, 1×1 cond, training/inference sizes, optimizer/schedule.
|
126 |
+
- `paper-tex/sections/method_yq.tex`: CE losses, λ, sliding-window with α, MinMax & MaxPool rationale.
|
127 |
+
- `paper-tex/figure_tex/pipeline.tex`: System overview; MinMax concatenation.
|
128 |
+
- `paper-tex/tables/component.tex`: Ablation of MinMax/MaxPool/coarse.
|
129 |
+
- `paper-tex/tables/logit.tex`: Conditioning variants.
|
130 |
+
- `paper-tex/tables/thresholds.tex`: α vs speed/quality.
|
WireSegHR-tex.tar.gz
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0a17096a2eaad07f51345426465697fcf0ee1a0c5b54aa6742b4ac23406f6bc4
|
3 |
-
size 33690376
|
|
|
|
|
|
|
|
configs/default.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default configuration for WireSegHR (segmentation-only)
|
2 |
+
backbone: mit_b3
|
3 |
+
|
4 |
+
coarse:
|
5 |
+
train_size: 512
|
6 |
+
test_size: 1024
|
7 |
+
|
8 |
+
fine:
|
9 |
+
patch_size: 768
|
10 |
+
overlap: 128
|
11 |
+
|
12 |
+
conditioning:
|
13 |
+
use_binary_location: true
|
14 |
+
cond_from: coarse_logits_1x1
|
15 |
+
cond_crop: patch # per published method (method_yq)
|
16 |
+
|
17 |
+
minmax:
|
18 |
+
enable: true
|
19 |
+
kernel: 6 # fixed 6x6 luminance min/max
|
20 |
+
|
21 |
+
label:
|
22 |
+
coarse_downsample: maxpool
|
23 |
+
|
24 |
+
inference:
|
25 |
+
alpha: 0.01
|
26 |
+
prob_threshold: 0.5
|
27 |
+
stitch: avg_logits
|
28 |
+
|
29 |
+
optim:
|
30 |
+
iters: 40000
|
31 |
+
batch_size: 8
|
32 |
+
lr: 6e-5
|
33 |
+
weight_decay: 0.01
|
34 |
+
schedule: poly
|
35 |
+
power: 1.0
|
36 |
+
|
37 |
+
# dataset paths (placeholders)
|
38 |
+
data:
|
39 |
+
train_images: /path/to/train/images
|
40 |
+
train_masks: /path/to/train/masks
|
41 |
+
val_images: /path/to/val/images
|
42 |
+
val_masks: /path/to/val/masks
|
pytest.ini
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
pythonpath = src
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.1.0
|
2 |
+
torchvision>=0.16.0
|
3 |
+
timm>=0.9.8
|
4 |
+
numpy>=1.24.0
|
5 |
+
opencv-python>=4.8.0.76
|
6 |
+
Pillow>=9.5.0
|
7 |
+
PyYAML>=6.0.1
|
8 |
+
tqdm>=4.65.0
|
src/wireseghr/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"model",
|
3 |
+
"data",
|
4 |
+
]
|
5 |
+
|
6 |
+
__version__ = "0.1.0"
|
src/wireseghr/data/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset import WireSegDataset
|
2 |
+
from .sampler import BalancedPatchSampler
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"WireSegDataset",
|
6 |
+
"BalancedPatchSampler",
|
7 |
+
]
|
src/wireseghr/data/dataset.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset placeholder for wire segmentation
|
2 |
+
"""WireSeg dataset indexing and loading.
|
3 |
+
|
4 |
+
Pairs images in `images_dir` with masks in `masks_dir` by matching filename stems.
|
5 |
+
Mask is loaded as single-channel 0/1.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Any, Dict, List
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
|
16 |
+
class WireSegDataset:
|
17 |
+
def __init__(self, images_dir: str, masks_dir: str, split: str = "train"):
|
18 |
+
self.images_dir = Path(images_dir)
|
19 |
+
self.masks_dir = Path(masks_dir)
|
20 |
+
self.split = split
|
21 |
+
assert self.images_dir.exists(), f"Missing images_dir: {self.images_dir}"
|
22 |
+
assert self.masks_dir.exists(), f"Missing masks_dir: {self.masks_dir}"
|
23 |
+
self._items: List[tuple[Path, Path]] = self._index_pairs()
|
24 |
+
|
25 |
+
def __len__(self) -> int:
|
26 |
+
return len(self._items)
|
27 |
+
|
28 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
29 |
+
img_path, mask_path = self._items[idx]
|
30 |
+
img_bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
31 |
+
assert img_bgr is not None, f"Failed to read image: {img_path}"
|
32 |
+
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
33 |
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
34 |
+
assert mask is not None, f"Failed to read mask: {mask_path}"
|
35 |
+
mask_bin = (mask > 0).astype(np.uint8)
|
36 |
+
return {"image": img, "mask": mask_bin, "image_path": str(img_path), "mask_path": str(mask_path)}
|
37 |
+
|
38 |
+
def _index_pairs(self) -> List[tuple[Path, Path]]:
|
39 |
+
exts_img = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
|
40 |
+
exts_mask = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
|
41 |
+
imgs: Dict[str, Path] = {}
|
42 |
+
for p in sorted(self.images_dir.rglob("*")):
|
43 |
+
if p.is_file() and p.suffix.lower() in exts_img:
|
44 |
+
imgs[p.stem] = p
|
45 |
+
masks: Dict[str, Path] = {}
|
46 |
+
for p in sorted(self.masks_dir.rglob("*")):
|
47 |
+
if p.is_file() and p.suffix.lower() in exts_mask:
|
48 |
+
masks[p.stem] = p
|
49 |
+
pairs: List[tuple[Path, Path]] = []
|
50 |
+
for stem, ip in imgs.items():
|
51 |
+
if stem in masks:
|
52 |
+
pairs.append((ip, masks[stem]))
|
53 |
+
assert len(pairs) > 0, f"No image-mask pairs found in {self.images_dir} and {self.masks_dir}"
|
54 |
+
return pairs
|
src/wireseghr/data/sampler.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Balanced patch sampler (>=1% wire pixels)
|
2 |
+
# TODO: Implement logic over mask to pick patches with wire ratio >= threshold.
|
3 |
+
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class BalancedPatchSampler:
|
9 |
+
patch_size: int = 768
|
10 |
+
min_wire_ratio: float = 0.01
|
11 |
+
|
12 |
+
def sample(self, image, mask):
|
13 |
+
# TODO: sample and return top-left (y, x) of a valid patch
|
14 |
+
return 0, 0
|
src/wireseghr/data/transforms.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training-time transforms: scaling, rotation, flip, photometric distortion
|
2 |
+
# TODO: Implement deterministic transform composition for reproducibility
|
3 |
+
|
4 |
+
class TrainTransforms:
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def __call__(self, image, mask):
|
9 |
+
return image, mask
|
src/wireseghr/infer.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pprint
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser(description="WireSegHR inference (skeleton)")
|
9 |
+
parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to YAML config")
|
10 |
+
parser.add_argument("--image", type=str, required=False, help="Path to input image")
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
cfg_path = args.config
|
14 |
+
if not os.path.isabs(cfg_path):
|
15 |
+
cfg_path = os.path.join(os.getcwd(), cfg_path)
|
16 |
+
|
17 |
+
with open(cfg_path, "r") as f:
|
18 |
+
cfg = yaml.safe_load(f)
|
19 |
+
|
20 |
+
print("[WireSegHR][infer] Loaded config from:", cfg_path)
|
21 |
+
pprint.pprint(cfg)
|
22 |
+
print("[WireSegHR][infer] Image:", args.image)
|
23 |
+
print("[WireSegHR][infer] Skeleton OK. Implement inference per SEGMENTATION_PLAN.md.")
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
main()
|
src/wireseghr/metrics.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Metrics placeholder: IoU, F1, Precision, Recall
|
2 |
+
# TODO: Implement proper metrics matching paper tables.
|
3 |
+
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
|
7 |
+
def compute_metrics(pred_mask, gt_mask) -> Dict[str, float]:
|
8 |
+
# TODO: implement
|
9 |
+
return {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
src/wireseghr/model/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoder import SegFormerEncoder
|
2 |
+
from .decoder import CoarseDecoder, FineDecoder
|
3 |
+
from .condition import Conditioning1x1
|
4 |
+
from .minmax import MinMaxLuminance
|
5 |
+
from .label_downsample import downsample_label_maxpool
|
6 |
+
from .model import WireSegHR
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"SegFormerEncoder",
|
10 |
+
"CoarseDecoder",
|
11 |
+
"FineDecoder",
|
12 |
+
"Conditioning1x1",
|
13 |
+
"MinMaxLuminance",
|
14 |
+
"downsample_label_maxpool",
|
15 |
+
"WireSegHR",
|
16 |
+
]
|
src/wireseghr/model/condition.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 1x1 conv to collapse 2-ch coarse logits into 1-ch conditioning map
|
2 |
+
# TODO: Wire with coarse decoder outputs and proper resize/cropping.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class Conditioning1x1(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.conv = nn.Conv2d(2, 1, kernel_size=1, bias=True)
|
12 |
+
|
13 |
+
def forward(self, coarse_logits: torch.Tensor) -> torch.Tensor:
|
14 |
+
return self.conv(coarse_logits)
|
src/wireseghr/model/decoder.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SegFormer-like multi-scale decoder heads for coarse and fine branches.
|
2 |
+
|
3 |
+
Fuse four feature maps from MiT encoder via 1x1 projections, upsample to the
|
4 |
+
highest spatial resolution (stage 0), concatenate, and predict 2-class logits.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class _ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_ch: int, out_ch: int, k: int, s: int = 1, p: int = 0):
|
16 |
+
super().__init__()
|
17 |
+
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False)
|
18 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
19 |
+
self.relu = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
22 |
+
x = self.conv(x)
|
23 |
+
x = self.bn(x)
|
24 |
+
x = self.relu(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
class _SegFormerHead(nn.Module):
|
29 |
+
def __init__(self, in_chs: List[int], embed_dim: int = 128, num_classes: int = 2):
|
30 |
+
super().__init__()
|
31 |
+
assert len(in_chs) == 4
|
32 |
+
self.proj = nn.ModuleList([nn.Conv2d(c, embed_dim, kernel_size=1) for c in in_chs])
|
33 |
+
self.fuse = _ConvBNReLU(embed_dim * 4, embed_dim, k=3, p=1)
|
34 |
+
self.cls = nn.Conv2d(embed_dim, num_classes, kernel_size=1)
|
35 |
+
|
36 |
+
def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
|
37 |
+
assert len(feats) == 4
|
38 |
+
h, w = feats[0].shape[2], feats[0].shape[3]
|
39 |
+
xs = []
|
40 |
+
for f, proj in zip(feats, self.proj):
|
41 |
+
x = proj(f)
|
42 |
+
if x.shape[2] != h or x.shape[3] != w:
|
43 |
+
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
|
44 |
+
xs.append(x)
|
45 |
+
x = torch.cat(xs, dim=1)
|
46 |
+
x = self.fuse(x)
|
47 |
+
x = self.cls(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class CoarseDecoder(_SegFormerHead):
|
52 |
+
def __init__(self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2):
|
53 |
+
super().__init__(list(in_chs), embed_dim, num_classes)
|
54 |
+
|
55 |
+
|
56 |
+
class FineDecoder(_SegFormerHead):
|
57 |
+
def __init__(self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2):
|
58 |
+
super().__init__(list(in_chs), embed_dim, num_classes)
|
src/wireseghr/model/encoder.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SegFormer MiT encoder wrapper with adjustable input channels.
|
2 |
+
|
3 |
+
Uses timm to instantiate MiT (e.g., mit_b3) and returns a list of multi-scale
|
4 |
+
features [C1, C2, C3, C4].
|
5 |
+
"""
|
6 |
+
|
7 |
+
from typing import List, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import timm
|
12 |
+
|
13 |
+
|
14 |
+
class SegFormerEncoder(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
backbone: str = "mit_b3",
|
18 |
+
in_channels: int = 7,
|
19 |
+
pretrained: bool = True,
|
20 |
+
out_indices: Tuple[int, int, int, int] = (0, 1, 2, 3),
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.backbone_name = backbone
|
24 |
+
self.in_channels = in_channels
|
25 |
+
self.pretrained = pretrained
|
26 |
+
self.out_indices = out_indices
|
27 |
+
|
28 |
+
# Create MiT with features_only to obtain multi-scale feature maps.
|
29 |
+
# in_chans allows expanded inputs (RGB + minmax + cond + loc)
|
30 |
+
self.encoder = timm.create_model(
|
31 |
+
backbone,
|
32 |
+
pretrained=pretrained,
|
33 |
+
features_only=True,
|
34 |
+
out_indices=out_indices,
|
35 |
+
in_chans=in_channels,
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
39 |
+
feats = self.encoder(x)
|
40 |
+
# Ensure list of tensors is returned
|
41 |
+
assert isinstance(feats, (list, tuple)) and len(feats) == len(self.out_indices)
|
42 |
+
return list(feats)
|
src/wireseghr/model/label_downsample.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MaxPool-based downsampling for coarse labels
|
2 |
+
"""Downsample binary masks preserving thin positives.
|
3 |
+
|
4 |
+
We use area-based resize on float32 masks followed by a >0 threshold.
|
5 |
+
This emulates block-wise max pooling: any positive in the source region
|
6 |
+
produces a positive in the target pixel.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def downsample_label_maxpool(mask: np.ndarray, out_h: int, out_w: int) -> np.ndarray:
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
mask: HxW binary (0/1) numpy array
|
16 |
+
out_h, out_w: target size
|
17 |
+
Returns:
|
18 |
+
H'xW' binary array via max-pooling-like downsample
|
19 |
+
"""
|
20 |
+
assert mask.ndim == 2
|
21 |
+
# Convert to float32 so area resize yields fractional averages > 0 if any positive present
|
22 |
+
import cv2
|
23 |
+
m = mask.astype(np.float32)
|
24 |
+
r = cv2.resize(m, (out_w, out_h), interpolation=cv2.INTER_AREA)
|
25 |
+
out = (r > 0.0).astype(np.uint8)
|
26 |
+
return out
|
src/wireseghr/model/minmax.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MinMax luminance feature computation (6x6 window)
|
2 |
+
# Implemented with OpenCV morphology (erode=min, dilate=max) using 6x6 kernel and replicate border.
|
3 |
+
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class MinMaxLuminance:
|
10 |
+
def __init__(self, kernel: int = 6):
|
11 |
+
assert kernel == 6, "Per plan, kernel is fixed to 6x6"
|
12 |
+
self.kernel = kernel
|
13 |
+
|
14 |
+
def __call__(self, img_rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
img_rgb: HxWx3 uint8 or float32 in [0,255] or [0,1]
|
18 |
+
Returns:
|
19 |
+
(Y_min, Y_max): two HxW float32 arrays
|
20 |
+
"""
|
21 |
+
assert img_rgb.ndim == 3 and img_rgb.shape[2] == 3
|
22 |
+
r, g, b = img_rgb[..., 0], img_rgb[..., 1], img_rgb[..., 2]
|
23 |
+
y = (0.299 * r + 0.587 * g + 0.114 * b).astype(np.float32)
|
24 |
+
|
25 |
+
import cv2 # lazy import to avoid test-time dependency at module import
|
26 |
+
kernel = np.ones((self.kernel, self.kernel), dtype=np.uint8)
|
27 |
+
y_min = cv2.erode(y, kernel, borderType=cv2.BORDER_REPLICATE)
|
28 |
+
y_max = cv2.dilate(y, kernel, borderType=cv2.BORDER_REPLICATE)
|
29 |
+
return y_min.astype(np.float32), y_max.astype(np.float32)
|
src/wireseghr/model/model.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .encoder import SegFormerEncoder
|
7 |
+
from .decoder import CoarseDecoder, FineDecoder
|
8 |
+
from .condition import Conditioning1x1
|
9 |
+
|
10 |
+
|
11 |
+
class WireSegHR(nn.Module):
|
12 |
+
"""
|
13 |
+
Two-stage WireSegHR model wrapper with shared encoder.
|
14 |
+
|
15 |
+
Expects callers to prepare input channel stacks according to the plan:
|
16 |
+
- Coarse input: RGB + MinMax (and any extra channels per config), shape (B, Cc, Hc, Wc)
|
17 |
+
- Fine input: RGB + MinMax + cond_crop + binary_location_mask, shape (B, Cf, p, p)
|
18 |
+
|
19 |
+
Conditioning 1x1 is applied to coarse logits to produce a single-channel map.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True):
|
23 |
+
super().__init__()
|
24 |
+
self.encoder = SegFormerEncoder(backbone=backbone, in_channels=in_channels, pretrained=pretrained)
|
25 |
+
# Default MiT-B3 channel dims for stages
|
26 |
+
in_chs = (64, 128, 320, 512)
|
27 |
+
self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
|
28 |
+
self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
|
29 |
+
self.cond1x1 = Conditioning1x1()
|
30 |
+
|
31 |
+
def forward_coarse(self, x_coarse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
32 |
+
assert x_coarse.dim() == 4
|
33 |
+
feats = self.encoder(x_coarse)
|
34 |
+
logits_coarse = self.coarse_head(feats)
|
35 |
+
cond_map = self.cond1x1(logits_coarse)
|
36 |
+
return logits_coarse, cond_map
|
37 |
+
|
38 |
+
def forward_fine(self, x_fine: torch.Tensor) -> torch.Tensor:
|
39 |
+
assert x_fine.dim() == 4
|
40 |
+
feats = self.encoder(x_fine)
|
41 |
+
logits_fine = self.fine_head(feats)
|
42 |
+
return logits_fine
|
src/wireseghr/train.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pprint
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
|
9 |
+
parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to YAML config")
|
10 |
+
args = parser.parse_args()
|
11 |
+
|
12 |
+
cfg_path = args.config
|
13 |
+
if not os.path.isabs(cfg_path):
|
14 |
+
cfg_path = os.path.join(os.getcwd(), cfg_path)
|
15 |
+
|
16 |
+
with open(cfg_path, "r") as f:
|
17 |
+
cfg = yaml.safe_load(f)
|
18 |
+
|
19 |
+
print("[WireSegHR][train] Loaded config from:", cfg_path)
|
20 |
+
pprint.pprint(cfg)
|
21 |
+
print("[WireSegHR][train] Skeleton OK. Implement training per SEGMENTATION_PLAN.md.")
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
main()
|
src/wireseghr/utils.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
def log(msg: str):
|
2 |
+
print(f"[WireSegHR] {msg}")
|
tests/test_model_forward.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from wireseghr.model import WireSegHR
|
4 |
+
|
5 |
+
|
6 |
+
def test_wireseghr_forward_shapes():
|
7 |
+
# Use small input to keep test light and avoid downloading weights
|
8 |
+
model = WireSegHR(backbone="mit_b3", in_channels=3, pretrained=False)
|
9 |
+
|
10 |
+
x = torch.randn(1, 3, 64, 64)
|
11 |
+
logits_coarse, cond = model.forward_coarse(x)
|
12 |
+
assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
|
13 |
+
assert cond.shape[0] == 1 and cond.shape[1] == 1
|
14 |
+
# Expect stage 0 resolution ~ 1/4 of input for MiT
|
15 |
+
assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
|
16 |
+
assert cond.shape[2] == 16 and cond.shape[3] == 16
|
17 |
+
|
18 |
+
logits_fine = model.forward_fine(x)
|
19 |
+
assert logits_fine.shape == logits_coarse.shape
|
tests/test_skeleton_imports.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def test_imports():
|
2 |
+
import wireseghr
|
3 |
+
import wireseghr.model as m
|
4 |
+
import wireseghr.data as d
|
5 |
+
|
6 |
+
assert hasattr(m, "SegFormerEncoder")
|
7 |
+
assert hasattr(d, "WireSegDataset")
|