Commit
Β·
dfb5c47
1
Parent(s):
12f6482
Fix app
Browse files- app.py +1 -43
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -22,48 +22,6 @@ os.system(f'cp {model_inference} .')
|
|
| 22 |
|
| 23 |
from inference import process_channel_group, _convert_np_uint8, load_example, run_model
|
| 24 |
|
| 25 |
-
def extract_rgb_imgs(input_img, pred_img, channels):
|
| 26 |
-
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
| 27 |
-
Args:
|
| 28 |
-
input_img: input torch.Tensor with shape (C, H, W).
|
| 29 |
-
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
| 30 |
-
pred_img: mask torch.Tensor with shape (C, T, H, W).
|
| 31 |
-
channels: list of indices representing RGB channels.
|
| 32 |
-
mean: list of mean values for each band.
|
| 33 |
-
std: list of std values for each band.
|
| 34 |
-
output_dir: directory where to save outputs.
|
| 35 |
-
meta_data: list of dicts with geotiff meta info.
|
| 36 |
-
"""
|
| 37 |
-
rgb_orig_list = []
|
| 38 |
-
rgb_mask_list = []
|
| 39 |
-
rgb_pred_list = []
|
| 40 |
-
|
| 41 |
-
for t in range(input_img.shape[1]):
|
| 42 |
-
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
|
| 43 |
-
new_img=rec_img[:, t, :, :],
|
| 44 |
-
channels=channels,
|
| 45 |
-
mean=mean,
|
| 46 |
-
std=std)
|
| 47 |
-
|
| 48 |
-
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
| 49 |
-
|
| 50 |
-
# extract images
|
| 51 |
-
rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
|
| 52 |
-
rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
|
| 53 |
-
rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
|
| 54 |
-
|
| 55 |
-
# Add white dummy image values for missing timestamps
|
| 56 |
-
dummy = np.ones((20, 20), dtype=np.uint8) * 255
|
| 57 |
-
num_dummies = 4 - len(rgb_orig_list)
|
| 58 |
-
if num_dummies:
|
| 59 |
-
rgb_orig_list.extend([dummy] * num_dummies)
|
| 60 |
-
rgb_mask_list.extend([dummy] * num_dummies)
|
| 61 |
-
rgb_pred_list.extend([dummy] * num_dummies)
|
| 62 |
-
|
| 63 |
-
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
|
| 64 |
-
|
| 65 |
-
return outputs
|
| 66 |
-
|
| 67 |
|
| 68 |
def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
|
| 69 |
try:
|
|
@@ -81,7 +39,7 @@ def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
|
|
| 81 |
# Load model ---------------------------------------------------------------------------------
|
| 82 |
|
| 83 |
lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
|
| 84 |
-
img_size =
|
| 85 |
|
| 86 |
# Loading data ---------------------------------------------------------------------------------
|
| 87 |
|
|
|
|
| 22 |
|
| 23 |
from inference import process_channel_group, _convert_np_uint8, load_example, run_model
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
|
| 27 |
try:
|
|
|
|
| 39 |
# Load model ---------------------------------------------------------------------------------
|
| 40 |
|
| 41 |
lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
|
| 42 |
+
img_size = 512 # Size from Sen1Floods11 training
|
| 43 |
|
| 44 |
# Loading data ---------------------------------------------------------------------------------
|
| 45 |
|
requirements.txt
CHANGED
|
@@ -5,4 +5,4 @@ rasterio
|
|
| 5 |
einops
|
| 6 |
huggingface_hub
|
| 7 |
gradio
|
| 8 |
-
|
|
|
|
| 5 |
einops
|
| 6 |
huggingface_hub
|
| 7 |
gradio
|
| 8 |
+
terratorch==1.0.2
|