qq1990 commited on
Commit
100edb4
·
1 Parent(s): ed684df
FengWu ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 3bdf35bd6a84600e95c8d534fec727c69e4e7982
Pangu-Weather ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 72bdd99096721e1a1f8912c37a9a3aff9ff0a4f2
Prithvi.py ADDED
@@ -0,0 +1,2505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import yaml
6
+ from pathlib import Path
7
+ from io import BytesIO
8
+ import random
9
+ from pathlib import Path
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ from huggingface_hub import hf_hub_download, snapshot_download
14
+ import tempfile
15
+ import traceback
16
+ import functools as ft
17
+ import os
18
+ import random
19
+ import re
20
+ from collections import defaultdict
21
+ from datetime import datetime, timedelta
22
+ from pathlib import Path
23
+ import h5py
24
+ import numpy as np
25
+ import pandas as pd
26
+ import torch
27
+ from torch import Tensor
28
+ from torch.utils.data import Dataset
29
+ import logging
30
+ from Prithvi import *
31
+
32
+ def preproc(batch: list[dict], padding: dict[tuple[int]]) -> dict[str, Tensor]:
33
+ """Prepressing function for MERRA2 Dataset
34
+
35
+ Args:
36
+ batch (dict): List of training samples, each sample should be a
37
+ dictionary with the following keys::
38
+
39
+ 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).
40
+ 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).
41
+ 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).
42
+ 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).
43
+ 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).
44
+ 'sur_climate': Torch tensor of shape (parameter, lat, lon)
45
+ 'ulv_climate': Torch tensor of shape (parameter, level, lat, lon)
46
+ 'lead_time': Integer.
47
+ 'input_time': Integer.
48
+
49
+ padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.
50
+
51
+ Returns:
52
+ Dictionary with the following keys::
53
+
54
+ 'x': [batch, time, parameter, lat, lon]
55
+ 'y': [batch, parameter, lat, lon]
56
+ 'static': [batch, parameter, lat, lon]
57
+ 'lead_time': [batch]
58
+ 'input_time': [batch]
59
+ 'climate (Optional)': [batch, parameter, lat, lon]
60
+
61
+ Note:
62
+ Here, for x and y, 'parameter' is [surface parameter, upper level,
63
+ parameter x level]. Similarly for the static information we have
64
+ [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),
65
+ ...].
66
+ """ # noqa: E501
67
+ b0 = batch[0]
68
+ nbatch = len(batch)
69
+ data_keys = set(b0.keys())
70
+
71
+ essential_keys = {
72
+ "sur_static",
73
+ "sur_vals",
74
+ "sur_tars",
75
+ "ulv_vals",
76
+ "ulv_tars",
77
+ "input_time",
78
+ "lead_time",
79
+ }
80
+
81
+ climate_keys = {
82
+ "sur_climate",
83
+ "ulv_climate",
84
+ }
85
+
86
+ all_keys = essential_keys | climate_keys
87
+
88
+ if not essential_keys.issubset(data_keys):
89
+ raise ValueError("Missing essential keys.")
90
+
91
+ if not data_keys.issubset(all_keys):
92
+ raise ValueError("Unexpected keys in batch.")
93
+
94
+ # Bring all tensors from the batch into a single tensor
95
+ upl_x = torch.empty((nbatch, *b0["ulv_vals"].shape))
96
+ upl_y = torch.empty((nbatch, *b0["ulv_tars"].shape))
97
+
98
+ sur_x = torch.empty((nbatch, *b0["sur_vals"].shape))
99
+ sur_y = torch.empty((nbatch, *b0["sur_tars"].shape))
100
+
101
+ sur_sta = torch.empty((nbatch, *b0["sur_static"].shape))
102
+
103
+ lead_time = torch.empty((nbatch,), dtype=torch.float32)
104
+ input_time = torch.empty((nbatch,), dtype=torch.float32)
105
+
106
+ for i, rec in enumerate(batch):
107
+ sur_x[i] = rec["sur_vals"]
108
+ sur_y[i] = rec["sur_tars"]
109
+
110
+ upl_x[i] = rec["ulv_vals"]
111
+ upl_y[i] = rec["ulv_tars"]
112
+
113
+ sur_sta[i] = rec["sur_static"]
114
+
115
+ lead_time[i] = rec["lead_time"]
116
+ input_time[i] = rec["input_time"]
117
+
118
+ return_value = {
119
+ "lead_time": lead_time,
120
+ "input_time": input_time,
121
+ }
122
+
123
+ # Reshape (batch, parameter, level, time, lat, lon) ->
124
+ # (batch, time, parameter, level, lat, lon)
125
+ upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))
126
+ upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))
127
+ # Reshape (batch, parameter, time, lat, lon) ->
128
+ # (batch, time, parameter, lat, lon)
129
+ sur_x = sur_x.permute((0, 2, 1, 3, 4))
130
+ sur_y = sur_y.permute((0, 2, 1, 3, 4))
131
+
132
+ # Pad
133
+ padding_2d = (*padding["lon"], *padding["lat"])
134
+
135
+ def pad2d(x):
136
+ return torch.nn.functional.pad(x, padding_2d, mode="constant", value=0)
137
+
138
+ padding_3d = (*padding["lon"], *padding["lat"], *padding["level"])
139
+
140
+ def pad3d(x):
141
+ return torch.nn.functional.pad(x, padding_3d, mode="constant", value=0)
142
+
143
+ sur_x = pad2d(sur_x).contiguous()
144
+ upl_x = pad3d(upl_x).contiguous()
145
+ sur_y = pad2d(sur_y).contiguous()
146
+ upl_y = pad3d(upl_y).contiguous()
147
+ return_value["static"] = pad2d(sur_sta).contiguous()
148
+
149
+ # Remove time for targets
150
+ upl_y = torch.squeeze(upl_y, 1)
151
+ sur_y = torch.squeeze(sur_y, 1)
152
+
153
+ # We stack along the combined parameter x level dimension
154
+ return_value["x"] = torch.cat(
155
+ (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2
156
+ )
157
+ return_value["y"] = torch.cat(
158
+ (sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1
159
+ )
160
+
161
+ if climate_keys.issubset(data_keys):
162
+ sur_climate = torch.empty((nbatch, *b0["sur_climate"].shape))
163
+ ulv_climate = torch.empty((nbatch, *b0["ulv_climate"].shape))
164
+ for i, rec in enumerate(batch):
165
+ sur_climate[i] = rec["sur_climate"]
166
+ ulv_climate[i] = rec["ulv_climate"]
167
+ sur_climate = pad2d(sur_climate)
168
+ ulv_climate = pad3d(ulv_climate)
169
+
170
+ return_value["climate"] = torch.cat(
171
+ (
172
+ sur_climate,
173
+ ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]),
174
+ ),
175
+ dim=1,
176
+ )
177
+
178
+ return return_value
179
+
180
+
181
+ def input_scalers(
182
+ surf_vars: list[str],
183
+ vert_vars: list[str],
184
+ levels: list[float],
185
+ surf_path: str | Path,
186
+ vert_path: str | Path,
187
+ ) -> tuple[Tensor, Tensor]:
188
+ """Reads the input scalers
189
+
190
+ Args:
191
+ surf_vars: surface variables to be used.
192
+ vert_vars: vertical variables to be used.
193
+ levels: MERRA2 levels to use.
194
+ surf_path: path to surface scalers file.
195
+ vert_path: path to vertical level scalers file.
196
+
197
+ Returns:
198
+ mu (Tensor): mean values
199
+ var (Tensor): varience values
200
+ """
201
+ with h5py.File(Path(surf_path), "r", libver="latest") as surf_file:
202
+ stats = [x.decode().lower() for x in surf_file["statistic"][()]]
203
+ mu_idx = stats.index("mu")
204
+ sig_idx = stats.index("sigma")
205
+
206
+ s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars])
207
+ s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars])
208
+
209
+ with h5py.File(Path(vert_path), "r", libver="latest") as vert_file:
210
+ stats = [x.decode().lower() for x in vert_file["statistic"][()]]
211
+ mu_idx = stats.index("mu")
212
+ sig_idx = stats.index("sigma")
213
+
214
+ lvl = vert_file["lev"][()]
215
+ l_idx = [np.where(lvl == v)[0].item() for v in levels]
216
+
217
+ v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars])
218
+ v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars])
219
+
220
+ v_mu = torch.from_numpy(v_mu).view(-1)
221
+ v_sig = torch.from_numpy(v_sig).view(-1)
222
+
223
+ mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32)
224
+ sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4)
225
+ return mu, sig
226
+
227
+
228
+ def static_input_scalers(
229
+ scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7
230
+ ) -> tuple[Tensor, Tensor]:
231
+ scalar_path = Path(scalar_path)
232
+
233
+ with h5py.File(scalar_path, "r", libver="latest") as scaler_file:
234
+ stats = [x.decode().lower() for x in scaler_file["statistic"][()]]
235
+ mu_idx = stats.index("mu")
236
+ sig_idx = stats.index("sigma")
237
+
238
+ mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars])
239
+ sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars])
240
+
241
+ z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device)
242
+ o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device)
243
+ mu = torch.cat((z, mu), dim=0).to(torch.float32)
244
+ sig = torch.cat((o, sig), dim=0).to(torch.float32)
245
+
246
+ return mu, sig.clamp(1e-4, 1e4)
247
+
248
+
249
+ def output_scalers(
250
+ surf_vars: list[str],
251
+ vert_vars: list[str],
252
+ levels: list[float],
253
+ surf_path: str | Path,
254
+ vert_path: str | Path,
255
+ ) -> Tensor:
256
+ surf_path = Path(surf_path)
257
+ vert_path = Path(vert_path)
258
+
259
+ with h5py.File(surf_path, "r", libver="latest") as surf_file:
260
+ svars = torch.tensor([surf_file[k][()] for k in surf_vars])
261
+
262
+ with h5py.File(vert_path, "r", libver="latest") as vert_file:
263
+ lvl = vert_file["lev"][()]
264
+ l_idx = [np.where(lvl == v)[0].item() for v in levels]
265
+ vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars])
266
+ vvars = torch.from_numpy(vvars).view(-1)
267
+
268
+ var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7)
269
+
270
+ return var
271
+
272
+
273
+ class SampleSpec:
274
+ """
275
+ A data class to collect the information used to define a sample.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ inputs: tuple[pd.Timestamp, pd.Timestamp],
281
+ lead_time: int,
282
+ target: pd.Timestamp | list[pd.Timestamp],
283
+ ):
284
+ """
285
+ Args:
286
+ inputs: Tuple of timestamps. In ascending order.
287
+ lead_time: Lead time. In hours.
288
+ target: Timestamp of the target. Can be before or after the inputs.
289
+ """
290
+ if not inputs[0] < inputs[1]:
291
+ raise ValueError(
292
+ "Timestamps in `inputs` should be in strictly ascending order."
293
+ )
294
+
295
+ self.inputs = inputs
296
+ self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600
297
+ self.lead_time = lead_time
298
+ self.target = target
299
+
300
+ self.times = [*inputs, target]
301
+ self.stat_times = [inputs[-1]]
302
+
303
+ @property
304
+ def climatology_info(self) -> tuple[int, int]:
305
+ """Get the required climatology info.
306
+
307
+ :return: information required to obtain climatology data. Essentially
308
+ this is the day of the year and hour of the day of the target
309
+ timestamp, with the former restricted to the interval [1, 365].
310
+ :rtype: tuple
311
+ """
312
+ return (min(self.target.dayofyear, 365), self.target.hour)
313
+
314
+ @property
315
+ def year(self) -> int:
316
+ return self.inputs[1].year
317
+
318
+ @property
319
+ def dayofyear(self) -> int:
320
+ return self.inputs[1].dayofyear
321
+
322
+ @property
323
+ def hourofday(self) -> int:
324
+ return self.inputs[1].hour
325
+
326
+ def _info_str(self) -> str:
327
+ iso_8601 = "%Y-%m-%dT%H:%M:%S"
328
+
329
+ return (
330
+ f"Issue time: {self.inputs[1].strftime(iso_8601)}\n"
331
+ f"Lead time: {self.lead_time} hours ahead\n"
332
+ f"Input delta: {self.input_time} hours\n"
333
+ f"Target time: {self.target.strftime(iso_8601)}"
334
+ )
335
+
336
+ @classmethod
337
+ def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int):
338
+ """Given a timestamp and lead time, generates a SampleSpec object
339
+ describing the sample further.
340
+
341
+ Args:
342
+ timestamp: Timstamp of the sample, Ie this is the larger of the two
343
+ input timstamps.
344
+ dt: Time between input samples, in hours.
345
+ lead_time: Lead time. In hours.
346
+
347
+ Returns:
348
+ SampleSpec
349
+ """ # noqa: E501
350
+ assert dt > 0, "dt should be possitive"
351
+ lt = pd.to_timedelta(lead_time, unit="h")
352
+ dt = pd.to_timedelta(dt, unit="h")
353
+
354
+ if lead_time >= 0:
355
+ timestamp_target = timestamp + lt
356
+ else:
357
+ timestamp_target = timestamp - dt + lt
358
+
359
+ spec = cls(
360
+ inputs=(timestamp - dt, timestamp),
361
+ lead_time=lead_time,
362
+ target=timestamp_target,
363
+ )
364
+
365
+ return spec
366
+
367
+ def __repr__(self) -> str:
368
+ return self._info_str()
369
+
370
+ def __str__(self) -> str:
371
+ return self._info_str()
372
+
373
+
374
+ class Merra2Dataset(Dataset):
375
+ """MERRA2 dataset. The dataset unifies surface and vertical data as well as
376
+ optional climatology.
377
+
378
+ Samples come in the form of a dictionary. Not all keys support all
379
+ variables, yet the general ordering of dimensions is
380
+ parameter, level, time, lat, lon
381
+
382
+ Note:
383
+ Data is assumed to be in NetCDF files containing daily data at 3-hourly
384
+ intervals. These follow the naming patterns
385
+ MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in
386
+ two different locations. Optional climatology data comes from files
387
+ climate_surface_doyDOY_hourHOD.nc and
388
+ climate_vertical_doyDOY_hourHOD.nc.
389
+
390
+
391
+ Note:
392
+ `_get_valid_timestamps` assembles a set of all timestamps for which
393
+ there is data (with hourly resolutions). The result is stored in
394
+ `_valid_timestamps`. `_get_valid_climate_timestamps` does the same with
395
+ climatology data and stores it in `_valid_climate_timestamps`.
396
+
397
+ Based on this information, `samples` generates a list of valid samples,
398
+ stored in `samples`. Here the format is::
399
+
400
+ [
401
+ [
402
+ (timestamp 1, lead time A),
403
+ (timestamp 1, lead time B),
404
+ (timestamp 1, lead time C),
405
+ ],
406
+ [
407
+ (timestamp 2, lead time D),
408
+ (timestamp 2, lead time E),
409
+ ]
410
+ ]
411
+
412
+ That is, the outer list iterates over timestamps (init times), the
413
+ inner over lead times. Only valid entries are stored.
414
+ """
415
+
416
+ valid_vertical_vars = [
417
+ "CLOUD",
418
+ "H",
419
+ "OMEGA",
420
+ "PL",
421
+ "QI",
422
+ "QL",
423
+ "QV",
424
+ "T",
425
+ "U",
426
+ "V",
427
+ ]
428
+ valid_surface_vars = [
429
+ "EFLUX",
430
+ "GWETROOT",
431
+ "HFLUX",
432
+ "LAI",
433
+ "LWGAB",
434
+ "LWGEM",
435
+ "LWTUP",
436
+ "PRECTOT",
437
+ "PS",
438
+ "QV2M",
439
+ "SLP",
440
+ "SWGNT",
441
+ "SWTNT",
442
+ "T2M",
443
+ "TQI",
444
+ "TQL",
445
+ "TQV",
446
+ "TS",
447
+ "U10M",
448
+ "V10M",
449
+ "Z0M",
450
+ ]
451
+ valid_static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
452
+
453
+ valid_levels = [
454
+ 34.0,
455
+ 39.0,
456
+ 41.0,
457
+ 43.0,
458
+ 44.0,
459
+ 45.0,
460
+ 48.0,
461
+ 51.0,
462
+ 53.0,
463
+ 56.0,
464
+ 63.0,
465
+ 68.0,
466
+ 71.0,
467
+ 72.0,
468
+ ]
469
+
470
+ timedelta_input = pd.to_timedelta(3, unit="h")
471
+
472
+ def __init__(
473
+ self,
474
+ time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],
475
+ lead_times: list[int],
476
+ input_times: list[int],
477
+ data_path_surface: str | Path,
478
+ data_path_vertical: str | Path,
479
+ climatology_path_surface: str | Path | None = None,
480
+ climatology_path_vertical: str | Path | None = None,
481
+ surface_vars: list[str] | None = None,
482
+ static_surface_vars: list[str] | None = None,
483
+ vertical_vars: list[str] | None = None,
484
+ levels: list[float] | None = None,
485
+ roll_longitudes: int = 0,
486
+ positional_encoding: str = "absolute",
487
+ rtype: type = np.float32,
488
+ dtype: torch.dtype = torch.float32,
489
+ ) -> None:
490
+ """
491
+ Args:
492
+ data_path_surface: Location of surface data.
493
+ data_path_vertical: Location of vertical data.
494
+ climatology_path_surface: Location of (optional) surface
495
+ climatology.
496
+ climatology_path_vertical: Location of (optional) vertical
497
+ climatology.
498
+ surface_vars: Surface variables.
499
+ static_surface_vars: Static surface variables.
500
+ vertical_vars: Vertical variables.
501
+ levels: Levels.
502
+ time_range: Used to subset data.
503
+ lead_times: Lead times for generalized forecasting.
504
+ roll_longitudes: Set to non-zero value to data by random amount
505
+ along longitude dimension.
506
+ position_encoding: possible values are
507
+ ['absolute' (default), 'fourier'].
508
+ 'absolute' returns lat lon encoded in 3 dimensions using sine
509
+ and cosine
510
+ 'fourier' returns lat/lon to be encoded by model
511
+ <any other key> returns lat/lon to be encoded by model
512
+ rtype: numpy data type used during read
513
+ dtype: torch data type of data output
514
+ """
515
+
516
+ self.time_range = (
517
+ pd.to_datetime(time_range[0]),
518
+ pd.to_datetime(time_range[1]),
519
+ )
520
+ self.lead_times = lead_times
521
+ self.input_times = input_times
522
+ self._roll_longitudes = list(range(roll_longitudes + 1))
523
+
524
+ self._uvars = vertical_vars or self.valid_vertical_vars
525
+ self._level = levels or self.valid_levels
526
+ self._svars = surface_vars or self.valid_surface_vars
527
+ self._sstat = static_surface_vars or self.valid_static_surface_vars
528
+ self._nuvars = len(self._uvars)
529
+ self._nlevel = len(self._level)
530
+ self._nsvars = len(self._svars)
531
+ self._nsstat = len(self._sstat)
532
+
533
+ self.rtype = rtype
534
+ self.dtype = dtype
535
+
536
+ self.positional_encoding = positional_encoding
537
+
538
+ self._data_path_surface = Path(data_path_surface)
539
+ self._data_path_vertical = Path(data_path_vertical)
540
+
541
+ self.dir_exists(self._data_path_surface)
542
+ self.dir_exists(self._data_path_vertical)
543
+
544
+ self._get_coordinates()
545
+
546
+ self._climatology_path_surface = Path(climatology_path_surface) or None
547
+ self._climatology_path_vertical = (
548
+ Path(climatology_path_vertical) or None
549
+ )
550
+ self._require_clim = (
551
+ self._climatology_path_surface is not None
552
+ and self._climatology_path_vertical is not None
553
+ )
554
+
555
+ if self._require_clim:
556
+ self.dir_exists(self._climatology_path_surface)
557
+ self.dir_exists(self._climatology_path_vertical)
558
+ elif (
559
+ climatology_path_surface is None
560
+ and climatology_path_vertical is None
561
+ ):
562
+ self._climatology_path_surface = None
563
+ self._climatology_path_vertical = None
564
+ else:
565
+ raise ValueError(
566
+ "Either both or neither of"
567
+ "`climatology_path_surface` and"
568
+ "`climatology_path_vertical` should be None."
569
+ )
570
+
571
+ if not set(self._svars).issubset(set(self.valid_surface_vars)):
572
+ raise ValueError("Invalid surface variable.")
573
+
574
+ if not set(self._sstat).issubset(set(self.valid_static_surface_vars)):
575
+ raise ValueError("Invalid static surface variable.")
576
+
577
+ if not set(self._uvars).issubset(set(self.valid_vertical_vars)):
578
+ raise ValueError("Inalid vertical variable.")
579
+
580
+ if not set(self._level).issubset(set(self.valid_levels)):
581
+ raise ValueError("Invalid level.")
582
+
583
+ @staticmethod
584
+ def dir_exists(path: Path) -> None:
585
+ if not path.is_dir():
586
+ raise ValueError(f"Directory {path} does not exist.")
587
+
588
+ @property
589
+ def upper_shape(self) -> tuple:
590
+ """Returns the vertical variables shape
591
+ Returns:
592
+ tuple: vertical variable shape in the following order::
593
+
594
+ [VAR, LEV, TIME, LAT, LON]
595
+ """
596
+ return self._nuvars, self._nlevel, 2, 361, 576
597
+
598
+ @property
599
+ def surface_shape(self) -> tuple:
600
+ """Returns the surface variables shape
601
+
602
+ Returns:
603
+ tuple: surafce shape in the following order::
604
+
605
+ [VAR, LEV, TIME, LAT, LON]
606
+ """
607
+ return self._nsvars, 2, 361, 576
608
+
609
+ def data_file_surface(self, timestamp: pd.Timestamp) -> Path:
610
+ """Build the surfcae data file name based on timestamp
611
+
612
+ Args:
613
+ timestamp: a timestamp
614
+
615
+ Returns:
616
+ Path: constructed path
617
+ """
618
+ pattern = "MERRA2_sfc_%Y%m%d.nc"
619
+ data_file = self._data_path_surface / timestamp.strftime(pattern)
620
+ return data_file
621
+
622
+ def data_file_vertical(self, timestamp: pd.Timestamp) -> Path:
623
+ """Build the vertical data file name based on timestamp
624
+
625
+ Args:
626
+ timestamp: a timestamp
627
+
628
+ Returns:
629
+ Path: constructed path
630
+ """
631
+ pattern = "MERRA_pres_%Y%m%d.nc"
632
+ data_file = self._data_path_vertical / timestamp.strftime(pattern)
633
+ return data_file
634
+
635
+ def data_file_surface_climate(
636
+ self,
637
+ timestamp: pd.Timestamp | None = None,
638
+ dayofyear: int | None = None,
639
+ hourofday: int | None = None,
640
+ ) -> Path:
641
+ """
642
+ Returns the path to a climatology file based either on a timestamp or
643
+ the dayofyear / hourofday combination.
644
+ Args:
645
+ timestamp: A timestamp.
646
+ dayofyear: Day of the year. 1 to 366.
647
+ hourofday: Hour of the day. 0 to 23.
648
+ Returns:
649
+ Path: Path to climatology file.
650
+ """
651
+ if timestamp is not None and (
652
+ (dayofyear is not None) or (hourofday is not None)
653
+ ):
654
+ raise ValueError(
655
+ "Provide either timestamp or both dayofyear and hourofday."
656
+ )
657
+
658
+ if timestamp is not None:
659
+ dayofyear = min(timestamp.dayofyear, 365)
660
+ hourofday = timestamp.hour
661
+
662
+ file_name = f"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc"
663
+ data_file = self._climatology_path_surface / file_name
664
+ return data_file
665
+
666
+ def data_file_vertical_climate(
667
+ self,
668
+ timestamp: pd.Timestamp | None = None,
669
+ dayofyear: int | None = None,
670
+ hourofday: int | None = None,
671
+ ) -> Path:
672
+ """Returns the path to a climatology file based either on a timestamp
673
+ or the dayofyear / hourofday combination.
674
+
675
+ Args:
676
+ timestamp: A timestamp. dayofyear: Day of the year. 1 to 366.
677
+ hourofday: Hour of the day. 0 to 23.
678
+ Returns:
679
+ Path: Path to climatology file.
680
+ """
681
+ if timestamp is not None and (
682
+ (dayofyear is not None) or (hourofday is not None)
683
+ ):
684
+ raise ValueError(
685
+ "Provide either timestamp or both dayofyear and hourofday."
686
+ )
687
+
688
+ if timestamp is not None:
689
+ dayofyear = min(timestamp.dayofyear, 365)
690
+ hourofday = timestamp.hour
691
+
692
+ file_name = f"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc"
693
+ data_file = self._climatology_path_vertical / file_name
694
+ return data_file
695
+
696
+ def _get_coordinates(self) -> None:
697
+ """
698
+ Obtains the coordiantes (latitudes and longitudes) from a single data
699
+ file.
700
+ """
701
+ timestamp = next(iter(self.valid_timestamps))
702
+
703
+ file = self.data_file_surface(timestamp)
704
+ with h5py.File(file, "r", libver="latest") as handle:
705
+ self.lats = lats = handle["lat"][()].astype(self.rtype)
706
+ self.lons = lons = handle["lon"][()].astype(self.rtype)
707
+
708
+ deg_to_rad = np.pi / 180
709
+ self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1)
710
+
711
+ self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype)
712
+ self._embed_lon[0, 0] = np.cos(lons * deg_to_rad)
713
+ self._embed_lon[1, 0] = np.sin(lons * deg_to_rad)
714
+
715
+ @ft.cached_property
716
+ def lats(self) -> np.ndarray:
717
+ timestamp = next(iter(self.valid_timestamps))
718
+
719
+ file = self.data_file_surface(timestamp)
720
+ with h5py.File(file, "r", libver="latest") as handle:
721
+ return handle["lat"][()].astype(self.rtype)
722
+
723
+ @ft.cached_property
724
+ def lons(self) -> np.ndarray:
725
+ timestamp = next(iter(self.valid_timestamps))
726
+
727
+ file = self.data_file_surface(timestamp)
728
+ with h5py.File(file, "r", libver="latest") as handle:
729
+ return handle["lon"][()].astype(self.rtype)
730
+
731
+ @ft.cached_property
732
+ def position_signal(self) -> np.ndarray:
733
+ """Generates the "position signal" that is part of the static
734
+ features.
735
+
736
+ Returns:
737
+ Tensor: Torch tensor of dimension (parameter, lat, lon) containing
738
+ sin(lat), cos(lon), sin(lon).
739
+ """
740
+
741
+ latitudes, longitudes = np.meshgrid(
742
+ self.lats, self.lons, indexing="ij"
743
+ )
744
+
745
+ if self.positional_encoding == "absolute":
746
+ latitudes = latitudes / 360 * 2.0 * np.pi
747
+ longitudes = longitudes / 360 * 2.0 * np.pi
748
+ sur_static = np.stack(
749
+ [np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)],
750
+ axis=0,
751
+ )
752
+ else:
753
+ sur_static = np.stack([latitudes, longitudes], axis=0)
754
+
755
+ sur_static = sur_static.astype(self.rtype)
756
+
757
+ return sur_static
758
+
759
+ @ft.cached_property
760
+ def valid_timestamps(self) -> set[pd.Timestamp]:
761
+ """Generates list of valid timestamps based on available files. Only
762
+ timestamps for which both surface and vertical information is available
763
+ are considered valid.
764
+ Returns:
765
+ list: list of timestamps
766
+ """
767
+
768
+ s_glob = self._data_path_surface.glob("MERRA2_sfc_????????.nc")
769
+ s_files = [os.path.basename(f) for f in s_glob]
770
+ v_glob = self._data_path_surface.glob("MERRA_pres_????????.nc")
771
+ v_files = [os.path.basename(f) for f in v_glob]
772
+
773
+ s_re = re.compile(r"MERRA2_sfc_(\d{8}).nc\Z")
774
+ v_re = re.compile(r"MERRA_pres_(\d{8}).nc\Z")
775
+ fmt = "%Y%m%d"
776
+
777
+ s_times = {
778
+ (datetime.strptime(m[1], fmt))
779
+ for f in s_files
780
+ if (m := s_re.match(f))
781
+ }
782
+ v_times = {
783
+ (datetime.strptime(m[1], fmt))
784
+ for f in v_files
785
+ if (m := v_re.match(f))
786
+ }
787
+
788
+ times = s_times.intersection(v_times)
789
+
790
+ # Each file contains a day at 3 hour intervals
791
+ times = {
792
+ t + timedelta(hours=i) for i in range(0, 24, 3) for t in times
793
+ }
794
+
795
+ start_time, end_time = self.time_range
796
+ times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time}
797
+
798
+ return times
799
+
800
+ @ft.cached_property
801
+ def valid_climate_timestamps(self) -> set[tuple[int, int]]:
802
+ """Generates list of "timestamps" (dayofyear, hourofday) for which
803
+ climatology data is present. Only instances for which surface and
804
+ vertical data is available are considered valid.
805
+ Returns:
806
+ list: List of tuples describing valid climatology instances.
807
+ """
808
+ if not self._require_clim:
809
+ return set()
810
+
811
+ s_glob = self._climatology_path_surface.glob(
812
+ "climate_surface_doy???_hour??.nc"
813
+ )
814
+ s_files = [os.path.basename(f) for f in s_glob]
815
+
816
+ v_glob = self._climatology_path_vertical.glob(
817
+ "climate_vertical_doy???_hour??.nc"
818
+ )
819
+ v_files = [os.path.basename(f) for f in v_glob]
820
+
821
+ s_re = re.compile(r"climate_surface_doy(\d{3})_hour(\d{2}).nc\Z")
822
+ v_re = re.compile(r"climate_vertical_doy(\d{3})_hour(\d{2}).nc\Z")
823
+
824
+ s_times = {
825
+ (int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f))
826
+ }
827
+ v_times = {
828
+ (int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f))
829
+ }
830
+
831
+ times = s_times.intersection(v_times)
832
+
833
+ return times
834
+
835
+ def _data_available(self, spec: SampleSpec) -> bool:
836
+ """
837
+ Checks whether data is available for a given SampleSpec object. Does so
838
+ using the internal sets with available data previously constructed. Not
839
+ by checking the file system.
840
+ Args:
841
+ spec: SampleSpec object as returned by SampleSpec.get
842
+ Returns:
843
+ bool: if data is availability.
844
+ """
845
+ valid = set(spec.times).issubset(self.valid_timestamps)
846
+
847
+ if self._require_clim:
848
+ sci = spec.climatology_info
849
+ ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405
850
+ valid &= ci.issubset(self.valid_climate_timestamps)
851
+
852
+ return valid
853
+
854
+ @ft.cached_property
855
+ def samples(self) -> list[tuple[pd.Timestamp, int, int]]:
856
+ """
857
+ Generates list of all valid samlpes.
858
+ Returns:
859
+ list: List of tuples (timestamp, input time, lead time).
860
+ """
861
+ valid_samples = []
862
+ dts = [(it, lt) for it in self.input_times for lt in self.lead_times]
863
+
864
+ for timestamp in sorted(self.valid_timestamps):
865
+ timestamp_samples = []
866
+ for it, lt in dts:
867
+ spec = SampleSpec.get(timestamp, -it, lt)
868
+
869
+ if self._data_available(spec):
870
+ timestamp_samples.append((timestamp, it, lt))
871
+
872
+ if timestamp_samples:
873
+ valid_samples.append(timestamp_samples)
874
+
875
+ return valid_samples
876
+
877
+ def _to_torch(
878
+ self,
879
+ data: dict[str, Tensor | list[Tensor]],
880
+ dtype: torch.dtype = torch.float32,
881
+ ) -> dict[str, Tensor | list[Tensor]]:
882
+ out = {}
883
+ for k, v in data.items():
884
+ if isinstance(v, list):
885
+ out[k] = [torch.from_numpy(x).to(dtype) for x in v]
886
+ else:
887
+ out[k] = torch.from_numpy(v).to(dtype)
888
+
889
+ return out
890
+
891
+ def _lat_roll(
892
+ self, data: dict[str, Tensor | list[Tensor]], n: int
893
+ ) -> dict[str, Tensor | list[Tensor]]:
894
+ out = {}
895
+ for k, v in data.items():
896
+ if isinstance(v, list):
897
+ out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v]
898
+ else:
899
+ out[k] = torch.roll(v, shifts=n, dims=-1)
900
+
901
+ return out
902
+
903
+ def _read_static_data(
904
+ self, file: str | Path, doy: int, hod: int
905
+ ) -> np.ndarray:
906
+ with h5py.File(file, "r", libver="latest") as handle:
907
+ lats_surf = handle["lat"]
908
+ lons_surf = handle["lon"]
909
+
910
+ nll = (len(lats_surf), len(lons_surf))
911
+
912
+ npos = len(self.position_signal)
913
+ ntime = 4
914
+
915
+ nstat = npos + ntime + self._nsstat
916
+ data = np.empty((nstat, *nll), dtype=self.rtype)
917
+
918
+ for i, key in enumerate(self._sstat, start=npos + ntime):
919
+ data[i] = handle[key][()].astype(dtype=self.rtype)
920
+
921
+ # [possition signal], cos(doy), sin(doy), cos(hod), sin(hod)
922
+ data[0:npos] = self.position_signal
923
+ data[npos + 0] = np.cos(2 * np.pi * doy / 366)
924
+ data[npos + 1] = np.sin(2 * np.pi * doy / 366)
925
+ data[npos + 2] = np.cos(2 * np.pi * hod / 24)
926
+ data[npos + 3] = np.sin(2 * np.pi * hod / 24)
927
+
928
+ return data
929
+
930
+ def _read_surface(
931
+ self, tidx: int, nll: tuple[int, int], handle: h5py.File
932
+ ) -> np.ndarray:
933
+ data = np.empty((self._nsvars, *nll), dtype=self.rtype)
934
+
935
+ for i, key in enumerate(self._svars):
936
+ data[i] = handle[key][tidx][()].astype(dtype=self.rtype)
937
+
938
+ return data
939
+
940
+ def _read_levels(
941
+ self, tidx: int, nll: tuple[int, int], handle: h5py.File
942
+ ) -> np.ndarray:
943
+ lvls = handle["lev"][()]
944
+ lidx = self._level_idxs(lvls)
945
+
946
+ data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype)
947
+
948
+ for i, key in enumerate(self._uvars):
949
+ data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype)
950
+
951
+ return np.ascontiguousarray(np.flip(data, axis=1))
952
+
953
+ def _level_idxs(self, lvls):
954
+ lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level]
955
+ return sorted(lidx)
956
+
957
+ @staticmethod
958
+ def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int:
959
+ if isinstance(date, pd.Timestamp):
960
+ date = date.to_pydatetime()
961
+
962
+ time = handle["time"]
963
+
964
+ t0 = time.attrs["begin_time"][()].item()
965
+ d0 = f"{time.attrs['begin_date'][()].item()}"
966
+
967
+ offset = datetime.strptime(d0, "%Y%m%d")
968
+
969
+ times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]]
970
+ return times.index(date)
971
+
972
+ def _read_data(
973
+ self, file_pair: tuple[str, str], date: datetime
974
+ ) -> dict[str, np.ndarray]:
975
+ s_file, v_file = file_pair
976
+
977
+ with h5py.File(s_file, "r", libver="latest") as shandle:
978
+ lats_surf = shandle["lat"]
979
+ lons_surf = shandle["lon"]
980
+
981
+ nll = (len(lats_surf), len(lons_surf))
982
+
983
+ tidx = self._date_to_tidx(date, shandle)
984
+
985
+ sdata = self._read_surface(tidx, nll, shandle)
986
+
987
+ with h5py.File(v_file, "r", libver="latest") as vhandle:
988
+ lats_vert = vhandle["lat"]
989
+ lons_vert = vhandle["lon"]
990
+
991
+ nll = (len(lats_vert), len(lons_vert))
992
+
993
+ tidx = self._date_to_tidx(date, vhandle)
994
+
995
+ vdata = self._read_levels(tidx, nll, vhandle)
996
+
997
+ data = {"vert": vdata, "surf": sdata}
998
+
999
+ return data
1000
+
1001
+ def _read_climate(
1002
+ self, file_pair: tuple[str, str]
1003
+ ) -> dict[str, np.ndarray]:
1004
+ s_file, v_file = file_pair
1005
+
1006
+ with h5py.File(s_file, "r", libver="latest") as shandle:
1007
+ lats_surf = shandle["lat"]
1008
+ lons_surf = shandle["lon"]
1009
+
1010
+ nll = (len(lats_surf), len(lons_surf))
1011
+
1012
+ sdata = np.empty((self._nsvars, *nll), dtype=self.rtype)
1013
+
1014
+ for i, key in enumerate(self._svars):
1015
+ sdata[i] = shandle[key][()].astype(dtype=self.rtype)
1016
+
1017
+ with h5py.File(v_file, "r", libver="latest") as vhandle:
1018
+ lats_vert = vhandle["lat"]
1019
+ lons_vert = vhandle["lon"]
1020
+
1021
+ nll = (len(lats_vert), len(lons_vert))
1022
+
1023
+ lvls = vhandle["lev"][()]
1024
+ lidx = self._level_idxs(lvls)
1025
+
1026
+ vdata = np.empty(
1027
+ (self._nuvars, self._nlevel, *nll), dtype=self.rtype
1028
+ )
1029
+
1030
+ for i, key in enumerate(self._uvars):
1031
+ vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype)
1032
+
1033
+ data = {
1034
+ "vert": np.ascontiguousarray(np.flip(vdata, axis=1)),
1035
+ "surf": sdata,
1036
+ }
1037
+
1038
+ return data
1039
+
1040
+ def get_data_from_sample_spec(
1041
+ self, spec: SampleSpec
1042
+ ) -> dict[str, Tensor | int | float]:
1043
+ """Loads and assembles sample data given a SampleSpec object.
1044
+
1045
+ Args:
1046
+ spec (SampleSpec): Full details regarding the data to be loaded
1047
+ Returns:
1048
+ dict: Dictionary with the following keys::
1049
+
1050
+ 'sur_static': Torch tensor of shape [parameter, lat, lon]. For
1051
+ each pixel (lat, lon), the first 7 dimensions index sin(lat),
1052
+ cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).
1053
+ Where doy is the day of the year [1, 366] and hod the hour of
1054
+ the day [0, 23].
1055
+ 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].
1056
+ 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].
1057
+ 'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon].
1058
+ 'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon].
1059
+ 'sur_climate': Torch tensor of shape [parameter, lat, lon].
1060
+ 'ulv_climate': Torch tensor of shape [paramter, level, lat, lon].
1061
+ 'lead_time': Float.
1062
+ 'input_time': Float.
1063
+
1064
+ """ # noqa: E501
1065
+
1066
+ # We assemble the unique timestamps for which we need data.
1067
+ vals_required = {*spec.times}
1068
+ stat_required = {*spec.stat_times}
1069
+
1070
+ # We assemble the unique data files from which we need value data
1071
+ vals_file_map = defaultdict(list)
1072
+ for t in vals_required:
1073
+ data_files = (
1074
+ self.data_file_surface(t),
1075
+ self.data_file_vertical(t),
1076
+ )
1077
+ vals_file_map[data_files].append(t)
1078
+
1079
+ # We assemble the unique data files from which we need static data
1080
+ stat_file_map = defaultdict(list)
1081
+ for t in stat_required:
1082
+ data_files = (
1083
+ self.data_file_surface(t),
1084
+ self.data_file_vertical(t),
1085
+ )
1086
+ stat_file_map[data_files].append(t)
1087
+
1088
+ # Load the value data
1089
+ data = {}
1090
+ for data_files, times in vals_file_map.items():
1091
+ for time in times:
1092
+ data[time] = self._read_data(data_files, time)
1093
+
1094
+ # Combine times
1095
+ sample_data = {}
1096
+
1097
+ input_upl = np.stack([data[t]["vert"] for t in spec.inputs], axis=2)
1098
+ sample_data["ulv_vals"] = input_upl
1099
+
1100
+ target_upl = data[spec.target]["vert"]
1101
+ sample_data["ulv_tars"] = target_upl[:, :, None]
1102
+
1103
+ input_sur = np.stack([data[t]["surf"] for t in spec.inputs], axis=1)
1104
+ sample_data["sur_vals"] = input_sur
1105
+
1106
+ target_sur = data[spec.target]["surf"]
1107
+ sample_data["sur_tars"] = target_sur[:, None]
1108
+
1109
+ # Load the static data
1110
+ data_files, times = stat_file_map.popitem()
1111
+ time = times[0].dayofyear, times[0].hour
1112
+ sample_data["sur_static"] = self._read_static_data(
1113
+ data_files[0], *time
1114
+ )
1115
+
1116
+ # If required load the surface data
1117
+ if self._require_clim:
1118
+ ci_year, ci_hour = spec.climatology_info
1119
+
1120
+ surf_file = self.data_file_surface_climate(
1121
+ dayofyear=ci_year,
1122
+ hourofday=ci_hour,
1123
+ )
1124
+
1125
+ vert_file = self.data_file_vertical_climate(
1126
+ dayofyear=ci_year,
1127
+ hourofday=ci_hour,
1128
+ )
1129
+
1130
+ clim_data = self._read_climate((surf_file, vert_file))
1131
+
1132
+ sample_data["sur_climate"] = clim_data["surf"]
1133
+ sample_data["ulv_climate"] = clim_data["vert"]
1134
+
1135
+ # Move the data from numpy to torch
1136
+ sample_data = self._to_torch(sample_data, dtype=self.dtype)
1137
+
1138
+ # Optionally roll
1139
+ if len(self._roll_longitudes) > 0:
1140
+ roll_by = random.choice(self._roll_longitudes)
1141
+ sample_data = self._lat_roll(sample_data, roll_by)
1142
+
1143
+ # Now that we have rolled, we can add the static data
1144
+ sample_data["lead_time"] = spec.lead_time
1145
+ sample_data["input_time"] = spec.input_time
1146
+
1147
+ return sample_data
1148
+
1149
+ def get_data(
1150
+ self, timestamp: pd.Timestamp, input_time: int, lead_time: int
1151
+ ) -> dict[str, Tensor | int]:
1152
+ """
1153
+ Loads data based on timestamp and lead time.
1154
+ Args:
1155
+ timestamp: Timestamp.
1156
+ input_time: time between input samples.
1157
+ lead_time: lead time.
1158
+ Returns:
1159
+ Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',
1160
+ 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',
1161
+ 'lead_time'.
1162
+ """
1163
+ spec = SampleSpec.get(timestamp, -input_time, lead_time)
1164
+ sample_data = self.get_data_from_sample_spec(spec)
1165
+ return sample_data
1166
+
1167
+ def __getitem__(self, idx: int) -> dict[str, Tensor | int]:
1168
+ """
1169
+ Loads data based on sample index and random choice of sample.
1170
+ Args:
1171
+ idx: Sample index.
1172
+ Returns:
1173
+ Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',
1174
+ 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',
1175
+ 'lead_time', 'input_time'.
1176
+ """
1177
+ sample_set = self.samples[idx]
1178
+ timestamp, input_time, lead_time, *nsteps = random.choice(sample_set)
1179
+ sample_data = self.get_data(timestamp, input_time, lead_time)
1180
+ return sample_data
1181
+
1182
+ def __len__(self):
1183
+ return len(self.samples)
1184
+
1185
+ from functools import cached_property
1186
+ from importlib.metadata import version
1187
+
1188
+ from torch import Tensor
1189
+ from torch.utils.checkpoint import checkpoint
1190
+
1191
+ if version("torch") > "2.3.0":
1192
+ from torch.nn.attention import SDPBackend, sdpa_kernel
1193
+ import numpy as np
1194
+ import torch
1195
+ import torch.nn as nn
1196
+ import torch.nn.functional as F
1197
+
1198
+
1199
+ # DropPath code is straight from timm
1200
+ # (https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py)
1201
+ def drop_path(
1202
+ x: Tensor,
1203
+ drop_prob: float = 0.0,
1204
+ training: bool = False,
1205
+ scale_by_keep: bool = True,
1206
+ ) -> Tensor:
1207
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
1208
+ residual blocks). Taken form timm.
1209
+
1210
+ Args:
1211
+ x (Tensor): Input tensor.
1212
+ drop_prob (float): Probability of dropping `x`, defaults to 0.
1213
+ training (bool): Whether model is in in traingin of eval mode,
1214
+ defaults to False.
1215
+ scale_by_keep (bool): Whether the output should scaled by
1216
+ (`1 - drop_prob`), defaults to True.
1217
+ Returns:
1218
+ Tensor: Tensor that may have randomly dropped with proability
1219
+ `drop_path`
1220
+ """
1221
+ if drop_prob == 0.0 or not training:
1222
+ return x
1223
+ keep_prob = 1 - drop_prob
1224
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
1225
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
1226
+ if keep_prob > 0.0 and scale_by_keep:
1227
+ random_tensor.div_(keep_prob)
1228
+ return x * random_tensor
1229
+
1230
+
1231
+ class DropPath(nn.Module):
1232
+ """
1233
+ Drop paths (Stochastic Depth) per sample (when applied in main path of
1234
+ residual blocks).
1235
+ """
1236
+
1237
+ def __init__(
1238
+ self, drop_prob: float | None = None, scale_by_keep: bool = True
1239
+ ) -> None:
1240
+ super(DropPath, self).__init__()
1241
+ self.drop_prob = drop_prob
1242
+ self.scale_by_keep = scale_by_keep
1243
+
1244
+ def forward(self, x: Tensor) -> Tensor:
1245
+ """Runs drop path on input tensor
1246
+
1247
+ Args:
1248
+ x: input
1249
+
1250
+ Returns:
1251
+ tensor: output after drop_path
1252
+ """
1253
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
1254
+
1255
+
1256
+ class Mlp(nn.Module):
1257
+ """
1258
+ Multi layer perceptron.
1259
+ """
1260
+
1261
+ def __init__(
1262
+ self, features: int, hidden_features: int, dropout: float = 0.0
1263
+ ) -> None:
1264
+ """
1265
+ Args:
1266
+ features: Input/output dimension.
1267
+ hidden_features: Hidden dimension.
1268
+ dropout: Dropout.
1269
+ """
1270
+ super().__init__()
1271
+ self.net = nn.Sequential(
1272
+ nn.Linear(features, hidden_features),
1273
+ nn.GELU(),
1274
+ nn.Dropout(dropout),
1275
+ nn.Linear(hidden_features, features),
1276
+ nn.Dropout(dropout),
1277
+ )
1278
+
1279
+ def forward(self, x: Tensor) -> Tensor:
1280
+ """
1281
+ Args:
1282
+ x (Tesnor): Tensor of shape [..., channel]
1283
+ Returns:
1284
+ Tenosr: Tensor of same shape as x.
1285
+ """
1286
+ return self.net(x)
1287
+
1288
+
1289
+ class LayerNormPassThrough(nn.LayerNorm):
1290
+ """Normalising layer that allows the attention mask to be passed through"""
1291
+
1292
+ def __init__(self, *args, **kwargs):
1293
+ super().__init__(*args, **kwargs)
1294
+
1295
+ def forward(
1296
+ self, d: tuple[Tensor, Tensor | None]
1297
+ ) -> tuple[Tensor, Tensor | None]:
1298
+ """Forwards function
1299
+
1300
+ Args:
1301
+ d (tuple): tuple of the data tensor and the attention mask
1302
+ Returns:
1303
+ output (Tensor): normalised output data
1304
+ attn_mask (Tensor): the attention mask that was passed in
1305
+ """
1306
+ input, attn_mask = d
1307
+ output = F.layer_norm(
1308
+ input, self.normalized_shape, self.weight, self.bias, self.eps
1309
+ )
1310
+ return output, attn_mask
1311
+
1312
+
1313
+ class MultiheadAttention(nn.Module):
1314
+ """Multihead attention layer for inputs of shape
1315
+ [..., sequence, features].
1316
+ """
1317
+
1318
+ def __init__(self, features: int, n_heads: int, dropout: float) -> None:
1319
+ """
1320
+ Args:
1321
+ features: Number of features for inputs to the layer.
1322
+ n_heads: Number of attention heads. Should be a factor of features.
1323
+ (I.e. the layer uses features // n_heads.)
1324
+ dropout: Dropout.
1325
+ """ # noqa: E501
1326
+ super().__init__()
1327
+
1328
+ if (features % n_heads) != 0:
1329
+ raise ValueError(
1330
+ f"Features '{features}' is not divisible by heads '{n_heads}'."
1331
+ )
1332
+
1333
+ self.features = features
1334
+ self.n_heads = n_heads
1335
+ self.dropout = dropout
1336
+
1337
+ self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False)
1338
+ self.w_layer = torch.nn.Linear(features, features, bias=False)
1339
+
1340
+ def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:
1341
+ """
1342
+ Args:
1343
+ d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask
1344
+ Returns:
1345
+ Tensor: Tensor of shape [..., sequence, features]
1346
+ """ # noqa: E501
1347
+ x, attn_mask = d
1348
+
1349
+ if not x.shape[-1] == self.features:
1350
+ raise ValueError(
1351
+ f"Expecting tensor with last dimension size {self.features}."
1352
+ )
1353
+
1354
+ passenger_dims = x.shape[:-2]
1355
+ B = passenger_dims.numel()
1356
+ S = x.shape[-2]
1357
+ C = x.shape[-1]
1358
+ x = x.reshape(B, S, C)
1359
+
1360
+ # x [B, S, C]
1361
+ # q, k, v [B, H, S, C/H]
1362
+ q, k, v = (
1363
+ self.qkv_layer(x)
1364
+ .view(B, S, self.n_heads, 3 * (C // self.n_heads))
1365
+ .transpose(1, 2)
1366
+ .chunk(chunks=3, dim=3)
1367
+ )
1368
+
1369
+ # Let us enforce either flash (A100+) or memory efficient attention.
1370
+ if version("torch") > "2.3.0":
1371
+ with sdpa_kernel(
1372
+ [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
1373
+ ):
1374
+ # x [B, H, S, C//H]
1375
+ x = F.scaled_dot_product_attention(
1376
+ q, k, v, attn_mask=attn_mask, dropout_p=self.dropout
1377
+ )
1378
+ else:
1379
+ with torch.backends.cuda.sdp_kernel(
1380
+ enable_flash=True, enable_math=False, enable_mem_efficient=True
1381
+ ):
1382
+ # x [B, H, S, C//H]
1383
+ x = F.scaled_dot_product_attention(
1384
+ q, k, v, dropout_p=self.dropout
1385
+ )
1386
+
1387
+ # x [B, S, C]
1388
+ x = x.transpose(1, 2).view(B, S, C)
1389
+
1390
+ # x [B, S, C]
1391
+ x = self.w_layer(x)
1392
+
1393
+ # Back to input shape
1394
+ x = x.view(*passenger_dims, S, self.features)
1395
+ return x
1396
+
1397
+
1398
+ class Transformer(nn.Module):
1399
+ """
1400
+ Transformer for inputs of shape [..., S, features].
1401
+ """
1402
+
1403
+ def __init__(
1404
+ self,
1405
+ features: int,
1406
+ mlp_multiplier: int,
1407
+ n_heads: int,
1408
+ dropout: float,
1409
+ drop_path: float,
1410
+ ) -> None:
1411
+ """
1412
+ Args:
1413
+ features: Number of features for inputs to the layer.
1414
+ mlp_multiplier: Model uses features*mlp_multiplier hidden units.
1415
+ n_heads: Number of attention heads. Should be a factor of features.
1416
+ (I.e. the layer uses features // n_heads.) dropout: Dropout.
1417
+ drop_path: DropPath.
1418
+ """
1419
+ super().__init__()
1420
+
1421
+ self.features = features
1422
+ self.mlp_multiplier = mlp_multiplier
1423
+ self.n_heads = n_heads
1424
+ self.dropout = dropout
1425
+ self.drop_path = (
1426
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1427
+ )
1428
+
1429
+ self.attention = nn.Sequential(
1430
+ LayerNormPassThrough(features),
1431
+ MultiheadAttention(features, n_heads, dropout),
1432
+ )
1433
+
1434
+ self.ff = nn.Sequential(
1435
+ nn.LayerNorm(features),
1436
+ Mlp(
1437
+ features=features,
1438
+ hidden_features=features * mlp_multiplier,
1439
+ dropout=dropout,
1440
+ ),
1441
+ )
1442
+
1443
+ def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:
1444
+ """
1445
+ Args:
1446
+ x: Tensor of shape [..., sequence, features]
1447
+ Returns:
1448
+ Tensor: Tensor of shape [..., sequence, features]
1449
+ """
1450
+ x, attn_mask = d
1451
+ if not x.shape[-1] == self.features:
1452
+ raise ValueError(
1453
+ f"Expecting tensor with last dimension size {self.features}."
1454
+ )
1455
+
1456
+ attention_x = self.attention(d)
1457
+
1458
+ x = x + self.drop_path(attention_x)
1459
+ x = x + self.drop_path(self.ff(x))
1460
+
1461
+ return x
1462
+
1463
+
1464
+ class _Shift(nn.Module):
1465
+ """Private base class for the shifter. This allows some behaviour to be
1466
+ easily handled when the shifter isn't used.
1467
+ """
1468
+
1469
+ def __init__(self):
1470
+ super().__init__()
1471
+
1472
+ self._shifted = False
1473
+
1474
+ @torch.no_grad()
1475
+ def reset(self) -> None:
1476
+ """
1477
+ Resets the bool tracking whether the data is shifted
1478
+ """
1479
+ self._shifted: bool = False
1480
+
1481
+ def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]:
1482
+ return data, {True: None, False: None}
1483
+
1484
+
1485
+ class SWINShift(_Shift):
1486
+ """
1487
+ Handles the shifting of patches similar to how SWIN works. However if we
1488
+ shift the latitudes then the poles will wrap and potentially that might be
1489
+ problematic. The possition tokens should handle it but masking is safer.
1490
+ """
1491
+
1492
+ def __init__(
1493
+ self,
1494
+ mu_shape: tuple[int, int],
1495
+ global_shape: tuple[int, int],
1496
+ local_shape: tuple[int, int],
1497
+ patch_shape: tuple[int, int],
1498
+ n_context_tokens: int = 2,
1499
+ ) -> None:
1500
+ """
1501
+ Args:
1502
+ mu_shape: the shape to the masking units
1503
+ global_shape: number of global patches in lat and lon
1504
+ local_shape: size of the local patches
1505
+ patch_shape: patch size
1506
+ n_context_token: number of additional context tokens at start of
1507
+ _each_ local sequence
1508
+ """
1509
+ super().__init__()
1510
+
1511
+ self._mu_shape = ms = mu_shape
1512
+ self._g_shape = gs = global_shape
1513
+ self._l_shape = ls = local_shape
1514
+ self._p_shape = ps = patch_shape
1515
+ self._lat_patch = (gs[0], ls[0], gs[1], ls[1])
1516
+ self._n_context_tokens = n_context_tokens
1517
+
1518
+ self._g_shift_to = tuple(
1519
+ int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)
1520
+ )
1521
+ self._g_shift_from = tuple(
1522
+ -int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)
1523
+ )
1524
+
1525
+ # Define the attention masks for the shifted MaxViT.
1526
+ nglobal = global_shape[0] * global_shape[1]
1527
+ nlocal = (
1528
+ local_shape[0] * local_shape[1] + self._n_context_tokens
1529
+ ) # "+ 1" for leadtime
1530
+
1531
+ lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool)
1532
+ mwidth = int(0.5 * local_shape[1]) * local_shape[0]
1533
+ lm[
1534
+ : gs[1],
1535
+ :,
1536
+ self._n_context_tokens : mwidth + self._n_context_tokens,
1537
+ self._n_context_tokens : mwidth + self._n_context_tokens,
1538
+ ] = False
1539
+ self.register_buffer("local_mask", lm)
1540
+
1541
+ gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool)
1542
+ gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False
1543
+ self.register_buffer("global_mask", gm)
1544
+
1545
+ def _to_grid_global(self, x: Tensor) -> Tensor:
1546
+ """
1547
+ Shuffle and reshape the data from the global/local setting back to the
1548
+ lat/lon grid setting
1549
+ Args:
1550
+ x: the data tensor to be shuffled.
1551
+ Returns:
1552
+ x: data in the global/local setting
1553
+ """
1554
+ nbatch, *other = x.shape
1555
+
1556
+ y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1)
1557
+ y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous()
1558
+
1559
+ s = y2.shape
1560
+ return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5]))
1561
+
1562
+ def _to_grid_local(self, x: Tensor) -> Tensor:
1563
+ """
1564
+ Shuffle and reshape the data from the local/global setting to the
1565
+ lat/lon grid setting
1566
+ Args:
1567
+ x: the data tensor to be shuffled.
1568
+ Returns:
1569
+ x: data in the lat/lon setting.
1570
+ """
1571
+ x = x.transpose(2, 1).contiguous()
1572
+ return self._to_grid_global(x)
1573
+
1574
+ def _from_grid_global(self, x: Tensor) -> Tensor:
1575
+ """
1576
+ Shuffle and reshape the data from the lat/lon grid to the global/local
1577
+ setting
1578
+ Args:
1579
+ x: the data tensor to be shuffled.
1580
+ Returns:
1581
+ x: data in the global/local setting
1582
+ """
1583
+ nbatch, *other = x.shape
1584
+
1585
+ z1 = x.view(nbatch, -1, *self._lat_patch)
1586
+ z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous()
1587
+
1588
+ s = z2.shape
1589
+ return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1)
1590
+
1591
+ def _from_grid_local(self, x: Tensor) -> Tensor:
1592
+ """
1593
+ Shuffle and reshape the data from the lat/lon grid to the local/global
1594
+ setting
1595
+ Args:
1596
+ x: the data tensor to be shuffled.
1597
+ Returns:
1598
+ x: data in the local/global setting
1599
+ """
1600
+ x = self._from_grid_global(x)
1601
+ return x.transpose(2, 1).contiguous()
1602
+
1603
+ def _shift(self, x: Tensor) -> Tensor:
1604
+ """
1605
+ Shifts data in the gridded lat/lon setting by half the mask unit shape
1606
+ Args:
1607
+ x: data to be shifted
1608
+ Returns:
1609
+ x: either the hsifted or unshifted data
1610
+ """
1611
+ shift = self._g_shift_from if self._shifted else self._g_shift_to
1612
+ x_shifted = torch.roll(x, shift, (-2, -1))
1613
+
1614
+ self._shifted = not self._shifted
1615
+ return x_shifted
1616
+
1617
+ def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]:
1618
+ """
1619
+ Seperate off the leadtime from the local patches
1620
+ Args:
1621
+ x: data to have leadtime removed from
1622
+ Returns:
1623
+ lt: leadtime
1624
+ x: data without the lead time in the local patch
1625
+ """
1626
+ lt_it = x[:, : self._n_context_tokens, :, :]
1627
+ x_stripped = x[:, self._n_context_tokens :, :, :]
1628
+
1629
+ return lt_it, x_stripped
1630
+
1631
+ def forward(self, data: Tensor) -> tuple[Tensor, Tensor]:
1632
+ """Shift or unshift the the data depending on whether the data is
1633
+ already shifted, as defined by self._shifte.
1634
+
1635
+ Args:
1636
+ data: data to be shifted
1637
+ Returns:
1638
+ Tensor: shifted data Tensor
1639
+ """
1640
+ lt, x = self._sep_lt(data)
1641
+
1642
+ x_grid = self._to_grid_local(x)
1643
+ x_shifted = self._shift(x_grid)
1644
+ x_patched = self._from_grid_local(x_shifted)
1645
+
1646
+ # Mask has to be repeated based on batch size
1647
+ n_batch = x_grid.shape[0]
1648
+ local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1)
1649
+ global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1)
1650
+
1651
+ if self._shifted:
1652
+ attn_mask = {
1653
+ True: self.local_mask.repeat(local_rep),
1654
+ False: self.global_mask.repeat(global_rep),
1655
+ }
1656
+ else:
1657
+ attn_mask = {True: None, False: None}
1658
+
1659
+ return torch.cat((lt, x_patched), axis=1), attn_mask
1660
+
1661
+
1662
+ class LocalGlobalLocalBlock(nn.Module):
1663
+ """
1664
+ Applies alternating block and grid attention. Given a parameter n_blocks,
1665
+ the entire module contains 2*n_blocks+1 transformer blocks. The first,
1666
+ third, ..., last apply local (block) attention. The second, fourth, ...
1667
+ global (grid) attention.
1668
+
1669
+ This is heavily inspired by
1670
+ Tu et al. "MaxViT: Multi-Axis Vision Transformer"
1671
+ (https://arxiv.org/abs/2204.01697).
1672
+ """
1673
+
1674
+ def __init__(
1675
+ self,
1676
+ features: int,
1677
+ mlp_multiplier: int,
1678
+ n_heads: int,
1679
+ dropout: float,
1680
+ n_blocks: int,
1681
+ drop_path: float,
1682
+ shifter: nn.Module | None = None,
1683
+ checkpoint: list[int] | None = None,
1684
+ ) -> None:
1685
+ """
1686
+ Args:
1687
+ features: Number of features for inputs to the layer.
1688
+ mlp_multiplier: Model uses features*mlp_multiplier hidden units.
1689
+ n_heads: Number of attention heads. Should be a factor of features.
1690
+ (I.e. the layer uses features // n_heads.)
1691
+ dropout: Dropout.
1692
+ drop_path: DropPath.
1693
+ n_blocks: Number of local-global transformer pairs.
1694
+ """
1695
+ super().__init__()
1696
+
1697
+ self.features = features
1698
+ self.mlp_multiplier = mlp_multiplier
1699
+ self.n_heads = n_heads
1700
+ self.dropout = dropout
1701
+ self.drop_path = drop_path
1702
+ self.n_blocks = n_blocks
1703
+ self._checkpoint = checkpoint or []
1704
+
1705
+ if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint):
1706
+ raise ValueError(
1707
+ "Checkpoints should be 0 <= i < 2*n_blocks+1. "
1708
+ f"{self._checkpoint=}."
1709
+ )
1710
+
1711
+ self.transformers = nn.ModuleList(
1712
+ [
1713
+ Transformer(
1714
+ features=features,
1715
+ mlp_multiplier=mlp_multiplier,
1716
+ n_heads=n_heads,
1717
+ dropout=dropout,
1718
+ drop_path=drop_path,
1719
+ )
1720
+ for _ in range(2 * n_blocks + 1)
1721
+ ]
1722
+ )
1723
+
1724
+ self.evaluator = [
1725
+ self._checkpoint_wrapper
1726
+ if i in self._checkpoint
1727
+ else lambda m, x: m(x)
1728
+ for i, _ in enumerate(self.transformers)
1729
+ ]
1730
+
1731
+ self.shifter = shifter or _Shift()
1732
+
1733
+ @staticmethod
1734
+ def _checkpoint_wrapper(
1735
+ model: nn.Module, data: tuple[Tensor, Tensor | None]
1736
+ ) -> Tensor:
1737
+ return checkpoint(model, data, use_reentrant=False)
1738
+
1739
+ def forward(self, x: Tensor) -> Tensor:
1740
+ """
1741
+ Args:
1742
+ x: Tensor of shape::
1743
+
1744
+ [batch, global_sequence, local_sequence, features]
1745
+
1746
+ Returns:
1747
+ Tensor: Tensor of shape::
1748
+
1749
+ [batch, global_sequence, local_sequence, features]
1750
+ """
1751
+ if x.shape[-1] != self.features:
1752
+ raise ValueError(
1753
+ f"Expecting tensor with last dimension size {self.features}."
1754
+ )
1755
+ if x.ndim != 4:
1756
+ raise ValueError(
1757
+ f"Expecting tensor with exactly four dimensions. {x.shape=}."
1758
+ )
1759
+
1760
+ self.shifter.reset()
1761
+ local: bool = True
1762
+ attn_mask = {True: None, False: None}
1763
+
1764
+ transformer_iter = zip(self.evaluator, self.transformers, strict=False)
1765
+
1766
+ # First local block
1767
+ evaluator, transformer = next(transformer_iter)
1768
+ x = evaluator(transformer, (x, attn_mask[local]))
1769
+
1770
+ for evaluator, transformer in transformer_iter:
1771
+ local = not local
1772
+ # We are making exactly 2*n_blocks transposes.
1773
+ # So the output has the same shape as input.
1774
+ x = x.transpose(1, 2)
1775
+
1776
+ x = evaluator(transformer, (x, attn_mask[local]))
1777
+
1778
+ if not local:
1779
+ x, attn_mask = self.shifter(x)
1780
+
1781
+ return x
1782
+
1783
+
1784
+ class PatchEmbed(nn.Module):
1785
+ """
1786
+ Patch embedding via 2D convolution.
1787
+ """
1788
+
1789
+ def __init__(
1790
+ self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int
1791
+ ):
1792
+ super().__init__()
1793
+
1794
+ self.patch_size = patch_size
1795
+ self.channels = channels
1796
+ self.embed_dim = embed_dim
1797
+
1798
+ self.proj = nn.Conv2d(
1799
+ channels,
1800
+ embed_dim,
1801
+ kernel_size=patch_size,
1802
+ stride=patch_size,
1803
+ bias=True,
1804
+ )
1805
+
1806
+ def forward(self, x: Tensor) -> Tensor:
1807
+ """
1808
+ Args:
1809
+ x: Tensor of shape [batch, channels, lat, lon].
1810
+ Returns:
1811
+ Tensor: Tensor with shape
1812
+ [batch, embed_dim, lat//patch_size, lon//patch_size]
1813
+ """
1814
+
1815
+ H, W = x.shape[-2:]
1816
+
1817
+ if W % self.patch_size[1] != 0:
1818
+ raise ValueError(
1819
+ f"Cannot do patch embedding for tensor of shape {x.size()}"
1820
+ " with patch size {self.patch_size}. (Dimensions are BSCHW.)"
1821
+ )
1822
+ if H % self.patch_size[0] != 0:
1823
+ raise ValueError(
1824
+ f"Cannot do patch embedding for tensor of shape {x.size()}"
1825
+ f" with patch size {self.patch_size}. (Dimensions are BSCHW.)"
1826
+ )
1827
+
1828
+ x = self.proj(x)
1829
+
1830
+ return x
1831
+
1832
+
1833
+ class PrithviWxCEncoderDecoder(nn.Module):
1834
+ """
1835
+ Hiera-MaxViT encoder/decoder code.
1836
+ """
1837
+
1838
+ def __init__(
1839
+ self,
1840
+ embed_dim: int,
1841
+ n_blocks: int,
1842
+ mlp_multiplier: float,
1843
+ n_heads: int,
1844
+ dropout: float,
1845
+ drop_path: float,
1846
+ shifter: nn.Module | None = None,
1847
+ transformer_cp: list[int] | None = None,
1848
+ ) -> None:
1849
+ """
1850
+ Args:
1851
+ embed_dim: Embedding dimension
1852
+ n_blocks: Number of local-global transformer pairs.
1853
+ mlp_multiplier: MLP multiplier for hidden features in feed forward
1854
+ networks.
1855
+ n_heads: Number of attention heads.
1856
+ dropout: Dropout.
1857
+ drop_path: DropPath.
1858
+ """
1859
+ super().__init__()
1860
+
1861
+ self.embed_dim = embed_dim
1862
+ self.n_blocks = n_blocks
1863
+ self.mlp_multiplier = mlp_multiplier
1864
+ self.n_heads = n_heads
1865
+ self.dropout = dropout
1866
+ self._transformer_cp = transformer_cp
1867
+
1868
+ self.lgl_block = LocalGlobalLocalBlock(
1869
+ features=embed_dim,
1870
+ mlp_multiplier=mlp_multiplier,
1871
+ n_heads=n_heads,
1872
+ dropout=dropout,
1873
+ drop_path=drop_path,
1874
+ n_blocks=n_blocks,
1875
+ shifter=shifter,
1876
+ checkpoint=transformer_cp,
1877
+ )
1878
+
1879
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1880
+ """
1881
+ Args:
1882
+ x: Tensor of shape
1883
+ [batch, global sequence, local sequence, embed_dim]
1884
+ Returns:
1885
+ Tensor of shape
1886
+ [batch, mask_unit_sequence, local_sequence, embed_dim].
1887
+ Identical in shape to the input x.
1888
+ """
1889
+
1890
+ x = self.lgl_block(x)
1891
+
1892
+ return x
1893
+
1894
+
1895
+ class PrithviWxC(nn.Module):
1896
+ """Encoder-decoder fusing Hiera with MaxViT. See
1897
+ - Ryali et al. "Hiera: A Hierarchical Vision Transformer without the
1898
+ Bells-and-Whistles" (https://arxiv.org/abs/2306.00989)
1899
+ - Tu et al. "MaxViT: Multi-Axis Vision Transformer"
1900
+ (https://arxiv.org/abs/2204.01697)
1901
+ """
1902
+
1903
+ def __init__(
1904
+ self,
1905
+ in_channels: int,
1906
+ input_size_time: int,
1907
+ in_channels_static: int,
1908
+ input_scalers_mu: Tensor,
1909
+ input_scalers_sigma: Tensor,
1910
+ input_scalers_epsilon: float,
1911
+ static_input_scalers_mu: Tensor,
1912
+ static_input_scalers_sigma: Tensor,
1913
+ static_input_scalers_epsilon: float,
1914
+ output_scalers: Tensor,
1915
+ n_lats_px: int,
1916
+ n_lons_px: int,
1917
+ patch_size_px: tuple[int],
1918
+ mask_unit_size_px: tuple[int],
1919
+ mask_ratio_inputs: float,
1920
+ embed_dim: int,
1921
+ n_blocks_encoder: int,
1922
+ n_blocks_decoder: int,
1923
+ mlp_multiplier: float,
1924
+ n_heads: int,
1925
+ dropout: float,
1926
+ drop_path: float,
1927
+ parameter_dropout: float,
1928
+ residual: str,
1929
+ masking_mode: str,
1930
+ positional_encoding: str,
1931
+ decoder_shifting: bool = False,
1932
+ checkpoint_encoder: list[int] | None = None,
1933
+ checkpoint_decoder: list[int] | None = None,
1934
+ ) -> None:
1935
+ """
1936
+ Args:
1937
+ in_channels: Number of input channels.
1938
+ input_size_time: Number of timestamps in input.
1939
+ in_channels_static: Number of input channels for static data.
1940
+ input_scalers_mu: Tensor of size (in_channels,). Used to rescale
1941
+ input.
1942
+ input_scalers_sigma: Tensor of size (in_channels,). Used to rescale
1943
+ input.
1944
+ input_scalers_epsilon: Float. Used to rescale input.
1945
+ static_input_scalers_mu: Tensor of size (in_channels_static). Used
1946
+ to rescale static inputs.
1947
+ static_input_scalers_sigma: Tensor of size (in_channels_static).
1948
+ Used to rescale static inputs.
1949
+ static_input_scalers_epsilon: Float. Used to rescale static inputs.
1950
+ output_scalers: Tensor of shape (in_channels,). Used to rescale
1951
+ output.
1952
+ n_lats_px: Total latitudes in data. In pixels.
1953
+ n_lons_px: Total longitudes in data. In pixels.
1954
+ patch_size_px: Patch size for tokenization. In pixels lat/lon.
1955
+ mask_unit_size_px: Size of each mask unit. In pixels lat/lon.
1956
+ mask_ratio_inputs: Masking ratio for inputs. 0 to 1.
1957
+ embed_dim: Embedding dimension
1958
+ n_blocks_encoder: Number of local-global transformer pairs in
1959
+ encoder.
1960
+ n_blocks_decoder: Number of local-global transformer pairs in
1961
+ decoder.
1962
+ mlp_multiplier: MLP multiplier for hidden features in feed forward
1963
+ networks.
1964
+ n_heads: Number of attention heads.
1965
+ dropout: Dropout.
1966
+ drop_path: DropPath.
1967
+ parameter_dropout: Dropout applied to parameters.
1968
+ residual: Indicates whether and how model should work as residual
1969
+ model. Accepted values are 'climate', 'temporal' and 'none'
1970
+ positional_encoding: possible values are
1971
+ ['absolute' (default), 'fourier'].
1972
+ 'absolute' lat lon encoded in 3 dimensions using sine and
1973
+ cosine
1974
+ 'fourier' lat/lon to be encoded using various frequencies
1975
+ masking_mode: String ['local', 'global', 'both'] that controls the
1976
+ type of masking used.
1977
+ checkpoint_encoder: List of integers controlling if gradient
1978
+ checkpointing is used on encoder.
1979
+ Format: [] for no gradient checkpointing. [3, 7] for
1980
+ checkpointing after 4th and 8th layer etc.
1981
+ checkpoint_decoder: List of integers controlling if gradient
1982
+ checkpointing is used on decoder.
1983
+ Format: See `checkpoint_encoder`.
1984
+ masking_mode: The type of masking to use
1985
+ {'global', 'local', 'both'}
1986
+ decoder_shifting: Whether to use swin shifting in the decoder.
1987
+ """
1988
+ super().__init__()
1989
+
1990
+ self.in_channels = in_channels
1991
+ self.input_size_time = input_size_time
1992
+ self.in_channels_static = in_channels_static
1993
+ self.n_lats_px = n_lats_px
1994
+ self.n_lons_px = n_lons_px
1995
+ self.patch_size_px = patch_size_px
1996
+ self.mask_unit_size_px = mask_unit_size_px
1997
+ self.mask_ratio_inputs = mask_ratio_inputs
1998
+ self.embed_dim = embed_dim
1999
+ self.n_blocks_encoder = n_blocks_encoder
2000
+ self.n_blocks_decoder = n_blocks_decoder
2001
+ self.mlp_multiplier = mlp_multiplier
2002
+ self.n_heads = n_heads
2003
+ self.dropout = dropout
2004
+ self.drop_path = drop_path
2005
+ self.residual = residual
2006
+ self._decoder_shift = decoder_shifting
2007
+ self.positional_encoding = positional_encoding
2008
+ self._checkpoint_encoder = checkpoint_encoder
2009
+ self._checkpoint_decoder = checkpoint_decoder
2010
+
2011
+ assert self.n_lats_px % self.mask_unit_size_px[0] == 0
2012
+ assert self.n_lons_px % self.mask_unit_size_px[1] == 0
2013
+ assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0
2014
+ assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0
2015
+
2016
+ if self.patch_size_px[0] != self.patch_size_px[1]:
2017
+ raise NotImplementedError(
2018
+ "Current pixel shuffle symmetric patches."
2019
+ )
2020
+
2021
+ self.local_shape_mu = (
2022
+ self.mask_unit_size_px[0] // self.patch_size_px[0],
2023
+ self.mask_unit_size_px[1] // self.patch_size_px[1],
2024
+ )
2025
+ self.global_shape_mu = (
2026
+ self.n_lats_px // self.mask_unit_size_px[0],
2027
+ self.n_lons_px // self.mask_unit_size_px[1],
2028
+ )
2029
+
2030
+ assert input_scalers_mu.shape == (in_channels,)
2031
+ assert input_scalers_sigma.shape == (in_channels,)
2032
+ assert output_scalers.shape == (in_channels,)
2033
+
2034
+ if self.positional_encoding != "fourier":
2035
+ assert static_input_scalers_mu.shape == (in_channels_static,)
2036
+ assert static_input_scalers_sigma.shape == (in_channels_static,)
2037
+
2038
+ # Input shape [batch, time, parameter, lat, lon]
2039
+ self.input_scalers_epsilon = input_scalers_epsilon
2040
+ self.register_buffer(
2041
+ "input_scalers_mu", input_scalers_mu.reshape(1, 1, -1, 1, 1)
2042
+ )
2043
+ self.register_buffer(
2044
+ "input_scalers_sigma", input_scalers_sigma.reshape(1, 1, -1, 1, 1)
2045
+ )
2046
+
2047
+ # Static inputs shape [batch, parameter, lat, lon]
2048
+ self.static_input_scalers_epsilon = static_input_scalers_epsilon
2049
+ self.register_buffer(
2050
+ "static_input_scalers_mu",
2051
+ static_input_scalers_mu.reshape(1, -1, 1, 1),
2052
+ )
2053
+ self.register_buffer(
2054
+ "static_input_scalers_sigma",
2055
+ static_input_scalers_sigma.reshape(1, -1, 1, 1),
2056
+ )
2057
+
2058
+ # Output shape [batch, parameter, lat, lon]
2059
+ self.register_buffer(
2060
+ "output_scalers", output_scalers.reshape(1, -1, 1, 1)
2061
+ )
2062
+
2063
+ self.parameter_dropout = nn.Dropout2d(p=parameter_dropout)
2064
+
2065
+ self.patch_embedding = PatchEmbed(
2066
+ patch_size=patch_size_px,
2067
+ channels=in_channels * input_size_time,
2068
+ embed_dim=embed_dim,
2069
+ )
2070
+
2071
+ if self.residual == "climate":
2072
+ self.patch_embedding_static = PatchEmbed(
2073
+ patch_size=patch_size_px,
2074
+ channels=in_channels + in_channels_static,
2075
+ embed_dim=embed_dim,
2076
+ )
2077
+ else:
2078
+ self.patch_embedding_static = PatchEmbed(
2079
+ patch_size=patch_size_px,
2080
+ channels=in_channels_static,
2081
+ embed_dim=embed_dim,
2082
+ )
2083
+
2084
+ self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)
2085
+ self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)
2086
+
2087
+ self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim))
2088
+ self._nglobal_mu = np.prod(self.global_shape_mu)
2089
+ self._global_idx = torch.arange(self._nglobal_mu)
2090
+
2091
+ self._nlocal_mu = np.prod(self.local_shape_mu)
2092
+ self._local_idx = torch.arange(self._nlocal_mu)
2093
+
2094
+ self.encoder = PrithviWxCEncoderDecoder(
2095
+ embed_dim=embed_dim,
2096
+ n_blocks=n_blocks_encoder,
2097
+ mlp_multiplier=mlp_multiplier,
2098
+ n_heads=n_heads,
2099
+ dropout=dropout,
2100
+ drop_path=drop_path,
2101
+ transformer_cp=checkpoint_encoder,
2102
+ )
2103
+
2104
+ if n_blocks_decoder != 0:
2105
+ if self._decoder_shift:
2106
+ self.decoder_shifter = d_shifter = SWINShift(
2107
+ self.mask_unit_size_px,
2108
+ self.global_shape_mu,
2109
+ self.local_shape_mu,
2110
+ self.patch_size_px,
2111
+ n_context_tokens=0,
2112
+ )
2113
+ else:
2114
+ self.decoder_shifter = d_shifter = None
2115
+
2116
+ self.decoder = PrithviWxCEncoderDecoder(
2117
+ embed_dim=embed_dim,
2118
+ n_blocks=n_blocks_decoder,
2119
+ mlp_multiplier=mlp_multiplier,
2120
+ n_heads=n_heads,
2121
+ dropout=dropout,
2122
+ drop_path=0.0,
2123
+ shifter=d_shifter,
2124
+ transformer_cp=checkpoint_decoder,
2125
+ )
2126
+
2127
+ self.unembed = nn.Linear(
2128
+ self.embed_dim,
2129
+ self.in_channels
2130
+ * self.patch_size_px[0]
2131
+ * self.patch_size_px[1],
2132
+ bias=True,
2133
+ )
2134
+
2135
+ self.masking_mode = masking_mode.lower()
2136
+ match self.masking_mode:
2137
+ case "local":
2138
+ self.generate_mask = self._gen_mask_local
2139
+ case "global":
2140
+ self.generate_mask = self._gen_mask_global
2141
+ case "both":
2142
+ self._mask_both_local: bool = True
2143
+ self.generate_mask = self._gen_mask_both
2144
+ case _:
2145
+ raise ValueError(
2146
+ f"Masking mode '{masking_mode}' not supported"
2147
+ )
2148
+
2149
+ def swap_masking(self) -> None:
2150
+ self._mask_both_local = not self._mask_both_local
2151
+
2152
+ @cached_property
2153
+ def n_masked_global(self):
2154
+ return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu))
2155
+
2156
+ @cached_property
2157
+ def n_masked_local(self):
2158
+ return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu))
2159
+
2160
+ @staticmethod
2161
+ def _shuffle_along_axis(a, axis):
2162
+ idx = torch.argsort(input=torch.rand(*a.shape), dim=axis)
2163
+ return torch.gather(a, dim=axis, index=idx)
2164
+
2165
+ def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]:
2166
+ """
2167
+ Args:
2168
+ batch_size: Number of elements in batch
2169
+ Returns:
2170
+ Tuple of torch tensors. [indices masked, indices unmasked].
2171
+ Each of these is a tensor of shape (batch, global sequene)
2172
+ """
2173
+ # Identify which indices (values) should be masked
2174
+
2175
+ maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1)
2176
+
2177
+ maskable_indices = self._shuffle_along_axis(maskable_indices, 2)
2178
+
2179
+ indices_masked = maskable_indices[:, :, : self.n_masked_local]
2180
+ indices_unmasked = maskable_indices[:, :, self.n_masked_local :]
2181
+
2182
+ return indices_masked, indices_unmasked
2183
+
2184
+ def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]:
2185
+ """
2186
+ Args:
2187
+ batch_size: Number of elements in batch
2188
+ Returns:
2189
+ Tuple of torch tensors. [indices masked, indices unmasked].
2190
+ Each of these is a tensor of shape (batch, global sequene)
2191
+ """
2192
+ # Identify which indices (values) should be masked
2193
+
2194
+ maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1)
2195
+
2196
+ maskable_indices = self._shuffle_along_axis(maskable_indices, 1)
2197
+
2198
+ indices_masked = maskable_indices[:, : self.n_masked_global]
2199
+ indices_unmasked = maskable_indices[:, self.n_masked_global :]
2200
+
2201
+ return indices_masked, indices_unmasked
2202
+
2203
+ def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]:
2204
+ if self._mask_both_local:
2205
+ return self._gen_mask_local(sizes)
2206
+ else:
2207
+ return self._gen_mask_global(sizes)
2208
+
2209
+ @staticmethod
2210
+ def reconstruct_batch(
2211
+ idx_masked: Tensor,
2212
+ idx_unmasked: Tensor,
2213
+ data_masked: Tensor,
2214
+ data_unmasked: Tensor,
2215
+ ) -> Tensor:
2216
+ """Reconstructs a tensor along the mask unit dimension. Batched
2217
+ version.
2218
+
2219
+ Args:
2220
+ idx_masked: Tensor of shape `batch, mask unit sequence`.
2221
+ idx_unmasked: Tensor of shape `batch, mask unit sequence`.
2222
+ data_masked: Tensor of shape `batch, mask unit sequence, ...`.
2223
+ Should have same size along mask unit sequence dimension as
2224
+ idx_masked. Dimensions beyond the first two, marked here as ...
2225
+ will typically be `local_sequence, channel` or
2226
+ `channel, lat, lon`. These dimensions should agree with
2227
+ data_unmasked.
2228
+ data_unmasked: Tensor of shape `batch, mask unit sequence, ...`.
2229
+ Should have same size along mask unit sequence dimension as
2230
+ idx_unmasked. Dimensions beyond the first two, marked here as
2231
+ ... will typically be `local_sequence, channel` or `channel,
2232
+ lat, lon`. These dimensions should agree with data_masked.
2233
+ Returns:
2234
+ Tensor: Tensor of same shape as inputs data_masked and
2235
+ data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for
2236
+ the total data composed of the masked and the unmasked part.
2237
+ """
2238
+ dim: int = idx_masked.ndim
2239
+
2240
+ idx_total = torch.argsort(
2241
+ torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1
2242
+ )
2243
+ idx_total = idx_total.view(
2244
+ *idx_total.shape, *[1] * (data_unmasked.ndim - dim)
2245
+ )
2246
+ idx_total = idx_total.expand(
2247
+ *idx_total.shape[:dim], *data_unmasked.shape[dim:]
2248
+ )
2249
+
2250
+ data = torch.cat([data_masked, data_unmasked], dim=dim - 1)
2251
+ data = torch.gather(data, dim=dim - 1, index=idx_total)
2252
+
2253
+ return data, idx_total
2254
+
2255
+ def fourier_pos_encoding(self, x_static: Tensor) -> Tensor:
2256
+ """
2257
+ Args
2258
+ x_static: B x C x H x W. first two channels are lat, and lon
2259
+ Returns
2260
+ Tensor: Tensor of shape B x E x H x W where E is the embedding
2261
+ dimension.
2262
+ """
2263
+
2264
+ # B x C x H x W -> B x 1 x H/P x W/P
2265
+ latitudes_patch = F.avg_pool2d(
2266
+ x_static[:, [0]],
2267
+ kernel_size=self.patch_size_px,
2268
+ stride=self.patch_size_px,
2269
+ )
2270
+ longitudes_patch = F.avg_pool2d(
2271
+ x_static[:, [1]],
2272
+ kernel_size=self.patch_size_px,
2273
+ stride=self.patch_size_px,
2274
+ )
2275
+
2276
+ modes = (
2277
+ torch.arange(self.embed_dim // 4, device=x_static.device).view(
2278
+ 1, -1, 1, 1
2279
+ )
2280
+ + 1.0
2281
+ )
2282
+ pos_encoding = torch.cat(
2283
+ (
2284
+ torch.sin(latitudes_patch * modes),
2285
+ torch.sin(longitudes_patch * modes),
2286
+ torch.cos(latitudes_patch * modes),
2287
+ torch.cos(longitudes_patch * modes),
2288
+ ),
2289
+ axis=1,
2290
+ )
2291
+
2292
+ return pos_encoding # B x E x H/P x W/P
2293
+
2294
+ def time_encoding(self, input_time, lead_time):
2295
+ """
2296
+ Args:
2297
+ input_time: Tensor of shape [batch].
2298
+ lead_time: Tensor of shape [batch].
2299
+ Returns:
2300
+ Tensor: Tensor of shape [batch, embed_dim, 1, 1]
2301
+ """
2302
+ input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1))
2303
+ lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1))
2304
+
2305
+ time_encoding = torch.cat(
2306
+ (
2307
+ torch.cos(input_time),
2308
+ torch.cos(lead_time),
2309
+ torch.sin(input_time),
2310
+ torch.sin(lead_time),
2311
+ ),
2312
+ axis=3,
2313
+ )
2314
+ return time_encoding
2315
+
2316
+ def to_patching(self, x: Tensor) -> Tensor:
2317
+ """Transform data from lat/lon space to two axis patching
2318
+
2319
+ Args: ->
2320
+ x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1)
2321
+
2322
+ Returns:
2323
+ Tensor in patch space (N, G, L, C)
2324
+ """
2325
+ n_batch = x.shape[0]
2326
+
2327
+ x = x.view(
2328
+ n_batch,
2329
+ -1,
2330
+ self.global_shape_mu[0],
2331
+ self.local_shape_mu[0],
2332
+ self.global_shape_mu[1],
2333
+ self.local_shape_mu[1],
2334
+ )
2335
+ x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
2336
+
2337
+ s = x.shape
2338
+ return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1)
2339
+
2340
+ def from_patching(self, x: Tensor) -> Tensor:
2341
+ """Transform data from two axis patching to lat/lon space
2342
+
2343
+ Args:
2344
+ x: Tensor in patch space with shape (N, G, L, C*P_0*P_1)
2345
+
2346
+ Returns:
2347
+ Tensor: Tensor in lat/lon space
2348
+ (N, C*P_0*P_1, Nlat//P_0, Nlon // P_1)
2349
+ """
2350
+ n_batch = x.shape[0]
2351
+
2352
+ x = x.view(
2353
+ n_batch,
2354
+ self.global_shape_mu[0],
2355
+ self.global_shape_mu[1],
2356
+ self.local_shape_mu[0],
2357
+ self.local_shape_mu[1],
2358
+ -1,
2359
+ )
2360
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
2361
+
2362
+ s = x.shape
2363
+ return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5])
2364
+
2365
+ def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
2366
+ """
2367
+ Args:
2368
+ batch: Dictionary the following keys::
2369
+
2370
+ 'x': Tensor of shape [batch, time, parameter, lat, lon]
2371
+ 'y': Tensor of shape [batch, parameter, lat, lon]
2372
+ 'static': Tensor of shape [batch, channel_static, lat, lon]
2373
+ 'climate': Optional tensor of shape [batch, parameter, lat, lon]
2374
+ 'input_time': Tensor of shape [batch]. Or none.
2375
+ 'lead_time': Tensor of shape [batch]. Or none.
2376
+
2377
+ Returns:
2378
+ Tensor: Tensor of shape [batch, parameter, lat, lon].
2379
+ """ # noqa: E501
2380
+ x_rescaled = (batch["x"] - self.input_scalers_mu) / (
2381
+ self.input_scalers_sigma + self.input_scalers_epsilon
2382
+ )
2383
+ batch_size = x_rescaled.shape[0]
2384
+
2385
+ if self.positional_encoding == "fourier":
2386
+ x_static_pos = self.fourier_pos_encoding(batch["static"])
2387
+ x_static = (
2388
+ batch["static"][:, 2:] - self.static_input_scalers_mu[:, 3:]
2389
+ ) / (
2390
+ self.static_input_scalers_sigma[:, 3:]
2391
+ + self.static_input_scalers_epsilon
2392
+ )
2393
+ else:
2394
+ x_static = (batch["static"] - self.static_input_scalers_mu) / (
2395
+ self.static_input_scalers_sigma
2396
+ + self.static_input_scalers_epsilon
2397
+ )
2398
+
2399
+ if self.residual == "temporal":
2400
+ # We create a residual of same shape as y
2401
+ index = torch.where(
2402
+ batch["lead_time"] > 0, batch["x"].shape[1] - 1, 0
2403
+ )
2404
+ index = index.view(-1, 1, 1, 1, 1)
2405
+ index = index.expand(batch_size, 1, *batch["x"].shape[2:])
2406
+ x_hat = torch.gather(batch["x"], dim=1, index=index)
2407
+ x_hat = x_hat.squeeze(1)
2408
+ elif self.residual == "climate":
2409
+ climate_scaled = (
2410
+ batch["climate"] - self.input_scalers_mu.view(1, -1, 1, 1)
2411
+ ) / (
2412
+ self.input_scalers_sigma.view(1, -1, 1, 1)
2413
+ + self.input_scalers_epsilon
2414
+ )
2415
+
2416
+ # [batch, time, parameter, lat, lon]
2417
+ # -> [batch, time x parameter, lat, lon]
2418
+ x_rescaled = x_rescaled.flatten(1, 2)
2419
+ # Parameter dropout
2420
+ x_rescaled = self.parameter_dropout(x_rescaled)
2421
+
2422
+ x_embedded = self.patch_embedding(x_rescaled)
2423
+
2424
+ if self.residual == "climate":
2425
+ static_embedded = self.patch_embedding_static(
2426
+ torch.cat((x_static, climate_scaled), dim=1)
2427
+ )
2428
+ else:
2429
+ static_embedded = self.patch_embedding_static(x_static)
2430
+
2431
+ if self.positional_encoding == "fourier":
2432
+ static_embedded += x_static_pos
2433
+
2434
+ x_embedded = self.to_patching(x_embedded)
2435
+ static_embedded = self.to_patching(static_embedded)
2436
+
2437
+ time_encoding = self.time_encoding(
2438
+ batch["input_time"], batch["lead_time"]
2439
+ )
2440
+
2441
+ tokens = x_embedded + static_embedded + time_encoding
2442
+
2443
+ # Now we generate masks based on masking_mode
2444
+ indices_masked, indices_unmasked = self.generate_mask(
2445
+ (batch_size, self._nglobal_mu)
2446
+ )
2447
+ indices_masked = indices_masked.to(device=tokens.device)
2448
+ indices_unmasked = indices_unmasked.to(device=tokens.device)
2449
+ maskdim: int = indices_masked.ndim
2450
+
2451
+ # Unmasking
2452
+ unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim))
2453
+ unmasked = torch.gather(
2454
+ tokens,
2455
+ dim=maskdim - 1,
2456
+ index=indices_unmasked.view(*unmask_view).expand(
2457
+ *indices_unmasked.shape, *tokens.shape[maskdim:]
2458
+ ),
2459
+ )
2460
+
2461
+ # Encoder
2462
+ x_encoded = self.encoder(unmasked)
2463
+
2464
+ # Generate and position encode the mask tokens
2465
+ # [1, 1, 1, embed_dim]
2466
+ # -> [batch, global_seq_masked, local seq, embed_dim]
2467
+ mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim))
2468
+ masking = self.mask_token.repeat(*static_embedded.shape[:3], 1)
2469
+ masked = masking + static_embedded
2470
+ masked = torch.gather(
2471
+ masked,
2472
+ dim=maskdim - 1,
2473
+ index=indices_masked.view(*mask_view).expand(
2474
+ *indices_masked.shape, *tokens.shape[maskdim:]
2475
+ ),
2476
+ )
2477
+
2478
+ recon, _ = self.reconstruct_batch(
2479
+ indices_masked, indices_unmasked, masked, x_encoded
2480
+ )
2481
+
2482
+ x_decoded = self.decoder(recon)
2483
+
2484
+ # Output: [batch, global sequence, local sequence,
2485
+ # in_channels * patch_size[0] * patch_size[1]]
2486
+ x_unembed = self.unembed(x_decoded)
2487
+
2488
+ # Reshape to [batch, global_lat, global_lon, local_lat, local_lon,
2489
+ # in_channels * patch_size[0] * patch_size[1]]
2490
+ x_out = self.from_patching(x_unembed)
2491
+
2492
+ # Pixel shuffle to [batch, in_channels, lat, lon]
2493
+ x_out = F.pixel_shuffle(x_out, self.patch_size_px[0])
2494
+
2495
+ if self.residual == "temporal":
2496
+ x_out = self.output_scalers * x_out + x_hat
2497
+ elif self.residual == "climate":
2498
+ x_out = self.output_scalers * x_out + batch["climate"]
2499
+ elif self.residual == "none":
2500
+ x_out = (
2501
+ self.output_scalers * x_out
2502
+ + self.input_scalers_mu.reshape(1, -1, 1, 1)
2503
+ )
2504
+
2505
+ return x_out
__pycache__/Prithvi.cpython-310.pyc ADDED
Binary file (71.5 kB). View file
 
__pycache__/Prithvi.cpython-312.pyc ADDED
Binary file (111 kB). View file
 
__pycache__/aurora_utils.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
__pycache__/config_utils.cpython-310.pyc ADDED
Binary file (445 Bytes). View file
 
__pycache__/data_utils.cpython-310.pyc ADDED
Binary file (1.33 kB). View file
 
__pycache__/fengwu_utils.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
__pycache__/inference_utils.cpython-310.pyc ADDED
Binary file (863 Bytes). View file
 
__pycache__/pangu_utils.cpython-310.pyc ADDED
Binary file (8.3 kB). View file
 
__pycache__/plot_utils.cpython-310.pyc ADDED
Binary file (3.99 kB). View file
 
__pycache__/prithvi_utils.cpython-310.pyc ADDED
Binary file (3.47 kB). View file
 
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import yaml
6
+ import logging
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ from pathlib import Path
10
+ import tempfile
11
+ import traceback
12
+
13
+ from data_utils import (
14
+ save_uploaded_files,
15
+ load_dataset,
16
+ )
17
+
18
+ from inference_utils import run_inference
19
+ from config_utils import load_config
20
+ from plot_utils import plot_prithvi_output, plot_aurora_output
21
+ from prithvi_utils import (
22
+ prithvi_config_ui,
23
+ initialize_prithvi_model,
24
+ prepare_prithvi_batch
25
+ )
26
+ from aurora_utils import aurora_config_ui, prepare_aurora_batch, initialize_aurora_model
27
+
28
+ from pangu_utils import (
29
+ pangu_config_data,
30
+ inference_1hr,
31
+ inference_3hrs,
32
+ inference_6hrs,
33
+ inference_24hrs,
34
+ inference_custom_hrs,
35
+ plot_pangu_output,
36
+ )
37
+
38
+ from fengwu_utils import (fengwu_config_data, inference_6hrs_fengwu, inference_12hrs_fengwu, inference_custom_hrs_fengwu, plot_fengwu_output)
39
+
40
+
41
+ logging.basicConfig(level=logging.INFO)
42
+ logger = logging.getLogger(__name__)
43
+
44
+ # Set page configuration
45
+ st.set_page_config(
46
+ page_title="Weather Data Processor",
47
+ layout="wide",
48
+ initial_sidebar_state="expanded",
49
+ )
50
+
51
+ header_col1, header_col2 = st.columns([4, 1])
52
+ with header_col1:
53
+ st.title("🌦️ Weather & Climate Data Processor and Forecaster")
54
+
55
+ with header_col2:
56
+ st.markdown("### Select a Model")
57
+ selected_model = st.selectbox(
58
+ "",
59
+ options=["Pangu-Weather", "FengWu", "Aurora", "Climax", "Prithvi", "LSTM"],
60
+ index=0,
61
+ key="model_selector",
62
+ help="Select the model you want to use."
63
+ )
64
+
65
+ st.write("---")
66
+
67
+ # --- Layout: Two Columns ---
68
+ left_col, right_col = st.columns([1, 2])
69
+
70
+ with left_col:
71
+ st.header("🔧 Configuration")
72
+
73
+ # Dynamically show configuration UI based on selected model
74
+ if selected_model == "Prithvi":
75
+ (config, uploaded_surface_files, uploaded_vertical_files,
76
+ clim_surf_path, clim_vert_path, config_path, weights_path) = prithvi_config_ui()
77
+ elif selected_model == "Aurora":
78
+ uploaded_files = aurora_config_ui()
79
+ elif selected_model == "Pangu-Weather":
80
+ input_surface_file, input_upper_file = pangu_config_data()
81
+ elif selected_model == "FengWu":
82
+ input_file1_fengwu, input_file2_fengwu = fengwu_config_data()
83
+ else:
84
+ # Generic data upload for other models
85
+ st.subheader(f"{selected_model} Model Data Upload")
86
+ st.markdown("### Drag and Drop Your Data Files Here")
87
+ uploaded_files = st.file_uploader(
88
+ f"Upload Data Files for {selected_model}",
89
+ accept_multiple_files=True,
90
+ key=f"{selected_model.lower()}_uploader",
91
+ type=["nc", "netcdf", "nc4"],
92
+ )
93
+
94
+ st.write("---")
95
+
96
+ # --- Forecast Duration Selection ---
97
+ st.subheader("Forecast Duration")
98
+ forecast_options = ["1 hour", "3 hours", "6 hours", "24 hours", "Custom"]
99
+ selected_duration = st.selectbox(
100
+ "Select forecast duration",
101
+ forecast_options,
102
+ index=3, # Default to 24 hours
103
+ help="Select how many hours to forecast."
104
+ )
105
+
106
+ custom_hours = None
107
+ if selected_duration == "Custom":
108
+ custom_hours = st.number_input(
109
+ "Enter custom forecast hours",
110
+ min_value=24,
111
+ max_value=480,
112
+ value=48,
113
+ step=24,
114
+ help="Enter the number of hours you want to forecast."
115
+ )
116
+
117
+ st.write("---")
118
+
119
+ # Run Inference button
120
+ if st.button("🚀 Run Inference"):
121
+ with right_col:
122
+ st.header("📈 Inference Progress & Visualization")
123
+
124
+ # Set seeds and device
125
+ try:
126
+ torch.jit.enable_onednn_fusion(True)
127
+ if torch.cuda.is_available():
128
+ device = torch.device("cuda")
129
+ st.write(f"Using device: **{torch.cuda.get_device_name()}**")
130
+ torch.backends.cudnn.benchmark = True
131
+ torch.backends.cudnn.deterministic = True
132
+ else:
133
+ device = torch.device("cpu")
134
+ st.write("Using device: **CPU**")
135
+
136
+ random.seed(42)
137
+ if torch.cuda.is_available():
138
+ torch.cuda.manual_seed(42)
139
+ torch.manual_seed(42)
140
+ np.random.seed(42)
141
+ except Exception:
142
+ st.error("Error initializing device:")
143
+ st.error(traceback.format_exc())
144
+ st.stop()
145
+
146
+ # Use a spinner while running inference
147
+ with st.spinner("Running inference, please wait..."):
148
+ # Initialize and run inference for selected model
149
+ if selected_model == "Prithvi":
150
+ model, in_mu, in_sig, output_sig, static_mu, static_sig = initialize_prithvi_model(
151
+ config, config_path, weights_path, device
152
+ )
153
+ batch = prepare_prithvi_batch(
154
+ uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device
155
+ )
156
+ out = run_inference(selected_model, model, batch, device)
157
+ # Store results
158
+ st.session_state['prithvi_out'] = out
159
+ st.session_state['prithvi_done'] = True
160
+
161
+ elif selected_model == "Aurora":
162
+ if uploaded_files:
163
+ save_uploaded_files(uploaded_files)
164
+ ds = load_dataset(st.session_state.temp_file_paths)
165
+ if ds is not None:
166
+ batch = prepare_aurora_batch(ds)
167
+ model = initialize_aurora_model(device)
168
+ out = run_inference(selected_model, model, batch, device)
169
+ st.session_state['aurora_out'] = out
170
+ st.session_state['aurora_ds_subset'] = ds
171
+ st.session_state['aurora_done'] = True
172
+ else:
173
+ st.error("Failed to load dataset for Aurora.")
174
+ st.stop()
175
+ else:
176
+ st.error("Please upload data files for Aurora.")
177
+ st.stop()
178
+
179
+ elif selected_model == "FengWu":
180
+ if input_file1_fengwu and input_file2_fengwu:
181
+ try:
182
+ input1 = np.load(input_file1_fengwu)
183
+ input2 = np.load(input_file2_fengwu)
184
+ if selected_duration == "1 hour":
185
+ st.warning("1hr inference is not yet available on this model.")
186
+ elif selected_duration == "3 hours":
187
+ st.warning("3hrs inference is not yet available on this model.")
188
+ elif selected_duration == "6 hours":
189
+ output_fengwu = inference_6hrs_fengwu(input1, input2)
190
+ elif selected_duration == "12 hours":
191
+ output_fengwu = inference_12hrs_fengwu(input1, input2)
192
+ else:
193
+ output_fengwu = inference_custom_hrs_fengwu(input1, input2, custom_hours)
194
+
195
+ st.session_state['output_fengwu'] = output_fengwu
196
+ st.session_state['fengwu_done'] = True
197
+ st.session_state['input_fengwu'] = input_file2_fengwu
198
+ except Exception as e:
199
+ st.error(f"An error occurred: {e}")
200
+ else:
201
+ st.error("Please upload data files for Aurora.")
202
+ st.stop()
203
+
204
+ elif selected_model == "Pangu-Weather":
205
+ if input_surface_file and input_upper_file:
206
+ try:
207
+ surface_data = np.load(input_surface_file)
208
+ upper_data = np.load(input_upper_file)
209
+
210
+ # Decide which inference function to use based on selection
211
+ if selected_duration == "1 hour":
212
+ out_upper, out_surface = inference_1hr(upper_data, surface_data)
213
+ elif selected_duration == "3 hours":
214
+ out_upper, out_surface = inference_3hrs(upper_data, surface_data)
215
+ elif selected_duration == "6 hours":
216
+ out_upper, out_surface = inference_6hrs(upper_data, surface_data)
217
+ elif selected_duration == "24 hours":
218
+ out_upper, out_surface = inference_24hrs(upper_data, surface_data)
219
+ else:
220
+ out_upper, out_surface = inference_custom_hrs(upper_data, surface_data, custom_hours)
221
+
222
+ # Store results in session_state
223
+ st.session_state['pangu_upper_data'] = upper_data
224
+ st.session_state['pangu_surface_data'] = surface_data
225
+ st.session_state['pangu_out_upper'] = out_upper
226
+ st.session_state['pangu_out_surface'] = out_surface
227
+ st.session_state['pangu_done'] = True
228
+
229
+ st.write("**Forecast Results:**")
230
+ st.write("Upper Data Forecast Shape:", out_upper.shape)
231
+ st.write("Surface Data Forecast Shape:", out_surface.shape)
232
+
233
+ except Exception as e:
234
+ st.error(f"An error occurred: {e}")
235
+ else:
236
+ st.error("Please upload data files for Pangu-Weather.")
237
+ st.stop()
238
+
239
+ else:
240
+ st.warning("Inference not implemented for this model.")
241
+ st.stop()
242
+
243
+ # Visualization after inference is done
244
+ if selected_model == "Prithvi":
245
+ if 'prithvi_done' in st.session_state and st.session_state['prithvi_done']:
246
+ plot_prithvi_output(st.session_state['prithvi_out'])
247
+ elif selected_model == "Aurora":
248
+ if 'aurora_done' in st.session_state and st.session_state['aurora_done']:
249
+ plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset'])
250
+ elif selected_model == "FengWu":
251
+ if 'fengwu_done' in st.session_state and st.session_state['fengwu_done']:
252
+ plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu'])
253
+ elif selected_model == "Pangu-Weather":
254
+ if 'pangu_done' in st.session_state and st.session_state['pangu_done']:
255
+ plot_pangu_output(
256
+ st.session_state['pangu_upper_data'],
257
+ st.session_state['pangu_surface_data'],
258
+ st.session_state['pangu_out_upper'],
259
+ st.session_state['pangu_out_surface']
260
+ )
261
+ else:
262
+ st.info("No visualization implemented for this model.")
263
+
264
+ else:
265
+ # If not running inference now, but we have previously computed results, show them
266
+ with right_col:
267
+ st.header("🖥️ Visualization & Progress")
268
+
269
+ # Check which model was selected and if we have done inference before
270
+ if selected_model == "Prithvi" and 'prithvi_done' in st.session_state and st.session_state['prithvi_done']:
271
+ plot_prithvi_output(st.session_state['prithvi_out'])
272
+ elif selected_model == "Aurora" and 'aurora_done' in st.session_state and st.session_state['aurora_done']:
273
+ plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset'])
274
+ elif selected_model == "Pangu-Weather" and 'pangu_done' in st.session_state and st.session_state['pangu_done']:
275
+ plot_pangu_output(
276
+ st.session_state['pangu_upper_data'],
277
+ st.session_state['pangu_surface_data'],
278
+ st.session_state['pangu_out_upper'],
279
+ st.session_state['pangu_out_surface']
280
+ )
281
+ elif selected_model == "FengWu" and 'output_fengwu' in st.session_state and st.session_state['fengwu_done']:
282
+ plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu'])
283
+ else:
284
+ st.info("Awaiting inference to display results.")
285
+
app1.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import yaml
6
+ from pathlib import Path
7
+ from io import BytesIO
8
+ import random
9
+ from pathlib import Path
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ from huggingface_hub import hf_hub_download, snapshot_download
14
+ import tempfile
15
+ import traceback
16
+ import functools as ft
17
+ import os
18
+ import random
19
+ import re
20
+ from collections import defaultdict
21
+ from datetime import datetime, timedelta
22
+ from pathlib import Path
23
+ import h5py
24
+ import numpy as np
25
+ import pandas as pd
26
+ import torch
27
+ from torch import Tensor
28
+ from torch.utils.data import Dataset
29
+ import logging
30
+ from Prithvi import *
31
+
32
+
33
+ # Configure logging
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ # Set page configuration
39
+ st.set_page_config(
40
+ page_title="MERRA2 Data Processor",
41
+ layout="wide",
42
+ initial_sidebar_state="expanded",
43
+ )
44
+ dataset_type = st.sidebar.selectbox(
45
+ "Select Dataset Type",
46
+ options=["MERRA2", "GEOS5"],
47
+ index=0
48
+ )
49
+ st.title("MERRA2 Data Processor with PrithviWxC Model")
50
+
51
+ # Sidebar for file uploads
52
+ st.sidebar.header("Upload MERRA2 Data Files")
53
+
54
+ # File uploader for surface data
55
+ uploaded_surface_files = st.sidebar.file_uploader(
56
+ "Upload Surface Data Files",
57
+ type=["nc", "netcdf"],
58
+ accept_multiple_files=True,
59
+ key="surface_uploader",
60
+ )
61
+
62
+ # File uploader for vertical data
63
+ uploaded_vertical_files = st.sidebar.file_uploader(
64
+ "Upload Vertical Data Files",
65
+ type=["nc", "netcdf"],
66
+ accept_multiple_files=True,
67
+ key="vertical_uploader",
68
+ )
69
+
70
+ # Optional: Upload config.yaml
71
+ uploaded_config = st.sidebar.file_uploader(
72
+ "Upload config.yaml",
73
+ type=["yaml", "yml"],
74
+ key="config_uploader",
75
+ )
76
+
77
+ # Optional: Upload model weights
78
+ uploaded_weights = st.sidebar.file_uploader(
79
+ "Upload Model Weights (.pt)",
80
+ type=["pt"],
81
+ key="weights_uploader",
82
+ )
83
+
84
+ # Other configurations
85
+ st.sidebar.header("Task Configuration")
86
+
87
+ lead_times = st.sidebar.multiselect(
88
+ "Select Lead Times",
89
+ options=[12, 24, 36, 48],
90
+ default=[12],
91
+ )
92
+
93
+ input_times = st.sidebar.multiselect(
94
+ "Select Input Times",
95
+ options=[-6, -12, -18, -24],
96
+ default=[-6],
97
+ )
98
+
99
+ time_range_start = st.sidebar.text_input(
100
+ "Start Time (e.g., 2020-01-01T00:00:00)",
101
+ value="2020-01-01T00:00:00",
102
+ )
103
+
104
+ time_range_end = st.sidebar.text_input(
105
+ "End Time (e.g., 2020-01-01T23:59:59)",
106
+ value="2020-01-01T23:59:59",
107
+ )
108
+
109
+ time_range = (time_range_start, time_range_end)
110
+
111
+ # Function to save uploaded files
112
+ def save_uploaded_files(uploaded_files, folder_name, max_size_mb=1024):
113
+ if not uploaded_files:
114
+ st.warning(f"No {folder_name} files uploaded.")
115
+ return None
116
+ # Validate file sizes
117
+ for file in uploaded_files:
118
+ if file.size > max_size_mb * 1024 * 1024:
119
+ st.error(f"File {file.name} exceeds the maximum size of {max_size_mb} MB.")
120
+ return None
121
+ temp_dir = tempfile.mkdtemp()
122
+ with st.spinner(f"Saving {folder_name} files..."):
123
+ for uploaded_file in uploaded_files:
124
+ file_path = Path(temp_dir) / uploaded_file.name
125
+ with open(file_path, "wb") as f:
126
+ f.write(uploaded_file.getbuffer())
127
+ st.success(f"Saved {len(uploaded_files)} {folder_name} files.")
128
+ return Path(temp_dir)
129
+
130
+ # Save uploaded files
131
+ surf_dir = save_uploaded_files(uploaded_surface_files, "surface")
132
+ vert_dir = save_uploaded_files(uploaded_vertical_files, "vertical")
133
+
134
+ # Display uploaded files
135
+ if surf_dir:
136
+ st.sidebar.subheader("Surface Files Uploaded:")
137
+ for file in surf_dir.iterdir():
138
+ st.sidebar.write(file.name)
139
+
140
+ if vert_dir:
141
+ st.sidebar.subheader("Vertical Files Uploaded:")
142
+ for file in vert_dir.iterdir():
143
+ st.sidebar.write(file.name)
144
+
145
+ # Handle Climatology Files
146
+ st.sidebar.header("Upload Climatology Files (If Missing)")
147
+
148
+ # Climatology files paths
149
+ default_clim_dir = Path("Prithvi-WxC/examples/climatology")
150
+ surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
151
+ vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
152
+ surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
153
+ vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
154
+
155
+ # Check if climatology files exist
156
+ clim_files_exist = all(
157
+ [
158
+ surf_in_scal_path.exists(),
159
+ vert_in_scal_path.exists(),
160
+ surf_out_scal_path.exists(),
161
+ vert_out_scal_path.exists(),
162
+ ]
163
+ )
164
+
165
+ if not clim_files_exist:
166
+ st.sidebar.warning("Climatology files are missing.")
167
+ uploaded_clim_surface = st.sidebar.file_uploader(
168
+ "Upload Climatology Surface File",
169
+ type=["nc", "netcdf"],
170
+ key="clim_surface_uploader",
171
+ )
172
+ uploaded_clim_vertical = st.sidebar.file_uploader(
173
+ "Upload Climatology Vertical File",
174
+ type=["nc", "netcdf"],
175
+ key="clim_vertical_uploader",
176
+ )
177
+
178
+ if uploaded_clim_surface and uploaded_clim_vertical:
179
+ clim_temp_dir = tempfile.mkdtemp()
180
+ clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
181
+ with open(clim_surf_path, "wb") as f:
182
+ f.write(uploaded_clim_surface.getbuffer())
183
+ clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
184
+ with open(clim_vert_path, "wb") as f:
185
+ f.write(uploaded_clim_vertical.getbuffer())
186
+ st.success("Climatology files uploaded and saved.")
187
+ else:
188
+ if not (uploaded_clim_surface and uploaded_clim_vertical):
189
+ st.warning("Please upload both climatology surface and vertical files.")
190
+ else:
191
+ clim_surf_path = surf_in_scal_path
192
+ clim_vert_path = vert_in_scal_path
193
+
194
+ # Save uploaded config.yaml
195
+ if uploaded_config:
196
+ temp_config = tempfile.mktemp(suffix=".yaml")
197
+ with open(temp_config, "wb") as f:
198
+ f.write(uploaded_config.getbuffer())
199
+ config_path = Path(temp_config)
200
+ st.sidebar.success("Config.yaml uploaded and saved.")
201
+ else:
202
+ # Use default config.yaml path
203
+ config_path = Path("Prithvi-WxC/examples/config.yaml")
204
+ if not config_path.exists():
205
+ st.sidebar.error("Default config.yaml not found. Please upload a config file.")
206
+ st.stop()
207
+
208
+ # Save uploaded model weights
209
+ if uploaded_weights:
210
+ temp_weights = tempfile.mktemp(suffix=".pt")
211
+ with open(temp_weights, "wb") as f:
212
+ f.write(uploaded_weights.getbuffer())
213
+ weights_path = Path(temp_weights)
214
+ st.sidebar.success("Model weights uploaded and saved.")
215
+ else:
216
+ # Use default weights path
217
+ weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
218
+ if not weights_path.exists():
219
+ st.sidebar.error("Default model weights not found. Please upload model weights.")
220
+ st.stop()
221
+
222
+ # Button to run inference
223
+ if st.sidebar.button("Run Inference"):
224
+
225
+ # Initialize device
226
+ torch.jit.enable_onednn_fusion(True)
227
+ if torch.cuda.is_available():
228
+ device = torch.device("cuda")
229
+ st.write(f"Using device: {torch.cuda.get_device_name()}")
230
+ torch.backends.cudnn.benchmark = True
231
+ torch.backends.cudnn.deterministic = True
232
+ else:
233
+ device = torch.device("cpu")
234
+ st.write("Using device: CPU")
235
+
236
+ # Set random seeds
237
+ random.seed(42)
238
+ if torch.cuda.is_available():
239
+ torch.cuda.manual_seed(42)
240
+ torch.manual_seed(42)
241
+ np.random.seed(42)
242
+
243
+ # Define variables and parameters
244
+ surface_vars = [
245
+ "EFLUX",
246
+ "GWETROOT",
247
+ "HFLUX",
248
+ "LAI",
249
+ "LWGAB",
250
+ "LWGEM",
251
+ "LWTUP",
252
+ "PS",
253
+ "QV2M",
254
+ "SLP",
255
+ "SWGNT",
256
+ "SWTNT",
257
+ "T2M",
258
+ "TQI",
259
+ "TQL",
260
+ "TQV",
261
+ "TS",
262
+ "U10M",
263
+ "V10M",
264
+ "Z0M",
265
+ ]
266
+ static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
267
+ vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
268
+ levels = [
269
+ 34.0,
270
+ 39.0,
271
+ 41.0,
272
+ 43.0,
273
+ 44.0,
274
+ 45.0,
275
+ 48.0,
276
+ 51.0,
277
+ 53.0,
278
+ 56.0,
279
+ 63.0,
280
+ 68.0,
281
+ 71.0,
282
+ 72.0,
283
+ ]
284
+ padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
285
+
286
+ residual = "climate"
287
+ masking_mode = "local"
288
+ decoder_shifting = True
289
+ masking_ratio = 0.99
290
+
291
+ positional_encoding = "fourier"
292
+
293
+ # Initialize Dataset
294
+ try:
295
+ with st.spinner("Initializing dataset..."):
296
+ # Validate climatology files
297
+ if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical):
298
+ st.error("Climatology files are missing. Please upload both surface and vertical climatology files.")
299
+ st.stop()
300
+
301
+ dataset = Merra2Dataset(
302
+ time_range=time_range,
303
+ lead_times=lead_times,
304
+ input_times=input_times,
305
+ data_path_surface=Path("Prithvi-WxC/examples/merra-2"),
306
+ data_path_vertical=Path("Prithvi-WxC/examples/merra-2"),
307
+ climatology_path_surface=Path("Prithvi-WxC/examples/climatology"),
308
+ climatology_path_vertical=Path("Prithvi-WxC/examples/climatology"),
309
+ surface_vars=surface_vars,
310
+ static_surface_vars=static_surface_vars,
311
+ vertical_vars=vertical_vars,
312
+ levels=levels,
313
+ positional_encoding=positional_encoding,
314
+ )
315
+ assert len(dataset) > 0, "There doesn't seem to be any valid data."
316
+ st.success("Dataset initialized successfully.")
317
+ except Exception as e:
318
+ st.error("Error initializing dataset:")
319
+ st.error(traceback.format_exc())
320
+ st.stop()
321
+
322
+ # Load scalers
323
+ try:
324
+ with st.spinner("Loading scalers..."):
325
+ # Assuming the scaler paths are the same as climatology paths
326
+ surf_in_scal_path = clim_surf_path
327
+ vert_in_scal_path = clim_vert_path
328
+ surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc"
329
+ vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc"
330
+
331
+ # Check if output scaler files exist
332
+ if not surf_out_scal_path.exists() or not vert_out_scal_path.exists():
333
+ st.error("Anomaly variance scaler files are missing.")
334
+ st.stop()
335
+
336
+ in_mu, in_sig = input_scalers(
337
+ surface_vars,
338
+ vertical_vars,
339
+ levels,
340
+ surf_in_scal_path,
341
+ vert_in_scal_path,
342
+ )
343
+
344
+ output_sig = output_scalers(
345
+ surface_vars,
346
+ vertical_vars,
347
+ levels,
348
+ surf_out_scal_path,
349
+ vert_out_scal_path,
350
+ )
351
+
352
+ static_mu, static_sig = static_input_scalers(
353
+ surf_in_scal_path,
354
+ static_surface_vars,
355
+ )
356
+ st.success("Scalers loaded successfully.")
357
+ except Exception as e:
358
+ st.error("Error loading scalers:")
359
+ st.error(traceback.format_exc())
360
+ st.stop()
361
+
362
+ # Load configuration
363
+ try:
364
+ with st.spinner("Loading configuration..."):
365
+ with open(config_path, "r") as f:
366
+ config = yaml.safe_load(f)
367
+ # Validate config
368
+ required_params = [
369
+ "in_channels", "input_size_time", "in_channels_static",
370
+ "input_scalers_epsilon", "static_input_scalers_epsilon",
371
+ "n_lats_px", "n_lons_px", "patch_size_px",
372
+ "mask_unit_size_px", "embed_dim", "n_blocks_encoder",
373
+ "n_blocks_decoder", "mlp_multiplier", "n_heads",
374
+ "dropout", "drop_path", "parameter_dropout"
375
+ ]
376
+ missing_params = [param for param in required_params if param not in config.get("params", {})]
377
+ if missing_params:
378
+ st.error(f"Missing configuration parameters: {missing_params}")
379
+ st.stop()
380
+ st.success("Configuration loaded successfully.")
381
+ except Exception as e:
382
+ st.error("Error loading configuration:")
383
+ st.error(traceback.format_exc())
384
+ st.stop()
385
+
386
+ # Initialize the model
387
+ try:
388
+ with st.spinner("Initializing model..."):
389
+ model = PrithviWxC(
390
+ in_channels=config["params"]["in_channels"],
391
+ input_size_time=config["params"]["input_size_time"],
392
+ in_channels_static=config["params"]["in_channels_static"],
393
+ input_scalers_mu=in_mu,
394
+ input_scalers_sigma=in_sig,
395
+ input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
396
+ static_input_scalers_mu=static_mu,
397
+ static_input_scalers_sigma=static_sig,
398
+ static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
399
+ output_scalers=output_sig**0.5,
400
+ n_lats_px=config["params"]["n_lats_px"],
401
+ n_lons_px=config["params"]["n_lons_px"],
402
+ patch_size_px=config["params"]["patch_size_px"],
403
+ mask_unit_size_px=config["params"]["mask_unit_size_px"],
404
+ mask_ratio_inputs=masking_ratio,
405
+ embed_dim=config["params"]["embed_dim"],
406
+ n_blocks_encoder=config["params"]["n_blocks_encoder"],
407
+ n_blocks_decoder=config["params"]["n_blocks_decoder"],
408
+ mlp_multiplier=config["params"]["mlp_multiplier"],
409
+ n_heads=config["params"]["n_heads"],
410
+ dropout=config["params"]["dropout"],
411
+ drop_path=config["params"]["drop_path"],
412
+ parameter_dropout=config["params"]["parameter_dropout"],
413
+ residual=residual,
414
+ masking_mode=masking_mode,
415
+ decoder_shifting=decoder_shifting,
416
+ positional_encoding=positional_encoding,
417
+ checkpoint_encoder=[],
418
+ checkpoint_decoder=[],
419
+ )
420
+ st.success("Model initialized successfully.")
421
+ except Exception as e:
422
+ st.error("Error initializing model:")
423
+ st.error(traceback.format_exc())
424
+ st.stop()
425
+
426
+ # Load model weights
427
+ try:
428
+ with st.spinner("Loading model weights..."):
429
+ state_dict = torch.load(weights_path, map_location=device)
430
+ if "model_state" in state_dict:
431
+ state_dict = state_dict["model_state"]
432
+ model.load_state_dict(state_dict, strict=True)
433
+ model.to(device)
434
+ st.success("Model weights loaded successfully.")
435
+ except Exception as e:
436
+ st.error("Error loading model weights:")
437
+ st.error(traceback.format_exc())
438
+ st.stop()
439
+
440
+ # Prepare data batch
441
+ try:
442
+ with st.spinner("Preparing data batch..."):
443
+ data = next(iter(dataset))
444
+ batch = preproc([data], padding)
445
+
446
+ for k, v in batch.items():
447
+ if isinstance(v, torch.Tensor):
448
+ batch[k] = v.to(device)
449
+ st.success("Data batch prepared successfully.")
450
+ except Exception as e:
451
+ st.error("Error preparing data batch:")
452
+ st.error(traceback.format_exc())
453
+ st.stop()
454
+
455
+ # Run inference
456
+ try:
457
+ with st.spinner("Running model inference..."):
458
+ rng_state_1 = torch.get_rng_state()
459
+ with torch.no_grad():
460
+ model.eval()
461
+ out = model(batch)
462
+ st.success("Model inference completed successfully.")
463
+ except Exception as e:
464
+ st.error("Error during model inference:")
465
+ st.error(traceback.format_exc())
466
+ st.stop()
467
+
468
+ # Display output
469
+ st.header("Inference Results")
470
+ st.write(out) # Adjust based on the structure of 'out'
471
+
472
+ # Optionally, provide download links or visualizations
473
+ # For example, if 'out' contains tensors or dataframes:
474
+ # st.write("Output Tensor:", out["some_key"].cpu().numpy())
475
+
476
+ else:
477
+ st.info("Please upload the necessary files and click 'Run Inference' to start.")
app2.py ADDED
@@ -0,0 +1,959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import yaml
6
+ from pathlib import Path
7
+ import tempfile
8
+ import traceback
9
+ import matplotlib.pyplot as plt
10
+ import plotly.graph_objects as go
11
+ from Prithvi import * # Ensure this import includes your model and dataset classes
12
+ import xarray as xr
13
+ from aurora import Batch, Metadata
14
+ from aurora import Aurora, rollout
15
+ import logging
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import cartopy.crs as ccrs
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Function to save uploaded files to temporary files and store paths in session_state
23
+ def save_uploaded_files(uploaded_files):
24
+ if 'temp_file_paths' not in st.session_state:
25
+ st.session_state.temp_file_paths = []
26
+ for uploaded_file in uploaded_files:
27
+ suffix = os.path.splitext(uploaded_file.name)[1]
28
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
29
+ temp_file.write(uploaded_file.read())
30
+ temp_file.close()
31
+ st.session_state.temp_file_paths.append(temp_file.name)
32
+ # Cached function to load dataset
33
+ @st.cache_resource
34
+ def load_dataset(file_paths):
35
+ try:
36
+ ds = xr.open_mfdataset(file_paths, combine='by_coords').load()
37
+ return ds
38
+ except Exception as e:
39
+ st.error("Error loading dataset:")
40
+ st.error(traceback.format_exc())
41
+ return None
42
+
43
+ # Set page configuration
44
+ st.set_page_config(
45
+ page_title="Weather Data Processor",
46
+ layout="wide",
47
+ initial_sidebar_state="expanded",
48
+ )
49
+
50
+
51
+
52
+ # Create a header with two columns: one for the title and one for the model selector
53
+ header_col1, header_col2 = st.columns([4, 1]) # Adjust the ratio as needed
54
+
55
+ with header_col1:
56
+ st.title("🌦️ Weather & Climate Data Processor and Forecaster")
57
+
58
+ with header_col2:
59
+ st.markdown("### Select a Model")
60
+ selected_model = st.selectbox(
61
+ "",
62
+ options=["Aurora", "Climax", "Prithvi", "LSTM"],
63
+ index=0,
64
+ key="model_selector",
65
+ help="Select the model you want to use for processing the data."
66
+ )
67
+
68
+ st.write("---") # Horizontal separator
69
+
70
+ # --- Layout: Two Columns ---
71
+ left_col, right_col = st.columns([1, 2]) # Adjust column ratios as needed
72
+
73
+ with left_col:
74
+ st.header("🔧 Configuration")
75
+
76
+ # --- Dynamic Configuration Based on Selected Model ---
77
+ def get_model_configuration(model_name):
78
+ if model_name == "Prithvi":
79
+ st.subheader("Prithvi Model Configuration")
80
+
81
+ # Prithvi-specific configuration inputs
82
+ param1 = st.number_input("Prithvi Parameter 1", value=10, step=1)
83
+ param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi")
84
+ # Add other Prithvi-specific parameters here
85
+
86
+ config = {
87
+ "param1": param1,
88
+ "param2": param2,
89
+ # Include other parameters as needed
90
+ }
91
+
92
+ # --- Prithvi-Specific File Uploads ---
93
+ st.markdown("### Upload Data Files for Prithvi Model")
94
+
95
+ # File uploader for surface data
96
+ uploaded_surface_files = st.file_uploader(
97
+ "Upload Surface Data Files",
98
+ type=["nc", "netcdf"],
99
+ accept_multiple_files=True,
100
+ key="surface_uploader",
101
+ )
102
+
103
+ # File uploader for vertical data
104
+ uploaded_vertical_files = st.file_uploader(
105
+ "Upload Vertical Data Files",
106
+ type=["nc", "netcdf"],
107
+ accept_multiple_files=True,
108
+ key="vertical_uploader",
109
+ )
110
+
111
+ # Handle Climatology Files
112
+ st.markdown("### Upload Climatology Files (If Missing)")
113
+
114
+ # Climatology files paths
115
+ default_clim_dir = Path("Prithvi-WxC/examples/climatology")
116
+ surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
117
+ vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
118
+ surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
119
+ vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
120
+
121
+ # Check if climatology files exist
122
+ clim_files_exist = all(
123
+ [
124
+ surf_in_scal_path.exists(),
125
+ vert_in_scal_path.exists(),
126
+ surf_out_scal_path.exists(),
127
+ vert_out_scal_path.exists(),
128
+ ]
129
+ )
130
+
131
+ if not clim_files_exist:
132
+ st.warning("Climatology files are missing.")
133
+ uploaded_clim_surface = st.file_uploader(
134
+ "Upload Climatology Surface File",
135
+ type=["nc", "netcdf"],
136
+ key="clim_surface_uploader",
137
+ )
138
+ uploaded_clim_vertical = st.file_uploader(
139
+ "Upload Climatology Vertical File",
140
+ type=["nc", "netcdf"],
141
+ key="clim_vertical_uploader",
142
+ )
143
+
144
+ # Process uploaded climatology files
145
+ if uploaded_clim_surface and uploaded_clim_vertical:
146
+ clim_temp_dir = tempfile.mkdtemp()
147
+ clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
148
+ with open(clim_surf_path, "wb") as f:
149
+ f.write(uploaded_clim_surface.getbuffer())
150
+ clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
151
+ with open(clim_vert_path, "wb") as f:
152
+ f.write(uploaded_clim_vertical.getbuffer())
153
+ st.success("Climatology files uploaded and saved.")
154
+ else:
155
+ st.warning("Please upload both climatology surface and vertical files.")
156
+ else:
157
+ clim_surf_path = surf_in_scal_path
158
+ clim_vert_path = vert_in_scal_path
159
+
160
+ # Optional: Upload config.yaml
161
+ uploaded_config = st.file_uploader(
162
+ "Upload config.yaml",
163
+ type=["yaml", "yml"],
164
+ key="config_uploader",
165
+ )
166
+
167
+ if uploaded_config:
168
+ temp_config = tempfile.mktemp(suffix=".yaml")
169
+ with open(temp_config, "wb") as f:
170
+ f.write(uploaded_config.getbuffer())
171
+ config_path = Path(temp_config)
172
+ st.success("Config.yaml uploaded and saved.")
173
+ else:
174
+ # Use default config.yaml path
175
+ config_path = Path("Prithvi-WxC/examples/config.yaml")
176
+ if not config_path.exists():
177
+ st.error("Default config.yaml not found. Please upload a config file.")
178
+ st.stop()
179
+
180
+ # Optional: Upload model weights
181
+ uploaded_weights = st.file_uploader(
182
+ "Upload Model Weights (.pt)",
183
+ type=["pt"],
184
+ key="weights_uploader",
185
+ )
186
+
187
+ if uploaded_weights:
188
+ temp_weights = tempfile.mktemp(suffix=".pt")
189
+ with open(temp_weights, "wb") as f:
190
+ f.write(uploaded_weights.getbuffer())
191
+ weights_path = Path(temp_weights)
192
+ st.success("Model weights uploaded and saved.")
193
+ else:
194
+ # Use default weights path
195
+ weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
196
+ if not weights_path.exists():
197
+ st.error("Default model weights not found. Please upload model weights.")
198
+ st.stop()
199
+
200
+ return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path
201
+
202
+ else:
203
+ # For other models, provide a simple file uploader
204
+ st.subheader(f"{model_name} Model Data Upload")
205
+ st.markdown("### Drag and Drop Your Data Files Here")
206
+ uploaded_files = st.file_uploader(
207
+ f"Upload Data Files for {model_name}",
208
+ accept_multiple_files=True,
209
+ key=f"{model_name.lower()}_uploader",
210
+ type=["nc", "netcdf", "nc4"],
211
+ )
212
+ return uploaded_files
213
+
214
+ # Retrieve model-specific configuration and files
215
+ if selected_model == "Prithvi":
216
+ config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path = get_model_configuration(selected_model)
217
+ else:
218
+ uploaded_files = get_model_configuration(selected_model)
219
+
220
+ st.write("---") # Horizontal separator
221
+
222
+ # --- Run Inference Button ---
223
+ if st.button("🚀 Run Inference"):
224
+ with right_col:
225
+ st.header("📈 Inference Progress & Visualization")
226
+
227
+ # Initialize device
228
+ try:
229
+ torch.jit.enable_onednn_fusion(True)
230
+ if torch.cuda.is_available():
231
+ device = torch.device("cuda")
232
+ st.write(f"Using device: **{torch.cuda.get_device_name()}**")
233
+ torch.backends.cudnn.benchmark = True
234
+ torch.backends.cudnn.deterministic = True
235
+ else:
236
+ device = torch.device("cpu")
237
+ st.write("Using device: **CPU**")
238
+ except Exception as e:
239
+ st.error("Error initializing device:")
240
+ st.error(traceback.format_exc())
241
+ st.stop()
242
+
243
+ # Set random seeds
244
+ try:
245
+ random.seed(42)
246
+ if torch.cuda.is_available():
247
+ torch.cuda.manual_seed(42)
248
+ torch.manual_seed(42)
249
+ np.random.seed(42)
250
+ except Exception as e:
251
+ st.error("Error setting random seeds:")
252
+ st.error(traceback.format_exc())
253
+ st.stop()
254
+
255
+ # # Define variables and parameters based on dataset type
256
+ # if dataset_type == "MERRA2":
257
+ # surface_vars = [
258
+ # "EFLUX",
259
+ # "GWETROOT",
260
+ # "HFLUX",
261
+ # "LAI",
262
+ # "LWGAB",
263
+ # "LWGEM",
264
+ # "LWTUP",
265
+ # "PS",
266
+ # "QV2M",
267
+ # "SLP",
268
+ # "SWGNT",
269
+ # "SWTNT",
270
+ # "T2M",
271
+ # "TQI",
272
+ # "TQL",
273
+ # "TQV",
274
+ # "TS",
275
+ # "U10M",
276
+ # "V10M",
277
+ # "Z0M",
278
+ # ]
279
+ # static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
280
+ # vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
281
+ # levels = [
282
+ # 34.0,
283
+ # 39.0,
284
+ # 41.0,
285
+ # 43.0,
286
+ # 44.0,
287
+ # 45.0,
288
+ # 48.0,
289
+ # 51.0,
290
+ # 53.0,
291
+ # 56.0,
292
+ # 63.0,
293
+ # 68.0,
294
+ # 71.0,
295
+ # 72.0,
296
+ # ]
297
+ # elif dataset_type == "GEOS5":
298
+ # # Define GEOS5 specific variables
299
+ # surface_vars = [
300
+ # "GEOS5_EFLUX",
301
+ # "GEOS5_GWETROOT",
302
+ # "GEOS5_HFLUX",
303
+ # "GEOS5_LAI",
304
+ # "GEOS5_LWGAB",
305
+ # "GEOS5_LWGEM",
306
+ # "GEOS5_LWTUP",
307
+ # "GEOS5_PS",
308
+ # "GEOS5_QV2M",
309
+ # "GEOS5_SLP",
310
+ # "GEOS5_SWGNT",
311
+ # "GEOS5_SWTNT",
312
+ # "GEOS5_T2M",
313
+ # "GEOS5_TQI",
314
+ # "GEOS5_TQL",
315
+ # "GEOS5_TQV",
316
+ # "GEOS5_TS",
317
+ # "GEOS5_U10M",
318
+ # "GEOS5_V10M",
319
+ # "GEOS5_Z0M",
320
+ # ]
321
+ # static_surface_vars = ["GEOS5_FRACI", "GEOS5_FRLAND", "GEOS5_FROCEAN", "GEOS5_PHIS"]
322
+ # vertical_vars = ["GEOS5_CLOUD", "GEOS5_H", "GEOS5_OMEGA", "GEOS5_PL", "GEOS5_QI", "GEOS5_QL", "GEOS5_QV", "GEOS5_T", "GEOS5_U", "GEOS5_V"]
323
+ # levels = [
324
+ # # Define levels specific to GEOS5 if different
325
+ # 10.0,
326
+ # 20.0,
327
+ # 30.0,
328
+ # 40.0,
329
+ # 50.0,
330
+ # 60.0,
331
+ # 70.0,
332
+ # 80.0,
333
+ # ]
334
+ # else:
335
+ # st.error("Unsupported dataset type selected.")
336
+ # st.stop()
337
+
338
+ padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
339
+
340
+ residual = "climate"
341
+ masking_mode = "local"
342
+ decoder_shifting = True
343
+ masking_ratio = 0.99
344
+
345
+ positional_encoding = "fourier"
346
+
347
+ # --- Initialize Dataset ---
348
+ try:
349
+ with st.spinner("Initializing dataset..."):
350
+ if selected_model == "Prithvi":
351
+ pass
352
+ # # Validate climatology files
353
+ # if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical):
354
+ # st.error("Climatology files are missing. Please upload both climatology surface and vertical files.")
355
+ # st.stop()
356
+
357
+ # dataset = Merra2Dataset(
358
+ # time_range=time_range,
359
+ # lead_times=lead_times,
360
+ # input_times=input_times,
361
+ # data_path_surface=surf_dir,
362
+ # data_path_vertical=vert_dir,
363
+ # climatology_path_surface=clim_surf_path,
364
+ # climatology_path_vertical=clim_vert_path,
365
+ # surface_vars=surface_vars,
366
+ # static_surface_vars=static_surface_vars,
367
+ # vertical_vars=vertical_vars,
368
+ # levels=levels,
369
+ # positional_encoding=positional_encoding,
370
+ # )
371
+ # assert len(dataset) > 0, "There doesn't seem to be any valid data."
372
+ elif selected_model == "Aurora":
373
+ # TODO just temporary, replace this
374
+ if uploaded_files:
375
+ temp_file_paths = [] # List to store paths of temporary files
376
+ try:
377
+ # Save each uploaded file to a temporary file
378
+ save_uploaded_files(uploaded_files)
379
+ ds = load_dataset(st.session_state.temp_file_paths)
380
+
381
+ # Now, use xarray to open the multiple files
382
+ if ds:
383
+ st.success("Files successfully loaded!")
384
+ st.session_state.ds_subset = ds
385
+
386
+
387
+ # print(ds)
388
+ ds = ds.fillna(ds.mean())
389
+
390
+ desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
391
+
392
+ # Ensure that the 'lev' dimension exists
393
+ if 'lev' not in ds.dims:
394
+ raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
395
+
396
+ # Define the _prepare function
397
+ def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
398
+ # Select previous and current time steps
399
+ selected = x[[i - 6, i]]
400
+
401
+ # Add a batch dimension
402
+ selected = selected[None]
403
+
404
+ # Ensure data is contiguous
405
+ selected = selected.copy()
406
+
407
+ # Convert to PyTorch tensor
408
+ return torch.from_numpy(selected)
409
+
410
+ # Adjust latitudes and longitudes
411
+ lat = ds.lat.values * -1
412
+ lon = ds.lon.values + 180
413
+
414
+ # Subset the dataset to only include the desired pressure levels
415
+ ds_subset = ds.sel(lev=desired_levels, method="nearest")
416
+
417
+ # Verify that all desired levels are present
418
+ present_levels = ds_subset.lev.values
419
+ missing_levels = set(desired_levels) - set(present_levels)
420
+ if missing_levels:
421
+ raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
422
+
423
+ # Extract pressure levels after subsetting
424
+ lev = ds_subset.lev.values # Pressure levels in hPa
425
+
426
+ # Prepare surface variables at 1000 hPa
427
+ try:
428
+ lev_index_1000 = np.where(lev == 1000)[0][0]
429
+ except IndexError:
430
+ raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
431
+
432
+ T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
433
+ U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
434
+ V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
435
+ SLP = ds_subset.SLP.compute()
436
+
437
+ # Reorder static variables (selecting the first time index to remove the time dimension)
438
+ PHIS = ds_subset.PHIS.isel(time=0).compute()
439
+
440
+ # Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
441
+ atmos_levels = [int(level) for level in lev if level != 1000]
442
+
443
+ T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
444
+ U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
445
+ V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
446
+
447
+ # Select time index
448
+ num_times = ds_subset.time.size
449
+ i = 6 # Adjust as needed (1 <= i < num_times)
450
+
451
+ if i >= num_times or i < 1:
452
+ raise IndexError("Time index i is out of bounds.")
453
+
454
+ time_values = ds_subset.time.values
455
+ current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
456
+
457
+ # Prepare surface variables
458
+ surf_vars = {
459
+ "2t": _prepare(T_surface.values, i), # Two-meter temperature
460
+ "10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
461
+ "10v": _prepare(V_surface.values, i), # Ten-meter northward wind
462
+ "msl": _prepare(SLP.values, i), # Mean sea-level pressure
463
+ }
464
+
465
+ # Prepare static variables (now 2D tensors)
466
+ static_vars = {
467
+ "z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
468
+ # Add 'lsm' and 'slt' if available and needed
469
+ }
470
+
471
+ # Prepare atmospheric variables
472
+ atmos_vars = {
473
+ "t": _prepare(T_atm.values, i), # Temperature at desired levels
474
+ "u": _prepare(U_atm.values, i), # Eastward wind at desired levels
475
+ "v": _prepare(V_atm.values, i), # Southward wind at desired levels
476
+ }
477
+
478
+ # Define metadata
479
+ metadata = Metadata(
480
+ lat=torch.from_numpy(lat.copy()),
481
+ lon=torch.from_numpy(lon.copy()),
482
+ time=(current_time,),
483
+ atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
484
+ )
485
+
486
+ # Create the Batch object
487
+ batch = Batch(
488
+ surf_vars=surf_vars,
489
+ static_vars=static_vars,
490
+ atmos_vars=atmos_vars,
491
+ metadata=metadata
492
+ ) # Display the dataset or perform further processing
493
+
494
+ st.session_state['batch'] = batch
495
+
496
+ except Exception as e:
497
+ st.error(f"An error occurred: {e}")
498
+
499
+ # finally:
500
+ # # Clean up: Remove temporary files
501
+ # for path in temp_file_paths:
502
+ # try:
503
+ # os.remove(path)
504
+ # except Exception as e:
505
+ # st.warning(f"Could not delete temp file {path}: {e}")
506
+ else:
507
+ # For other models, implement their specific dataset initialization
508
+ # Placeholder: Replace with actual dataset initialization for other models
509
+ dataset = None # Replace with actual dataset
510
+ st.warning("Dataset initialization for this model is not implemented yet.")
511
+ st.stop()
512
+ st.success("Dataset initialized successfully.")
513
+ except Exception as e:
514
+ st.error("Error initializing dataset:")
515
+ st.error(traceback.format_exc())
516
+ st.stop()
517
+
518
+ # --- Load Scalers ---
519
+ try:
520
+ with st.spinner("Loading scalers..."):
521
+ if selected_model == "Prithvi":
522
+ pass
523
+ # # Assuming the scaler paths are the same as climatology paths
524
+ # surf_in_scal_path = clim_surf_path
525
+ # vert_in_scal_path = clim_vert_path
526
+ # surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc"
527
+ # vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc"
528
+
529
+ # # Check if output scaler files exist
530
+ # if not surf_out_scal_path.exists() or not vert_out_scal_path.exists():
531
+ # st.error("Anomaly variance scaler files are missing.")
532
+ # st.stop()
533
+
534
+ # in_mu, in_sig = input_scalers(
535
+ # surface_vars,
536
+ # vertical_vars,
537
+ # levels,
538
+ # surf_in_scal_path,
539
+ # vert_in_scal_path,
540
+ # )
541
+
542
+ # output_sig = output_scalers(
543
+ # surface_vars,
544
+ # vertical_vars,
545
+ # levels,
546
+ # surf_out_scal_path,
547
+ # vert_out_scal_path,
548
+ # )
549
+
550
+ # static_mu, static_sig = static_input_scalers(
551
+ # surf_in_scal_path,
552
+ # static_surface_vars,
553
+ # )
554
+ else:
555
+ # Load scalers for other models if applicable
556
+ # Placeholder: Replace with actual scaler loading for other models
557
+ in_mu, in_sig = None, None
558
+ output_sig = None
559
+ static_mu, static_sig = None, None
560
+ st.success("Scalers loaded successfully.")
561
+ except Exception as e:
562
+ st.error("Error loading scalers:")
563
+ st.error(traceback.format_exc())
564
+ st.stop()
565
+
566
+ # --- Load Configuration ---
567
+ try:
568
+ with st.spinner("Loading configuration..."):
569
+ if selected_model == "Prithvi":
570
+ with open(config_path, "r") as f:
571
+ config = yaml.safe_load(f)
572
+ # Validate config
573
+ required_params = [
574
+ "in_channels", "input_size_time", "in_channels_static",
575
+ "input_scalers_epsilon", "static_input_scalers_epsilon",
576
+ "n_lats_px", "n_lons_px", "patch_size_px",
577
+ "mask_unit_size_px", "embed_dim", "n_blocks_encoder",
578
+ "n_blocks_decoder", "mlp_multiplier", "n_heads",
579
+ "dropout", "drop_path", "parameter_dropout"
580
+ ]
581
+ missing_params = [param for param in required_params if param not in config.get("params", {})]
582
+ if missing_params:
583
+ st.error(f"Missing configuration parameters: {missing_params}")
584
+ st.stop()
585
+ else:
586
+ # Load configuration for other models if applicable
587
+ # Placeholder: Replace with actual configuration loading for other models
588
+ config = {}
589
+ st.success("Configuration loaded successfully.")
590
+ except Exception as e:
591
+ st.error("Error loading configuration:")
592
+ st.error(traceback.format_exc())
593
+ st.stop()
594
+
595
+ # --- Initialize the Model ---
596
+ try:
597
+ with st.spinner("Initializing model..."):
598
+ if selected_model == "Prithvi":
599
+ model = PrithviWxC(
600
+ in_channels=config["params"]["in_channels"],
601
+ input_size_time=config["params"]["input_size_time"],
602
+ in_channels_static=config["params"]["in_channels_static"],
603
+ input_scalers_mu=in_mu,
604
+ input_scalers_sigma=in_sig,
605
+ input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
606
+ static_input_scalers_mu=static_mu,
607
+ static_input_scalers_sigma=static_sig,
608
+ static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
609
+ output_scalers=output_sig**0.5,
610
+ n_lats_px=config["params"]["n_lats_px"],
611
+ n_lons_px=config["params"]["n_lons_px"],
612
+ patch_size_px=config["params"]["patch_size_px"],
613
+ mask_unit_size_px=config["params"]["mask_unit_size_px"],
614
+ mask_ratio_inputs=masking_ratio,
615
+ embed_dim=config["params"]["embed_dim"],
616
+ n_blocks_encoder=config["params"]["n_blocks_encoder"],
617
+ n_blocks_decoder=config["params"]["n_blocks_decoder"],
618
+ mlp_multiplier=config["params"]["mlp_multiplier"],
619
+ n_heads=config["params"]["n_heads"],
620
+ dropout=config["params"]["dropout"],
621
+ drop_path=config["params"]["drop_path"],
622
+ parameter_dropout=config["params"]["parameter_dropout"],
623
+ residual=residual,
624
+ masking_mode=masking_mode,
625
+ decoder_shifting=decoder_shifting,
626
+ positional_encoding=positional_encoding,
627
+ checkpoint_encoder=[],
628
+ checkpoint_decoder=[],
629
+ )
630
+ elif selected_model == "Aurora":
631
+ pass
632
+
633
+ else:
634
+
635
+ # Initialize other models here
636
+ # Placeholder: Replace with actual model initialization for other models
637
+ model = None
638
+ st.warning("Model initialization for this model is not implemented yet.")
639
+ st.stop()
640
+ # model.to(device)
641
+ st.success("Model initialized successfully.")
642
+ except Exception as e:
643
+ st.error("Error initializing model:")
644
+ st.error(traceback.format_exc())
645
+ st.stop()
646
+
647
+ # --- Load Model Weights ---
648
+ try:
649
+ with st.spinner("Loading model weights..."):
650
+ if selected_model == "Prithvi":
651
+ state_dict = torch.load(weights_path, map_location=device)
652
+ if "model_state" in state_dict:
653
+ state_dict = state_dict["model_state"]
654
+ model.load_state_dict(state_dict, strict=True)
655
+ model.to(device)
656
+ else:
657
+ # Load weights for other models if applicable
658
+ # Placeholder: Replace with actual weight loading for other models
659
+ pass
660
+ st.success("Model weights loaded successfully.")
661
+ except Exception as e:
662
+ st.error("Error loading model weights:")
663
+ st.error(traceback.format_exc())
664
+ st.stop()
665
+
666
+ # --- Prepare Data Batch ---
667
+ try:
668
+ with st.spinner("Preparing data batch..."):
669
+ if selected_model == "Prithvi":
670
+ data = next(iter(dataset))
671
+ batch = preproc([data], padding)
672
+ for k, v in batch.items():
673
+ if isinstance(v, torch.Tensor):
674
+ batch[k] = v.to(device)
675
+ elif selected_model == "Aurora":
676
+ batch = batch.regrid(res=0.25)
677
+
678
+ else:
679
+ # Prepare data batch for other models
680
+ # Placeholder: Replace with actual data preparation for other models
681
+ batch = None
682
+ st.success("Data batch prepared successfully.")
683
+ except Exception as e:
684
+ st.error("Error preparing data batch:")
685
+ st.error(traceback.format_exc())
686
+ st.stop()
687
+
688
+ # --- Run Inference ---
689
+ try:
690
+ with st.spinner("Running model inference..."):
691
+ if selected_model == "Prithvi":
692
+ model.eval()
693
+ with torch.no_grad():
694
+ out = model(batch)
695
+ elif selected_model == "Aurora":
696
+
697
+ model = Aurora(use_lora=False)
698
+ # model = Aurora()
699
+ model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
700
+ # model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
701
+
702
+ model.eval()
703
+ # model = model.to("cuda") # Uncomment if using a GPU
704
+
705
+ with torch.inference_mode():
706
+ out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]
707
+
708
+ model = model.to("cpu")
709
+ st.session_state.model = model
710
+ else:
711
+ # Run inference for other models
712
+ # Placeholder: Replace with actual inference code for other models
713
+ out = torch.randn(1, 10, 180, 360) # Dummy tensor
714
+ st.success("Model inference completed successfully.")
715
+ st.session_state['out'] = out
716
+ except Exception as e:
717
+ st.error("Error during model inference:")
718
+ st.error(traceback.format_exc())
719
+ st.stop()
720
+
721
+ # --- Visualization Settings ---
722
+ st.markdown("## 📊 Visualization Settings")
723
+
724
+ if 'out' in st.session_state and 'batch' in st.session_state and selected_model == "Prithvi":
725
+ # Display the shape of the output tensor
726
+ out_tensor = st.session_state['out']
727
+ st.write(f"**Output tensor shape:** {out_tensor.shape}")
728
+
729
+ # Ensure the output tensor has at least 4 dimensions (batch, variables, lat, lon)
730
+ if out_tensor.ndim < 4:
731
+ st.error("The output tensor does not have the expected number of dimensions (batch, variables, lat, lon).")
732
+ st.stop()
733
+
734
+ # Get the number of variables
735
+ num_variables = out_tensor.shape[1]
736
+
737
+ # Define variable names (update with your actual variable names)
738
+ variable_names = [f"Variable_{i}" for i in range(num_variables)]
739
+
740
+ # Visualization settings
741
+ col1, col2 = st.columns(2)
742
+
743
+ with col1:
744
+ # Select variable to plot
745
+ selected_variable_name = st.selectbox(
746
+ "Select Variable to Plot",
747
+ options=variable_names,
748
+ index=0,
749
+ help="Choose the variable you want to visualize."
750
+ )
751
+
752
+ # Select plot type
753
+ plot_type = st.selectbox(
754
+ "Select Plot Type",
755
+ options=["Contour", "Heatmap"],
756
+ index=0,
757
+ help="Choose the type of plot to display."
758
+ )
759
+
760
+ with col2:
761
+ # Select color map
762
+ cmap = st.selectbox(
763
+ "Select Color Map",
764
+ options=plt.colormaps(),
765
+ index=plt.colormaps().index("viridis"),
766
+ help="Choose the color map for the plot."
767
+ )
768
+
769
+ # Set number of levels (for contour plot)
770
+ if plot_type == "Contour":
771
+ num_levels = st.slider(
772
+ "Number of Contour Levels",
773
+ min_value=5,
774
+ max_value=100,
775
+ value=20,
776
+ step=5,
777
+ help="Set the number of contour levels."
778
+ )
779
+ else:
780
+ num_levels = None
781
+
782
+ # Find the index based on the selected name
783
+ variable_index = variable_names.index(selected_variable_name)
784
+
785
+ # Extract the selected variable
786
+ selected_variable = out_tensor[0, variable_index].cpu().numpy()
787
+
788
+ # Generate latitude and longitude arrays
789
+ lat = np.linspace(-90, 90, selected_variable.shape[0])
790
+ lon = np.linspace(-180, 180, selected_variable.shape[1])
791
+ X, Y = np.meshgrid(lon, lat)
792
+
793
+ # Plot the selected variable
794
+ st.markdown(f"### Plot of {selected_variable_name}")
795
+
796
+ # Matplotlib figure
797
+ fig, ax = plt.subplots(figsize=(10, 6))
798
+
799
+ if plot_type == "Contour":
800
+ # Generate the contour plot
801
+ contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap)
802
+ elif plot_type == "Heatmap":
803
+ # Generate the heatmap
804
+ contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto')
805
+
806
+ # Add a color bar
807
+ cbar = plt.colorbar(contour, ax=ax)
808
+ cbar.set_label(f'{selected_variable_name}', fontsize=12)
809
+
810
+ # Set aspect ratio and labels
811
+ ax.set_xlabel("Longitude", fontsize=12)
812
+ ax.set_ylabel("Latitude", fontsize=12)
813
+ ax.set_title(f"{selected_variable_name}", fontsize=14)
814
+
815
+ # Display the plot in Streamlit
816
+ st.pyplot(fig)
817
+
818
+ # Optional: Provide interactive Plotly plot
819
+ st.markdown("#### Interactive Plot")
820
+ if plot_type == "Contour":
821
+ fig_plotly = go.Figure(data=go.Contour(
822
+ z=selected_variable,
823
+ x=lon,
824
+ y=lat,
825
+ colorscale=cmap,
826
+ contours=dict(
827
+ coloring='fill',
828
+ showlabels=True,
829
+ labelfont=dict(size=12, color='white'),
830
+ ncontours=num_levels
831
+ )
832
+ ))
833
+ elif plot_type == "Heatmap":
834
+ fig_plotly = go.Figure(data=go.Heatmap(
835
+ z=selected_variable,
836
+ x=lon,
837
+ y=lat,
838
+ colorscale=cmap
839
+ ))
840
+
841
+ fig_plotly.update_layout(
842
+ xaxis_title="Longitude",
843
+ yaxis_title="Latitude",
844
+ autosize=False,
845
+ width=800,
846
+ height=600,
847
+ )
848
+
849
+ st.plotly_chart(fig_plotly)
850
+
851
+ elif 'out' in st.session_state and selected_model == "Aurora" and st.session_state['out'] is not None:
852
+ preds = st.session_state['out']
853
+ ds_subset = st.session_state.get('ds_subset', None)
854
+ batch = st.session_state.get('batch', None)
855
+
856
+ # **Determine Available Levels**
857
+ # For example, let's assume levels range from 0 to max_level_index
858
+ # You need to replace 'max_level_index' with the actual maximum level index in your data
859
+ try:
860
+ # Assuming 'lev' dimension exists and is 1D
861
+ levels = preds[0].atmos_vars["t"].shape[2] # Adjust based on your data structure
862
+ level_indices = list(range(levels))
863
+ except Exception as e:
864
+ st.error("Error determining available levels:")
865
+ st.error(traceback.format_exc())
866
+ levels = None # Set to None if levels cannot be determined
867
+
868
+ if levels is not None:
869
+ # **Add a Slider for Level Selection**
870
+ selected_level = st.slider(
871
+ 'Select Level',
872
+ min_value=0,
873
+ max_value=levels - 1,
874
+ value=11, # Default level index
875
+ step=1,
876
+ help="Select the vertical level for plotting."
877
+ )
878
+
879
+ # Loop through predictions and ground truths
880
+ for idx in range(len(preds)):
881
+ pred = preds[idx]
882
+ pred_time = pred.metadata.time[0]
883
+
884
+ # Display prediction time
885
+ st.write(f"### Prediction Time: {pred_time}")
886
+
887
+ # **Extract Data at Selected Level**
888
+ try:
889
+ # Update indices with the selected_level
890
+ pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15
891
+ truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15
892
+
893
+ except Exception as e:
894
+ st.error("Error extracting data for plotting:")
895
+ st.error(traceback.format_exc())
896
+ continue
897
+
898
+ # Extract latitude and longitude
899
+ try:
900
+ lat = np.array(pred.metadata.lat) # Assuming 'lat' is 1D
901
+ lon = np.array(pred.metadata.lon) # Assuming 'lon' is 1D
902
+ except Exception as e:
903
+ st.error("Error extracting latitude and longitude:")
904
+ st.error(traceback.format_exc())
905
+ continue
906
+
907
+ # Create a meshgrid for plotting
908
+ lon_grid, lat_grid = np.meshgrid(lon, lat)
909
+
910
+ # Create a Matplotlib figure with Cartopy projection
911
+ fig, axes = plt.subplots(
912
+ 1, 3, figsize=(18, 6),
913
+ subplot_kw={'projection': ccrs.PlateCarree()}
914
+ )
915
+
916
+ # **Ground Truth Plot**
917
+ im1 = axes[0].imshow(
918
+ truth_data,
919
+ extent=[lon.min(), lon.max(), lat.min(), lat.max()],
920
+ origin='lower',
921
+ cmap='coolwarm',
922
+ transform=ccrs.PlateCarree()
923
+ )
924
+ axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}")
925
+ axes[0].set_xlabel('Longitude')
926
+ axes[0].set_ylabel('Latitude')
927
+ plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05)
928
+
929
+ # **Prediction Plot**
930
+ im2 = axes[1].imshow(
931
+ pred_data,
932
+ extent=[lon.min(), lon.max(), lat.min(), lat.max()],
933
+ origin='lower',
934
+ cmap='coolwarm',
935
+ transform=ccrs.PlateCarree()
936
+ )
937
+ axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}")
938
+ axes[1].set_xlabel('Longitude')
939
+ axes[1].set_ylabel('Latitude')
940
+ plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05)
941
+
942
+ plt.tight_layout()
943
+
944
+ # Display the plot in Streamlit
945
+ st.pyplot(fig)
946
+ else:
947
+ st.error("Could not determine the available levels in the data.")
948
+
949
+
950
+ else:
951
+ st.warning("No output available to display or visualization is not implemented for this model.")
952
+
953
+ # --- End of Inference Button ---
954
+ else:
955
+ with right_col:
956
+ st.header("🖥️ Visualization & Progress")
957
+ st.info("Awaiting inference to display results.")
958
+
959
+
aurora ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8b11659b91d06f87c2d22e541dbcd0092baf2157
aurora_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from aurora import Aurora, Batch, Metadata
4
+ import numpy as np
5
+ from datetime import datetime
6
+
7
+ def aurora_config_ui():
8
+ st.subheader("Aurora Model Data Upload")
9
+ st.markdown("### Drag and Drop Your Data Files Here")
10
+ uploaded_files = st.file_uploader(
11
+ "Upload Data Files for Aurora",
12
+ accept_multiple_files=True,
13
+ key="aurora_uploader",
14
+ type=["nc", "netcdf", "nc4"]
15
+ )
16
+ return uploaded_files
17
+
18
+ def prepare_aurora_batch(ds):
19
+ desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
20
+
21
+ # Ensure that the 'lev' dimension exists
22
+ if 'lev' not in ds.dims:
23
+ raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
24
+
25
+ # Define the _prepare function
26
+ def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
27
+ # Select previous and current time steps
28
+ selected = x[[i - 6, i]]
29
+
30
+ # Add a batch dimension
31
+ selected = selected[None]
32
+
33
+ # Ensure data is contiguous
34
+ selected = selected.copy()
35
+
36
+ # Convert to PyTorch tensor
37
+ return torch.from_numpy(selected)
38
+
39
+ # Adjust latitudes and longitudes
40
+ lat = ds.lat.values * -1
41
+ lon = ds.lon.values + 180
42
+
43
+ # Subset the dataset to only include the desired pressure levels
44
+ ds_subset = ds.sel(lev=desired_levels, method="nearest")
45
+
46
+ # Verify that all desired levels are present
47
+ present_levels = ds_subset.lev.values
48
+ missing_levels = set(desired_levels) - set(present_levels)
49
+ if missing_levels:
50
+ raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
51
+
52
+ # Extract pressure levels after subsetting
53
+ lev = ds_subset.lev.values # Pressure levels in hPa
54
+
55
+ # Prepare surface variables at 1000 hPa
56
+ try:
57
+ lev_index_1000 = np.where(lev == 1000)[0][0]
58
+ except IndexError:
59
+ raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
60
+
61
+ T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
62
+ U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
63
+ V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
64
+ SLP = ds_subset.SLP.compute()
65
+
66
+ # Reorder static variables (selecting the first time index to remove the time dimension)
67
+ PHIS = ds_subset.PHIS.isel(time=0).compute()
68
+
69
+ # Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
70
+ atmos_levels = [int(level) for level in lev if level != 1000]
71
+
72
+ T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
73
+ U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
74
+ V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
75
+
76
+ # Select time index
77
+ num_times = ds_subset.time.size
78
+ i = 6 # Adjust as needed (1 <= i < num_times)
79
+
80
+ if i >= num_times or i < 1:
81
+ raise IndexError("Time index i is out of bounds.")
82
+
83
+ time_values = ds_subset.time.values
84
+ current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
85
+
86
+ # Prepare surface variables
87
+ surf_vars = {
88
+ "2t": _prepare(T_surface.values, i), # Two-meter temperature
89
+ "10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
90
+ "10v": _prepare(V_surface.values, i), # Ten-meter northward wind
91
+ "msl": _prepare(SLP.values, i), # Mean sea-level pressure
92
+ }
93
+
94
+ # Prepare static variables (now 2D tensors)
95
+ static_vars = {
96
+ "z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
97
+ # Add 'lsm' and 'slt' if available and needed
98
+ }
99
+
100
+ # Prepare atmospheric variables
101
+ atmos_vars = {
102
+ "t": _prepare(T_atm.values, i), # Temperature at desired levels
103
+ "u": _prepare(U_atm.values, i), # Eastward wind at desired levels
104
+ "v": _prepare(V_atm.values, i), # Southward wind at desired levels
105
+ }
106
+
107
+ # Define metadata
108
+ metadata = Metadata(
109
+ lat=torch.from_numpy(lat.copy()),
110
+ lon=torch.from_numpy(lon.copy()),
111
+ time=(current_time,),
112
+ atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
113
+ )
114
+
115
+ # Create the Batch object
116
+ batch = Batch(
117
+ surf_vars=surf_vars,
118
+ static_vars=static_vars,
119
+ atmos_vars=atmos_vars,
120
+ metadata=metadata
121
+ ) # Display the dataset or perform further processing
122
+ return batch
123
+
124
+ def initialize_aurora_model(device):
125
+ model = Aurora(use_lora=False)
126
+ # Load pretrained checkpoint if available
127
+ model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
128
+ model = model.to(device)
129
+ return model
config_utils.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import streamlit as st
3
+ from pathlib import Path
4
+
5
+ def load_config(config_path: Path):
6
+ with open(config_path, "r") as f:
7
+ config = yaml.safe_load(f)
8
+ return config
data_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import traceback
4
+ import streamlit as st
5
+ import xarray as xr
6
+ from typing import List
7
+ import numpy as np
8
+
9
+ @st.cache_resource
10
+ def save_uploaded_files(uploaded_files):
11
+ if 'temp_file_paths' not in st.session_state:
12
+ st.session_state.temp_file_paths = []
13
+ for uploaded_file in uploaded_files:
14
+ suffix = os.path.splitext(uploaded_file.name)[1]
15
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
16
+ temp_file.write(uploaded_file.read())
17
+ temp_file.close()
18
+ st.session_state.temp_file_paths.append(temp_file.name)
19
+
20
+
21
+ @st.cache_resource
22
+ def load_dataset(file_paths: List[str]):
23
+ try:
24
+ ds = xr.open_mfdataset(file_paths, combine='by_coords').load()
25
+ return ds
26
+ except Exception:
27
+ st.error("Error loading dataset:")
28
+ st.error(traceback.format_exc())
29
+ return None
30
+
31
+ @st.cache_resource
32
+ def load_dataset_pangu(file_path: str):
33
+ try:
34
+ ds = np.load(file_path)
35
+ return ds
36
+ except Exception:
37
+ st.error("Error loading dataset:")
38
+ st.error(traceback.format_exc())
39
+ return None
fengwu_utils.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ # from Pangu-Weather import *
4
+ import numpy as np
5
+ from datetime import datetime
6
+ import numpy as np
7
+ import onnx
8
+ import onnxruntime as ort
9
+ import matplotlib.pyplot as plt
10
+ import cartopy.crs as ccrs
11
+ import io
12
+
13
+ def fengwu_config_data():
14
+ st.subheader("FengWu Model Data Input")
15
+
16
+ # Detailed data description section
17
+ st.markdown("""
18
+ **Input Data Requirements (FengWu):**
19
+ FengWu takes **two consecutive six-hour atmospheric states** as input:
20
+ 1. **First Input (input1.npy)**: Atmospheric data at the initial time.
21
+ 2. **Second Input (input2.npy)**: Atmospheric data 6 hours later.
22
+
23
+ **Shape & Variables:**
24
+ Each input is a NumPy array with shape `(69, 721, 1440)`:
25
+ - **Dimension 0 (69 features):**
26
+ The first 4 features are surface variables:
27
+ 1. U10 (10-meter Eastward Wind)
28
+ 2. V10 (10-meter Northward Wind)
29
+ 3. T2M (2-meter Temperature)
30
+ 4. MSL (Mean Sea Level Pressure)
31
+
32
+ These are followed by non-surface variables, each with 13 pressure levels:
33
+ - Z (Geopotential)
34
+ - Q (Specific Humidity)
35
+ - U (Eastward Wind)
36
+ - V (Northward Wind)
37
+ - T (Temperature)
38
+
39
+ The 13 vertical levels are: [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] hPa
40
+
41
+ The total count is:
42
+ - Surface vars: 4
43
+ - For each non-surface var (Z, Q, U, V, T): 13 levels = 65 vars
44
+ 4 (surface) + 65 (5 vars * 13 levels) = 69 total features.
45
+
46
+ **Spatial & Coordinate Details:**
47
+ - Latitude dimension (721 points) ranges from 90°N to -90°S with ~0.25° spacing.
48
+ - Longitude dimension (1440 points) ranges from 0° to 360°E with ~0.25° spacing.
49
+ - Ensure data is single precision floats (`.astype(np.float32)`).
50
+
51
+ **Data Frequency & Forecasting Scheme:**
52
+ - `input1.npy` corresponds to a given time (e.g., 06:00 UTC Jan 1, 2018).
53
+ - `input2.npy` corresponds to 6 hours later (e.g., 12:00 UTC Jan 1, 2018).
54
+ - The model predicts future states at subsequent 6-hour intervals.
55
+
56
+ **Converting Your Data:**
57
+ - ERA5 `.nc` files or ECMWF `.grib` files can be converted to `.npy` using appropriate Python packages (`netCDF4` or `pygrib`).
58
+ - Ensure you follow the exact variable and level ordering as described.
59
+
60
+
61
+ """)
62
+
63
+ # File uploaders for FengWu input data (two consecutive time steps)
64
+ st.markdown("### Upload Your FengWu Input Data Files")
65
+ input1_file = st.file_uploader(
66
+ "Upload input1.npy (Initial Time)",
67
+ type=["npy"],
68
+ key="fengwu_input1"
69
+ )
70
+
71
+ input2_file = st.file_uploader(
72
+ "Upload input2.npy (6 Hours Later)",
73
+ type=["npy"],
74
+ key="fengwu_input2"
75
+ )
76
+
77
+ st.markdown("---")
78
+ st.markdown("### References & Resources")
79
+ st.markdown("""
80
+ - **Research Paper:** [FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead](https://arxiv.org/abs/2304.02948)
81
+
82
+ - **GitHub Source Code:** [Fengwu on GitHub](https://github.com/OpenEarthLab/FengWu?tab=readme-ov-file)
83
+ """)
84
+
85
+ return input1_file, input2_file
86
+
87
+
88
+ @st.cache_resource
89
+ def inference_6hrs_fengwu(input1, input2):
90
+ model_6 = onnx.load('FengWu/fengwu_v2.onnx')
91
+
92
+ # Set the behavier of onnxruntime
93
+ options = ort.SessionOptions()
94
+ options.enable_cpu_mem_arena=False
95
+ options.enable_mem_pattern = False
96
+ options.enable_mem_reuse = False
97
+ # Increase the number for faster inference and more memory consumption
98
+ options.intra_op_num_threads = 1
99
+
100
+ # Set the behavier of cuda provider
101
+ cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
102
+
103
+ # Initialize onnxruntime session for Pangu-Weather Models
104
+ ort_session_6 = ort.InferenceSession('FengWu/fengwu_v2.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
105
+
106
+
107
+ data_mean = np.load("FengWu/data_mean.npy")[:, np.newaxis, np.newaxis]
108
+ data_std = np.load("FengWu/data_std.npy")[:, np.newaxis, np.newaxis]
109
+
110
+ input1_after_norm = (input1 - data_mean) / data_std
111
+ input2_after_norm = (input2 - data_mean) / data_std
112
+ input = np.concatenate((input1_after_norm, input2_after_norm), axis=0)[np.newaxis, :, :, :]
113
+ input = input.astype(np.float32)
114
+
115
+ output = ort_session_6.run(None, {'input':input})[0]
116
+ output = (output[0, :69] * data_std) + data_mean
117
+
118
+ return output
119
+
120
+
121
+ @st.cache_resource
122
+ def inference_12hrs_fengwu(input1, input2):
123
+ model_6 = onnx.load('FengWu/fengwu_v2.onnx')
124
+
125
+ # Set the behavier of onnxruntime
126
+ options = ort.SessionOptions()
127
+ options.enable_cpu_mem_arena=False
128
+ options.enable_mem_pattern = False
129
+ options.enable_mem_reuse = False
130
+ # Increase the number for faster inference and more memory consumption
131
+ options.intra_op_num_threads = 1
132
+
133
+ # Set the behavier of cuda provider
134
+ cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
135
+
136
+ # Initialize onnxruntime session for Pangu-Weather Models
137
+ ort_session_6 = ort.InferenceSession('FengWu/fengwu_v2.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
138
+
139
+
140
+ data_mean = np.load("FengWu/data_mean.npy")[:, np.newaxis, np.newaxis]
141
+ data_std = np.load("FengWu/data_std.npy")[:, np.newaxis, np.newaxis]
142
+
143
+ input1_after_norm = (input1 - data_mean) / data_std
144
+ input2_after_norm = (input2 - data_mean) / data_std
145
+ input = np.concatenate((input1_after_norm, input2_after_norm), axis=0)[np.newaxis, :, :, :]
146
+ input = input.astype(np.float32)
147
+
148
+ for i in range(2):
149
+ output = ort_session_6.run(None, {'input':input})[0]
150
+ input = np.concatenate((input[:, 69:], output[:, :69]), axis=1)
151
+ output = (output[0, :69] * data_std) + data_mean
152
+ # print(output.shape)
153
+
154
+ return output
155
+
156
+ @st.cache_resource
157
+ def inference_custom_hrs_fengwu(input1, input2, forecast_hours):
158
+ model_6 = onnx.load('FengWu/fengwu_v2.onnx')
159
+
160
+ # Set the behavier of onnxruntime
161
+ options = ort.SessionOptions()
162
+ options.enable_cpu_mem_arena=False
163
+ options.enable_mem_pattern = False
164
+ options.enable_mem_reuse = False
165
+ # Increase the number for faster inference and more memory consumption
166
+ options.intra_op_num_threads = 1
167
+
168
+ # Set the behavier of cuda provider
169
+ cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
170
+
171
+ # Initialize onnxruntime session for Pangu-Weather Models
172
+ ort_session_6 = ort.InferenceSession('FengWu/fengwu_v2.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
173
+
174
+
175
+ data_mean = np.load("FengWu/data_mean.npy")[:, np.newaxis, np.newaxis]
176
+ data_std = np.load("FengWu/data_std.npy")[:, np.newaxis, np.newaxis]
177
+
178
+ input1_after_norm = (input1 - data_mean) / data_std
179
+ input2_after_norm = (input2 - data_mean) / data_std
180
+ input = np.concatenate((input1_after_norm, input2_after_norm), axis=0)[np.newaxis, :, :, :]
181
+ input = input.astype(np.float32)
182
+
183
+ for i in range(forecast_hours/6):
184
+ output = ort_session_6.run(None, {'input':input})[0]
185
+ input = np.concatenate((input[:, 69:], output[:, :69]), axis=1)
186
+ output = (output[0, :69] * data_std) + data_mean
187
+ # print(output.shape)
188
+
189
+ return output
190
+
191
+ def plot_fengwu_output(initial_data, predicted_data):
192
+ """
193
+ Plot initial and predicted Fengwu model outputs.
194
+
195
+ Parameters:
196
+ - initial_data: np.ndarray of shape (69, 721, 1440) representing the initial or input state.
197
+ - predicted_data: np.ndarray of shape (69, 721, 1440) representing the predicted state by Fengwu.
198
+ """
199
+ # Coordinate setup
200
+ lat = np.linspace(90, -90, 721) # Latitude from 90N to 90S
201
+ lon = np.linspace(0, 360, 1440) # Longitude from 0E to 360E
202
+
203
+ # Surface and upper-level variable definitions
204
+ surface_vars = ["U10", "V10", "T2M", "MSL"]
205
+ upper_vars = ["Z (Geopotential)", "Q (Specific Humidity)", "U (Eastward Wind)", "V (Northward Wind)", "T (Temperature)"]
206
+ upper_levels = [50,100,150,200,250,300,400,500,600,700,850,925,1000]
207
+
208
+ # Mapping of upper variable groups to their starting indices
209
+ # Each group has 13 levels, so indices shift by 13 for each subsequent group.
210
+ var_group_start = {
211
+ "Z (Geopotential)": 4, # Z starts at index 4
212
+ "Q (Specific Humidity)": 17, # Q = 4+13=17
213
+ "U (Eastward Wind)": 30, # U = 17+13=30
214
+ "V (Northward Wind)": 43,# V = 30+13=43
215
+ "T (Temperature)": 56 # T = 43+13=56
216
+ }
217
+
218
+ # --- Initial Data Visualization ---
219
+ st.subheader("Initial Data Visualization (Fengwu)")
220
+ init_col1, init_col2 = st.columns([1,1])
221
+
222
+ with init_col1:
223
+ init_data_choice = st.selectbox("Data Source", ["Upper-Air Data", "Surface Data"], key="fengwu_init_data_choice")
224
+ with init_col2:
225
+ if init_data_choice == "Upper-Air Data":
226
+ init_var = st.selectbox("Variable", upper_vars, key="fengwu_init_upper_var")
227
+ else:
228
+ init_var = st.selectbox("Variable", surface_vars, key="fengwu_init_surface_var")
229
+
230
+ # Select the data slice for initial data
231
+ if init_data_choice == "Upper-Air Data":
232
+ selected_level_hpa_init = st.select_slider(
233
+ "Select Pressure Level (hPa)",
234
+ options=upper_levels,
235
+ value=850, # Default to 850hPa
236
+ help="Select the pressure level in hPa.",
237
+ key="fengwu_init_level_hpa_slider"
238
+ )
239
+ level_index_init = upper_levels.index(selected_level_hpa_init)
240
+ start_index_init = var_group_start[init_var]
241
+ data_index_init = start_index_init + level_index_init
242
+ data_to_plot_init = initial_data[data_index_init, :, :]
243
+ title_init = f"Initial Upper-Air: {init_var} at {selected_level_hpa_init}hPa"
244
+ else:
245
+ # Surface variable
246
+ var_index_init = surface_vars.index(init_var)
247
+ data_to_plot_init = initial_data[var_index_init, :, :]
248
+ title_init = f"Initial Surface: {init_var}"
249
+
250
+ # Plot initial data
251
+ fig_init, ax_init = plt.subplots(figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()})
252
+ ax_init.set_title(title_init)
253
+ im_init = ax_init.imshow(data_to_plot_init, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
254
+ origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree())
255
+ ax_init.coastlines()
256
+ plt.colorbar(im_init, ax=ax_init, orientation='horizontal', pad=0.05)
257
+ st.pyplot(fig_init)
258
+
259
+ # --- Predicted Data Visualization ---
260
+ st.subheader("Predicted Data Visualization (Fengwu)")
261
+ pred_col1, pred_col2 = st.columns([1,1])
262
+
263
+ with pred_col1:
264
+ pred_data_choice = st.selectbox("Data Source", ["Upper-Air Data", "Surface Data"], key="fengwu_pred_data_choice")
265
+ with pred_col2:
266
+ if pred_data_choice == "Upper-Air Data":
267
+ pred_var = st.selectbox("Variable", upper_vars, key="fengwu_pred_upper_var")
268
+ else:
269
+ pred_var = st.selectbox("Variable", surface_vars, key="fengwu_pred_surface_var")
270
+
271
+ # Select the data slice for predicted data
272
+ if pred_data_choice == "Upper-Air Data":
273
+ selected_level_hpa_pred = st.select_slider(
274
+ "Select Pressure Level (hPa)",
275
+ options=upper_levels,
276
+ value=850, # Default to 850hPa
277
+ help="Select the pressure level in hPa.",
278
+ key="fengwu_pred_level_hpa_slider"
279
+ )
280
+ level_index_pred = upper_levels.index(selected_level_hpa_pred)
281
+ start_index_pred = var_group_start[pred_var]
282
+ data_index_pred = start_index_pred + level_index_pred
283
+ data_to_plot_pred = predicted_data[data_index_pred, :, :]
284
+ title_pred = f"Predicted Upper-Air: {pred_var} at {selected_level_hpa_pred}hPa"
285
+ else:
286
+ # Surface variable for predicted data
287
+ var_index_pred = surface_vars.index(pred_var)
288
+ data_to_plot_pred = predicted_data[var_index_pred, :, :]
289
+ title_pred = f"Predicted Surface: {pred_var}"
290
+
291
+ # Plot predicted data
292
+ fig_pred, ax_pred = plt.subplots(figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()})
293
+ ax_pred.set_title(title_pred)
294
+ im_pred = ax_pred.imshow(data_to_plot_pred, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
295
+ origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree())
296
+ ax_pred.coastlines()
297
+ plt.colorbar(im_pred, ax=ax_pred, orientation='horizontal', pad=0.05)
298
+ st.pyplot(fig_pred)
299
+
300
+ # --- Download Buttons ---
301
+ st.subheader("Download Predicted Fengwu Data")
302
+
303
+ # Convert predicted_data to binary format for download
304
+ buffer_pred = io.BytesIO()
305
+ np.save(buffer_pred, predicted_data)
306
+ buffer_pred.seek(0)
307
+
308
+ st.download_button(label="Download Predicted Fengwu Data",
309
+ data=buffer_pred,
310
+ file_name="predicted_fengwu.npy",
311
+ mime="application/octet-stream")
inference_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from aurora import rollout, Aurora
4
+
5
+ def run_inference(selected_model, model, batch, device):
6
+ if selected_model == "Prithvi":
7
+ model.eval()
8
+ with torch.no_grad():
9
+ out = model(batch)
10
+ return out
11
+ elif selected_model == "Aurora":
12
+ model.eval()
13
+ with torch.inference_mode():
14
+ # Example: Predict 2 steps ahead
15
+ out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]
16
+ return out
17
+ else:
18
+ st.error("Inference not implemented for this model.")
19
+ return None
pangu_utils.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ # from Pangu-Weather import *
4
+ import numpy as np
5
+ from datetime import datetime
6
+ import numpy as np
7
+ import onnx
8
+ import onnxruntime as ort
9
+ import matplotlib.pyplot as plt
10
+ import cartopy.crs as ccrs
11
+ import io
12
+
13
+
14
+ def pangu_config_data():
15
+ st.subheader("Pangu-Weather Model Data Input")
16
+
17
+ # Detailed data description section
18
+ st.markdown("""
19
+ **Input Data Requirements:**
20
+ Pangu-Weather uses two NumPy arrays to represent initial atmospheric conditions:
21
+ 1. **Surface Data (input_surface.npy)**
22
+ - Shape: `(4, 721, 1440)`
23
+ - Variables: MSLP, U10, V10, T2M in this exact order.
24
+ - **MSLP:** Mean Sea Level Pressure
25
+ - **U10:** 10-meter Eastward Wind
26
+ - **V10:** 10-meter Northward Wind
27
+ - **T2M:** 2-meter Temperature
28
+ 2. **Upper-Air Data (input_upper.npy)**
29
+ - Shape: `(5, 13, 721, 1440)`
30
+ - Variables (first dim): Z, Q, T, U, V in this exact order
31
+ - **Z:** Geopotential (Note: if your source provides geopotential height, multiply by 9.80665 to get geopotential)
32
+ - **Q:** Specific Humidity
33
+ - **T:** Temperature
34
+ - **U:** Eastward Wind
35
+ - **V:** Northward Wind
36
+ - Pressure Levels (second dim): 1000hPa, 925hPa, 850hPa, 700hPa, 600hPa, 500hPa, 400hPa, 300hPa, 250hPa, 200hPa, 150hPa, 100hPa, 50hPa in this exact order.
37
+
38
+ **Spatial & Coordinate Details:**
39
+ - Latitude dimension (721 points) ranges from 90°N to -90°S with a 0.25° spacing.
40
+ - Longitude dimension (1440 points) ranges from 0° to 359.75°E with a 0.25° spacing.
41
+ - Data should be single precision floats (`.astype(np.float32)`).
42
+
43
+ **Supported Data Sources:**
44
+ - ERA5 initial fields (strongly recommended).
45
+ - ECMWF initial fields (e.g., HRES forecast) can be used, but may result in a slight accuracy drop.
46
+ - Other types of initial fields are not currently supported due to potentially large discrepancies in data fields.
47
+
48
+ **Converting Your Data:**
49
+ - ERA5 `.nc` files can be converted to `.npy` using the `netCDF4` Python package.
50
+ - ECMWF `.grib` files can be converted to `.npy` using the `pygrib` Python package.
51
+ - Ensure the order of variables and pressure levels is exactly as described above.
52
+ """)
53
+
54
+ # File uploaders for surface and upper data separately
55
+ st.markdown("### Upload Your Input Data Files")
56
+ input_surface_file = st.file_uploader(
57
+ "Upload input_surface.npy",
58
+ type=["npy"],
59
+ key="pangu_input_surface"
60
+ )
61
+
62
+ input_upper_file = st.file_uploader(
63
+ "Upload input_upper.npy",
64
+ type=["npy"],
65
+ key="pangu_input_upper"
66
+ )
67
+
68
+ st.markdown("---")
69
+ st.markdown("### References & Resources")
70
+ st.markdown("""
71
+ - **Research Paper:** [Accurate medium-range global weather forecasting with 3D neural networks](https://www.nature.com/articles/s41586-023-06185-3)
72
+ - [Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast](https://arxiv.org/abs/2211.02556)
73
+ - **GitHub Source Code:** [Pangu-Weather on GitHub](https://github.com/198808xc/Pangu-Weather?tab=readme-ov-file)
74
+ """)
75
+
76
+ return input_surface_file, input_upper_file
77
+
78
+ def inference_24hrs(input, input_surface):
79
+ model_24 = onnx.load('Pangu-Weather/pangu_weather_24.onnx')
80
+
81
+ # Set the behavier of onnxruntime
82
+ options = ort.SessionOptions()
83
+ options.enable_cpu_mem_arena=False
84
+ options.enable_mem_pattern = False
85
+ options.enable_mem_reuse = False
86
+ # Increase the number for faster inference and more memory consumption
87
+ options.intra_op_num_threads = 1
88
+
89
+ # Set the behavier of cuda provider
90
+ cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
91
+
92
+ # Initialize onnxruntime session for Pangu-Weather Models
93
+ ort_session_24 = ort.InferenceSession('Pangu-Weather/pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider'])
94
+
95
+ # Run the inference session
96
+ output, output_surface = ort_session_24.run(None, {'input':input, 'input_surface':input_surface})
97
+
98
+ return output, output_surface
99
+
100
+ @st.cache_resource
101
+ def inference_6hrs(input, input_surface):
102
+ model_6 = onnx.load('Pangu-Weather/pangu_weather_6.onnx')
103
+
104
+ # Set the behavier of onnxruntime
105
+ options = ort.SessionOptions()
106
+ options.enable_cpu_mem_arena=False
107
+ options.enable_mem_pattern = False
108
+ options.enable_mem_reuse = False
109
+ # Increase the number for faster inference and more memory consumption
110
+ options.intra_op_num_threads = 1
111
+
112
+ # Set the behavier of cuda provider
113
+ cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
114
+
115
+ # Initialize onnxruntime session for Pangu-Weather Models
116
+ ort_session_6 = ort.InferenceSession('Pangu-Weather/pangu_weather_6.onnx', sess_options=options, providers=['CPUExecutionProvider'])
117
+
118
+ # Run the inference session
119
+ output, output_surface = ort_session_6.run(None, {'input':input, 'input_surface':input_surface})
120
+
121
+ return output, output_surface
122
+
123
+ @st.cache_resource
124
+ def inference_1hr(input, input_surface):
125
+ model_1 = onnx.load('Pangu-Weather/pangu_weather_1.onnx')
126
+
127
+ # Set the behavier of onnxruntime
128
+ options = ort.SessionOptions()
129
+ options.enable_cpu_mem_arena=False
130
+ options.enable_mem_pattern = False
131
+ options.enable_mem_reuse = False
132
+ # Increase the number for faster inference and more memory consumption
133
+ options.intra_op_num_threads = 1
134
+
135
+ # Set the behavier of cuda provider
136
+ cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
137
+
138
+ # Initialize onnxruntime session for Pangu-Weather Models
139
+ ort_session_1 = ort.InferenceSession('Pangu-Weather/pangu_weather_1.onnx', sess_options=options, providers=['CPUExecutionProvider'])
140
+
141
+ # Run the inference session
142
+ output, output_surface = ort_session_1.run(None, {'input':input, 'input_surface':input_surface})
143
+
144
+ return output, output_surface
145
+
146
+ @st.cache_resource
147
+ def inference_3hrs(input, input_surface):
148
+ model_3 = onnx.load('Pangu-Weather/pangu_weather_3.onnx')
149
+
150
+ # Set the behavier of onnxruntime
151
+ options = ort.SessionOptions()
152
+ options.enable_cpu_mem_arena=False
153
+ options.enable_mem_pattern = False
154
+ options.enable_mem_reuse = False
155
+ # Increase the number for faster inference and more memory consumption
156
+ options.intra_op_num_threads = 1
157
+
158
+ # Set the behavier of cuda provider
159
+ cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
160
+
161
+ # Initialize onnxruntime session for Pangu-Weather Models
162
+ ort_session_3 = ort.InferenceSession('Pangu-Weather/pangu_weather_3.onnx', sess_options=options, providers=['CPUExecutionProvider'])
163
+
164
+ # Run the inference session
165
+ output, output_surface = ort_session_3.run(None, {'input':input, 'input_surface':input_surface})
166
+
167
+ return output, output_surface
168
+
169
+ @st.cache_resource
170
+ def inference_custom_hrs(input, input_surface, forecast_hours):
171
+ # Ensure forecast_hours is a multiple of 24
172
+ if forecast_hours % 24 != 0:
173
+ raise ValueError("forecast_hours must be a multiple of 24.")
174
+
175
+ # Load the 24-hour model
176
+ model_24 = onnx.load('Pangu-Weather/pangu_weather_24.onnx')
177
+
178
+ # Configure ONNX Runtime session
179
+ options = ort.SessionOptions()
180
+ options.enable_cpu_mem_arena = False
181
+ options.enable_mem_pattern = False
182
+ options.enable_mem_reuse = False
183
+ options.intra_op_num_threads = 1
184
+
185
+ # Using CPUExecutionProvider for simplicity
186
+ ort_session_24 = ort.InferenceSession('Pangu-Weather/pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider'])
187
+
188
+ # Calculate how many 24-hour steps we need
189
+ steps = forecast_hours // 24
190
+
191
+ # Run the 24-hour model repeatedly
192
+ for i in range(steps):
193
+ output, output_surface = ort_session_24.run(None, {'input': input, 'input_surface': input_surface})
194
+ input, input_surface = output, output_surface
195
+
196
+ # Return the final predictions after completing all steps
197
+ return input, input_surface
198
+
199
+
200
+ def plot_pangu_output(upper_data, surface_data, out_upper, out_surface):
201
+ # Coordinate setup
202
+ lat = np.linspace(90, -90, 721) # Latitude grid
203
+ lon = np.linspace(0, 360, 1440) # Longitude grid
204
+
205
+ # Variable and level names
206
+ upper_vars = ["Z (Geopotential)", "Q (Specific Humidity)", "T (Temperature)", "U (Eastward Wind)", "V (Northward Wind)"]
207
+ upper_levels = ["1000hPa", "925hPa", "850hPa", "700hPa", "600hPa", "500hPa",
208
+ "400hPa", "300hPa", "250hPa", "200hPa", "150hPa", "100hPa", "50hPa"]
209
+ # Extract numeric hPa values for selection
210
+ upper_hpa_values = [int(l.replace("hPa", "")) for l in upper_levels]
211
+
212
+ surface_vars = ["MSLP", "U10", "V10", "T2M"]
213
+
214
+ # --- Initial Data Visualization ---
215
+ st.subheader("Initial Data Visualization")
216
+ init_col1, init_col2 = st.columns([1,1])
217
+
218
+ with init_col1:
219
+ init_data_choice = st.selectbox("Data Source", ["Upper-Air Data", "Surface Data"], key="init_data_choice")
220
+ with init_col2:
221
+ if init_data_choice == "Upper-Air Data":
222
+ init_var = st.selectbox("Variable", upper_vars, key="init_upper_var")
223
+ else:
224
+ init_var = st.selectbox("Variable", surface_vars, key="init_surface_var")
225
+
226
+ if init_data_choice == "Upper-Air Data":
227
+ selected_level_hpa_init = st.select_slider(
228
+ "Select Pressure Level (hPa)",
229
+ options=upper_hpa_values,
230
+ value=850, # Default to 850hPa
231
+ help="Select the pressure level in hPa.",
232
+ key="init_level_hpa_slider"
233
+ )
234
+ # Find the corresponding index from the selected hPa value
235
+ selected_level_index_init = upper_hpa_values.index(selected_level_hpa_init)
236
+ selected_var_index_init = upper_vars.index(init_var)
237
+ data_to_plot_init = upper_data[selected_var_index_init, selected_level_index_init, :, :]
238
+ title_init = f"Initial Upper-Air: {init_var} at {selected_level_hpa_init}hPa"
239
+ else:
240
+ selected_var_index_init = surface_vars.index(init_var)
241
+ data_to_plot_init = surface_data[selected_var_index_init, :, :]
242
+ title_init = f"Initial Surface: {init_var}"
243
+
244
+ # Plot initial data
245
+ fig_init, ax_init = plt.subplots(figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()})
246
+ ax_init.set_title(title_init)
247
+ im_init = ax_init.imshow(data_to_plot_init, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
248
+ origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree())
249
+ ax_init.coastlines()
250
+ plt.colorbar(im_init, ax=ax_init, orientation='horizontal', pad=0.05)
251
+ st.pyplot(fig_init)
252
+
253
+ # --- Predicted Data Visualization ---
254
+ st.subheader("Predicted Data Visualization")
255
+ pred_col1, pred_col2 = st.columns([1,1])
256
+
257
+ with pred_col1:
258
+ pred_data_choice = st.selectbox("Data Source", ["Upper-Air Data", "Surface Data"], key="pred_data_choice")
259
+ with pred_col2:
260
+ if pred_data_choice == "Upper-Air Data":
261
+ pred_var = st.selectbox("Variable", upper_vars, key="pred_upper_var")
262
+ else:
263
+ pred_var = st.selectbox("Variable", surface_vars, key="pred_surface_var")
264
+
265
+ if pred_data_choice == "Upper-Air Data":
266
+ selected_level_hpa_pred = st.select_slider(
267
+ "Select Pressure Level (hPa)",
268
+ options=upper_hpa_values,
269
+ value=850, # Default to 850hPa
270
+ help="Select the pressure level in hPa.",
271
+ key="pred_level_hpa_slider"
272
+ )
273
+ selected_level_index_pred = upper_hpa_values.index(selected_level_hpa_pred)
274
+ selected_var_index_pred = upper_vars.index(pred_var)
275
+ data_to_plot_pred = out_upper[selected_var_index_pred, selected_level_index_pred, :, :]
276
+ title_pred = f"Predicted Upper-Air: {pred_var} at {selected_level_hpa_pred}hPa"
277
+ else:
278
+ selected_var_index_pred = surface_vars.index(pred_var)
279
+ data_to_plot_pred = out_surface[selected_var_index_pred, :, :]
280
+ title_pred = f"Predicted Surface: {pred_var}"
281
+
282
+ # Plot predicted data
283
+ fig_pred, ax_pred = plt.subplots(figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()})
284
+ ax_pred.set_title(title_pred)
285
+ im_pred = ax_pred.imshow(data_to_plot_pred, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
286
+ origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree())
287
+ ax_pred.coastlines()
288
+ plt.colorbar(im_pred, ax=ax_pred, orientation='horizontal', pad=0.05)
289
+ st.pyplot(fig_pred)
290
+
291
+ # --- Download Buttons ---
292
+ st.subheader("Download Predicted Data")
293
+
294
+ # Convert out_upper and out_surface to binary format for download
295
+ buffer_upper = io.BytesIO()
296
+ np.save(buffer_upper, out_upper)
297
+ buffer_upper.seek(0)
298
+
299
+ buffer_surface = io.BytesIO()
300
+ np.save(buffer_surface, out_surface)
301
+ buffer_surface.seek(0)
302
+
303
+ st.download_button(label="Download Predicted Upper-Air Data",
304
+ data=buffer_upper,
305
+ file_name="predicted_upper.npy",
306
+ mime="application/octet-stream")
307
+
308
+ st.download_button(label="Download Predicted Surface Data",
309
+ data=buffer_surface,
310
+ file_name="predicted_surface.npy",
311
+ mime="application/octet-stream")
plot_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ import cartopy.crs as ccrs
6
+
7
+ def plot_prithvi_output(out_tensor):
8
+ if out_tensor is None:
9
+ st.warning("No output available for plotting.")
10
+ return
11
+
12
+ # Example visualization UI for Prithvi:
13
+ st.markdown("## 📊 Visualization Settings")
14
+ # Extract shape and variable names as needed
15
+ if out_tensor.ndim < 4:
16
+ st.error("The output tensor does not have the expected dimensions.")
17
+ return
18
+
19
+ num_variables = out_tensor.shape[1]
20
+ variable_names = [f"Variable_{i}" for i in range(num_variables)]
21
+
22
+ col1, col2 = st.columns(2)
23
+ with col1:
24
+ selected_variable_name = st.selectbox(
25
+ "Select Variable to Plot",
26
+ options=variable_names,
27
+ index=0,
28
+ help="Choose the variable to visualize."
29
+ )
30
+ plot_type = st.selectbox("Select Plot Type", ["Contour", "Heatmap"], index=0)
31
+
32
+ with col2:
33
+ cmap = st.selectbox("Select Color Map", options=plt.colormaps(), index=plt.colormaps().index("viridis"))
34
+ if plot_type == "Contour":
35
+ num_levels = st.slider("Number of Contour Levels", 5, 100, 20, 5)
36
+ else:
37
+ num_levels = None
38
+
39
+ variable_index = variable_names.index(selected_variable_name)
40
+ selected_variable = out_tensor[0, variable_index].cpu().numpy()
41
+ lat = np.linspace(-90, 90, selected_variable.shape[0])
42
+ lon = np.linspace(-180, 180, selected_variable.shape[1])
43
+ X, Y = np.meshgrid(lon, lat)
44
+
45
+ st.markdown(f"### Plot of {selected_variable_name}")
46
+ fig, ax = plt.subplots(figsize=(10, 6))
47
+ if plot_type == "Contour":
48
+ contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap)
49
+ else:
50
+ contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto')
51
+
52
+ cbar = plt.colorbar(contour, ax=ax)
53
+ cbar.set_label(f'{selected_variable_name}', fontsize=12)
54
+ ax.set_xlabel("Longitude", fontsize=12)
55
+ ax.set_ylabel("Latitude", fontsize=12)
56
+ ax.set_title(selected_variable_name, fontsize=14)
57
+ st.pyplot(fig)
58
+
59
+ # Plotly interactive plot
60
+ st.markdown("#### Interactive Plot")
61
+ if plot_type == "Contour":
62
+ fig_plotly = go.Figure(data=go.Contour(
63
+ z=selected_variable,
64
+ x=lon,
65
+ y=lat,
66
+ colorscale=cmap,
67
+ contours=dict(coloring='fill', showlabels=True, labelfont=dict(size=12, color='white'), ncontours=num_levels)
68
+ ))
69
+ else:
70
+ fig_plotly = go.Figure(data=go.Heatmap(z=selected_variable, x=lon, y=lat, colorscale=cmap))
71
+
72
+ fig_plotly.update_layout(
73
+ xaxis_title="Longitude",
74
+ yaxis_title="Latitude",
75
+ width=800,
76
+ height=600,
77
+ )
78
+ st.plotly_chart(fig_plotly)
79
+
80
+
81
+ def plot_aurora_output(preds, ds_subset):
82
+ if preds is None or ds_subset is None:
83
+ st.error("No predictions or dataset subset available for visualization.")
84
+ return
85
+
86
+ try:
87
+ levels = preds[0].atmos_vars["t"].shape[2]
88
+ except:
89
+ st.error("Could not determine available levels in the data.")
90
+ return
91
+
92
+ selected_level = st.slider('Select Level', 0, levels - 1, 11, 1)
93
+
94
+ for idx, pred in enumerate(preds):
95
+ pred_time = pred.metadata.time[0]
96
+
97
+ try:
98
+ pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15
99
+ truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15
100
+ except Exception as e:
101
+ st.error("Error extracting data for plotting:")
102
+ st.error(e)
103
+ continue
104
+
105
+ lat = np.array(pred.metadata.lat)
106
+ lon = np.array(pred.metadata.lon)
107
+
108
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()})
109
+
110
+ # Ground Truth
111
+ im1 = axes[0].imshow(
112
+ truth_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
113
+ origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree()
114
+ )
115
+ axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}")
116
+ plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05)
117
+
118
+ # Prediction
119
+ im2 = axes[1].imshow(
120
+ pred_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
121
+ origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree()
122
+ )
123
+ axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}")
124
+ plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05)
125
+
126
+ plt.tight_layout()
127
+ st.pyplot(fig)
prithvi_utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ from pathlib import Path
4
+ import torch
5
+ import traceback
6
+ import yaml
7
+ # from Prithvi import PrithviWxC, Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc
8
+
9
+ def prithvi_config_ui():
10
+ st.subheader("Prithvi Model Configuration")
11
+ param1 = st.number_input("Prithvi Parameter 1", value=10, step=1)
12
+ param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi")
13
+
14
+ config = {"param1": param1, "param2": param2}
15
+
16
+ st.markdown("### Upload Data Files for Prithvi Model")
17
+ uploaded_surface_files = st.file_uploader(
18
+ "Upload Surface Data Files",
19
+ type=["nc", "netcdf"],
20
+ accept_multiple_files=True,
21
+ key="surface_uploader",
22
+ )
23
+
24
+ uploaded_vertical_files = st.file_uploader(
25
+ "Upload Vertical Data Files",
26
+ type=["nc", "netcdf"],
27
+ accept_multiple_files=True,
28
+ key="vertical_uploader",
29
+ )
30
+
31
+ st.markdown("### Upload Climatology Files (If Missing)")
32
+ default_clim_dir = Path("Prithvi-WxC/examples/climatology")
33
+ surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
34
+ vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
35
+ surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
36
+ vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
37
+ clim_files_exist = all([
38
+ surf_in_scal_path.exists(),
39
+ vert_in_scal_path.exists(),
40
+ surf_out_scal_path.exists(),
41
+ vert_out_scal_path.exists(),
42
+ ])
43
+
44
+ if not clim_files_exist:
45
+ st.warning("Climatology files are missing.")
46
+ uploaded_clim_surface = st.file_uploader(
47
+ "Upload Climatology Surface File",
48
+ type=["nc", "netcdf"],
49
+ key="clim_surface_uploader",
50
+ )
51
+ uploaded_clim_vertical = st.file_uploader(
52
+ "Upload Climatology Vertical File",
53
+ type=["nc", "netcdf"],
54
+ key="clim_vertical_uploader",
55
+ )
56
+ if uploaded_clim_surface and uploaded_clim_vertical:
57
+ clim_temp_dir = tempfile.mkdtemp()
58
+ clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
59
+ with open(clim_surf_path, "wb") as f:
60
+ f.write(uploaded_clim_surface.getbuffer())
61
+ clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
62
+ with open(clim_vert_path, "wb") as f:
63
+ f.write(uploaded_clim_vertical.getbuffer())
64
+ st.success("Climatology files uploaded and saved.")
65
+ else:
66
+ st.warning("Please upload both climatology surface and vertical files.")
67
+ clim_surf_path, clim_vert_path = None, None
68
+ else:
69
+ clim_surf_path = surf_in_scal_path
70
+ clim_vert_path = vert_in_scal_path
71
+
72
+ uploaded_config = st.file_uploader(
73
+ "Upload config.yaml",
74
+ type=["yaml", "yml"],
75
+ key="config_uploader",
76
+ )
77
+
78
+ if uploaded_config:
79
+ temp_config = tempfile.mktemp(suffix=".yaml")
80
+ with open(temp_config, "wb") as f:
81
+ f.write(uploaded_config.getbuffer())
82
+ config_path = Path(temp_config)
83
+ st.success("Config.yaml uploaded and saved.")
84
+ else:
85
+ config_path = Path("Prithvi-WxC/examples/config.yaml")
86
+ if not config_path.exists():
87
+ st.error("Default config.yaml not found. Please upload a config file.")
88
+ st.stop()
89
+
90
+ uploaded_weights = st.file_uploader(
91
+ "Upload Model Weights (.pt)",
92
+ type=["pt"],
93
+ key="weights_uploader",
94
+ )
95
+
96
+ if uploaded_weights:
97
+ temp_weights = tempfile.mktemp(suffix=".pt")
98
+ with open(temp_weights, "wb") as f:
99
+ f.write(uploaded_weights.getbuffer())
100
+ weights_path = Path(temp_weights)
101
+ st.success("Model weights uploaded and saved.")
102
+ else:
103
+ weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
104
+ if not weights_path.exists():
105
+ st.error("Default model weights not found. Please upload model weights.")
106
+ st.stop()
107
+
108
+ return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path
109
+
110
+
111
+ def initialize_prithvi_model(config, config_path, weights_path, device):
112
+ # Load the configuration
113
+ with open(config_path, "r") as f:
114
+ cfg = yaml.safe_load(f)
115
+
116
+ # Validate and load scalers, etc.
117
+ # Insert your logic here (loading scalers, etc.)
118
+ # Example (pseudo-code):
119
+ # in_mu, in_sig = input_scalers(...)
120
+ # output_sig = output_scalers(...)
121
+ # static_mu, static_sig = static_input_scalers(...)
122
+
123
+ # from Prithvi import PrithviWxC
124
+ # model = PrithviWxC(**cfg["params"], ...)
125
+ # state_dict = torch.load(weights_path, map_location=device)
126
+ # model.load_state_dict(state_dict["model_state"] if "model_state" in state_dict else state_dict, strict=True)
127
+ # model.to(device)
128
+
129
+ # Placeholder returns until actual logic is implemented
130
+ model = None
131
+ in_mu, in_sig, output_sig, static_mu, static_sig = None, None, None, None, None
132
+ return model, in_mu, in_sig, output_sig, static_mu, static_sig
133
+
134
+
135
+ def prepare_prithvi_batch(uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device):
136
+ # Prepare your dataset and batch for Prithvi inference
137
+ # dataset = Merra2Dataset(...)
138
+ # data = next(iter(dataset))
139
+ # batch = preproc([data], padding={...})
140
+ # for k,v in batch.items():
141
+ # if isinstance(v, torch.Tensor):
142
+ # batch[k] = v.to(device)
143
+
144
+ # Placeholder until implemented
145
+ return None