jstuckner commited on
Commit
648ae9c
·
verified ·
1 Parent(s): dc0d59b

Upload encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. encoder.py +26 -0
encoder.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from segmentation_models_pytorch.encoders import encoders
2
+ from segmentation_models_pytorch import Unet
3
+ import torch
4
+
5
+ # Override pretrained settings for your weights
6
+ encoders["dpn98"]["pretrained_settings"]["micronet"] = {
7
+ "url": "https://huggingface.co/jstuckner/microscopy-dpn98-micronet/resolve/main/dpn98_micronet_weights.pth",
8
+ "input_space": "RGB",
9
+ "input_range": [0, 1],
10
+ "mean": [0.485, 0.456, 0.406],
11
+ "std": [0.229, 0.224, 0.225],
12
+ }
13
+
14
+ # Use as normal
15
+ model = Unet(
16
+ encoder_name="dpn98",
17
+ encoder_weights="micronet",
18
+ classes=1,
19
+ activation=None,
20
+ )
21
+
22
+ # Test input
23
+ x = torch.randn(1, 3, 256, 256)
24
+ with torch.no_grad():
25
+ y = model(x)
26
+ print("Output shape:", y.shape)