saburq commited on
Commit
e9a1c0f
·
1 Parent(s): c3b681a

add animation

Browse files
Files changed (1) hide show
  1. app.py +162 -55
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  # Set memory optimization environment variables
3
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
4
  os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'
@@ -14,6 +16,7 @@ from anemoi.inference.runners.simple import SimpleRunner
14
  from ecmwf.opendata import Client as OpendataClient
15
  import earthkit.data as ekd
16
  import earthkit.regrid as ekr
 
17
 
18
  # Define parameters (updating to match notebook.py)
19
  PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
@@ -66,6 +69,11 @@ for var in ["t", "u", "v", "w", "q", "z"]:
66
  # Load the model once at startup
67
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
68
 
 
 
 
 
 
69
  def get_open_data(param, levelist=[]):
70
  fields = {}
71
  # Get the data for the current date and the previous date
@@ -90,6 +98,91 @@ def get_open_data(param, levelist=[]):
90
 
91
  return fields
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def run_forecast(date, lead_time, device):
94
  # Get all required fields
95
  fields = {}
@@ -122,42 +215,24 @@ def run_forecast(date, lead_time, device):
122
  if device != MODEL.device:
123
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
124
 
125
- results = []
 
126
  for state in MODEL.run(input_state=input_state, lead_time=lead_time):
127
- results.append(state)
128
- return results[-1]
129
 
130
- def plot_forecast(state, selected_variable):
131
- latitudes, longitudes = state["latitudes"], state["longitudes"]
132
- values = state["fields"][selected_variable]
133
-
134
- # Create figure with specific projection centered on 0°
135
- fig = plt.figure(figsize=(15, 8))
136
- ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
137
-
138
- ax.set_global()
139
- ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree())
140
-
141
- # Add map features
142
- ax.coastlines(resolution='50m')
143
- ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
144
- ax.gridlines(draw_labels=True)
145
-
146
- # Fix longitudes to be -180 to 180
147
- fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
148
-
149
- # Create triangulation with fixed longitudes
150
- triangulation = tri.Triangulation(fixed_lons, latitudes)
151
-
152
- # Create the contour plot
153
- contour = ax.tricontourf(triangulation, values, levels=20,
154
- transform=ccrs.PlateCarree(),
155
- cmap='RdBu_r')
156
-
157
- plt.title(f"{selected_variable} at {state['date']}")
158
- plt.colorbar(contour, orientation='horizontal', pad=0.05)
159
-
160
- return fig
161
 
162
  # Create dropdown choices with groups
163
  DROPDOWN_CHOICES = []
@@ -173,17 +248,31 @@ with gr.Blocks(css="""
173
  text-align: center;
174
  margin-bottom: 20px;
175
  }
 
 
 
 
 
 
 
 
 
 
 
176
  """) as demo:
177
- # Centered header section
178
  gr.Markdown(f"""
179
  # AIFS Weather Forecast
180
 
181
- Interactive visualization of ECMWF AIFS weather forecasts. Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),
 
 
182
  select how many hours ahead you want to forecast and which meteorological variable to visualize.
183
- """, elem_classes=["centered-header"])
 
184
 
 
185
  with gr.Row():
186
- # Controls column - takes up 1/3 of the width
187
  with gr.Column(scale=1):
188
  lead_time = gr.Slider(
189
  minimum=6,
@@ -195,34 +284,52 @@ with gr.Blocks(css="""
195
  variable = gr.Dropdown(
196
  choices=DROPDOWN_CHOICES,
197
  value="2t",
198
- label="Select Variable to Plot",
199
- info="Choose a meteorological variable to visualize"
200
  )
201
-
202
- # Add buttons in a row
203
  with gr.Row():
204
  clear_btn = gr.Button("Clear")
205
  submit_btn = gr.Button("Submit", variant="primary")
206
 
207
- # Map column - takes up 2/3 of the width
208
  with gr.Column(scale=2):
209
- plot_output = gr.Plot()
210
 
211
- # Connect the inputs to the forecast function
212
- def update_plot(lead_time, variable):
213
- state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
214
- return plot_forecast(state, variable)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- # Clear function to reset to defaults
217
  def clear():
218
  return [
219
- 12, # Reset slider to default value
220
- "2t", # Reset dropdown to default value
221
- None # Clear the plot
222
  ]
223
 
224
- # Connect the buttons
225
- submit_btn.click(fn=update_plot, inputs=[lead_time, variable], outputs=plot_output)
226
- clear_btn.click(fn=clear, inputs=[], outputs=[lead_time, variable, plot_output])
 
 
 
 
 
 
 
 
227
 
228
  demo.launch()
 
1
  import os
2
+ import tempfile
3
+ from pathlib import Path
4
  # Set memory optimization environment variables
5
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
6
  os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'
 
16
  from ecmwf.opendata import Client as OpendataClient
17
  import earthkit.data as ekd
18
  import earthkit.regrid as ekr
19
+ import matplotlib.animation as animation
20
 
21
  # Define parameters (updating to match notebook.py)
22
  PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
 
69
  # Load the model once at startup
70
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
71
 
72
+ # Create and set custom temp directory
73
+ TEMP_DIR = Path("./gradio_temp")
74
+ TEMP_DIR.mkdir(exist_ok=True)
75
+ os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
76
+
77
  def get_open_data(param, levelist=[]):
78
  fields = {}
79
  # Get the data for the current date and the previous date
 
98
 
99
  return fields
100
 
101
+ def plot_forecast_animation(states, selected_variable):
102
+ # Setup the figure and axis
103
+ fig = plt.figure(figsize=(15, 8))
104
+ ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
105
+
106
+ # Get the first state to setup the plot
107
+ first_state = states[0]
108
+ latitudes, longitudes = first_state["latitudes"], first_state["longitudes"]
109
+ fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
110
+ triangulation = tri.Triangulation(fixed_lons, latitudes)
111
+
112
+ # Find global min/max for consistent colorbar
113
+ all_values = [state["fields"][selected_variable] for state in states]
114
+ vmin, vmax = np.min(all_values), np.max(all_values)
115
+
116
+ # Create a single colorbar that will be reused
117
+ contour = None
118
+ cbar_ax = None
119
+
120
+ def update(frame):
121
+ nonlocal contour, cbar_ax
122
+ ax.clear()
123
+
124
+ # Set map features
125
+ ax.set_global()
126
+ ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
127
+ ax.coastlines(resolution='50m')
128
+ ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
129
+ ax.gridlines(draw_labels=True)
130
+
131
+ state = states[frame]
132
+ values = state["fields"][selected_variable]
133
+
134
+ # Clear the previous colorbar axis if it exists
135
+ if cbar_ax:
136
+ cbar_ax.remove()
137
+
138
+ # Create new contour plot
139
+ contour = ax.tricontourf(triangulation, values,
140
+ levels=20, transform=ccrs.PlateCarree(),
141
+ cmap='RdBu_r', vmin=vmin, vmax=vmax)
142
+
143
+ # Create new colorbar
144
+ cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03]) # [left, bottom, width, height]
145
+ plt.colorbar(contour, cax=cbar_ax, orientation='horizontal')
146
+
147
+ # Format the date string properly
148
+ forecast_time = state["date"]
149
+ if isinstance(forecast_time, str):
150
+ try:
151
+ forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S")
152
+ except ValueError:
153
+ try:
154
+ forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f")
155
+ except ValueError:
156
+ forecast_time = DEFAULT_DATE
157
+
158
+ time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
159
+
160
+ # Get variable description from VARIABLE_GROUPS
161
+ var_desc = None
162
+ for group in VARIABLE_GROUPS.values():
163
+ if selected_variable in group:
164
+ var_desc = group[selected_variable]
165
+ break
166
+ var_name = var_desc if var_desc else selected_variable
167
+
168
+ ax.set_title(f"{var_name} - {time_str}")
169
+
170
+ # Create animation
171
+ anim = animation.FuncAnimation(
172
+ fig, update,
173
+ frames=len(states),
174
+ interval=1000, # 1 second between frames
175
+ repeat=True,
176
+ blit=False # Must be False to update the colorbar
177
+ )
178
+
179
+ # Save as MP4
180
+ temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.mp4")
181
+ anim.save(temp_file, writer='ffmpeg', fps=1)
182
+ plt.close()
183
+
184
+ return temp_file
185
+
186
  def run_forecast(date, lead_time, device):
187
  # Get all required fields
188
  fields = {}
 
215
  if device != MODEL.device:
216
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
217
 
218
+ # Collect all states instead of just the last one
219
+ states = []
220
  for state in MODEL.run(input_state=input_state, lead_time=lead_time):
221
+ states.append(state)
222
+ return states
223
 
224
+ def update_plot(lead_time, variable):
225
+ cleanup_old_files() # Clean up old files before creating new ones
226
+ states = run_forecast(DEFAULT_DATE, lead_time, "cuda")
227
+ return plot_forecast_animation(states, variable)
228
+
229
+ # Add cleanup function for old files
230
+ def cleanup_old_files():
231
+ # Remove files older than 1 hour
232
+ current_time = datetime.datetime.now().timestamp()
233
+ for file in TEMP_DIR.glob("*.mp4"): # Changed from *.gif to *.mp4
234
+ if current_time - file.stat().st_mtime > 3600: # 1 hour in seconds
235
+ file.unlink(missing_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  # Create dropdown choices with groups
238
  DROPDOWN_CHOICES = []
 
248
  text-align: center;
249
  margin-bottom: 20px;
250
  }
251
+ .subtitle {
252
+ font-size: 1.2em;
253
+ line-height: 1.5;
254
+ margin: 20px 0;
255
+ }
256
+ .footer {
257
+ text-align: center;
258
+ padding: 20px;
259
+ margin-top: 20px;
260
+ border-top: 1px solid #eee;
261
+ }
262
  """) as demo:
263
+ # Header section
264
  gr.Markdown(f"""
265
  # AIFS Weather Forecast
266
 
267
+ <div class="subtitle">
268
+ Interactive visualization of ECMWF AIFS weather forecasts.<br>
269
+ Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
270
  select how many hours ahead you want to forecast and which meteorological variable to visualize.
271
+ </div>
272
+ """)
273
 
274
+ # Main content
275
  with gr.Row():
 
276
  with gr.Column(scale=1):
277
  lead_time = gr.Slider(
278
  minimum=6,
 
284
  variable = gr.Dropdown(
285
  choices=DROPDOWN_CHOICES,
286
  value="2t",
287
+ label="Select Variable to Plot"
 
288
  )
 
 
289
  with gr.Row():
290
  clear_btn = gr.Button("Clear")
291
  submit_btn = gr.Button("Submit", variant="primary")
292
 
 
293
  with gr.Column(scale=2):
294
+ animation_output = gr.Video()
295
 
296
+ # Footer with fork instructions and model reference
297
+ gr.Markdown("""
298
+ <div class="footer">
299
+ <h3>Want to run this on your own?</h3>
300
+ You can fork this space and run it yourself:
301
+
302
+ 1. Visit <a href="https://huggingface.co/spaces/geobase/aifs-forecast" target="_blank">https://huggingface.co/spaces/geobase/aifs-forecast</a>\n
303
+ 2. Click the "Duplicate this Space" button in the top right\n
304
+ 3. Select your hardware requirements (GPU recommended)\n
305
+ 4. Wait for your copy to deploy
306
+
307
+ <h3>Model Information</h3>
308
+ This demo uses the <a href="https://huggingface.co/ecmwf/aifs-single-1.0" target="_blank">AIFS Single 1.0</a> model from ECMWF,
309
+ which is their first operationally supported Artificial Intelligence Forecasting System. The model produces highly skilled forecasts
310
+ for upper-air variables, surface weather parameters, and tropical cyclone tracks.
311
+
312
+ Note: If you encounter any issues with this demo, trying your own fork might work better!
313
+ </div>
314
+ """)
315
 
 
316
  def clear():
317
  return [
318
+ 12,
319
+ "2t",
320
+ None
321
  ]
322
 
323
+ # Connect the inputs to the forecast function
324
+ submit_btn.click(
325
+ fn=update_plot,
326
+ inputs=[lead_time, variable],
327
+ outputs=animation_output
328
+ )
329
+ clear_btn.click(
330
+ fn=clear,
331
+ inputs=[],
332
+ outputs=[lead_time, variable, animation_output]
333
+ )
334
 
335
  demo.launch()