Upload encoder.py with huggingface_hub
Browse files- encoder.py +26 -0
encoder.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from segmentation_models_pytorch.encoders import encoders
|
2 |
+
from segmentation_models_pytorch import Unet
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# Override pretrained settings for your weights
|
6 |
+
encoders["dpn98"]["pretrained_settings"]["micronet"] = {
|
7 |
+
"url": "https://huggingface.co/jstuckner/microscopy-dpn98-micronet/resolve/main/dpn98_micronet_weights.pth",
|
8 |
+
"input_space": "RGB",
|
9 |
+
"input_range": [0, 1],
|
10 |
+
"mean": [0.485, 0.456, 0.406],
|
11 |
+
"std": [0.229, 0.224, 0.225],
|
12 |
+
}
|
13 |
+
|
14 |
+
# Use as normal
|
15 |
+
model = Unet(
|
16 |
+
encoder_name="dpn98",
|
17 |
+
encoder_weights="micronet",
|
18 |
+
classes=1,
|
19 |
+
activation=None,
|
20 |
+
)
|
21 |
+
|
22 |
+
# Test input
|
23 |
+
x = torch.randn(1, 3, 256, 256)
|
24 |
+
with torch.no_grad():
|
25 |
+
y = model(x)
|
26 |
+
print("Output shape:", y.shape)
|