File size: 4,080 Bytes
a952d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
# Set memory optimization environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'

import gradio as gr
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.tri as tri
from anemoi.inference.runners.simple import SimpleRunner
from ecmwf.opendata import Client as OpendataClient
import earthkit.data as ekd
import earthkit.regrid as ekr

# Define parameters (updating to match notebook.py)
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
DEFAULT_DATE = OpendataClient().latest()

def get_open_data(param, levelist=[]):
    fields = {}
    # Get the data for the current date and the previous date
    for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
        data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
        for f in data:
            assert f.to_numpy().shape == (721, 1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            if name not in fields:
                fields[name] = []
            fields[name].append(values)
    
    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)
    
    return fields

def run_forecast(date, lead_time, device):
    # Get all required fields
    fields = {}
    
    # Get surface fields
    fields.update(get_open_data(param=PARAM_SFC))
    
    # Get soil fields and rename them
    soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
    mapping = {
        'sot_1': 'stl1', 'sot_2': 'stl2',
        'vsw_1': 'swvl1', 'vsw_2': 'swvl2'
    }
    for k, v in soil.items():
        fields[mapping[k]] = v
    
    # Get pressure level fields
    fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
    
    # Convert geopotential height to geopotential
    for level in LEVELS:
        gh = fields.pop(f"gh_{level}")
        fields[f"z_{level}"] = gh * 9.80665
    
    input_state = dict(date=date, fields=fields)
    runner = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
    results = []
    for state in runner.run(input_state=input_state, lead_time=lead_time):
        results.append(state)
    return results[-1]

def plot_forecast(state):
    latitudes, longitudes = state["latitudes"], state["longitudes"]
    values = state["fields"]["100u"]
    fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()})
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle=":")
    triangulation = tri.Triangulation(longitudes, latitudes)
    contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap="RdBu")
    plt.title(f"100m winds at {state['date']}")
    plt.colorbar(contour)
    return fig

def gradio_interface(date_str, lead_time, device):
    try:
        date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
    except ValueError:
        raise gr.Error("Please enter a valid date in YYYY-MM-DD format")
    state = run_forecast(date, lead_time, device)
    return plot_forecast(state)

demo = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(value=DEFAULT_DATE.strftime("%Y-%m-%d"), label="Forecast Date (YYYY-MM-DD)"),
        gr.Slider(minimum=6, maximum=48, step=6, value=12, label="Lead Time (Hours)"),
        gr.Radio(choices=["cuda", "cpu"], value="cuda", label="Compute Device")
    ],
    outputs=gr.Plot(),
    title="AIFS Weather Forecast",
    description="Run ECMWF AIFS forecasts based on selected parameters."
)

demo.launch()