MRiabov commited on
Commit
854831b
·
1 Parent(s): ae754fb

(debug) to torch operations in bfloat conversions.

Browse files
Files changed (3) hide show
  1. configs/default.yaml +5 -5
  2. infer.py +12 -9
  3. train.py +3 -3
configs/default.yaml CHANGED
@@ -28,8 +28,8 @@ inference:
28
  stitch: avg_logits
29
 
30
  eval:
31
- max_samples: 16
32
- fine_batch: 32
33
 
34
  optim:
35
  iters: 2000
@@ -38,14 +38,14 @@ optim:
38
  weight_decay: 0.01
39
  schedule: poly
40
  power: 1.0
41
- precision: fp32 # one of: fp32, fp16, bf16
42
 
43
  # training housekeeping
44
  seed: 42
45
  out_dir: runs/wireseghr
46
- eval_interval: 100
47
  ckpt_interval: 300
48
- resume: runs/wireseghr/ckpt_1800.pt # optional
49
 
50
  # dataset paths (placeholders)
51
  data:
 
28
  stitch: avg_logits
29
 
30
  eval:
31
+ max_samples: 12
32
+ fine_batch: 16
33
 
34
  optim:
35
  iters: 2000
 
38
  weight_decay: 0.01
39
  schedule: poly
40
  power: 1.0
41
+ precision: bf16 # one of: fp32, fp16, bf16
42
 
43
  # training housekeeping
44
  seed: 42
45
  out_dir: runs/wireseghr
46
+ eval_interval: 150
47
  ckpt_interval: 300
48
+ # resume: runs/wireseghr/ckpt_1800.pt # optional
49
 
50
  # dataset paths (placeholders)
51
  data:
infer.py CHANGED
@@ -31,7 +31,7 @@ def _coarse_forward(
31
  device: torch.device,
32
  amp_flag: bool,
33
  amp_dtype,
34
- ) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
35
  # Convert to tensor on device
36
  t_img = (
37
  torch.from_numpy(np.transpose(img_rgb, (2, 0, 1)))
@@ -76,8 +76,8 @@ def _coarse_forward(
76
  F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
77
  .detach()
78
  .cpu()
79
- .numpy()
80
- )
81
  return prob_up, cond_map, t_img, y_min_full, y_max_full
82
 
83
 
@@ -94,7 +94,7 @@ def _tiled_fine_forward(
94
  device: torch.device,
95
  amp_flag: bool,
96
  amp_dtype,
97
- ) -> np.ndarray:
98
  H = int(t_img.shape[2])
99
  W = int(t_img.shape[3])
100
  P = patch_size
@@ -153,8 +153,8 @@ def _tiled_fine_forward(
153
  prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
154
  weight_t[y0:y1, x0:x1] += 1.0
155
 
156
- prob_full = (prob_sum_t / weight_t).detach().cpu().numpy()
157
- return prob_full
158
 
159
 
160
  def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR:
@@ -216,7 +216,9 @@ def infer_image(
216
  amp_dtype,
217
  )
218
 
219
- pred = (prob_f > prob_thresh).astype(np.uint8) * 255
 
 
220
 
221
  if out_dir is not None:
222
  os.makedirs(out_dir, exist_ok=True)
@@ -225,9 +227,10 @@ def infer_image(
225
  cv2.imwrite(out_mask, pred)
226
  if save_prob:
227
  out_prob = os.path.join(out_dir, f"{stem}_prob.npy")
228
- np.save(out_prob, prob_f.astype(np.float32))
229
 
230
- return pred, prob_f
 
231
 
232
 
233
  def main():
 
31
  device: torch.device,
32
  amp_flag: bool,
33
  amp_dtype,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
35
  # Convert to tensor on device
36
  t_img = (
37
  torch.from_numpy(np.transpose(img_rgb, (2, 0, 1)))
 
76
  F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
77
  .detach()
78
  .cpu()
79
+ .float()
80
+ ) # HxW torch.Tensor on CPU
81
  return prob_up, cond_map, t_img, y_min_full, y_max_full
82
 
83
 
 
94
  device: torch.device,
95
  amp_flag: bool,
96
  amp_dtype,
97
+ ) -> torch.Tensor:
98
  H = int(t_img.shape[2])
99
  W = int(t_img.shape[3])
100
  P = patch_size
 
153
  prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
154
  weight_t[y0:y1, x0:x1] += 1.0
155
 
156
+ prob_full = (prob_sum_t / weight_t).detach().cpu().float()
157
+ return prob_full # HxW torch.Tensor on CPU
158
 
159
 
160
  def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR:
 
216
  amp_dtype,
217
  )
218
 
219
+ # Threshold with torch on CPU; convert to numpy only for saving/returning
220
+ pred_t = (prob_f > prob_thresh).to(torch.uint8) * 255 # HxW uint8 torch
221
+ pred = pred_t.detach().cpu().numpy()
222
 
223
  if out_dir is not None:
224
  os.makedirs(out_dir, exist_ok=True)
 
227
  cv2.imwrite(out_mask, pred)
228
  if save_prob:
229
  out_prob = os.path.join(out_dir, f"{stem}_prob.npy")
230
+ np.save(out_prob, prob_f.detach().cpu().float().numpy())
231
 
232
+ # Return numpy arrays for external consumers, computed via torch
233
+ return pred, prob_f.detach().cpu().numpy()
234
 
235
 
236
  def main():
train.py CHANGED
@@ -635,7 +635,7 @@ def validate(
635
  amp_dtype,
636
  )
637
  # Coarse metrics
638
- pred_coarse = (prob_up > prob_thresh).astype(np.uint8)
639
  m_c = compute_metrics(pred_coarse, mask)
640
  for k in coarse_sum:
641
  coarse_sum[k] += m_c[k]
@@ -664,7 +664,7 @@ def validate(
664
  if xs[-1] != (W - P):
665
  xs.append(W - P)
666
  total_tiles += len(ys) * len(xs)
667
- pred_fine = (prob_full > prob_thresh).astype(np.uint8)
668
  m_f = compute_metrics(pred_fine, mask)
669
  for k in metrics_sum:
670
  metrics_sum[k] += m_f[k]
@@ -721,7 +721,7 @@ def save_test_visuals(
721
  bool(amp_flag),
722
  None,
723
  )
724
- pred = (prob_up > prob_thresh).astype(np.uint8) * 255
725
  # Save input and prediction
726
  img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
727
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
 
635
  amp_dtype,
636
  )
637
  # Coarse metrics
638
+ pred_coarse = (prob_up > prob_thresh).to(torch.uint8).cpu().numpy()
639
  m_c = compute_metrics(pred_coarse, mask)
640
  for k in coarse_sum:
641
  coarse_sum[k] += m_c[k]
 
664
  if xs[-1] != (W - P):
665
  xs.append(W - P)
666
  total_tiles += len(ys) * len(xs)
667
+ pred_fine = (prob_full > prob_thresh).to(torch.uint8).cpu().numpy()
668
  m_f = compute_metrics(pred_fine, mask)
669
  for k in metrics_sum:
670
  metrics_sum[k] += m_f[k]
 
721
  bool(amp_flag),
722
  None,
723
  )
724
+ pred = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
725
  # Save input and prediction
726
  img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
727
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)