qq1990's picture
init
100edb4
import streamlit as st
import torch
import random
import numpy as np
import yaml
from pathlib import Path
import tempfile
import traceback
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from Prithvi import * # Ensure this import includes your model and dataset classes
import xarray as xr
from aurora import Batch, Metadata
from aurora import Aurora, rollout
import logging
import matplotlib.pyplot as plt
import numpy as np
import cartopy.crs as ccrs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Function to save uploaded files to temporary files and store paths in session_state
def save_uploaded_files(uploaded_files):
if 'temp_file_paths' not in st.session_state:
st.session_state.temp_file_paths = []
for uploaded_file in uploaded_files:
suffix = os.path.splitext(uploaded_file.name)[1]
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
temp_file.write(uploaded_file.read())
temp_file.close()
st.session_state.temp_file_paths.append(temp_file.name)
# Cached function to load dataset
@st.cache_resource
def load_dataset(file_paths):
try:
ds = xr.open_mfdataset(file_paths, combine='by_coords').load()
return ds
except Exception as e:
st.error("Error loading dataset:")
st.error(traceback.format_exc())
return None
# Set page configuration
st.set_page_config(
page_title="Weather Data Processor",
layout="wide",
initial_sidebar_state="expanded",
)
# Create a header with two columns: one for the title and one for the model selector
header_col1, header_col2 = st.columns([4, 1]) # Adjust the ratio as needed
with header_col1:
st.title("🌦️ Weather & Climate Data Processor and Forecaster")
with header_col2:
st.markdown("### Select a Model")
selected_model = st.selectbox(
"",
options=["Aurora", "Climax", "Prithvi", "LSTM"],
index=0,
key="model_selector",
help="Select the model you want to use for processing the data."
)
st.write("---") # Horizontal separator
# --- Layout: Two Columns ---
left_col, right_col = st.columns([1, 2]) # Adjust column ratios as needed
with left_col:
st.header("🔧 Configuration")
# --- Dynamic Configuration Based on Selected Model ---
def get_model_configuration(model_name):
if model_name == "Prithvi":
st.subheader("Prithvi Model Configuration")
# Prithvi-specific configuration inputs
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1)
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi")
# Add other Prithvi-specific parameters here
config = {
"param1": param1,
"param2": param2,
# Include other parameters as needed
}
# --- Prithvi-Specific File Uploads ---
st.markdown("### Upload Data Files for Prithvi Model")
# File uploader for surface data
uploaded_surface_files = st.file_uploader(
"Upload Surface Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="surface_uploader",
)
# File uploader for vertical data
uploaded_vertical_files = st.file_uploader(
"Upload Vertical Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="vertical_uploader",
)
# Handle Climatology Files
st.markdown("### Upload Climatology Files (If Missing)")
# Climatology files paths
default_clim_dir = Path("Prithvi-WxC/examples/climatology")
surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
# Check if climatology files exist
clim_files_exist = all(
[
surf_in_scal_path.exists(),
vert_in_scal_path.exists(),
surf_out_scal_path.exists(),
vert_out_scal_path.exists(),
]
)
if not clim_files_exist:
st.warning("Climatology files are missing.")
uploaded_clim_surface = st.file_uploader(
"Upload Climatology Surface File",
type=["nc", "netcdf"],
key="clim_surface_uploader",
)
uploaded_clim_vertical = st.file_uploader(
"Upload Climatology Vertical File",
type=["nc", "netcdf"],
key="clim_vertical_uploader",
)
# Process uploaded climatology files
if uploaded_clim_surface and uploaded_clim_vertical:
clim_temp_dir = tempfile.mkdtemp()
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
with open(clim_surf_path, "wb") as f:
f.write(uploaded_clim_surface.getbuffer())
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
with open(clim_vert_path, "wb") as f:
f.write(uploaded_clim_vertical.getbuffer())
st.success("Climatology files uploaded and saved.")
else:
st.warning("Please upload both climatology surface and vertical files.")
else:
clim_surf_path = surf_in_scal_path
clim_vert_path = vert_in_scal_path
# Optional: Upload config.yaml
uploaded_config = st.file_uploader(
"Upload config.yaml",
type=["yaml", "yml"],
key="config_uploader",
)
if uploaded_config:
temp_config = tempfile.mktemp(suffix=".yaml")
with open(temp_config, "wb") as f:
f.write(uploaded_config.getbuffer())
config_path = Path(temp_config)
st.success("Config.yaml uploaded and saved.")
else:
# Use default config.yaml path
config_path = Path("Prithvi-WxC/examples/config.yaml")
if not config_path.exists():
st.error("Default config.yaml not found. Please upload a config file.")
st.stop()
# Optional: Upload model weights
uploaded_weights = st.file_uploader(
"Upload Model Weights (.pt)",
type=["pt"],
key="weights_uploader",
)
if uploaded_weights:
temp_weights = tempfile.mktemp(suffix=".pt")
with open(temp_weights, "wb") as f:
f.write(uploaded_weights.getbuffer())
weights_path = Path(temp_weights)
st.success("Model weights uploaded and saved.")
else:
# Use default weights path
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
if not weights_path.exists():
st.error("Default model weights not found. Please upload model weights.")
st.stop()
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path
else:
# For other models, provide a simple file uploader
st.subheader(f"{model_name} Model Data Upload")
st.markdown("### Drag and Drop Your Data Files Here")
uploaded_files = st.file_uploader(
f"Upload Data Files for {model_name}",
accept_multiple_files=True,
key=f"{model_name.lower()}_uploader",
type=["nc", "netcdf", "nc4"],
)
return uploaded_files
# Retrieve model-specific configuration and files
if selected_model == "Prithvi":
config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path = get_model_configuration(selected_model)
else:
uploaded_files = get_model_configuration(selected_model)
st.write("---") # Horizontal separator
# --- Run Inference Button ---
if st.button("🚀 Run Inference"):
with right_col:
st.header("📈 Inference Progress & Visualization")
# Initialize device
try:
torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
device = torch.device("cuda")
st.write(f"Using device: **{torch.cuda.get_device_name()}**")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
else:
device = torch.device("cpu")
st.write("Using device: **CPU**")
except Exception as e:
st.error("Error initializing device:")
st.error(traceback.format_exc())
st.stop()
# Set random seeds
try:
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)
except Exception as e:
st.error("Error setting random seeds:")
st.error(traceback.format_exc())
st.stop()
# # Define variables and parameters based on dataset type
# if dataset_type == "MERRA2":
# surface_vars = [
# "EFLUX",
# "GWETROOT",
# "HFLUX",
# "LAI",
# "LWGAB",
# "LWGEM",
# "LWTUP",
# "PS",
# "QV2M",
# "SLP",
# "SWGNT",
# "SWTNT",
# "T2M",
# "TQI",
# "TQL",
# "TQV",
# "TS",
# "U10M",
# "V10M",
# "Z0M",
# ]
# static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
# vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
# levels = [
# 34.0,
# 39.0,
# 41.0,
# 43.0,
# 44.0,
# 45.0,
# 48.0,
# 51.0,
# 53.0,
# 56.0,
# 63.0,
# 68.0,
# 71.0,
# 72.0,
# ]
# elif dataset_type == "GEOS5":
# # Define GEOS5 specific variables
# surface_vars = [
# "GEOS5_EFLUX",
# "GEOS5_GWETROOT",
# "GEOS5_HFLUX",
# "GEOS5_LAI",
# "GEOS5_LWGAB",
# "GEOS5_LWGEM",
# "GEOS5_LWTUP",
# "GEOS5_PS",
# "GEOS5_QV2M",
# "GEOS5_SLP",
# "GEOS5_SWGNT",
# "GEOS5_SWTNT",
# "GEOS5_T2M",
# "GEOS5_TQI",
# "GEOS5_TQL",
# "GEOS5_TQV",
# "GEOS5_TS",
# "GEOS5_U10M",
# "GEOS5_V10M",
# "GEOS5_Z0M",
# ]
# static_surface_vars = ["GEOS5_FRACI", "GEOS5_FRLAND", "GEOS5_FROCEAN", "GEOS5_PHIS"]
# vertical_vars = ["GEOS5_CLOUD", "GEOS5_H", "GEOS5_OMEGA", "GEOS5_PL", "GEOS5_QI", "GEOS5_QL", "GEOS5_QV", "GEOS5_T", "GEOS5_U", "GEOS5_V"]
# levels = [
# # Define levels specific to GEOS5 if different
# 10.0,
# 20.0,
# 30.0,
# 40.0,
# 50.0,
# 60.0,
# 70.0,
# 80.0,
# ]
# else:
# st.error("Unsupported dataset type selected.")
# st.stop()
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
residual = "climate"
masking_mode = "local"
decoder_shifting = True
masking_ratio = 0.99
positional_encoding = "fourier"
# --- Initialize Dataset ---
try:
with st.spinner("Initializing dataset..."):
if selected_model == "Prithvi":
pass
# # Validate climatology files
# if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical):
# st.error("Climatology files are missing. Please upload both climatology surface and vertical files.")
# st.stop()
# dataset = Merra2Dataset(
# time_range=time_range,
# lead_times=lead_times,
# input_times=input_times,
# data_path_surface=surf_dir,
# data_path_vertical=vert_dir,
# climatology_path_surface=clim_surf_path,
# climatology_path_vertical=clim_vert_path,
# surface_vars=surface_vars,
# static_surface_vars=static_surface_vars,
# vertical_vars=vertical_vars,
# levels=levels,
# positional_encoding=positional_encoding,
# )
# assert len(dataset) > 0, "There doesn't seem to be any valid data."
elif selected_model == "Aurora":
# TODO just temporary, replace this
if uploaded_files:
temp_file_paths = [] # List to store paths of temporary files
try:
# Save each uploaded file to a temporary file
save_uploaded_files(uploaded_files)
ds = load_dataset(st.session_state.temp_file_paths)
# Now, use xarray to open the multiple files
if ds:
st.success("Files successfully loaded!")
st.session_state.ds_subset = ds
# print(ds)
ds = ds.fillna(ds.mean())
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
# Ensure that the 'lev' dimension exists
if 'lev' not in ds.dims:
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
# Define the _prepare function
def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
# Select previous and current time steps
selected = x[[i - 6, i]]
# Add a batch dimension
selected = selected[None]
# Ensure data is contiguous
selected = selected.copy()
# Convert to PyTorch tensor
return torch.from_numpy(selected)
# Adjust latitudes and longitudes
lat = ds.lat.values * -1
lon = ds.lon.values + 180
# Subset the dataset to only include the desired pressure levels
ds_subset = ds.sel(lev=desired_levels, method="nearest")
# Verify that all desired levels are present
present_levels = ds_subset.lev.values
missing_levels = set(desired_levels) - set(present_levels)
if missing_levels:
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
# Extract pressure levels after subsetting
lev = ds_subset.lev.values # Pressure levels in hPa
# Prepare surface variables at 1000 hPa
try:
lev_index_1000 = np.where(lev == 1000)[0][0]
except IndexError:
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
SLP = ds_subset.SLP.compute()
# Reorder static variables (selecting the first time index to remove the time dimension)
PHIS = ds_subset.PHIS.isel(time=0).compute()
# Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
atmos_levels = [int(level) for level in lev if level != 1000]
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
# Select time index
num_times = ds_subset.time.size
i = 6 # Adjust as needed (1 <= i < num_times)
if i >= num_times or i < 1:
raise IndexError("Time index i is out of bounds.")
time_values = ds_subset.time.values
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
# Prepare surface variables
surf_vars = {
"2t": _prepare(T_surface.values, i), # Two-meter temperature
"10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
"10v": _prepare(V_surface.values, i), # Ten-meter northward wind
"msl": _prepare(SLP.values, i), # Mean sea-level pressure
}
# Prepare static variables (now 2D tensors)
static_vars = {
"z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
# Add 'lsm' and 'slt' if available and needed
}
# Prepare atmospheric variables
atmos_vars = {
"t": _prepare(T_atm.values, i), # Temperature at desired levels
"u": _prepare(U_atm.values, i), # Eastward wind at desired levels
"v": _prepare(V_atm.values, i), # Southward wind at desired levels
}
# Define metadata
metadata = Metadata(
lat=torch.from_numpy(lat.copy()),
lon=torch.from_numpy(lon.copy()),
time=(current_time,),
atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
)
# Create the Batch object
batch = Batch(
surf_vars=surf_vars,
static_vars=static_vars,
atmos_vars=atmos_vars,
metadata=metadata
) # Display the dataset or perform further processing
st.session_state['batch'] = batch
except Exception as e:
st.error(f"An error occurred: {e}")
# finally:
# # Clean up: Remove temporary files
# for path in temp_file_paths:
# try:
# os.remove(path)
# except Exception as e:
# st.warning(f"Could not delete temp file {path}: {e}")
else:
# For other models, implement their specific dataset initialization
# Placeholder: Replace with actual dataset initialization for other models
dataset = None # Replace with actual dataset
st.warning("Dataset initialization for this model is not implemented yet.")
st.stop()
st.success("Dataset initialized successfully.")
except Exception as e:
st.error("Error initializing dataset:")
st.error(traceback.format_exc())
st.stop()
# --- Load Scalers ---
try:
with st.spinner("Loading scalers..."):
if selected_model == "Prithvi":
pass
# # Assuming the scaler paths are the same as climatology paths
# surf_in_scal_path = clim_surf_path
# vert_in_scal_path = clim_vert_path
# surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc"
# vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc"
# # Check if output scaler files exist
# if not surf_out_scal_path.exists() or not vert_out_scal_path.exists():
# st.error("Anomaly variance scaler files are missing.")
# st.stop()
# in_mu, in_sig = input_scalers(
# surface_vars,
# vertical_vars,
# levels,
# surf_in_scal_path,
# vert_in_scal_path,
# )
# output_sig = output_scalers(
# surface_vars,
# vertical_vars,
# levels,
# surf_out_scal_path,
# vert_out_scal_path,
# )
# static_mu, static_sig = static_input_scalers(
# surf_in_scal_path,
# static_surface_vars,
# )
else:
# Load scalers for other models if applicable
# Placeholder: Replace with actual scaler loading for other models
in_mu, in_sig = None, None
output_sig = None
static_mu, static_sig = None, None
st.success("Scalers loaded successfully.")
except Exception as e:
st.error("Error loading scalers:")
st.error(traceback.format_exc())
st.stop()
# --- Load Configuration ---
try:
with st.spinner("Loading configuration..."):
if selected_model == "Prithvi":
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Validate config
required_params = [
"in_channels", "input_size_time", "in_channels_static",
"input_scalers_epsilon", "static_input_scalers_epsilon",
"n_lats_px", "n_lons_px", "patch_size_px",
"mask_unit_size_px", "embed_dim", "n_blocks_encoder",
"n_blocks_decoder", "mlp_multiplier", "n_heads",
"dropout", "drop_path", "parameter_dropout"
]
missing_params = [param for param in required_params if param not in config.get("params", {})]
if missing_params:
st.error(f"Missing configuration parameters: {missing_params}")
st.stop()
else:
# Load configuration for other models if applicable
# Placeholder: Replace with actual configuration loading for other models
config = {}
st.success("Configuration loaded successfully.")
except Exception as e:
st.error("Error loading configuration:")
st.error(traceback.format_exc())
st.stop()
# --- Initialize the Model ---
try:
with st.spinner("Initializing model..."):
if selected_model == "Prithvi":
model = PrithviWxC(
in_channels=config["params"]["in_channels"],
input_size_time=config["params"]["input_size_time"],
in_channels_static=config["params"]["in_channels_static"],
input_scalers_mu=in_mu,
input_scalers_sigma=in_sig,
input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
static_input_scalers_mu=static_mu,
static_input_scalers_sigma=static_sig,
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
output_scalers=output_sig**0.5,
n_lats_px=config["params"]["n_lats_px"],
n_lons_px=config["params"]["n_lons_px"],
patch_size_px=config["params"]["patch_size_px"],
mask_unit_size_px=config["params"]["mask_unit_size_px"],
mask_ratio_inputs=masking_ratio,
embed_dim=config["params"]["embed_dim"],
n_blocks_encoder=config["params"]["n_blocks_encoder"],
n_blocks_decoder=config["params"]["n_blocks_decoder"],
mlp_multiplier=config["params"]["mlp_multiplier"],
n_heads=config["params"]["n_heads"],
dropout=config["params"]["dropout"],
drop_path=config["params"]["drop_path"],
parameter_dropout=config["params"]["parameter_dropout"],
residual=residual,
masking_mode=masking_mode,
decoder_shifting=decoder_shifting,
positional_encoding=positional_encoding,
checkpoint_encoder=[],
checkpoint_decoder=[],
)
elif selected_model == "Aurora":
pass
else:
# Initialize other models here
# Placeholder: Replace with actual model initialization for other models
model = None
st.warning("Model initialization for this model is not implemented yet.")
st.stop()
# model.to(device)
st.success("Model initialized successfully.")
except Exception as e:
st.error("Error initializing model:")
st.error(traceback.format_exc())
st.stop()
# --- Load Model Weights ---
try:
with st.spinner("Loading model weights..."):
if selected_model == "Prithvi":
state_dict = torch.load(weights_path, map_location=device)
if "model_state" in state_dict:
state_dict = state_dict["model_state"]
model.load_state_dict(state_dict, strict=True)
model.to(device)
else:
# Load weights for other models if applicable
# Placeholder: Replace with actual weight loading for other models
pass
st.success("Model weights loaded successfully.")
except Exception as e:
st.error("Error loading model weights:")
st.error(traceback.format_exc())
st.stop()
# --- Prepare Data Batch ---
try:
with st.spinner("Preparing data batch..."):
if selected_model == "Prithvi":
data = next(iter(dataset))
batch = preproc([data], padding)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
elif selected_model == "Aurora":
batch = batch.regrid(res=0.25)
else:
# Prepare data batch for other models
# Placeholder: Replace with actual data preparation for other models
batch = None
st.success("Data batch prepared successfully.")
except Exception as e:
st.error("Error preparing data batch:")
st.error(traceback.format_exc())
st.stop()
# --- Run Inference ---
try:
with st.spinner("Running model inference..."):
if selected_model == "Prithvi":
model.eval()
with torch.no_grad():
out = model(batch)
elif selected_model == "Aurora":
model = Aurora(use_lora=False)
# model = Aurora()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model.eval()
# model = model.to("cuda") # Uncomment if using a GPU
with torch.inference_mode():
out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]
model = model.to("cpu")
st.session_state.model = model
else:
# Run inference for other models
# Placeholder: Replace with actual inference code for other models
out = torch.randn(1, 10, 180, 360) # Dummy tensor
st.success("Model inference completed successfully.")
st.session_state['out'] = out
except Exception as e:
st.error("Error during model inference:")
st.error(traceback.format_exc())
st.stop()
# --- Visualization Settings ---
st.markdown("## 📊 Visualization Settings")
if 'out' in st.session_state and 'batch' in st.session_state and selected_model == "Prithvi":
# Display the shape of the output tensor
out_tensor = st.session_state['out']
st.write(f"**Output tensor shape:** {out_tensor.shape}")
# Ensure the output tensor has at least 4 dimensions (batch, variables, lat, lon)
if out_tensor.ndim < 4:
st.error("The output tensor does not have the expected number of dimensions (batch, variables, lat, lon).")
st.stop()
# Get the number of variables
num_variables = out_tensor.shape[1]
# Define variable names (update with your actual variable names)
variable_names = [f"Variable_{i}" for i in range(num_variables)]
# Visualization settings
col1, col2 = st.columns(2)
with col1:
# Select variable to plot
selected_variable_name = st.selectbox(
"Select Variable to Plot",
options=variable_names,
index=0,
help="Choose the variable you want to visualize."
)
# Select plot type
plot_type = st.selectbox(
"Select Plot Type",
options=["Contour", "Heatmap"],
index=0,
help="Choose the type of plot to display."
)
with col2:
# Select color map
cmap = st.selectbox(
"Select Color Map",
options=plt.colormaps(),
index=plt.colormaps().index("viridis"),
help="Choose the color map for the plot."
)
# Set number of levels (for contour plot)
if plot_type == "Contour":
num_levels = st.slider(
"Number of Contour Levels",
min_value=5,
max_value=100,
value=20,
step=5,
help="Set the number of contour levels."
)
else:
num_levels = None
# Find the index based on the selected name
variable_index = variable_names.index(selected_variable_name)
# Extract the selected variable
selected_variable = out_tensor[0, variable_index].cpu().numpy()
# Generate latitude and longitude arrays
lat = np.linspace(-90, 90, selected_variable.shape[0])
lon = np.linspace(-180, 180, selected_variable.shape[1])
X, Y = np.meshgrid(lon, lat)
# Plot the selected variable
st.markdown(f"### Plot of {selected_variable_name}")
# Matplotlib figure
fig, ax = plt.subplots(figsize=(10, 6))
if plot_type == "Contour":
# Generate the contour plot
contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap)
elif plot_type == "Heatmap":
# Generate the heatmap
contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto')
# Add a color bar
cbar = plt.colorbar(contour, ax=ax)
cbar.set_label(f'{selected_variable_name}', fontsize=12)
# Set aspect ratio and labels
ax.set_xlabel("Longitude", fontsize=12)
ax.set_ylabel("Latitude", fontsize=12)
ax.set_title(f"{selected_variable_name}", fontsize=14)
# Display the plot in Streamlit
st.pyplot(fig)
# Optional: Provide interactive Plotly plot
st.markdown("#### Interactive Plot")
if plot_type == "Contour":
fig_plotly = go.Figure(data=go.Contour(
z=selected_variable,
x=lon,
y=lat,
colorscale=cmap,
contours=dict(
coloring='fill',
showlabels=True,
labelfont=dict(size=12, color='white'),
ncontours=num_levels
)
))
elif plot_type == "Heatmap":
fig_plotly = go.Figure(data=go.Heatmap(
z=selected_variable,
x=lon,
y=lat,
colorscale=cmap
))
fig_plotly.update_layout(
xaxis_title="Longitude",
yaxis_title="Latitude",
autosize=False,
width=800,
height=600,
)
st.plotly_chart(fig_plotly)
elif 'out' in st.session_state and selected_model == "Aurora" and st.session_state['out'] is not None:
preds = st.session_state['out']
ds_subset = st.session_state.get('ds_subset', None)
batch = st.session_state.get('batch', None)
# **Determine Available Levels**
# For example, let's assume levels range from 0 to max_level_index
# You need to replace 'max_level_index' with the actual maximum level index in your data
try:
# Assuming 'lev' dimension exists and is 1D
levels = preds[0].atmos_vars["t"].shape[2] # Adjust based on your data structure
level_indices = list(range(levels))
except Exception as e:
st.error("Error determining available levels:")
st.error(traceback.format_exc())
levels = None # Set to None if levels cannot be determined
if levels is not None:
# **Add a Slider for Level Selection**
selected_level = st.slider(
'Select Level',
min_value=0,
max_value=levels - 1,
value=11, # Default level index
step=1,
help="Select the vertical level for plotting."
)
# Loop through predictions and ground truths
for idx in range(len(preds)):
pred = preds[idx]
pred_time = pred.metadata.time[0]
# Display prediction time
st.write(f"### Prediction Time: {pred_time}")
# **Extract Data at Selected Level**
try:
# Update indices with the selected_level
pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15
truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15
except Exception as e:
st.error("Error extracting data for plotting:")
st.error(traceback.format_exc())
continue
# Extract latitude and longitude
try:
lat = np.array(pred.metadata.lat) # Assuming 'lat' is 1D
lon = np.array(pred.metadata.lon) # Assuming 'lon' is 1D
except Exception as e:
st.error("Error extracting latitude and longitude:")
st.error(traceback.format_exc())
continue
# Create a meshgrid for plotting
lon_grid, lat_grid = np.meshgrid(lon, lat)
# Create a Matplotlib figure with Cartopy projection
fig, axes = plt.subplots(
1, 3, figsize=(18, 6),
subplot_kw={'projection': ccrs.PlateCarree()}
)
# **Ground Truth Plot**
im1 = axes[0].imshow(
truth_data,
extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower',
cmap='coolwarm',
transform=ccrs.PlateCarree()
)
axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}")
axes[0].set_xlabel('Longitude')
axes[0].set_ylabel('Latitude')
plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05)
# **Prediction Plot**
im2 = axes[1].imshow(
pred_data,
extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower',
cmap='coolwarm',
transform=ccrs.PlateCarree()
)
axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}")
axes[1].set_xlabel('Longitude')
axes[1].set_ylabel('Latitude')
plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05)
plt.tight_layout()
# Display the plot in Streamlit
st.pyplot(fig)
else:
st.error("Could not determine the available levels in the data.")
else:
st.warning("No output available to display or visualization is not implemented for this model.")
# --- End of Inference Button ---
else:
with right_col:
st.header("🖥️ Visualization & Progress")
st.info("Awaiting inference to display results.")