Spaces:
Build error
Build error
File size: 6,805 Bytes
a952d46 62a6171 93f7649 62a6171 93f7649 62a6171 93f7649 45b15ae a952d46 45b15ae a952d46 45b15ae a952d46 45b15ae a952d46 45b15ae a952d46 93f7649 a952d46 93f7649 a952d46 62a6171 93f7649 a952d46 62a6171 93f7649 45b15ae 93f7649 a952d46 45b15ae 93f7649 62a6171 93f7649 a952d46 45b15ae a952d46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
# Set memory optimization environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'
import gradio as gr
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.tri as tri
from anemoi.inference.runners.simple import SimpleRunner
from ecmwf.opendata import Client as OpendataClient
import earthkit.data as ekd
import earthkit.regrid as ekr
# Define parameters (updating to match notebook.py)
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
DEFAULT_DATE = OpendataClient().latest()
# First organize variables into categories
VARIABLE_GROUPS = {
"Surface Variables": {
"10u": "10m U Wind Component",
"10v": "10m V Wind Component",
"2d": "2m Dewpoint Temperature",
"2t": "2m Temperature",
"msl": "Mean Sea Level Pressure",
"skt": "Skin Temperature",
"sp": "Surface Pressure",
"tcw": "Total Column Water",
"lsm": "Land-Sea Mask",
"z": "Surface Geopotential",
"slor": "Slope of Sub-gridscale Orography",
"sdor": "Standard Deviation of Orography",
},
"Soil Variables": {
"stl1": "Soil Temperature Level 1",
"stl2": "Soil Temperature Level 2",
"swvl1": "Soil Water Volume Level 1",
"swvl2": "Soil Water Volume Level 2",
},
"Pressure Level Variables": {} # Will fill this dynamically
}
# Add pressure level variables dynamically
for var in ["t", "u", "v", "w", "q", "z"]:
var_name = {
"t": "Temperature",
"u": "U Wind Component",
"v": "V Wind Component",
"w": "Vertical Velocity",
"q": "Specific Humidity",
"z": "Geopotential"
}[var]
for level in LEVELS:
var_id = f"{var}_{level}"
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
# Load the model once at startup
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
def get_open_data(param, levelist=[]):
fields = {}
# Get the data for the current date and the previous date
myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]
print(myiterable)
for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
print(f"Fetching data for {date}")
# sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
for f in data:
assert f.to_numpy().shape == (721, 1440)
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
if name not in fields:
fields[name] = []
fields[name].append(values)
# Create a single matrix for each parameter
for param, values in fields.items():
fields[param] = np.stack(values)
return fields
def run_forecast(date, lead_time, device):
# Get all required fields
fields = {}
# Get surface fields
fields.update(get_open_data(param=PARAM_SFC))
# Get soil fields and rename them
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
mapping = {
'sot_1': 'stl1', 'sot_2': 'stl2',
'vsw_1': 'swvl1', 'vsw_2': 'swvl2'
}
for k, v in soil.items():
fields[mapping[k]] = v
# Get pressure level fields
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
# Convert geopotential height to geopotential
for level in LEVELS:
gh = fields.pop(f"gh_{level}")
fields[f"z_{level}"] = gh * 9.80665
input_state = dict(date=date, fields=fields)
# Use the global model instance
global MODEL
# If device preference changed, move model to new device
if device != MODEL.device:
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
results = []
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
results.append(state)
return results[-1]
def plot_forecast(state, selected_variable):
latitudes, longitudes = state["latitudes"], state["longitudes"]
values = state["fields"][selected_variable]
fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()})
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=":")
triangulation = tri.Triangulation(longitudes, latitudes)
# Use 'RdBu_r' instead of 'RdBu' to reverse the color scheme
contour = ax.tricontourf(triangulation, values, levels=20,
transform=ccrs.PlateCarree(),
cmap='RdBu_r')
plt.title(f"{selected_variable} at {state['date']}")
plt.colorbar(contour)
return fig
# Create dropdown choices with groups
DROPDOWN_CHOICES = []
for group_name, variables in VARIABLE_GROUPS.items():
# Add group separator
DROPDOWN_CHOICES.append((f"ββ {group_name} ββ", None))
# Add variables in this group
for var_id, desc in sorted(variables.items()):
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
def gradio_interface(lead_time, selected_variable):
# Use the global latest date
global DEFAULT_DATE
state = run_forecast(DEFAULT_DATE, lead_time, "cuda") # Always use CUDA
return plot_forecast(state, selected_variable)
demo = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Slider(
minimum=6,
maximum=48,
step=6,
value=12,
label="Forecast Hours Ahead",
info=f"Latest data available from: {DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}"
),
gr.Dropdown(
choices=DROPDOWN_CHOICES,
value="2t", # Default to 2m temperature
label="Select Variable to Plot",
info="Choose a meteorological variable to visualize"
)
],
outputs=gr.Plot(),
title="AIFS Weather Forecast",
description=f"""
Interactive visualization of ECMWF AIFS weather forecasts.
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),
select how many hours ahead you want to forecast and which meteorological variable to visualize.
"""
)
demo.launch()
|