graphcast_amse
This repository contains the model checkpoints trained as part of Subich et al 2025, which introduces an adjusted mean squared error (AMSE) loss function to elimiante the "double penalty" problem in the training of weather models. This loss function decomposes model error in spherical harmonic space and contains separate terms for amplitude and correlation errors, by total wavenumber.
These models are based on the graphcast-operational
model trained by Google DeepMind, a ¼°, 13-level version of GraphCast trained on the ERA5 dataset and fine-tuned on the HERS initial conditions dataset (both available from WeatherBench 2). The checkpoints here were fine-tuned on the HRES initial conditions dataset, with a batch size of 8 and the following training curriculum (cosine schedule, warmup 512 samples or 64 batches):
Length | Batches | Peak LR | End LR |
---|---|---|---|
1 step (6h) | 25,000 | 2.5e-5 | 1.25e-7 |
2 steps (12h) | 2,500 | 2.5e-6 | 7.5e-8 |
4 steps (24h) | 2,500 | 2.5e-6 | 7.5e-8 |
8 steps (48h) | 1,250 | 2.5e-6 | 7.5e-8 |
12 steps (72h) | 1,250 | 2.5e-6 | 7.5e-8 |
The model checkpoints are in the params/ar{1,12}
directories, the former containing the checkpoints after the end of the first training stage and the latter containing the final checkpoints. The models trained are:
- amse.ckpt -- "the" model trained with the AMSE loss function
- mse.ckpt -- A control model trained with the ordinary MSE loss function
- mae.ckpt -- An ablative study, trained with the mean absolute error loss function
The model training code is available at github.
Because these models are based on the graphcast-operational
checkpoint, the checkpoints retain the CC-BY-ND-SA 4.0 license.