Spaces:
Build error
Build error
add animation
Browse files
app.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import os
|
|
|
|
|
2 |
# Set memory optimization environment variables
|
3 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
4 |
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'
|
@@ -14,6 +16,7 @@ from anemoi.inference.runners.simple import SimpleRunner
|
|
14 |
from ecmwf.opendata import Client as OpendataClient
|
15 |
import earthkit.data as ekd
|
16 |
import earthkit.regrid as ekr
|
|
|
17 |
|
18 |
# Define parameters (updating to match notebook.py)
|
19 |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
|
@@ -66,6 +69,11 @@ for var in ["t", "u", "v", "w", "q", "z"]:
|
|
66 |
# Load the model once at startup
|
67 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
|
68 |
|
|
|
|
|
|
|
|
|
|
|
69 |
def get_open_data(param, levelist=[]):
|
70 |
fields = {}
|
71 |
# Get the data for the current date and the previous date
|
@@ -90,6 +98,91 @@ def get_open_data(param, levelist=[]):
|
|
90 |
|
91 |
return fields
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def run_forecast(date, lead_time, device):
|
94 |
# Get all required fields
|
95 |
fields = {}
|
@@ -122,42 +215,24 @@ def run_forecast(date, lead_time, device):
|
|
122 |
if device != MODEL.device:
|
123 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
124 |
|
125 |
-
|
|
|
126 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
127 |
-
|
128 |
-
return
|
129 |
|
130 |
-
def
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
ax.coastlines(resolution='50m')
|
143 |
-
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
144 |
-
ax.gridlines(draw_labels=True)
|
145 |
-
|
146 |
-
# Fix longitudes to be -180 to 180
|
147 |
-
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
148 |
-
|
149 |
-
# Create triangulation with fixed longitudes
|
150 |
-
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
151 |
-
|
152 |
-
# Create the contour plot
|
153 |
-
contour = ax.tricontourf(triangulation, values, levels=20,
|
154 |
-
transform=ccrs.PlateCarree(),
|
155 |
-
cmap='RdBu_r')
|
156 |
-
|
157 |
-
plt.title(f"{selected_variable} at {state['date']}")
|
158 |
-
plt.colorbar(contour, orientation='horizontal', pad=0.05)
|
159 |
-
|
160 |
-
return fig
|
161 |
|
162 |
# Create dropdown choices with groups
|
163 |
DROPDOWN_CHOICES = []
|
@@ -173,17 +248,31 @@ with gr.Blocks(css="""
|
|
173 |
text-align: center;
|
174 |
margin-bottom: 20px;
|
175 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
""") as demo:
|
177 |
-
#
|
178 |
gr.Markdown(f"""
|
179 |
# AIFS Weather Forecast
|
180 |
|
181 |
-
|
|
|
|
|
182 |
select how many hours ahead you want to forecast and which meteorological variable to visualize.
|
183 |
-
|
|
|
184 |
|
|
|
185 |
with gr.Row():
|
186 |
-
# Controls column - takes up 1/3 of the width
|
187 |
with gr.Column(scale=1):
|
188 |
lead_time = gr.Slider(
|
189 |
minimum=6,
|
@@ -195,34 +284,52 @@ with gr.Blocks(css="""
|
|
195 |
variable = gr.Dropdown(
|
196 |
choices=DROPDOWN_CHOICES,
|
197 |
value="2t",
|
198 |
-
label="Select Variable to Plot"
|
199 |
-
info="Choose a meteorological variable to visualize"
|
200 |
)
|
201 |
-
|
202 |
-
# Add buttons in a row
|
203 |
with gr.Row():
|
204 |
clear_btn = gr.Button("Clear")
|
205 |
submit_btn = gr.Button("Submit", variant="primary")
|
206 |
|
207 |
-
# Map column - takes up 2/3 of the width
|
208 |
with gr.Column(scale=2):
|
209 |
-
|
210 |
|
211 |
-
#
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
# Clear function to reset to defaults
|
217 |
def clear():
|
218 |
return [
|
219 |
-
12,
|
220 |
-
"2t",
|
221 |
-
None
|
222 |
]
|
223 |
|
224 |
-
# Connect the
|
225 |
-
submit_btn.click(
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
demo.launch()
|
|
|
1 |
import os
|
2 |
+
import tempfile
|
3 |
+
from pathlib import Path
|
4 |
# Set memory optimization environment variables
|
5 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
6 |
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'
|
|
|
16 |
from ecmwf.opendata import Client as OpendataClient
|
17 |
import earthkit.data as ekd
|
18 |
import earthkit.regrid as ekr
|
19 |
+
import matplotlib.animation as animation
|
20 |
|
21 |
# Define parameters (updating to match notebook.py)
|
22 |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
|
|
|
69 |
# Load the model once at startup
|
70 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
|
71 |
|
72 |
+
# Create and set custom temp directory
|
73 |
+
TEMP_DIR = Path("./gradio_temp")
|
74 |
+
TEMP_DIR.mkdir(exist_ok=True)
|
75 |
+
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
|
76 |
+
|
77 |
def get_open_data(param, levelist=[]):
|
78 |
fields = {}
|
79 |
# Get the data for the current date and the previous date
|
|
|
98 |
|
99 |
return fields
|
100 |
|
101 |
+
def plot_forecast_animation(states, selected_variable):
|
102 |
+
# Setup the figure and axis
|
103 |
+
fig = plt.figure(figsize=(15, 8))
|
104 |
+
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
|
105 |
+
|
106 |
+
# Get the first state to setup the plot
|
107 |
+
first_state = states[0]
|
108 |
+
latitudes, longitudes = first_state["latitudes"], first_state["longitudes"]
|
109 |
+
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
110 |
+
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
111 |
+
|
112 |
+
# Find global min/max for consistent colorbar
|
113 |
+
all_values = [state["fields"][selected_variable] for state in states]
|
114 |
+
vmin, vmax = np.min(all_values), np.max(all_values)
|
115 |
+
|
116 |
+
# Create a single colorbar that will be reused
|
117 |
+
contour = None
|
118 |
+
cbar_ax = None
|
119 |
+
|
120 |
+
def update(frame):
|
121 |
+
nonlocal contour, cbar_ax
|
122 |
+
ax.clear()
|
123 |
+
|
124 |
+
# Set map features
|
125 |
+
ax.set_global()
|
126 |
+
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
|
127 |
+
ax.coastlines(resolution='50m')
|
128 |
+
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
129 |
+
ax.gridlines(draw_labels=True)
|
130 |
+
|
131 |
+
state = states[frame]
|
132 |
+
values = state["fields"][selected_variable]
|
133 |
+
|
134 |
+
# Clear the previous colorbar axis if it exists
|
135 |
+
if cbar_ax:
|
136 |
+
cbar_ax.remove()
|
137 |
+
|
138 |
+
# Create new contour plot
|
139 |
+
contour = ax.tricontourf(triangulation, values,
|
140 |
+
levels=20, transform=ccrs.PlateCarree(),
|
141 |
+
cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
142 |
+
|
143 |
+
# Create new colorbar
|
144 |
+
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03]) # [left, bottom, width, height]
|
145 |
+
plt.colorbar(contour, cax=cbar_ax, orientation='horizontal')
|
146 |
+
|
147 |
+
# Format the date string properly
|
148 |
+
forecast_time = state["date"]
|
149 |
+
if isinstance(forecast_time, str):
|
150 |
+
try:
|
151 |
+
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S")
|
152 |
+
except ValueError:
|
153 |
+
try:
|
154 |
+
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f")
|
155 |
+
except ValueError:
|
156 |
+
forecast_time = DEFAULT_DATE
|
157 |
+
|
158 |
+
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
159 |
+
|
160 |
+
# Get variable description from VARIABLE_GROUPS
|
161 |
+
var_desc = None
|
162 |
+
for group in VARIABLE_GROUPS.values():
|
163 |
+
if selected_variable in group:
|
164 |
+
var_desc = group[selected_variable]
|
165 |
+
break
|
166 |
+
var_name = var_desc if var_desc else selected_variable
|
167 |
+
|
168 |
+
ax.set_title(f"{var_name} - {time_str}")
|
169 |
+
|
170 |
+
# Create animation
|
171 |
+
anim = animation.FuncAnimation(
|
172 |
+
fig, update,
|
173 |
+
frames=len(states),
|
174 |
+
interval=1000, # 1 second between frames
|
175 |
+
repeat=True,
|
176 |
+
blit=False # Must be False to update the colorbar
|
177 |
+
)
|
178 |
+
|
179 |
+
# Save as MP4
|
180 |
+
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.mp4")
|
181 |
+
anim.save(temp_file, writer='ffmpeg', fps=1)
|
182 |
+
plt.close()
|
183 |
+
|
184 |
+
return temp_file
|
185 |
+
|
186 |
def run_forecast(date, lead_time, device):
|
187 |
# Get all required fields
|
188 |
fields = {}
|
|
|
215 |
if device != MODEL.device:
|
216 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
217 |
|
218 |
+
# Collect all states instead of just the last one
|
219 |
+
states = []
|
220 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
221 |
+
states.append(state)
|
222 |
+
return states
|
223 |
|
224 |
+
def update_plot(lead_time, variable):
|
225 |
+
cleanup_old_files() # Clean up old files before creating new ones
|
226 |
+
states = run_forecast(DEFAULT_DATE, lead_time, "cuda")
|
227 |
+
return plot_forecast_animation(states, variable)
|
228 |
+
|
229 |
+
# Add cleanup function for old files
|
230 |
+
def cleanup_old_files():
|
231 |
+
# Remove files older than 1 hour
|
232 |
+
current_time = datetime.datetime.now().timestamp()
|
233 |
+
for file in TEMP_DIR.glob("*.mp4"): # Changed from *.gif to *.mp4
|
234 |
+
if current_time - file.stat().st_mtime > 3600: # 1 hour in seconds
|
235 |
+
file.unlink(missing_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
# Create dropdown choices with groups
|
238 |
DROPDOWN_CHOICES = []
|
|
|
248 |
text-align: center;
|
249 |
margin-bottom: 20px;
|
250 |
}
|
251 |
+
.subtitle {
|
252 |
+
font-size: 1.2em;
|
253 |
+
line-height: 1.5;
|
254 |
+
margin: 20px 0;
|
255 |
+
}
|
256 |
+
.footer {
|
257 |
+
text-align: center;
|
258 |
+
padding: 20px;
|
259 |
+
margin-top: 20px;
|
260 |
+
border-top: 1px solid #eee;
|
261 |
+
}
|
262 |
""") as demo:
|
263 |
+
# Header section
|
264 |
gr.Markdown(f"""
|
265 |
# AIFS Weather Forecast
|
266 |
|
267 |
+
<div class="subtitle">
|
268 |
+
Interactive visualization of ECMWF AIFS weather forecasts.<br>
|
269 |
+
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
|
270 |
select how many hours ahead you want to forecast and which meteorological variable to visualize.
|
271 |
+
</div>
|
272 |
+
""")
|
273 |
|
274 |
+
# Main content
|
275 |
with gr.Row():
|
|
|
276 |
with gr.Column(scale=1):
|
277 |
lead_time = gr.Slider(
|
278 |
minimum=6,
|
|
|
284 |
variable = gr.Dropdown(
|
285 |
choices=DROPDOWN_CHOICES,
|
286 |
value="2t",
|
287 |
+
label="Select Variable to Plot"
|
|
|
288 |
)
|
|
|
|
|
289 |
with gr.Row():
|
290 |
clear_btn = gr.Button("Clear")
|
291 |
submit_btn = gr.Button("Submit", variant="primary")
|
292 |
|
|
|
293 |
with gr.Column(scale=2):
|
294 |
+
animation_output = gr.Video()
|
295 |
|
296 |
+
# Footer with fork instructions and model reference
|
297 |
+
gr.Markdown("""
|
298 |
+
<div class="footer">
|
299 |
+
<h3>Want to run this on your own?</h3>
|
300 |
+
You can fork this space and run it yourself:
|
301 |
+
|
302 |
+
1. Visit <a href="https://huggingface.co/spaces/geobase/aifs-forecast" target="_blank">https://huggingface.co/spaces/geobase/aifs-forecast</a>\n
|
303 |
+
2. Click the "Duplicate this Space" button in the top right\n
|
304 |
+
3. Select your hardware requirements (GPU recommended)\n
|
305 |
+
4. Wait for your copy to deploy
|
306 |
+
|
307 |
+
<h3>Model Information</h3>
|
308 |
+
This demo uses the <a href="https://huggingface.co/ecmwf/aifs-single-1.0" target="_blank">AIFS Single 1.0</a> model from ECMWF,
|
309 |
+
which is their first operationally supported Artificial Intelligence Forecasting System. The model produces highly skilled forecasts
|
310 |
+
for upper-air variables, surface weather parameters, and tropical cyclone tracks.
|
311 |
+
|
312 |
+
Note: If you encounter any issues with this demo, trying your own fork might work better!
|
313 |
+
</div>
|
314 |
+
""")
|
315 |
|
|
|
316 |
def clear():
|
317 |
return [
|
318 |
+
12,
|
319 |
+
"2t",
|
320 |
+
None
|
321 |
]
|
322 |
|
323 |
+
# Connect the inputs to the forecast function
|
324 |
+
submit_btn.click(
|
325 |
+
fn=update_plot,
|
326 |
+
inputs=[lead_time, variable],
|
327 |
+
outputs=animation_output
|
328 |
+
)
|
329 |
+
clear_btn.click(
|
330 |
+
fn=clear,
|
331 |
+
inputs=[],
|
332 |
+
outputs=[lead_time, variable, animation_output]
|
333 |
+
)
|
334 |
|
335 |
demo.launch()
|