fantos commited on
Commit
2a034f9
·
verified ·
1 Parent(s): 45561c1

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +31 -13
  2. app.py +104 -0
  3. model.py +62 -0
  4. requirements.txt +7 -0
  5. utils.py +54 -0
README.md CHANGED
@@ -1,13 +1,31 @@
1
- ---
2
- title: PAIR
3
- emoji: 🦀
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.42.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: PAIR
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PAIR-inspired Delivery Timing Predictor (Gradio Space)
2
+
3
+ Important
4
+ This demo is inspired by the PAIR study (Perinatal artificial intelligence in ultrasound) but does not include the proprietary model or the private clinical dataset described in the paper. It is provided only as a technical scaffold and demonstration UI.
5
+
6
+ What this Space does
7
+ 1. Lets you upload ultrasound images (PNG/JPG) or DICOM files and returns:
8
+ - Predicted days-until-delivery (regression).
9
+ - Preterm probability and label (binary classification with threshold 0.5).
10
+ 2. Aggregates predictions across multiple images from the same exam using a simple mean.
11
+ 3. Can load your own PyTorch weights via a Hugging Face repo id or a path to a .pt file.
12
+
13
+ How to use in your own Space
14
+ 1) Create a new Space on Hugging Face with SDK set to Gradio.
15
+ 2) Upload the files in this repository.
16
+ 3) Optional: place your model weights at weights/pair_v4.pt or set the env var HF_WEIGHTS to point to a Hugging Face model repo or a local .pt file.
17
+ 4) Click Run. If weights are missing, the app falls back to a constant baseline just to demonstrate the UI.
18
+
19
+ Model input
20
+ - One or more 2D ultrasound images or DICOM frames from one exam. The app will convert grayscale to 3-channels when needed.
21
+
22
+ Model output
23
+ - days_to_delivery: float in [1, 300] (clamped).
24
+ - preterm_proba: float in [0, 1].
25
+ - preterm_label: Term if proba < 0.5, Preterm otherwise.
26
+ - predicted_date: today + days_to_delivery (for demo; in clinical use you would provide the scan date).
27
+
28
+ Notes
29
+ - This is not a medical device and is not for clinical use.
30
+ - Performance will be meaningless without appropriate training on a suitable dataset; this scaffold is for integration and UI only.
31
+ - To replicate the paper, you would need authorized access to a comparable dataset and a trained model.
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import zipfile
4
+ from pathlib import Path
5
+ from typing import List, Tuple
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from model import PairNet, load_weights_if_any
12
+ from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
13
+
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ def init_model() -> Tuple[torch.nn.Module, str, bool]:
17
+ model = PairNet(pretrained=True).to(DEVICE)
18
+ model.eval()
19
+ weights_hint = os.getenv("HF_WEIGHTS", "").strip()
20
+ ok = False
21
+ msg = "Running in baseline mode (no weights). Output is a UI demo only."
22
+ if weights_hint:
23
+ ok, msg = load_weights_if_any(model, weights_hint)
24
+ elif Path("weights/pair_v4.pt").exists():
25
+ ok, msg = load_weights_if_any(model, "weights/pair_v4.pt")
26
+ return model, msg, ok
27
+
28
+ MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
29
+
30
+ def predict_on_files(files: List[gr.File]) -> dict:
31
+ # Collect file paths (also support zip of images)
32
+ paths: List[str] = []
33
+ for f in files or []:
34
+ p = Path(f.name)
35
+ if p.suffix.lower() == ".zip":
36
+ with zipfile.ZipFile(p, "r") as z:
37
+ with tempfile.TemporaryDirectory() as td:
38
+ z.extractall(td)
39
+ for ext in (".png", ".jpg", ".jpeg", ".dcm"):
40
+ paths.extend([str(q) for q in Path(td).rglob(f"*{ext}")])
41
+ else:
42
+ paths.append(str(p))
43
+
44
+ if not paths:
45
+ return {"status": "no files received"}
46
+
47
+ x = load_exam_as_batch(paths).to(DEVICE)
48
+ x = x.float()
49
+
50
+ if HAS_WEIGHTS:
51
+ with torch.no_grad():
52
+ days_raw, logits = MODEL(x)
53
+ # Aggregate across frames by mean
54
+ days = days_raw.squeeze(-1).detach().cpu().tolist()
55
+ proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
56
+ else:
57
+ # Baseline demo: constant mid-gestation-ish guess and neutral probability
58
+ days = [150.0 for _ in range(x.shape[0])]
59
+ proba = [0.5 for _ in range(x.shape[0])]
60
+
61
+ days = [clamp_days(float(d)) for d in days]
62
+ preterm = ["Preterm" if p >= 0.5 else "Term" for p in proba]
63
+
64
+ days_mean, proba_mean = aggregate_predictions(days, proba)
65
+
66
+ result = {
67
+ "frames": len(paths),
68
+ "per_frame_days": days,
69
+ "per_frame_preterm_proba": proba,
70
+ "per_frame_preterm_label": preterm,
71
+ "aggregate_days_mean": days_mean,
72
+ "aggregate_predicted_date": today_plus_days(days_mean),
73
+ "aggregate_preterm_proba": proba_mean,
74
+ "aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
75
+ "weights_message": LOAD_MSG
76
+ }
77
+ return result
78
+
79
+ with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
80
+ gr.Markdown(
81
+ "PAIR-inspired Delivery Timing Predictor
82
+
83
+ "
84
+ "This app is a technical scaffold inspired by the PAIR study. "
85
+ "It does not include the proprietary model or clinical dataset. "
86
+ "Not for medical use."
87
+ )
88
+
89
+ with gr.Row():
90
+ with gr.Column():
91
+ in_files = gr.Files(label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)", file_count="multiple", type="filepath")
92
+ run_btn = gr.Button("Run prediction", variant="primary")
93
+ with gr.Column():
94
+ status = gr.JSON(label="Outputs")
95
+
96
+ note = gr.Markdown(f"Model status: {LOAD_MSG}")
97
+
98
+ def _run(files):
99
+ return predict_on_files(files)
100
+
101
+ run_btn.click(_run, inputs=[in_files], outputs=[status])
102
+
103
+ if __name__ == "__main__":
104
+ demo.launch()
model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+ class PairNet(nn.Module):
6
+ """
7
+ A lightweight backbone + dual-head network:
8
+ - Regression head for days-to-delivery
9
+ - Classification head for preterm probability
10
+ This is a scaffold and not the proprietary model from the paper.
11
+ """
12
+ def __init__(self, backbone_name: str = "efficientnet_b0", pretrained: bool = True):
13
+ super().__init__()
14
+ if backbone_name == "efficientnet_b0":
15
+ try:
16
+ weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
17
+ except Exception:
18
+ weights = None
19
+ backbone = models.efficientnet_b0(weights=weights)
20
+ in_feats = backbone.classifier[1].in_features
21
+ backbone.classifier = nn.Identity()
22
+ else:
23
+ # Fallback to resnet18
24
+ try:
25
+ weights = models.ResNet18_Weights.DEFAULT if pretrained else None
26
+ except Exception:
27
+ weights = None
28
+ backbone = models.resnet18(weights=weights)
29
+ in_feats = backbone.fc.in_features
30
+ backbone.fc = nn.Identity()
31
+
32
+ self.backbone = backbone
33
+ self.reg_head = nn.Linear(in_feats, 1)
34
+ self.cls_head = nn.Linear(in_feats, 1)
35
+
36
+ def forward(self, x):
37
+ feats = self.backbone(x)
38
+ days = self.reg_head(feats) # unconstrained
39
+ logits = self.cls_head(feats) # unconstrained
40
+ return days, logits
41
+
42
+ def load_weights_if_any(model: nn.Module, weights_path: str | None):
43
+ if not weights_path:
44
+ return False, "No weights path provided"
45
+ import os
46
+ if os.path.isfile(weights_path):
47
+ state = torch.load(weights_path, map_location="cpu")
48
+ if "state_dict" in state:
49
+ state = state["state_dict"]
50
+ missing, unexpected = model.load_state_dict(state, strict=False)
51
+ return True, f"Loaded local weights. missing={len(missing)} unexpected={len(unexpected)}"
52
+ # Try huggingface hub repo id
53
+ try:
54
+ from huggingface_hub import hf_hub_download
55
+ fp = hf_hub_download(repo_id=weights_path, filename="pytorch_model.bin", local_dir="weights")
56
+ state = torch.load(fp, map_location="cpu")
57
+ if "state_dict" in state:
58
+ state = state["state_dict"]
59
+ missing, unexpected = model.load_state_dict(state, strict=False)
60
+ return True, f"Loaded HF weights from {weights_path}. missing={len(missing)} unexpected={len(unexpected)}"
61
+ except Exception as e:
62
+ return False, f"Failed to load weights: {e}"
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ torch>=2.2.0
3
+ torchvision>=0.17.0
4
+ pillow>=10.0.0
5
+ numpy>=1.26.0
6
+ pydicom>=2.4.4
7
+ huggingface_hub>=0.23.0
utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import math
4
+ import datetime as dt
5
+ from typing import List, Tuple
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ import pydicom
11
+
12
+ IMG_SIZE = 224
13
+
14
+ def _read_image(file_path: str) -> Image.Image:
15
+ ext = os.path.splitext(file_path)[1].lower()
16
+ if ext == ".dcm":
17
+ ds = pydicom.dcmread(file_path)
18
+ arr = ds.pixel_array.astype(np.float32)
19
+ # normalize to 0..255
20
+ arr = arr - arr.min()
21
+ if arr.max() > 0:
22
+ arr = arr / arr.max()
23
+ arr = (arr * 255.0).clip(0,255).astype(np.uint8)
24
+ return Image.fromarray(arr)
25
+ else:
26
+ return Image.open(file_path).convert("L") # grayscale
27
+
28
+ def _to_tensor(img: Image.Image) -> torch.Tensor:
29
+ # Resize, center-crop/pad to square, stack to 3 channels, normalize 0..1
30
+ img = img.resize((IMG_SIZE, IMG_SIZE))
31
+ arr = np.array(img).astype(np.float32) / 255.0
32
+ if arr.ndim == 2:
33
+ arr = np.stack([arr, arr, arr], axis=0) # 3xHxW
34
+ elif arr.ndim == 3:
35
+ arr = arr.transpose(2, 0, 1) # HWC -> CHW
36
+ return torch.from_numpy(arr)
37
+
38
+ def load_exam_as_batch(file_paths: List[str]) -> torch.Tensor:
39
+ imgs = [_to_tensor(_read_image(p)) for p in file_paths]
40
+ x = torch.stack(imgs, dim=0) # Nx3xHxW
41
+ return x
42
+
43
+ def aggregate_predictions(days_list: List[float], proba_list: List[float]) -> Tuple[float, float]:
44
+ if len(days_list) == 0:
45
+ return 0.0, 0.0
46
+ return float(np.mean(days_list)), float(np.mean(proba_list))
47
+
48
+ def clamp_days(d: float) -> float:
49
+ return float(max(1.0, min(300.0, d)))
50
+
51
+ def today_plus_days(days: float) -> str:
52
+ base = dt.date.today()
53
+ target = base + dt.timedelta(days=int(round(days)))
54
+ return target.isoformat()