|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models |
|
|
|
|
|
class PairNet(nn.Module): |
|
|
""" |
|
|
A lightweight backbone + dual-head network: |
|
|
- Regression head for days-to-delivery |
|
|
- Classification head for preterm probability |
|
|
This is a scaffold and not the proprietary model from the paper. |
|
|
""" |
|
|
def __init__(self, backbone_name: str = "efficientnet_b0", pretrained: bool = True): |
|
|
super().__init__() |
|
|
if backbone_name == "efficientnet_b0": |
|
|
try: |
|
|
weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None |
|
|
except Exception: |
|
|
weights = None |
|
|
backbone = models.efficientnet_b0(weights=weights) |
|
|
in_feats = backbone.classifier[1].in_features |
|
|
backbone.classifier = nn.Identity() |
|
|
else: |
|
|
|
|
|
try: |
|
|
weights = models.ResNet18_Weights.DEFAULT if pretrained else None |
|
|
except Exception: |
|
|
weights = None |
|
|
backbone = models.resnet18(weights=weights) |
|
|
in_feats = backbone.fc.in_features |
|
|
backbone.fc = nn.Identity() |
|
|
|
|
|
self.backbone = backbone |
|
|
self.reg_head = nn.Linear(in_feats, 1) |
|
|
self.cls_head = nn.Linear(in_feats, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
feats = self.backbone(x) |
|
|
days = self.reg_head(feats) |
|
|
logits = self.cls_head(feats) |
|
|
return days, logits |
|
|
|
|
|
def load_weights_if_any(model: nn.Module, weights_path: str | None): |
|
|
if not weights_path: |
|
|
return False, "No weights path provided" |
|
|
import os |
|
|
if os.path.isfile(weights_path): |
|
|
state = torch.load(weights_path, map_location="cpu") |
|
|
if "state_dict" in state: |
|
|
state = state["state_dict"] |
|
|
missing, unexpected = model.load_state_dict(state, strict=False) |
|
|
return True, f"Loaded local weights. missing={len(missing)} unexpected={len(unexpected)}" |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
fp = hf_hub_download(repo_id=weights_path, filename="pytorch_model.bin", local_dir="weights") |
|
|
state = torch.load(fp, map_location="cpu") |
|
|
if "state_dict" in state: |
|
|
state = state["state_dict"] |
|
|
missing, unexpected = model.load_state_dict(state, strict=False) |
|
|
return True, f"Loaded HF weights from {weights_path}. missing={len(missing)} unexpected={len(unexpected)}" |
|
|
except Exception as e: |
|
|
return False, f"Failed to load weights: {e}" |
|
|
|