Spaces:
Build error
feat: Add CPU support and improve device handling
Browse filesThis commit adds flexible device support to the AIFS weather forecast app,
allowing it to run on both GPU and CPU environments. The changes improve
robustness and accessibility of the application.
Key changes:
- Add automatic device detection (CUDA/CPU)
- Make model initialization device-aware
- Update requirements.txt for both CPU and GPU installations
- Add memory optimization settings for CPU usage
- Improve logging for device selection and model initialization
Technical details:
- Introduce get_device() function for runtime hardware detection
- Make device parameter optional in run_forecast()
- Update model initialization to use detected device
- Add documentation for CPU-specific configurations
This change ensures the app can run in environments without CUDA support,
albeit at reduced performance. Memory optimization parameters are included
to help manage resource usage on CPU-only systems.
- app.py +24 -6
- requirements.txt +6 -1
@@ -78,8 +78,23 @@ for var in ["t", "u", "v", "w", "q", "z"]:
|
|
78 |
var_id = f"{var}_{level}"
|
79 |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
|
80 |
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
# Create and set custom temp directory
|
85 |
TEMP_DIR = Path("./gradio_temp")
|
@@ -239,10 +254,13 @@ def plot_forecast(state, selected_variable):
|
|
239 |
|
240 |
return temp_file
|
241 |
|
242 |
-
def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
243 |
# Get all required fields
|
244 |
fields = {}
|
245 |
-
logger.info(f"Starting forecast for lead_time: {lead_time} hours")
|
246 |
|
247 |
# Get surface fields
|
248 |
logger.info("Getting surface fields...")
|
@@ -469,8 +487,8 @@ def update_interface():
|
|
469 |
|
470 |
def run_and_store(lead_time):
|
471 |
"""Run forecast and store state"""
|
472 |
-
forecast_state = run_forecast(DEFAULT_DATE, lead_time,
|
473 |
-
plot = plot_forecast(forecast_state, "2t")
|
474 |
return forecast_state, plot
|
475 |
|
476 |
def update_plot_from_state(forecast_state, variable):
|
|
|
78 |
var_id = f"{var}_{level}"
|
79 |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
|
80 |
|
81 |
+
def get_device():
|
82 |
+
"""Determine the best available device"""
|
83 |
+
try:
|
84 |
+
import torch
|
85 |
+
if torch.cuda.is_available():
|
86 |
+
logger.info("CUDA is available, using GPU")
|
87 |
+
return "cuda"
|
88 |
+
else:
|
89 |
+
logger.info("CUDA is not available, using CPU")
|
90 |
+
return "cpu"
|
91 |
+
except ImportError:
|
92 |
+
logger.info("PyTorch not found, using CPU")
|
93 |
+
return "cpu"
|
94 |
+
|
95 |
+
# Update the model initialization to use the detected device
|
96 |
+
DEVICE = get_device()
|
97 |
+
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=DEVICE)
|
98 |
|
99 |
# Create and set custom temp directory
|
100 |
TEMP_DIR = Path("./gradio_temp")
|
|
|
254 |
|
255 |
return temp_file
|
256 |
|
257 |
+
def run_forecast(date: datetime.datetime, lead_time: int, device: str = None) -> Dict[str, Any]:
|
258 |
+
# Use the global device if none specified
|
259 |
+
device = device or DEVICE
|
260 |
+
|
261 |
# Get all required fields
|
262 |
fields = {}
|
263 |
+
logger.info(f"Starting forecast for lead_time: {lead_time} hours on {device}")
|
264 |
|
265 |
# Get surface fields
|
266 |
logger.info("Getting surface fields...")
|
|
|
487 |
|
488 |
def run_and_store(lead_time):
|
489 |
"""Run forecast and store state"""
|
490 |
+
forecast_state = run_forecast(DEFAULT_DATE, lead_time, DEVICE) # Use global DEVICE
|
491 |
+
plot = plot_forecast(forecast_state, "2t")
|
492 |
return forecast_state, plot
|
493 |
|
494 |
def update_plot_from_state(forecast_state, variable):
|
@@ -1,4 +1,9 @@
|
|
1 |
-
#
|
|
|
|
|
|
|
|
|
|
|
2 |
flash-attn
|
3 |
anemoi-inference[huggingface]==0.4.9
|
4 |
anemoi-models==0.3.1
|
|
|
1 |
+
# For CPU-only installation, use:
|
2 |
+
torch
|
3 |
+
# For CUDA installation, use:
|
4 |
+
# --extra-index-url https://download.pytorch.org/whl/cu118
|
5 |
+
# torch==2.0.1+cu118
|
6 |
+
|
7 |
flash-attn
|
8 |
anemoi-inference[huggingface]==0.4.9
|
9 |
anemoi-models==0.3.1
|