Upload 5 files
Browse files- README.md +31 -13
- app.py +104 -0
- model.py +62 -0
- requirements.txt +7 -0
- utils.py +54 -0
README.md
CHANGED
|
@@ -1,13 +1,31 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PAIR-inspired Delivery Timing Predictor (Gradio Space)
|
| 2 |
+
|
| 3 |
+
Important
|
| 4 |
+
This demo is inspired by the PAIR study (Perinatal artificial intelligence in ultrasound) but does not include the proprietary model or the private clinical dataset described in the paper. It is provided only as a technical scaffold and demonstration UI.
|
| 5 |
+
|
| 6 |
+
What this Space does
|
| 7 |
+
1. Lets you upload ultrasound images (PNG/JPG) or DICOM files and returns:
|
| 8 |
+
- Predicted days-until-delivery (regression).
|
| 9 |
+
- Preterm probability and label (binary classification with threshold 0.5).
|
| 10 |
+
2. Aggregates predictions across multiple images from the same exam using a simple mean.
|
| 11 |
+
3. Can load your own PyTorch weights via a Hugging Face repo id or a path to a .pt file.
|
| 12 |
+
|
| 13 |
+
How to use in your own Space
|
| 14 |
+
1) Create a new Space on Hugging Face with SDK set to Gradio.
|
| 15 |
+
2) Upload the files in this repository.
|
| 16 |
+
3) Optional: place your model weights at weights/pair_v4.pt or set the env var HF_WEIGHTS to point to a Hugging Face model repo or a local .pt file.
|
| 17 |
+
4) Click Run. If weights are missing, the app falls back to a constant baseline just to demonstrate the UI.
|
| 18 |
+
|
| 19 |
+
Model input
|
| 20 |
+
- One or more 2D ultrasound images or DICOM frames from one exam. The app will convert grayscale to 3-channels when needed.
|
| 21 |
+
|
| 22 |
+
Model output
|
| 23 |
+
- days_to_delivery: float in [1, 300] (clamped).
|
| 24 |
+
- preterm_proba: float in [0, 1].
|
| 25 |
+
- preterm_label: Term if proba < 0.5, Preterm otherwise.
|
| 26 |
+
- predicted_date: today + days_to_delivery (for demo; in clinical use you would provide the scan date).
|
| 27 |
+
|
| 28 |
+
Notes
|
| 29 |
+
- This is not a medical device and is not for clinical use.
|
| 30 |
+
- Performance will be meaningless without appropriate training on a suitable dataset; this scaffold is for integration and UI only.
|
| 31 |
+
- To replicate the paper, you would need authorized access to a comparable dataset and a trained model.
|
app.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import zipfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from model import PairNet, load_weights_if_any
|
| 12 |
+
from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
|
| 13 |
+
|
| 14 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
|
| 16 |
+
def init_model() -> Tuple[torch.nn.Module, str, bool]:
|
| 17 |
+
model = PairNet(pretrained=True).to(DEVICE)
|
| 18 |
+
model.eval()
|
| 19 |
+
weights_hint = os.getenv("HF_WEIGHTS", "").strip()
|
| 20 |
+
ok = False
|
| 21 |
+
msg = "Running in baseline mode (no weights). Output is a UI demo only."
|
| 22 |
+
if weights_hint:
|
| 23 |
+
ok, msg = load_weights_if_any(model, weights_hint)
|
| 24 |
+
elif Path("weights/pair_v4.pt").exists():
|
| 25 |
+
ok, msg = load_weights_if_any(model, "weights/pair_v4.pt")
|
| 26 |
+
return model, msg, ok
|
| 27 |
+
|
| 28 |
+
MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
|
| 29 |
+
|
| 30 |
+
def predict_on_files(files: List[gr.File]) -> dict:
|
| 31 |
+
# Collect file paths (also support zip of images)
|
| 32 |
+
paths: List[str] = []
|
| 33 |
+
for f in files or []:
|
| 34 |
+
p = Path(f.name)
|
| 35 |
+
if p.suffix.lower() == ".zip":
|
| 36 |
+
with zipfile.ZipFile(p, "r") as z:
|
| 37 |
+
with tempfile.TemporaryDirectory() as td:
|
| 38 |
+
z.extractall(td)
|
| 39 |
+
for ext in (".png", ".jpg", ".jpeg", ".dcm"):
|
| 40 |
+
paths.extend([str(q) for q in Path(td).rglob(f"*{ext}")])
|
| 41 |
+
else:
|
| 42 |
+
paths.append(str(p))
|
| 43 |
+
|
| 44 |
+
if not paths:
|
| 45 |
+
return {"status": "no files received"}
|
| 46 |
+
|
| 47 |
+
x = load_exam_as_batch(paths).to(DEVICE)
|
| 48 |
+
x = x.float()
|
| 49 |
+
|
| 50 |
+
if HAS_WEIGHTS:
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
days_raw, logits = MODEL(x)
|
| 53 |
+
# Aggregate across frames by mean
|
| 54 |
+
days = days_raw.squeeze(-1).detach().cpu().tolist()
|
| 55 |
+
proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
|
| 56 |
+
else:
|
| 57 |
+
# Baseline demo: constant mid-gestation-ish guess and neutral probability
|
| 58 |
+
days = [150.0 for _ in range(x.shape[0])]
|
| 59 |
+
proba = [0.5 for _ in range(x.shape[0])]
|
| 60 |
+
|
| 61 |
+
days = [clamp_days(float(d)) for d in days]
|
| 62 |
+
preterm = ["Preterm" if p >= 0.5 else "Term" for p in proba]
|
| 63 |
+
|
| 64 |
+
days_mean, proba_mean = aggregate_predictions(days, proba)
|
| 65 |
+
|
| 66 |
+
result = {
|
| 67 |
+
"frames": len(paths),
|
| 68 |
+
"per_frame_days": days,
|
| 69 |
+
"per_frame_preterm_proba": proba,
|
| 70 |
+
"per_frame_preterm_label": preterm,
|
| 71 |
+
"aggregate_days_mean": days_mean,
|
| 72 |
+
"aggregate_predicted_date": today_plus_days(days_mean),
|
| 73 |
+
"aggregate_preterm_proba": proba_mean,
|
| 74 |
+
"aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
|
| 75 |
+
"weights_message": LOAD_MSG
|
| 76 |
+
}
|
| 77 |
+
return result
|
| 78 |
+
|
| 79 |
+
with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
|
| 80 |
+
gr.Markdown(
|
| 81 |
+
"PAIR-inspired Delivery Timing Predictor
|
| 82 |
+
|
| 83 |
+
"
|
| 84 |
+
"This app is a technical scaffold inspired by the PAIR study. "
|
| 85 |
+
"It does not include the proprietary model or clinical dataset. "
|
| 86 |
+
"Not for medical use."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
with gr.Column():
|
| 91 |
+
in_files = gr.Files(label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)", file_count="multiple", type="filepath")
|
| 92 |
+
run_btn = gr.Button("Run prediction", variant="primary")
|
| 93 |
+
with gr.Column():
|
| 94 |
+
status = gr.JSON(label="Outputs")
|
| 95 |
+
|
| 96 |
+
note = gr.Markdown(f"Model status: {LOAD_MSG}")
|
| 97 |
+
|
| 98 |
+
def _run(files):
|
| 99 |
+
return predict_on_files(files)
|
| 100 |
+
|
| 101 |
+
run_btn.click(_run, inputs=[in_files], outputs=[status])
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
demo.launch()
|
model.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision import models
|
| 4 |
+
|
| 5 |
+
class PairNet(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
A lightweight backbone + dual-head network:
|
| 8 |
+
- Regression head for days-to-delivery
|
| 9 |
+
- Classification head for preterm probability
|
| 10 |
+
This is a scaffold and not the proprietary model from the paper.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, backbone_name: str = "efficientnet_b0", pretrained: bool = True):
|
| 13 |
+
super().__init__()
|
| 14 |
+
if backbone_name == "efficientnet_b0":
|
| 15 |
+
try:
|
| 16 |
+
weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
|
| 17 |
+
except Exception:
|
| 18 |
+
weights = None
|
| 19 |
+
backbone = models.efficientnet_b0(weights=weights)
|
| 20 |
+
in_feats = backbone.classifier[1].in_features
|
| 21 |
+
backbone.classifier = nn.Identity()
|
| 22 |
+
else:
|
| 23 |
+
# Fallback to resnet18
|
| 24 |
+
try:
|
| 25 |
+
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
|
| 26 |
+
except Exception:
|
| 27 |
+
weights = None
|
| 28 |
+
backbone = models.resnet18(weights=weights)
|
| 29 |
+
in_feats = backbone.fc.in_features
|
| 30 |
+
backbone.fc = nn.Identity()
|
| 31 |
+
|
| 32 |
+
self.backbone = backbone
|
| 33 |
+
self.reg_head = nn.Linear(in_feats, 1)
|
| 34 |
+
self.cls_head = nn.Linear(in_feats, 1)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
feats = self.backbone(x)
|
| 38 |
+
days = self.reg_head(feats) # unconstrained
|
| 39 |
+
logits = self.cls_head(feats) # unconstrained
|
| 40 |
+
return days, logits
|
| 41 |
+
|
| 42 |
+
def load_weights_if_any(model: nn.Module, weights_path: str | None):
|
| 43 |
+
if not weights_path:
|
| 44 |
+
return False, "No weights path provided"
|
| 45 |
+
import os
|
| 46 |
+
if os.path.isfile(weights_path):
|
| 47 |
+
state = torch.load(weights_path, map_location="cpu")
|
| 48 |
+
if "state_dict" in state:
|
| 49 |
+
state = state["state_dict"]
|
| 50 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 51 |
+
return True, f"Loaded local weights. missing={len(missing)} unexpected={len(unexpected)}"
|
| 52 |
+
# Try huggingface hub repo id
|
| 53 |
+
try:
|
| 54 |
+
from huggingface_hub import hf_hub_download
|
| 55 |
+
fp = hf_hub_download(repo_id=weights_path, filename="pytorch_model.bin", local_dir="weights")
|
| 56 |
+
state = torch.load(fp, map_location="cpu")
|
| 57 |
+
if "state_dict" in state:
|
| 58 |
+
state = state["state_dict"]
|
| 59 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 60 |
+
return True, f"Loaded HF weights from {weights_path}. missing={len(missing)} unexpected={len(unexpected)}"
|
| 61 |
+
except Exception as e:
|
| 62 |
+
return False, f"Failed to load weights: {e}"
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
+
torch>=2.2.0
|
| 3 |
+
torchvision>=0.17.0
|
| 4 |
+
pillow>=10.0.0
|
| 5 |
+
numpy>=1.26.0
|
| 6 |
+
pydicom>=2.4.4
|
| 7 |
+
huggingface_hub>=0.23.0
|
utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import datetime as dt
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import torch
|
| 10 |
+
import pydicom
|
| 11 |
+
|
| 12 |
+
IMG_SIZE = 224
|
| 13 |
+
|
| 14 |
+
def _read_image(file_path: str) -> Image.Image:
|
| 15 |
+
ext = os.path.splitext(file_path)[1].lower()
|
| 16 |
+
if ext == ".dcm":
|
| 17 |
+
ds = pydicom.dcmread(file_path)
|
| 18 |
+
arr = ds.pixel_array.astype(np.float32)
|
| 19 |
+
# normalize to 0..255
|
| 20 |
+
arr = arr - arr.min()
|
| 21 |
+
if arr.max() > 0:
|
| 22 |
+
arr = arr / arr.max()
|
| 23 |
+
arr = (arr * 255.0).clip(0,255).astype(np.uint8)
|
| 24 |
+
return Image.fromarray(arr)
|
| 25 |
+
else:
|
| 26 |
+
return Image.open(file_path).convert("L") # grayscale
|
| 27 |
+
|
| 28 |
+
def _to_tensor(img: Image.Image) -> torch.Tensor:
|
| 29 |
+
# Resize, center-crop/pad to square, stack to 3 channels, normalize 0..1
|
| 30 |
+
img = img.resize((IMG_SIZE, IMG_SIZE))
|
| 31 |
+
arr = np.array(img).astype(np.float32) / 255.0
|
| 32 |
+
if arr.ndim == 2:
|
| 33 |
+
arr = np.stack([arr, arr, arr], axis=0) # 3xHxW
|
| 34 |
+
elif arr.ndim == 3:
|
| 35 |
+
arr = arr.transpose(2, 0, 1) # HWC -> CHW
|
| 36 |
+
return torch.from_numpy(arr)
|
| 37 |
+
|
| 38 |
+
def load_exam_as_batch(file_paths: List[str]) -> torch.Tensor:
|
| 39 |
+
imgs = [_to_tensor(_read_image(p)) for p in file_paths]
|
| 40 |
+
x = torch.stack(imgs, dim=0) # Nx3xHxW
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
def aggregate_predictions(days_list: List[float], proba_list: List[float]) -> Tuple[float, float]:
|
| 44 |
+
if len(days_list) == 0:
|
| 45 |
+
return 0.0, 0.0
|
| 46 |
+
return float(np.mean(days_list)), float(np.mean(proba_list))
|
| 47 |
+
|
| 48 |
+
def clamp_days(d: float) -> float:
|
| 49 |
+
return float(max(1.0, min(300.0, d)))
|
| 50 |
+
|
| 51 |
+
def today_plus_days(days: float) -> str:
|
| 52 |
+
base = dt.date.today()
|
| 53 |
+
target = base + dt.timedelta(days=int(round(days)))
|
| 54 |
+
return target.isoformat()
|