import os import tempfile from pathlib import Path # Set memory optimization environment variables os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' import gradio as gr import datetime import numpy as np import matplotlib.pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cfeature import matplotlib.tri as tri from anemoi.inference.runners.simple import SimpleRunner from ecmwf.opendata import Client as OpendataClient import earthkit.data as ekd import earthkit.regrid as ekr import matplotlib.animation as animation from functools import lru_cache import hashlib import pickle import json from typing import List, Dict, Any import logging import xarray as xr import pandas as pd # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Define parameters (updating to match notebook.py) PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] PARAM_SOIL = ["vsw", "sot"] PARAM_PL = ["gh", "t", "u", "v", "w", "q"] LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] SOIL_LEVELS = [1, 2] DEFAULT_DATE = OpendataClient().latest() # First organize variables into categories VARIABLE_GROUPS = { "Surface Variables": { "10u": "10m U Wind Component", "10v": "10m V Wind Component", "2d": "2m Dewpoint Temperature", "2t": "2m Temperature", "msl": "Mean Sea Level Pressure", "skt": "Skin Temperature", "sp": "Surface Pressure", "tcw": "Total Column Water", "lsm": "Land-Sea Mask", "z": "Surface Geopotential", "slor": "Slope of Sub-gridscale Orography", "sdor": "Standard Deviation of Orography", }, "Soil Variables": { "stl1": "Soil Temperature Level 1", "stl2": "Soil Temperature Level 2", "swvl1": "Soil Water Volume Level 1", "swvl2": "Soil Water Volume Level 2", }, "Pressure Level Variables": {} # Will fill this dynamically } # Add pressure level variables dynamically for var in ["t", "u", "v", "w", "q", "z"]: var_name = { "t": "Temperature", "u": "U Wind Component", "v": "V Wind Component", "w": "Vertical Velocity", "q": "Specific Humidity", "z": "Geopotential" }[var] for level in LEVELS: var_id = f"{var}_{level}" VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" def get_device(): """Determine the best available device""" try: import torch if torch.cuda.is_available(): logger.info("CUDA is available, using GPU") return "cuda" else: logger.info("CUDA is not available, using CPU") return "cpu" except ImportError: logger.info("PyTorch not found, using CPU") return "cpu" # Update the model initialization to use the detected device DEVICE = get_device() MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=DEVICE) # Create and set custom temp directory TEMP_DIR = Path("./gradio_temp") TEMP_DIR.mkdir(exist_ok=True) os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR) # Add these cache-related functions after the MODEL initialization def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[int]) -> str: """Create a unique cache key based on the request parameters""" key_parts = [ date.isoformat(), ",".join(sorted(params)), ",".join(str(x) for x in sorted(levellist)) if levellist else "no_levels" ] key_string = "_".join(key_parts) cache_key = hashlib.md5(key_string.encode()).hexdigest() logger.info(f"Generated cache key: {cache_key} for {key_string}") return cache_key def get_cache_path(cache_key: str) -> Path: """Get the path to the cache file""" return TEMP_DIR / "data_cache" / f"{cache_key}.pkl" def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None: """Save data to disk cache""" cache_file = get_cache_path(cache_key) try: with open(cache_file, 'wb') as f: pickle.dump(data, f) logger.info(f"Successfully saved data to cache: {cache_file}") except Exception as e: logger.error(f"Failed to save to cache: {e}") def load_from_cache(cache_key: str) -> Dict[str, Any]: """Load data from disk cache""" cache_file = get_cache_path(cache_key) if cache_file.exists(): try: with open(cache_file, 'rb') as f: data = pickle.load(f) logger.info(f"Successfully loaded data from cache: {cache_file}") return data except Exception as e: logger.error(f"Failed to load from cache: {e}") cache_file.unlink(missing_ok=True) logger.info(f"No cache file found: {cache_file}") return None # Modify the get_open_data function to use caching @lru_cache(maxsize=32) def get_cached_data(date_str: str, param_tuple: tuple, levelist_tuple: tuple) -> Dict[str, Any]: """Memory cache wrapper for get_open_data""" return get_open_data_impl( datetime.datetime.fromisoformat(date_str), list(param_tuple), list(levelist_tuple) if levelist_tuple else [] ) def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any]: """Main function to get data with caching""" if levelist is None: levelist = [] # Try disk cache first (more persistent than memory cache) cache_key = get_cache_key(DEFAULT_DATE, param, levelist) logger.info(f"Checking cache for key: {cache_key}") cached_data = load_from_cache(cache_key) if cached_data is not None: logger.info(f"Cache hit for {cache_key}") return cached_data # If not in cache, download and process the data logger.info(f"Cache miss for {cache_key}, downloading fresh data") fields = get_open_data_impl(DEFAULT_DATE, param, levelist) # Save to disk cache save_to_cache(cache_key, fields) return fields def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]: """Implementation of data download and processing""" fields = {} myiterable = [date - datetime.timedelta(hours=6), date] logger.info(f"Downloading data for dates: {myiterable}") for current_date in myiterable: logger.info(f"Fetching data for {current_date}") data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist) for f in data: assert f.to_numpy().shape == (721, 1440) values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") if name not in fields: fields[name] = [] fields[name].append(values) # Create a single matrix for each parameter for param, values in fields.items(): fields[param] = np.stack(values) return fields def plot_forecast(state, selected_variable): logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}") # Setup the figure and axis fig = plt.figure(figsize=(15, 8)) ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0)) # Get the coordinates latitudes, longitudes = state["latitudes"], state["longitudes"] fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes) triangulation = tri.Triangulation(fixed_lons, latitudes) # Get the values values = state["fields"][selected_variable] logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}") # Set map features ax.set_global() ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree()) ax.coastlines(resolution='50m') ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) ax.gridlines(draw_labels=True) # Create contour plot contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap='RdBu_r') # Add colorbar plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05) # Format the date string forecast_time = state["date"] if isinstance(forecast_time, str): forecast_time = datetime.datetime.fromisoformat(forecast_time) time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC") # Get variable description var_desc = None for group in VARIABLE_GROUPS.values(): if selected_variable in group: var_desc = group[selected_variable] break var_name = var_desc if var_desc else selected_variable ax.set_title(f"{var_name} - {time_str}") # Save as PNG temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png") plt.savefig(temp_file, bbox_inches='tight', dpi=100) plt.close() return temp_file def run_forecast(date: datetime.datetime, lead_time: int, device: str = None) -> Dict[str, Any]: # Use the global device if none specified device = device or DEVICE # Get all required fields fields = {} logger.info(f"Starting forecast for lead_time: {lead_time} hours on {device}") # Get surface fields logger.info("Getting surface fields...") fields.update(get_open_data(param=PARAM_SFC)) # Get soil fields and rename them logger.info("Getting soil fields...") soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) mapping = { 'sot_1': 'stl1', 'sot_2': 'stl2', 'vsw_1': 'swvl1', 'vsw_2': 'swvl2' } for k, v in soil.items(): fields[mapping[k]] = v # Get pressure level fields logger.info("Getting pressure level fields...") fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) # Convert geopotential height to geopotential for level in LEVELS: gh = fields.pop(f"gh_{level}") fields[f"z_{level}"] = gh * 9.80665 input_state = dict(date=date, fields=fields) # Use the global model instance global MODEL if device != MODEL.device: MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) # Run the model and get the final state final_state = None for state in MODEL.run(input_state=input_state, lead_time=lead_time): logger.info(f"\nš date={state['date']} latitudes={state['latitudes'].shape} " f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}") # Log a few example variables to show we have all fields for var in ['2t', 'msl', 't_1000', 'z_850']: if var in state['fields']: values = state['fields'][var] logger.info(f" {var:<6} shape={values.shape} " f"min={np.min(values):.6f} " f"max={np.max(values):.6f}") final_state = state logger.info(f"Final state contains {len(final_state['fields'])} variables") return final_state def get_available_variables(state): """Get available variables from the state and organize them into groups""" available_vars = set(state['fields'].keys()) # Create dropdown choices only for available variables choices = [] for group_name, variables in VARIABLE_GROUPS.items(): group_vars = [(f"{desc} ({var_id})", var_id) for var_id, desc in variables.items() if var_id in available_vars] if group_vars: # Only add group if it has available variables choices.append((f"āā {group_name} āā", None)) choices.extend(group_vars) return choices def save_forecast_data(state, format='json'): """Save forecast data in specified format""" if state is None: raise ValueError("No forecast data available. Please run a forecast first.") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") forecast_time = state['date'].strftime("%Y%m%d_%H") if isinstance(state['date'], datetime.datetime) else state['date'] # Use forecasts directory for all outputs output_dir = TEMP_DIR / "forecasts" if format == 'json': # Create a JSON-serializable dictionary data = { 'metadata': { 'forecast_date': forecast_time, 'export_date': datetime.datetime.now().isoformat(), 'total_points': len(state['latitudes']), 'total_variables': len(state['fields']) }, 'coordinates': { 'latitudes': state['latitudes'].tolist(), 'longitudes': state['longitudes'].tolist() }, 'fields': { var_name: { 'values': values.tolist(), 'statistics': { 'min': float(np.min(values)), 'max': float(np.max(values)), 'mean': float(np.mean(values)), 'std': float(np.std(values)) } } for var_name, values in state['fields'].items() } } output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.json" with open(output_file, 'w') as f: json.dump(data, f, indent=2) return str(output_file) elif format == 'netcdf': # Create an xarray Dataset data_vars = {} coords = { 'point': np.arange(len(state['latitudes'])), 'latitude': ('point', state['latitudes']), 'longitude': ('point', state['longitudes']), } # Add each field as a variable for var_name, values in state['fields'].items(): data_vars[var_name] = (['point'], values) # Create the dataset ds = xr.Dataset( data_vars=data_vars, coords=coords, attrs={ 'forecast_date': forecast_time, 'export_date': datetime.datetime.now().isoformat(), 'description': 'AIFS Weather Forecast Data' } ) output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.nc" ds.to_netcdf(output_file) return str(output_file) elif format == 'csv': # Create a DataFrame with lat/lon and all variables df = pd.DataFrame({ 'latitude': state['latitudes'], 'longitude': state['longitudes'] }) # Add each field as a column for var_name, values in state['fields'].items(): df[var_name] = values output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.csv" df.to_csv(output_file, index=False) return str(output_file) else: raise ValueError(f"Unsupported format: {format}") # Create dropdown choices with groups DROPDOWN_CHOICES = [] for group_name, variables in VARIABLE_GROUPS.items(): # Add group separator DROPDOWN_CHOICES.append((f"āā {group_name} āā", None)) # Add variables in this group for var_id, desc in sorted(variables.items()): DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id)) def update_interface(): with gr.Blocks(css=""" .centered-header { text-align: center; margin-bottom: 20px; } .subtitle { font-size: 1.2em; line-height: 1.5; margin: 20px 0; } .footer { text-align: center; padding: 20px; margin-top: 20px; border-top: 1px solid #eee; } """) as demo: forecast_state = gr.State(None) # Header section gr.Markdown(f""" # AIFS Weather Forecast