from segmentation_models_pytorch.encoders import encoders | |
from segmentation_models_pytorch import Unet | |
import torch | |
# Override pretrained settings for your weights | |
encoders["dpn98"]["pretrained_settings"]["micronet"] = { | |
"url": "https://huggingface.co/jstuckner/microscopy-dpn98-micronet/resolve/main/dpn98_micronet_weights.pth", | |
"input_space": "RGB", | |
"input_range": [0, 1], | |
"mean": [0.485, 0.456, 0.406], | |
"std": [0.229, 0.224, 0.225], | |
} | |
# Use as normal | |
model = Unet( | |
encoder_name="dpn98", | |
encoder_weights="micronet", | |
classes=1, | |
activation=None, | |
) | |
# Test input | |
x = torch.randn(1, 3, 256, 256) | |
with torch.no_grad(): | |
y = model(x) | |
print("Output shape:", y.shape) | |