Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ from tqdm import tqdm
|
|
16 |
|
17 |
|
18 |
st.title('A _Quickstart Notebook_ for :blue[ClimSim]:')
|
19 |
-
st.link_button("ClimSim", "https://huggingface.co/datasets/LEAP/subsampled_low_res/tree/main",use_container_width=True
|
20 |
st.header('**Step 1:** Import data_utils')
|
21 |
st.code('''from data_utils import *''',language='python')
|
22 |
st.header('**Step 2:** Instantiate class')
|
@@ -24,7 +24,7 @@ st.header('**Step 3:** Load training and validation data')
|
|
24 |
|
25 |
st.header('**Step 4:** Train models')
|
26 |
st.subheader('Train constant prediction model')
|
27 |
-
st.
|
28 |
|
29 |
|
30 |
|
@@ -121,12 +121,73 @@ fig.tight_layout()
|
|
121 |
st.pyplot(fig)
|
122 |
|
123 |
# path to target input
|
124 |
-
data.input_scoring = np.load('
|
125 |
-
|
126 |
# path to target output
|
127 |
data.target_scoring = np.load('scoring_target_small.npy')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
|
|
|
|
|
|
|
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
st.markdown('Streamlit p')
|
|
|
16 |
|
17 |
|
18 |
st.title('A _Quickstart Notebook_ for :blue[ClimSim]:')
|
19 |
+
st.link_button("ClimSim", "https://huggingface.co/datasets/LEAP/subsampled_low_res/tree/main",use_container_width=True)
|
20 |
st.header('**Step 1:** Import data_utils')
|
21 |
st.code('''from data_utils import *''',language='python')
|
22 |
st.header('**Step 2:** Instantiate class')
|
|
|
24 |
|
25 |
st.header('**Step 4:** Train models')
|
26 |
st.subheader('Train constant prediction model')
|
27 |
+
st.link_button("Go to Original Dataset", "https://huggingface.co/datasets/LEAP/subsampled_low_res/tree/main",,use_container_width=True)
|
28 |
|
29 |
|
30 |
|
|
|
121 |
st.pyplot(fig)
|
122 |
|
123 |
# path to target input
|
124 |
+
data.input_scoring = np.load('score_input_small.npy')
|
|
|
125 |
# path to target output
|
126 |
data.target_scoring = np.load('scoring_target_small.npy')
|
127 |
+
data.set_pressure_grid(data_split = 'scoring')
|
128 |
+
# constant prediction
|
129 |
+
const_pred_scoring = np.repeat(const_model[np.newaxis, :], data.target_scoring.shape[0], axis = 0)
|
130 |
+
print(const_pred_scoring.shape)
|
131 |
+
|
132 |
+
# multiple linear regression
|
133 |
+
X_scoring = data.input_scoring
|
134 |
+
bias_vector_scoring = np.ones((X_scoring.shape[0], 1))
|
135 |
+
X_scoring = np.concatenate((X_scoring, bias_vector_scoring), axis=1)
|
136 |
+
mlr_pred_scoring = X_scoring@mlr_weights
|
137 |
+
print(mlr_pred_scoring.shape)
|
138 |
+
|
139 |
+
# Your model prediction here
|
140 |
+
|
141 |
+
# Load predictions into object
|
142 |
+
data.model_names = ['const', 'mlr'] # model name here
|
143 |
+
preds = [const_pred_scoring, mlr_pred_scoring] # add prediction here
|
144 |
+
data.preds_scoring = dict(zip(data.model_names, preds))
|
145 |
+
# weight predictions and target
|
146 |
+
data.reweight_target(data_split = 'scoring')
|
147 |
+
data.reweight_preds(data_split = 'scoring')
|
148 |
+
|
149 |
+
# set and calculate metrics
|
150 |
+
data.metrics_names = ['MAE', 'RMSE', 'R2', 'bias']
|
151 |
+
data.create_metrics_df(data_split = 'scoring')
|
152 |
+
|
153 |
+
# set plotting settings
|
154 |
+
%config InlineBackend.figure_format = 'retina'
|
155 |
+
letters = string.ascii_lowercase
|
156 |
+
|
157 |
+
# create custom dictionary for plotting
|
158 |
+
dict_var = data.metrics_var_scoring
|
159 |
+
plot_df_byvar = {}
|
160 |
+
for metric in data.metrics_names:
|
161 |
+
plot_df_byvar[metric] = pd.DataFrame([dict_var[model][metric] for model in data.model_names],
|
162 |
+
index=data.model_names)
|
163 |
+
plot_df_byvar[metric] = plot_df_byvar[metric].rename(columns = data.var_short_names).transpose()
|
164 |
|
165 |
+
# plot figure
|
166 |
+
fig, axes = plt.subplots(nrows = len(data.metrics_names), sharex = True)
|
167 |
+
for i in range(len(data.metrics_names)):
|
168 |
+
plot_df_byvar[data.metrics_names[i]].plot.bar(
|
169 |
+
legend = False,
|
170 |
+
ax = axes[i])
|
171 |
+
if data.metrics_names[i] != 'R2':
|
172 |
+
axes[i].set_ylabel('$W/m^2$')
|
173 |
+
else:
|
174 |
+
axes[i].set_ylim(0,1)
|
175 |
|
176 |
+
axes[i].set_title(f'({letters[i]}) {data.metrics_names[i]}')
|
177 |
+
axes[i].set_xlabel('Output variable')
|
178 |
+
axes[i].set_xticklabels(plot_df_byvar[data.metrics_names[i]].index, \
|
179 |
+
rotation=0, ha='center')
|
180 |
|
181 |
+
axes[0].legend(columnspacing = .9,
|
182 |
+
labelspacing = .3,
|
183 |
+
handleheight = .07,
|
184 |
+
handlelength = 1.5,
|
185 |
+
handletextpad = .2,
|
186 |
+
borderpad = .2,
|
187 |
+
ncol = 3,
|
188 |
+
loc = 'upper right')
|
189 |
+
fig.set_size_inches(7,8)
|
190 |
+
fig.tight_layout()
|
191 |
+
st.pyplot(fig)
|
192 |
|
193 |
st.markdown('Streamlit p')
|