
astroPTv2.0: a Large Observation Model for Astronomy
Here we have the model files for the astroPT project, the code to run inference with these models is found here: https://github.com/smith42/astropt
You will find the fully trained models (pretrained on 8.6 million galaxies) in
folders labelled with the model parameter count in the astropt
directory.
Unlike the older models which were trained on the "image" column in smith42/galaxies, these models are trained on the "cropped" galaxies from the "image_crop" column. Those galaxies have been cropped and zoomed so that they take up the majority of each image before uploading.
We get some promising scaling on this new dataset, see below:

Usage
To use these models in anger you can pip install astropt
and run the following code:
from astropt.model_utils import load_astropt
from astropt.local_datasets import GalaxyImageDataset
from datasets import load_dataset # for Smith42/galaxies
import torch
import numpy as np
from functools import partial
from torch.utils.data import DataLoader
from torchvision import transforms
# boilerplate to preprocess galaxy images
def normalise(x):
std, mean = torch.std_mean(x, dim=1, keepdim=True)
return (x - mean) / (std + 1e-8)
def data_transforms():
return transforms.Compose([transforms.Lambda(normalise)])
def _process_galaxy_wrapper(idx, func):
"""This function ensures that the image is tokenised in the same way as the pre-trained model is expecting"""
galaxy = func(
torch.from_numpy(np.array(idx["image"]).swapaxes(0, 2)).to(float)
).to(torch.float)
galaxy_positions = torch.arange(0, len(galaxy), dtype=torch.long)
return {
"images": galaxy,
"images_positions": galaxy_positions,
}
# for 095M parameter model, 015M and 850M models are also available:
model = load_astropt("Smith42/astroPT_v2.0", path="astropt/095M")
galproc = GalaxyImageDataset(
None,
spiral=True,
transform={"images": data_transforms()},
modality_registry=model.modality_registry
)
ds = (
load_dataset("Smith42/galaxies", split="test", revision="v2.0", streaming=True)
.select_columns("image")
.map(partial(_process_galaxy_wrapper, func=galproc.process_galaxy))
.with_format("torch")
)
dl = iter(DataLoader(ds, batch_size=128, num_workers=32))
zs = []
for B in dl:
zs.append(model.generate_embeddings(B)["images"].detach().numpy())
zs = np.concatenate(zs)
# do cool stuff with zs...
Updates and community
AstroPT is an open-to-all UniverseTBD project. Please join the UniverseTBD Discord for updates: https://discord.gg/MNEVegvfJq
- Downloads last month
- 47