fmegahed commited on
Commit
776c727
·
verified ·
1 Parent(s): 188cf42

Trying to plot each cross validation window separately

Browse files
Files changed (1) hide show
  1. app.py +31 -5
app.py CHANGED
@@ -33,10 +33,18 @@ def load_data(file):
33
  except Exception as e:
34
  return None, f"Error loading data: {str(e)}"
35
 
 
 
 
 
36
  # Function to generate and return a plot
37
- def create_forecast_plot(forecast_df, original_df):
 
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
 
 
 
40
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
41
 
42
  for unique_id in unique_ids:
@@ -113,8 +121,10 @@ def run_forecast(
113
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
114
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
115
  eval_df = pd.DataFrame(evaluation).reset_index()
116
- fig_forecast = create_forecast_plot(cv_results, df)
117
- return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!"
 
 
118
 
119
  else: # Fixed window
120
  train_size = len(df) - horizon
@@ -128,11 +138,20 @@ def run_forecast(
128
  evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
129
  eval_df = pd.DataFrame(evaluation).reset_index()
130
  fig_forecast = create_forecast_plot(forecast, df)
131
- return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!"
132
 
133
  except Exception as e:
134
  return None, None, None, f"Error during forecasting: {str(e)}"
135
 
 
 
 
 
 
 
 
 
 
136
  # Sample CSV file generation
137
  def download_sample():
138
  sample_data = """unique_id,ds,y
@@ -192,11 +211,15 @@ with gr.Blocks(title="StatsForecast Demo") as app:
192
  submit_btn = gr.Button("Run Forecast")
193
 
194
  with gr.Column(scale=3):
 
195
  eval_output = gr.Dataframe(label="Evaluation Results")
196
  forecast_output = gr.Dataframe(label="Forecast Data")
197
  plot_output = gr.Plot(label="Forecast Plot")
198
  message_output = gr.Textbox(label="Message")
199
 
 
 
 
200
  submit_btn.click(
201
  fn=run_forecast,
202
  inputs=[
@@ -205,8 +228,11 @@ with gr.Blocks(title="StatsForecast Demo") as app:
205
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
206
  use_autoets, use_autoarima
207
  ],
208
- outputs=[eval_output, forecast_output, plot_output, message_output]
209
  )
210
 
211
  if __name__ == "__main__":
212
  app.launch(share=False)
 
 
 
 
33
  except Exception as e:
34
  return None, f"Error loading data: {str(e)}"
35
 
36
+
37
+ # Global store to hold cross-validation forecasts
38
+ forecast_store = {}
39
+
40
  # Function to generate and return a plot
41
+
42
+ def create_forecast_plot(forecast_df, original_df, window=None):
43
  plt.figure(figsize=(10, 6))
44
  unique_ids = forecast_df['unique_id'].unique()
45
+ if window is not None and 'cutoff' in forecast_df.columns:
46
+ forecast_df = forecast_df[forecast_df['cutoff'] == window]
47
+
48
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
49
 
50
  for unique_id in unique_ids:
 
121
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
122
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
123
  eval_df = pd.DataFrame(evaluation).reset_index()
124
+ forecast_store['cv'] = {'forecast': cv_results, 'original': df}
125
+ unique_cutoffs = sorted(cv_results['cutoff'].unique())
126
+ fig_forecast = create_forecast_plot(cv_results, df, window=unique_cutoffs[0])
127
+ return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
128
 
129
  else: # Fixed window
130
  train_size = len(df) - horizon
 
138
  evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
139
  eval_df = pd.DataFrame(evaluation).reset_index()
140
  fig_forecast = create_forecast_plot(forecast, df)
141
+ return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!", []
142
 
143
  except Exception as e:
144
  return None, None, None, f"Error during forecasting: {str(e)}"
145
 
146
+
147
+ # Function to update forecast plot for selected CV window
148
+ def update_forecast_plot(selected_window):
149
+ data = forecast_store.get('cv')
150
+ if not data:
151
+ return None
152
+ return create_forecast_plot(data['forecast'], data['original'], window=selected_window)
153
+
154
+
155
  # Sample CSV file generation
156
  def download_sample():
157
  sample_data = """unique_id,ds,y
 
211
  submit_btn = gr.Button("Run Forecast")
212
 
213
  with gr.Column(scale=3):
214
+ window_selector = gr.Dropdown(label='Select CV Window', choices=[], visible=False)
215
  eval_output = gr.Dataframe(label="Evaluation Results")
216
  forecast_output = gr.Dataframe(label="Forecast Data")
217
  plot_output = gr.Plot(label="Forecast Plot")
218
  message_output = gr.Textbox(label="Message")
219
 
220
+ def handle_forecast_output(eval_df, forecast_df, plot, msg, windows):
221
+ return eval_df, forecast_df, plot, msg, gr.update(choices=[str(w) for w in windows], visible=bool(windows), value=str(windows[0]) if windows else None)
222
+
223
  submit_btn.click(
224
  fn=run_forecast,
225
  inputs=[
 
228
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
229
  use_autoets, use_autoarima
230
  ],
231
+ outputs=[eval_output, forecast_output, plot_output, message_output, window_selector]
232
  )
233
 
234
  if __name__ == "__main__":
235
  app.launch(share=False)
236
+
237
+ # Update plot when a window is selected
238
+ window_selector.change(fn=update_forecast_plot, inputs=window_selector, outputs=plot_output)