MIRAI: Mammography-based Risk Assessment with AI
Model Description
MIRAI is a state-of-the-art deep learning model for breast cancer risk prediction using mammography images. Developed by researchers at MIT and Massachusetts General Hospital, it predicts breast cancer risk at multiple time points (1-5 years) using standard mammography views and optional clinical risk factors.
The model has been validated across diverse populations and imaging devices, demonstrating robust performance across different demographics and technical settings.
Key Features
- Multi-time-point prediction: Provides risk assessments for years 1-5
- Robust across populations: Validated on diverse ethnicities and age groups
- Device agnostic: Works with mammograms from different manufacturers
- Clinical integration ready: Designed for real-world deployment
Model Architecture
MIRAI uses a two-stage architecture:
Image Encoder (ResNet-based)
- Processes individual mammogram views (L-CC, L-MLO, R-CC, R-MLO)
- Input: 1664ร2048 pixel images
- Output: 2048-dimensional feature vectors
Transformer + Risk Factor Module
- Aggregates features from multiple views
- Incorporates 34 clinical risk factors (optional)
- Outputs: 5-year cancer risk predictions
Installation
From Hugging Face Hub (Recommended)
pip install torch torchvision numpy pillow huggingface-hub transformers opencv-python
Additional Dependencies (Optional)
pip install pandas scikit-learn pydicom # For data processing and DICOM support
Quick Start
Method 1: Using Hugging Face Hub (Recommended)
import torch
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
import sys
# Download model from Hugging Face Hub
model_dir = snapshot_download(repo_id="Lab-Rasool/Mirai")
sys.path.insert(0, model_dir)
# Import model components
from modeling_mirai import MiraiModel
from configuration_mirai import MiraiConfig
from preprocessor import MiraiPreprocessor
# Load model and configuration
config = MiraiConfig.from_pretrained(model_dir)
model = MiraiModel.from_pretrained(model_dir, config=config)
model.eval()
# Initialize preprocessor
preprocessor = MiraiPreprocessor()
# Load mammogram images (4 standard views required)
image_paths = {
'L-CC': 'path/to/left_cc.png',
'L-MLO': 'path/to/left_mlo.png',
'R-CC': 'path/to/right_cc.png',
'R-MLO': 'path/to/right_mlo.png'
}
# Preprocess images
exam_tensor = preprocessor.load_mammogram_exam(image_paths)
# Prepare clinical risk factors (optional but recommended)
risk_factors = {
'age': 55,
'density': 2, # BI-RADS density (1-4)
'family_history': False,
'biopsy_benign': False,
'biopsy_lcis': False,
'biopsy_atypical': False,
'menarche_age': 13,
'menopause_age': 0, # 0 if pre-menopausal
'first_pregnancy_age': 28,
'race': 1,
'weight': 70, # kg
'height': 165, # cm
'parous': True,
'menopausal_status': 0 # 0=pre, 1=post
}
risk_factors_tensor = preprocessor.prepare_risk_factors(risk_factors)
# Prepare batch - transpose to [views, channels, height, width]
exam_tensor = exam_tensor.permute(1, 0, 2, 3) # From [C, V, H, W] to [V, C, H, W]
batch_images = exam_tensor.unsqueeze(0)
batch_risk_factors = risk_factors_tensor.unsqueeze(0)
# Create metadata for the batch
batch_metadata = {
'time_seq': torch.zeros(1, 4).long(),
'view_seq': torch.tensor([[0, 1, 0, 1]]), # CC, MLO, CC, MLO
'side_seq': torch.tensor([[0, 0, 1, 1]]), # L, L, R, R
}
# Run inference
with torch.no_grad():
outputs = model(
images=batch_images,
risk_factors=batch_risk_factors,
batch_metadata=batch_metadata,
return_dict=True
)
# Extract probabilities
if hasattr(outputs, 'probabilities'):
probabilities = outputs.probabilities[0].numpy()
else:
probabilities = torch.sigmoid(outputs[0][0]).numpy()
# Display results
for year in range(len(probabilities)):
risk_pct = probabilities[year] * 100
print(f"Year {year + 1} risk: {risk_pct:.2f}%")
# Risk assessment for last available year
if len(probabilities) >= 5:
five_year_risk = probabilities[4] * 100
print(f"\n5-Year Cumulative Risk: {five_year_risk:.2f}%")
elif len(probabilities) > 0:
last_year = len(probabilities)
last_risk = probabilities[-1] * 100
print(f"\n{last_year}-Year Cumulative Risk: {last_risk:.2f}%")
Method 2: Creating Sample Test Data
import numpy as np
from PIL import Image
import os
def create_sample_mammogram_images(output_dir="sample_mammograms"):
"""Create sample mammogram images for testing."""
os.makedirs(output_dir, exist_ok=True)
views = ['L-CC', 'L-MLO', 'R-CC', 'R-MLO']
image_paths = {}
for view in views:
# Create synthetic mammogram-like image
np.random.seed(hash(view) % 1000)
image = np.zeros((1664, 2048), dtype=np.uint16)
# Add tissue-like patterns
background = np.random.normal(8000, 2000, (1664, 2048))
# Add dense regions
for _ in range(5):
x, y = np.random.randint(200, 1400), np.random.randint(200, 1800)
size = np.random.randint(100, 300)
density = np.random.normal(12000, 1000)
y_coords, x_coords = np.ogrid[:1664, :2048]
mask = ((y_coords - x)**2 + (x_coords - y)**2) <= size**2
background[mask] += density * np.exp(-((y_coords - x)**2 + (x_coords - y)**2) / (2 * size**2))[mask]
# Add chest wall edge
if 'L' in view:
background[:, :200] += np.linspace(5000, 0, 200)
else:
background[:, -200:] += np.linspace(0, 5000, 200)
# Save as PNG
image = np.clip(background, 0, 65535).astype(np.uint16)
filepath = os.path.join(output_dir, f"{view}.png")
Image.fromarray((image / 256).astype(np.uint8)).save(filepath)
image_paths[view] = filepath
return image_paths
# Use the sample images with the model
image_paths = create_sample_mammogram_images()
# Then follow the steps above to run inference
Input Requirements
Mammogram Images
- Views Required: 4 standard views (L-CC, L-MLO, R-CC, R-MLO)
- Format: PNG16 (converted from DICOM) or PNG8
- Size: 1664ร2048 pixels
- Preprocessing: Images should be converted using DCMTK with
+on2
and--min-max-window
flags - Normalization: Applied automatically by preprocessor (mean=7047.99, std=12005.5)
Risk Factors (Optional but Recommended)
The model can utilize 34 clinical risk factors for improved accuracy:
# Complete list of risk factors with expected formats
risk_factors = {
# Demographics
'age': 55, # Age in years
'race': 1, # 1=White, 2=Black, 3=Asian, 4=Other
'weight': 70, # Weight in kg
'height': 165, # Height in cm
# Breast density
'density': 2, # BI-RADS density: 1=A, 2=B, 3=C, 4=D
# Family history
'family_history': False, # First-degree relative with breast cancer
# Biopsy history
'biopsy_benign': False, # Previous benign biopsy
'biopsy_lcis': False, # Previous LCIS diagnosis
'biopsy_atypical': False, # Previous atypical hyperplasia
# Reproductive history
'menarche_age': 13, # Age at first menstruation
'menopause_age': 0, # Age at menopause (0 if pre-menopausal)
'first_pregnancy_age': 28, # Age at first pregnancy (0 if nulliparous)
'parous': True, # Has had children
'menopausal_status': 0 # 0=pre-menopausal, 1=post-menopausal
}
Model Performance
Performance metrics from the original paper (Science Translational Medicine, 2021):
Dataset | AUC (5-year) | C-index |
---|---|---|
MGH Test Set | 0.76 | 0.71 |
External Validation (Karolinska) | 0.74 | 0.70 |
External Validation (CGMH) | 0.75 | 0.71 |
Complete Working Example
#!/usr/bin/env python3
"""
Complete example for using MIRAI model from Hugging Face Hub
"""
import torch
import numpy as np
from PIL import Image
import os
from huggingface_hub import snapshot_download
import sys
# Step 1: Download and setup model
print("Downloading MIRAI model from Hugging Face Hub...")
model_dir = snapshot_download(
repo_id="Lab-Rasool/Mirai",
cache_dir=".cache"
)
sys.path.insert(0, model_dir)
from modeling_mirai import MiraiModel
from configuration_mirai import MiraiConfig
from preprocessor import MiraiPreprocessor
# Step 2: Load model
config = MiraiConfig.from_pretrained("Lab-Rasool/Mirai")
model = MiraiModel.from_pretrained("Lab-Rasool/Mirai", config=config)
model.eval()
preprocessor = MiraiPreprocessor()
# Step 3: Load your mammogram images
# Replace these paths with actual mammogram images
image_paths = {
'L-CC': 'path/to/left_cc.png',
'L-MLO': 'path/to/left_mlo.png',
'R-CC': 'path/to/right_cc.png',
'R-MLO': 'path/to/right_mlo.png'
}
# Step 4: Preprocess images
exam_tensor = preprocessor.load_mammogram_exam(image_paths)
# Step 5: Prepare risk factors
risk_factors = {
'age': 55,
'density': 2,
'family_history': False,
'biopsy_benign': False,
'biopsy_lcis': False,
'biopsy_atypical': False,
'menarche_age': 13,
'menopause_age': 0,
'first_pregnancy_age': 28,
'race': 1,
'weight': 70,
'height': 165,
'parous': True,
'menopausal_status': 0
}
risk_factors_tensor = preprocessor.prepare_risk_factors(risk_factors)
# Step 6: Prepare batch
batch_images = exam_tensor.unsqueeze(0)
batch_risk_factors = risk_factors_tensor.unsqueeze(0)
batch_metadata = {
'time_seq': torch.zeros(1, 4).long(),
'view_seq': torch.tensor([[0, 1, 2, 3]]),
'side_seq': torch.tensor([[0, 0, 1, 1]]),
}
# Step 7: Run inference
with torch.no_grad():
outputs = model(
images=batch_images,
risk_factors=batch_risk_factors,
batch_metadata=batch_metadata,
return_dict=True
)
# Step 8: Process results
if hasattr(outputs, 'probabilities'):
probabilities = outputs.probabilities[0].numpy()
else:
probabilities = torch.sigmoid(outputs[0][0]).numpy()
# Display risk assessment
print("\nBreast Cancer Risk Assessment:")
print("-" * 40)
for year in range(len(probabilities)):
risk_pct = probabilities[year] * 100
print(f"Year {year + 1}: {risk_pct:5.2f}%")
# Interpret 5-year risk
five_year_risk = probabilities[4] * 100
print(f"\n5-Year Cumulative Risk: {five_year_risk:.2f}%")
# Risk categorization
if five_year_risk < 1.67:
print("Risk Category: Low Risk")
elif five_year_risk < 3.0:
print("Risk Category: Average Risk")
elif five_year_risk < 5.0:
print("Risk Category: Moderate Risk")
else:
print("Risk Category: High Risk")
Testing the Model
To test the model with synthetic data:
# Run the test script
python test_huggingface_model.py
This will:
- Download the model from Hugging Face Hub
- Create synthetic mammogram images for testing
- Run inference with sample risk factors
- Display risk predictions and categories
Preprocessing DICOM Images
For DICOM to PNG16 conversion:
# Using DCMTK
dcmj2pnm +on2 --min-max-window input.dcm output.png
Python DICOM Processing
import pydicom
import numpy as np
from PIL import Image
def dicom_to_png(dicom_path, output_path):
"""Convert DICOM to PNG for MIRAI model."""
# Read DICOM
ds = pydicom.dcmread(dicom_path)
# Get pixel array
pixel_array = ds.pixel_array
# Apply window/level if present
if hasattr(ds, 'WindowCenter') and hasattr(ds, 'WindowWidth'):
window_center = ds.WindowCenter
window_width = ds.WindowWidth
# Handle lists
if isinstance(window_center, list):
window_center = window_center[0]
if isinstance(window_width, list):
window_width = window_width[0]
# Apply windowing
img_min = window_center - window_width // 2
img_max = window_center + window_width // 2
pixel_array = np.clip(pixel_array, img_min, img_max)
# Normalize to 16-bit
pixel_array = ((pixel_array - pixel_array.min()) /
(pixel_array.max() - pixel_array.min()) * 65535).astype(np.uint16)
# Save as PNG
Image.fromarray((pixel_array / 256).astype(np.uint8)).save(output_path)
return output_path
Clinical Risk Categories
The model outputs are typically interpreted as:
5-Year Risk | Category | Recommendation |
---|---|---|
< 1.67% | Low Risk | Standard screening |
1.67-3.0% | Average Risk | Annual mammography |
3.0-5.0% | Moderate Risk | Consider supplemental screening |
> 5.0% | High Risk | Discuss risk reduction strategies |
Troubleshooting
Common Issues and Solutions
1. Model Download Issues
# If automatic download fails, manually download:
from huggingface_hub import snapshot_download
model_dir = snapshot_download(
repo_id="Lab-Rasool/Mirai",
cache_dir=".cache",
force_download=True # Force re-download
)
2. Missing Views Error
# Ensure all 4 views are provided
required_views = ['L-CC', 'L-MLO', 'R-CC', 'R-MLO']
for view in required_views:
if view not in image_paths:
print(f"Missing required view: {view}")
3. Image Size Mismatch
# Resize images if needed
from PIL import Image
def resize_mammogram(image_path, target_size=(1664, 2048)):
img = Image.open(image_path)
img_resized = img.resize(target_size, Image.LANCZOS)
return img_resized
4. Memory Issues
# For systems with limited memory
torch.cuda.empty_cache() # Clear GPU cache
model = model.cpu() # Run on CPU instead
API Reference
MiraiModel
model = MiraiModel.from_pretrained(
"Lab-Rasool/Mirai",
config=config,
cache_dir=".cache" # Optional: specify cache directory
)
MiraiPreprocessor
preprocessor = MiraiPreprocessor()
# Load mammogram exam
exam_tensor = preprocessor.load_mammogram_exam(
image_paths, # Dict with 'L-CC', 'L-MLO', 'R-CC', 'R-MLO' keys
apply_augmentation=False # Set True for training
)
# Prepare risk factors
risk_tensor = preprocessor.prepare_risk_factors(
risk_factors_dict, # Dict with clinical risk factors
normalize=True # Apply normalization
)
Model Output Format
outputs = model(images, risk_factors, batch_metadata)
# Output structure:
# - outputs.probabilities: Tensor of shape [batch, 5] with year-wise risks
# - outputs.risk_scores: Additional risk metrics (if available)
# - outputs.features: Extracted image features (if return_features=True)
Limitations
- Requires all 4 standard mammography views
- Optimized for screening mammograms (not diagnostic)
- Risk predictions should be interpreted by healthcare professionals
- Model performance may vary with image quality
- Not validated for tomosynthesis or 3D mammography
Ethical Considerations
- This model is intended for use by healthcare professionals
- Should not be used as sole basis for clinical decisions
- Results should be interpreted in context of full clinical picture
- Ensure appropriate patient consent for AI-assisted analysis
Citation
If you use this model, please cite the original paper:
@article{yala2021toward,
title={Toward robust mammography-based models for breast cancer risk},
author={Yala, Adam and Mikhael, Peter G and Strand, Fredrik and Lin, Gigin and Smith, Kevin and
Wan, Yung-Liang and Lamb, Leslie and Hughes, Kevin and Lehman, Constance and Barzilay, Regina},
journal={Science Translational Medicine},
volume={13},
number={578},
pages={eaba4373},
year={2021},
publisher={American Association for the Advancement of Science}
}
License
This model is released under the MIT License. See LICENSE file for details.
Copyright (c) 2021 Massachusetts Institute of Technology and Massachusetts General Hospital
Acknowledgments
This implementation is based on the original work by the Barzilay Lab at MIT CSAIL and Massachusetts General Hospital. We thank the authors for making their research and model weights publicly available.
Support
For questions about the original research:
- Original Paper: Science Translational Medicine
- MIT CSAIL: Barzilay Lab
Disclaimer
This model is for research purposes. Clinical deployment requires appropriate regulatory approval and validation in the target population.
- Downloads last month
- 28