WireSegHR / train.py
MRiabov's picture
README and cleanup
e05cc17
import argparse
import os
import pprint
import yaml
from typing import Tuple, List, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast
from torch.amp import GradScaler
from tqdm import tqdm
import random
import torch.backends.cudnn as cudnn
import cv2
from torch.utils.data import DataLoader
import time
from src.wireseghr.model import WireSegHR
from src.wireseghr.model.minmax import MinMaxLuminance
from src.wireseghr.data.dataset import WireSegDataset
from src.wireseghr.model.label_downsample import downsample_label_maxpool
from src.wireseghr.data.sampler import BalancedPatchSampler
from src.wireseghr.metrics import compute_metrics
from infer import _coarse_forward, _tiled_fine_forward
from pathlib import Path
class SizeBatchSampler:
"""Batch sampler that groups indices by exact (H, W) so all samples in a batch share size.
This enables DataLoader prefetching while preserving the existing assumption
in `_prepare_batch()` that all items in a batch have the same full resolution.
"""
def __init__(self, dset: WireSegDataset, batch_size: int):
self.dset = dset
self.batch_size = batch_size
# Precompute epoch length as the total number of full batches across bins
bins = self.dset.size_bins
self._len = 0
for hw, idxs in bins.items():
_ = hw # unused, clarity
self._len += len(idxs) // self.batch_size
def __len__(self) -> int:
return self._len
def __iter__(self):
# Create randomized batches per epoch across size bins
bins = self.dset.size_bins
keys = list(bins.keys())
random.shuffle(keys)
for hw in keys:
pool = list(bins[hw])
random.shuffle(pool)
# Yield only full batches to keep fixed batch size and same-size assumption
for i in range(
0, len(pool) - (len(pool) % self.batch_size), self.batch_size
):
yield pool[i : i + self.batch_size]
def collate_train(batch: List[Dict]):
"""Collate function that returns lists of numpy arrays to match existing pipeline."""
imgs = [b["image"] for b in batch]
masks = [b["mask"] for b in batch]
return imgs, masks
def main():
parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
parser.add_argument(
"--config", type=str, default="configs/default.yaml", help="Path to YAML config"
)
args = parser.parse_args()
cfg_path = args.config
if not Path(cfg_path).is_absolute():
cfg_path = str(Path.cwd() / cfg_path)
with open(cfg_path, "r") as f:
cfg = yaml.safe_load(f)
print("[WireSegHR][train] Loaded config from:", cfg_path)
pprint.pprint(cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[WireSegHR][train] Device: {device}")
# Config
coarse_train = int(cfg["coarse"]["train_size"]) # 512
coarse_test = int(cfg["coarse"]["test_size"]) # use higher res for eval/infer
patch_size = int(cfg["fine"]["patch_size"]) # training fine patch size
overlap = int(cfg["fine"]["overlap"]) # e.g., 128
eval_patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024 for eval/infer
eval_cfg = cfg.get("eval", {})
eval_fine_batch = int(eval_cfg.get("fine_batch", 16))
assert eval_fine_batch >= 1
eval_max_samples = int(eval_cfg.get("max_samples", 16))
assert eval_max_samples >= 1
iters = int(cfg["optim"]["iters"]) # 40000
batch_size = int(cfg["optim"]["batch_size"]) # 8
base_lr = float(cfg["optim"]["lr"]) # 6e-5
weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
power = float(cfg["optim"]["power"]) # 1.0
precision = str(cfg["optim"].get("precision", "fp32")).lower()
assert precision in ("fp32", "fp16", "bf16")
# Enable AMP only when requested and on CUDA
amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16"))
# Fail fast on unsupported hardware if mixed precision is requested
if amp_enabled:
cc_major, cc_minor = torch.cuda.get_device_capability()
if precision == "fp16":
assert cc_major >= 7, (
f"fp16 requires Volta (SM 7.0)+; current SM {cc_major}.{cc_minor}"
)
elif precision == "bf16":
assert cc_major >= 8, (
f"bf16 requires Ampere (SM 8.0)+; current SM {cc_major}.{cc_minor}"
)
amp_dtype = (
torch.float16
if precision == "fp16"
else (torch.bfloat16 if precision == "bf16" else None)
)
# Housekeeping
seed = int(cfg.get("seed", 42))
out_dir = cfg.get("out_dir", "runs/wireseghr")
eval_interval = int(cfg["eval_interval"])
ckpt_interval = int(cfg["ckpt_interval"])
os.makedirs(out_dir, exist_ok=True)
set_seed(seed)
# Dataset
train_images = cfg["data"]["train_images"]
train_masks = cfg["data"]["train_masks"]
dset = WireSegDataset(train_images, train_masks, split="train")
# DataLoader with prefetching and size-aware batching
loader_cfg = cfg.get("loader", {})
num_workers = int(loader_cfg.get("num_workers", 4))
prefetch_factor = int(loader_cfg.get("prefetch_factor", 2))
pin_memory = bool(loader_cfg.get("pin_memory", True))
persistent_workers = (
bool(loader_cfg.get("persistent_workers", True)) if num_workers > 0 else False
)
batch_sampler = SizeBatchSampler(dset, batch_size)
loader_kwargs = dict(
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
collate_fn=collate_train,
)
if num_workers > 0:
loader_kwargs["prefetch_factor"] = prefetch_factor
train_loader = DataLoader(dset, **loader_kwargs)
# Validation and test
val_images = cfg["data"].get("val_images", None)
val_masks = cfg["data"].get("val_masks", None)
test_images = cfg["data"].get("test_images", None)
test_masks = cfg["data"].get("test_masks", None)
dset_val = (
WireSegDataset(val_images, val_masks, split="val")
if val_images and val_masks
else None
)
dset_test = (
WireSegDataset(test_images, test_masks, split="test")
if test_images and test_masks
else None
)
sampler = BalancedPatchSampler(patch_size=patch_size, min_wire_ratio=0.01)
minmax = (
MinMaxLuminance(kernel=cfg["minmax"]["kernel"])
if cfg["minmax"]["enable"]
else None
)
# Inference/eval settings from config
prob_thresh = float(cfg["inference"]["prob_threshold"])
mm_enable = bool(cfg["minmax"]["enable"])
mm_kernel = int(cfg["minmax"]["kernel"])
# Model
# Channel definition: RGB(3) + MinMax(2) + cond(1) = 6
pretrained_flag = bool(cfg.get("pretrained", False))
model = WireSegHR(
backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag
)
model = model.to(device)
# Optimizer and loss
optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
scaler = GradScaler("cuda", enabled=(device.type == "cuda" and precision == "fp16"))
ce = nn.CrossEntropyLoss()
# Resume
start_step = 0
best_f1 = -1.0
resume_path = cfg.get("resume", None)
if resume_path and Path(resume_path).is_file():
print(f"[WireSegHR][train] Resuming from {resume_path}")
start_step, best_f1 = _load_checkpoint(
resume_path, model, optim, scaler, device
)
# Training loop
model.train()
step = start_step
pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100)
data_iter = iter(train_loader)
while step < iters:
optim.zero_grad(set_to_none=True)
try:
imgs, masks = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
imgs, masks = next(data_iter)
batch = _prepare_batch(
imgs, masks, coarse_train, patch_size, sampler, minmax, device
)
with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
logits_coarse, cond_map = model.forward_coarse(
batch["x_coarse"]
) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
# Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
B, _, hc4, wc4 = cond_map.shape
x_fine = _build_fine_inputs(batch, cond_map, device)
with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
logits_fine = model.forward_fine(x_fine)
# Targets
y_coarse = _build_coarse_targets(batch["mask_full"], hc4, wc4, device)
y_fine = _build_fine_targets(
batch["mask_patches"],
logits_fine.shape[2],
logits_fine.shape[3],
device,
)
loss_coarse = ce(logits_coarse, y_coarse)
loss_fine = ce(logits_fine, y_fine)
loss = loss_coarse + loss_fine
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
# Poly LR schedule (per optimizer step)
lr = base_lr * ((1.0 - float(step) / float(iters)) ** power)
for pg in optim.param_groups:
pg["lr"] = lr
if step % 50 == 0:
print(f"[Iter {step}/{iters}] lr={lr:.6e}")
# Eval & Checkpoint
if (step % eval_interval == 0) and (dset_val is not None):
# Free training-step tensors before eval to lower peak memory
del (
x_fine,
logits_coarse,
cond_map,
logits_fine,
y_coarse,
y_fine,
loss_coarse,
loss_fine,
loss,
)
torch.cuda.empty_cache()
model.eval()
print(
f"[WireSegHR][train] Eval starting... val_size={len(dset_val)} max={eval_max_samples} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}",
flush=True,
)
val_stats = validate(
model,
dset_val,
coarse_test,
device,
amp_enabled,
amp_dtype,
prob_thresh,
mm_enable,
mm_kernel,
eval_patch_size,
overlap,
eval_fine_batch,
eval_max_samples,
)
print(
f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
)
print(
f"[Val @ {step}][Coarse] IoU={val_stats['iou_coarse']:.4f} F1={val_stats['f1_coarse']:.4f} P={val_stats['precision_coarse']:.4f} R={val_stats['recall_coarse']:.4f}"
)
# Save best
if val_stats["f1"] > best_f1:
best_f1 = val_stats["f1"]
_save_checkpoint(
str(Path(out_dir) / "best.pt"),
step,
model,
optim,
scaler,
best_f1,
)
# Save periodic ckpt
if ckpt_interval > 0 and (step % ckpt_interval == 0):
_save_checkpoint(
str(Path(out_dir) / f"ckpt_{step}.pt"),
step,
model,
optim,
scaler,
best_f1,
)
# Save test visualizations
if dset_test is not None:
save_test_visuals(
model,
dset_test,
coarse_test,
device,
str(Path(out_dir) / f"test_vis_{step}"),
amp_enabled,
mm_enable,
mm_kernel,
prob_thresh,
max_samples=8,
)
model.train()
step += 1
pbar.update(1)
# Save a final checkpoint upon completion
_save_checkpoint(
str(Path(out_dir) / f"ckpt_{iters}.pt"), step, model, optim, scaler, best_f1
)
# Final test evaluation
if dset_test is not None:
torch.cuda.empty_cache()
model.eval()
print(
f"[WireSegHR][train] Final test starting... test_size={len(dset_test)} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}",
flush=True,
)
test_stats = validate(
model,
dset_test,
coarse_test,
device,
amp_enabled,
amp_dtype,
prob_thresh,
mm_enable,
mm_kernel,
eval_patch_size,
overlap,
eval_fine_batch,
len(dset_test),
)
print(
f"[Test Final][Fine] IoU={test_stats['iou']:.4f} F1={test_stats['f1']:.4f} P={test_stats['precision']:.4f} R={test_stats['recall']:.4f}"
)
print(
f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}"
)
# Save final evaluation artifacts
final_out = Path(out_dir) / f"final_vis_{step}"
final_out.mkdir(parents=True, exist_ok=True)
# Dump metrics for record
with open(final_out / "metrics.yaml", "w") as f:
yaml.safe_dump({**test_stats, "step": step}, f, sort_keys=False)
# Save predictions (fine + coarse) for the whole test set
save_final_visuals(
model,
dset_test,
coarse_test,
device,
str(final_out),
amp_enabled,
amp_dtype,
prob_thresh,
mm_enable,
mm_kernel,
eval_patch_size,
overlap,
eval_fine_batch,
)
model.train()
print("[WireSegHR][train] Done.")
def _prepare_batch(
imgs: List[np.ndarray],
masks: List[np.ndarray],
coarse_train: int,
patch_size: int,
sampler: BalancedPatchSampler,
minmax: Optional[MinMaxLuminance],
device: torch.device,
):
B = len(imgs)
assert B == len(masks)
# Keep numpy versions for geometry and torch versions for model inputs
full_h = imgs[0].shape[0]
full_w = imgs[0].shape[1]
for im, m in zip(imgs, masks):
assert im.shape[0] == full_h and im.shape[1] == full_w
assert m.shape[0] == full_h and m.shape[1] == full_w
xs_coarse = []
patches_rgb = []
patches_mask = []
patches_min = []
patches_max = []
yx_list: List[tuple[int, int]] = []
for img, mask in zip(imgs, masks):
# Float32 [0,1] on CPU, then move to GPU for heavy ops
imgf = img.astype(np.float32) / 255.0
t_img = (
torch.from_numpy(np.transpose(imgf, (2, 0, 1))).unsqueeze(0).to(device)
) # 1x3xHxW
# Luminance and Min/Max (6x6 replicate) on GPU
y_t = (
0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
) # 1x1xHxW
if minmax is not None:
# Asymmetric padding for even kernel to keep same HxW
y_p = F.pad(y_t, (2, 3, 2, 3), mode="replicate")
y_max_full = F.max_pool2d(y_p, kernel_size=6, stride=1)
y_min_full = -F.max_pool2d(-y_p, kernel_size=6, stride=1)
else:
y_min_full = y_t
y_max_full = y_t
# Coarse input: resize on GPU, build 6-ch tensor (RGB + min + max + cond0)
rgb_coarse_t = F.interpolate(
t_img,
size=(coarse_train, coarse_train),
mode="bilinear",
align_corners=False,
)[0]
y_min_c_t = F.interpolate(
y_min_full,
size=(coarse_train, coarse_train),
mode="bilinear",
align_corners=False,
)[0]
y_max_c_t = F.interpolate(
y_max_full,
size=(coarse_train, coarse_train),
mode="bilinear",
align_corners=False,
)[0]
zeros_coarse = torch.zeros(1, coarse_train, coarse_train, device=device)
c_t = torch.cat(
[rgb_coarse_t, y_min_c_t, y_max_c_t, zeros_coarse], dim=0
) # 6xHc x Wc
xs_coarse.append(c_t)
# Sample fine patch (CPU mask), then slice GPU min/max and transfer only patches
y0, x0 = sampler.sample(imgf, mask)
patch_rgb = imgf[y0 : y0 + patch_size, x0 : x0 + patch_size, :]
patch_mask = mask[y0 : y0 + patch_size, x0 : x0 + patch_size]
patches_rgb.append(patch_rgb)
patches_mask.append(patch_mask)
ymin_patch = (
y_min_full[0, 0, y0 : y0 + patch_size, x0 : x0 + patch_size]
.detach()
.cpu()
.numpy()
)
ymax_patch = (
y_max_full[0, 0, y0 : y0 + patch_size, x0 : x0 + patch_size]
.detach()
.cpu()
.numpy()
)
patches_min.append(ymin_patch)
patches_max.append(ymax_patch)
yx_list.append((y0, x0))
x_coarse = torch.stack(xs_coarse, dim=0) # already on device
# Store numpy arrays for fine build
return {
"x_coarse": x_coarse,
"full_h": full_h,
"full_w": full_w,
"rgb_patches": patches_rgb,
"mask_patches": patches_mask,
"ymin_patches": patches_min,
"ymax_patches": patches_max,
"patch_yx": yx_list,
"mask_full": masks,
}
def _build_fine_inputs(
batch, cond_map: torch.Tensor, device: torch.device
) -> torch.Tensor:
# Build fine input tensor Bx6xP x P; crop cond from low-res map, upsample to P
B = cond_map.shape[0]
P = batch["rgb_patches"][0].shape[0]
full_h, full_w = batch["full_h"], batch["full_w"]
hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
xs: List[torch.Tensor] = []
for i in range(B):
rgb = batch["rgb_patches"][i]
ymin = batch["ymin_patches"][i]
ymax = batch["ymax_patches"][i]
y0, x0 = batch["patch_yx"][i]
# Map full-res patch box to low-res cond grid, crop and upsample to P
y1, x1 = y0 + P, x0 + P
y0c = (y0 * hc4) // full_h
y1c = ((y1 * hc4) + full_h - 1) // full_h
x0c = (x0 * wc4) // full_w
x1c = ((x1 * wc4) + full_w - 1) // full_w
cond_sub = cond_map[i : i + 1, :, y0c:y1c, x0c:x1c].float() # 1x1xhxw
cond_patch = F.interpolate(
cond_sub, size=(P, P), mode="bilinear", align_corners=False
).squeeze(1) # 1xPxP
# Convert numpy channels to torch and concat
rgb_t = (
torch.from_numpy(np.transpose(rgb, (2, 0, 1))).to(device).float()
) # 3xPxP
ymin_t = torch.from_numpy(ymin)[None, ...].to(device).float() # 1xPxP
ymax_t = torch.from_numpy(ymax)[None, ...].to(device).float() # 1xPxP
x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0) # 6xPxP
xs.append(x)
x_fine = torch.stack(xs, dim=0)
return x_fine
def _build_coarse_targets(
masks: List[np.ndarray], out_h: int, out_w: int, device: torch.device
) -> torch.Tensor:
ys: List[torch.Tensor] = []
for m in masks:
dm = downsample_label_maxpool(m, out_h, out_w)
ys.append(torch.from_numpy(dm.astype(np.int64)))
y = torch.stack(ys, dim=0).to(device) # BxHc4xWc4 with values {0,1}
return y
def _build_fine_targets(
mask_patches: List[np.ndarray], out_h: int, out_w: int, device: torch.device
) -> torch.Tensor:
ys: List[torch.Tensor] = []
for m in mask_patches:
dm = downsample_label_maxpool(m, out_h, out_w)
ys.append(torch.from_numpy(dm.astype(np.int64)))
y = torch.stack(ys, dim=0).to(device) # BxHf4xWf4 with values {0,1}
return y
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# cudnn.benchmark = False
# cudnn.deterministic = True
cudnn.benchmark = True
cudnn.deterministic = False
def _save_checkpoint(
path: str,
step: int,
model: nn.Module,
optim: torch.optim.Optimizer,
scaler: GradScaler,
best_f1: float,
):
Path(path).parent.mkdir(parents=True, exist_ok=True)
state = {
"step": step,
"model": model.state_dict(),
"optim": optim.state_dict(),
"scaler": scaler.state_dict(),
"best_f1": best_f1,
}
torch.save(state, path)
print(f"[WireSegHR][train] Saved checkpoint: {path}")
def _load_checkpoint(
path: str,
model: nn.Module,
optim: torch.optim.Optimizer,
scaler: GradScaler,
device: torch.device,
) -> Tuple[int, float]:
ckpt = torch.load(path, map_location=device)
model.load_state_dict(ckpt["model"])
optim.load_state_dict(ckpt["optim"])
try:
scaler.load_state_dict(ckpt["scaler"]) # may not exist
except Exception:
pass
step = int(ckpt.get("step", 0))
best_f1 = float(ckpt.get("best_f1", -1.0))
return step, best_f1
@torch.no_grad()
def validate(
model: WireSegHR,
dset_val: WireSegDataset,
coarse_size: int,
device: torch.device,
amp_flag: bool,
amp_dtype,
prob_thresh: float,
minmax_enable: bool,
minmax_kernel: int,
fine_patch_size: int,
fine_overlap: int,
fine_batch: int,
max_images: int,
) -> Dict[str, float]:
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
model = model.to(device)
metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
coarse_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
n = 0
t0 = time.time()
total_tiles = 0
target_n = min(len(dset_val), max_images)
idxs = random.sample(range(len(dset_val)), k=target_n)
print(
f"[Eval] Started: N={target_n}/{len(dset_val)} coarse={coarse_size} patch={fine_patch_size} overlap={fine_overlap} stride={fine_patch_size - fine_overlap} fine_batch={fine_batch}",
flush=True,
)
for j, i in enumerate(idxs):
if (j % 2) == 0:
print(f"[Eval] Running... {j}/{target_n}", flush=True)
item = dset_val[i]
img = item["image"].astype(np.float32) / 255.0 # HxWx3
mask = item["mask"].astype(np.uint8)
H, W = mask.shape
# Reuse inference coarse pass
prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward(
model,
img,
coarse_size,
minmax_enable,
int(minmax_kernel),
device,
amp_flag,
amp_dtype,
)
# Coarse metrics
pred_coarse = (prob_up > prob_thresh).to(torch.uint8).cpu().numpy()
m_c = compute_metrics(pred_coarse, mask)
for k in coarse_sum:
coarse_sum[k] += m_c[k]
# Fine-stage via helper (batched and stitched)
prob_full = _tiled_fine_forward(
model,
t_img,
cond_map,
y_min_full,
y_max_full,
int(fine_patch_size),
int(fine_overlap),
int(fine_batch),
device,
amp_flag,
amp_dtype,
)
# Track tiles for throughput parity
P = int(fine_patch_size)
stride = P - int(fine_overlap)
ys = list(range(0, H - P + 1, stride))
if ys[-1] != (H - P):
ys.append(H - P)
xs = list(range(0, W - P + 1, stride))
if xs[-1] != (W - P):
xs.append(W - P)
total_tiles += len(ys) * len(xs)
pred_fine = (prob_full > prob_thresh).to(torch.uint8).cpu().numpy()
m_f = compute_metrics(pred_fine, mask)
for k in metrics_sum:
metrics_sum[k] += m_f[k]
n += 1
if n > 0:
for k in metrics_sum:
metrics_sum[k] /= n
for k in coarse_sum:
coarse_sum[k] /= n
dt = time.time() - t0
tp_img = (n / dt) if dt > 0 else 0.0
tp_tile = (total_tiles / dt) if dt > 0 else 0.0
print(
f"[Eval] Done in {dt:.2f}s | imgs={n}, tiles={total_tiles}, imgs/s={tp_img:.2f}, tiles/s={tp_tile:.2f}",
flush=True,
)
out = {k: v for k, v in metrics_sum.items()}
out.update(
{
"iou_coarse": coarse_sum["iou"],
"f1_coarse": coarse_sum["f1"],
"precision_coarse": coarse_sum["precision"],
"recall_coarse": coarse_sum["recall"],
}
)
return out
@torch.no_grad()
def save_test_visuals(
model: WireSegHR,
dset_test: WireSegDataset,
coarse_size: int,
device: torch.device,
out_dir: str,
amp_flag: bool,
minmax_enable: bool,
minmax_kernel: int,
prob_thresh: float,
max_samples: int = 8,
):
Path(out_dir).mkdir(parents=True, exist_ok=True)
for i in range(min(max_samples, len(dset_test))):
item = dset_test[i]
img = item["image"].astype(np.float32) / 255.0
H, W = img.shape[:2]
prob_up, _cond_map, _t_img, _ymin, _ymax = _coarse_forward(
model,
img,
int(coarse_size),
bool(minmax_enable),
int(minmax_kernel),
device,
bool(amp_flag),
None,
)
pred = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
# Save input and prediction
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
cv2.imwrite(str(Path(out_dir) / f"{i:03d}_input.jpg"), img_bgr)
cv2.imwrite(str(Path(out_dir) / f"{i:03d}_pred.png"), pred)
@torch.no_grad()
def save_final_visuals(
model: WireSegHR,
dset_test: WireSegDataset,
coarse_size: int,
device: torch.device,
out_dir: str,
amp_flag: bool,
amp_dtype,
prob_thresh: float,
minmax_enable: bool,
minmax_kernel: int,
fine_patch_size: int,
fine_overlap: int,
fine_batch: int,
):
Path(out_dir).mkdir(parents=True, exist_ok=True)
for i in range(len(dset_test)):
item = dset_test[i]
img = item["image"].astype(np.float32) / 255.0
H, W = img.shape[:2]
# Coarse pass
prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward(
model,
img,
int(coarse_size),
bool(minmax_enable),
int(minmax_kernel),
device,
bool(amp_flag),
amp_dtype,
)
pred_coarse = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
# Fine pass (tiled)
prob_full = _tiled_fine_forward(
model,
t_img,
cond_map,
y_min_full,
y_max_full,
int(fine_patch_size),
int(fine_overlap),
int(fine_batch),
device,
bool(amp_flag),
amp_dtype,
)
pred_fine = ((prob_full > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
# Save input and predictions
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
base = f"{i:03d}"
cv2.imwrite(str(Path(out_dir) / f"{base}_input.jpg"), img_bgr)
cv2.imwrite(str(Path(out_dir) / f"{base}_coarse_pred.png"), pred_coarse)
cv2.imwrite(str(Path(out_dir) / f"{base}_fine_pred.png"), pred_fine)
if __name__ == "__main__":
main()