MRiabov commited on
Commit
6e8ad0e
·
1 Parent(s): c2fb738

(test) test ResNet backbone

Browse files
configs/default.yaml CHANGED
@@ -32,7 +32,7 @@ eval:
32
  fine_batch: 16
33
 
34
  optim:
35
- iters: 10000
36
  batch_size: 4
37
  lr: 6e-5
38
  weight_decay: 0.01
@@ -45,7 +45,7 @@ seed: 42
45
  out_dir: runs/wireseghr
46
  eval_interval: 200
47
  ckpt_interval: 400
48
- resume: runs/wireseghr/ckpt_1800.pt # optional
49
 
50
  # dataset paths (placeholders)
51
  data:
 
32
  fine_batch: 16
33
 
34
  optim:
35
+ iters: 5000
36
  batch_size: 4
37
  lr: 6e-5
38
  weight_decay: 0.01
 
45
  out_dir: runs/wireseghr
46
  eval_interval: 200
47
  ckpt_interval: 400
48
+ resume: runs/wireseghr/ckpt_4800.pt # optional
49
 
50
  # dataset paths (placeholders)
51
  data:
tests/test_model_forward.py CHANGED
@@ -17,3 +17,19 @@ def test_wireseghr_forward_shapes():
17
 
18
  logits_fine = model.forward_fine(x)
19
  assert logits_fine.shape == logits_coarse.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  logits_fine = model.forward_fine(x)
19
  assert logits_fine.shape == logits_coarse.shape
20
+
21
+
22
+ def test_wireseghr_forward_shapes_resnet50():
23
+ # Ensure ResNet-50 alt backbone works and keeps 1/4 stage0 resolution
24
+ model = WireSegHR(backbone="resnet50", in_channels=3, pretrained=False)
25
+
26
+ x = torch.randn(1, 3, 64, 64)
27
+ logits_coarse, cond = model.forward_coarse(x)
28
+ assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
29
+ assert cond.shape[0] == 1 and cond.shape[1] == 1
30
+ # ResNet stage0 is also 1/4 of input
31
+ assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
32
+ assert cond.shape[2] == 16 and cond.shape[3] == 16
33
+
34
+ logits_fine = model.forward_fine(x)
35
+ assert logits_fine.shape == logits_coarse.shape