fmegahed commited on
Commit
c9b451d
·
verified ·
1 Parent(s): 08c1c42

Fix some bugs to allow multiple window plotting

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -15,9 +15,10 @@ from statsforecast.models import (
15
  )
16
 
17
  from utilsforecast.evaluation import evaluate
18
- from utilsforecast.losses import * # Assuming you need the metrics like bias, mae, rmse, mape
 
 
19
 
20
- # Function to load and process uploaded CSV
21
  def load_data(file):
22
  if file is None:
23
  return None, "Please upload a CSV file"
@@ -33,7 +34,6 @@ def load_data(file):
33
  except Exception as e:
34
  return None, f"Error loading data: {str(e)}"
35
 
36
- # Function to generate and return a plot for a specific cross-validation window
37
  def create_forecast_plot(forecast_df, original_df, window=None):
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
@@ -55,10 +55,8 @@ def create_forecast_plot(forecast_df, original_df, window=None):
55
  plt.ylabel('Value')
56
  plt.legend()
57
  plt.grid(True)
58
- fig = plt.gcf()
59
- return fig
60
 
61
- # Main forecasting logic
62
  def run_forecast(
63
  file,
64
  frequency,
@@ -114,13 +112,21 @@ def run_forecast(
114
  try:
115
  if eval_strategy == "Cross Validation":
116
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
 
 
 
 
 
117
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
118
  eval_df = pd.DataFrame(evaluation).reset_index()
 
 
119
  unique_cutoffs = sorted(str(c) for c in cv_results['cutoff'].unique())
120
- fig_forecast = create_forecast_plot(cv_results, df, window=unique_cutoffs[0])
 
121
  return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
122
 
123
- else: # Fixed window
124
  train_size = len(df) - horizon
125
  if train_size <= 0:
126
  return None, None, None, f"Not enough data for horizon={horizon}", []
@@ -137,7 +143,13 @@ def run_forecast(
137
  except Exception as e:
138
  return None, None, None, f"Error during forecasting: {str(e)}", []
139
 
140
- # Sample CSV file generation
 
 
 
 
 
 
141
  def download_sample():
142
  sample_data = """unique_id,ds,y
143
  series1,2023-01-01,100
@@ -161,7 +173,6 @@ series1,2023-01-15,131
161
  temp.close()
162
  return temp.name
163
 
164
- # Gradio interface
165
  with gr.Blocks(title="StatsForecast Demo") as app:
166
  gr.Markdown("# 📈 StatsForecast Demo App")
167
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
@@ -169,25 +180,23 @@ with gr.Blocks(title="StatsForecast Demo") as app:
169
  with gr.Row():
170
  with gr.Column(scale=2):
171
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
172
-
173
  download_btn = gr.Button("Download Sample Data")
174
  download_output = gr.File(label="Click to download", visible=True)
175
  download_btn.click(fn=download_sample, outputs=download_output)
176
 
177
  frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
178
  eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
179
- horizon = gr.Slider(1, 100, value=14, step=1, label="Horizon")
180
  step_size = gr.Slider(1, 50, value=5, step=1, label="Step Size")
181
  num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
182
 
183
-
184
  gr.Markdown("### Model Configuration")
185
  use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
186
  use_naive = gr.Checkbox(label="Use Naive", value=True)
187
  use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
188
- seasonality = gr.Number(label="Seasonality", value=7)
189
  use_window_avg = gr.Checkbox(label="Use Window Average")
190
- window_size = gr.Number(label="Window Size", value=3)
191
  use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
192
  seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
193
  use_autoets = gr.Checkbox(label="Use AutoETS")
@@ -210,10 +219,12 @@ with gr.Blocks(title="StatsForecast Demo") as app:
210
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
211
  use_autoets, use_autoarima
212
  ],
213
- outputs=[eval_output, forecast_output, plot_output, message_output, window_selector]
 
 
214
  )
215
 
216
- window_selector.change(fn=create_forecast_plot, inputs=window_selector, outputs=plot_output)
217
 
218
  if __name__ == "__main__":
219
  app.launch(share=False)
 
15
  )
16
 
17
  from utilsforecast.evaluation import evaluate
18
+ from utilsforecast.losses import *
19
+
20
+ forecast_store = {} # for storing CV results globally
21
 
 
22
  def load_data(file):
23
  if file is None:
24
  return None, "Please upload a CSV file"
 
34
  except Exception as e:
35
  return None, f"Error loading data: {str(e)}"
36
 
 
37
  def create_forecast_plot(forecast_df, original_df, window=None):
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
 
55
  plt.ylabel('Value')
56
  plt.legend()
57
  plt.grid(True)
58
+ return plt.gcf()
 
59
 
 
60
  def run_forecast(
61
  file,
62
  frequency,
 
112
  try:
113
  if eval_strategy == "Cross Validation":
114
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
115
+
116
+ # Store for dropdown selection
117
+ forecast_store['forecast'] = cv_results
118
+ forecast_store['original'] = df
119
+
120
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
121
  eval_df = pd.DataFrame(evaluation).reset_index()
122
+
123
+ # Dropdown cutoffs
124
  unique_cutoffs = sorted(str(c) for c in cv_results['cutoff'].unique())
125
+ fig_forecast = create_forecast_plot(cv_results, df, window=pd.to_datetime(unique_cutoffs[0]))
126
+
127
  return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
128
 
129
+ else:
130
  train_size = len(df) - horizon
131
  if train_size <= 0:
132
  return None, None, None, f"Not enough data for horizon={horizon}", []
 
143
  except Exception as e:
144
  return None, None, None, f"Error during forecasting: {str(e)}", []
145
 
146
+ def update_window_plot(window_str):
147
+ if 'forecast' not in forecast_store:
148
+ return None
149
+ forecast_df = forecast_store['forecast']
150
+ original_df = forecast_store['original']
151
+ return create_forecast_plot(forecast_df, original_df, window=pd.to_datetime(window_str))
152
+
153
  def download_sample():
154
  sample_data = """unique_id,ds,y
155
  series1,2023-01-01,100
 
173
  temp.close()
174
  return temp.name
175
 
 
176
  with gr.Blocks(title="StatsForecast Demo") as app:
177
  gr.Markdown("# 📈 StatsForecast Demo App")
178
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
 
180
  with gr.Row():
181
  with gr.Column(scale=2):
182
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
 
183
  download_btn = gr.Button("Download Sample Data")
184
  download_output = gr.File(label="Click to download", visible=True)
185
  download_btn.click(fn=download_sample, outputs=download_output)
186
 
187
  frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
188
  eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
189
+ horizon = gr.Slider(1, 100, value=10, step=1, label="Horizon")
190
  step_size = gr.Slider(1, 50, value=5, step=1, label="Step Size")
191
  num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
192
 
 
193
  gr.Markdown("### Model Configuration")
194
  use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
195
  use_naive = gr.Checkbox(label="Use Naive", value=True)
196
  use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
197
+ seasonality = gr.Number(label="Seasonality", value=5)
198
  use_window_avg = gr.Checkbox(label="Use Window Average")
199
+ window_size = gr.Number(label="Window Size", value=10)
200
  use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
201
  seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
202
  use_autoets = gr.Checkbox(label="Use AutoETS")
 
219
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
220
  use_autoets, use_autoarima
221
  ],
222
+ outputs=[
223
+ eval_output, forecast_output, plot_output, message_output, window_selector
224
+ ]
225
  )
226
 
227
+ window_selector.change(fn=update_window_plot, inputs=window_selector, outputs=plot_output)
228
 
229
  if __name__ == "__main__":
230
  app.launch(share=False)