Spaces:
Build error
Build error
import os | |
# Set memory optimization environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' | |
import gradio as gr | |
import datetime | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cartopy.crs as ccrs | |
import cartopy.feature as cfeature | |
import matplotlib.tri as tri | |
from anemoi.inference.runners.simple import SimpleRunner | |
from ecmwf.opendata import Client as OpendataClient | |
import earthkit.data as ekd | |
import earthkit.regrid as ekr | |
# Define parameters (updating to match notebook.py) | |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] | |
PARAM_SOIL = ["vsw", "sot"] | |
PARAM_PL = ["gh", "t", "u", "v", "w", "q"] | |
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] | |
SOIL_LEVELS = [1, 2] | |
DEFAULT_DATE = OpendataClient().latest() | |
def get_open_data(param, levelist=[]): | |
fields = {} | |
# Get the data for the current date and the previous date | |
for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]: | |
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) | |
for f in data: | |
assert f.to_numpy().shape == (721, 1440) | |
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) | |
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) | |
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") | |
if name not in fields: | |
fields[name] = [] | |
fields[name].append(values) | |
# Create a single matrix for each parameter | |
for param, values in fields.items(): | |
fields[param] = np.stack(values) | |
return fields | |
def run_forecast(date, lead_time, device): | |
# Get all required fields | |
fields = {} | |
# Get surface fields | |
fields.update(get_open_data(param=PARAM_SFC)) | |
# Get soil fields and rename them | |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) | |
mapping = { | |
'sot_1': 'stl1', 'sot_2': 'stl2', | |
'vsw_1': 'swvl1', 'vsw_2': 'swvl2' | |
} | |
for k, v in soil.items(): | |
fields[mapping[k]] = v | |
# Get pressure level fields | |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) | |
# Convert geopotential height to geopotential | |
for level in LEVELS: | |
gh = fields.pop(f"gh_{level}") | |
fields[f"z_{level}"] = gh * 9.80665 | |
input_state = dict(date=date, fields=fields) | |
runner = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) | |
results = [] | |
for state in runner.run(input_state=input_state, lead_time=lead_time): | |
results.append(state) | |
return results[-1] | |
def plot_forecast(state): | |
latitudes, longitudes = state["latitudes"], state["longitudes"] | |
values = state["fields"]["100u"] | |
fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()}) | |
ax.coastlines() | |
ax.add_feature(cfeature.BORDERS, linestyle=":") | |
triangulation = tri.Triangulation(longitudes, latitudes) | |
contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap="RdBu") | |
plt.title(f"100m winds at {state['date']}") | |
plt.colorbar(contour) | |
return fig | |
def gradio_interface(date_str, lead_time, device): | |
try: | |
date = datetime.datetime.strptime(date_str, "%Y-%m-%d") | |
except ValueError: | |
raise gr.Error("Please enter a valid date in YYYY-MM-DD format") | |
state = run_forecast(date, lead_time, device) | |
return plot_forecast(state) | |
demo = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(value=DEFAULT_DATE.strftime("%Y-%m-%d"), label="Forecast Date (YYYY-MM-DD)"), | |
gr.Slider(minimum=6, maximum=48, step=6, value=12, label="Lead Time (Hours)"), | |
gr.Radio(choices=["cuda", "cpu"], value="cuda", label="Compute Device") | |
], | |
outputs=gr.Plot(), | |
title="AIFS Weather Forecast", | |
description="Run ECMWF AIFS forecasts based on selected parameters." | |
) | |
demo.launch() | |