Upload 3 files
Browse files- .gitattributes +1 -0
- README.md +55 -3
- config.yaml +202 -0
- model_architecture.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
model_architecture.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,55 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: terratorch
|
| 4 |
+
tags:
|
| 5 |
+
- Pytorch
|
| 6 |
+
- Earth Observation
|
| 7 |
+
- Foundation Model
|
| 8 |
+
- IBM
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# granite-geospatial-ocean
|
| 12 |
+
|
| 13 |
+
The granite-geospatial-ocean foundation model was jointly developed by IBM, STFC, Plymouth Marine Laboratory and University of Exeter. This pre-trained model supports a range of potential uses cases in ocean ecosystem health, fisheries management, pollution and other ocean processes that can be monitored using ocean colour observations. We provide an example to fine tune the model to quantify primary production by phytoplankton (carbon sequestration which determine's the ocean's role in climate change).
|
| 14 |
+
|
| 15 |
+
## Architecture Overview
|
| 16 |
+
|
| 17 |
+
The granite-geospatial-ocean model is a transformer-based geospatial foundation model trained on Sentinel-3 Ocean Land Colour Instrument (OLCI) and Sea and Land Surface Temperature Radiometer (SLSTR) images. The model consists of a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder (MAE) learning strategy, with an MSE loss function and follows the same architecture as [Prithvi-EO](https://huggingface.co/collections/ibm-nasa-geospatial/prithvi-for-earth-observation-6740a7a81883466bf41d93d6).
|
| 18 |
+
|
| 19 |
+
We used a 42x42 image size and 16 bands of Level-2 sentinel-3 OLCI(OL1 to OL12, OL16, OL17, OL18 and OL21) and also a further band of Level-2 SLSTR sea surface temperature data were in the pre-training. In total of 512,000 images were used for pre-training.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
<figure>
|
| 23 |
+
<img src='./model_architecture.png' alt='missing' />
|
| 24 |
+
<!-- <figcaption>Model architecture -->
|
| 25 |
+
</figcaption>
|
| 26 |
+
</figure>
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## How to Get Started with the Model
|
| 30 |
+
|
| 31 |
+
We have provided an example of fine-tuning the model for primary production quantification which can be found [here](https://github.com/ibm-granite/geospatial/blob/main/granite-geospatial-ocean/notebooks/fine_tuning.ipynb). These examples make use of [TerraTorch](https://github.com/IBM/terratorch) for fine-tuning and prediction.
|
| 32 |
+
|
| 33 |
+
Example Notebooks:
|
| 34 |
+
|
| 35 |
+
[Primary Production Quantification](https://github.com/ibm-granite/geospatial/blob/main/granite-geospatial-ocean/notebooks/fine_tuning.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/ibm-granite/geospatial/blob/main/granite-geospatial-ocean/notebooks/fine_tuning.ipynb) (Choose T4 GPU runtime)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
### Feedback
|
| 39 |
+
|
| 40 |
+
Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by starting a discussion in this HF repository or submitting an issue to [TerraTorch](https://github.com/IBM/terratorch) on GitHub.
|
| 41 |
+
|
| 42 |
+
### Model Card Authors
|
| 43 |
+
Geoffrey Dawson, Remy Vandaele, Andrew Taylor, David Moffat, Helen Tamura-Wicks, Sarah Jackson, Chunbo Luo, Paolo Fraccaro, Hywel Williams, Rosie Lickorish and Anne Jones
|
| 44 |
+
|
| 45 |
+
### IBM Public Repository Disclosure:
|
| 46 |
+
All content in this repository including code has been provided by IBM under the associated open source software license and IBM is under no obligation to provide enhancements, updates, or support. IBM developers produced this code as an open source project (not as an IBM product), and IBM makes no assertions as to the level of quality nor security, and will not be maintaining this code going forward.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
<!-- ### Citation
|
| 50 |
+
|
| 51 |
+
If this model helped your research, please cite [Granite-ocean-gfm]() in your publications.
|
| 52 |
+
|
| 53 |
+
```
|
| 54 |
+
@article{}
|
| 55 |
+
``` -->
|
config.yaml
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# lightning.pytorch==2.1.1
|
| 2 |
+
seed_everything: 42
|
| 3 |
+
out_dtype: float32
|
| 4 |
+
custom_modules_path: ./../custom_modules/
|
| 5 |
+
### Trainer configuration
|
| 6 |
+
trainer:
|
| 7 |
+
accelerator: auto
|
| 8 |
+
strategy: auto
|
| 9 |
+
devices: auto
|
| 10 |
+
num_nodes: 1
|
| 11 |
+
# precision: 16-mixed
|
| 12 |
+
logger:
|
| 13 |
+
class_path: TensorBoardLogger
|
| 14 |
+
init_args:
|
| 15 |
+
save_dir: ./../data/
|
| 16 |
+
name: model_runs
|
| 17 |
+
callbacks:
|
| 18 |
+
- class_path: LearningRateMonitor
|
| 19 |
+
init_args:
|
| 20 |
+
logging_interval: epoch
|
| 21 |
+
- class_path: EarlyStopping
|
| 22 |
+
init_args:
|
| 23 |
+
monitor: val/loss
|
| 24 |
+
patience: 100
|
| 25 |
+
max_epochs: 1
|
| 26 |
+
check_val_every_n_epoch: 1
|
| 27 |
+
log_every_n_steps: 5
|
| 28 |
+
enable_checkpointing: true
|
| 29 |
+
default_root_dir: ./../data/
|
| 30 |
+
|
| 31 |
+
### Data configuration
|
| 32 |
+
data:
|
| 33 |
+
class_path: terratorch.datamodules.GenericNonGeoPixelwiseRegressionDataModule
|
| 34 |
+
init_args:
|
| 35 |
+
batch_size: 8
|
| 36 |
+
num_workers: 2
|
| 37 |
+
train_transform:
|
| 38 |
+
- class_path: albumentations.HorizontalFlip
|
| 39 |
+
init_args:
|
| 40 |
+
p: 0.5
|
| 41 |
+
- class_path: albumentations.RandomCrop
|
| 42 |
+
init_args:
|
| 43 |
+
height: 42
|
| 44 |
+
width: 42
|
| 45 |
+
- class_path: albumentations.Rotate
|
| 46 |
+
init_args:
|
| 47 |
+
limit: 30
|
| 48 |
+
border_mode: 0 # cv2.BORDER_CONSTANT
|
| 49 |
+
value: 0
|
| 50 |
+
# mask_value: 1
|
| 51 |
+
p: 0.5
|
| 52 |
+
- class_path: ToTensorV2
|
| 53 |
+
# Specify all bands which are in the input data.
|
| 54 |
+
# -1 are placeholders for bands that are in the data but that we will discard
|
| 55 |
+
dataset_bands:
|
| 56 |
+
- Oa01_reflectance
|
| 57 |
+
- Oa02_reflectance
|
| 58 |
+
- Oa03_reflectance
|
| 59 |
+
- Oa04_reflectance
|
| 60 |
+
- Oa05_reflectance
|
| 61 |
+
- Oa06_reflectance
|
| 62 |
+
- Oa07_reflectance
|
| 63 |
+
- Oa08_reflectance
|
| 64 |
+
- Oa09_reflectance
|
| 65 |
+
- Oa10_reflectance
|
| 66 |
+
- Oa11_reflectance
|
| 67 |
+
- Oa12_reflectance
|
| 68 |
+
- Oa16_reflectance
|
| 69 |
+
- Oa17_reflectance
|
| 70 |
+
- Oa18_reflectance
|
| 71 |
+
- Oa21_reflectance
|
| 72 |
+
- SST
|
| 73 |
+
output_bands: #Specify the bands which are used from the input data.
|
| 74 |
+
- Oa01_reflectance
|
| 75 |
+
- Oa02_reflectance
|
| 76 |
+
- Oa03_reflectance
|
| 77 |
+
- Oa04_reflectance
|
| 78 |
+
- Oa05_reflectance
|
| 79 |
+
- Oa06_reflectance
|
| 80 |
+
- Oa07_reflectance
|
| 81 |
+
- Oa08_reflectance
|
| 82 |
+
- Oa09_reflectance
|
| 83 |
+
- Oa10_reflectance
|
| 84 |
+
- Oa11_reflectance
|
| 85 |
+
- Oa12_reflectance
|
| 86 |
+
- Oa16_reflectance
|
| 87 |
+
- Oa17_reflectance
|
| 88 |
+
- Oa18_reflectance
|
| 89 |
+
- Oa21_reflectance
|
| 90 |
+
rgb_indices:
|
| 91 |
+
- 2
|
| 92 |
+
- 1
|
| 93 |
+
- 0
|
| 94 |
+
# Directory roots to training, validation and test datasplits:
|
| 95 |
+
test_data_root: ./../data/fine-tuning
|
| 96 |
+
test_label_data_root: ./../data/fine-tuning
|
| 97 |
+
test_split: ./../data/fine-tuning/test_data.txt
|
| 98 |
+
train_data_root: ./../data/fine-tuning
|
| 99 |
+
train_label_data_root: ./../data/fine-tuning
|
| 100 |
+
train_split: ./../data/fine-tuning/train_data.txt
|
| 101 |
+
val_data_root: ./../data/fine-tuning
|
| 102 |
+
val_label_data_root: ./../data/fine-tuning
|
| 103 |
+
val_split: ./../data/fine-tuning/val_data.txt
|
| 104 |
+
img_grep: "*_img.tif"
|
| 105 |
+
label_grep: "*_lab.tif"
|
| 106 |
+
means: # Mean value of the training dataset per band
|
| 107 |
+
- 11378.33724842
|
| 108 |
+
- 11379.51141294
|
| 109 |
+
- 11291.99698672
|
| 110 |
+
- 11116.38807044
|
| 111 |
+
- 10898.95680699
|
| 112 |
+
- 10686.41604621
|
| 113 |
+
- 10466.67864162
|
| 114 |
+
- 10456.52999209
|
| 115 |
+
- 10462.41327758
|
| 116 |
+
- 10464.24100298
|
| 117 |
+
- 10443.59591923
|
| 118 |
+
- 10448.53157824
|
| 119 |
+
- 10470.36129347
|
| 120 |
+
- 10454.74328843
|
| 121 |
+
- 10453.79858959
|
| 122 |
+
- 10452.88001737
|
| 123 |
+
stds: # Standard deviation of the training dataset per band
|
| 124 |
+
- 3125.36214152
|
| 125 |
+
- 3118.65965249
|
| 126 |
+
- 3088.88720386
|
| 127 |
+
- 3055.0881767
|
| 128 |
+
- 3026.73186213
|
| 129 |
+
- 2997.72812315
|
| 130 |
+
- 2968.12838628
|
| 131 |
+
- 2968.75857855
|
| 132 |
+
- 2969.94390514
|
| 133 |
+
- 2970.39202078
|
| 134 |
+
- 2964.1543642
|
| 135 |
+
- 2973.0155451
|
| 136 |
+
- 2985.89318262
|
| 137 |
+
- 2975.50852528
|
| 138 |
+
- 2973.00652761
|
| 139 |
+
- 2973.00330406
|
| 140 |
+
# Nodata value in label data
|
| 141 |
+
no_label_replace: -1
|
| 142 |
+
# Nodata value in the input data
|
| 143 |
+
no_data_replace: 0
|
| 144 |
+
### Model configuration
|
| 145 |
+
model:
|
| 146 |
+
class_path: terratorch.tasks.PixelwiseRegressionTask
|
| 147 |
+
init_args:
|
| 148 |
+
model_args:
|
| 149 |
+
backbone_pretrained: true
|
| 150 |
+
backbone: prithvi_s3_v1
|
| 151 |
+
backbone_pretrained_cfg_overlay:
|
| 152 |
+
file: ./../data/checkpoints/checkpoint.pt
|
| 153 |
+
backbone_pretrain_img_size: 42
|
| 154 |
+
backbone_drop_path: 0.1
|
| 155 |
+
backbone_bands:
|
| 156 |
+
- Oa01_reflectance
|
| 157 |
+
- Oa02_reflectance
|
| 158 |
+
- Oa03_reflectance
|
| 159 |
+
- Oa04_reflectance
|
| 160 |
+
- Oa05_reflectance
|
| 161 |
+
- Oa06_reflectance
|
| 162 |
+
- Oa07_reflectance
|
| 163 |
+
- Oa08_reflectance
|
| 164 |
+
- Oa09_reflectance
|
| 165 |
+
- Oa10_reflectance
|
| 166 |
+
- Oa11_reflectance
|
| 167 |
+
- Oa12_reflectance
|
| 168 |
+
- Oa16_reflectance
|
| 169 |
+
- Oa17_reflectance
|
| 170 |
+
- Oa18_reflectance
|
| 171 |
+
- Oa21_reflectance
|
| 172 |
+
head_dropout: 0.16194593880230534
|
| 173 |
+
head_channel_list: [64]
|
| 174 |
+
necks:
|
| 175 |
+
- name: SelectIndices
|
| 176 |
+
indices: [2, 5, 8, 11]
|
| 177 |
+
- name: ReshapeTokensToImage
|
| 178 |
+
- name: LearnedInterpolateToPyramidal
|
| 179 |
+
decoder: UNetDecoder
|
| 180 |
+
decoder_channels: [256, 128, 64, 32]
|
| 181 |
+
head_dropout: 0.1
|
| 182 |
+
loss: rmse
|
| 183 |
+
ignore_index: -1
|
| 184 |
+
freeze_backbone: false
|
| 185 |
+
freeze_decoder: false
|
| 186 |
+
model_factory: EncoderDecoderFactory
|
| 187 |
+
tiled_inference_parameters:
|
| 188 |
+
h_crop: 64
|
| 189 |
+
h_stride: 4
|
| 190 |
+
w_crop: 64
|
| 191 |
+
w_stride: 4
|
| 192 |
+
delta: 8
|
| 193 |
+
average_patches: true
|
| 194 |
+
optimizer:
|
| 195 |
+
class_path: torch.optim.AdamW
|
| 196 |
+
init_args:
|
| 197 |
+
lr: 0.00012
|
| 198 |
+
weight_decay: 0.3
|
| 199 |
+
lr_scheduler:
|
| 200 |
+
class_path: ReduceLROnPlateau
|
| 201 |
+
init_args:
|
| 202 |
+
monitor: val/loss
|
model_architecture.png
ADDED
|
Git LFS Details
|