Spaces:
Build error
Build error
import os | |
import tempfile | |
from pathlib import Path | |
# Set memory optimization environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' | |
import gradio as gr | |
import datetime | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cartopy.crs as ccrs | |
import cartopy.feature as cfeature | |
import matplotlib.tri as tri | |
from anemoi.inference.runners.simple import SimpleRunner | |
from ecmwf.opendata import Client as OpendataClient | |
import earthkit.data as ekd | |
import earthkit.regrid as ekr | |
import matplotlib.animation as animation | |
from functools import lru_cache | |
import hashlib | |
import pickle | |
import json | |
from typing import List, Dict, Any | |
import logging | |
import xarray as xr | |
import pandas as pd | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Define parameters (updating to match notebook.py) | |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] | |
PARAM_SOIL = ["vsw", "sot"] | |
PARAM_PL = ["gh", "t", "u", "v", "w", "q"] | |
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] | |
SOIL_LEVELS = [1, 2] | |
DEFAULT_DATE = OpendataClient().latest() | |
# First organize variables into categories | |
VARIABLE_GROUPS = { | |
"Surface Variables": { | |
"10u": "10m U Wind Component", | |
"10v": "10m V Wind Component", | |
"2d": "2m Dewpoint Temperature", | |
"2t": "2m Temperature", | |
"msl": "Mean Sea Level Pressure", | |
"skt": "Skin Temperature", | |
"sp": "Surface Pressure", | |
"tcw": "Total Column Water", | |
"lsm": "Land-Sea Mask", | |
"z": "Surface Geopotential", | |
"slor": "Slope of Sub-gridscale Orography", | |
"sdor": "Standard Deviation of Orography", | |
}, | |
"Soil Variables": { | |
"stl1": "Soil Temperature Level 1", | |
"stl2": "Soil Temperature Level 2", | |
"swvl1": "Soil Water Volume Level 1", | |
"swvl2": "Soil Water Volume Level 2", | |
}, | |
"Pressure Level Variables": {} # Will fill this dynamically | |
} | |
# Add pressure level variables dynamically | |
for var in ["t", "u", "v", "w", "q", "z"]: | |
var_name = { | |
"t": "Temperature", | |
"u": "U Wind Component", | |
"v": "V Wind Component", | |
"w": "Vertical Velocity", | |
"q": "Specific Humidity", | |
"z": "Geopotential" | |
}[var] | |
for level in LEVELS: | |
var_id = f"{var}_{level}" | |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" | |
def get_device(): | |
"""Determine the best available device""" | |
try: | |
import torch | |
if torch.cuda.is_available(): | |
logger.info("CUDA is available, using GPU") | |
return "cuda" | |
else: | |
logger.info("CUDA is not available, using CPU") | |
return "cpu" | |
except ImportError: | |
logger.info("PyTorch not found, using CPU") | |
return "cpu" | |
# Update the model initialization to use the detected device | |
DEVICE = get_device() | |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=DEVICE) | |
# Create and set custom temp directory | |
TEMP_DIR = Path("./gradio_temp") | |
TEMP_DIR.mkdir(exist_ok=True) | |
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR) | |
# Add these cache-related functions after the MODEL initialization | |
def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[int]) -> str: | |
"""Create a unique cache key based on the request parameters""" | |
key_parts = [ | |
date.isoformat(), | |
",".join(sorted(params)), | |
",".join(str(x) for x in sorted(levellist)) if levellist else "no_levels" | |
] | |
key_string = "_".join(key_parts) | |
cache_key = hashlib.md5(key_string.encode()).hexdigest() | |
logger.info(f"Generated cache key: {cache_key} for {key_string}") | |
return cache_key | |
def get_cache_path(cache_key: str) -> Path: | |
"""Get the path to the cache file""" | |
return TEMP_DIR / "data_cache" / f"{cache_key}.pkl" | |
def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None: | |
"""Save data to disk cache""" | |
cache_file = get_cache_path(cache_key) | |
try: | |
with open(cache_file, 'wb') as f: | |
pickle.dump(data, f) | |
logger.info(f"Successfully saved data to cache: {cache_file}") | |
except Exception as e: | |
logger.error(f"Failed to save to cache: {e}") | |
def load_from_cache(cache_key: str) -> Dict[str, Any]: | |
"""Load data from disk cache""" | |
cache_file = get_cache_path(cache_key) | |
if cache_file.exists(): | |
try: | |
with open(cache_file, 'rb') as f: | |
data = pickle.load(f) | |
logger.info(f"Successfully loaded data from cache: {cache_file}") | |
return data | |
except Exception as e: | |
logger.error(f"Failed to load from cache: {e}") | |
cache_file.unlink(missing_ok=True) | |
logger.info(f"No cache file found: {cache_file}") | |
return None | |
# Modify the get_open_data function to use caching | |
def get_cached_data(date_str: str, param_tuple: tuple, levelist_tuple: tuple) -> Dict[str, Any]: | |
"""Memory cache wrapper for get_open_data""" | |
return get_open_data_impl( | |
datetime.datetime.fromisoformat(date_str), | |
list(param_tuple), | |
list(levelist_tuple) if levelist_tuple else [] | |
) | |
def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any]: | |
"""Main function to get data with caching""" | |
if levelist is None: | |
levelist = [] | |
# Try disk cache first (more persistent than memory cache) | |
cache_key = get_cache_key(DEFAULT_DATE, param, levelist) | |
logger.info(f"Checking cache for key: {cache_key}") | |
cached_data = load_from_cache(cache_key) | |
if cached_data is not None: | |
logger.info(f"Cache hit for {cache_key}") | |
return cached_data | |
# If not in cache, download and process the data | |
logger.info(f"Cache miss for {cache_key}, downloading fresh data") | |
fields = get_open_data_impl(DEFAULT_DATE, param, levelist) | |
# Save to disk cache | |
save_to_cache(cache_key, fields) | |
return fields | |
def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]: | |
"""Implementation of data download and processing""" | |
fields = {} | |
myiterable = [date - datetime.timedelta(hours=6), date] | |
logger.info(f"Downloading data for dates: {myiterable}") | |
for current_date in myiterable: | |
logger.info(f"Fetching data for {current_date}") | |
data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist) | |
for f in data: | |
assert f.to_numpy().shape == (721, 1440) | |
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) | |
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) | |
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") | |
if name not in fields: | |
fields[name] = [] | |
fields[name].append(values) | |
# Create a single matrix for each parameter | |
for param, values in fields.items(): | |
fields[param] = np.stack(values) | |
return fields | |
def plot_forecast(state, selected_variable): | |
logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}") | |
# Setup the figure and axis | |
fig = plt.figure(figsize=(15, 8)) | |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0)) | |
# Get the coordinates | |
latitudes, longitudes = state["latitudes"], state["longitudes"] | |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes) | |
triangulation = tri.Triangulation(fixed_lons, latitudes) | |
# Get the values | |
values = state["fields"][selected_variable] | |
logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}") | |
# Set map features | |
ax.set_global() | |
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree()) | |
ax.coastlines(resolution='50m') | |
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) | |
ax.gridlines(draw_labels=True) | |
# Create contour plot | |
contour = ax.tricontourf(triangulation, values, | |
levels=20, transform=ccrs.PlateCarree(), | |
cmap='RdBu_r') | |
# Add colorbar | |
plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05) | |
# Format the date string | |
forecast_time = state["date"] | |
if isinstance(forecast_time, str): | |
forecast_time = datetime.datetime.fromisoformat(forecast_time) | |
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC") | |
# Get variable description | |
var_desc = None | |
for group in VARIABLE_GROUPS.values(): | |
if selected_variable in group: | |
var_desc = group[selected_variable] | |
break | |
var_name = var_desc if var_desc else selected_variable | |
ax.set_title(f"{var_name} - {time_str}") | |
# Save as PNG | |
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png") | |
plt.savefig(temp_file, bbox_inches='tight', dpi=100) | |
plt.close() | |
return temp_file | |
def run_forecast(date: datetime.datetime, lead_time: int, device: str = None) -> Dict[str, Any]: | |
# Use the global device if none specified | |
device = device or DEVICE | |
# Get all required fields | |
fields = {} | |
logger.info(f"Starting forecast for lead_time: {lead_time} hours on {device}") | |
# Get surface fields | |
logger.info("Getting surface fields...") | |
fields.update(get_open_data(param=PARAM_SFC)) | |
# Get soil fields and rename them | |
logger.info("Getting soil fields...") | |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) | |
mapping = { | |
'sot_1': 'stl1', 'sot_2': 'stl2', | |
'vsw_1': 'swvl1', 'vsw_2': 'swvl2' | |
} | |
for k, v in soil.items(): | |
fields[mapping[k]] = v | |
# Get pressure level fields | |
logger.info("Getting pressure level fields...") | |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) | |
# Convert geopotential height to geopotential | |
for level in LEVELS: | |
gh = fields.pop(f"gh_{level}") | |
fields[f"z_{level}"] = gh * 9.80665 | |
input_state = dict(date=date, fields=fields) | |
# Use the global model instance | |
global MODEL | |
if device != MODEL.device: | |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) | |
# Run the model and get the final state | |
final_state = None | |
for state in MODEL.run(input_state=input_state, lead_time=lead_time): | |
logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} " | |
f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}") | |
# Log a few example variables to show we have all fields | |
for var in ['2t', 'msl', 't_1000', 'z_850']: | |
if var in state['fields']: | |
values = state['fields'][var] | |
logger.info(f" {var:<6} shape={values.shape} " | |
f"min={np.min(values):.6f} " | |
f"max={np.max(values):.6f}") | |
final_state = state | |
logger.info(f"Final state contains {len(final_state['fields'])} variables") | |
return final_state | |
def get_available_variables(state): | |
"""Get available variables from the state and organize them into groups""" | |
available_vars = set(state['fields'].keys()) | |
# Create dropdown choices only for available variables | |
choices = [] | |
for group_name, variables in VARIABLE_GROUPS.items(): | |
group_vars = [(f"{desc} ({var_id})", var_id) | |
for var_id, desc in variables.items() | |
if var_id in available_vars] | |
if group_vars: # Only add group if it has available variables | |
choices.append((f"── {group_name} ──", None)) | |
choices.extend(group_vars) | |
return choices | |
def save_forecast_data(state, format='json'): | |
"""Save forecast data in specified format""" | |
if state is None: | |
raise ValueError("No forecast data available. Please run a forecast first.") | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
forecast_time = state['date'].strftime("%Y%m%d_%H") if isinstance(state['date'], datetime.datetime) else state['date'] | |
# Use forecasts directory for all outputs | |
output_dir = TEMP_DIR / "forecasts" | |
if format == 'json': | |
# Create a JSON-serializable dictionary | |
data = { | |
'metadata': { | |
'forecast_date': forecast_time, | |
'export_date': datetime.datetime.now().isoformat(), | |
'total_points': len(state['latitudes']), | |
'total_variables': len(state['fields']) | |
}, | |
'coordinates': { | |
'latitudes': state['latitudes'].tolist(), | |
'longitudes': state['longitudes'].tolist() | |
}, | |
'fields': { | |
var_name: { | |
'values': values.tolist(), | |
'statistics': { | |
'min': float(np.min(values)), | |
'max': float(np.max(values)), | |
'mean': float(np.mean(values)), | |
'std': float(np.std(values)) | |
} | |
} | |
for var_name, values in state['fields'].items() | |
} | |
} | |
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.json" | |
with open(output_file, 'w') as f: | |
json.dump(data, f, indent=2) | |
return str(output_file) | |
elif format == 'netcdf': | |
# Create an xarray Dataset | |
data_vars = {} | |
coords = { | |
'point': np.arange(len(state['latitudes'])), | |
'latitude': ('point', state['latitudes']), | |
'longitude': ('point', state['longitudes']), | |
} | |
# Add each field as a variable | |
for var_name, values in state['fields'].items(): | |
data_vars[var_name] = (['point'], values) | |
# Create the dataset | |
ds = xr.Dataset( | |
data_vars=data_vars, | |
coords=coords, | |
attrs={ | |
'forecast_date': forecast_time, | |
'export_date': datetime.datetime.now().isoformat(), | |
'description': 'AIFS Weather Forecast Data' | |
} | |
) | |
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.nc" | |
ds.to_netcdf(output_file) | |
return str(output_file) | |
elif format == 'csv': | |
# Create a DataFrame with lat/lon and all variables | |
df = pd.DataFrame({ | |
'latitude': state['latitudes'], | |
'longitude': state['longitudes'] | |
}) | |
# Add each field as a column | |
for var_name, values in state['fields'].items(): | |
df[var_name] = values | |
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.csv" | |
df.to_csv(output_file, index=False) | |
return str(output_file) | |
else: | |
raise ValueError(f"Unsupported format: {format}") | |
# Create dropdown choices with groups | |
DROPDOWN_CHOICES = [] | |
for group_name, variables in VARIABLE_GROUPS.items(): | |
# Add group separator | |
DROPDOWN_CHOICES.append((f"── {group_name} ──", None)) | |
# Add variables in this group | |
for var_id, desc in sorted(variables.items()): | |
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id)) | |
def update_interface(): | |
with gr.Blocks(css=""" | |
.centered-header { | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
.subtitle { | |
font-size: 1.2em; | |
line-height: 1.5; | |
margin: 20px 0; | |
} | |
.footer { | |
text-align: center; | |
padding: 20px; | |
margin-top: 20px; | |
border-top: 1px solid #eee; | |
} | |
""") as demo: | |
forecast_state = gr.State(None) | |
# Header section | |
gr.Markdown(f""" | |
# AIFS Weather Forecast | |
<div class="subtitle"> | |
Interactive visualization of ECMWF AIFS weather forecasts.<br> | |
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br> | |
select how many hours ahead you want to forecast and which meteorological variable to visualize. | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
lead_time = gr.Slider( | |
minimum=6, | |
maximum=48, | |
step=6, | |
value=12, | |
label="Forecast Hours Ahead" | |
) | |
# Start with the original DROPDOWN_CHOICES | |
variable = gr.Dropdown( | |
choices=DROPDOWN_CHOICES, # Use original choices at startup | |
value="2t", | |
label="Select Variable to Plot" | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
run_btn = gr.Button("Run Forecast", variant="primary") | |
download_nc = gr.Button("Download Forecast (NetCDF)") | |
download_output = gr.File(label="Download Output") | |
with gr.Column(scale=2): | |
forecast_output = gr.Image() | |
def run_and_store(lead_time): | |
"""Run forecast and store state""" | |
forecast_state = run_forecast(DEFAULT_DATE, lead_time, DEVICE) # Use global DEVICE | |
plot = plot_forecast(forecast_state, "2t") | |
return forecast_state, plot | |
def update_plot_from_state(forecast_state, variable): | |
"""Update plot using stored state""" | |
if forecast_state is None or variable is None: | |
return None | |
try: | |
return plot_forecast(forecast_state, variable) | |
except KeyError as e: | |
logger.error(f"Variable {variable} not found in state: {e}") | |
return None | |
def clear(): | |
"""Clear everything""" | |
return [None, None, 12, "2t"] | |
def save_netcdf(forecast_state): | |
"""Save forecast data as NetCDF""" | |
if forecast_state is None: | |
raise ValueError("No forecast data available. Please run a forecast first.") | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
forecast_time = forecast_state['date'].strftime("%Y%m%d_%H") if isinstance(forecast_state['date'], datetime.datetime) else forecast_state['date'] | |
# Create an xarray Dataset | |
data_vars = {} | |
coords = { | |
'point': np.arange(len(forecast_state['latitudes'])), | |
'latitude': ('point', forecast_state['latitudes']), | |
'longitude': ('point', forecast_state['longitudes']), | |
} | |
# Add each field as a variable | |
for var_name, values in forecast_state['fields'].items(): | |
data_vars[var_name] = (['point'], values) | |
# Create the dataset | |
ds = xr.Dataset( | |
data_vars=data_vars, | |
coords=coords, | |
attrs={ | |
'forecast_date': forecast_time, | |
'export_date': datetime.datetime.now().isoformat(), | |
'description': 'AIFS Weather Forecast Data' | |
} | |
) | |
output_file = TEMP_DIR / "forecasts" / f"forecast_{forecast_time}_{timestamp}.nc" | |
ds.to_netcdf(output_file) | |
return str(output_file) | |
# Connect the components | |
run_btn.click( | |
fn=run_and_store, | |
inputs=[lead_time], | |
outputs=[forecast_state, forecast_output] | |
) | |
variable.change( | |
fn=update_plot_from_state, | |
inputs=[forecast_state, variable], | |
outputs=forecast_output | |
) | |
clear_btn.click( | |
fn=clear, | |
inputs=[], | |
outputs=[forecast_state, forecast_output, lead_time, variable] | |
) | |
download_nc.click( | |
fn=save_netcdf, | |
inputs=[forecast_state], | |
outputs=[download_output] | |
) | |
return demo | |
# Create and launch the interface | |
demo = update_interface() | |
demo.launch() | |
def setup_directories(): | |
"""Create necessary directories with .keep files""" | |
# Define all required directories | |
directories = { | |
TEMP_DIR / "data_cache": "Cache directory for downloaded weather data", | |
TEMP_DIR / "forecasts": "Directory for forecast outputs (plots and data files)", | |
} | |
# Create directories and .keep files | |
for directory, description in directories.items(): | |
directory.mkdir(parents=True, exist_ok=True) | |
keep_file = directory / ".keep" | |
if not keep_file.exists(): | |
keep_file.write_text(f"# {description}\n# This file ensures the directory is tracked in git\n") | |
logger.info(f"Created directory and .keep file: {directory}") | |
# Call it during initialization | |
setup_directories() | |
def cleanup_old_files(): | |
"""Remove old temporary and cache files""" | |
current_time = datetime.datetime.now().timestamp() | |
# Clean up forecast files (1 hour old) | |
forecast_dir = TEMP_DIR / "forecasts" | |
for file in forecast_dir.glob("*.*"): | |
if file.name == ".keep": | |
continue | |
if current_time - file.stat().st_mtime > 3600: | |
logger.info(f"Removing old forecast file: {file}") | |
file.unlink(missing_ok=True) | |
# Clean up cache files (24 hours old) | |
cache_dir = TEMP_DIR / "data_cache" | |
for file in cache_dir.glob("*.pkl"): | |
if file.name == ".keep": | |
continue | |
if current_time - file.stat().st_mtime > 86400: | |
logger.info(f"Removing old cache file: {file}") | |
file.unlink(missing_ok=True) | |