Commit
·
f491c9c
1
Parent(s):
60bda88
Added inference code
Browse files- .gitattributes +1 -0
- README.md +85 -0
- inference.py +344 -0
- requirements.txt +6 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.tif filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,88 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
tags:
|
6 |
+
- Pytorch
|
7 |
+
- segmentation
|
8 |
+
- Flood mapping
|
9 |
+
- Sentinel-2
|
10 |
+
- Geospatial
|
11 |
+
- Foundation model
|
12 |
---
|
13 |
+
### Model and Inputs
|
14 |
+
The pretrained [Prithvi-EO-1.0-100m](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/blob/main/README.md) model is finetuned to segment the extent of floods on Sentinel-2 images from the [Sen1Floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
|
15 |
+
|
16 |
+
The dataset consists of 446 labeled 512x512 chips that span all 14 biomes, 357 ecoregions, and 6 continents of the world across 11 flood events. The benchmark associated to Sen1Floods11 provides results for fully convolutional neural networks trained in various input/labeled data setups, considering Sentinel-1 and Sentinel-2 imagery.
|
17 |
+
|
18 |
+
We extract the following bands for flood mapping:
|
19 |
+
|
20 |
+
1. Blue
|
21 |
+
2. Green
|
22 |
+
3. Red
|
23 |
+
4. Narrow NIR
|
24 |
+
5. SWIR 1
|
25 |
+
6. SWIR 2
|
26 |
+
|
27 |
+
Labels represent no water (class 0), water/flood (class 1), and no data/clouds (class -1).
|
28 |
+
|
29 |
+
The Prithvi-100m model was initially pretrained using a sequence length of 3 timesteps. Based on the characteristics of this benchmark dataset, we focus on single-timestamp segmentation. This demonstrates that our model can be utilized with an arbitrary number of timestamps during finetuning.
|
30 |
+
|
31 |
+

|
32 |
+
|
33 |
+
### Code
|
34 |
+
|
35 |
+
The code for this finetuning is available through [github](https://github.com/NASA-IMPACT/hls-foundation-os/).
|
36 |
+
|
37 |
+
The configuration used for finetuning is available through this [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/sen1floods11.py).
|
38 |
+
|
39 |
+
### Results
|
40 |
+
|
41 |
+
Finetuning the geospatial foundation model for 100 epochs leads to the following performance on the test dataset:
|
42 |
+
|
43 |
+
| **Classes** | **IoU**| **Acc**|
|
44 |
+
|:------------------:|:------:|:------:|
|
45 |
+
| No water | 96.90% | 98.11% |
|
46 |
+
| Water/Flood | 80.46% | 90.54% |
|
47 |
+
|
48 |
+
| **aAcc** |**mIoU**|**mAcc**|
|
49 |
+
|:------------------:|:------:|:------:|
|
50 |
+
| 97.25% | 88.68% | 94.37% |
|
51 |
+
|
52 |
+
|
53 |
+
The performance of the model has been further validated on an unseen, holdout flood event in Bolivia. The results are consistent with the performance on the test set:
|
54 |
+
|
55 |
+
|
56 |
+
| **Classes** | **IoU**| **Acc**|
|
57 |
+
|:------------------:|:------:|:------:|
|
58 |
+
| No water | 95.37% | 97.39% |
|
59 |
+
| Water/Flood | 77.95% | 88.74% |
|
60 |
+
|
61 |
+
| **aAcc** |**mIoU**|**mAcc**|
|
62 |
+
|:------------------:|:------:|:------:|
|
63 |
+
| 96.02% | 86.66% | 93.07% |
|
64 |
+
|
65 |
+
Finetuning took ~1 hour on an NVIDIA V100.
|
66 |
+
|
67 |
+
|
68 |
+
### Inference
|
69 |
+
The github repo includes an inference script that allows running the flood mapping model for inference on Sentinel-2 images. These inputs have to be geotiff format, including 6 bands for a single time-step described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order. There is also a **demo** that leverages the same code **[here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-sen1floods11-demo)**.
|
70 |
+
|
71 |
+
### Feedback
|
72 |
+
|
73 |
+
Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on our open-source repository, [hls-foundation-os](https://github.com/NASA-IMPACT/hls-foundation-os/issues), on GitHub.
|
74 |
+
|
75 |
+
### Citation
|
76 |
+
|
77 |
+
If this model helped your research, please cite our model in your publications. Here is an example BibTeX entry:
|
78 |
+
|
79 |
+
```
|
80 |
+
@misc{Prithvi-100M-flood-mapping,
|
81 |
+
author = {Jakubik, Johannes and Fraccaro, Paolo and Oliveira Borges, Dario and Muszynski, Michal and Weldemariam, Kommy and Zadrozny, Bianca and Ganti, Raghu and Mukkavilli, Karthik},
|
82 |
+
month = aug,
|
83 |
+
doi = { 10.57967/hf/0973 },
|
84 |
+
title = {{Prithvi 100M flood mapping}},
|
85 |
+
repository-code = {https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11},
|
86 |
+
year = {2023}
|
87 |
+
}
|
88 |
+
```
|
inference.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
from typing import List, Union
|
5 |
+
import re
|
6 |
+
import datetime
|
7 |
+
import numpy as np
|
8 |
+
import rasterio
|
9 |
+
import torch
|
10 |
+
import yaml
|
11 |
+
from einops import rearrange
|
12 |
+
from terratorch.cli_tools import LightningInferenceModel
|
13 |
+
|
14 |
+
NO_DATA = -9999
|
15 |
+
NO_DATA_FLOAT = 0.0001
|
16 |
+
OFFSET = 0
|
17 |
+
PERCENTILE = 99
|
18 |
+
|
19 |
+
|
20 |
+
def process_channel_group(orig_img, channels):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
24 |
+
channels: list of indices representing RGB channels.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor with shape (num_channels, height, width) for original image
|
28 |
+
"""
|
29 |
+
|
30 |
+
orig_img = orig_img[channels, ...]
|
31 |
+
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
32 |
+
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
33 |
+
|
34 |
+
|
35 |
+
# Rescale (enhancing contrast)
|
36 |
+
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
37 |
+
min_value = OFFSET
|
38 |
+
|
39 |
+
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
40 |
+
|
41 |
+
# No data as zeros
|
42 |
+
orig_img[~valid_mask] = 0
|
43 |
+
|
44 |
+
return orig_img
|
45 |
+
|
46 |
+
|
47 |
+
def read_geotiff(file_path: str):
|
48 |
+
"""Read all bands from *file_path* and return image + meta info.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
file_path: path to image file.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
np.ndarray with shape (bands, height, width)
|
55 |
+
meta info dict
|
56 |
+
"""
|
57 |
+
|
58 |
+
with rasterio.open(file_path) as src:
|
59 |
+
img = src.read()
|
60 |
+
meta = src.meta
|
61 |
+
try:
|
62 |
+
coords = src.lnglat()
|
63 |
+
except:
|
64 |
+
# Cannot read coords
|
65 |
+
coords = None
|
66 |
+
|
67 |
+
return img, meta, coords
|
68 |
+
|
69 |
+
|
70 |
+
def save_geotiff(image, output_path: str, meta: dict):
|
71 |
+
"""Save multi-band image in Geotiff file.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
image: np.ndarray with shape (bands, height, width)
|
75 |
+
output_path: path where to save the image
|
76 |
+
meta: dict with meta info.
|
77 |
+
"""
|
78 |
+
|
79 |
+
with rasterio.open(output_path, "w", **meta) as dest:
|
80 |
+
for i in range(image.shape[0]):
|
81 |
+
dest.write(image[i, :, :], i + 1)
|
82 |
+
|
83 |
+
return
|
84 |
+
|
85 |
+
|
86 |
+
def _convert_np_uint8(float_image: torch.Tensor):
|
87 |
+
image = float_image.numpy() * 255.0
|
88 |
+
image = image.astype(dtype=np.uint8)
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
|
93 |
+
def load_example(
|
94 |
+
file_paths: List[str],
|
95 |
+
mean: List[float] = None,
|
96 |
+
std: List[float] = None,
|
97 |
+
indices: Union[list[int], None] = None,
|
98 |
+
):
|
99 |
+
"""Build an input example by loading images in *file_paths*.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
file_paths: list of file paths .
|
103 |
+
mean: list containing mean values for each band in the images in *file_paths*.
|
104 |
+
std: list containing std values for each band in the images in *file_paths*.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
np.array containing created example
|
108 |
+
list of meta info for each image in *file_paths*
|
109 |
+
"""
|
110 |
+
|
111 |
+
imgs = []
|
112 |
+
metas = []
|
113 |
+
temporal_coords = []
|
114 |
+
location_coords = []
|
115 |
+
|
116 |
+
for file in file_paths:
|
117 |
+
img, meta, coords = read_geotiff(file)
|
118 |
+
|
119 |
+
# Rescaling (don't normalize on nodata)
|
120 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
121 |
+
if indices is not None:
|
122 |
+
img = img[..., indices]
|
123 |
+
if mean is not None and std is not None:
|
124 |
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
125 |
+
|
126 |
+
imgs.append(img)
|
127 |
+
metas.append(meta)
|
128 |
+
if coords is not None:
|
129 |
+
location_coords.append(coords)
|
130 |
+
|
131 |
+
try:
|
132 |
+
match = re.search(r'(\d{7,8}T\d{6})', file)
|
133 |
+
if match:
|
134 |
+
year = int(match.group(1)[:4])
|
135 |
+
julian_day = match.group(1).split('T')[0][4:]
|
136 |
+
if len(julian_day) == 3:
|
137 |
+
julian_day = int(julian_day)
|
138 |
+
else:
|
139 |
+
julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
|
140 |
+
temporal_coords.append([year, julian_day])
|
141 |
+
except Exception as e:
|
142 |
+
print(f'Could not extract timestamp for {file} ({e})')
|
143 |
+
|
144 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
145 |
+
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
146 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
147 |
+
|
148 |
+
return imgs, temporal_coords, location_coords, metas
|
149 |
+
|
150 |
+
|
151 |
+
def run_model(input_data, temporal_coords, location_coords, model, datamodule, img_size):
|
152 |
+
# Reflect pad if not divisible by img_size
|
153 |
+
original_h, original_w = input_data.shape[-2:]
|
154 |
+
pad_h = (img_size - (original_h % img_size)) % img_size
|
155 |
+
pad_w = (img_size - (original_w % img_size)) % img_size
|
156 |
+
input_data = np.pad(
|
157 |
+
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
158 |
+
)
|
159 |
+
|
160 |
+
# Build sliding window
|
161 |
+
|
162 |
+
batch_size = 1
|
163 |
+
batch = torch.tensor(input_data, device="cpu")
|
164 |
+
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
165 |
+
h1, w1 = windows.shape[3:5]
|
166 |
+
windows = rearrange(
|
167 |
+
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
168 |
+
)
|
169 |
+
|
170 |
+
# Split into batches if number of windows > batch_size
|
171 |
+
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
172 |
+
windows = torch.tensor_split(windows, num_batches, dim=0)
|
173 |
+
|
174 |
+
if temporal_coords:
|
175 |
+
temporal_coords = torch.Tensor(temporal_coords, device=model.device).unsqueeze(0)
|
176 |
+
else:
|
177 |
+
temporal_coords = None
|
178 |
+
if location_coords:
|
179 |
+
location_coords = torch.Tensor(location_coords[0], device=model.device).unsqueeze(0)
|
180 |
+
else:
|
181 |
+
location_coords = None
|
182 |
+
|
183 |
+
# Run model
|
184 |
+
pred_imgs = []
|
185 |
+
for x in windows:
|
186 |
+
# Apply standardization
|
187 |
+
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
|
188 |
+
x = datamodule.aug(x)['image']
|
189 |
+
|
190 |
+
with torch.no_grad():
|
191 |
+
x = x.to(model.device)
|
192 |
+
pred = model(x, temporal_coords=temporal_coords, location_coords=location_coords)
|
193 |
+
pred = pred.output.detach().cpu()
|
194 |
+
|
195 |
+
y_hat = pred.argmax(dim=1)
|
196 |
+
|
197 |
+
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), size=img_size, mode="nearest")
|
198 |
+
|
199 |
+
pred_imgs.append(y_hat)
|
200 |
+
|
201 |
+
pred_imgs = torch.concat(pred_imgs, dim=0)
|
202 |
+
|
203 |
+
# Build images from patches
|
204 |
+
pred_imgs = rearrange(
|
205 |
+
pred_imgs,
|
206 |
+
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
207 |
+
h=img_size,
|
208 |
+
w=img_size,
|
209 |
+
b=1,
|
210 |
+
c=1,
|
211 |
+
h1=h1,
|
212 |
+
w1=w1,
|
213 |
+
)
|
214 |
+
|
215 |
+
# Cut padded area back to original size
|
216 |
+
pred_imgs = pred_imgs[..., :original_h, :original_w]
|
217 |
+
|
218 |
+
# Squeeze (batch size 1)
|
219 |
+
pred_imgs = pred_imgs[0]
|
220 |
+
|
221 |
+
return pred_imgs
|
222 |
+
|
223 |
+
|
224 |
+
def main(
|
225 |
+
data_file: str,
|
226 |
+
config: str,
|
227 |
+
checkpoint: str,
|
228 |
+
output_dir: str,
|
229 |
+
rgb_outputs: bool,
|
230 |
+
input_indices: list[int] = None,
|
231 |
+
):
|
232 |
+
os.makedirs(output_dir, exist_ok=True)
|
233 |
+
|
234 |
+
with open(config, "r") as f:
|
235 |
+
config_dict = yaml.safe_load(f)
|
236 |
+
|
237 |
+
# Load model ---------------------------------------------------------------------------------
|
238 |
+
|
239 |
+
lightning_model = LightningInferenceModel.from_config(config, checkpoint)
|
240 |
+
img_size = 512 # Size of Sen1Floods11
|
241 |
+
|
242 |
+
# Loading data ---------------------------------------------------------------------------------
|
243 |
+
|
244 |
+
input_data, temporal_coords, location_coords, meta_data = load_example(
|
245 |
+
file_paths=[data_file], indices=input_indices,
|
246 |
+
)
|
247 |
+
|
248 |
+
meta_data = meta_data[0] # only one image
|
249 |
+
|
250 |
+
if input_data.mean() > 1:
|
251 |
+
input_data = input_data / 10000 # Convert to range 0-1
|
252 |
+
|
253 |
+
# Running model --------------------------------------------------------------------------------
|
254 |
+
|
255 |
+
lightning_model.model.eval()
|
256 |
+
|
257 |
+
channels = [config_dict['data']['init_args']['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]] # BGR -> RGB
|
258 |
+
|
259 |
+
pred = run_model(input_data, temporal_coords, location_coords,
|
260 |
+
lightning_model.model, lightning_model.datamodule, img_size)
|
261 |
+
|
262 |
+
# Save pred
|
263 |
+
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
264 |
+
pred_file = os.path.join(output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
265 |
+
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
266 |
+
|
267 |
+
# Save image + pred
|
268 |
+
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
269 |
+
|
270 |
+
if input_data.mean() < 1:
|
271 |
+
input_data = input_data * 10000 # Scale to 0-10000
|
272 |
+
|
273 |
+
rgb_orig = process_channel_group(
|
274 |
+
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
275 |
+
channels=channels,
|
276 |
+
)
|
277 |
+
|
278 |
+
pred[pred == 0.] = np.nan
|
279 |
+
img_pred = rgb_orig * 0.7 + pred * 0.3
|
280 |
+
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
281 |
+
|
282 |
+
img_pred_file = os.path.join(output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
283 |
+
save_geotiff(
|
284 |
+
image=_convert_np_uint8(img_pred),
|
285 |
+
output_path=img_pred_file,
|
286 |
+
meta=meta_data,
|
287 |
+
)
|
288 |
+
|
289 |
+
# Save image rgb
|
290 |
+
if rgb_outputs:
|
291 |
+
rgb_file = os.path.join(output_dir, f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
292 |
+
save_geotiff(
|
293 |
+
image=_convert_np_uint8(rgb_orig),
|
294 |
+
output_path=rgb_file,
|
295 |
+
meta=meta_data,
|
296 |
+
)
|
297 |
+
|
298 |
+
print("Done!")
|
299 |
+
|
300 |
+
|
301 |
+
if __name__ == "__main__":
|
302 |
+
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
303 |
+
|
304 |
+
parser.add_argument(
|
305 |
+
"--data_file",
|
306 |
+
type=str,
|
307 |
+
default="examples/India_900498_S2Hand.tif",
|
308 |
+
help="Path to the file.",
|
309 |
+
)
|
310 |
+
parser.add_argument(
|
311 |
+
"--config",
|
312 |
+
"-c",
|
313 |
+
type=str,
|
314 |
+
default="config.yaml",
|
315 |
+
help="Path to yaml file containing model parameters.",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--checkpoint",
|
319 |
+
type=str,
|
320 |
+
default="Prithvi-EO-V2-300M-TL-Sen1Floods11.ckpt",
|
321 |
+
help="Path to a checkpoint file to load from.",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--output_dir",
|
325 |
+
type=str,
|
326 |
+
default="output",
|
327 |
+
help="Path to the directory where to save outputs.",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--input_indices",
|
331 |
+
default=[1,2,3,8,11,12],
|
332 |
+
type=int,
|
333 |
+
nargs="+",
|
334 |
+
help="0-based indices of the six Prithvi channels to be selected from the input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
"--rgb_outputs",
|
338 |
+
action="store_true",
|
339 |
+
help="If present, output files will only contain RGB channels. "
|
340 |
+
"Otherwise, all bands will be saved.",
|
341 |
+
)
|
342 |
+
args = parser.parse_args()
|
343 |
+
|
344 |
+
main(**vars(args))
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
timm
|
4 |
+
einops
|
5 |
+
rasterio
|
6 |
+
git+https://github.com/IBM/terratorch.git
|