fmegahed commited on
Commit
97fbbe3
·
verified ·
1 Parent(s): c9b451d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -47
app.py CHANGED
@@ -2,6 +2,7 @@ import pandas as pd
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import tempfile
 
5
 
6
  from statsforecast import StatsForecast
7
  from statsforecast.models import (
@@ -17,8 +18,7 @@ from statsforecast.models import (
17
  from utilsforecast.evaluation import evaluate
18
  from utilsforecast.losses import *
19
 
20
- forecast_store = {} # for storing CV results globally
21
-
22
  def load_data(file):
23
  if file is None:
24
  return None, "Please upload a CSV file"
@@ -34,29 +34,49 @@ def load_data(file):
34
  except Exception as e:
35
  return None, f"Error loading data: {str(e)}"
36
 
37
- def create_forecast_plot(forecast_df, original_df, window=None):
 
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
40
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
41
-
42
- if window is not None and 'cutoff' in forecast_df.columns:
43
- forecast_df = forecast_df[forecast_df['cutoff'] == pd.to_datetime(window)]
44
-
 
45
  for unique_id in unique_ids:
46
  original_data = original_df[original_df['unique_id'] == unique_id]
47
  plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
48
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
49
- for col in forecast_cols:
50
- if col in forecast_data.columns:
51
- plt.plot(forecast_data['ds'], forecast_data[col], label=col)
 
52
 
53
- plt.title(f'Forecasting Results{" (Window: " + str(window) + ")" if window else ""}')
54
  plt.xlabel('Date')
55
  plt.ylabel('Value')
56
  plt.legend()
57
  plt.grid(True)
58
- return plt.gcf()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
60
  def run_forecast(
61
  file,
62
  frequency,
@@ -77,7 +97,7 @@ def run_forecast(
77
  ):
78
  df, message = load_data(file)
79
  if df is None:
80
- return None, None, None, message, []
81
 
82
  models = []
83
  model_aliases = []
@@ -105,31 +125,35 @@ def run_forecast(
105
  model_aliases.append('autoarima')
106
 
107
  if not models:
108
- return None, None, None, "Please select at least one forecasting model", []
109
 
110
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
111
 
112
  try:
113
  if eval_strategy == "Cross Validation":
114
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
115
-
116
- # Store for dropdown selection
117
- forecast_store['forecast'] = cv_results
118
- forecast_store['original'] = df
119
-
120
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
121
  eval_df = pd.DataFrame(evaluation).reset_index()
122
-
123
- # Dropdown cutoffs
124
- unique_cutoffs = sorted(str(c) for c in cv_results['cutoff'].unique())
125
- fig_forecast = create_forecast_plot(cv_results, df, window=pd.to_datetime(unique_cutoffs[0]))
126
-
127
- return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
128
-
129
- else:
 
 
 
 
 
 
 
 
 
130
  train_size = len(df) - horizon
131
  if train_size <= 0:
132
- return None, None, None, f"Not enough data for horizon={horizon}", []
133
 
134
  train_df = df.iloc[:train_size]
135
  test_df = df.iloc[train_size:]
@@ -138,18 +162,13 @@ def run_forecast(
138
  evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
139
  eval_df = pd.DataFrame(evaluation).reset_index()
140
  fig_forecast = create_forecast_plot(forecast, df)
141
- return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!", []
 
142
 
143
  except Exception as e:
144
- return None, None, None, f"Error during forecasting: {str(e)}", []
145
-
146
- def update_window_plot(window_str):
147
- if 'forecast' not in forecast_store:
148
- return None
149
- forecast_df = forecast_store['forecast']
150
- original_df = forecast_store['original']
151
- return create_forecast_plot(forecast_df, original_df, window=pd.to_datetime(window_str))
152
 
 
153
  def download_sample():
154
  sample_data = """unique_id,ds,y
155
  series1,2023-01-01,100
@@ -173,30 +192,37 @@ series1,2023-01-15,131
173
  temp.close()
174
  return temp.name
175
 
 
176
  with gr.Blocks(title="StatsForecast Demo") as app:
177
  gr.Markdown("# 📈 StatsForecast Demo App")
178
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
179
 
 
 
 
 
180
  with gr.Row():
181
  with gr.Column(scale=2):
182
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
 
183
  download_btn = gr.Button("Download Sample Data")
184
  download_output = gr.File(label="Click to download", visible=True)
185
  download_btn.click(fn=download_sample, outputs=download_output)
186
 
187
  frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
188
  eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
189
- horizon = gr.Slider(1, 100, value=10, step=1, label="Horizon")
190
  step_size = gr.Slider(1, 50, value=5, step=1, label="Step Size")
191
  num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
192
 
 
193
  gr.Markdown("### Model Configuration")
194
  use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
195
  use_naive = gr.Checkbox(label="Use Naive", value=True)
196
  use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
197
- seasonality = gr.Number(label="Seasonality", value=5)
198
  use_window_avg = gr.Checkbox(label="Use Window Average")
199
- window_size = gr.Number(label="Window Size", value=10)
200
  use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
201
  seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
202
  use_autoets = gr.Checkbox(label="Use AutoETS")
@@ -207,10 +233,19 @@ with gr.Blocks(title="StatsForecast Demo") as app:
207
  with gr.Column(scale=3):
208
  eval_output = gr.Dataframe(label="Evaluation Results")
209
  forecast_output = gr.Dataframe(label="Forecast Data")
 
 
 
 
 
 
 
 
 
210
  plot_output = gr.Plot(label="Forecast Plot")
211
  message_output = gr.Textbox(label="Message")
212
- window_selector = gr.Dropdown(label="Select Forecast Window", choices=[], visible=False)
213
 
 
214
  submit_btn.click(
215
  fn=run_forecast,
216
  inputs=[
@@ -219,12 +254,25 @@ with gr.Blocks(title="StatsForecast Demo") as app:
219
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
220
  use_autoets, use_autoarima
221
  ],
222
- outputs=[
223
- eval_output, forecast_output, plot_output, message_output, window_selector
224
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  )
226
-
227
- window_selector.change(fn=update_window_plot, inputs=window_selector, outputs=plot_output)
228
 
229
  if __name__ == "__main__":
230
- app.launch(share=False)
 
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import tempfile
5
+ import numpy as np
6
 
7
  from statsforecast import StatsForecast
8
  from statsforecast.models import (
 
18
  from utilsforecast.evaluation import evaluate
19
  from utilsforecast.losses import *
20
 
21
+ # Function to load and process uploaded CSV
 
22
  def load_data(file):
23
  if file is None:
24
  return None, "Please upload a CSV file"
 
34
  except Exception as e:
35
  return None, f"Error loading data: {str(e)}"
36
 
37
+ # Function to generate and return a plot
38
+ def create_forecast_plot(forecast_df, original_df, selected_cutoff=None):
39
  plt.figure(figsize=(10, 6))
40
  unique_ids = forecast_df['unique_id'].unique()
41
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
42
+
43
+ # Filter by cutoff if provided and if 'cutoff' column exists
44
+ if selected_cutoff is not None and 'cutoff' in forecast_df.columns:
45
+ forecast_df = forecast_df[forecast_df['cutoff'] == selected_cutoff]
46
+
47
  for unique_id in unique_ids:
48
  original_data = original_df[original_df['unique_id'] == unique_id]
49
  plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
50
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
51
+ if len(forecast_data) > 0: # Only plot if there's data after filtering
52
+ for col in forecast_cols:
53
+ if col in forecast_data.columns:
54
+ plt.plot(forecast_data['ds'], forecast_data[col], label=col)
55
 
56
+ plt.title('Forecasting Results')
57
  plt.xlabel('Date')
58
  plt.ylabel('Value')
59
  plt.legend()
60
  plt.grid(True)
61
+ fig = plt.gcf()
62
+ return fig
63
+
64
+ # Function to update plot based on selected cutoff
65
+ def update_plot(selected_cutoff, cv_results, original_df):
66
+ if cv_results is None or original_df is None:
67
+ return None, "No forecast data available."
68
+
69
+ try:
70
+ # Convert string back to datetime if needed
71
+ if isinstance(selected_cutoff, str):
72
+ selected_cutoff = pd.to_datetime(selected_cutoff)
73
+
74
+ fig = create_forecast_plot(cv_results, original_df, selected_cutoff)
75
+ return fig, f"Showing forecast for cutoff: {selected_cutoff}"
76
+ except Exception as e:
77
+ return None, f"Error updating plot: {str(e)}"
78
 
79
+ # Main forecasting logic
80
  def run_forecast(
81
  file,
82
  frequency,
 
97
  ):
98
  df, message = load_data(file)
99
  if df is None:
100
+ return None, None, None, None, [], message
101
 
102
  models = []
103
  model_aliases = []
 
125
  model_aliases.append('autoarima')
126
 
127
  if not models:
128
+ return None, None, None, None, [], "Please select at least one forecasting model"
129
 
130
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
131
 
132
  try:
133
  if eval_strategy == "Cross Validation":
134
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
 
 
 
 
 
135
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
136
  eval_df = pd.DataFrame(evaluation).reset_index()
137
+
138
+ # Get unique cutoff dates for the dropdown
139
+ cutoff_dates = cv_results['cutoff'].unique().tolist()
140
+
141
+ # Sort cutoff dates (newest first)
142
+ cutoff_dates.sort(reverse=True)
143
+
144
+ # Use the most recent cutoff for initial plot
145
+ if cutoff_dates:
146
+ latest_cutoff = cutoff_dates[0]
147
+ fig_forecast = create_forecast_plot(cv_results, df, latest_cutoff)
148
+ else:
149
+ fig_forecast = create_forecast_plot(cv_results, df)
150
+
151
+ return eval_df, cv_results, fig_forecast, df, cutoff_dates, "Cross validation completed successfully!"
152
+
153
+ else: # Fixed window
154
  train_size = len(df) - horizon
155
  if train_size <= 0:
156
+ return None, None, None, None, [], f"Not enough data for horizon={horizon}"
157
 
158
  train_df = df.iloc[:train_size]
159
  test_df = df.iloc[train_size:]
 
162
  evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
163
  eval_df = pd.DataFrame(evaluation).reset_index()
164
  fig_forecast = create_forecast_plot(forecast, df)
165
+ # For fixed window, we don't have cutoff dates
166
+ return eval_df, forecast, fig_forecast, df, [], "Fixed window evaluation completed successfully!"
167
 
168
  except Exception as e:
169
+ return None, None, None, None, [], f"Error during forecasting: {str(e)}"
 
 
 
 
 
 
 
170
 
171
+ # Sample CSV file generation
172
  def download_sample():
173
  sample_data = """unique_id,ds,y
174
  series1,2023-01-01,100
 
192
  temp.close()
193
  return temp.name
194
 
195
+ # Gradio interface
196
  with gr.Blocks(title="StatsForecast Demo") as app:
197
  gr.Markdown("# 📈 StatsForecast Demo App")
198
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
199
 
200
+ # Store state variables
201
+ cv_results_state = gr.State(None)
202
+ original_df_state = gr.State(None)
203
+
204
  with gr.Row():
205
  with gr.Column(scale=2):
206
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
207
+
208
  download_btn = gr.Button("Download Sample Data")
209
  download_output = gr.File(label="Click to download", visible=True)
210
  download_btn.click(fn=download_sample, outputs=download_output)
211
 
212
  frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
213
  eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
214
+ horizon = gr.Slider(1, 100, value=14, step=1, label="Horizon")
215
  step_size = gr.Slider(1, 50, value=5, step=1, label="Step Size")
216
  num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
217
 
218
+
219
  gr.Markdown("### Model Configuration")
220
  use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
221
  use_naive = gr.Checkbox(label="Use Naive", value=True)
222
  use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
223
+ seasonality = gr.Number(label="Seasonality", value=7)
224
  use_window_avg = gr.Checkbox(label="Use Window Average")
225
+ window_size = gr.Number(label="Window Size", value=3)
226
  use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
227
  seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
228
  use_autoets = gr.Checkbox(label="Use AutoETS")
 
233
  with gr.Column(scale=3):
234
  eval_output = gr.Dataframe(label="Evaluation Results")
235
  forecast_output = gr.Dataframe(label="Forecast Data")
236
+
237
+ # Add cutoff selection dropdown
238
+ cutoff_dropdown = gr.Dropdown(
239
+ label="Select Validation Window (Cutoff Date)",
240
+ choices=[],
241
+ interactive=True,
242
+ visible=False
243
+ )
244
+
245
  plot_output = gr.Plot(label="Forecast Plot")
246
  message_output = gr.Textbox(label="Message")
 
247
 
248
+ # Run forecast function with updated outputs
249
  submit_btn.click(
250
  fn=run_forecast,
251
  inputs=[
 
254
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
255
  use_autoets, use_autoarima
256
  ],
257
+ outputs=[eval_output, cv_results_state, plot_output, original_df_state, cutoff_dropdown, message_output]
258
+ )
259
+
260
+ # Update cutoff dropdown visibility based on evaluation strategy
261
+ def update_dropdown_visibility(strategy):
262
+ return gr.update(visible=strategy == "Cross Validation")
263
+
264
+ eval_strategy.change(
265
+ fn=update_dropdown_visibility,
266
+ inputs=[eval_strategy],
267
+ outputs=[cutoff_dropdown]
268
+ )
269
+
270
+ # Update plot when cutoff is selected
271
+ cutoff_dropdown.change(
272
+ fn=update_plot,
273
+ inputs=[cutoff_dropdown, cv_results_state, original_df_state],
274
+ outputs=[plot_output, message_output]
275
  )
 
 
276
 
277
  if __name__ == "__main__":
278
+ app.launch(share=False)