|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
<img src="pixcell_256_banner.png" alt="pixcell_256_banner" width="500"/> |
|
|
|
# PixCell: A generative foundation model for digital histopathology images |
|
|
|
[[π arXiv]](https://arxiv.org/abs/2506.05127)[[π¬ PixCell-1024]](https://huggingface.co/StonyBrook-CVLab/PixCell-1024) [[π¬ PixCell-256]](https://huggingface.co/StonyBrook-CVLab/PixCell-256) [[π¬ Pixcell-256-Cell-ControlNet]](https://huggingface.co/StonyBrook-CVLab/PixCell-256-Cell-ControlNet) [[πΎ Synthetic SBU-1M]](https://huggingface.co/datasets/StonyBrook-CVLab/Synthetic-SBU-1M) |
|
|
|
### Load PixCell-256 model |
|
|
|
```python |
|
import torch |
|
|
|
from diffusers import DiffusionPipeline |
|
from diffusers import AutoencoderKL |
|
|
|
device = torch.device('cuda') |
|
|
|
# We do not host the weights of the SD3 VAE -- load it from StabilityAI |
|
sd3_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="vae") |
|
|
|
pipeline = DiffusionPipeline.from_pretrained( |
|
"StonyBrook-CVLab/PixCell-256", |
|
vae=sd3_vae, |
|
custom_pipeline="StonyBrook-CVLab/PixCell-pipeline", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
pipeline.to(device); |
|
``` |
|
|
|
### Load [[UNI-2h]](https://huggingface.co/MahmoodLab/UNI2-h) for conditioning |
|
```python |
|
import timm |
|
from timm.data import resolve_data_config |
|
from timm.data.transforms_factory import create_transform |
|
|
|
timm_kwargs = { |
|
'img_size': 224, |
|
'patch_size': 14, |
|
'depth': 24, |
|
'num_heads': 24, |
|
'init_values': 1e-5, |
|
'embed_dim': 1536, |
|
'mlp_ratio': 2.66667*2, |
|
'num_classes': 0, |
|
'no_embed_class': True, |
|
'mlp_layer': timm.layers.SwiGLUPacked, |
|
'act_layer': torch.nn.SiLU, |
|
'reg_tokens': 8, |
|
'dynamic_img_size': True |
|
} |
|
uni_model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs) |
|
transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model)) |
|
uni_model.eval() |
|
uni_model.to(device); |
|
``` |
|
|
|
### Unconditional generation |
|
```python |
|
uncond = pipeline.get_unconditional_embedding(1) |
|
with torch.amp.autocast('cuda'): |
|
samples = pipeline(uni_embeds=uncond, negative_uni_embeds=None, guidance_scale=1.0) |
|
``` |
|
|
|
### Conditional generation |
|
```python |
|
# Load image |
|
import numpy as np |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
|
|
# This is an example image we provide |
|
path = hf_hub_download(repo_id="StonyBrook-CVLab/PixCell-256", filename="test_image.png") |
|
image = Image.open(path).convert("RGB") |
|
|
|
# Extract UNI embedding from the image |
|
uni_inp = transform(image).unsqueeze(dim=0) |
|
with torch.inference_mode(): |
|
uni_emb = uni_model(uni_inp.to(device)) |
|
|
|
# reshape UNI to (bs, 1, D) |
|
uni_emb = uni_emb.unsqueeze(1) |
|
print("Extracted UNI:", uni_emb.shape) |
|
|
|
# Get unconditional embedding for classifier-free guidance |
|
uncond = pipeline.get_unconditional_embedding(uni_emb.shape[0]) |
|
# Generate new samples |
|
with torch.amp.autocast('cuda'): |
|
samples = pipeline(uni_embeds=uni_emb, negative_uni_embeds=uncond, guidance_scale=3., num_images_per_prompt=1).images |
|
``` |