puqi commited on
Commit
2e37859
·
1 Parent(s): 942923c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -4
app.py CHANGED
@@ -16,7 +16,7 @@ from tqdm import tqdm
16
 
17
 
18
  st.title('A _Quickstart Notebook_ for :blue[ClimSim]:')
19
- st.link_button("ClimSim", "https://huggingface.co/datasets/LEAP/subsampled_low_res/tree/main",use_container_width=True,type="primary")
20
  st.header('**Step 1:** Import data_utils')
21
  st.code('''from data_utils import *''',language='python')
22
  st.header('**Step 2:** Instantiate class')
@@ -24,7 +24,7 @@ st.header('**Step 3:** Load training and validation data')
24
 
25
  st.header('**Step 4:** Train models')
26
  st.subheader('Train constant prediction model')
27
- st.sidebar.link_button("Go to Original Dataset", "https://huggingface.co/datasets/LEAP/subsampled_low_res/tree/main")
28
 
29
 
30
 
@@ -121,12 +121,73 @@ fig.tight_layout()
121
  st.pyplot(fig)
122
 
123
  # path to target input
124
- data.input_scoring = np.load('score_input_smallnn.npy')
125
-
126
  # path to target output
127
  data.target_scoring = np.load('scoring_target_small.npy')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
129
 
 
 
 
 
130
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  st.markdown('Streamlit p')
 
16
 
17
 
18
  st.title('A _Quickstart Notebook_ for :blue[ClimSim]:')
19
+ st.link_button("ClimSim", "https://huggingface.co/datasets/LEAP/subsampled_low_res/tree/main",use_container_width=True)
20
  st.header('**Step 1:** Import data_utils')
21
  st.code('''from data_utils import *''',language='python')
22
  st.header('**Step 2:** Instantiate class')
 
24
 
25
  st.header('**Step 4:** Train models')
26
  st.subheader('Train constant prediction model')
27
+ st.link_button("Go to Original Dataset", "https://huggingface.co/datasets/LEAP/subsampled_low_res/tree/main",,use_container_width=True)
28
 
29
 
30
 
 
121
  st.pyplot(fig)
122
 
123
  # path to target input
124
+ data.input_scoring = np.load('score_input_small.npy')
 
125
  # path to target output
126
  data.target_scoring = np.load('scoring_target_small.npy')
127
+ data.set_pressure_grid(data_split = 'scoring')
128
+ # constant prediction
129
+ const_pred_scoring = np.repeat(const_model[np.newaxis, :], data.target_scoring.shape[0], axis = 0)
130
+ print(const_pred_scoring.shape)
131
+
132
+ # multiple linear regression
133
+ X_scoring = data.input_scoring
134
+ bias_vector_scoring = np.ones((X_scoring.shape[0], 1))
135
+ X_scoring = np.concatenate((X_scoring, bias_vector_scoring), axis=1)
136
+ mlr_pred_scoring = X_scoring@mlr_weights
137
+ print(mlr_pred_scoring.shape)
138
+
139
+ # Your model prediction here
140
+
141
+ # Load predictions into object
142
+ data.model_names = ['const', 'mlr'] # model name here
143
+ preds = [const_pred_scoring, mlr_pred_scoring] # add prediction here
144
+ data.preds_scoring = dict(zip(data.model_names, preds))
145
+ # weight predictions and target
146
+ data.reweight_target(data_split = 'scoring')
147
+ data.reweight_preds(data_split = 'scoring')
148
+
149
+ # set and calculate metrics
150
+ data.metrics_names = ['MAE', 'RMSE', 'R2', 'bias']
151
+ data.create_metrics_df(data_split = 'scoring')
152
+
153
+ # set plotting settings
154
+ %config InlineBackend.figure_format = 'retina'
155
+ letters = string.ascii_lowercase
156
+
157
+ # create custom dictionary for plotting
158
+ dict_var = data.metrics_var_scoring
159
+ plot_df_byvar = {}
160
+ for metric in data.metrics_names:
161
+ plot_df_byvar[metric] = pd.DataFrame([dict_var[model][metric] for model in data.model_names],
162
+ index=data.model_names)
163
+ plot_df_byvar[metric] = plot_df_byvar[metric].rename(columns = data.var_short_names).transpose()
164
 
165
+ # plot figure
166
+ fig, axes = plt.subplots(nrows = len(data.metrics_names), sharex = True)
167
+ for i in range(len(data.metrics_names)):
168
+ plot_df_byvar[data.metrics_names[i]].plot.bar(
169
+ legend = False,
170
+ ax = axes[i])
171
+ if data.metrics_names[i] != 'R2':
172
+ axes[i].set_ylabel('$W/m^2$')
173
+ else:
174
+ axes[i].set_ylim(0,1)
175
 
176
+ axes[i].set_title(f'({letters[i]}) {data.metrics_names[i]}')
177
+ axes[i].set_xlabel('Output variable')
178
+ axes[i].set_xticklabels(plot_df_byvar[data.metrics_names[i]].index, \
179
+ rotation=0, ha='center')
180
 
181
+ axes[0].legend(columnspacing = .9,
182
+ labelspacing = .3,
183
+ handleheight = .07,
184
+ handlelength = 1.5,
185
+ handletextpad = .2,
186
+ borderpad = .2,
187
+ ncol = 3,
188
+ loc = 'upper right')
189
+ fig.set_size_inches(7,8)
190
+ fig.tight_layout()
191
+ st.pyplot(fig)
192
 
193
  st.markdown('Streamlit p')