|
import streamlit as st |
|
import torch |
|
import random |
|
import numpy as np |
|
import yaml |
|
from pathlib import Path |
|
import tempfile |
|
import traceback |
|
import matplotlib.pyplot as plt |
|
import plotly.graph_objects as go |
|
from Prithvi import * |
|
import xarray as xr |
|
from aurora import Batch, Metadata |
|
from aurora import Aurora, rollout |
|
import logging |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import cartopy.crs as ccrs |
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def save_uploaded_files(uploaded_files): |
|
if 'temp_file_paths' not in st.session_state: |
|
st.session_state.temp_file_paths = [] |
|
for uploaded_file in uploaded_files: |
|
suffix = os.path.splitext(uploaded_file.name)[1] |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) |
|
temp_file.write(uploaded_file.read()) |
|
temp_file.close() |
|
st.session_state.temp_file_paths.append(temp_file.name) |
|
|
|
@st.cache_resource |
|
def load_dataset(file_paths): |
|
try: |
|
ds = xr.open_mfdataset(file_paths, combine='by_coords').load() |
|
return ds |
|
except Exception as e: |
|
st.error("Error loading dataset:") |
|
st.error(traceback.format_exc()) |
|
return None |
|
|
|
|
|
st.set_page_config( |
|
page_title="Weather Data Processor", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
|
|
|
|
|
|
header_col1, header_col2 = st.columns([4, 1]) |
|
|
|
with header_col1: |
|
st.title("🌦️ Weather & Climate Data Processor and Forecaster") |
|
|
|
with header_col2: |
|
st.markdown("### Select a Model") |
|
selected_model = st.selectbox( |
|
"", |
|
options=["Aurora", "Climax", "Prithvi", "LSTM"], |
|
index=0, |
|
key="model_selector", |
|
help="Select the model you want to use for processing the data." |
|
) |
|
|
|
st.write("---") |
|
|
|
|
|
left_col, right_col = st.columns([1, 2]) |
|
|
|
with left_col: |
|
st.header("🔧 Configuration") |
|
|
|
|
|
def get_model_configuration(model_name): |
|
if model_name == "Prithvi": |
|
st.subheader("Prithvi Model Configuration") |
|
|
|
|
|
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1) |
|
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi") |
|
|
|
|
|
config = { |
|
"param1": param1, |
|
"param2": param2, |
|
|
|
} |
|
|
|
|
|
st.markdown("### Upload Data Files for Prithvi Model") |
|
|
|
|
|
uploaded_surface_files = st.file_uploader( |
|
"Upload Surface Data Files", |
|
type=["nc", "netcdf"], |
|
accept_multiple_files=True, |
|
key="surface_uploader", |
|
) |
|
|
|
|
|
uploaded_vertical_files = st.file_uploader( |
|
"Upload Vertical Data Files", |
|
type=["nc", "netcdf"], |
|
accept_multiple_files=True, |
|
key="vertical_uploader", |
|
) |
|
|
|
|
|
st.markdown("### Upload Climatology Files (If Missing)") |
|
|
|
|
|
default_clim_dir = Path("Prithvi-WxC/examples/climatology") |
|
surf_in_scal_path = default_clim_dir / "musigma_surface.nc" |
|
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" |
|
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" |
|
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" |
|
|
|
|
|
clim_files_exist = all( |
|
[ |
|
surf_in_scal_path.exists(), |
|
vert_in_scal_path.exists(), |
|
surf_out_scal_path.exists(), |
|
vert_out_scal_path.exists(), |
|
] |
|
) |
|
|
|
if not clim_files_exist: |
|
st.warning("Climatology files are missing.") |
|
uploaded_clim_surface = st.file_uploader( |
|
"Upload Climatology Surface File", |
|
type=["nc", "netcdf"], |
|
key="clim_surface_uploader", |
|
) |
|
uploaded_clim_vertical = st.file_uploader( |
|
"Upload Climatology Vertical File", |
|
type=["nc", "netcdf"], |
|
key="clim_vertical_uploader", |
|
) |
|
|
|
|
|
if uploaded_clim_surface and uploaded_clim_vertical: |
|
clim_temp_dir = tempfile.mkdtemp() |
|
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name |
|
with open(clim_surf_path, "wb") as f: |
|
f.write(uploaded_clim_surface.getbuffer()) |
|
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name |
|
with open(clim_vert_path, "wb") as f: |
|
f.write(uploaded_clim_vertical.getbuffer()) |
|
st.success("Climatology files uploaded and saved.") |
|
else: |
|
st.warning("Please upload both climatology surface and vertical files.") |
|
else: |
|
clim_surf_path = surf_in_scal_path |
|
clim_vert_path = vert_in_scal_path |
|
|
|
|
|
uploaded_config = st.file_uploader( |
|
"Upload config.yaml", |
|
type=["yaml", "yml"], |
|
key="config_uploader", |
|
) |
|
|
|
if uploaded_config: |
|
temp_config = tempfile.mktemp(suffix=".yaml") |
|
with open(temp_config, "wb") as f: |
|
f.write(uploaded_config.getbuffer()) |
|
config_path = Path(temp_config) |
|
st.success("Config.yaml uploaded and saved.") |
|
else: |
|
|
|
config_path = Path("Prithvi-WxC/examples/config.yaml") |
|
if not config_path.exists(): |
|
st.error("Default config.yaml not found. Please upload a config file.") |
|
st.stop() |
|
|
|
|
|
uploaded_weights = st.file_uploader( |
|
"Upload Model Weights (.pt)", |
|
type=["pt"], |
|
key="weights_uploader", |
|
) |
|
|
|
if uploaded_weights: |
|
temp_weights = tempfile.mktemp(suffix=".pt") |
|
with open(temp_weights, "wb") as f: |
|
f.write(uploaded_weights.getbuffer()) |
|
weights_path = Path(temp_weights) |
|
st.success("Model weights uploaded and saved.") |
|
else: |
|
|
|
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") |
|
if not weights_path.exists(): |
|
st.error("Default model weights not found. Please upload model weights.") |
|
st.stop() |
|
|
|
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path |
|
|
|
else: |
|
|
|
st.subheader(f"{model_name} Model Data Upload") |
|
st.markdown("### Drag and Drop Your Data Files Here") |
|
uploaded_files = st.file_uploader( |
|
f"Upload Data Files for {model_name}", |
|
accept_multiple_files=True, |
|
key=f"{model_name.lower()}_uploader", |
|
type=["nc", "netcdf", "nc4"], |
|
) |
|
return uploaded_files |
|
|
|
|
|
if selected_model == "Prithvi": |
|
config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path = get_model_configuration(selected_model) |
|
else: |
|
uploaded_files = get_model_configuration(selected_model) |
|
|
|
st.write("---") |
|
|
|
|
|
if st.button("🚀 Run Inference"): |
|
with right_col: |
|
st.header("📈 Inference Progress & Visualization") |
|
|
|
|
|
try: |
|
torch.jit.enable_onednn_fusion(True) |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
st.write(f"Using device: **{torch.cuda.get_device_name()}**") |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = True |
|
else: |
|
device = torch.device("cpu") |
|
st.write("Using device: **CPU**") |
|
except Exception as e: |
|
st.error("Error initializing device:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
random.seed(42) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(42) |
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
except Exception as e: |
|
st.error("Error setting random seeds:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} |
|
|
|
residual = "climate" |
|
masking_mode = "local" |
|
decoder_shifting = True |
|
masking_ratio = 0.99 |
|
|
|
positional_encoding = "fourier" |
|
|
|
|
|
try: |
|
with st.spinner("Initializing dataset..."): |
|
if selected_model == "Prithvi": |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif selected_model == "Aurora": |
|
|
|
if uploaded_files: |
|
temp_file_paths = [] |
|
try: |
|
|
|
save_uploaded_files(uploaded_files) |
|
ds = load_dataset(st.session_state.temp_file_paths) |
|
|
|
|
|
if ds: |
|
st.success("Files successfully loaded!") |
|
st.session_state.ds_subset = ds |
|
|
|
|
|
|
|
ds = ds.fillna(ds.mean()) |
|
|
|
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] |
|
|
|
|
|
if 'lev' not in ds.dims: |
|
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.") |
|
|
|
|
|
def _prepare(x: np.ndarray, i: int) -> torch.Tensor: |
|
|
|
selected = x[[i - 6, i]] |
|
|
|
|
|
selected = selected[None] |
|
|
|
|
|
selected = selected.copy() |
|
|
|
|
|
return torch.from_numpy(selected) |
|
|
|
|
|
lat = ds.lat.values * -1 |
|
lon = ds.lon.values + 180 |
|
|
|
|
|
ds_subset = ds.sel(lev=desired_levels, method="nearest") |
|
|
|
|
|
present_levels = ds_subset.lev.values |
|
missing_levels = set(desired_levels) - set(present_levels) |
|
if missing_levels: |
|
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}") |
|
|
|
|
|
lev = ds_subset.lev.values |
|
|
|
|
|
try: |
|
lev_index_1000 = np.where(lev == 1000)[0][0] |
|
except IndexError: |
|
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.") |
|
|
|
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute() |
|
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute() |
|
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute() |
|
SLP = ds_subset.SLP.compute() |
|
|
|
|
|
PHIS = ds_subset.PHIS.isel(time=0).compute() |
|
|
|
|
|
atmos_levels = [int(level) for level in lev if level != 1000] |
|
|
|
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute() |
|
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute() |
|
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute() |
|
|
|
|
|
num_times = ds_subset.time.size |
|
i = 6 |
|
|
|
if i >= num_times or i < 1: |
|
raise IndexError("Time index i is out of bounds.") |
|
|
|
time_values = ds_subset.time.values |
|
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime) |
|
|
|
|
|
surf_vars = { |
|
"2t": _prepare(T_surface.values, i), |
|
"10u": _prepare(U_surface.values, i), |
|
"10v": _prepare(V_surface.values, i), |
|
"msl": _prepare(SLP.values, i), |
|
} |
|
|
|
|
|
static_vars = { |
|
"z": torch.from_numpy(PHIS.values.copy()), |
|
|
|
} |
|
|
|
|
|
atmos_vars = { |
|
"t": _prepare(T_atm.values, i), |
|
"u": _prepare(U_atm.values, i), |
|
"v": _prepare(V_atm.values, i), |
|
} |
|
|
|
|
|
metadata = Metadata( |
|
lat=torch.from_numpy(lat.copy()), |
|
lon=torch.from_numpy(lon.copy()), |
|
time=(current_time,), |
|
atmos_levels=tuple(atmos_levels), |
|
) |
|
|
|
|
|
batch = Batch( |
|
surf_vars=surf_vars, |
|
static_vars=static_vars, |
|
atmos_vars=atmos_vars, |
|
metadata=metadata |
|
) |
|
|
|
st.session_state['batch'] = batch |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
dataset = None |
|
st.warning("Dataset initialization for this model is not implemented yet.") |
|
st.stop() |
|
st.success("Dataset initialized successfully.") |
|
except Exception as e: |
|
st.error("Error initializing dataset:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Loading scalers..."): |
|
if selected_model == "Prithvi": |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
in_mu, in_sig = None, None |
|
output_sig = None |
|
static_mu, static_sig = None, None |
|
st.success("Scalers loaded successfully.") |
|
except Exception as e: |
|
st.error("Error loading scalers:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Loading configuration..."): |
|
if selected_model == "Prithvi": |
|
with open(config_path, "r") as f: |
|
config = yaml.safe_load(f) |
|
|
|
required_params = [ |
|
"in_channels", "input_size_time", "in_channels_static", |
|
"input_scalers_epsilon", "static_input_scalers_epsilon", |
|
"n_lats_px", "n_lons_px", "patch_size_px", |
|
"mask_unit_size_px", "embed_dim", "n_blocks_encoder", |
|
"n_blocks_decoder", "mlp_multiplier", "n_heads", |
|
"dropout", "drop_path", "parameter_dropout" |
|
] |
|
missing_params = [param for param in required_params if param not in config.get("params", {})] |
|
if missing_params: |
|
st.error(f"Missing configuration parameters: {missing_params}") |
|
st.stop() |
|
else: |
|
|
|
|
|
config = {} |
|
st.success("Configuration loaded successfully.") |
|
except Exception as e: |
|
st.error("Error loading configuration:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Initializing model..."): |
|
if selected_model == "Prithvi": |
|
model = PrithviWxC( |
|
in_channels=config["params"]["in_channels"], |
|
input_size_time=config["params"]["input_size_time"], |
|
in_channels_static=config["params"]["in_channels_static"], |
|
input_scalers_mu=in_mu, |
|
input_scalers_sigma=in_sig, |
|
input_scalers_epsilon=config["params"]["input_scalers_epsilon"], |
|
static_input_scalers_mu=static_mu, |
|
static_input_scalers_sigma=static_sig, |
|
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], |
|
output_scalers=output_sig**0.5, |
|
n_lats_px=config["params"]["n_lats_px"], |
|
n_lons_px=config["params"]["n_lons_px"], |
|
patch_size_px=config["params"]["patch_size_px"], |
|
mask_unit_size_px=config["params"]["mask_unit_size_px"], |
|
mask_ratio_inputs=masking_ratio, |
|
embed_dim=config["params"]["embed_dim"], |
|
n_blocks_encoder=config["params"]["n_blocks_encoder"], |
|
n_blocks_decoder=config["params"]["n_blocks_decoder"], |
|
mlp_multiplier=config["params"]["mlp_multiplier"], |
|
n_heads=config["params"]["n_heads"], |
|
dropout=config["params"]["dropout"], |
|
drop_path=config["params"]["drop_path"], |
|
parameter_dropout=config["params"]["parameter_dropout"], |
|
residual=residual, |
|
masking_mode=masking_mode, |
|
decoder_shifting=decoder_shifting, |
|
positional_encoding=positional_encoding, |
|
checkpoint_encoder=[], |
|
checkpoint_decoder=[], |
|
) |
|
elif selected_model == "Aurora": |
|
pass |
|
|
|
else: |
|
|
|
|
|
|
|
model = None |
|
st.warning("Model initialization for this model is not implemented yet.") |
|
st.stop() |
|
|
|
st.success("Model initialized successfully.") |
|
except Exception as e: |
|
st.error("Error initializing model:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Loading model weights..."): |
|
if selected_model == "Prithvi": |
|
state_dict = torch.load(weights_path, map_location=device) |
|
if "model_state" in state_dict: |
|
state_dict = state_dict["model_state"] |
|
model.load_state_dict(state_dict, strict=True) |
|
model.to(device) |
|
else: |
|
|
|
|
|
pass |
|
st.success("Model weights loaded successfully.") |
|
except Exception as e: |
|
st.error("Error loading model weights:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Preparing data batch..."): |
|
if selected_model == "Prithvi": |
|
data = next(iter(dataset)) |
|
batch = preproc([data], padding) |
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch[k] = v.to(device) |
|
elif selected_model == "Aurora": |
|
batch = batch.regrid(res=0.25) |
|
|
|
else: |
|
|
|
|
|
batch = None |
|
st.success("Data batch prepared successfully.") |
|
except Exception as e: |
|
st.error("Error preparing data batch:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Running model inference..."): |
|
if selected_model == "Prithvi": |
|
model.eval() |
|
with torch.no_grad(): |
|
out = model(batch) |
|
elif selected_model == "Aurora": |
|
|
|
model = Aurora(use_lora=False) |
|
|
|
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
with torch.inference_mode(): |
|
out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)] |
|
|
|
model = model.to("cpu") |
|
st.session_state.model = model |
|
else: |
|
|
|
|
|
out = torch.randn(1, 10, 180, 360) |
|
st.success("Model inference completed successfully.") |
|
st.session_state['out'] = out |
|
except Exception as e: |
|
st.error("Error during model inference:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
st.markdown("## 📊 Visualization Settings") |
|
|
|
if 'out' in st.session_state and 'batch' in st.session_state and selected_model == "Prithvi": |
|
|
|
out_tensor = st.session_state['out'] |
|
st.write(f"**Output tensor shape:** {out_tensor.shape}") |
|
|
|
|
|
if out_tensor.ndim < 4: |
|
st.error("The output tensor does not have the expected number of dimensions (batch, variables, lat, lon).") |
|
st.stop() |
|
|
|
|
|
num_variables = out_tensor.shape[1] |
|
|
|
|
|
variable_names = [f"Variable_{i}" for i in range(num_variables)] |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
|
|
selected_variable_name = st.selectbox( |
|
"Select Variable to Plot", |
|
options=variable_names, |
|
index=0, |
|
help="Choose the variable you want to visualize." |
|
) |
|
|
|
|
|
plot_type = st.selectbox( |
|
"Select Plot Type", |
|
options=["Contour", "Heatmap"], |
|
index=0, |
|
help="Choose the type of plot to display." |
|
) |
|
|
|
with col2: |
|
|
|
cmap = st.selectbox( |
|
"Select Color Map", |
|
options=plt.colormaps(), |
|
index=plt.colormaps().index("viridis"), |
|
help="Choose the color map for the plot." |
|
) |
|
|
|
|
|
if plot_type == "Contour": |
|
num_levels = st.slider( |
|
"Number of Contour Levels", |
|
min_value=5, |
|
max_value=100, |
|
value=20, |
|
step=5, |
|
help="Set the number of contour levels." |
|
) |
|
else: |
|
num_levels = None |
|
|
|
|
|
variable_index = variable_names.index(selected_variable_name) |
|
|
|
|
|
selected_variable = out_tensor[0, variable_index].cpu().numpy() |
|
|
|
|
|
lat = np.linspace(-90, 90, selected_variable.shape[0]) |
|
lon = np.linspace(-180, 180, selected_variable.shape[1]) |
|
X, Y = np.meshgrid(lon, lat) |
|
|
|
|
|
st.markdown(f"### Plot of {selected_variable_name}") |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
if plot_type == "Contour": |
|
|
|
contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap) |
|
elif plot_type == "Heatmap": |
|
|
|
contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto') |
|
|
|
|
|
cbar = plt.colorbar(contour, ax=ax) |
|
cbar.set_label(f'{selected_variable_name}', fontsize=12) |
|
|
|
|
|
ax.set_xlabel("Longitude", fontsize=12) |
|
ax.set_ylabel("Latitude", fontsize=12) |
|
ax.set_title(f"{selected_variable_name}", fontsize=14) |
|
|
|
|
|
st.pyplot(fig) |
|
|
|
|
|
st.markdown("#### Interactive Plot") |
|
if plot_type == "Contour": |
|
fig_plotly = go.Figure(data=go.Contour( |
|
z=selected_variable, |
|
x=lon, |
|
y=lat, |
|
colorscale=cmap, |
|
contours=dict( |
|
coloring='fill', |
|
showlabels=True, |
|
labelfont=dict(size=12, color='white'), |
|
ncontours=num_levels |
|
) |
|
)) |
|
elif plot_type == "Heatmap": |
|
fig_plotly = go.Figure(data=go.Heatmap( |
|
z=selected_variable, |
|
x=lon, |
|
y=lat, |
|
colorscale=cmap |
|
)) |
|
|
|
fig_plotly.update_layout( |
|
xaxis_title="Longitude", |
|
yaxis_title="Latitude", |
|
autosize=False, |
|
width=800, |
|
height=600, |
|
) |
|
|
|
st.plotly_chart(fig_plotly) |
|
|
|
elif 'out' in st.session_state and selected_model == "Aurora" and st.session_state['out'] is not None: |
|
preds = st.session_state['out'] |
|
ds_subset = st.session_state.get('ds_subset', None) |
|
batch = st.session_state.get('batch', None) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
levels = preds[0].atmos_vars["t"].shape[2] |
|
level_indices = list(range(levels)) |
|
except Exception as e: |
|
st.error("Error determining available levels:") |
|
st.error(traceback.format_exc()) |
|
levels = None |
|
|
|
if levels is not None: |
|
|
|
selected_level = st.slider( |
|
'Select Level', |
|
min_value=0, |
|
max_value=levels - 1, |
|
value=11, |
|
step=1, |
|
help="Select the vertical level for plotting." |
|
) |
|
|
|
|
|
for idx in range(len(preds)): |
|
pred = preds[idx] |
|
pred_time = pred.metadata.time[0] |
|
|
|
|
|
st.write(f"### Prediction Time: {pred_time}") |
|
|
|
|
|
try: |
|
|
|
pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15 |
|
truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15 |
|
|
|
except Exception as e: |
|
st.error("Error extracting data for plotting:") |
|
st.error(traceback.format_exc()) |
|
continue |
|
|
|
|
|
try: |
|
lat = np.array(pred.metadata.lat) |
|
lon = np.array(pred.metadata.lon) |
|
except Exception as e: |
|
st.error("Error extracting latitude and longitude:") |
|
st.error(traceback.format_exc()) |
|
continue |
|
|
|
|
|
lon_grid, lat_grid = np.meshgrid(lon, lat) |
|
|
|
|
|
fig, axes = plt.subplots( |
|
1, 3, figsize=(18, 6), |
|
subplot_kw={'projection': ccrs.PlateCarree()} |
|
) |
|
|
|
|
|
im1 = axes[0].imshow( |
|
truth_data, |
|
extent=[lon.min(), lon.max(), lat.min(), lat.max()], |
|
origin='lower', |
|
cmap='coolwarm', |
|
transform=ccrs.PlateCarree() |
|
) |
|
axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}") |
|
axes[0].set_xlabel('Longitude') |
|
axes[0].set_ylabel('Latitude') |
|
plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05) |
|
|
|
|
|
im2 = axes[1].imshow( |
|
pred_data, |
|
extent=[lon.min(), lon.max(), lat.min(), lat.max()], |
|
origin='lower', |
|
cmap='coolwarm', |
|
transform=ccrs.PlateCarree() |
|
) |
|
axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}") |
|
axes[1].set_xlabel('Longitude') |
|
axes[1].set_ylabel('Latitude') |
|
plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
st.pyplot(fig) |
|
else: |
|
st.error("Could not determine the available levels in the data.") |
|
|
|
|
|
else: |
|
st.warning("No output available to display or visualization is not implemented for this model.") |
|
|
|
|
|
else: |
|
with right_col: |
|
st.header("🖥️ Visualization & Progress") |
|
st.info("Awaiting inference to display results.") |
|
|
|
|
|
|