MRiabov commited on
Commit
500cad9
·
1 Parent(s): d2db539

Eval to infer.py

Browse files
configs/default.yaml CHANGED
@@ -7,7 +7,7 @@ coarse:
7
  test_size: 1024
8
 
9
  fine:
10
- patch_size: 768
11
  overlap: 128
12
 
13
  conditioning:
@@ -23,12 +23,13 @@ label:
23
 
24
  inference:
25
  alpha: 0.01
26
- prob_threshold: 0.3 # was 0.5, not actually mentioned in the paper.
 
27
  stitch: avg_logits
28
 
29
  eval:
30
  max_samples: 16
31
- fine_batch: 48
32
 
33
  optim:
34
  iters: 2000
@@ -44,7 +45,7 @@ seed: 42
44
  out_dir: runs/wireseghr
45
  eval_interval: 100
46
  ckpt_interval: 300
47
- resume: runs/wireseghr/ckpt_300.pt # optional
48
 
49
  # dataset paths (placeholders)
50
  data:
 
7
  test_size: 1024
8
 
9
  fine:
10
+ patch_size: 512
11
  overlap: 128
12
 
13
  conditioning:
 
23
 
24
  inference:
25
  alpha: 0.01
26
+ prob_threshold: 0.5 # default inference threshold per paper tuning
27
+ fine_patch_size: 1024
28
  stitch: avg_logits
29
 
30
  eval:
31
  max_samples: 16
32
+ fine_batch: 32
33
 
34
  optim:
35
  iters: 2000
 
45
  out_dir: runs/wireseghr
46
  eval_interval: 100
47
  ckpt_interval: 300
48
+ resume: runs/wireseghr/ckpt_1800.pt # optional
49
 
50
  # dataset paths (placeholders)
51
  data:
infer.py CHANGED
@@ -1,15 +1,260 @@
1
  import argparse
2
  import os
3
  import pprint
 
4
  import yaml
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def main():
8
- parser = argparse.ArgumentParser(description="WireSegHR inference (skeleton)")
9
  parser.add_argument(
10
  "--config", type=str, default="configs/default.yaml", help="Path to YAML config"
11
  )
12
  parser.add_argument("--image", type=str, required=False, help="Path to input image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  args = parser.parse_args()
14
 
15
  cfg_path = args.config
@@ -21,10 +266,66 @@ def main():
21
 
22
  print("[WireSegHR][infer] Loaded config from:", cfg_path)
23
  pprint.pprint(cfg)
24
- print("[WireSegHR][infer] Image:", args.image)
25
- print(
26
- "[WireSegHR][infer] Skeleton OK. Implement inference per SEGMENTATION_PLAN.md."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  if __name__ == "__main__":
 
1
  import argparse
2
  import os
3
  import pprint
4
+ from typing import List, Tuple, Optional
5
  import yaml
6
 
7
+ import numpy as np
8
+ import cv2
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.amp import autocast
12
+
13
+ from src.wireseghr.model import WireSegHR
14
+
15
+
16
+ def _pad_for_minmax(kernel: int) -> Tuple[int, int, int, int]:
17
+ # Replicate the padding logic from train.validate for even/odd kernels
18
+ if (kernel % 2) == 0:
19
+ return (kernel // 2 - 1, kernel // 2, kernel // 2 - 1, kernel // 2)
20
+ else:
21
+ return (kernel // 2, kernel // 2, kernel // 2, kernel // 2)
22
+
23
+
24
+ @torch.no_grad()
25
+ def _coarse_forward(
26
+ model: WireSegHR,
27
+ img_rgb: np.ndarray,
28
+ coarse_size: int,
29
+ minmax_enable: bool,
30
+ minmax_kernel: int,
31
+ device: torch.device,
32
+ amp_flag: bool,
33
+ amp_dtype,
34
+ ) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
35
+ # Convert to tensor on device
36
+ t_img = (
37
+ torch.from_numpy(np.transpose(img_rgb, (2, 0, 1)))
38
+ .unsqueeze(0)
39
+ .to(device)
40
+ .float()
41
+ ) # 1x3xHxW
42
+ H = img_rgb.shape[0]
43
+ W = img_rgb.shape[1]
44
+
45
+ rgb_c = F.interpolate(
46
+ t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
47
+ )[0]
48
+ y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
49
+ if minmax_enable:
50
+ pad = _pad_for_minmax(minmax_kernel)
51
+ y_p = F.pad(y_t, pad, mode="replicate")
52
+ y_max_full = F.max_pool2d(y_p, kernel_size=minmax_kernel, stride=1)
53
+ y_min_full = -F.max_pool2d(-y_p, kernel_size=minmax_kernel, stride=1)
54
+ else:
55
+ y_min_full = y_t
56
+ y_max_full = y_t
57
+ y_min_c = F.interpolate(
58
+ y_min_full,
59
+ size=(coarse_size, coarse_size),
60
+ mode="bilinear",
61
+ align_corners=False,
62
+ )[0]
63
+ y_max_c = F.interpolate(
64
+ y_max_full,
65
+ size=(coarse_size, coarse_size),
66
+ mode="bilinear",
67
+ align_corners=False,
68
+ )[0]
69
+ zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
70
+ x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
71
+
72
+ with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
73
+ logits_c, cond_map = model.forward_coarse(x_t)
74
+ prob = torch.softmax(logits_c, dim=1)[:, 1:2]
75
+ prob_up = (
76
+ F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
77
+ .detach()
78
+ .cpu()
79
+ .numpy()
80
+ )
81
+ return prob_up, cond_map, t_img, y_min_full, y_max_full
82
+
83
+
84
+ @torch.no_grad()
85
+ def _tiled_fine_forward(
86
+ model: WireSegHR,
87
+ t_img: torch.Tensor, # 1x3xHxW on device
88
+ cond_map: torch.Tensor, # 1x1xhxw
89
+ y_min_full: torch.Tensor, # 1x1xHxW
90
+ y_max_full: torch.Tensor, # 1x1xHxW
91
+ patch_size: int,
92
+ overlap: int,
93
+ fine_batch: int,
94
+ device: torch.device,
95
+ amp_flag: bool,
96
+ amp_dtype,
97
+ ) -> np.ndarray:
98
+ H = int(t_img.shape[2])
99
+ W = int(t_img.shape[3])
100
+ P = patch_size
101
+ stride = P - overlap
102
+ assert stride > 0
103
+ assert H >= P and W >= P
104
+
105
+ prob_sum_t = torch.zeros((H, W), device=device, dtype=torch.float32)
106
+ weight_t = torch.zeros((H, W), device=device, dtype=torch.float32)
107
+
108
+ hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
109
+
110
+ ys = list(range(0, H - P + 1, stride))
111
+ if ys[-1] != (H - P):
112
+ ys.append(H - P)
113
+ xs = list(range(0, W - P + 1, stride))
114
+ if xs[-1] != (W - P):
115
+ xs.append(W - P)
116
+
117
+ coords: List[Tuple[int, int]] = []
118
+ for y0 in ys:
119
+ for x0 in xs:
120
+ coords.append((y0, x0))
121
+
122
+ for i0 in range(0, len(coords), fine_batch):
123
+ batch_coords = coords[i0 : i0 + fine_batch]
124
+ xs_list: List[torch.Tensor] = []
125
+ for y0, x0 in batch_coords:
126
+ y1, x1 = y0 + P, x0 + P
127
+ # Map to cond grid
128
+ y0c = (y0 * hc4) // H
129
+ y1c = ((y1 * hc4) + H - 1) // H
130
+ x0c = (x0 * wc4) // W
131
+ x1c = ((x1 * wc4) + W - 1) // W
132
+ cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c].float()
133
+ cond_patch = F.interpolate(
134
+ cond_sub, size=(P, P), mode="bilinear", align_corners=False
135
+ ).squeeze(1) # 1xPxP
136
+
137
+ rgb_t = t_img[0, :, y0:y1, x0:x1] # 3xPxP
138
+ ymin_t = y_min_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
139
+ ymax_t = y_max_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
140
+ x_f = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0).unsqueeze(0)
141
+ xs_list.append(x_f)
142
+
143
+ x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
144
+ with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
145
+ logits_f = model.forward_fine(x_f_batch)
146
+ prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
147
+ prob_f_up = F.interpolate(
148
+ prob_f, size=(P, P), mode="bilinear", align_corners=False
149
+ )[:, 0, :, :] # BxPxP
150
+
151
+ for bi, (y0, x0) in enumerate(batch_coords):
152
+ y1, x1 = y0 + P, x0 + P
153
+ prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
154
+ weight_t[y0:y1, x0:x1] += 1.0
155
+
156
+ prob_full = (prob_sum_t / weight_t).detach().cpu().numpy()
157
+ return prob_full
158
+
159
+
160
+ def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR:
161
+ pretrained_flag = bool(cfg.get("pretrained", False))
162
+ model = WireSegHR(
163
+ backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag
164
+ )
165
+ model = model.to(device)
166
+ return model
167
+
168
+
169
+ @torch.no_grad()
170
+ def infer_image(
171
+ model: WireSegHR,
172
+ img_path: str,
173
+ cfg: dict,
174
+ device: torch.device,
175
+ amp_flag: bool,
176
+ amp_dtype,
177
+ out_dir: Optional[str] = None,
178
+ save_prob: bool = False,
179
+ prob_thresh: Optional[float] = None,
180
+ ) -> Tuple[np.ndarray, np.ndarray]:
181
+ assert os.path.isfile(img_path), f"Image not found: {img_path}"
182
+ bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
183
+ assert bgr is not None, f"Failed to read {img_path}"
184
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
185
+
186
+ coarse_size = int(cfg["coarse"]["test_size"])
187
+ patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024 for inference
188
+ overlap = int(cfg["fine"]["overlap"])
189
+ minmax_enable = bool(cfg["minmax"]["enable"])
190
+ minmax_kernel = int(cfg["minmax"]["kernel"])
191
+ if prob_thresh is None:
192
+ prob_thresh = float(cfg["inference"]["prob_threshold"])
193
+
194
+ prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_forward(
195
+ model,
196
+ rgb,
197
+ coarse_size,
198
+ minmax_enable,
199
+ minmax_kernel,
200
+ device,
201
+ amp_flag,
202
+ amp_dtype,
203
+ )
204
+
205
+ prob_f = _tiled_fine_forward(
206
+ model,
207
+ t_img,
208
+ cond_map,
209
+ y_min_full,
210
+ y_max_full,
211
+ patch_size,
212
+ overlap,
213
+ int(cfg.get("eval", {}).get("fine_batch", 16)),
214
+ device,
215
+ amp_flag,
216
+ amp_dtype,
217
+ )
218
+
219
+ pred = (prob_f > prob_thresh).astype(np.uint8) * 255
220
+
221
+ if out_dir is not None:
222
+ os.makedirs(out_dir, exist_ok=True)
223
+ stem = os.path.splitext(os.path.basename(img_path))[0]
224
+ out_mask = os.path.join(out_dir, f"{stem}_pred.png")
225
+ cv2.imwrite(out_mask, pred)
226
+ if save_prob:
227
+ out_prob = os.path.join(out_dir, f"{stem}_prob.npy")
228
+ np.save(out_prob, prob_f.astype(np.float32))
229
+
230
+ return pred, prob_f
231
+
232
 
233
  def main():
234
+ parser = argparse.ArgumentParser(description="WireSegHR inference")
235
  parser.add_argument(
236
  "--config", type=str, default="configs/default.yaml", help="Path to YAML config"
237
  )
238
  parser.add_argument("--image", type=str, required=False, help="Path to input image")
239
+ parser.add_argument(
240
+ "--images_dir",
241
+ type=str,
242
+ required=False,
243
+ help="Directory with .jpg/.jpeg images",
244
+ )
245
+ parser.add_argument(
246
+ "--out", type=str, default="outputs/infer", help="Directory to save predictions"
247
+ )
248
+ parser.add_argument(
249
+ "--ckpt",
250
+ type=str,
251
+ default="",
252
+ help="Optional checkpoint (.pt) with model state",
253
+ )
254
+ parser.add_argument(
255
+ "--save_prob", action="store_true", help="Also save probability .npy"
256
+ )
257
+
258
  args = parser.parse_args()
259
 
260
  cfg_path = args.config
 
266
 
267
  print("[WireSegHR][infer] Loaded config from:", cfg_path)
268
  pprint.pprint(cfg)
269
+
270
+ assert (args.image is not None) ^ (args.images_dir is not None), (
271
+ "Provide exactly one of --image or --images_dir"
272
+ )
273
+
274
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275
+ precision = str(cfg["optim"].get("precision", "fp32")).lower()
276
+ assert precision in ("fp32", "fp16", "bf16")
277
+ amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16"))
278
+ amp_dtype = (
279
+ torch.float16
280
+ if precision == "fp16"
281
+ else (torch.bfloat16 if precision == "bf16" else None)
282
+ )
283
+
284
+ model = _build_model_from_cfg(cfg, device)
285
+
286
+ ckpt_path = args.ckpt if args.ckpt else cfg.get("resume", "")
287
+ if ckpt_path:
288
+ assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"
289
+ print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}")
290
+ state = torch.load(ckpt_path, map_location=device)
291
+ model.load_state_dict(state["model"])
292
+ model.eval()
293
+
294
+ if args.image is not None:
295
+ infer_image(
296
+ model,
297
+ args.image,
298
+ cfg,
299
+ device,
300
+ amp_enabled,
301
+ amp_dtype,
302
+ out_dir=args.out,
303
+ save_prob=args.save_prob,
304
+ )
305
+ print("[WireSegHR][infer] Done.")
306
+ return
307
+
308
+ # Directory mode
309
+ img_dir = args.images_dir
310
+ assert os.path.isdir(img_dir), f"Not a directory: {img_dir}"
311
+ img_files = sorted(
312
+ [p for p in os.listdir(img_dir) if p.lower().endswith((".jpg", ".jpeg"))]
313
  )
314
+ assert len(img_files) > 0, f"No .jpg/.jpeg in {img_dir}"
315
+ os.makedirs(args.out, exist_ok=True)
316
+ for name in img_files:
317
+ path = os.path.join(img_dir, name)
318
+ infer_image(
319
+ model,
320
+ path,
321
+ cfg,
322
+ device,
323
+ amp_enabled,
324
+ amp_dtype,
325
+ out_dir=args.out,
326
+ save_prob=args.save_prob,
327
+ )
328
+ print("[WireSegHR][infer] Done.")
329
 
330
 
331
  if __name__ == "__main__":
scripts/pull_and_preprocess_wireseghr_dataset.py CHANGED
@@ -80,7 +80,11 @@ def download_folder(folder_id, dest, service_account_json, workers: int):
80
  for meta in files_with_paths:
81
  out_path = os.path.join(dest, meta["rel_path"])
82
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
83
- if meta["size"] > 0 and os.path.exists(out_path) and os.path.getsize(out_path) == meta["size"]:
 
 
 
 
84
  skipped += 1
85
  continue
86
  tasks.append((meta["id"], out_path))
@@ -98,7 +102,9 @@ def download_folder(folder_id, dest, service_account_json, workers: int):
98
 
99
  with ThreadPoolExecutor(max_workers=workers) as ex:
100
  futures = [ex.submit(_download_one, fid, path) for fid, path in tasks]
101
- for _ in tqdm(as_completed(futures), total=len(futures), desc="Downloading", unit="file"):
 
 
102
  pass
103
 
104
 
@@ -137,7 +143,9 @@ def pull(args=None):
137
 
138
 
139
  def _index_numeric_pairs(images_dir: Path, masks_dir: Path):
140
- assert images_dir.exists() and images_dir.is_dir(), f"Missing images_dir: {images_dir}"
 
 
141
  assert masks_dir.exists() and masks_dir.is_dir(), f"Missing masks_dir: {masks_dir}"
142
  img_files = sorted([p for p in images_dir.glob("*.jpg") if p.is_file()])
143
  img_files += sorted([p for p in images_dir.glob("*.jpeg") if p.is_file()])
@@ -247,7 +255,9 @@ if __name__ == "__main__":
247
  subs = top.add_subparsers(dest="cmd", required=True)
248
 
249
  sp_pull = subs.add_parser("pull", help="Download dataset from Google Drive")
250
- sp_pull.add_argument("--folder-id", dest="folder_id", default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p")
 
 
251
  sp_pull.add_argument("--output-dir", dest="output_dir", default="dataset/")
252
  sp_pull.add_argument("--service-account", default="secrets/drive-json.json")
253
  sp_pull.add_argument("--workers", type=int, default=8)
@@ -265,26 +275,30 @@ if __name__ == "__main__":
265
 
266
  ns = top.parse_args()
267
  if ns.cmd == "pull":
268
- pull([
269
- "--folder-id",
270
- ns.folder_id,
271
- "--output-dir",
272
- ns.output_dir,
273
- "--service-account",
274
- ns.service_account,
275
- "--workers",
276
- str(ns.workers),
277
- ])
 
 
278
  elif ns.cmd == "split_test_train_val":
279
- split_test_train_val([
280
- "--images-dir",
281
- ns.images_dir,
282
- "--masks-dir",
283
- ns.masks_dir,
284
- "--out-dir",
285
- ns.out_dir,
286
- "--seed",
287
- str(ns.seed),
288
- "--link-method",
289
- ns.link_method,
290
- ])
 
 
 
80
  for meta in files_with_paths:
81
  out_path = os.path.join(dest, meta["rel_path"])
82
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
83
+ if (
84
+ meta["size"] > 0
85
+ and os.path.exists(out_path)
86
+ and os.path.getsize(out_path) == meta["size"]
87
+ ):
88
  skipped += 1
89
  continue
90
  tasks.append((meta["id"], out_path))
 
102
 
103
  with ThreadPoolExecutor(max_workers=workers) as ex:
104
  futures = [ex.submit(_download_one, fid, path) for fid, path in tasks]
105
+ for _ in tqdm(
106
+ as_completed(futures), total=len(futures), desc="Downloading", unit="file"
107
+ ):
108
  pass
109
 
110
 
 
143
 
144
 
145
  def _index_numeric_pairs(images_dir: Path, masks_dir: Path):
146
+ assert images_dir.exists() and images_dir.is_dir(), (
147
+ f"Missing images_dir: {images_dir}"
148
+ )
149
  assert masks_dir.exists() and masks_dir.is_dir(), f"Missing masks_dir: {masks_dir}"
150
  img_files = sorted([p for p in images_dir.glob("*.jpg") if p.is_file()])
151
  img_files += sorted([p for p in images_dir.glob("*.jpeg") if p.is_file()])
 
255
  subs = top.add_subparsers(dest="cmd", required=True)
256
 
257
  sp_pull = subs.add_parser("pull", help="Download dataset from Google Drive")
258
+ sp_pull.add_argument(
259
+ "--folder-id", dest="folder_id", default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p"
260
+ )
261
  sp_pull.add_argument("--output-dir", dest="output_dir", default="dataset/")
262
  sp_pull.add_argument("--service-account", default="secrets/drive-json.json")
263
  sp_pull.add_argument("--workers", type=int, default=8)
 
275
 
276
  ns = top.parse_args()
277
  if ns.cmd == "pull":
278
+ pull(
279
+ [
280
+ "--folder-id",
281
+ ns.folder_id,
282
+ "--output-dir",
283
+ ns.output_dir,
284
+ "--service-account",
285
+ ns.service_account,
286
+ "--workers",
287
+ str(ns.workers),
288
+ ]
289
+ )
290
  elif ns.cmd == "split_test_train_val":
291
+ split_test_train_val(
292
+ [
293
+ "--images-dir",
294
+ ns.images_dir,
295
+ "--masks-dir",
296
+ ns.masks_dir,
297
+ "--out-dir",
298
+ ns.out_dir,
299
+ "--seed",
300
+ str(ns.seed),
301
+ "--link-method",
302
+ ns.link_method,
303
+ ]
304
+ )
scripts/setup_script.sh CHANGED
@@ -5,9 +5,9 @@ set -euo pipefail
5
 
6
  # 0) Setup env (includes gdown used by scripts/pull_ttpla.sh)
7
  pip install uv
8
- uv venv || true
9
- source .venv/bin/activate
10
- pip install uv
11
  uv pip install -r requirements.txt
12
  uv pip install gdown
13
 
 
5
 
6
  # 0) Setup env (includes gdown used by scripts/pull_ttpla.sh)
7
  pip install uv
8
+ # uv venv || true # note: don't create new venv since one exists in vast.ai pytorch image.
9
+ # source .venv/bin/activate
10
+ # pip install uv
11
  uv pip install -r requirements.txt
12
  uv pip install gdown
13
 
src/wireseghr/data/ttpla_to_masks.py CHANGED
@@ -9,7 +9,9 @@ from PIL import Image, ImageDraw
9
  import numpy as np
10
 
11
 
12
- def _rasterize_cable_mask(shapes: List[dict], height: int, width: int, label: str) -> np.ndarray:
 
 
13
  """Rasterize polygons with given label into a binary mask of shape (H, W), values {0,255}.
14
 
15
  Expects LabelMe-style annotations with shape entries containing keys:
@@ -33,7 +35,7 @@ def _rasterize_cable_mask(shapes: List[dict], height: int, width: int, label: st
33
  pts[:, 0] = np.clip(pts[:, 0], 0, width - 1)
34
  pts[:, 1] = np.clip(pts[:, 1], 0, height - 1)
35
  # PIL expects list of (x, y) tuples
36
- pts_list = [ (int(p[0]), int(p[1])) for p in pts ]
37
  draw.polygon(pts_list, outline=255, fill=255)
38
 
39
  mask = np.asarray(mask_img, dtype=np.uint8)
@@ -46,12 +48,14 @@ def _convert_one(json_path: Path, out_dir: Path, label: str) -> Path | None:
46
 
47
  shapes = data["shapes"]
48
  H = int(data["imageHeight"]) # required by given JSON
49
- W = int(data["imageWidth"]) # required by given JSON
50
  image_path = Path(data["imagePath"]) # e.g. "1_00186.jpg"
51
  # WireSegDataset expects numeric filename stems. Derive a numeric-only stem.
52
  stem_raw = image_path.stem
53
  out_stem = "".join([c for c in stem_raw if c.isdigit()])
54
- assert out_stem.isdigit() and len(out_stem) > 0, f"Non-numeric stem derived from {stem_raw}"
 
 
55
 
56
  mask = _rasterize_cable_mask(shapes, H, W, label)
57
 
@@ -62,7 +66,12 @@ def _convert_one(json_path: Path, out_dir: Path, label: str) -> Path | None:
62
  return out_path
63
 
64
 
65
- def convert_ttpla_jsons_to_masks(input_path: str | Path, output_dir: str | Path, label: str = "cable", recursive: bool = True) -> List[Path]:
 
 
 
 
 
66
  """Convert TTPLA LabelMe JSON annotations into binary masks matching WireSegHR conventions.
67
 
68
  - input_path: directory containing JSONs (or a single .json file)
@@ -76,11 +85,15 @@ def convert_ttpla_jsons_to_masks(input_path: str | Path, output_dir: str | Path,
76
  output_p = Path(output_dir)
77
 
78
  if input_p.is_file():
79
- assert input_p.suffix.lower() == ".json", f"Expected a .json file, got: {input_p}"
 
 
80
  out = _convert_one(input_p, output_p, label)
81
  return [out] if out else []
82
 
83
- assert input_p.is_dir(), f"Input path must be a directory or a .json file: {input_p}"
 
 
84
 
85
  json_iter: Iterable[Path]
86
  if recursive:
@@ -97,11 +110,23 @@ def convert_ttpla_jsons_to_masks(input_path: str | Path, output_dir: str | Path,
97
 
98
 
99
  def main(argv: List[str] | None = None) -> None:
100
- parser = argparse.ArgumentParser(description="Convert TTPLA LabelMe JSONs to WireSegHR-style binary masks")
101
- parser.add_argument("--input", required=True, help="Path to a directory of JSONs or a single JSON file")
102
- parser.add_argument("--output", required=True, help="Output directory for PNG masks")
103
- parser.add_argument("--label", default="cable", help="Label to rasterize (default: cable)")
104
- parser.add_argument("--no-recursive", action="store_true", help="Do not search subdirectories")
 
 
 
 
 
 
 
 
 
 
 
 
105
  args = parser.parse_args(argv)
106
 
107
  convert_ttpla_jsons_to_masks(
 
9
  import numpy as np
10
 
11
 
12
+ def _rasterize_cable_mask(
13
+ shapes: List[dict], height: int, width: int, label: str
14
+ ) -> np.ndarray:
15
  """Rasterize polygons with given label into a binary mask of shape (H, W), values {0,255}.
16
 
17
  Expects LabelMe-style annotations with shape entries containing keys:
 
35
  pts[:, 0] = np.clip(pts[:, 0], 0, width - 1)
36
  pts[:, 1] = np.clip(pts[:, 1], 0, height - 1)
37
  # PIL expects list of (x, y) tuples
38
+ pts_list = [(int(p[0]), int(p[1])) for p in pts]
39
  draw.polygon(pts_list, outline=255, fill=255)
40
 
41
  mask = np.asarray(mask_img, dtype=np.uint8)
 
48
 
49
  shapes = data["shapes"]
50
  H = int(data["imageHeight"]) # required by given JSON
51
+ W = int(data["imageWidth"]) # required by given JSON
52
  image_path = Path(data["imagePath"]) # e.g. "1_00186.jpg"
53
  # WireSegDataset expects numeric filename stems. Derive a numeric-only stem.
54
  stem_raw = image_path.stem
55
  out_stem = "".join([c for c in stem_raw if c.isdigit()])
56
+ assert out_stem.isdigit() and len(out_stem) > 0, (
57
+ f"Non-numeric stem derived from {stem_raw}"
58
+ )
59
 
60
  mask = _rasterize_cable_mask(shapes, H, W, label)
61
 
 
66
  return out_path
67
 
68
 
69
+ def convert_ttpla_jsons_to_masks(
70
+ input_path: str | Path,
71
+ output_dir: str | Path,
72
+ label: str = "cable",
73
+ recursive: bool = True,
74
+ ) -> List[Path]:
75
  """Convert TTPLA LabelMe JSON annotations into binary masks matching WireSegHR conventions.
76
 
77
  - input_path: directory containing JSONs (or a single .json file)
 
85
  output_p = Path(output_dir)
86
 
87
  if input_p.is_file():
88
+ assert input_p.suffix.lower() == ".json", (
89
+ f"Expected a .json file, got: {input_p}"
90
+ )
91
  out = _convert_one(input_p, output_p, label)
92
  return [out] if out else []
93
 
94
+ assert input_p.is_dir(), (
95
+ f"Input path must be a directory or a .json file: {input_p}"
96
+ )
97
 
98
  json_iter: Iterable[Path]
99
  if recursive:
 
110
 
111
 
112
  def main(argv: List[str] | None = None) -> None:
113
+ parser = argparse.ArgumentParser(
114
+ description="Convert TTPLA LabelMe JSONs to WireSegHR-style binary masks"
115
+ )
116
+ parser.add_argument(
117
+ "--input",
118
+ required=True,
119
+ help="Path to a directory of JSONs or a single JSON file",
120
+ )
121
+ parser.add_argument(
122
+ "--output", required=True, help="Output directory for PNG masks"
123
+ )
124
+ parser.add_argument(
125
+ "--label", default="cable", help="Label to rasterize (default: cable)"
126
+ )
127
+ parser.add_argument(
128
+ "--no-recursive", action="store_true", help="Do not search subdirectories"
129
+ )
130
  args = parser.parse_args(argv)
131
 
132
  convert_ttpla_jsons_to_masks(
train.py CHANGED
@@ -23,6 +23,7 @@ from src.wireseghr.data.dataset import WireSegDataset
23
  from src.wireseghr.model.label_downsample import downsample_label_maxpool
24
  from src.wireseghr.data.sampler import BalancedPatchSampler
25
  from src.wireseghr.metrics import compute_metrics
 
26
 
27
 
28
  class SizeBatchSampler:
@@ -40,7 +41,7 @@ class SizeBatchSampler:
40
  self._len = 0
41
  for hw, idxs in bins.items():
42
  _ = hw # unused, clarity
43
- self._len += (len(idxs) // self.batch_size)
44
 
45
  def __len__(self) -> int:
46
  return self._len
@@ -54,7 +55,9 @@ class SizeBatchSampler:
54
  pool = list(bins[hw])
55
  random.shuffle(pool)
56
  # Yield only full batches to keep fixed batch size and same-size assumption
57
- for i in range(0, len(pool) - (len(pool) % self.batch_size), self.batch_size):
 
 
58
  yield pool[i : i + self.batch_size]
59
 
60
 
@@ -87,8 +90,10 @@ def main():
87
 
88
  # Config
89
  coarse_train = int(cfg["coarse"]["train_size"]) # 512
90
- patch_size = int(cfg["fine"]["patch_size"]) # 768
 
91
  overlap = int(cfg["fine"]["overlap"]) # e.g., 128
 
92
  eval_cfg = cfg.get("eval", {})
93
  eval_fine_batch = int(eval_cfg.get("fine_batch", 16))
94
  assert eval_fine_batch >= 1
@@ -107,15 +112,17 @@ def main():
107
  if amp_enabled:
108
  cc_major, cc_minor = torch.cuda.get_device_capability()
109
  if precision == "fp16":
110
- assert (
111
- cc_major >= 7
112
- ), f"fp16 requires Volta (SM 7.0)+; current SM {cc_major}.{cc_minor}"
113
  elif precision == "bf16":
114
- assert (
115
- cc_major >= 8
116
- ), f"bf16 requires Ampere (SM 8.0)+; current SM {cc_major}.{cc_minor}"
117
  amp_dtype = (
118
- torch.float16 if precision == "fp16" else (torch.bfloat16 if precision == "bf16" else None)
 
 
119
  )
120
 
121
  # Housekeeping
@@ -135,7 +142,9 @@ def main():
135
  num_workers = int(loader_cfg.get("num_workers", 4))
136
  prefetch_factor = int(loader_cfg.get("prefetch_factor", 2))
137
  pin_memory = bool(loader_cfg.get("pin_memory", True))
138
- persistent_workers = bool(loader_cfg.get("persistent_workers", True)) if num_workers > 0 else False
 
 
139
  batch_sampler = SizeBatchSampler(dset, batch_size)
140
  loader_kwargs = dict(
141
  batch_sampler=batch_sampler,
@@ -252,24 +261,34 @@ def main():
252
  # Eval & Checkpoint
253
  if (step % eval_interval == 0) and (dset_val is not None):
254
  # Free training-step tensors before eval to lower peak memory
255
- del x_fine, logits_coarse, cond_map, logits_fine, y_coarse, y_fine, loss_coarse, loss_fine, loss
 
 
 
 
 
 
 
 
 
 
256
  torch.cuda.empty_cache()
257
  model.eval()
258
  print(
259
- f"[WireSegHR][train] Eval starting... val_size={len(dset_val)} max={eval_max_samples} patch={patch_size} overlap={overlap} stride={patch_size - overlap} fine_batch={eval_fine_batch}",
260
  flush=True,
261
  )
262
  val_stats = validate(
263
  model,
264
  dset_val,
265
- coarse_train,
266
  device,
267
  amp_enabled,
268
  amp_dtype,
269
  prob_thresh,
270
  mm_enable,
271
  mm_kernel,
272
- patch_size,
273
  overlap,
274
  eval_fine_batch,
275
  eval_max_samples,
@@ -306,7 +325,7 @@ def main():
306
  save_test_visuals(
307
  model,
308
  dset_test,
309
- coarse_train,
310
  device,
311
  os.path.join(out_dir, f"test_vis_{step}"),
312
  amp_enabled,
@@ -604,52 +623,16 @@ def validate(
604
  img = item["image"].astype(np.float32) / 255.0 # HxWx3
605
  mask = item["mask"].astype(np.uint8)
606
  H, W = mask.shape
607
- # Build coarse input (zeros for cond+loc) on GPU
608
- t_img = (
609
- torch.from_numpy(np.transpose(img, (2, 0, 1)))
610
- .unsqueeze(0)
611
- .to(device)
612
- .float()
613
- )
614
- rgb_c = F.interpolate(
615
- t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
616
- )[0]
617
- y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
618
- if minmax_enable:
619
- # Asymmetric padding for even kernel to keep same HxW
620
- k = int(minmax_kernel)
621
- if (k % 2) == 0:
622
- pad = (k // 2 - 1, k // 2, k // 2 - 1, k // 2)
623
- else:
624
- pad = (k // 2, k // 2, k // 2, k // 2)
625
- y_p = F.pad(y_t, pad, mode="replicate")
626
- y_max_full = F.max_pool2d(y_p, kernel_size=k, stride=1)
627
- y_min_full = -F.max_pool2d(-y_p, kernel_size=k, stride=1)
628
- else:
629
- y_min_full = y_t
630
- y_max_full = y_t
631
- y_min_c = F.interpolate(
632
- y_min_full,
633
- size=(coarse_size, coarse_size),
634
- mode="bilinear",
635
- align_corners=False,
636
- )[0]
637
- y_max_c = F.interpolate(
638
- y_max_full,
639
- size=(coarse_size, coarse_size),
640
- mode="bilinear",
641
- align_corners=False,
642
- )[0]
643
- zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
644
- x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
645
- with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
646
- logits_c, cond_map = model.forward_coarse(x_t)
647
- prob = torch.softmax(logits_c, dim=1)[:, 1:2]
648
- prob_up = (
649
- F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
650
- .detach()
651
- .cpu()
652
- .numpy()
653
  )
654
  # Coarse metrics
655
  pred_coarse = (prob_up > prob_thresh).astype(np.uint8)
@@ -657,75 +640,30 @@ def validate(
657
  for k in coarse_sum:
658
  coarse_sum[k] += m_c[k]
659
 
660
- # Fine-stage tiled inference and stitching (BATCHED)
661
- P = fine_patch_size
662
- stride = P - fine_overlap
663
- assert stride > 0
664
- assert H >= P and W >= P
665
- # Accumulate on device to avoid CPU<->GPU thrash
666
- prob_sum_t = torch.zeros((H, W), device=device, dtype=torch.float32)
667
- weight_t = torch.zeros((H, W), device=device, dtype=torch.float32)
668
-
669
- # Prepare min/max on full-res (already computed above as y_min_full/y_max_full)
670
- hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
671
-
672
- ys = list(range(0, max(H - P, 0) + 1, stride))
 
 
 
 
 
673
  if ys[-1] != (H - P):
674
  ys.append(H - P)
675
- xs = list(range(0, max(W - P, 0) + 1, stride))
676
  if xs[-1] != (W - P):
677
  xs.append(W - P)
678
-
679
- coords: List[Tuple[int, int]] = []
680
- for y0 in ys:
681
- for x0 in xs:
682
- coords.append((y0, x0))
683
- total_tiles += len(coords)
684
-
685
- total_batches = (len(coords) + fine_batch - 1) // fine_batch
686
- for i0 in range(0, len(coords), fine_batch):
687
- batch_coords = coords[i0 : i0 + fine_batch]
688
- xs_list: List[torch.Tensor] = []
689
- batch_idx = i0 // fine_batch
690
- if total_batches > 0 and (batch_idx % max(1, total_batches // 10) == 0):
691
- print(
692
- f"[Eval] Img {i+1}/{target_n} | Tile batch {batch_idx+1}/{total_batches}",
693
- flush=True,
694
- )
695
- for (y0, x0) in batch_coords:
696
- y1, x1 = y0 + P, x0 + P
697
- # Cond crop mapping (same as training _build_fine_inputs)
698
- y0c = (y0 * hc4) // H
699
- y1c = ((y1 * hc4) + H - 1) // H
700
- x0c = (x0 * wc4) // W
701
- x1c = ((x1 * wc4) + W - 1) // W
702
- cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c].float()
703
- cond_patch = F.interpolate(
704
- cond_sub, size=(P, P), mode="bilinear", align_corners=False
705
- ).squeeze(1) # 1xPxP
706
-
707
- # Build fine input channels directly from on-device tensors
708
- rgb_t = t_img[0, :, y0:y1, x0:x1] # 3xPxP
709
- ymin_t = y_min_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
710
- ymax_t = y_max_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
711
- x_f = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0).unsqueeze(0)
712
- xs_list.append(x_f)
713
-
714
- x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
715
-
716
- with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
717
- logits_f = model.forward_fine(x_f_batch)
718
- prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
719
- prob_f_up = F.interpolate(
720
- prob_f, size=(P, P), mode="bilinear", align_corners=False
721
- )[:, 0, :, :] # BxPxP
722
-
723
- for bi, (y0, x0) in enumerate(batch_coords):
724
- y1, x1 = y0 + P, x0 + P
725
- prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
726
- weight_t[y0:y1, x0:x1] += 1.0
727
-
728
- prob_full = (prob_sum_t / weight_t).detach().cpu().numpy()
729
  pred_fine = (prob_full > prob_thresh).astype(np.uint8)
730
  m_f = compute_metrics(pred_fine, mask)
731
  for k in metrics_sum:
@@ -773,50 +711,15 @@ def save_test_visuals(
773
  item = dset_test[i]
774
  img = item["image"].astype(np.float32) / 255.0
775
  H, W = img.shape[:2]
776
- t_img = (
777
- torch.from_numpy(np.transpose(img, (2, 0, 1)))
778
- .unsqueeze(0)
779
- .to(device)
780
- .float()
781
- )
782
- rgb_c = F.interpolate(
783
- t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
784
- )[0]
785
- y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
786
- if minmax_enable:
787
- k = int(minmax_kernel)
788
- if (k % 2) == 0:
789
- pad = (k // 2 - 1, k // 2, k // 2 - 1, k // 2)
790
- else:
791
- pad = (k // 2, k // 2, k // 2, k // 2)
792
- y_p = F.pad(y_t, pad, mode="replicate")
793
- y_max_full = F.max_pool2d(y_p, kernel_size=k, stride=1)
794
- y_min_full = -F.max_pool2d(-y_p, kernel_size=k, stride=1)
795
- else:
796
- y_min_full = y_t
797
- y_max_full = y_t
798
- y_min_c = F.interpolate(
799
- y_min_full,
800
- size=(coarse_size, coarse_size),
801
- mode="bilinear",
802
- align_corners=False,
803
- )[0]
804
- y_max_c = F.interpolate(
805
- y_max_full,
806
- size=(coarse_size, coarse_size),
807
- mode="bilinear",
808
- align_corners=False,
809
- )[0]
810
- zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
811
- x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
812
- with autocast(device_type=device.type, dtype=None, enabled=amp_flag):
813
- logits_c, _ = model.forward_coarse(x_t)
814
- prob = torch.softmax(logits_c, dim=1)[:, 1:2]
815
- prob_up = (
816
- F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
817
- .detach()
818
- .cpu()
819
- .numpy()
820
  )
821
  pred = (prob_up > prob_thresh).astype(np.uint8) * 255
822
  # Save input and prediction
 
23
  from src.wireseghr.model.label_downsample import downsample_label_maxpool
24
  from src.wireseghr.data.sampler import BalancedPatchSampler
25
  from src.wireseghr.metrics import compute_metrics
26
+ from infer import _coarse_forward, _tiled_fine_forward
27
 
28
 
29
  class SizeBatchSampler:
 
41
  self._len = 0
42
  for hw, idxs in bins.items():
43
  _ = hw # unused, clarity
44
+ self._len += len(idxs) // self.batch_size
45
 
46
  def __len__(self) -> int:
47
  return self._len
 
55
  pool = list(bins[hw])
56
  random.shuffle(pool)
57
  # Yield only full batches to keep fixed batch size and same-size assumption
58
+ for i in range(
59
+ 0, len(pool) - (len(pool) % self.batch_size), self.batch_size
60
+ ):
61
  yield pool[i : i + self.batch_size]
62
 
63
 
 
90
 
91
  # Config
92
  coarse_train = int(cfg["coarse"]["train_size"]) # 512
93
+ coarse_test = int(cfg["coarse"]["test_size"]) # use higher res for eval/infer
94
+ patch_size = int(cfg["fine"]["patch_size"]) # training fine patch size
95
  overlap = int(cfg["fine"]["overlap"]) # e.g., 128
96
+ eval_patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024 for eval/infer
97
  eval_cfg = cfg.get("eval", {})
98
  eval_fine_batch = int(eval_cfg.get("fine_batch", 16))
99
  assert eval_fine_batch >= 1
 
112
  if amp_enabled:
113
  cc_major, cc_minor = torch.cuda.get_device_capability()
114
  if precision == "fp16":
115
+ assert cc_major >= 7, (
116
+ f"fp16 requires Volta (SM 7.0)+; current SM {cc_major}.{cc_minor}"
117
+ )
118
  elif precision == "bf16":
119
+ assert cc_major >= 8, (
120
+ f"bf16 requires Ampere (SM 8.0)+; current SM {cc_major}.{cc_minor}"
121
+ )
122
  amp_dtype = (
123
+ torch.float16
124
+ if precision == "fp16"
125
+ else (torch.bfloat16 if precision == "bf16" else None)
126
  )
127
 
128
  # Housekeeping
 
142
  num_workers = int(loader_cfg.get("num_workers", 4))
143
  prefetch_factor = int(loader_cfg.get("prefetch_factor", 2))
144
  pin_memory = bool(loader_cfg.get("pin_memory", True))
145
+ persistent_workers = (
146
+ bool(loader_cfg.get("persistent_workers", True)) if num_workers > 0 else False
147
+ )
148
  batch_sampler = SizeBatchSampler(dset, batch_size)
149
  loader_kwargs = dict(
150
  batch_sampler=batch_sampler,
 
261
  # Eval & Checkpoint
262
  if (step % eval_interval == 0) and (dset_val is not None):
263
  # Free training-step tensors before eval to lower peak memory
264
+ del (
265
+ x_fine,
266
+ logits_coarse,
267
+ cond_map,
268
+ logits_fine,
269
+ y_coarse,
270
+ y_fine,
271
+ loss_coarse,
272
+ loss_fine,
273
+ loss,
274
+ )
275
  torch.cuda.empty_cache()
276
  model.eval()
277
  print(
278
+ f"[WireSegHR][train] Eval starting... val_size={len(dset_val)} max={eval_max_samples} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}",
279
  flush=True,
280
  )
281
  val_stats = validate(
282
  model,
283
  dset_val,
284
+ coarse_test,
285
  device,
286
  amp_enabled,
287
  amp_dtype,
288
  prob_thresh,
289
  mm_enable,
290
  mm_kernel,
291
+ eval_patch_size,
292
  overlap,
293
  eval_fine_batch,
294
  eval_max_samples,
 
325
  save_test_visuals(
326
  model,
327
  dset_test,
328
+ coarse_test,
329
  device,
330
  os.path.join(out_dir, f"test_vis_{step}"),
331
  amp_enabled,
 
623
  img = item["image"].astype(np.float32) / 255.0 # HxWx3
624
  mask = item["mask"].astype(np.uint8)
625
  H, W = mask.shape
626
+ # Reuse inference coarse pass
627
+ prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward(
628
+ model,
629
+ img,
630
+ coarse_size,
631
+ minmax_enable,
632
+ int(minmax_kernel),
633
+ device,
634
+ amp_flag,
635
+ amp_dtype,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  )
637
  # Coarse metrics
638
  pred_coarse = (prob_up > prob_thresh).astype(np.uint8)
 
640
  for k in coarse_sum:
641
  coarse_sum[k] += m_c[k]
642
 
643
+ # Fine-stage via helper (batched and stitched)
644
+ prob_full = _tiled_fine_forward(
645
+ model,
646
+ t_img,
647
+ cond_map,
648
+ y_min_full,
649
+ y_max_full,
650
+ int(fine_patch_size),
651
+ int(fine_overlap),
652
+ int(fine_batch),
653
+ device,
654
+ amp_flag,
655
+ amp_dtype,
656
+ )
657
+ # Track tiles for throughput parity
658
+ P = int(fine_patch_size)
659
+ stride = P - int(fine_overlap)
660
+ ys = list(range(0, H - P + 1, stride))
661
  if ys[-1] != (H - P):
662
  ys.append(H - P)
663
+ xs = list(range(0, W - P + 1, stride))
664
  if xs[-1] != (W - P):
665
  xs.append(W - P)
666
+ total_tiles += len(ys) * len(xs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  pred_fine = (prob_full > prob_thresh).astype(np.uint8)
668
  m_f = compute_metrics(pred_fine, mask)
669
  for k in metrics_sum:
 
711
  item = dset_test[i]
712
  img = item["image"].astype(np.float32) / 255.0
713
  H, W = img.shape[:2]
714
+ prob_up, _cond_map, _t_img, _ymin, _ymax = _coarse_forward(
715
+ model,
716
+ img,
717
+ int(coarse_size),
718
+ bool(minmax_enable),
719
+ int(minmax_kernel),
720
+ device,
721
+ bool(amp_flag),
722
+ None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  )
724
  pred = (prob_up > prob_thresh).astype(np.uint8) * 255
725
  # Save input and prediction