MIRAI: Mammography-based Risk Assessment with AI

Model Description

MIRAI is a state-of-the-art deep learning model for breast cancer risk prediction using mammography images. Developed by researchers at MIT and Massachusetts General Hospital, it predicts breast cancer risk at multiple time points (1-5 years) using standard mammography views and optional clinical risk factors.

The model has been validated across diverse populations and imaging devices, demonstrating robust performance across different demographics and technical settings.

Key Features

  • Multi-time-point prediction: Provides risk assessments for years 1-5
  • Robust across populations: Validated on diverse ethnicities and age groups
  • Device agnostic: Works with mammograms from different manufacturers
  • Clinical integration ready: Designed for real-world deployment

Model Architecture

MIRAI uses a two-stage architecture:

  1. Image Encoder (ResNet-based)

    • Processes individual mammogram views (L-CC, L-MLO, R-CC, R-MLO)
    • Input: 1664ร—2048 pixel images
    • Output: 2048-dimensional feature vectors
  2. Transformer + Risk Factor Module

    • Aggregates features from multiple views
    • Incorporates 34 clinical risk factors (optional)
    • Outputs: 5-year cancer risk predictions

Installation

From Hugging Face Hub (Recommended)

pip install torch torchvision numpy pillow huggingface-hub transformers opencv-python

Additional Dependencies (Optional)

pip install pandas scikit-learn pydicom  # For data processing and DICOM support

Quick Start

Method 1: Using Hugging Face Hub (Recommended)

import torch
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
import sys

# Download model from Hugging Face Hub
model_dir = snapshot_download(repo_id="Lab-Rasool/Mirai")
sys.path.insert(0, model_dir)

# Import model components
from modeling_mirai import MiraiModel
from configuration_mirai import MiraiConfig
from preprocessor import MiraiPreprocessor

# Load model and configuration
config = MiraiConfig.from_pretrained(model_dir)
model = MiraiModel.from_pretrained(model_dir, config=config)
model.eval()

# Initialize preprocessor
preprocessor = MiraiPreprocessor()

# Load mammogram images (4 standard views required)
image_paths = {
    'L-CC': 'path/to/left_cc.png',
    'L-MLO': 'path/to/left_mlo.png',
    'R-CC': 'path/to/right_cc.png',
    'R-MLO': 'path/to/right_mlo.png'
}

# Preprocess images
exam_tensor = preprocessor.load_mammogram_exam(image_paths)

# Prepare clinical risk factors (optional but recommended)
risk_factors = {
    'age': 55,
    'density': 2,  # BI-RADS density (1-4)
    'family_history': False,
    'biopsy_benign': False,
    'biopsy_lcis': False,
    'biopsy_atypical': False,
    'menarche_age': 13,
    'menopause_age': 0,  # 0 if pre-menopausal
    'first_pregnancy_age': 28,
    'race': 1,
    'weight': 70,  # kg
    'height': 165,  # cm
    'parous': True,
    'menopausal_status': 0  # 0=pre, 1=post
}

risk_factors_tensor = preprocessor.prepare_risk_factors(risk_factors)

# Prepare batch - transpose to [views, channels, height, width]
exam_tensor = exam_tensor.permute(1, 0, 2, 3)  # From [C, V, H, W] to [V, C, H, W]
batch_images = exam_tensor.unsqueeze(0)
batch_risk_factors = risk_factors_tensor.unsqueeze(0)

# Create metadata for the batch
batch_metadata = {
    'time_seq': torch.zeros(1, 4).long(),
    'view_seq': torch.tensor([[0, 1, 0, 1]]),  # CC, MLO, CC, MLO
    'side_seq': torch.tensor([[0, 0, 1, 1]]),  # L, L, R, R
}

# Run inference
with torch.no_grad():
    outputs = model(
        images=batch_images,
        risk_factors=batch_risk_factors,
        batch_metadata=batch_metadata,
        return_dict=True
    )

    # Extract probabilities
    if hasattr(outputs, 'probabilities'):
        probabilities = outputs.probabilities[0].numpy()
    else:
        probabilities = torch.sigmoid(outputs[0][0]).numpy()

    # Display results
    for year in range(len(probabilities)):
        risk_pct = probabilities[year] * 100
        print(f"Year {year + 1} risk: {risk_pct:.2f}%")

    # Risk assessment for last available year
    if len(probabilities) >= 5:
        five_year_risk = probabilities[4] * 100
        print(f"\n5-Year Cumulative Risk: {five_year_risk:.2f}%")
    elif len(probabilities) > 0:
        last_year = len(probabilities)
        last_risk = probabilities[-1] * 100
        print(f"\n{last_year}-Year Cumulative Risk: {last_risk:.2f}%")

Method 2: Creating Sample Test Data

import numpy as np
from PIL import Image
import os

def create_sample_mammogram_images(output_dir="sample_mammograms"):
    """Create sample mammogram images for testing."""
    os.makedirs(output_dir, exist_ok=True)

    views = ['L-CC', 'L-MLO', 'R-CC', 'R-MLO']
    image_paths = {}

    for view in views:
        # Create synthetic mammogram-like image
        np.random.seed(hash(view) % 1000)
        image = np.zeros((1664, 2048), dtype=np.uint16)

        # Add tissue-like patterns
        background = np.random.normal(8000, 2000, (1664, 2048))

        # Add dense regions
        for _ in range(5):
            x, y = np.random.randint(200, 1400), np.random.randint(200, 1800)
            size = np.random.randint(100, 300)
            density = np.random.normal(12000, 1000)

            y_coords, x_coords = np.ogrid[:1664, :2048]
            mask = ((y_coords - x)**2 + (x_coords - y)**2) <= size**2
            background[mask] += density * np.exp(-((y_coords - x)**2 + (x_coords - y)**2) / (2 * size**2))[mask]

        # Add chest wall edge
        if 'L' in view:
            background[:, :200] += np.linspace(5000, 0, 200)
        else:
            background[:, -200:] += np.linspace(0, 5000, 200)

        # Save as PNG
        image = np.clip(background, 0, 65535).astype(np.uint16)
        filepath = os.path.join(output_dir, f"{view}.png")
        Image.fromarray((image / 256).astype(np.uint8)).save(filepath)
        image_paths[view] = filepath

    return image_paths

# Use the sample images with the model
image_paths = create_sample_mammogram_images()
# Then follow the steps above to run inference

Input Requirements

Mammogram Images

  • Views Required: 4 standard views (L-CC, L-MLO, R-CC, R-MLO)
  • Format: PNG16 (converted from DICOM) or PNG8
  • Size: 1664ร—2048 pixels
  • Preprocessing: Images should be converted using DCMTK with +on2 and --min-max-window flags
  • Normalization: Applied automatically by preprocessor (mean=7047.99, std=12005.5)

Risk Factors (Optional but Recommended)

The model can utilize 34 clinical risk factors for improved accuracy:

# Complete list of risk factors with expected formats
risk_factors = {
    # Demographics
    'age': 55,                    # Age in years
    'race': 1,                     # 1=White, 2=Black, 3=Asian, 4=Other
    'weight': 70,                  # Weight in kg
    'height': 165,                 # Height in cm

    # Breast density
    'density': 2,                  # BI-RADS density: 1=A, 2=B, 3=C, 4=D

    # Family history
    'family_history': False,       # First-degree relative with breast cancer

    # Biopsy history
    'biopsy_benign': False,        # Previous benign biopsy
    'biopsy_lcis': False,          # Previous LCIS diagnosis
    'biopsy_atypical': False,      # Previous atypical hyperplasia

    # Reproductive history
    'menarche_age': 13,           # Age at first menstruation
    'menopause_age': 0,           # Age at menopause (0 if pre-menopausal)
    'first_pregnancy_age': 28,    # Age at first pregnancy (0 if nulliparous)
    'parous': True,               # Has had children
    'menopausal_status': 0        # 0=pre-menopausal, 1=post-menopausal
}

Model Performance

Performance metrics from the original paper (Science Translational Medicine, 2021):

Dataset AUC (5-year) C-index
MGH Test Set 0.76 0.71
External Validation (Karolinska) 0.74 0.70
External Validation (CGMH) 0.75 0.71

Complete Working Example

#!/usr/bin/env python3
"""
Complete example for using MIRAI model from Hugging Face Hub
"""

import torch
import numpy as np
from PIL import Image
import os
from huggingface_hub import snapshot_download
import sys

# Step 1: Download and setup model
print("Downloading MIRAI model from Hugging Face Hub...")
model_dir = snapshot_download(
    repo_id="Lab-Rasool/Mirai",
    cache_dir=".cache"
)
sys.path.insert(0, model_dir)

from modeling_mirai import MiraiModel
from configuration_mirai import MiraiConfig
from preprocessor import MiraiPreprocessor

# Step 2: Load model
config = MiraiConfig.from_pretrained("Lab-Rasool/Mirai")
model = MiraiModel.from_pretrained("Lab-Rasool/Mirai", config=config)
model.eval()
preprocessor = MiraiPreprocessor()

# Step 3: Load your mammogram images
# Replace these paths with actual mammogram images
image_paths = {
    'L-CC': 'path/to/left_cc.png',
    'L-MLO': 'path/to/left_mlo.png',
    'R-CC': 'path/to/right_cc.png',
    'R-MLO': 'path/to/right_mlo.png'
}

# Step 4: Preprocess images
exam_tensor = preprocessor.load_mammogram_exam(image_paths)

# Step 5: Prepare risk factors
risk_factors = {
    'age': 55,
    'density': 2,
    'family_history': False,
    'biopsy_benign': False,
    'biopsy_lcis': False,
    'biopsy_atypical': False,
    'menarche_age': 13,
    'menopause_age': 0,
    'first_pregnancy_age': 28,
    'race': 1,
    'weight': 70,
    'height': 165,
    'parous': True,
    'menopausal_status': 0
}

risk_factors_tensor = preprocessor.prepare_risk_factors(risk_factors)

# Step 6: Prepare batch
batch_images = exam_tensor.unsqueeze(0)
batch_risk_factors = risk_factors_tensor.unsqueeze(0)
batch_metadata = {
    'time_seq': torch.zeros(1, 4).long(),
    'view_seq': torch.tensor([[0, 1, 2, 3]]),
    'side_seq': torch.tensor([[0, 0, 1, 1]]),
}

# Step 7: Run inference
with torch.no_grad():
    outputs = model(
        images=batch_images,
        risk_factors=batch_risk_factors,
        batch_metadata=batch_metadata,
        return_dict=True
    )

# Step 8: Process results
if hasattr(outputs, 'probabilities'):
    probabilities = outputs.probabilities[0].numpy()
else:
    probabilities = torch.sigmoid(outputs[0][0]).numpy()

# Display risk assessment
print("\nBreast Cancer Risk Assessment:")
print("-" * 40)
for year in range(len(probabilities)):
    risk_pct = probabilities[year] * 100
    print(f"Year {year + 1}: {risk_pct:5.2f}%")

# Interpret 5-year risk
five_year_risk = probabilities[4] * 100
print(f"\n5-Year Cumulative Risk: {five_year_risk:.2f}%")

# Risk categorization
if five_year_risk < 1.67:
    print("Risk Category: Low Risk")
elif five_year_risk < 3.0:
    print("Risk Category: Average Risk")
elif five_year_risk < 5.0:
    print("Risk Category: Moderate Risk")
else:
    print("Risk Category: High Risk")

Testing the Model

To test the model with synthetic data:

# Run the test script
python test_huggingface_model.py

This will:

  1. Download the model from Hugging Face Hub
  2. Create synthetic mammogram images for testing
  3. Run inference with sample risk factors
  4. Display risk predictions and categories

Preprocessing DICOM Images

For DICOM to PNG16 conversion:

# Using DCMTK
dcmj2pnm +on2 --min-max-window input.dcm output.png

Python DICOM Processing

import pydicom
import numpy as np
from PIL import Image

def dicom_to_png(dicom_path, output_path):
    """Convert DICOM to PNG for MIRAI model."""
    # Read DICOM
    ds = pydicom.dcmread(dicom_path)

    # Get pixel array
    pixel_array = ds.pixel_array

    # Apply window/level if present
    if hasattr(ds, 'WindowCenter') and hasattr(ds, 'WindowWidth'):
        window_center = ds.WindowCenter
        window_width = ds.WindowWidth

        # Handle lists
        if isinstance(window_center, list):
            window_center = window_center[0]
        if isinstance(window_width, list):
            window_width = window_width[0]

        # Apply windowing
        img_min = window_center - window_width // 2
        img_max = window_center + window_width // 2
        pixel_array = np.clip(pixel_array, img_min, img_max)

    # Normalize to 16-bit
    pixel_array = ((pixel_array - pixel_array.min()) /
                   (pixel_array.max() - pixel_array.min()) * 65535).astype(np.uint16)

    # Save as PNG
    Image.fromarray((pixel_array / 256).astype(np.uint8)).save(output_path)

    return output_path

Clinical Risk Categories

The model outputs are typically interpreted as:

5-Year Risk Category Recommendation
< 1.67% Low Risk Standard screening
1.67-3.0% Average Risk Annual mammography
3.0-5.0% Moderate Risk Consider supplemental screening
> 5.0% High Risk Discuss risk reduction strategies

Troubleshooting

Common Issues and Solutions

1. Model Download Issues

# If automatic download fails, manually download:
from huggingface_hub import snapshot_download

model_dir = snapshot_download(
    repo_id="Lab-Rasool/Mirai",
    cache_dir=".cache",
    force_download=True  # Force re-download
)

2. Missing Views Error

# Ensure all 4 views are provided
required_views = ['L-CC', 'L-MLO', 'R-CC', 'R-MLO']
for view in required_views:
    if view not in image_paths:
        print(f"Missing required view: {view}")

3. Image Size Mismatch

# Resize images if needed
from PIL import Image

def resize_mammogram(image_path, target_size=(1664, 2048)):
    img = Image.open(image_path)
    img_resized = img.resize(target_size, Image.LANCZOS)
    return img_resized

4. Memory Issues

# For systems with limited memory
torch.cuda.empty_cache()  # Clear GPU cache
model = model.cpu()  # Run on CPU instead

API Reference

MiraiModel

model = MiraiModel.from_pretrained(
    "Lab-Rasool/Mirai",
    config=config,
    cache_dir=".cache"  # Optional: specify cache directory
)

MiraiPreprocessor

preprocessor = MiraiPreprocessor()

# Load mammogram exam
exam_tensor = preprocessor.load_mammogram_exam(
    image_paths,  # Dict with 'L-CC', 'L-MLO', 'R-CC', 'R-MLO' keys
    apply_augmentation=False  # Set True for training
)

# Prepare risk factors
risk_tensor = preprocessor.prepare_risk_factors(
    risk_factors_dict,  # Dict with clinical risk factors
    normalize=True  # Apply normalization
)

Model Output Format

outputs = model(images, risk_factors, batch_metadata)

# Output structure:
# - outputs.probabilities: Tensor of shape [batch, 5] with year-wise risks
# - outputs.risk_scores: Additional risk metrics (if available)
# - outputs.features: Extracted image features (if return_features=True)

Limitations

  • Requires all 4 standard mammography views
  • Optimized for screening mammograms (not diagnostic)
  • Risk predictions should be interpreted by healthcare professionals
  • Model performance may vary with image quality
  • Not validated for tomosynthesis or 3D mammography

Ethical Considerations

  • This model is intended for use by healthcare professionals
  • Should not be used as sole basis for clinical decisions
  • Results should be interpreted in context of full clinical picture
  • Ensure appropriate patient consent for AI-assisted analysis

Citation

If you use this model, please cite the original paper:

@article{yala2021toward,
  title={Toward robust mammography-based models for breast cancer risk},
  author={Yala, Adam and Mikhael, Peter G and Strand, Fredrik and Lin, Gigin and Smith, Kevin and
          Wan, Yung-Liang and Lamb, Leslie and Hughes, Kevin and Lehman, Constance and Barzilay, Regina},
  journal={Science Translational Medicine},
  volume={13},
  number={578},
  pages={eaba4373},
  year={2021},
  publisher={American Association for the Advancement of Science}
}

License

This model is released under the MIT License. See LICENSE file for details.

Copyright (c) 2021 Massachusetts Institute of Technology and Massachusetts General Hospital

Acknowledgments

This implementation is based on the original work by the Barzilay Lab at MIT CSAIL and Massachusetts General Hospital. We thank the authors for making their research and model weights publicly available.

Support

For questions about the original research:

Disclaimer

This model is for research purposes. Clinical deployment requires appropriate regulatory approval and validation in the target population.

Downloads last month
28
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support