puqi commited on
Commit
d01872e
·
1 Parent(s): d738cb1

Upload data_utils.py

Browse files
Files changed (1) hide show
  1. data_utils.py +993 -0
data_utils.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xarray as xr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import pickle
6
+ import glob, os
7
+ import re
8
+ import tensorflow as tf
9
+ import netCDF4
10
+ import copy
11
+ import string
12
+ import h5py
13
+ from tqdm import tqdm
14
+
15
+ class data_utils:
16
+ def __init__(self,
17
+ grid_info,
18
+ input_mean,
19
+ input_max,
20
+ input_min,
21
+ output_scale):
22
+ self.data_path = None
23
+ self.input_vars = []
24
+ self.target_vars = []
25
+ self.input_feature_len = None
26
+ self.target_feature_len = None
27
+ self.grid_info = grid_info
28
+ self.level_name = 'lev'
29
+ self.sample_name = 'sample'
30
+ self.latlonnum = len(self.grid_info['ncol']) # number of unique lat/lon grid points
31
+ # make area-weights
32
+ self.grid_info['area_wgt'] = self.grid_info['area']/self.grid_info['area'].mean(dim = 'ncol')
33
+ self.area_wgt = self.grid_info['area_wgt'].values
34
+ # map ncol to nsamples dimension
35
+ # to_xarray = {'area_wgt':(self.sample_name,np.tile(self.grid_info['area_wgt'], int(n_samples/len(self.grid_info['ncol']))))}
36
+ # to_xarray = xr.Dataset(to_xarray)
37
+ self.input_mean = input_mean
38
+ self.input_max = input_max
39
+ self.input_min = input_min
40
+ self.output_scale = output_scale
41
+ self.lats, self.lats_indices = np.unique(self.grid_info['lat'].values, return_index=True)
42
+ self.lons, self.lons_indices = np.unique(self.grid_info['lon'].values, return_index=True)
43
+ self.sort_lat_key = np.argsort(self.grid_info['lat'].values[np.sort(self.lats_indices)])
44
+ self.sort_lon_key = np.argsort(self.grid_info['lon'].values[np.sort(self.lons_indices)])
45
+ self.indextolatlon = {i: (self.grid_info['lat'].values[i%self.latlonnum], self.grid_info['lon'].values[i%self.latlonnum]) for i in range(self.latlonnum)}
46
+
47
+ def find_keys(dictionary, value):
48
+ keys = []
49
+ for key, val in dictionary.items():
50
+ if val[0] == value:
51
+ keys.append(key)
52
+ return keys
53
+ indices_list = []
54
+ for lat in self.lats:
55
+ indices = find_keys(self.indextolatlon, lat)
56
+ indices_list.append(indices)
57
+ indices_list.sort(key = lambda x: x[0])
58
+ self.lat_indices_list = indices_list
59
+
60
+ self.hyam = self.grid_info['hyam'].values
61
+ self.hybm = self.grid_info['hybm'].values
62
+ self.p0 = 1e5 # code assumes this will always be a scalar
63
+
64
+ self.pressure_grid_train = None
65
+ self.pressure_grid_val = None
66
+ self.pressure_grid_scoring = None
67
+ self.pressure_grid_test = None
68
+
69
+ self.dp_train = None
70
+ self.dp_val = None
71
+ self.dp_scoring = None
72
+ self.dp_test = None
73
+
74
+ self.train_regexps = None
75
+ self.train_stride_sample = None
76
+ self.train_filelist = None
77
+ self.val_regexps = None
78
+ self.val_stride_sample = None
79
+ self.val_filelist = None
80
+ self.scoring_regexps = None
81
+ self.scoring_stride_sample = None
82
+ self.scoring_filelist = None
83
+ self.test_regexps = None
84
+ self.test_stride_sample = None
85
+ self.test_filelist = None
86
+
87
+ # physical constants from E3SM_ROOT/share/util/shr_const_mod.F90
88
+ self.grav = 9.80616 # acceleration of gravity ~ m/s^2
89
+ self.cp = 1.00464e3 # specific heat of dry air ~ J/kg/K
90
+ self.lv = 2.501e6 # latent heat of evaporation ~ J/kg
91
+ self.lf = 3.337e5 # latent heat of fusion ~ J/kg
92
+ self.lsub = self.lv + self.lf # latent heat of sublimation ~ J/kg
93
+ self.rho_air = 101325/(6.02214e26*1.38065e-23/28.966)/273.15 # density of dry air at STP ~ kg/m^3
94
+ # ~ 1.2923182846924677
95
+ # SHR_CONST_PSTD/(SHR_CONST_RDAIR*SHR_CONST_TKFRZ)
96
+ # SHR_CONST_RDAIR = SHR_CONST_RGAS/SHR_CONST_MWDAIR
97
+ # SHR_CONST_RGAS = SHR_CONST_AVOGAD*SHR_CONST_BOLTZ
98
+ self.rho_h20 = 1.e3 # density of fresh water ~ kg/m^ 3
99
+
100
+ self.v1_inputs = ['state_t',
101
+ 'state_q0001',
102
+ 'state_ps',
103
+ 'pbuf_SOLIN',
104
+ 'pbuf_LHFLX',
105
+ 'pbuf_SHFLX']
106
+
107
+ self.v1_outputs = ['ptend_t',
108
+ 'ptend_q0001',
109
+ 'cam_out_NETSW',
110
+ 'cam_out_FLWDS',
111
+ 'cam_out_PRECSC',
112
+ 'cam_out_PRECC',
113
+ 'cam_out_SOLS',
114
+ 'cam_out_SOLL',
115
+ 'cam_out_SOLSD',
116
+ 'cam_out_SOLLD']
117
+
118
+ self.var_lens = {#inputs
119
+ 'state_t':60,
120
+ 'state_q0001':60,
121
+ 'state_ps':1,
122
+ 'pbuf_SOLIN':1,
123
+ 'pbuf_LHFLX':1,
124
+ 'pbuf_SHFLX':1,
125
+ #outputs
126
+ 'ptend_t':60,
127
+ 'ptend_q0001':60,
128
+ 'cam_out_NETSW':1,
129
+ 'cam_out_FLWDS':1,
130
+ 'cam_out_PRECSC':1,
131
+ 'cam_out_PRECC':1,
132
+ 'cam_out_SOLS':1,
133
+ 'cam_out_SOLL':1,
134
+ 'cam_out_SOLSD':1,
135
+ 'cam_out_SOLLD':1
136
+ }
137
+
138
+ self.var_short_names = {'ptend_t':'$dT/dt$',
139
+ 'ptend_q0001':'$dq/dt$',
140
+ 'cam_out_NETSW':'NETSW',
141
+ 'cam_out_FLWDS':'FLWDS',
142
+ 'cam_out_PRECSC':'PRECSC',
143
+ 'cam_out_PRECC':'PRECC',
144
+ 'cam_out_SOLS':'SOLS',
145
+ 'cam_out_SOLL':'SOLL',
146
+ 'cam_out_SOLSD':'SOLSD',
147
+ 'cam_out_SOLLD':'SOLLD'}
148
+
149
+ self.target_energy_conv = {'ptend_t':self.cp,
150
+ 'ptend_q0001':self.lv,
151
+ 'cam_out_NETSW':1.,
152
+ 'cam_out_FLWDS':1.,
153
+ 'cam_out_PRECSC':self.lv*self.rho_h20,
154
+ 'cam_out_PRECC':self.lv*self.rho_h20,
155
+ 'cam_out_SOLS':1.,
156
+ 'cam_out_SOLL':1.,
157
+ 'cam_out_SOLSD':1.,
158
+ 'cam_out_SOLLD':1.
159
+ }
160
+
161
+ # for metrics
162
+
163
+ self.input_train = None
164
+ self.target_train = None
165
+ self.preds_train = None
166
+ self.samples_train = None
167
+ self.target_weighted_train = {}
168
+ self.preds_weighted_train = {}
169
+ self.samples_weighted_train = {}
170
+ self.metrics_train = []
171
+ self.metrics_idx_train = {}
172
+ self.metrics_var_train = {}
173
+
174
+ self.input_val = None
175
+ self.target_val = None
176
+ self.preds_val = None
177
+ self.samples_val = None
178
+ self.target_weighted_val = {}
179
+ self.preds_weighted_val = {}
180
+ self.samples_weighted_val = {}
181
+ self.metrics_val = []
182
+ self.metrics_idx_val = {}
183
+ self.metrics_var_val = {}
184
+
185
+ self.input_scoring = None
186
+ self.target_scoring = None
187
+ self.preds_scoring = None
188
+ self.samples_scoring = None
189
+ self.target_weighted_scoring = {}
190
+ self.preds_weighted_scoring = {}
191
+ self.samples_weighted_scoring = {}
192
+ self.metrics_scoring = []
193
+ self.metrics_idx_scoring = {}
194
+ self.metrics_var_scoring = {}
195
+
196
+ self.input_test = None
197
+ self.target_test = None
198
+ self.preds_test = None
199
+ self.samples_test = None
200
+ self.target_weighted_test = {}
201
+ self.preds_weighted_test = {}
202
+ self.samples_weighted_test = {}
203
+ self.metrics_test = []
204
+ self.metrics_idx_test = {}
205
+ self.metrics_var_test = {}
206
+
207
+ self.model_names = []
208
+ self.metrics_names = []
209
+ self.metrics_dict = {'MAE': self.calc_MAE,
210
+ 'RMSE': self.calc_RMSE,
211
+ 'R2': self.calc_R2,
212
+ 'CRPS': self.calc_CRPS,
213
+ 'bias': self.calc_bias
214
+ }
215
+ self.linecolors = ['#0072B2',
216
+ '#E69F00',
217
+ '#882255',
218
+ '#009E73',
219
+ '#D55E00'
220
+ ]
221
+
222
+ def set_to_v1_vars(self):
223
+ '''
224
+ This function sets the inputs and outputs to the V1 subset.
225
+ '''
226
+ self.input_vars = self.v1_inputs
227
+ self.target_vars = self.v1_outputs
228
+ self.input_feature_len = 124
229
+ self.target_feature_len = 128
230
+
231
+ def get_xrdata(self, file, file_vars = None):
232
+ '''
233
+ This function reads in a file and returns an xarray dataset with the variables specified.
234
+ file_vars must be a list of strings.
235
+ '''
236
+ ds = xr.open_dataset(file, engine = 'netcdf4')
237
+ if file_vars is not None:
238
+ ds = ds[file_vars]
239
+ ds = ds.merge(self.grid_info[['lat','lon']])
240
+ ds = ds.where((ds['lat']>-999)*(ds['lat']<999), drop=True)
241
+ ds = ds.where((ds['lon']>-999)*(ds['lon']<999), drop=True)
242
+ return ds
243
+
244
+ def get_input(self, input_file):
245
+ '''
246
+ This function reads in a file and returns an xarray dataset with the input variables for the emulator.
247
+ '''
248
+ # read inputs
249
+ return self.get_xrdata(input_file, self.input_vars)
250
+
251
+ def get_target(self, input_file):
252
+ '''
253
+ This function reads in a file and returns an xarray dataset with the target variables for the emulator.
254
+ '''
255
+ # read inputs
256
+ ds_input = self.get_input(input_file)
257
+ ds_target = self.get_xrdata(input_file.replace('.mli.','.mlo.'))
258
+ # each timestep is 20 minutes which corresponds to 1200 seconds
259
+ ds_target['ptend_t'] = (ds_target['state_t'] - ds_input['state_t'])/1200 # T tendency [K/s]
260
+ ds_target['ptend_q0001'] = (ds_target['state_q0001'] - ds_input['state_q0001'])/1200 # Q tendency [kg/kg/s]
261
+ ds_target = ds_target[self.target_vars]
262
+ return ds_target
263
+
264
+ def set_regexps(self, data_split, regexps):
265
+ '''
266
+ This function sets the regular expressions used for getting the filelist for train, val, scoring, and test.
267
+ '''
268
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
269
+ if data_split == 'train':
270
+ self.train_regexps = regexps
271
+ elif data_split == 'val':
272
+ self.val_regexps = regexps
273
+ elif data_split == 'scoring':
274
+ self.scoring_regexps = regexps
275
+ elif data_split == 'test':
276
+ self.test_regexps = regexps
277
+
278
+ def set_stride_sample(self, data_split, stride_sample):
279
+ '''
280
+ This function sets the stride_sample for train, val, scoring, and test.
281
+ '''
282
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
283
+ if data_split == 'train':
284
+ self.train_stride_sample = stride_sample
285
+ elif data_split == 'val':
286
+ self.val_stride_sample = stride_sample
287
+ elif data_split == 'scoring':
288
+ self.scoring_stride_sample = stride_sample
289
+ elif data_split == 'test':
290
+ self.test_stride_sample = stride_sample
291
+
292
+ def set_filelist(self, data_split):
293
+ '''
294
+ This function sets the filelists corresponding to data splits for train, val, scoring, and test.
295
+ '''
296
+ filelist = []
297
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
298
+ if data_split == 'train':
299
+ assert self.train_regexps is not None, 'regexps for train is not set.'
300
+ assert self.train_stride_sample is not None, 'stride_sample for train is not set.'
301
+ for regexp in self.train_regexps:
302
+ filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
303
+ self.train_filelist = sorted(filelist)[::self.train_stride_sample]
304
+ elif data_split == 'val':
305
+ assert self.val_regexps is not None, 'regexps for val is not set.'
306
+ assert self.val_stride_sample is not None, 'stride_sample for val is not set.'
307
+ for regexp in self.val_regexps:
308
+ filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
309
+ self.val_filelist = sorted(filelist)[::self.val_stride_sample]
310
+ elif data_split == 'scoring':
311
+ assert self.scoring_regexps is not None, 'regexps for scoring is not set.'
312
+ assert self.scoring_stride_sample is not None, 'stride_sample for scoring is not set.'
313
+ for regexp in self.scoring_regexps:
314
+ filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
315
+ self.scoring_filelist = sorted(filelist)[::self.scoring_stride_sample]
316
+ elif data_split == 'test':
317
+ assert self.test_regexps is not None, 'regexps for test is not set.'
318
+ assert self.test_stride_sample is not None, 'stride_sample for test is not set.'
319
+ for regexp in self.test_regexps:
320
+ filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
321
+ self.test_filelist = sorted(filelist)[::self.test_stride_sample]
322
+
323
+ def get_filelist(self, data_split):
324
+ '''
325
+ This function returns the filelist corresponding to data splits for train, val, scoring, and test.
326
+ '''
327
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
328
+ if data_split == 'train':
329
+ assert self.train_filelist is not None, 'filelist for train is not set.'
330
+ return self.train_filelist
331
+ elif data_split == 'val':
332
+ assert self.val_filelist is not None, 'filelist for val is not set.'
333
+ return self.val_filelist
334
+ elif data_split == 'scoring':
335
+ assert self.scoring_filelist is not None, 'filelist for scoring is not set.'
336
+ return self.scoring_filelist
337
+ elif data_split == 'test':
338
+ assert self.test_filelist is not None, 'filelist for test is not set.'
339
+ return self.test_filelist
340
+
341
+ def load_ncdata_with_generator(self, data_split):
342
+ '''
343
+ This function works as a dataloader when training the emulator with raw netCDF files.
344
+ This can be used as a dataloader during training or it can be used to create entire datasets.
345
+ When used as a dataloader for training, I/O can slow down training considerably.
346
+ This function also normalizes the data.
347
+ mli corresponds to input
348
+ mlo corresponds to target
349
+ '''
350
+ filelist = self.get_filelist(data_split)
351
+ def gen():
352
+ for file in filelist:
353
+ # read inputs
354
+ ds_input = self.get_input(file)
355
+ # read targets
356
+ ds_target = self.get_target(file)
357
+
358
+ # normalization, scaling
359
+ ds_input = (ds_input - self.input_mean)/(self.input_max - self.input_min)
360
+ ds_target = ds_target*self.output_scale
361
+
362
+ # stack
363
+ # ds = ds.stack({'batch':{'sample','ncol'}})
364
+ ds_input = ds_input.stack({'batch':{'ncol'}})
365
+ ds_input = ds_input.to_stacked_array('mlvar', sample_dims=['batch'], name='mli')
366
+ # dso = dso.stack({'batch':{'sample','ncol'}})
367
+ ds_target = ds_target.stack({'batch':{'ncol'}})
368
+ ds_target = ds_target.to_stacked_array('mlvar', sample_dims=['batch'], name='mlo')
369
+ yield (ds_input.values, ds_target.values)
370
+
371
+ return tf.data.Dataset.from_generator(
372
+ gen,
373
+ output_types = (tf.float64, tf.float64),
374
+ output_shapes = ((None,124),(None,128))
375
+ )
376
+
377
+ def save_as_npy(self,
378
+ data_split,
379
+ save_path = '',
380
+ save_latlontime_dict = False):
381
+ '''
382
+ This function saves the training data as a .npy file. Prefix should be train or val.
383
+ '''
384
+ prefix = save_path + data_split
385
+ data_loader = self.load_ncdata_with_generator(data_split)
386
+ npy_iterator = list(data_loader.as_numpy_iterator())
387
+ npy_input = np.concatenate([npy_iterator[x][0] for x in range(len(npy_iterator))])
388
+ npy_target = np.concatenate([npy_iterator[x][1] for x in range(len(npy_iterator))])
389
+ with open(save_path + prefix + '_input.npy', 'wb') as f:
390
+ np.save(f, np.float32(npy_input))
391
+ with open(save_path + prefix + '_target.npy', 'wb') as f:
392
+ np.save(f, np.float32(npy_target))
393
+ if data_split == 'train':
394
+ data_files = self.train_filelist
395
+ elif data_split == 'val':
396
+ data_files = self.val_filelist
397
+ elif data_split == 'scoring':
398
+ data_files = self.scoring_filelist
399
+ elif data_split == 'test':
400
+ data_files = self.test_filelist
401
+ if save_latlontime_dict:
402
+ dates = [re.sub('^.*mli\.', '', x) for x in data_files]
403
+ dates = [re.sub('\.nc$', '', x) for x in dates]
404
+ repeat_dates = []
405
+ for date in dates:
406
+ for i in range(self.latlonnum):
407
+ repeat_dates.append(date)
408
+ latlontime = {i: [(self.grid_info['lat'].values[i%self.latlonnum], self.grid_info['lon'].values[i%self.latlonnum]), repeat_dates[i]] for i in range(npy_input.shape[0])}
409
+ with open(save_path + prefix + '_indextolatlontime.pkl', 'wb') as f:
410
+ pickle.dump(latlontime, f)
411
+
412
+ def reshape_npy(self, var_arr, var_arr_dim):
413
+ '''
414
+ This function reshapes the a variable in numpy such that time gets its own axis (instead of being num_samples x num_levels).
415
+ Shape of target would be (timestep, lat/lon combo, num_levels)
416
+ '''
417
+ var_arr = var_arr.reshape((int(var_arr.shape[0]/self.latlonnum), self.latlonnum, var_arr_dim))
418
+ return var_arr
419
+
420
+ @staticmethod
421
+ def ls(dir_path = ''):
422
+ '''
423
+ You can treat this as a Python wrapper for the bash command "ls".
424
+ '''
425
+ return os.popen(' '.join(['ls', dir_path])).read().splitlines()
426
+
427
+ @staticmethod
428
+ def set_plot_params():
429
+ '''
430
+ This function sets the plot parameters for matplotlib.
431
+ '''
432
+ plt.close('all')
433
+ plt.rcParams.update(plt.rcParamsDefault)
434
+ plt.rc('font', family='sans')
435
+ plt.rcParams.update({'font.size': 32,
436
+ 'lines.linewidth': 2,
437
+ 'axes.labelsize': 32,
438
+ 'axes.titlesize': 32,
439
+ 'xtick.labelsize': 32,
440
+ 'ytick.labelsize': 32,
441
+ 'legend.fontsize': 32,
442
+ 'axes.linewidth': 2,
443
+ "pgf.texsystem": "pdflatex"
444
+ })
445
+ # %config InlineBackend.figure_format = 'retina'
446
+ # use the above line when working in a jupyter notebook
447
+
448
+ @staticmethod
449
+ def load_npy_file(load_path = ''):
450
+ '''
451
+ This function loads the prediction .npy file.
452
+ '''
453
+ with open(load_path, 'rb') as f:
454
+ pred = np.load(f)
455
+ return pred
456
+
457
+ @staticmethod
458
+ def load_h5_file(load_path = ''):
459
+ '''
460
+ This function loads the prediction .h5 file.
461
+ '''
462
+ hf = h5py.File(load_path, 'r')
463
+ pred = np.array(hf.get('pred'))
464
+ return pred
465
+
466
+ def set_pressure_grid(self, data_split):
467
+ '''
468
+ This function sets the pressure weighting for metrics.
469
+ '''
470
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
471
+
472
+ if data_split == 'train':
473
+ assert self.input_train is not None
474
+ state_ps = self.input_train[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
475
+ state_ps = np.reshape(state_ps, (-1, self.latlonnum))
476
+ pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
477
+ pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
478
+ self.pressure_grid_train = pressure_grid_p1 + pressure_grid_p2
479
+ self.dp_train = self.pressure_grid_train[1:61,:,:] - self.pressure_grid_train[0:60,:,:]
480
+ self.dp_train = self.dp_train.transpose((1,2,0))
481
+ elif data_split == 'val':
482
+ assert self.input_val is not None
483
+ state_ps = self.input_val[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
484
+ state_ps = np.reshape(state_ps, (-1, self.latlonnum))
485
+ pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
486
+ pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
487
+ self.pressure_grid_val = pressure_grid_p1 + pressure_grid_p2
488
+ self.dp_val = self.pressure_grid_val[1:61,:,:] - self.pressure_grid_val[0:60,:,:]
489
+ self.dp_val = self.dp_val.transpose((1,2,0))
490
+ elif data_split == 'scoring':
491
+ assert self.input_scoring is not None
492
+ state_ps = self.input_scoring[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
493
+ state_ps = np.reshape(state_ps, (-1, self.latlonnum))
494
+ pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
495
+ pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
496
+ self.pressure_grid_scoring = pressure_grid_p1 + pressure_grid_p2
497
+ self.dp_scoring = self.pressure_grid_scoring[1:61,:,:] - self.pressure_grid_scoring[0:60,:,:]
498
+ self.dp_scoring = self.dp_scoring.transpose((1,2,0))
499
+ elif data_split == 'test':
500
+ assert self.input_test is not None
501
+ state_ps = self.input_test[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
502
+ state_ps = np.reshape(state_ps, (-1, self.latlonnum))
503
+ pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
504
+ pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
505
+ self.pressure_grid_test = pressure_grid_p1 + pressure_grid_p2
506
+ self.dp_test = self.pressure_grid_test[1:61,:,:] - self.pressure_grid_test[0:60,:,:]
507
+ self.dp_test = self.dp_test.transpose((1,2,0))
508
+
509
+ def get_pressure_grid_plotting(self, data_split):
510
+ '''
511
+ This function creates the temporally and zonally averaged pressure grid corresponding to a given data split.
512
+ '''
513
+ filelist = self.get_filelist(data_split)
514
+ ps = np.concatenate([self.get_xrdata(file, ['state_ps'])['state_ps'].values[np.newaxis, :] for file in tqdm(filelist)], axis = 0)[:, :, np.newaxis]
515
+ hyam_component = self.hyam[np.newaxis, np.newaxis, :]*self.p0
516
+ hybm_component = self.hybm[np.newaxis, np.newaxis, :]*ps
517
+ pressures = np.mean(hyam_component + hybm_component, axis = 0)
518
+ pg_lats = []
519
+ def find_keys(dictionary, value):
520
+ keys = []
521
+ for key, val in dictionary.items():
522
+ if val[0] == value:
523
+ keys.append(key)
524
+ return keys
525
+ for lat in self.lats:
526
+ indices = find_keys(self.indextolatlon, lat)
527
+ pg_lats.append(np.mean(pressures[indices, :], axis = 0)[:, np.newaxis])
528
+ pressure_grid_plotting = np.concatenate(pg_lats, axis = 1)
529
+ return pressure_grid_plotting
530
+
531
+ def output_weighting(self, output, data_split):
532
+ '''
533
+ This function does four transformations, and assumes we are using V1 variables:
534
+ [0] Undos the output scaling
535
+ [1] Weight vertical levels by dp/g
536
+ [2] Weight horizontal area of each grid cell by a[x]/mean(a[x])
537
+ [3] Unit conversion to a common energy unit
538
+ '''
539
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
540
+ num_samples = output.shape[0]
541
+ heating = output[:,:60].reshape((int(num_samples/self.latlonnum), self.latlonnum, 60))
542
+ moistening = output[:,60:120].reshape((int(num_samples/self.latlonnum), self.latlonnum, 60))
543
+ netsw = output[:,120].reshape((int(num_samples/self.latlonnum), self.latlonnum))
544
+ flwds = output[:,121].reshape((int(num_samples/self.latlonnum), self.latlonnum))
545
+ precsc = output[:,122].reshape((int(num_samples/self.latlonnum), self.latlonnum))
546
+ precc = output[:,123].reshape((int(num_samples/self.latlonnum), self.latlonnum))
547
+ sols = output[:,124].reshape((int(num_samples/self.latlonnum), self.latlonnum))
548
+ soll = output[:,125].reshape((int(num_samples/self.latlonnum), self.latlonnum))
549
+ solsd = output[:,126].reshape((int(num_samples/self.latlonnum), self.latlonnum))
550
+ solld = output[:,127].reshape((int(num_samples/self.latlonnum), self.latlonnum))
551
+
552
+ # heating = heating.transpose((2,0,1))
553
+ # moistening = moistening.transpose((2,0,1))
554
+ # scalar_outputs = scalar_outputs.transpose((2,0,1))
555
+
556
+ # [0] Undo output scaling
557
+ heating = heating/self.output_scale['ptend_t'].values[np.newaxis, np.newaxis, :]
558
+ moistening = moistening/self.output_scale['ptend_q0001'].values[np.newaxis, np.newaxis, :]
559
+ netsw = netsw/self.output_scale['cam_out_NETSW'].values
560
+ flwds = flwds/self.output_scale['cam_out_FLWDS'].values
561
+ precsc = precsc/self.output_scale['cam_out_PRECSC'].values
562
+ precc = precc/self.output_scale['cam_out_PRECC'].values
563
+ sols = sols/self.output_scale['cam_out_SOLS'].values
564
+ soll = soll/self.output_scale['cam_out_SOLL'].values
565
+ solsd = solsd/self.output_scale['cam_out_SOLSD'].values
566
+ solld = solld/self.output_scale['cam_out_SOLLD'].values
567
+
568
+ # [1] Weight vertical levels by dp/g
569
+ # only for vertically-resolved variables, e.g. ptend_{t,q0001}
570
+ # dp/g = -\rho * dz
571
+ if data_split == 'train':
572
+ dp = self.dp_train
573
+ elif data_split == 'val':
574
+ dp = self.dp_val
575
+ elif data_split == 'scoring':
576
+ dp = self.dp_scoring
577
+ elif data_split == 'test':
578
+ dp = self.dp_test
579
+ heating = heating * dp/self.grav
580
+ moistening = moistening * dp/self.grav
581
+
582
+ # [2] weight by area
583
+ heating = heating * self.area_wgt[np.newaxis, :, np.newaxis]
584
+ moistening = moistening * self.area_wgt[np.newaxis, :, np.newaxis]
585
+ netsw = netsw * self.area_wgt[np.newaxis, :]
586
+ flwds = flwds * self.area_wgt[np.newaxis, :]
587
+ precsc = precsc * self.area_wgt[np.newaxis, :]
588
+ precc = precc * self.area_wgt[np.newaxis, :]
589
+ sols = sols * self.area_wgt[np.newaxis, :]
590
+ soll = soll * self.area_wgt[np.newaxis, :]
591
+ solsd = solsd * self.area_wgt[np.newaxis, :]
592
+ solld = solld * self.area_wgt[np.newaxis, :]
593
+
594
+ # [3] unit conversion
595
+ heating = heating * self.target_energy_conv['ptend_t']
596
+ moistening = moistening * self.target_energy_conv['ptend_q0001']
597
+ netsw = netsw * self.target_energy_conv['cam_out_NETSW']
598
+ flwds = flwds * self.target_energy_conv['cam_out_FLWDS']
599
+ precsc = precsc * self.target_energy_conv['cam_out_PRECSC']
600
+ precc = precc * self.target_energy_conv['cam_out_PRECC']
601
+ sols = sols * self.target_energy_conv['cam_out_SOLS']
602
+ soll = soll * self.target_energy_conv['cam_out_SOLL']
603
+ solsd = solsd * self.target_energy_conv['cam_out_SOLSD']
604
+ solld = solld * self.target_energy_conv['cam_out_SOLLD']
605
+
606
+ return {'ptend_t':heating,
607
+ 'ptend_q0001':moistening,
608
+ 'cam_out_NETSW':netsw,
609
+ 'cam_out_FLWDS':flwds,
610
+ 'cam_out_PRECSC':precsc,
611
+ 'cam_out_PRECC':precc,
612
+ 'cam_out_SOLS':sols,
613
+ 'cam_out_SOLL':soll,
614
+ 'cam_out_SOLSD':solsd,
615
+ 'cam_out_SOLLD':solld}
616
+
617
+ def reweight_target(self, data_split):
618
+ '''
619
+ data_split should be train, val, scoring, or test
620
+ weights target variables assuming V1 outputs using the output_weighting function
621
+ '''
622
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
623
+ if data_split == 'train':
624
+ assert self.target_train is not None
625
+ self.target_weighted_train = self.output_weighting(self.target_train, data_split)
626
+ elif data_split == 'val':
627
+ assert self.target_val is not None
628
+ self.target_weighted_val = self.output_weighting(self.target_val, data_split)
629
+ elif data_split == 'scoring':
630
+ assert self.target_scoring is not None
631
+ self.target_weighted_scoring = self.output_weighting(self.target_scoring, data_split)
632
+ elif data_split == 'test':
633
+ assert self.target_test is not None
634
+ self.target_weighted_test = self.output_weighting(self.target_test, data_split)
635
+
636
+ def reweight_preds(self, data_split):
637
+ '''
638
+ weights predictions assuming V1 outputs using the output_weighting function
639
+ '''
640
+ assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
641
+ assert self.model_names is not None
642
+
643
+ if data_split == 'train':
644
+ assert self.preds_train is not None
645
+ for model_name in self.model_names:
646
+ self.preds_weighted_train[model_name] = self.output_weighting(self.preds_train[model_name], data_split)
647
+ elif data_split == 'val':
648
+ assert self.preds_val is not None
649
+ for model_name in self.model_names:
650
+ self.preds_weighted_val[model_name] = self.output_weighting(self.preds_val[model_name], data_split)
651
+ elif data_split == 'scoring':
652
+ assert self.preds_scoring is not None
653
+ for model_name in self.model_names:
654
+ self.preds_weighted_scoring[model_name] = self.output_weighting(self.preds_scoring[model_name], data_split)
655
+ elif data_split == 'test':
656
+ assert self.preds_test is not None
657
+ for model_name in self.model_names:
658
+ self.preds_weighted_test[model_name] = self.output_weighting(self.preds_test[model_name], data_split)
659
+
660
+ def calc_MAE(self, pred, target, avg_grid = True):
661
+ '''
662
+ calculate 'globally averaged' mean absolute error
663
+ for vertically-resolved variables, shape should be time x grid x level
664
+ for scalars, shape should be time x grid
665
+
666
+ returns vector of length level or 1
667
+ '''
668
+ assert pred.shape[1] == self.latlonnum
669
+ assert pred.shape == target.shape
670
+ mae = np.abs(pred - target).mean(axis = 0)
671
+ if avg_grid:
672
+ return mae.mean(axis = 0) # we decided to average globally at end
673
+ else:
674
+ return mae
675
+
676
+ def calc_RMSE(self, pred, target, avg_grid = True):
677
+ '''
678
+ calculate 'globally averaged' root mean squared error
679
+ for vertically-resolved variables, shape should be time x grid x level
680
+ for scalars, shape should be time x grid
681
+
682
+ returns vector of length level or 1
683
+ '''
684
+ assert pred.shape[1] == self.latlonnum
685
+ assert pred.shape == target.shape
686
+ sq_diff = (pred - target)**2
687
+ rmse = np.sqrt(sq_diff.mean(axis = 0)) # mean over time
688
+ if avg_grid:
689
+ return rmse.mean(axis = 0) # we decided to separately average globally at end
690
+ else:
691
+ return rmse
692
+
693
+ def calc_R2(self, pred, target, avg_grid = True):
694
+ '''
695
+ calculate 'globally averaged' R-squared
696
+ for vertically-resolved variables, input shape should be time x grid x level
697
+ for scalars, input shape should be time x grid
698
+
699
+ returns vector of length level or 1
700
+ '''
701
+ assert pred.shape[1] == self.latlonnum
702
+ assert pred.shape == target.shape
703
+ sq_diff = (pred - target)**2
704
+ tss_time = (target - target.mean(axis = 0)[np.newaxis, ...])**2 # mean over time
705
+ r_squared = 1 - sq_diff.sum(axis = 0)/tss_time.sum(axis = 0) # sum over time
706
+ if avg_grid:
707
+ return r_squared.mean(axis = 0) # we decided to separately average globally at end
708
+ else:
709
+ return r_squared
710
+
711
+ def calc_bias(self, pred, target, avg_grid = True):
712
+ '''
713
+ calculate bias
714
+ for vertically-resolved variables, input shape should be time x grid x level
715
+ for scalars, input shape should be time x grid
716
+
717
+ returns vector of length level or 1
718
+ '''
719
+ assert pred.shape[1] == self.latlonnum
720
+ assert pred.shape == target.shape
721
+ bias = pred.mean(axis = 0) - target.mean(axis = 0)
722
+ if avg_grid:
723
+ return bias.mean(axis = 0) # we decided to separately average globally at end
724
+ else:
725
+ return bias
726
+
727
+
728
+ def calc_CRPS(self, preds, target, avg_grid = True):
729
+ '''
730
+ calculate 'globally averaged' continuous ranked probability score
731
+ for vertically-resolved variables, input shape should be time x grid x level x num_crps_samples
732
+ for scalars, input shape should be time x grid x num_crps_samples
733
+
734
+ returns vector of length level or 1
735
+ '''
736
+ assert preds.shape[1] == self.latlonnum
737
+ num_crps = preds.shape[-1]
738
+ mae = np.mean(np.abs(preds - target[..., np.newaxis]), axis = (0, -1)) # mean over time and crps samples
739
+ diff = preds[..., 1:] - preds[..., :-1]
740
+ count = np.arange(1, num_crps) * np.arange(num_crps - 1, 0, -1)
741
+ spread = (diff * count[np.newaxis, np.newaxis, np.newaxis, :]).mean(axis = (0, -1)) # mean over time and crps samples
742
+ crps = mae - spread/(num_crps*(num_crps-1))
743
+ # already divided by two in spread by exploiting symmetry
744
+ if avg_grid:
745
+ return crps.mean(axis = 0) # we decided to separately average globally at end
746
+ else:
747
+ return crps
748
+
749
+ def create_metrics_df(self, data_split):
750
+ '''
751
+ creates a dataframe of metrics for each model
752
+ '''
753
+ assert data_split in ['train', 'val', 'scoring', 'test'], \
754
+ 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
755
+ assert len(self.model_names) != 0
756
+ assert len(self.metrics_names) != 0
757
+ assert len(self.target_vars) != 0
758
+ assert self.target_feature_len is not None
759
+
760
+ if data_split == 'train':
761
+ assert len(self.preds_weighted_train) != 0
762
+ assert len(self.target_weighted_train) != 0
763
+ for model_name in self.model_names:
764
+ df_var = pd.DataFrame(columns = self.metrics_names, index = self.target_vars)
765
+ df_var.index.name = 'variable'
766
+ df_idx = pd.DataFrame(columns = self.metrics_names, index = range(self.target_feature_len))
767
+ df_idx.index.name = 'output_idx'
768
+ for metric_name in self.metrics_names:
769
+ current_idx = 0
770
+ for target_var in self.target_vars:
771
+ metric = self.metrics_dict[metric_name](self.preds_weighted_train[model_name][target_var], self.target_weighted_train[target_var])
772
+ df_var.loc[target_var, metric_name] = np.mean(metric)
773
+ df_idx.loc[current_idx:current_idx + self.var_lens[target_var] - 1, metric_name] = np.atleast_1d(metric)
774
+ current_idx += self.var_lens[target_var]
775
+ self.metrics_var_train[model_name] = df_var
776
+ self.metrics_idx_train[model_name] = df_idx
777
+
778
+ elif data_split == 'val':
779
+ assert len(self.preds_weighted_val) != 0
780
+ assert len(self.target_weighted_val) != 0
781
+ for model_name in self.model_names:
782
+ df_var = pd.DataFrame(columns = self.metrics_names, index = self.target_vars)
783
+ df_var.index.name = 'variable'
784
+ df_idx = pd.DataFrame(columns = self.metrics_names, index = range(self.target_feature_len))
785
+ df_idx.index.name = 'output_idx'
786
+ for metric_name in self.metrics_names:
787
+ current_idx = 0
788
+ for target_var in self.target_vars:
789
+ metric = self.metrics_dict[metric_name](self.preds_weighted_val[model_name][target_var], self.target_weighted_val[target_var])
790
+ df_var.loc[target_var, metric_name] = np.mean(metric)
791
+ df_idx.loc[current_idx:current_idx + self.var_lens[target_var] - 1, metric_name] = np.atleast_1d(metric)
792
+ current_idx += self.var_lens[target_var]
793
+ self.metrics_var_val[model_name] = df_var
794
+ self.metrics_idx_val[model_name] = df_idx
795
+
796
+ elif data_split == 'scoring':
797
+ assert len(self.preds_weighted_scoring) != 0
798
+ assert len(self.target_weighted_scoring) != 0
799
+ for model_name in self.model_names:
800
+ df_var = pd.DataFrame(columns = self.metrics_names, index = self.target_vars)
801
+ df_var.index.name = 'variable'
802
+ df_idx = pd.DataFrame(columns = self.metrics_names, index = range(self.target_feature_len))
803
+ df_idx.index.name = 'output_idx'
804
+ for metric_name in self.metrics_names:
805
+ current_idx = 0
806
+ for target_var in self.target_vars:
807
+ metric = self.metrics_dict[metric_name](self.preds_weighted_scoring[model_name][target_var], self.target_weighted_scoring[target_var])
808
+ df_var.loc[target_var, metric_name] = np.mean(metric)
809
+ df_idx.loc[current_idx:current_idx + self.var_lens[target_var] - 1, metric_name] = np.atleast_1d(metric)
810
+ current_idx += self.var_lens[target_var]
811
+ self.metrics_var_scoring[model_name] = df_var
812
+ self.metrics_idx_scoring[model_name] = df_idx
813
+
814
+ elif data_split == 'test':
815
+ assert len(self.preds_weighted_test) != 0
816
+ assert len(self.target_weighted_test) != 0
817
+ for model_name in self.model_names:
818
+ df_var = pd.DataFrame(columns = self.metrics_names, index = self.target_vars)
819
+ df_var.index.name = 'variable'
820
+ df_idx = pd.DataFrame(columns = self.metrics_names, index = range(self.target_feature_len))
821
+ df_idx.index.name = 'output_idx'
822
+ for metric_name in self.metrics_names:
823
+ current_idx = 0
824
+ for target_var in self.target_vars:
825
+ metric = self.metrics_dict[metric_name](self.preds_weighted_test[model_name][target_var], self.target_weighted_test[target_var])
826
+ df_var.loc[target_var, metric_name] = np.mean(metric)
827
+ df_idx.loc[current_idx:current_idx + self.var_lens[target_var] - 1, metric_name] = np.atleast_1d(metric)
828
+ current_idx += self.var_lens[target_var]
829
+ self.metrics_var_test[model_name] = df_var
830
+ self.metrics_idx_test[model_name] = df_idx
831
+
832
+ def reshape_daily(self, output):
833
+ '''
834
+ This function returns two numpy arrays, one for each vertically resolved variable (heating and moistening).
835
+ Dimensions of expected input are num_samples by 128 (number of target features).
836
+ Output argument is espected to be have dimensions of num_samples by features.
837
+ Heating is expected to be the first feature, and moistening is expected to be the second feature.
838
+ Data is expected to use a stride_sample of 6. (12 samples per day, 20 min timestep).
839
+ '''
840
+ num_samples = output.shape[0]
841
+ heating = output[:,:60].reshape((int(num_samples/self.latlonnum), self.latlonnum, 60))
842
+ moistening = output[:,60:120].reshape((int(num_samples/self.latlonnum), self.latlonnum, 60))
843
+ heating_daily = np.mean(heating.reshape((heating.shape[0]//12, 12, self.latlonnum, 60)), axis = 1) # Nday x lotlonnum x 60
844
+ moistening_daily = np.mean(moistening.reshape((moistening.shape[0]//12, 12, self.latlonnum, 60)), axis = 1) # Nday x lotlonnum x 60
845
+ heating_daily_long = []
846
+ moistening_daily_long = []
847
+ for i in range(len(self.lats)):
848
+ heating_daily_long.append(np.mean(heating_daily[:,self.lat_indices_list[i],:],axis=1))
849
+ moistening_daily_long.append(np.mean(moistening_daily[:,self.lat_indices_list[i],:],axis=1))
850
+ heating_daily_long = np.array(heating_daily_long) # lat x Nday x 60
851
+ moistening_daily_long = np.array(moistening_daily_long) # lat x Nday x 60
852
+ return heating_daily_long, moistening_daily_long
853
+
854
+ def plot_r2_analysis(self, pressure_grid_plotting, save_path = ''):
855
+ '''
856
+ This function plots the R2 pressure latitude figure shown in the SI.
857
+ '''
858
+ self.set_plot_params()
859
+ n_model = len(self.model_names)
860
+ fig, ax = plt.subplots(2,n_model, figsize=(n_model*12,18))
861
+ y = np.array(range(60))
862
+ X, Y = np.meshgrid(np.sin(self.lats*np.pi/180), y)
863
+ Y = pressure_grid_plotting/100
864
+ test_heat_daily_long, test_moist_daily_long = self.reshape_daily(self.target_scoring)
865
+ for i, model_name in enumerate(self.model_names):
866
+ pred_heat_daily_long, pred_moist_daily_long = self.reshape_daily(self.preds_scoring[model_name])
867
+ coeff = 1 - np.sum( (pred_heat_daily_long-test_heat_daily_long)**2, axis=1)/np.sum( (test_heat_daily_long-np.mean(test_heat_daily_long, axis=1)[:,None,:])**2, axis=1)
868
+ coeff = coeff[self.sort_lat_key,:]
869
+ coeff = coeff.T
870
+
871
+ contour_plot = ax[0,i].pcolor(X, Y, coeff,cmap='Blues', vmin = 0, vmax = 1) # pcolormesh
872
+ ax[0,i].contour(X, Y, coeff, [0.7], colors='orange', linewidths=[4])
873
+ ax[0,i].contour(X, Y, coeff, [0.9], colors='yellow', linewidths=[4])
874
+ ax[0,i].set_ylim(ax[0,i].get_ylim()[::-1])
875
+ ax[0,i].set_title(self.model_names[i] + " - Heating")
876
+ ax[0,i].set_xticks([])
877
+
878
+ coeff = 1 - np.sum( (pred_moist_daily_long-test_moist_daily_long)**2, axis=1)/np.sum( (test_moist_daily_long-np.mean(test_moist_daily_long, axis=1)[:,None,:])**2, axis=1)
879
+ coeff = coeff[self.sort_lat_key,:]
880
+ coeff = coeff.T
881
+
882
+ contour_plot = ax[1,i].pcolor(X, Y, coeff,cmap='Blues', vmin = 0, vmax = 1) # pcolormesh
883
+ ax[1,i].contour(X, Y, coeff, [0.7], colors='orange', linewidths=[4])
884
+ ax[1,i].contour(X, Y, coeff, [0.9], colors='yellow', linewidths=[4])
885
+ ax[1,i].set_ylim(ax[1,i].get_ylim()[::-1])
886
+ ax[1,i].set_title(self.model_names[i] + " - Moistening")
887
+ ax[1,i].xaxis.set_ticks([np.sin(-50/180*np.pi), 0, np.sin(50/180*np.pi)])
888
+ ax[1,i].xaxis.set_ticklabels(['50$^\circ$S', '0$^\circ$', '50$^\circ$N'])
889
+ ax[1,i].xaxis.set_tick_params(width = 2)
890
+
891
+ if i != 0:
892
+ ax[0,i].set_yticks([])
893
+ ax[1,i].set_yticks([])
894
+
895
+ # lines below for x and y label axes are valid if 3 models are considered
896
+ # we want to put only one label for each axis
897
+ # if nbr of models is different from 3 please adjust label location to center it
898
+
899
+ #ax[1,1].xaxis.set_label_coords(-0.10,-0.10)
900
+
901
+ ax[0,0].set_ylabel("Pressure [hPa]")
902
+ ax[0,0].yaxis.set_label_coords(-0.2,-0.09) # (-1.38,-0.09)
903
+ ax[0,0].yaxis.set_ticks([1000,800,600,400,200,0])
904
+ ax[1,0].yaxis.set_ticks([1000,800,600,400,200,0])
905
+
906
+ fig.subplots_adjust(right=0.8)
907
+ cbar_ax = fig.add_axes([0.82, 0.12, 0.02, 0.76])
908
+ cb = fig.colorbar(contour_plot, cax=cbar_ax)
909
+ cb.set_label("Skill Score "+r'$\left(\mathrm{R^{2}}\right)$',labelpad=50.1)
910
+ plt.suptitle("Baseline Models Skill for Vertically Resolved Tendencies", y = 0.97)
911
+ plt.subplots_adjust(hspace=0.13)
912
+ plt.show()
913
+ plt.savefig(save_path + 'press_lat_diff_models.png', bbox_inches='tight', pad_inches=0.1 , dpi = 300)
914
+
915
+ @staticmethod
916
+ def reshape_input_for_cnn(npy_input, save_path = ''):
917
+ '''
918
+ This function reshapes a numpy input array to be compatible with CNN training.
919
+ Each variable becomes its own channel.
920
+ For the input there are 6 channels, each with 60 vertical levels.
921
+ The last 4 channels correspond to scalars repeated across all 60 levels.
922
+ This is for V1 data only! (V2 data has more variables)
923
+ '''
924
+ npy_input_cnn = np.stack([
925
+ npy_input[:, 0:60],
926
+ npy_input[:, 60:120],
927
+ np.repeat(npy_input[:, 120][:, np.newaxis], 60, axis = 1),
928
+ np.repeat(npy_input[:, 121][:, np.newaxis], 60, axis = 1),
929
+ np.repeat(npy_input[:, 122][:, np.newaxis], 60, axis = 1),
930
+ np.repeat(npy_input[:, 123][:, np.newaxis], 60, axis = 1)], axis = 2)
931
+
932
+ if save_path != '':
933
+ with open(save_path + 'train_input_cnn.npy', 'wb') as f:
934
+ np.save(f, np.float32(npy_input_cnn))
935
+ return npy_input_cnn
936
+
937
+ @staticmethod
938
+ def reshape_target_for_cnn(npy_target, save_path = ''):
939
+ '''
940
+ This function reshapes a numpy target array to be compatible with CNN training.
941
+ Each variable becomes its own channel.
942
+ For the input there are 6 channels, each with 60 vertical levels.
943
+ The last 4 channels correspond to scalars repeated across all 60 levels.
944
+ This is for V1 data only! (V2 data has more variables)
945
+ '''
946
+ npy_target_cnn = np.stack([
947
+ npy_target[:, 0:60],
948
+ npy_target[:, 60:120],
949
+ np.repeat(npy_target[:, 120][:, np.newaxis], 60, axis = 1),
950
+ np.repeat(npy_target[:, 121][:, np.newaxis], 60, axis = 1),
951
+ np.repeat(npy_target[:, 122][:, np.newaxis], 60, axis = 1),
952
+ np.repeat(npy_target[:, 123][:, np.newaxis], 60, axis = 1),
953
+ np.repeat(npy_target[:, 124][:, np.newaxis], 60, axis = 1),
954
+ np.repeat(npy_target[:, 125][:, np.newaxis], 60, axis = 1),
955
+ np.repeat(npy_target[:, 126][:, np.newaxis], 60, axis = 1),
956
+ np.repeat(npy_target[:, 127][:, np.newaxis], 60, axis = 1)], axis = 2)
957
+
958
+ if save_path != '':
959
+ with open(save_path + 'train_target_cnn.npy', 'wb') as f:
960
+ np.save(f, np.float32(npy_target_cnn))
961
+ return npy_target_cnn
962
+
963
+ @staticmethod
964
+ def reshape_target_from_cnn(npy_predict_cnn, save_path = ''):
965
+ '''
966
+ This function reshapes CNN target to (num_samples, 128) for standardized metrics.
967
+ This is for V1 data only! (V2 data has more variables)
968
+ '''
969
+ npy_predict_cnn_reshaped = np.concatenate([
970
+ npy_predict_cnn[:,:,0],
971
+ npy_predict_cnn[:,:,1],
972
+ np.mean(npy_predict_cnn[:,:,2], axis = 1)[:, np.newaxis],
973
+ np.mean(npy_predict_cnn[:,:,3], axis = 1)[:, np.newaxis],
974
+ np.mean(npy_predict_cnn[:,:,4], axis = 1)[:, np.newaxis],
975
+ np.mean(npy_predict_cnn[:,:,5], axis = 1)[:, np.newaxis],
976
+ np.mean(npy_predict_cnn[:,:,6], axis = 1)[:, np.newaxis],
977
+ np.mean(npy_predict_cnn[:,:,7], axis = 1)[:, np.newaxis],
978
+ np.mean(npy_predict_cnn[:,:,8], axis = 1)[:, np.newaxis],
979
+ np.mean(npy_predict_cnn[:,:,9], axis = 1)[:, np.newaxis]], axis = 1)
980
+
981
+ if save_path != '':
982
+ with open(save_path + 'cnn_predict_reshaped.npy', 'wb') as f:
983
+ np.save(f, np.float32(npy_predict_cnn_reshaped))
984
+ return npy_predict_cnn_reshaped
985
+
986
+
987
+
988
+
989
+
990
+
991
+
992
+
993
+