File size: 8,097 Bytes
a952d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62a6171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f7649
 
62a6171
93f7649
 
 
 
 
 
 
 
 
 
 
62a6171
 
93f7649
45b15ae
 
 
a952d46
 
 
45b15ae
 
a952d46
45b15ae
 
a952d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b15ae
 
 
 
 
 
 
a952d46
45b15ae
a952d46
 
 
93f7649
a952d46
93f7649
62a6171
9a29800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62a6171
 
 
9a29800
93f7649
9a29800
 
a952d46
 
62a6171
 
 
 
 
 
 
 
93f7649
9a29800
 
 
 
 
 
 
 
 
 
 
45b15ae
9a29800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a952d46
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os
# Set memory optimization environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'

import gradio as gr
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.tri as tri
from anemoi.inference.runners.simple import SimpleRunner
from ecmwf.opendata import Client as OpendataClient
import earthkit.data as ekd
import earthkit.regrid as ekr

# Define parameters (updating to match notebook.py)
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
DEFAULT_DATE = OpendataClient().latest()

# First organize variables into categories
VARIABLE_GROUPS = {
    "Surface Variables": {
        "10u": "10m U Wind Component",
        "10v": "10m V Wind Component",
        "2d": "2m Dewpoint Temperature",
        "2t": "2m Temperature",
        "msl": "Mean Sea Level Pressure",
        "skt": "Skin Temperature",
        "sp": "Surface Pressure",
        "tcw": "Total Column Water",
        "lsm": "Land-Sea Mask",
        "z": "Surface Geopotential",
        "slor": "Slope of Sub-gridscale Orography",
        "sdor": "Standard Deviation of Orography",
    },
    "Soil Variables": {
        "stl1": "Soil Temperature Level 1",
        "stl2": "Soil Temperature Level 2",
        "swvl1": "Soil Water Volume Level 1",
        "swvl2": "Soil Water Volume Level 2",
    },
    "Pressure Level Variables": {}  # Will fill this dynamically
}

# Add pressure level variables dynamically
for var in ["t", "u", "v", "w", "q", "z"]:
    var_name = {
        "t": "Temperature",
        "u": "U Wind Component",
        "v": "V Wind Component",
        "w": "Vertical Velocity",
        "q": "Specific Humidity",
        "z": "Geopotential"
    }[var]
    
    for level in LEVELS:
        var_id = f"{var}_{level}"
        VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"

# Load the model once at startup
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda")  # Default to CUDA

def get_open_data(param, levelist=[]):
    fields = {}
    # Get the data for the current date and the previous date
    myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]
    print(myiterable)
    for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
        print(f"Fetching data for {date}")
        # sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57
        data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
        for f in data:
            assert f.to_numpy().shape == (721, 1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            if name not in fields:
                fields[name] = []
            fields[name].append(values)
    
    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)
    
    return fields

def run_forecast(date, lead_time, device):
    # Get all required fields
    fields = {}
    
    # Get surface fields
    fields.update(get_open_data(param=PARAM_SFC))
    
    # Get soil fields and rename them
    soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
    mapping = {
        'sot_1': 'stl1', 'sot_2': 'stl2',
        'vsw_1': 'swvl1', 'vsw_2': 'swvl2'
    }
    for k, v in soil.items():
        fields[mapping[k]] = v
    
    # Get pressure level fields
    fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
    
    # Convert geopotential height to geopotential
    for level in LEVELS:
        gh = fields.pop(f"gh_{level}")
        fields[f"z_{level}"] = gh * 9.80665
    
    input_state = dict(date=date, fields=fields)
    
    # Use the global model instance
    global MODEL
    # If device preference changed, move model to new device
    if device != MODEL.device:
        MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
    
    results = []
    for state in MODEL.run(input_state=input_state, lead_time=lead_time):
        results.append(state)
    return results[-1]

def plot_forecast(state, selected_variable):
    latitudes, longitudes = state["latitudes"], state["longitudes"]
    values = state["fields"][selected_variable]
    
    # Create figure with specific projection centered on 0Β°
    fig = plt.figure(figsize=(15, 8))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
    
    ax.set_global()
    ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree())
    
    # Add map features
    ax.coastlines(resolution='50m')
    ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
    ax.gridlines(draw_labels=True)
    
    # Fix longitudes to be -180 to 180
    fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
    
    # Create triangulation with fixed longitudes
    triangulation = tri.Triangulation(fixed_lons, latitudes)
    
    # Create the contour plot
    contour = ax.tricontourf(triangulation, values, levels=20, 
                            transform=ccrs.PlateCarree(), 
                            cmap='RdBu_r')
    
    plt.title(f"{selected_variable} at {state['date']}")
    plt.colorbar(contour, orientation='horizontal', pad=0.05)
    
    return fig

# Create dropdown choices with groups
DROPDOWN_CHOICES = []
for group_name, variables in VARIABLE_GROUPS.items():
    # Add group separator
    DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
    # Add variables in this group
    for var_id, desc in sorted(variables.items()):
        DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))

with gr.Blocks(css="""
    .centered-header {
        text-align: center;
        margin-bottom: 20px;
    }
""") as demo:
    # Centered header section
    gr.Markdown(f"""
    # AIFS Weather Forecast
    
    Interactive visualization of ECMWF AIFS weather forecasts. Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}), 
    select how many hours ahead you want to forecast and which meteorological variable to visualize.
    """, elem_classes=["centered-header"])
    
    with gr.Row():
        # Controls column - takes up 1/3 of the width
        with gr.Column(scale=1):
            lead_time = gr.Slider(
                minimum=6, 
                maximum=48, 
                step=6, 
                value=12, 
                label="Forecast Hours Ahead"
            )
            variable = gr.Dropdown(
                choices=DROPDOWN_CHOICES,
                value="2t",
                label="Select Variable to Plot",
                info="Choose a meteorological variable to visualize"
            )
            
            # Add buttons in a row
            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit", variant="primary")
        
        # Map column - takes up 2/3 of the width
        with gr.Column(scale=2):
            plot_output = gr.Plot()
    
    # Connect the inputs to the forecast function
    def update_plot(lead_time, variable):
        state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
        return plot_forecast(state, variable)
    
    # Clear function to reset to defaults
    def clear():
        return [
            gr.Slider.update(value=12),
            gr.Dropdown.update(value="2t"),
            None  # Clear the plot
        ]
    
    # Connect the buttons
    submit_btn.click(fn=update_plot, inputs=[lead_time, variable], outputs=plot_output)
    clear_btn.click(fn=clear, inputs=[], outputs=[lead_time, variable, plot_output])

demo.launch()