import torch import torch.nn as nn from torchvision import models class PairNet(nn.Module): """ A lightweight backbone + dual-head network: - Regression head for days-to-delivery - Classification head for preterm probability This is a scaffold and not the proprietary model from the paper. """ def __init__(self, backbone_name: str = "efficientnet_b0", pretrained: bool = True): super().__init__() if backbone_name == "efficientnet_b0": try: weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None except Exception: weights = None backbone = models.efficientnet_b0(weights=weights) in_feats = backbone.classifier[1].in_features backbone.classifier = nn.Identity() else: # Fallback to resnet18 try: weights = models.ResNet18_Weights.DEFAULT if pretrained else None except Exception: weights = None backbone = models.resnet18(weights=weights) in_feats = backbone.fc.in_features backbone.fc = nn.Identity() self.backbone = backbone self.reg_head = nn.Linear(in_feats, 1) self.cls_head = nn.Linear(in_feats, 1) def forward(self, x): feats = self.backbone(x) days = self.reg_head(feats) # unconstrained logits = self.cls_head(feats) # unconstrained return days, logits def load_weights_if_any(model: nn.Module, weights_path: str | None): if not weights_path: return False, "No weights path provided" import os if os.path.isfile(weights_path): state = torch.load(weights_path, map_location="cpu") if "state_dict" in state: state = state["state_dict"] missing, unexpected = model.load_state_dict(state, strict=False) return True, f"Loaded local weights. missing={len(missing)} unexpected={len(unexpected)}" # Try huggingface hub repo id try: from huggingface_hub import hf_hub_download fp = hf_hub_download(repo_id=weights_path, filename="pytorch_model.bin", local_dir="weights") state = torch.load(fp, map_location="cpu") if "state_dict" in state: state = state["state_dict"] missing, unexpected = model.load_state_dict(state, strict=False) return True, f"Loaded HF weights from {weights_path}. missing={len(missing)} unexpected={len(unexpected)}" except Exception as e: return False, f"Failed to load weights: {e}"