Spaces:
Build error
Build error
import os | |
# Set memory optimization environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' | |
import gradio as gr | |
import datetime | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cartopy.crs as ccrs | |
import cartopy.feature as cfeature | |
import matplotlib.tri as tri | |
from anemoi.inference.runners.simple import SimpleRunner | |
from ecmwf.opendata import Client as OpendataClient | |
import earthkit.data as ekd | |
import earthkit.regrid as ekr | |
# Define parameters (updating to match notebook.py) | |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] | |
PARAM_SOIL = ["vsw", "sot"] | |
PARAM_PL = ["gh", "t", "u", "v", "w", "q"] | |
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] | |
SOIL_LEVELS = [1, 2] | |
DEFAULT_DATE = OpendataClient().latest() | |
# First organize variables into categories | |
VARIABLE_GROUPS = { | |
"Surface Variables": { | |
"10u": "10m U Wind Component", | |
"10v": "10m V Wind Component", | |
"2d": "2m Dewpoint Temperature", | |
"2t": "2m Temperature", | |
"msl": "Mean Sea Level Pressure", | |
"skt": "Skin Temperature", | |
"sp": "Surface Pressure", | |
"tcw": "Total Column Water", | |
"lsm": "Land-Sea Mask", | |
"z": "Surface Geopotential", | |
"slor": "Slope of Sub-gridscale Orography", | |
"sdor": "Standard Deviation of Orography", | |
}, | |
"Soil Variables": { | |
"stl1": "Soil Temperature Level 1", | |
"stl2": "Soil Temperature Level 2", | |
"swvl1": "Soil Water Volume Level 1", | |
"swvl2": "Soil Water Volume Level 2", | |
}, | |
"Pressure Level Variables": {} # Will fill this dynamically | |
} | |
# Add pressure level variables dynamically | |
for var in ["t", "u", "v", "w", "q", "z"]: | |
var_name = { | |
"t": "Temperature", | |
"u": "U Wind Component", | |
"v": "V Wind Component", | |
"w": "Vertical Velocity", | |
"q": "Specific Humidity", | |
"z": "Geopotential" | |
}[var] | |
for level in LEVELS: | |
var_id = f"{var}_{level}" | |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" | |
# Load the model once at startup | |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA | |
def get_open_data(param, levelist=[]): | |
fields = {} | |
# Get the data for the current date and the previous date | |
myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE] | |
print(myiterable) | |
for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]: | |
print(f"Fetching data for {date}") | |
# sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57 | |
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) | |
for f in data: | |
assert f.to_numpy().shape == (721, 1440) | |
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) | |
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) | |
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") | |
if name not in fields: | |
fields[name] = [] | |
fields[name].append(values) | |
# Create a single matrix for each parameter | |
for param, values in fields.items(): | |
fields[param] = np.stack(values) | |
return fields | |
def run_forecast(date, lead_time, device): | |
# Get all required fields | |
fields = {} | |
# Get surface fields | |
fields.update(get_open_data(param=PARAM_SFC)) | |
# Get soil fields and rename them | |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) | |
mapping = { | |
'sot_1': 'stl1', 'sot_2': 'stl2', | |
'vsw_1': 'swvl1', 'vsw_2': 'swvl2' | |
} | |
for k, v in soil.items(): | |
fields[mapping[k]] = v | |
# Get pressure level fields | |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) | |
# Convert geopotential height to geopotential | |
for level in LEVELS: | |
gh = fields.pop(f"gh_{level}") | |
fields[f"z_{level}"] = gh * 9.80665 | |
input_state = dict(date=date, fields=fields) | |
# Use the global model instance | |
global MODEL | |
# If device preference changed, move model to new device | |
if device != MODEL.device: | |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) | |
results = [] | |
for state in MODEL.run(input_state=input_state, lead_time=lead_time): | |
results.append(state) | |
return results[-1] | |
def plot_forecast(state, selected_variable): | |
latitudes, longitudes = state["latitudes"], state["longitudes"] | |
values = state["fields"][selected_variable] | |
# Create figure with specific projection centered on 0Β° | |
fig = plt.figure(figsize=(15, 8)) | |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0)) | |
ax.set_global() | |
ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree()) | |
# Add map features | |
ax.coastlines(resolution='50m') | |
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) | |
ax.gridlines(draw_labels=True) | |
# Fix longitudes to be -180 to 180 | |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes) | |
# Create triangulation with fixed longitudes | |
triangulation = tri.Triangulation(fixed_lons, latitudes) | |
# Create the contour plot | |
contour = ax.tricontourf(triangulation, values, levels=20, | |
transform=ccrs.PlateCarree(), | |
cmap='RdBu_r') | |
plt.title(f"{selected_variable} at {state['date']}") | |
plt.colorbar(contour, orientation='horizontal', pad=0.05) | |
return fig | |
# Create dropdown choices with groups | |
DROPDOWN_CHOICES = [] | |
for group_name, variables in VARIABLE_GROUPS.items(): | |
# Add group separator | |
DROPDOWN_CHOICES.append((f"ββ {group_name} ββ", None)) | |
# Add variables in this group | |
for var_id, desc in sorted(variables.items()): | |
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id)) | |
with gr.Blocks(css=""" | |
.centered-header { | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
""") as demo: | |
# Centered header section | |
gr.Markdown(f""" | |
# AIFS Weather Forecast | |
Interactive visualization of ECMWF AIFS weather forecasts. Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}), | |
select how many hours ahead you want to forecast and which meteorological variable to visualize. | |
""", elem_classes=["centered-header"]) | |
with gr.Row(): | |
# Controls column - takes up 1/3 of the width | |
with gr.Column(scale=1): | |
lead_time = gr.Slider( | |
minimum=6, | |
maximum=48, | |
step=6, | |
value=12, | |
label="Forecast Hours Ahead" | |
) | |
variable = gr.Dropdown( | |
choices=DROPDOWN_CHOICES, | |
value="2t", | |
label="Select Variable to Plot", | |
info="Choose a meteorological variable to visualize" | |
) | |
# Add buttons in a row | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
submit_btn = gr.Button("Submit", variant="primary") | |
# Map column - takes up 2/3 of the width | |
with gr.Column(scale=2): | |
plot_output = gr.Plot() | |
# Connect the inputs to the forecast function | |
def update_plot(lead_time, variable): | |
state = run_forecast(DEFAULT_DATE, lead_time, "cuda") | |
return plot_forecast(state, variable) | |
# Clear function to reset to defaults | |
def clear(): | |
return [ | |
gr.Slider.update(value=12), | |
gr.Dropdown.update(value="2t"), | |
None # Clear the plot | |
] | |
# Connect the buttons | |
submit_btn.click(fn=update_plot, inputs=[lead_time, variable], outputs=plot_output) | |
clear_btn.click(fn=clear, inputs=[], outputs=[lead_time, variable, plot_output]) | |
demo.launch() | |