(debug) to torch operations in bfloat conversions.
Browse files- configs/default.yaml +5 -5
- infer.py +12 -9
- train.py +3 -3
configs/default.yaml
CHANGED
@@ -28,8 +28,8 @@ inference:
|
|
28 |
stitch: avg_logits
|
29 |
|
30 |
eval:
|
31 |
-
max_samples:
|
32 |
-
fine_batch:
|
33 |
|
34 |
optim:
|
35 |
iters: 2000
|
@@ -38,14 +38,14 @@ optim:
|
|
38 |
weight_decay: 0.01
|
39 |
schedule: poly
|
40 |
power: 1.0
|
41 |
-
precision:
|
42 |
|
43 |
# training housekeeping
|
44 |
seed: 42
|
45 |
out_dir: runs/wireseghr
|
46 |
-
eval_interval:
|
47 |
ckpt_interval: 300
|
48 |
-
resume: runs/wireseghr/ckpt_1800.pt # optional
|
49 |
|
50 |
# dataset paths (placeholders)
|
51 |
data:
|
|
|
28 |
stitch: avg_logits
|
29 |
|
30 |
eval:
|
31 |
+
max_samples: 12
|
32 |
+
fine_batch: 16
|
33 |
|
34 |
optim:
|
35 |
iters: 2000
|
|
|
38 |
weight_decay: 0.01
|
39 |
schedule: poly
|
40 |
power: 1.0
|
41 |
+
precision: bf16 # one of: fp32, fp16, bf16
|
42 |
|
43 |
# training housekeeping
|
44 |
seed: 42
|
45 |
out_dir: runs/wireseghr
|
46 |
+
eval_interval: 150
|
47 |
ckpt_interval: 300
|
48 |
+
# resume: runs/wireseghr/ckpt_1800.pt # optional
|
49 |
|
50 |
# dataset paths (placeholders)
|
51 |
data:
|
infer.py
CHANGED
@@ -31,7 +31,7 @@ def _coarse_forward(
|
|
31 |
device: torch.device,
|
32 |
amp_flag: bool,
|
33 |
amp_dtype,
|
34 |
-
) -> Tuple[
|
35 |
# Convert to tensor on device
|
36 |
t_img = (
|
37 |
torch.from_numpy(np.transpose(img_rgb, (2, 0, 1)))
|
@@ -76,8 +76,8 @@ def _coarse_forward(
|
|
76 |
F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
|
77 |
.detach()
|
78 |
.cpu()
|
79 |
-
.
|
80 |
-
)
|
81 |
return prob_up, cond_map, t_img, y_min_full, y_max_full
|
82 |
|
83 |
|
@@ -94,7 +94,7 @@ def _tiled_fine_forward(
|
|
94 |
device: torch.device,
|
95 |
amp_flag: bool,
|
96 |
amp_dtype,
|
97 |
-
) ->
|
98 |
H = int(t_img.shape[2])
|
99 |
W = int(t_img.shape[3])
|
100 |
P = patch_size
|
@@ -153,8 +153,8 @@ def _tiled_fine_forward(
|
|
153 |
prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
|
154 |
weight_t[y0:y1, x0:x1] += 1.0
|
155 |
|
156 |
-
prob_full = (prob_sum_t / weight_t).detach().cpu().
|
157 |
-
return prob_full
|
158 |
|
159 |
|
160 |
def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR:
|
@@ -216,7 +216,9 @@ def infer_image(
|
|
216 |
amp_dtype,
|
217 |
)
|
218 |
|
219 |
-
|
|
|
|
|
220 |
|
221 |
if out_dir is not None:
|
222 |
os.makedirs(out_dir, exist_ok=True)
|
@@ -225,9 +227,10 @@ def infer_image(
|
|
225 |
cv2.imwrite(out_mask, pred)
|
226 |
if save_prob:
|
227 |
out_prob = os.path.join(out_dir, f"{stem}_prob.npy")
|
228 |
-
np.save(out_prob, prob_f.
|
229 |
|
230 |
-
|
|
|
231 |
|
232 |
|
233 |
def main():
|
|
|
31 |
device: torch.device,
|
32 |
amp_flag: bool,
|
33 |
amp_dtype,
|
34 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
35 |
# Convert to tensor on device
|
36 |
t_img = (
|
37 |
torch.from_numpy(np.transpose(img_rgb, (2, 0, 1)))
|
|
|
76 |
F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
|
77 |
.detach()
|
78 |
.cpu()
|
79 |
+
.float()
|
80 |
+
) # HxW torch.Tensor on CPU
|
81 |
return prob_up, cond_map, t_img, y_min_full, y_max_full
|
82 |
|
83 |
|
|
|
94 |
device: torch.device,
|
95 |
amp_flag: bool,
|
96 |
amp_dtype,
|
97 |
+
) -> torch.Tensor:
|
98 |
H = int(t_img.shape[2])
|
99 |
W = int(t_img.shape[3])
|
100 |
P = patch_size
|
|
|
153 |
prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
|
154 |
weight_t[y0:y1, x0:x1] += 1.0
|
155 |
|
156 |
+
prob_full = (prob_sum_t / weight_t).detach().cpu().float()
|
157 |
+
return prob_full # HxW torch.Tensor on CPU
|
158 |
|
159 |
|
160 |
def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR:
|
|
|
216 |
amp_dtype,
|
217 |
)
|
218 |
|
219 |
+
# Threshold with torch on CPU; convert to numpy only for saving/returning
|
220 |
+
pred_t = (prob_f > prob_thresh).to(torch.uint8) * 255 # HxW uint8 torch
|
221 |
+
pred = pred_t.detach().cpu().numpy()
|
222 |
|
223 |
if out_dir is not None:
|
224 |
os.makedirs(out_dir, exist_ok=True)
|
|
|
227 |
cv2.imwrite(out_mask, pred)
|
228 |
if save_prob:
|
229 |
out_prob = os.path.join(out_dir, f"{stem}_prob.npy")
|
230 |
+
np.save(out_prob, prob_f.detach().cpu().float().numpy())
|
231 |
|
232 |
+
# Return numpy arrays for external consumers, computed via torch
|
233 |
+
return pred, prob_f.detach().cpu().numpy()
|
234 |
|
235 |
|
236 |
def main():
|
train.py
CHANGED
@@ -635,7 +635,7 @@ def validate(
|
|
635 |
amp_dtype,
|
636 |
)
|
637 |
# Coarse metrics
|
638 |
-
pred_coarse = (prob_up > prob_thresh).
|
639 |
m_c = compute_metrics(pred_coarse, mask)
|
640 |
for k in coarse_sum:
|
641 |
coarse_sum[k] += m_c[k]
|
@@ -664,7 +664,7 @@ def validate(
|
|
664 |
if xs[-1] != (W - P):
|
665 |
xs.append(W - P)
|
666 |
total_tiles += len(ys) * len(xs)
|
667 |
-
pred_fine = (prob_full > prob_thresh).
|
668 |
m_f = compute_metrics(pred_fine, mask)
|
669 |
for k in metrics_sum:
|
670 |
metrics_sum[k] += m_f[k]
|
@@ -721,7 +721,7 @@ def save_test_visuals(
|
|
721 |
bool(amp_flag),
|
722 |
None,
|
723 |
)
|
724 |
-
pred = (prob_up > prob_thresh).
|
725 |
# Save input and prediction
|
726 |
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
727 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
|
|
|
635 |
amp_dtype,
|
636 |
)
|
637 |
# Coarse metrics
|
638 |
+
pred_coarse = (prob_up > prob_thresh).to(torch.uint8).cpu().numpy()
|
639 |
m_c = compute_metrics(pred_coarse, mask)
|
640 |
for k in coarse_sum:
|
641 |
coarse_sum[k] += m_c[k]
|
|
|
664 |
if xs[-1] != (W - P):
|
665 |
xs.append(W - P)
|
666 |
total_tiles += len(ys) * len(xs)
|
667 |
+
pred_fine = (prob_full > prob_thresh).to(torch.uint8).cpu().numpy()
|
668 |
m_f = compute_metrics(pred_fine, mask)
|
669 |
for k in metrics_sum:
|
670 |
metrics_sum[k] += m_f[k]
|
|
|
721 |
bool(amp_flag),
|
722 |
None,
|
723 |
)
|
724 |
+
pred = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
|
725 |
# Save input and prediction
|
726 |
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
727 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
|