saburq commited on
Commit
a952d46
·
0 Parent(s):
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +109 -0
  3. notebook.py +184 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ aifs-single-mse-1.0.ckpt
2
+ flagged/
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Set memory optimization environment variables
3
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
4
+ os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'
5
+
6
+ import gradio as gr
7
+ import datetime
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import cartopy.crs as ccrs
11
+ import cartopy.feature as cfeature
12
+ import matplotlib.tri as tri
13
+ from anemoi.inference.runners.simple import SimpleRunner
14
+ from ecmwf.opendata import Client as OpendataClient
15
+ import earthkit.data as ekd
16
+ import earthkit.regrid as ekr
17
+
18
+ # Define parameters (updating to match notebook.py)
19
+ PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
20
+ PARAM_SOIL = ["vsw", "sot"]
21
+ PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
22
+ LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
23
+ SOIL_LEVELS = [1, 2]
24
+ DEFAULT_DATE = OpendataClient().latest()
25
+
26
+ def get_open_data(param, levelist=[]):
27
+ fields = {}
28
+ # Get the data for the current date and the previous date
29
+ for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
30
+ data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
31
+ for f in data:
32
+ assert f.to_numpy().shape == (721, 1440)
33
+ values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
34
+ values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
35
+ name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
36
+ if name not in fields:
37
+ fields[name] = []
38
+ fields[name].append(values)
39
+
40
+ # Create a single matrix for each parameter
41
+ for param, values in fields.items():
42
+ fields[param] = np.stack(values)
43
+
44
+ return fields
45
+
46
+ def run_forecast(date, lead_time, device):
47
+ # Get all required fields
48
+ fields = {}
49
+
50
+ # Get surface fields
51
+ fields.update(get_open_data(param=PARAM_SFC))
52
+
53
+ # Get soil fields and rename them
54
+ soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
55
+ mapping = {
56
+ 'sot_1': 'stl1', 'sot_2': 'stl2',
57
+ 'vsw_1': 'swvl1', 'vsw_2': 'swvl2'
58
+ }
59
+ for k, v in soil.items():
60
+ fields[mapping[k]] = v
61
+
62
+ # Get pressure level fields
63
+ fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
64
+
65
+ # Convert geopotential height to geopotential
66
+ for level in LEVELS:
67
+ gh = fields.pop(f"gh_{level}")
68
+ fields[f"z_{level}"] = gh * 9.80665
69
+
70
+ input_state = dict(date=date, fields=fields)
71
+ runner = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
72
+ results = []
73
+ for state in runner.run(input_state=input_state, lead_time=lead_time):
74
+ results.append(state)
75
+ return results[-1]
76
+
77
+ def plot_forecast(state):
78
+ latitudes, longitudes = state["latitudes"], state["longitudes"]
79
+ values = state["fields"]["100u"]
80
+ fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()})
81
+ ax.coastlines()
82
+ ax.add_feature(cfeature.BORDERS, linestyle=":")
83
+ triangulation = tri.Triangulation(longitudes, latitudes)
84
+ contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap="RdBu")
85
+ plt.title(f"100m winds at {state['date']}")
86
+ plt.colorbar(contour)
87
+ return fig
88
+
89
+ def gradio_interface(date_str, lead_time, device):
90
+ try:
91
+ date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
92
+ except ValueError:
93
+ raise gr.Error("Please enter a valid date in YYYY-MM-DD format")
94
+ state = run_forecast(date, lead_time, device)
95
+ return plot_forecast(state)
96
+
97
+ demo = gr.Interface(
98
+ fn=gradio_interface,
99
+ inputs=[
100
+ gr.Textbox(value=DEFAULT_DATE.strftime("%Y-%m-%d"), label="Forecast Date (YYYY-MM-DD)"),
101
+ gr.Slider(minimum=6, maximum=48, step=6, value=12, label="Lead Time (Hours)"),
102
+ gr.Radio(choices=["cuda", "cpu"], value="cuda", label="Compute Device")
103
+ ],
104
+ outputs=gr.Plot(),
105
+ title="AIFS Weather Forecast",
106
+ description="Run ECMWF AIFS forecasts based on selected parameters."
107
+ )
108
+
109
+ demo.launch()
notebook.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """run_AIFS_v1.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/#fileId=https%3A//huggingface.co/ecmwf/aifs-single-1.0/blob/94409c197d36d39c380467c6f3130a2be6eb722b/run_AIFS_v1.ipynb
8
+
9
+ This notebook runs ECMWF's aifs-single-v1 data-driven model, using ECMWF's [open data](https://www.ecmwf.int/en/forecasts/datasets/open-data) dataset and the [anemoi-inference](https://anemoi-inference.readthedocs.io/en/latest/apis/level1.html) package.
10
+
11
+ # 1. Install Required Packages and Imports
12
+ """
13
+
14
+ # Uncomment the lines below to install the required packages
15
+
16
+ # !pip install -q anemoi-inference[huggingface]==0.4.9 anemoi-models==0.3.1
17
+ # !pip install -q earthkit-regrid==0.4.0 ecmwf-opendata
18
+ # !pip install -q flash_attn
19
+
20
+ import datetime
21
+ from collections import defaultdict
22
+
23
+ import numpy as np
24
+ import earthkit.data as ekd
25
+ import earthkit.regrid as ekr
26
+
27
+ from anemoi.inference.runners.simple import SimpleRunner
28
+ from anemoi.inference.outputs.printer import print_state
29
+
30
+ from ecmwf.opendata import Client as OpendataClient
31
+
32
+ """# 2. Retrieve Initial Conditions from ECMWF Open Data
33
+
34
+ ### List of parameters to retrieve form ECMWF open data
35
+ """
36
+
37
+ PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
38
+ PARAM_SOIL =["vsw","sot"]
39
+ PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
40
+ LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
41
+ SOIL_LEVELS = [1,2]
42
+
43
+ """### Select a date"""
44
+
45
+ DATE = OpendataClient().latest()
46
+
47
+ print("Initial date is", DATE)
48
+
49
+ """### Get the data from the ECMWF Open Data API"""
50
+
51
+ def get_open_data(param, levelist=[]):
52
+ fields = defaultdict(list)
53
+ # Get the data for the current date and the previous date
54
+ for date in [DATE - datetime.timedelta(hours=6), DATE]:
55
+ data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
56
+ for f in data:
57
+ # Open data is between -180 and 180, we need to shift it to 0-360
58
+ assert f.to_numpy().shape == (721,1440)
59
+ values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
60
+ # Interpolate the data to from 0.25 to N320
61
+ values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
62
+ # Add the values to the list
63
+ name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
64
+ fields[name].append(values)
65
+
66
+ # Create a single matrix for each parameter
67
+ for param, values in fields.items():
68
+ fields[param] = np.stack(values)
69
+
70
+ return fields
71
+
72
+ """### Get Input Fields"""
73
+
74
+ fields = {}
75
+
76
+ """#### Add the single levels fields"""
77
+
78
+ fields.update(get_open_data(param=PARAM_SFC))
79
+
80
+ soil=get_open_data(param=PARAM_SOIL,levelist=SOIL_LEVELS)
81
+
82
+ """Soil parameters have been renamed since training this model, we need to rename to the original names"""
83
+
84
+ mapping = {'sot_1': 'stl1', 'sot_2': 'stl2',
85
+ 'vsw_1': 'swvl1','vsw_2': 'swvl2'}
86
+ for k,v in soil.items():
87
+ fields[mapping[k]]=v
88
+
89
+ """#### Add the pressure levels fields"""
90
+
91
+ fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
92
+
93
+ """#### Convert geopotential height into geopotential"""
94
+
95
+ # Transform GH to Z
96
+ for level in LEVELS:
97
+ gh = fields.pop(f"gh_{level}")
98
+ fields[f"z_{level}"] = gh * 9.80665
99
+
100
+ """### Create Initial State"""
101
+
102
+ input_state = dict(date=DATE, fields=fields)
103
+
104
+ """# 3. Load the Model and Run the Forecast
105
+
106
+ ### Download the Model's Checkpoint from Hugging Face & create a Runner
107
+ """
108
+
109
+ checkpoint = {"huggingface":"ecmwf/aifs-single-1.0"}
110
+
111
+ checkpoint = 'aifs-single-mse-1.0.ckpt'
112
+
113
+ """To reduce the memory usage of the model once can set certain environment variables, like the number of chunks of the model's mapper.
114
+ Please refer to:
115
+ - https://anemoi.readthedocs.io/projects/models/en/latest/modules/layers.html#anemoi-inference-num-chunks
116
+ - https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
117
+
118
+ for more information. To do so, you can use the code below:
119
+ ```
120
+ import os
121
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
122
+ os.environ['ANEMOI_INFERENCE_NUM_CHUNKS']='16'
123
+ ```
124
+ """
125
+
126
+ runner = SimpleRunner(checkpoint, device="cuda")
127
+
128
+ """** Note - changing the device from GPU to CPU**
129
+
130
+ - Running the transformer model used on the CPU is tricky, it depends on the FlashAttention library which only supports Nvidia and AMD GPUs, and is optimised for performance and memory usage
131
+ - In newer versions of anemoi-models, v0.4.2 and above, there is an option to switch off flash attention and uses Pytorchs Scaled Dot Product Attention (SDPA). The code snippet below shows how to overwrite a model from a checkpoint to use SDPA. Unfortunately it's not optimised for memory usage in the same way, leading to much greater memory usage. Please refer to https://github.com/ecmwf/anemoi-inference/issues/119 for more details
132
+
133
+ #### Run the forecast
134
+ """
135
+
136
+ for state in runner.run(input_state=input_state, lead_time=12):
137
+ print_state(state)
138
+
139
+ """**Note**
140
+ Due to the non-determinism of GPUs, users will be unable to exactly reproduce an official AIFS forecast when running AIFS Single themselves.
141
+ If you want to enforece determinism at GPU level, you can do so enforcing the following settings:
142
+
143
+ ```
144
+ #First in your terminal
145
+ export CUBLAS_WORKSPACE_CONFIG=:4096:8
146
+
147
+ #And then before running inference:
148
+ import torch
149
+ torch.backends.cudnn.benchmark = False
150
+ torch.backends.cudnn.deterministic = True
151
+ torch.use_deterministic_algorithms(True)
152
+
153
+ ```
154
+ Using the above will lead to a significant increase in runtime. Additionally, the input conditions here are provided by open data. The reprojection performed on open data differs from the one carried out at the operational level, hence small differences in the forecast are expected.
155
+
156
+ # 4. Inspect the generated forecast
157
+
158
+ #### Plot a field
159
+ """
160
+
161
+ import matplotlib.pyplot as plt
162
+ import cartopy.crs as ccrs
163
+ import cartopy.feature as cfeature
164
+ import matplotlib.tri as tri
165
+
166
+ def fix(lons):
167
+ # Shift the longitudes from 0-360 to -180-180
168
+ return np.where(lons > 180, lons - 360, lons)
169
+
170
+ latitudes = state["latitudes"]
171
+ longitudes = state["longitudes"]
172
+ values = state["fields"]["100u"]
173
+
174
+ fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()})
175
+ ax.coastlines()
176
+ ax.add_feature(cfeature.BORDERS, linestyle=":")
177
+
178
+ triangulation = tri.Triangulation(fix(longitudes), latitudes)
179
+
180
+ contour=ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap="RdBu")
181
+ cbar = fig.colorbar(contour, ax=ax, orientation="vertical", shrink=0.7, label="100u")
182
+
183
+ plt.title("100m winds (100u) at {}".format(state["date"]))
184
+ plt.show()