saburq commited on
Commit
9ebc452
·
1 Parent(s): efc88b8

add variables

Browse files
Files changed (3) hide show
  1. .gitignore +29 -0
  2. app.py +244 -80
  3. gradio_temp/.keep +0 -0
.gitignore CHANGED
@@ -1,3 +1,32 @@
1
  aifs-single-mse-1.0.ckpt
2
  flagged/
3
  gradio_temp/*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  aifs-single-mse-1.0.ckpt
2
  flagged/
3
  gradio_temp/*
4
+
5
+ # Ignore all files in temp directories except .keep
6
+ gradio_temp/data_cache/*
7
+ !gradio_temp/data_cache/.keep
8
+
9
+ gradio_temp/forecasts/*
10
+ !gradio_temp/forecasts/.keep
11
+
12
+ # Python cache files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # Environment directories
18
+ .env
19
+ .venv
20
+ env/
21
+ venv/
22
+ ENV/
23
+
24
+ # IDE directories
25
+ .idea/
26
+ .vscode/
27
+
28
+ # Jupyter Notebook
29
+ .ipynb_checkpoints
30
+
31
+ # Logs
32
+ *.log
app.py CHANGED
@@ -23,6 +23,8 @@ import pickle
23
  import json
24
  from typing import List, Dict, Any
25
  import logging
 
 
26
 
27
  # Configure logging
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -71,7 +73,7 @@ for var in ["t", "u", "v", "w", "q", "z"]:
71
  "q": "Specific Humidity",
72
  "z": "Geopotential"
73
  }[var]
74
-
75
  for level in LEVELS:
76
  var_id = f"{var}_{level}"
77
  VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
@@ -99,9 +101,7 @@ def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[in
99
 
100
  def get_cache_path(cache_key: str) -> Path:
101
  """Get the path to the cache file"""
102
- cache_dir = TEMP_DIR / "data_cache"
103
- cache_dir.mkdir(exist_ok=True)
104
- return cache_dir / f"{cache_key}.pkl"
105
 
106
  def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
107
  """Save data to disk cache"""
@@ -142,23 +142,23 @@ def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any
142
  """Main function to get data with caching"""
143
  if levelist is None:
144
  levelist = []
145
-
146
  # Try disk cache first (more persistent than memory cache)
147
  cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
148
  logger.info(f"Checking cache for key: {cache_key}")
149
-
150
  cached_data = load_from_cache(cache_key)
151
  if cached_data is not None:
152
  logger.info(f"Cache hit for {cache_key}")
153
  return cached_data
154
-
155
  # If not in cache, download and process the data
156
  logger.info(f"Cache miss for {cache_key}, downloading fresh data")
157
  fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
158
-
159
  # Save to disk cache
160
  save_to_cache(cache_key, fields)
161
-
162
  return fields
163
 
164
  def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
@@ -166,7 +166,7 @@ def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List
166
  fields = {}
167
  myiterable = [date - datetime.timedelta(hours=6), date]
168
  logger.info(f"Downloading data for dates: {myiterable}")
169
-
170
  for current_date in myiterable:
171
  logger.info(f"Fetching data for {current_date}")
172
  data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
@@ -178,50 +178,50 @@ def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List
178
  if name not in fields:
179
  fields[name] = []
180
  fields[name].append(values)
181
-
182
  # Create a single matrix for each parameter
183
  for param, values in fields.items():
184
  fields[param] = np.stack(values)
185
-
186
  return fields
187
 
188
  def plot_forecast(state, selected_variable):
189
  logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
190
-
191
  # Setup the figure and axis
192
  fig = plt.figure(figsize=(15, 8))
193
  ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
194
-
195
  # Get the coordinates
196
  latitudes, longitudes = state["latitudes"], state["longitudes"]
197
  fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
198
  triangulation = tri.Triangulation(fixed_lons, latitudes)
199
-
200
  # Get the values
201
  values = state["fields"][selected_variable]
202
  logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
203
-
204
  # Set map features
205
  ax.set_global()
206
  ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
207
  ax.coastlines(resolution='50m')
208
  ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
209
  ax.gridlines(draw_labels=True)
210
-
211
  # Create contour plot
212
  contour = ax.tricontourf(triangulation, values,
213
  levels=20, transform=ccrs.PlateCarree(),
214
  cmap='RdBu_r')
215
-
216
  # Add colorbar
217
  plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
218
-
219
  # Format the date string
220
  forecast_time = state["date"]
221
  if isinstance(forecast_time, str):
222
  forecast_time = datetime.datetime.fromisoformat(forecast_time)
223
  time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
224
-
225
  # Get variable description
226
  var_desc = None
227
  for group in VARIABLE_GROUPS.values():
@@ -229,25 +229,25 @@ def plot_forecast(state, selected_variable):
229
  var_desc = group[selected_variable]
230
  break
231
  var_name = var_desc if var_desc else selected_variable
232
-
233
  ax.set_title(f"{var_name} - {time_str}")
234
-
235
  # Save as PNG
236
  temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
237
  plt.savefig(temp_file, bbox_inches='tight', dpi=100)
238
  plt.close()
239
-
240
  return temp_file
241
 
242
  def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
243
  # Get all required fields
244
  fields = {}
245
  logger.info(f"Starting forecast for lead_time: {lead_time} hours")
246
-
247
  # Get surface fields
248
  logger.info("Getting surface fields...")
249
  fields.update(get_open_data(param=PARAM_SFC))
250
-
251
  # Get soil fields and rename them
252
  logger.info("Getting soil fields...")
253
  soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
@@ -257,29 +257,29 @@ def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[s
257
  }
258
  for k, v in soil.items():
259
  fields[mapping[k]] = v
260
-
261
  # Get pressure level fields
262
  logger.info("Getting pressure level fields...")
263
  fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
264
-
265
  # Convert geopotential height to geopotential
266
  for level in LEVELS:
267
  gh = fields.pop(f"gh_{level}")
268
  fields[f"z_{level}"] = gh * 9.80665
269
-
270
  input_state = dict(date=date, fields=fields)
271
-
272
  # Use the global model instance
273
  global MODEL
274
  if device != MODEL.device:
275
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
276
-
277
  # Run the model and get the final state
278
  final_state = None
279
  for state in MODEL.run(input_state=input_state, lead_time=lead_time):
280
  logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} "
281
  f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
282
-
283
  # Log a few example variables to show we have all fields
284
  for var in ['2t', 'msl', 't_1000', 'z_850']:
285
  if var in state['fields']:
@@ -287,29 +287,130 @@ def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[s
287
  logger.info(f" {var:<6} shape={values.shape} "
288
  f"min={np.min(values):.6f} "
289
  f"max={np.max(values):.6f}")
290
-
291
  final_state = state
292
-
293
  logger.info(f"Final state contains {len(final_state['fields'])} variables")
294
  return final_state
295
 
296
  def get_available_variables(state):
297
  """Get available variables from the state and organize them into groups"""
298
  available_vars = set(state['fields'].keys())
299
-
300
  # Create dropdown choices only for available variables
301
  choices = []
302
  for group_name, variables in VARIABLE_GROUPS.items():
303
- group_vars = [(f"{desc} ({var_id})", var_id)
304
- for var_id, desc in variables.items()
305
  if var_id in available_vars]
306
-
307
  if group_vars: # Only add group if it has available variables
308
  choices.append((f"── {group_name} ──", None))
309
  choices.extend(group_vars)
310
-
311
  return choices
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def update_interface():
314
  with gr.Blocks(css="""
315
  .centered-header {
@@ -328,7 +429,18 @@ def update_interface():
328
  border-top: 1px solid #eee;
329
  }
330
  """) as demo:
331
- state = gr.State(None)
 
 
 
 
 
 
 
 
 
 
 
332
 
333
  with gr.Row():
334
  with gr.Column(scale=1):
@@ -339,90 +451,101 @@ def update_interface():
339
  value=12,
340
  label="Forecast Hours Ahead"
341
  )
 
342
  variable = gr.Dropdown(
343
- choices=[], # Start empty
344
- value=None,
345
  label="Select Variable to Plot"
346
  )
347
  with gr.Row():
348
  clear_btn = gr.Button("Clear")
349
  run_btn = gr.Button("Run Forecast", variant="primary")
350
 
351
- with gr.Row():
352
- download_json = gr.Button("Download JSON")
353
- download_nc = gr.Button("Download NetCDF")
354
 
355
  with gr.Column(scale=2):
356
  forecast_output = gr.Image()
357
 
358
  def run_and_store(lead_time):
359
  """Run forecast and store state"""
360
- state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
361
-
362
- # Get available variables
363
- choices = get_available_variables(state)
364
-
365
- # Select first real variable as default
366
- default_var = next((var_id for _, var_id in choices if var_id is not None), None)
367
-
368
- # Generate initial plot
369
- plot = plot_forecast(state, default_var) if default_var else None
370
-
371
- return [state, gr.Dropdown(choices=choices), default_var, plot]
372
 
373
- def update_plot_from_state(state, variable):
374
  """Update plot using stored state"""
375
- if state is None or variable is None:
376
  return None
377
  try:
378
- return plot_forecast(state, variable)
379
  except KeyError as e:
380
  logger.error(f"Variable {variable} not found in state: {e}")
381
  return None
382
 
383
  def clear():
384
  """Clear everything"""
385
- return [None, None, gr.Dropdown(choices=[]), None]
386
-
387
- def save_json(state):
388
- if state is None:
389
- return None
390
- return save_forecast_data(state, 'json')
391
 
392
- def save_netcdf(state):
393
- if state is None:
394
- return None
395
- return save_forecast_data(state, 'netcdf')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  # Connect the components
398
  run_btn.click(
399
  fn=run_and_store,
400
  inputs=[lead_time],
401
- outputs=[state, variable, variable, forecast_output]
402
  )
403
 
404
  variable.change(
405
  fn=update_plot_from_state,
406
- inputs=[state, variable],
407
  outputs=forecast_output
408
  )
409
 
410
  clear_btn.click(
411
  fn=clear,
412
  inputs=[],
413
- outputs=[state, forecast_output, variable, variable]
414
- )
415
-
416
- download_json.click(
417
- fn=save_json,
418
- inputs=[state],
419
- outputs=gr.File()
420
  )
421
 
422
  download_nc.click(
423
  fn=save_netcdf,
424
- inputs=[state],
425
- outputs=gr.File()
426
  )
427
 
428
  return demo
@@ -430,3 +553,44 @@ def update_interface():
430
  # Create and launch the interface
431
  demo = update_interface()
432
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  import json
24
  from typing import List, Dict, Any
25
  import logging
26
+ import xarray as xr
27
+ import pandas as pd
28
 
29
  # Configure logging
30
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
73
  "q": "Specific Humidity",
74
  "z": "Geopotential"
75
  }[var]
76
+
77
  for level in LEVELS:
78
  var_id = f"{var}_{level}"
79
  VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
 
101
 
102
  def get_cache_path(cache_key: str) -> Path:
103
  """Get the path to the cache file"""
104
+ return TEMP_DIR / "data_cache" / f"{cache_key}.pkl"
 
 
105
 
106
  def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
107
  """Save data to disk cache"""
 
142
  """Main function to get data with caching"""
143
  if levelist is None:
144
  levelist = []
145
+
146
  # Try disk cache first (more persistent than memory cache)
147
  cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
148
  logger.info(f"Checking cache for key: {cache_key}")
149
+
150
  cached_data = load_from_cache(cache_key)
151
  if cached_data is not None:
152
  logger.info(f"Cache hit for {cache_key}")
153
  return cached_data
154
+
155
  # If not in cache, download and process the data
156
  logger.info(f"Cache miss for {cache_key}, downloading fresh data")
157
  fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
158
+
159
  # Save to disk cache
160
  save_to_cache(cache_key, fields)
161
+
162
  return fields
163
 
164
  def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
 
166
  fields = {}
167
  myiterable = [date - datetime.timedelta(hours=6), date]
168
  logger.info(f"Downloading data for dates: {myiterable}")
169
+
170
  for current_date in myiterable:
171
  logger.info(f"Fetching data for {current_date}")
172
  data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
 
178
  if name not in fields:
179
  fields[name] = []
180
  fields[name].append(values)
181
+
182
  # Create a single matrix for each parameter
183
  for param, values in fields.items():
184
  fields[param] = np.stack(values)
185
+
186
  return fields
187
 
188
  def plot_forecast(state, selected_variable):
189
  logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
190
+
191
  # Setup the figure and axis
192
  fig = plt.figure(figsize=(15, 8))
193
  ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
194
+
195
  # Get the coordinates
196
  latitudes, longitudes = state["latitudes"], state["longitudes"]
197
  fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
198
  triangulation = tri.Triangulation(fixed_lons, latitudes)
199
+
200
  # Get the values
201
  values = state["fields"][selected_variable]
202
  logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
203
+
204
  # Set map features
205
  ax.set_global()
206
  ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
207
  ax.coastlines(resolution='50m')
208
  ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
209
  ax.gridlines(draw_labels=True)
210
+
211
  # Create contour plot
212
  contour = ax.tricontourf(triangulation, values,
213
  levels=20, transform=ccrs.PlateCarree(),
214
  cmap='RdBu_r')
215
+
216
  # Add colorbar
217
  plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
218
+
219
  # Format the date string
220
  forecast_time = state["date"]
221
  if isinstance(forecast_time, str):
222
  forecast_time = datetime.datetime.fromisoformat(forecast_time)
223
  time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
224
+
225
  # Get variable description
226
  var_desc = None
227
  for group in VARIABLE_GROUPS.values():
 
229
  var_desc = group[selected_variable]
230
  break
231
  var_name = var_desc if var_desc else selected_variable
232
+
233
  ax.set_title(f"{var_name} - {time_str}")
234
+
235
  # Save as PNG
236
  temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
237
  plt.savefig(temp_file, bbox_inches='tight', dpi=100)
238
  plt.close()
239
+
240
  return temp_file
241
 
242
  def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
243
  # Get all required fields
244
  fields = {}
245
  logger.info(f"Starting forecast for lead_time: {lead_time} hours")
246
+
247
  # Get surface fields
248
  logger.info("Getting surface fields...")
249
  fields.update(get_open_data(param=PARAM_SFC))
250
+
251
  # Get soil fields and rename them
252
  logger.info("Getting soil fields...")
253
  soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
 
257
  }
258
  for k, v in soil.items():
259
  fields[mapping[k]] = v
260
+
261
  # Get pressure level fields
262
  logger.info("Getting pressure level fields...")
263
  fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
264
+
265
  # Convert geopotential height to geopotential
266
  for level in LEVELS:
267
  gh = fields.pop(f"gh_{level}")
268
  fields[f"z_{level}"] = gh * 9.80665
269
+
270
  input_state = dict(date=date, fields=fields)
271
+
272
  # Use the global model instance
273
  global MODEL
274
  if device != MODEL.device:
275
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
276
+
277
  # Run the model and get the final state
278
  final_state = None
279
  for state in MODEL.run(input_state=input_state, lead_time=lead_time):
280
  logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} "
281
  f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
282
+
283
  # Log a few example variables to show we have all fields
284
  for var in ['2t', 'msl', 't_1000', 'z_850']:
285
  if var in state['fields']:
 
287
  logger.info(f" {var:<6} shape={values.shape} "
288
  f"min={np.min(values):.6f} "
289
  f"max={np.max(values):.6f}")
290
+
291
  final_state = state
292
+
293
  logger.info(f"Final state contains {len(final_state['fields'])} variables")
294
  return final_state
295
 
296
  def get_available_variables(state):
297
  """Get available variables from the state and organize them into groups"""
298
  available_vars = set(state['fields'].keys())
299
+
300
  # Create dropdown choices only for available variables
301
  choices = []
302
  for group_name, variables in VARIABLE_GROUPS.items():
303
+ group_vars = [(f"{desc} ({var_id})", var_id)
304
+ for var_id, desc in variables.items()
305
  if var_id in available_vars]
306
+
307
  if group_vars: # Only add group if it has available variables
308
  choices.append((f"── {group_name} ──", None))
309
  choices.extend(group_vars)
310
+
311
  return choices
312
 
313
+ def save_forecast_data(state, format='json'):
314
+ """Save forecast data in specified format"""
315
+ if state is None:
316
+ raise ValueError("No forecast data available. Please run a forecast first.")
317
+
318
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
319
+ forecast_time = state['date'].strftime("%Y%m%d_%H") if isinstance(state['date'], datetime.datetime) else state['date']
320
+
321
+ # Use forecasts directory for all outputs
322
+ output_dir = TEMP_DIR / "forecasts"
323
+
324
+ if format == 'json':
325
+ # Create a JSON-serializable dictionary
326
+ data = {
327
+ 'metadata': {
328
+ 'forecast_date': forecast_time,
329
+ 'export_date': datetime.datetime.now().isoformat(),
330
+ 'total_points': len(state['latitudes']),
331
+ 'total_variables': len(state['fields'])
332
+ },
333
+ 'coordinates': {
334
+ 'latitudes': state['latitudes'].tolist(),
335
+ 'longitudes': state['longitudes'].tolist()
336
+ },
337
+ 'fields': {
338
+ var_name: {
339
+ 'values': values.tolist(),
340
+ 'statistics': {
341
+ 'min': float(np.min(values)),
342
+ 'max': float(np.max(values)),
343
+ 'mean': float(np.mean(values)),
344
+ 'std': float(np.std(values))
345
+ }
346
+ }
347
+ for var_name, values in state['fields'].items()
348
+ }
349
+ }
350
+
351
+ output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.json"
352
+ with open(output_file, 'w') as f:
353
+ json.dump(data, f, indent=2)
354
+
355
+ return str(output_file)
356
+
357
+ elif format == 'netcdf':
358
+ # Create an xarray Dataset
359
+ data_vars = {}
360
+ coords = {
361
+ 'point': np.arange(len(state['latitudes'])),
362
+ 'latitude': ('point', state['latitudes']),
363
+ 'longitude': ('point', state['longitudes']),
364
+ }
365
+
366
+ # Add each field as a variable
367
+ for var_name, values in state['fields'].items():
368
+ data_vars[var_name] = (['point'], values)
369
+
370
+ # Create the dataset
371
+ ds = xr.Dataset(
372
+ data_vars=data_vars,
373
+ coords=coords,
374
+ attrs={
375
+ 'forecast_date': forecast_time,
376
+ 'export_date': datetime.datetime.now().isoformat(),
377
+ 'description': 'AIFS Weather Forecast Data'
378
+ }
379
+ )
380
+
381
+ output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.nc"
382
+ ds.to_netcdf(output_file)
383
+
384
+ return str(output_file)
385
+
386
+ elif format == 'csv':
387
+ # Create a DataFrame with lat/lon and all variables
388
+ df = pd.DataFrame({
389
+ 'latitude': state['latitudes'],
390
+ 'longitude': state['longitudes']
391
+ })
392
+
393
+ # Add each field as a column
394
+ for var_name, values in state['fields'].items():
395
+ df[var_name] = values
396
+
397
+ output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.csv"
398
+ df.to_csv(output_file, index=False)
399
+
400
+ return str(output_file)
401
+
402
+ else:
403
+ raise ValueError(f"Unsupported format: {format}")
404
+
405
+ # Create dropdown choices with groups
406
+ DROPDOWN_CHOICES = []
407
+ for group_name, variables in VARIABLE_GROUPS.items():
408
+ # Add group separator
409
+ DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
410
+ # Add variables in this group
411
+ for var_id, desc in sorted(variables.items()):
412
+ DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
413
+
414
  def update_interface():
415
  with gr.Blocks(css="""
416
  .centered-header {
 
429
  border-top: 1px solid #eee;
430
  }
431
  """) as demo:
432
+ forecast_state = gr.State(None)
433
+
434
+ # Header section
435
+ gr.Markdown(f"""
436
+ # AIFS Weather Forecast
437
+
438
+ <div class="subtitle">
439
+ Interactive visualization of ECMWF AIFS weather forecasts.<br>
440
+ Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
441
+ select how many hours ahead you want to forecast and which meteorological variable to visualize.
442
+ </div>
443
+ """)
444
 
445
  with gr.Row():
446
  with gr.Column(scale=1):
 
451
  value=12,
452
  label="Forecast Hours Ahead"
453
  )
454
+ # Start with the original DROPDOWN_CHOICES
455
  variable = gr.Dropdown(
456
+ choices=DROPDOWN_CHOICES, # Use original choices at startup
457
+ value="2t",
458
  label="Select Variable to Plot"
459
  )
460
  with gr.Row():
461
  clear_btn = gr.Button("Clear")
462
  run_btn = gr.Button("Run Forecast", variant="primary")
463
 
464
+ download_nc = gr.Button("Download Forecast (NetCDF)")
465
+ download_output = gr.File(label="Download Output")
 
466
 
467
  with gr.Column(scale=2):
468
  forecast_output = gr.Image()
469
 
470
  def run_and_store(lead_time):
471
  """Run forecast and store state"""
472
+ forecast_state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
473
+ plot = plot_forecast(forecast_state, "2t") # Default to 2t
474
+ return forecast_state, plot
 
 
 
 
 
 
 
 
 
475
 
476
+ def update_plot_from_state(forecast_state, variable):
477
  """Update plot using stored state"""
478
+ if forecast_state is None or variable is None:
479
  return None
480
  try:
481
+ return plot_forecast(forecast_state, variable)
482
  except KeyError as e:
483
  logger.error(f"Variable {variable} not found in state: {e}")
484
  return None
485
 
486
  def clear():
487
  """Clear everything"""
488
+ return [None, None, 12, "2t"]
 
 
 
 
 
489
 
490
+ def save_netcdf(forecast_state):
491
+ """Save forecast data as NetCDF"""
492
+ if forecast_state is None:
493
+ raise ValueError("No forecast data available. Please run a forecast first.")
494
+
495
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
496
+ forecast_time = forecast_state['date'].strftime("%Y%m%d_%H") if isinstance(forecast_state['date'], datetime.datetime) else forecast_state['date']
497
+
498
+ # Create an xarray Dataset
499
+ data_vars = {}
500
+ coords = {
501
+ 'point': np.arange(len(forecast_state['latitudes'])),
502
+ 'latitude': ('point', forecast_state['latitudes']),
503
+ 'longitude': ('point', forecast_state['longitudes']),
504
+ }
505
+
506
+ # Add each field as a variable
507
+ for var_name, values in forecast_state['fields'].items():
508
+ data_vars[var_name] = (['point'], values)
509
+
510
+ # Create the dataset
511
+ ds = xr.Dataset(
512
+ data_vars=data_vars,
513
+ coords=coords,
514
+ attrs={
515
+ 'forecast_date': forecast_time,
516
+ 'export_date': datetime.datetime.now().isoformat(),
517
+ 'description': 'AIFS Weather Forecast Data'
518
+ }
519
+ )
520
+
521
+ output_file = TEMP_DIR / "forecasts" / f"forecast_{forecast_time}_{timestamp}.nc"
522
+ ds.to_netcdf(output_file)
523
+
524
+ return str(output_file)
525
 
526
  # Connect the components
527
  run_btn.click(
528
  fn=run_and_store,
529
  inputs=[lead_time],
530
+ outputs=[forecast_state, forecast_output]
531
  )
532
 
533
  variable.change(
534
  fn=update_plot_from_state,
535
+ inputs=[forecast_state, variable],
536
  outputs=forecast_output
537
  )
538
 
539
  clear_btn.click(
540
  fn=clear,
541
  inputs=[],
542
+ outputs=[forecast_state, forecast_output, lead_time, variable]
 
 
 
 
 
 
543
  )
544
 
545
  download_nc.click(
546
  fn=save_netcdf,
547
+ inputs=[forecast_state],
548
+ outputs=[download_output]
549
  )
550
 
551
  return demo
 
553
  # Create and launch the interface
554
  demo = update_interface()
555
  demo.launch()
556
+
557
+ def setup_directories():
558
+ """Create necessary directories with .keep files"""
559
+ # Define all required directories
560
+ directories = {
561
+ TEMP_DIR / "data_cache": "Cache directory for downloaded weather data",
562
+ TEMP_DIR / "forecasts": "Directory for forecast outputs (plots and data files)",
563
+ }
564
+
565
+ # Create directories and .keep files
566
+ for directory, description in directories.items():
567
+ directory.mkdir(parents=True, exist_ok=True)
568
+ keep_file = directory / ".keep"
569
+ if not keep_file.exists():
570
+ keep_file.write_text(f"# {description}\n# This file ensures the directory is tracked in git\n")
571
+ logger.info(f"Created directory and .keep file: {directory}")
572
+
573
+ # Call it during initialization
574
+ setup_directories()
575
+
576
+ def cleanup_old_files():
577
+ """Remove old temporary and cache files"""
578
+ current_time = datetime.datetime.now().timestamp()
579
+
580
+ # Clean up forecast files (1 hour old)
581
+ forecast_dir = TEMP_DIR / "forecasts"
582
+ for file in forecast_dir.glob("*.*"):
583
+ if file.name == ".keep":
584
+ continue
585
+ if current_time - file.stat().st_mtime > 3600:
586
+ logger.info(f"Removing old forecast file: {file}")
587
+ file.unlink(missing_ok=True)
588
+
589
+ # Clean up cache files (24 hours old)
590
+ cache_dir = TEMP_DIR / "data_cache"
591
+ for file in cache_dir.glob("*.pkl"):
592
+ if file.name == ".keep":
593
+ continue
594
+ if current_time - file.stat().st_mtime > 86400:
595
+ logger.info(f"Removing old cache file: {file}")
596
+ file.unlink(missing_ok=True)
gradio_temp/.keep CHANGED
File without changes