PAIR / model.py
fantos's picture
Upload 5 files
2a034f9 verified
import torch
import torch.nn as nn
from torchvision import models
class PairNet(nn.Module):
"""
A lightweight backbone + dual-head network:
- Regression head for days-to-delivery
- Classification head for preterm probability
This is a scaffold and not the proprietary model from the paper.
"""
def __init__(self, backbone_name: str = "efficientnet_b0", pretrained: bool = True):
super().__init__()
if backbone_name == "efficientnet_b0":
try:
weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
except Exception:
weights = None
backbone = models.efficientnet_b0(weights=weights)
in_feats = backbone.classifier[1].in_features
backbone.classifier = nn.Identity()
else:
# Fallback to resnet18
try:
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
except Exception:
weights = None
backbone = models.resnet18(weights=weights)
in_feats = backbone.fc.in_features
backbone.fc = nn.Identity()
self.backbone = backbone
self.reg_head = nn.Linear(in_feats, 1)
self.cls_head = nn.Linear(in_feats, 1)
def forward(self, x):
feats = self.backbone(x)
days = self.reg_head(feats) # unconstrained
logits = self.cls_head(feats) # unconstrained
return days, logits
def load_weights_if_any(model: nn.Module, weights_path: str | None):
if not weights_path:
return False, "No weights path provided"
import os
if os.path.isfile(weights_path):
state = torch.load(weights_path, map_location="cpu")
if "state_dict" in state:
state = state["state_dict"]
missing, unexpected = model.load_state_dict(state, strict=False)
return True, f"Loaded local weights. missing={len(missing)} unexpected={len(unexpected)}"
# Try huggingface hub repo id
try:
from huggingface_hub import hf_hub_download
fp = hf_hub_download(repo_id=weights_path, filename="pytorch_model.bin", local_dir="weights")
state = torch.load(fp, map_location="cpu")
if "state_dict" in state:
state = state["state_dict"]
missing, unexpected = model.load_state_dict(state, strict=False)
return True, f"Loaded HF weights from {weights_path}. missing={len(missing)} unexpected={len(unexpected)}"
except Exception as e:
return False, f"Failed to load weights: {e}"