jstuckner's picture
Upload encoder.py with huggingface_hub
648ae9c verified
from segmentation_models_pytorch.encoders import encoders
from segmentation_models_pytorch import Unet
import torch
# Override pretrained settings for your weights
encoders["dpn98"]["pretrained_settings"]["micronet"] = {
"url": "https://huggingface.co/jstuckner/microscopy-dpn98-micronet/resolve/main/dpn98_micronet_weights.pth",
"input_space": "RGB",
"input_range": [0, 1],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
}
# Use as normal
model = Unet(
encoder_name="dpn98",
encoder_weights="micronet",
classes=1,
activation=None,
)
# Test input
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
y = model(x)
print("Output shape:", y.shape)