Spaces:
Build error
Build error
add variables
Browse files- .gitignore +29 -0
- app.py +244 -80
- gradio_temp/.keep +0 -0
.gitignore
CHANGED
@@ -1,3 +1,32 @@
|
|
1 |
aifs-single-mse-1.0.ckpt
|
2 |
flagged/
|
3 |
gradio_temp/*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
aifs-single-mse-1.0.ckpt
|
2 |
flagged/
|
3 |
gradio_temp/*
|
4 |
+
|
5 |
+
# Ignore all files in temp directories except .keep
|
6 |
+
gradio_temp/data_cache/*
|
7 |
+
!gradio_temp/data_cache/.keep
|
8 |
+
|
9 |
+
gradio_temp/forecasts/*
|
10 |
+
!gradio_temp/forecasts/.keep
|
11 |
+
|
12 |
+
# Python cache files
|
13 |
+
__pycache__/
|
14 |
+
*.py[cod]
|
15 |
+
*$py.class
|
16 |
+
|
17 |
+
# Environment directories
|
18 |
+
.env
|
19 |
+
.venv
|
20 |
+
env/
|
21 |
+
venv/
|
22 |
+
ENV/
|
23 |
+
|
24 |
+
# IDE directories
|
25 |
+
.idea/
|
26 |
+
.vscode/
|
27 |
+
|
28 |
+
# Jupyter Notebook
|
29 |
+
.ipynb_checkpoints
|
30 |
+
|
31 |
+
# Logs
|
32 |
+
*.log
|
app.py
CHANGED
@@ -23,6 +23,8 @@ import pickle
|
|
23 |
import json
|
24 |
from typing import List, Dict, Any
|
25 |
import logging
|
|
|
|
|
26 |
|
27 |
# Configure logging
|
28 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
@@ -71,7 +73,7 @@ for var in ["t", "u", "v", "w", "q", "z"]:
|
|
71 |
"q": "Specific Humidity",
|
72 |
"z": "Geopotential"
|
73 |
}[var]
|
74 |
-
|
75 |
for level in LEVELS:
|
76 |
var_id = f"{var}_{level}"
|
77 |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
|
@@ -99,9 +101,7 @@ def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[in
|
|
99 |
|
100 |
def get_cache_path(cache_key: str) -> Path:
|
101 |
"""Get the path to the cache file"""
|
102 |
-
|
103 |
-
cache_dir.mkdir(exist_ok=True)
|
104 |
-
return cache_dir / f"{cache_key}.pkl"
|
105 |
|
106 |
def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
|
107 |
"""Save data to disk cache"""
|
@@ -142,23 +142,23 @@ def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any
|
|
142 |
"""Main function to get data with caching"""
|
143 |
if levelist is None:
|
144 |
levelist = []
|
145 |
-
|
146 |
# Try disk cache first (more persistent than memory cache)
|
147 |
cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
|
148 |
logger.info(f"Checking cache for key: {cache_key}")
|
149 |
-
|
150 |
cached_data = load_from_cache(cache_key)
|
151 |
if cached_data is not None:
|
152 |
logger.info(f"Cache hit for {cache_key}")
|
153 |
return cached_data
|
154 |
-
|
155 |
# If not in cache, download and process the data
|
156 |
logger.info(f"Cache miss for {cache_key}, downloading fresh data")
|
157 |
fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
|
158 |
-
|
159 |
# Save to disk cache
|
160 |
save_to_cache(cache_key, fields)
|
161 |
-
|
162 |
return fields
|
163 |
|
164 |
def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
|
@@ -166,7 +166,7 @@ def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List
|
|
166 |
fields = {}
|
167 |
myiterable = [date - datetime.timedelta(hours=6), date]
|
168 |
logger.info(f"Downloading data for dates: {myiterable}")
|
169 |
-
|
170 |
for current_date in myiterable:
|
171 |
logger.info(f"Fetching data for {current_date}")
|
172 |
data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
|
@@ -178,50 +178,50 @@ def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List
|
|
178 |
if name not in fields:
|
179 |
fields[name] = []
|
180 |
fields[name].append(values)
|
181 |
-
|
182 |
# Create a single matrix for each parameter
|
183 |
for param, values in fields.items():
|
184 |
fields[param] = np.stack(values)
|
185 |
-
|
186 |
return fields
|
187 |
|
188 |
def plot_forecast(state, selected_variable):
|
189 |
logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
|
190 |
-
|
191 |
# Setup the figure and axis
|
192 |
fig = plt.figure(figsize=(15, 8))
|
193 |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
|
194 |
-
|
195 |
# Get the coordinates
|
196 |
latitudes, longitudes = state["latitudes"], state["longitudes"]
|
197 |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
198 |
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
199 |
-
|
200 |
# Get the values
|
201 |
values = state["fields"][selected_variable]
|
202 |
logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
|
203 |
-
|
204 |
# Set map features
|
205 |
ax.set_global()
|
206 |
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
|
207 |
ax.coastlines(resolution='50m')
|
208 |
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
209 |
ax.gridlines(draw_labels=True)
|
210 |
-
|
211 |
# Create contour plot
|
212 |
contour = ax.tricontourf(triangulation, values,
|
213 |
levels=20, transform=ccrs.PlateCarree(),
|
214 |
cmap='RdBu_r')
|
215 |
-
|
216 |
# Add colorbar
|
217 |
plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
|
218 |
-
|
219 |
# Format the date string
|
220 |
forecast_time = state["date"]
|
221 |
if isinstance(forecast_time, str):
|
222 |
forecast_time = datetime.datetime.fromisoformat(forecast_time)
|
223 |
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
224 |
-
|
225 |
# Get variable description
|
226 |
var_desc = None
|
227 |
for group in VARIABLE_GROUPS.values():
|
@@ -229,25 +229,25 @@ def plot_forecast(state, selected_variable):
|
|
229 |
var_desc = group[selected_variable]
|
230 |
break
|
231 |
var_name = var_desc if var_desc else selected_variable
|
232 |
-
|
233 |
ax.set_title(f"{var_name} - {time_str}")
|
234 |
-
|
235 |
# Save as PNG
|
236 |
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
|
237 |
plt.savefig(temp_file, bbox_inches='tight', dpi=100)
|
238 |
plt.close()
|
239 |
-
|
240 |
return temp_file
|
241 |
|
242 |
def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
|
243 |
# Get all required fields
|
244 |
fields = {}
|
245 |
logger.info(f"Starting forecast for lead_time: {lead_time} hours")
|
246 |
-
|
247 |
# Get surface fields
|
248 |
logger.info("Getting surface fields...")
|
249 |
fields.update(get_open_data(param=PARAM_SFC))
|
250 |
-
|
251 |
# Get soil fields and rename them
|
252 |
logger.info("Getting soil fields...")
|
253 |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
|
@@ -257,29 +257,29 @@ def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[s
|
|
257 |
}
|
258 |
for k, v in soil.items():
|
259 |
fields[mapping[k]] = v
|
260 |
-
|
261 |
# Get pressure level fields
|
262 |
logger.info("Getting pressure level fields...")
|
263 |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
|
264 |
-
|
265 |
# Convert geopotential height to geopotential
|
266 |
for level in LEVELS:
|
267 |
gh = fields.pop(f"gh_{level}")
|
268 |
fields[f"z_{level}"] = gh * 9.80665
|
269 |
-
|
270 |
input_state = dict(date=date, fields=fields)
|
271 |
-
|
272 |
# Use the global model instance
|
273 |
global MODEL
|
274 |
if device != MODEL.device:
|
275 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
276 |
-
|
277 |
# Run the model and get the final state
|
278 |
final_state = None
|
279 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
280 |
logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} "
|
281 |
f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
|
282 |
-
|
283 |
# Log a few example variables to show we have all fields
|
284 |
for var in ['2t', 'msl', 't_1000', 'z_850']:
|
285 |
if var in state['fields']:
|
@@ -287,29 +287,130 @@ def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[s
|
|
287 |
logger.info(f" {var:<6} shape={values.shape} "
|
288 |
f"min={np.min(values):.6f} "
|
289 |
f"max={np.max(values):.6f}")
|
290 |
-
|
291 |
final_state = state
|
292 |
-
|
293 |
logger.info(f"Final state contains {len(final_state['fields'])} variables")
|
294 |
return final_state
|
295 |
|
296 |
def get_available_variables(state):
|
297 |
"""Get available variables from the state and organize them into groups"""
|
298 |
available_vars = set(state['fields'].keys())
|
299 |
-
|
300 |
# Create dropdown choices only for available variables
|
301 |
choices = []
|
302 |
for group_name, variables in VARIABLE_GROUPS.items():
|
303 |
-
group_vars = [(f"{desc} ({var_id})", var_id)
|
304 |
-
for var_id, desc in variables.items()
|
305 |
if var_id in available_vars]
|
306 |
-
|
307 |
if group_vars: # Only add group if it has available variables
|
308 |
choices.append((f"── {group_name} ──", None))
|
309 |
choices.extend(group_vars)
|
310 |
-
|
311 |
return choices
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
def update_interface():
|
314 |
with gr.Blocks(css="""
|
315 |
.centered-header {
|
@@ -328,7 +429,18 @@ def update_interface():
|
|
328 |
border-top: 1px solid #eee;
|
329 |
}
|
330 |
""") as demo:
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
with gr.Row():
|
334 |
with gr.Column(scale=1):
|
@@ -339,90 +451,101 @@ def update_interface():
|
|
339 |
value=12,
|
340 |
label="Forecast Hours Ahead"
|
341 |
)
|
|
|
342 |
variable = gr.Dropdown(
|
343 |
-
choices=
|
344 |
-
value=
|
345 |
label="Select Variable to Plot"
|
346 |
)
|
347 |
with gr.Row():
|
348 |
clear_btn = gr.Button("Clear")
|
349 |
run_btn = gr.Button("Run Forecast", variant="primary")
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
download_nc = gr.Button("Download NetCDF")
|
354 |
|
355 |
with gr.Column(scale=2):
|
356 |
forecast_output = gr.Image()
|
357 |
|
358 |
def run_and_store(lead_time):
|
359 |
"""Run forecast and store state"""
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
choices = get_available_variables(state)
|
364 |
-
|
365 |
-
# Select first real variable as default
|
366 |
-
default_var = next((var_id for _, var_id in choices if var_id is not None), None)
|
367 |
-
|
368 |
-
# Generate initial plot
|
369 |
-
plot = plot_forecast(state, default_var) if default_var else None
|
370 |
-
|
371 |
-
return [state, gr.Dropdown(choices=choices), default_var, plot]
|
372 |
|
373 |
-
def update_plot_from_state(
|
374 |
"""Update plot using stored state"""
|
375 |
-
if
|
376 |
return None
|
377 |
try:
|
378 |
-
return plot_forecast(
|
379 |
except KeyError as e:
|
380 |
logger.error(f"Variable {variable} not found in state: {e}")
|
381 |
return None
|
382 |
|
383 |
def clear():
|
384 |
"""Clear everything"""
|
385 |
-
return [None, None,
|
386 |
-
|
387 |
-
def save_json(state):
|
388 |
-
if state is None:
|
389 |
-
return None
|
390 |
-
return save_forecast_data(state, 'json')
|
391 |
|
392 |
-
def save_netcdf(
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
# Connect the components
|
398 |
run_btn.click(
|
399 |
fn=run_and_store,
|
400 |
inputs=[lead_time],
|
401 |
-
outputs=[
|
402 |
)
|
403 |
|
404 |
variable.change(
|
405 |
fn=update_plot_from_state,
|
406 |
-
inputs=[
|
407 |
outputs=forecast_output
|
408 |
)
|
409 |
|
410 |
clear_btn.click(
|
411 |
fn=clear,
|
412 |
inputs=[],
|
413 |
-
outputs=[
|
414 |
-
)
|
415 |
-
|
416 |
-
download_json.click(
|
417 |
-
fn=save_json,
|
418 |
-
inputs=[state],
|
419 |
-
outputs=gr.File()
|
420 |
)
|
421 |
|
422 |
download_nc.click(
|
423 |
fn=save_netcdf,
|
424 |
-
inputs=[
|
425 |
-
outputs=
|
426 |
)
|
427 |
|
428 |
return demo
|
@@ -430,3 +553,44 @@ def update_interface():
|
|
430 |
# Create and launch the interface
|
431 |
demo = update_interface()
|
432 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
import json
|
24 |
from typing import List, Dict, Any
|
25 |
import logging
|
26 |
+
import xarray as xr
|
27 |
+
import pandas as pd
|
28 |
|
29 |
# Configure logging
|
30 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
73 |
"q": "Specific Humidity",
|
74 |
"z": "Geopotential"
|
75 |
}[var]
|
76 |
+
|
77 |
for level in LEVELS:
|
78 |
var_id = f"{var}_{level}"
|
79 |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
|
|
|
101 |
|
102 |
def get_cache_path(cache_key: str) -> Path:
|
103 |
"""Get the path to the cache file"""
|
104 |
+
return TEMP_DIR / "data_cache" / f"{cache_key}.pkl"
|
|
|
|
|
105 |
|
106 |
def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
|
107 |
"""Save data to disk cache"""
|
|
|
142 |
"""Main function to get data with caching"""
|
143 |
if levelist is None:
|
144 |
levelist = []
|
145 |
+
|
146 |
# Try disk cache first (more persistent than memory cache)
|
147 |
cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
|
148 |
logger.info(f"Checking cache for key: {cache_key}")
|
149 |
+
|
150 |
cached_data = load_from_cache(cache_key)
|
151 |
if cached_data is not None:
|
152 |
logger.info(f"Cache hit for {cache_key}")
|
153 |
return cached_data
|
154 |
+
|
155 |
# If not in cache, download and process the data
|
156 |
logger.info(f"Cache miss for {cache_key}, downloading fresh data")
|
157 |
fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
|
158 |
+
|
159 |
# Save to disk cache
|
160 |
save_to_cache(cache_key, fields)
|
161 |
+
|
162 |
return fields
|
163 |
|
164 |
def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
|
|
|
166 |
fields = {}
|
167 |
myiterable = [date - datetime.timedelta(hours=6), date]
|
168 |
logger.info(f"Downloading data for dates: {myiterable}")
|
169 |
+
|
170 |
for current_date in myiterable:
|
171 |
logger.info(f"Fetching data for {current_date}")
|
172 |
data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
|
|
|
178 |
if name not in fields:
|
179 |
fields[name] = []
|
180 |
fields[name].append(values)
|
181 |
+
|
182 |
# Create a single matrix for each parameter
|
183 |
for param, values in fields.items():
|
184 |
fields[param] = np.stack(values)
|
185 |
+
|
186 |
return fields
|
187 |
|
188 |
def plot_forecast(state, selected_variable):
|
189 |
logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
|
190 |
+
|
191 |
# Setup the figure and axis
|
192 |
fig = plt.figure(figsize=(15, 8))
|
193 |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
|
194 |
+
|
195 |
# Get the coordinates
|
196 |
latitudes, longitudes = state["latitudes"], state["longitudes"]
|
197 |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
198 |
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
199 |
+
|
200 |
# Get the values
|
201 |
values = state["fields"][selected_variable]
|
202 |
logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
|
203 |
+
|
204 |
# Set map features
|
205 |
ax.set_global()
|
206 |
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
|
207 |
ax.coastlines(resolution='50m')
|
208 |
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
209 |
ax.gridlines(draw_labels=True)
|
210 |
+
|
211 |
# Create contour plot
|
212 |
contour = ax.tricontourf(triangulation, values,
|
213 |
levels=20, transform=ccrs.PlateCarree(),
|
214 |
cmap='RdBu_r')
|
215 |
+
|
216 |
# Add colorbar
|
217 |
plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
|
218 |
+
|
219 |
# Format the date string
|
220 |
forecast_time = state["date"]
|
221 |
if isinstance(forecast_time, str):
|
222 |
forecast_time = datetime.datetime.fromisoformat(forecast_time)
|
223 |
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
224 |
+
|
225 |
# Get variable description
|
226 |
var_desc = None
|
227 |
for group in VARIABLE_GROUPS.values():
|
|
|
229 |
var_desc = group[selected_variable]
|
230 |
break
|
231 |
var_name = var_desc if var_desc else selected_variable
|
232 |
+
|
233 |
ax.set_title(f"{var_name} - {time_str}")
|
234 |
+
|
235 |
# Save as PNG
|
236 |
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
|
237 |
plt.savefig(temp_file, bbox_inches='tight', dpi=100)
|
238 |
plt.close()
|
239 |
+
|
240 |
return temp_file
|
241 |
|
242 |
def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
|
243 |
# Get all required fields
|
244 |
fields = {}
|
245 |
logger.info(f"Starting forecast for lead_time: {lead_time} hours")
|
246 |
+
|
247 |
# Get surface fields
|
248 |
logger.info("Getting surface fields...")
|
249 |
fields.update(get_open_data(param=PARAM_SFC))
|
250 |
+
|
251 |
# Get soil fields and rename them
|
252 |
logger.info("Getting soil fields...")
|
253 |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
|
|
|
257 |
}
|
258 |
for k, v in soil.items():
|
259 |
fields[mapping[k]] = v
|
260 |
+
|
261 |
# Get pressure level fields
|
262 |
logger.info("Getting pressure level fields...")
|
263 |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
|
264 |
+
|
265 |
# Convert geopotential height to geopotential
|
266 |
for level in LEVELS:
|
267 |
gh = fields.pop(f"gh_{level}")
|
268 |
fields[f"z_{level}"] = gh * 9.80665
|
269 |
+
|
270 |
input_state = dict(date=date, fields=fields)
|
271 |
+
|
272 |
# Use the global model instance
|
273 |
global MODEL
|
274 |
if device != MODEL.device:
|
275 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
276 |
+
|
277 |
# Run the model and get the final state
|
278 |
final_state = None
|
279 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
280 |
logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} "
|
281 |
f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
|
282 |
+
|
283 |
# Log a few example variables to show we have all fields
|
284 |
for var in ['2t', 'msl', 't_1000', 'z_850']:
|
285 |
if var in state['fields']:
|
|
|
287 |
logger.info(f" {var:<6} shape={values.shape} "
|
288 |
f"min={np.min(values):.6f} "
|
289 |
f"max={np.max(values):.6f}")
|
290 |
+
|
291 |
final_state = state
|
292 |
+
|
293 |
logger.info(f"Final state contains {len(final_state['fields'])} variables")
|
294 |
return final_state
|
295 |
|
296 |
def get_available_variables(state):
|
297 |
"""Get available variables from the state and organize them into groups"""
|
298 |
available_vars = set(state['fields'].keys())
|
299 |
+
|
300 |
# Create dropdown choices only for available variables
|
301 |
choices = []
|
302 |
for group_name, variables in VARIABLE_GROUPS.items():
|
303 |
+
group_vars = [(f"{desc} ({var_id})", var_id)
|
304 |
+
for var_id, desc in variables.items()
|
305 |
if var_id in available_vars]
|
306 |
+
|
307 |
if group_vars: # Only add group if it has available variables
|
308 |
choices.append((f"── {group_name} ──", None))
|
309 |
choices.extend(group_vars)
|
310 |
+
|
311 |
return choices
|
312 |
|
313 |
+
def save_forecast_data(state, format='json'):
|
314 |
+
"""Save forecast data in specified format"""
|
315 |
+
if state is None:
|
316 |
+
raise ValueError("No forecast data available. Please run a forecast first.")
|
317 |
+
|
318 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
319 |
+
forecast_time = state['date'].strftime("%Y%m%d_%H") if isinstance(state['date'], datetime.datetime) else state['date']
|
320 |
+
|
321 |
+
# Use forecasts directory for all outputs
|
322 |
+
output_dir = TEMP_DIR / "forecasts"
|
323 |
+
|
324 |
+
if format == 'json':
|
325 |
+
# Create a JSON-serializable dictionary
|
326 |
+
data = {
|
327 |
+
'metadata': {
|
328 |
+
'forecast_date': forecast_time,
|
329 |
+
'export_date': datetime.datetime.now().isoformat(),
|
330 |
+
'total_points': len(state['latitudes']),
|
331 |
+
'total_variables': len(state['fields'])
|
332 |
+
},
|
333 |
+
'coordinates': {
|
334 |
+
'latitudes': state['latitudes'].tolist(),
|
335 |
+
'longitudes': state['longitudes'].tolist()
|
336 |
+
},
|
337 |
+
'fields': {
|
338 |
+
var_name: {
|
339 |
+
'values': values.tolist(),
|
340 |
+
'statistics': {
|
341 |
+
'min': float(np.min(values)),
|
342 |
+
'max': float(np.max(values)),
|
343 |
+
'mean': float(np.mean(values)),
|
344 |
+
'std': float(np.std(values))
|
345 |
+
}
|
346 |
+
}
|
347 |
+
for var_name, values in state['fields'].items()
|
348 |
+
}
|
349 |
+
}
|
350 |
+
|
351 |
+
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.json"
|
352 |
+
with open(output_file, 'w') as f:
|
353 |
+
json.dump(data, f, indent=2)
|
354 |
+
|
355 |
+
return str(output_file)
|
356 |
+
|
357 |
+
elif format == 'netcdf':
|
358 |
+
# Create an xarray Dataset
|
359 |
+
data_vars = {}
|
360 |
+
coords = {
|
361 |
+
'point': np.arange(len(state['latitudes'])),
|
362 |
+
'latitude': ('point', state['latitudes']),
|
363 |
+
'longitude': ('point', state['longitudes']),
|
364 |
+
}
|
365 |
+
|
366 |
+
# Add each field as a variable
|
367 |
+
for var_name, values in state['fields'].items():
|
368 |
+
data_vars[var_name] = (['point'], values)
|
369 |
+
|
370 |
+
# Create the dataset
|
371 |
+
ds = xr.Dataset(
|
372 |
+
data_vars=data_vars,
|
373 |
+
coords=coords,
|
374 |
+
attrs={
|
375 |
+
'forecast_date': forecast_time,
|
376 |
+
'export_date': datetime.datetime.now().isoformat(),
|
377 |
+
'description': 'AIFS Weather Forecast Data'
|
378 |
+
}
|
379 |
+
)
|
380 |
+
|
381 |
+
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.nc"
|
382 |
+
ds.to_netcdf(output_file)
|
383 |
+
|
384 |
+
return str(output_file)
|
385 |
+
|
386 |
+
elif format == 'csv':
|
387 |
+
# Create a DataFrame with lat/lon and all variables
|
388 |
+
df = pd.DataFrame({
|
389 |
+
'latitude': state['latitudes'],
|
390 |
+
'longitude': state['longitudes']
|
391 |
+
})
|
392 |
+
|
393 |
+
# Add each field as a column
|
394 |
+
for var_name, values in state['fields'].items():
|
395 |
+
df[var_name] = values
|
396 |
+
|
397 |
+
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.csv"
|
398 |
+
df.to_csv(output_file, index=False)
|
399 |
+
|
400 |
+
return str(output_file)
|
401 |
+
|
402 |
+
else:
|
403 |
+
raise ValueError(f"Unsupported format: {format}")
|
404 |
+
|
405 |
+
# Create dropdown choices with groups
|
406 |
+
DROPDOWN_CHOICES = []
|
407 |
+
for group_name, variables in VARIABLE_GROUPS.items():
|
408 |
+
# Add group separator
|
409 |
+
DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
|
410 |
+
# Add variables in this group
|
411 |
+
for var_id, desc in sorted(variables.items()):
|
412 |
+
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
|
413 |
+
|
414 |
def update_interface():
|
415 |
with gr.Blocks(css="""
|
416 |
.centered-header {
|
|
|
429 |
border-top: 1px solid #eee;
|
430 |
}
|
431 |
""") as demo:
|
432 |
+
forecast_state = gr.State(None)
|
433 |
+
|
434 |
+
# Header section
|
435 |
+
gr.Markdown(f"""
|
436 |
+
# AIFS Weather Forecast
|
437 |
+
|
438 |
+
<div class="subtitle">
|
439 |
+
Interactive visualization of ECMWF AIFS weather forecasts.<br>
|
440 |
+
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
|
441 |
+
select how many hours ahead you want to forecast and which meteorological variable to visualize.
|
442 |
+
</div>
|
443 |
+
""")
|
444 |
|
445 |
with gr.Row():
|
446 |
with gr.Column(scale=1):
|
|
|
451 |
value=12,
|
452 |
label="Forecast Hours Ahead"
|
453 |
)
|
454 |
+
# Start with the original DROPDOWN_CHOICES
|
455 |
variable = gr.Dropdown(
|
456 |
+
choices=DROPDOWN_CHOICES, # Use original choices at startup
|
457 |
+
value="2t",
|
458 |
label="Select Variable to Plot"
|
459 |
)
|
460 |
with gr.Row():
|
461 |
clear_btn = gr.Button("Clear")
|
462 |
run_btn = gr.Button("Run Forecast", variant="primary")
|
463 |
|
464 |
+
download_nc = gr.Button("Download Forecast (NetCDF)")
|
465 |
+
download_output = gr.File(label="Download Output")
|
|
|
466 |
|
467 |
with gr.Column(scale=2):
|
468 |
forecast_output = gr.Image()
|
469 |
|
470 |
def run_and_store(lead_time):
|
471 |
"""Run forecast and store state"""
|
472 |
+
forecast_state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
|
473 |
+
plot = plot_forecast(forecast_state, "2t") # Default to 2t
|
474 |
+
return forecast_state, plot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
|
476 |
+
def update_plot_from_state(forecast_state, variable):
|
477 |
"""Update plot using stored state"""
|
478 |
+
if forecast_state is None or variable is None:
|
479 |
return None
|
480 |
try:
|
481 |
+
return plot_forecast(forecast_state, variable)
|
482 |
except KeyError as e:
|
483 |
logger.error(f"Variable {variable} not found in state: {e}")
|
484 |
return None
|
485 |
|
486 |
def clear():
|
487 |
"""Clear everything"""
|
488 |
+
return [None, None, 12, "2t"]
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
+
def save_netcdf(forecast_state):
|
491 |
+
"""Save forecast data as NetCDF"""
|
492 |
+
if forecast_state is None:
|
493 |
+
raise ValueError("No forecast data available. Please run a forecast first.")
|
494 |
+
|
495 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
496 |
+
forecast_time = forecast_state['date'].strftime("%Y%m%d_%H") if isinstance(forecast_state['date'], datetime.datetime) else forecast_state['date']
|
497 |
+
|
498 |
+
# Create an xarray Dataset
|
499 |
+
data_vars = {}
|
500 |
+
coords = {
|
501 |
+
'point': np.arange(len(forecast_state['latitudes'])),
|
502 |
+
'latitude': ('point', forecast_state['latitudes']),
|
503 |
+
'longitude': ('point', forecast_state['longitudes']),
|
504 |
+
}
|
505 |
+
|
506 |
+
# Add each field as a variable
|
507 |
+
for var_name, values in forecast_state['fields'].items():
|
508 |
+
data_vars[var_name] = (['point'], values)
|
509 |
+
|
510 |
+
# Create the dataset
|
511 |
+
ds = xr.Dataset(
|
512 |
+
data_vars=data_vars,
|
513 |
+
coords=coords,
|
514 |
+
attrs={
|
515 |
+
'forecast_date': forecast_time,
|
516 |
+
'export_date': datetime.datetime.now().isoformat(),
|
517 |
+
'description': 'AIFS Weather Forecast Data'
|
518 |
+
}
|
519 |
+
)
|
520 |
+
|
521 |
+
output_file = TEMP_DIR / "forecasts" / f"forecast_{forecast_time}_{timestamp}.nc"
|
522 |
+
ds.to_netcdf(output_file)
|
523 |
+
|
524 |
+
return str(output_file)
|
525 |
|
526 |
# Connect the components
|
527 |
run_btn.click(
|
528 |
fn=run_and_store,
|
529 |
inputs=[lead_time],
|
530 |
+
outputs=[forecast_state, forecast_output]
|
531 |
)
|
532 |
|
533 |
variable.change(
|
534 |
fn=update_plot_from_state,
|
535 |
+
inputs=[forecast_state, variable],
|
536 |
outputs=forecast_output
|
537 |
)
|
538 |
|
539 |
clear_btn.click(
|
540 |
fn=clear,
|
541 |
inputs=[],
|
542 |
+
outputs=[forecast_state, forecast_output, lead_time, variable]
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
)
|
544 |
|
545 |
download_nc.click(
|
546 |
fn=save_netcdf,
|
547 |
+
inputs=[forecast_state],
|
548 |
+
outputs=[download_output]
|
549 |
)
|
550 |
|
551 |
return demo
|
|
|
553 |
# Create and launch the interface
|
554 |
demo = update_interface()
|
555 |
demo.launch()
|
556 |
+
|
557 |
+
def setup_directories():
|
558 |
+
"""Create necessary directories with .keep files"""
|
559 |
+
# Define all required directories
|
560 |
+
directories = {
|
561 |
+
TEMP_DIR / "data_cache": "Cache directory for downloaded weather data",
|
562 |
+
TEMP_DIR / "forecasts": "Directory for forecast outputs (plots and data files)",
|
563 |
+
}
|
564 |
+
|
565 |
+
# Create directories and .keep files
|
566 |
+
for directory, description in directories.items():
|
567 |
+
directory.mkdir(parents=True, exist_ok=True)
|
568 |
+
keep_file = directory / ".keep"
|
569 |
+
if not keep_file.exists():
|
570 |
+
keep_file.write_text(f"# {description}\n# This file ensures the directory is tracked in git\n")
|
571 |
+
logger.info(f"Created directory and .keep file: {directory}")
|
572 |
+
|
573 |
+
# Call it during initialization
|
574 |
+
setup_directories()
|
575 |
+
|
576 |
+
def cleanup_old_files():
|
577 |
+
"""Remove old temporary and cache files"""
|
578 |
+
current_time = datetime.datetime.now().timestamp()
|
579 |
+
|
580 |
+
# Clean up forecast files (1 hour old)
|
581 |
+
forecast_dir = TEMP_DIR / "forecasts"
|
582 |
+
for file in forecast_dir.glob("*.*"):
|
583 |
+
if file.name == ".keep":
|
584 |
+
continue
|
585 |
+
if current_time - file.stat().st_mtime > 3600:
|
586 |
+
logger.info(f"Removing old forecast file: {file}")
|
587 |
+
file.unlink(missing_ok=True)
|
588 |
+
|
589 |
+
# Clean up cache files (24 hours old)
|
590 |
+
cache_dir = TEMP_DIR / "data_cache"
|
591 |
+
for file in cache_dir.glob("*.pkl"):
|
592 |
+
if file.name == ".keep":
|
593 |
+
continue
|
594 |
+
if current_time - file.stat().st_mtime > 86400:
|
595 |
+
logger.info(f"Removing old cache file: {file}")
|
596 |
+
file.unlink(missing_ok=True)
|
gradio_temp/.keep
CHANGED
File without changes
|