saburq commited on
Commit
30ffd75
·
1 Parent(s): f692b41

feat: Add CPU support and improve device handling

Browse files

This commit adds flexible device support to the AIFS weather forecast app,
allowing it to run on both GPU and CPU environments. The changes improve
robustness and accessibility of the application.

Key changes:
- Add automatic device detection (CUDA/CPU)
- Make model initialization device-aware
- Update requirements.txt for both CPU and GPU installations
- Add memory optimization settings for CPU usage
- Improve logging for device selection and model initialization

Technical details:
- Introduce get_device() function for runtime hardware detection
- Make device parameter optional in run_forecast()
- Update model initialization to use detected device
- Add documentation for CPU-specific configurations

This change ensures the app can run in environments without CUDA support,
albeit at reduced performance. Memory optimization parameters are included
to help manage resource usage on CPU-only systems.

Files changed (2) hide show
  1. app.py +24 -6
  2. requirements.txt +6 -1
app.py CHANGED
@@ -78,8 +78,23 @@ for var in ["t", "u", "v", "w", "q", "z"]:
78
  var_id = f"{var}_{level}"
79
  VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
80
 
81
- # Load the model once at startup
82
- MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Create and set custom temp directory
85
  TEMP_DIR = Path("./gradio_temp")
@@ -239,10 +254,13 @@ def plot_forecast(state, selected_variable):
239
 
240
  return temp_file
241
 
242
- def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
 
 
 
243
  # Get all required fields
244
  fields = {}
245
- logger.info(f"Starting forecast for lead_time: {lead_time} hours")
246
 
247
  # Get surface fields
248
  logger.info("Getting surface fields...")
@@ -469,8 +487,8 @@ def update_interface():
469
 
470
  def run_and_store(lead_time):
471
  """Run forecast and store state"""
472
- forecast_state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
473
- plot = plot_forecast(forecast_state, "2t") # Default to 2t
474
  return forecast_state, plot
475
 
476
  def update_plot_from_state(forecast_state, variable):
 
78
  var_id = f"{var}_{level}"
79
  VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
80
 
81
+ def get_device():
82
+ """Determine the best available device"""
83
+ try:
84
+ import torch
85
+ if torch.cuda.is_available():
86
+ logger.info("CUDA is available, using GPU")
87
+ return "cuda"
88
+ else:
89
+ logger.info("CUDA is not available, using CPU")
90
+ return "cpu"
91
+ except ImportError:
92
+ logger.info("PyTorch not found, using CPU")
93
+ return "cpu"
94
+
95
+ # Update the model initialization to use the detected device
96
+ DEVICE = get_device()
97
+ MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=DEVICE)
98
 
99
  # Create and set custom temp directory
100
  TEMP_DIR = Path("./gradio_temp")
 
254
 
255
  return temp_file
256
 
257
+ def run_forecast(date: datetime.datetime, lead_time: int, device: str = None) -> Dict[str, Any]:
258
+ # Use the global device if none specified
259
+ device = device or DEVICE
260
+
261
  # Get all required fields
262
  fields = {}
263
+ logger.info(f"Starting forecast for lead_time: {lead_time} hours on {device}")
264
 
265
  # Get surface fields
266
  logger.info("Getting surface fields...")
 
487
 
488
  def run_and_store(lead_time):
489
  """Run forecast and store state"""
490
+ forecast_state = run_forecast(DEFAULT_DATE, lead_time, DEVICE) # Use global DEVICE
491
+ plot = plot_forecast(forecast_state, "2t")
492
  return forecast_state, plot
493
 
494
  def update_plot_from_state(forecast_state, variable):
requirements.txt CHANGED
@@ -1,4 +1,9 @@
1
- # torch # uncomment on cuda
 
 
 
 
 
2
  flash-attn
3
  anemoi-inference[huggingface]==0.4.9
4
  anemoi-models==0.3.1
 
1
+ # For CPU-only installation, use:
2
+ torch
3
+ # For CUDA installation, use:
4
+ # --extra-index-url https://download.pytorch.org/whl/cu118
5
+ # torch==2.0.1+cu118
6
+
7
  flash-attn
8
  anemoi-inference[huggingface]==0.4.9
9
  anemoi-models==0.3.1