fmegahed commited on
Commit
6f155ab
·
verified ·
1 Parent(s): 2039666

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -189
app.py CHANGED
@@ -2,7 +2,6 @@ import pandas as pd
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import tempfile
5
- from datetime import datetime
6
 
7
  from statsforecast import StatsForecast
8
  from statsforecast.models import (
@@ -35,126 +34,18 @@ def load_data(file):
35
  return None, f"Error loading data: {str(e)}"
36
 
37
  # Function to generate and return a plot
38
- def create_forecast_plot(forecast_df, original_df, selected_cutoff=None):
39
- if forecast_df is None or original_df is None:
40
- return None
41
-
42
  plt.figure(figsize=(10, 6))
43
  unique_ids = forecast_df['unique_id'].unique()
44
-
45
- # Check if the DataFrame has a cutoff column (cross-validation format)
46
- is_cv_format = 'cutoff' in forecast_df.columns
47
-
48
- if is_cv_format:
49
- # For cross-validation format
50
- cutoffs = forecast_df['cutoff'].unique()
51
-
52
- # Use selected cutoff if provided, otherwise use the latest
53
- if selected_cutoff is not None and selected_cutoff in cutoffs:
54
- cutoff_to_use = selected_cutoff
55
- else:
56
- cutoff_to_use = max(cutoffs)
57
-
58
- # Get model names - StatsForecast uses dash (-) not underscore (_)
59
- model_names = set()
60
- for col in forecast_df.columns:
61
- if col not in ['unique_id', 'ds', 'cutoff', 'y'] and '-' in col:
62
- model_name = col.split('-')[0]
63
- model_names.add(model_name)
64
-
65
- # Print some debug info
66
- print(f"Available columns: {forecast_df.columns.tolist()}")
67
- print(f"Detected model names: {model_names}")
68
- print(f"Selected cutoff: {cutoff_to_use}")
69
-
70
- for unique_id in unique_ids:
71
- # Filter forecast data for the selected cutoff
72
- forecast_data = forecast_df[(forecast_df['unique_id'] == unique_id) &
73
- (forecast_df['cutoff'] == cutoff_to_use)]
74
-
75
- if forecast_data.empty:
76
- print(f"No forecast data for unique_id={unique_id} and cutoff={cutoff_to_use}")
77
- continue
78
-
79
- # Get original data
80
- original_data = original_df[original_df['unique_id'] == unique_id].copy()
81
-
82
- # Determine the forecast horizon based on the available data
83
- horizons = []
84
- for col in forecast_data.columns:
85
- if '-' in col:
86
- try:
87
- h = int(col.split('-')[1])
88
- horizons.append(h)
89
- except (ValueError, IndexError):
90
- continue
91
-
92
- if not horizons:
93
- print(f"No valid horizons found for models")
94
- continue
95
-
96
- max_horizon = max(horizons)
97
-
98
- # Split original data into "before cutoff" and "after cutoff"
99
- train_data = original_data[original_data['ds'] <= cutoff_to_use]
100
- test_data = original_data[original_data['ds'] > cutoff_to_use]
101
-
102
- # Limit test data to horizon length
103
- test_data = test_data.iloc[:max_horizon]
104
-
105
- # Plot training data
106
- plt.plot(train_data['ds'], train_data['y'], 'k-', label='Historical Data')
107
-
108
- # Plot test data (actual values during forecast period)
109
- if not test_data.empty:
110
- plt.plot(test_data['ds'], test_data['y'], 'k--', label='Actual (Test)')
111
-
112
- # Plot forecasts for each model
113
- for model in model_names:
114
- model_forecast_data = []
115
- model_forecast_dates = []
116
-
117
- # Get columns for this model with different horizons
118
- for h in range(1, max_horizon + 1):
119
- col = f"{model}-{h}"
120
- if col in forecast_data.columns:
121
- # There is only one row per unique_id and cutoff
122
- forecast_value = forecast_data[col].iloc[0] if not forecast_data.empty else None
123
- if forecast_value is not None:
124
- # Calculate the date for this horizon step
125
- # Use the frequency from original data to set the timedelta
126
- if frequency == 'D':
127
- forecast_date = cutoff_to_use + pd.Timedelta(days=h)
128
- elif frequency == 'H':
129
- forecast_date = cutoff_to_use + pd.Timedelta(hours=h)
130
- elif frequency == 'WS':
131
- forecast_date = cutoff_to_use + pd.Timedelta(weeks=h)
132
- elif frequency == 'MS':
133
- forecast_date = cutoff_to_use + pd.DateOffset(months=h)
134
- elif frequency == 'QS':
135
- forecast_date = cutoff_to_use + pd.DateOffset(months=3*h)
136
- elif frequency == 'YS':
137
- forecast_date = cutoff_to_use + pd.DateOffset(years=h)
138
- else:
139
- forecast_date = cutoff_to_use + pd.Timedelta(days=h)
140
-
141
- model_forecast_dates.append(forecast_date)
142
- model_forecast_data.append(forecast_value)
143
-
144
- if model_forecast_data:
145
- plt.plot(model_forecast_dates, model_forecast_data, '-o', label=model)
146
-
147
- else:
148
- # For fixed window format
149
- forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
150
-
151
- for unique_id in unique_ids:
152
- original_data = original_df[original_df['unique_id'] == unique_id]
153
- plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
154
- forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
155
- for col in forecast_cols:
156
- if col in forecast_data.columns:
157
- plt.plot(forecast_data['ds'], forecast_data[col], label=col)
158
 
159
  plt.title('Forecasting Results')
160
  plt.xlabel('Date')
@@ -164,26 +55,10 @@ def create_forecast_plot(forecast_df, original_df, selected_cutoff=None):
164
  fig = plt.gcf()
165
  return fig
166
 
167
- # Function to update plot when cutoff is selected
168
- def update_plot(forecast_df, original_df, selected_cutoff, freq):
169
- if forecast_df is None or original_df is None:
170
- return None
171
-
172
- # Convert the selected cutoff string back to datetime if needed
173
- if isinstance(selected_cutoff, str) and 'cutoff' in forecast_df.columns:
174
- # If forecast_df cutoffs are datetime objects
175
- if isinstance(forecast_df['cutoff'].iloc[0], pd.Timestamp):
176
- selected_cutoff = pd.to_datetime(selected_cutoff)
177
-
178
- return create_forecast_plot(forecast_df, original_df, selected_cutoff)
179
-
180
- # Global variable to store frequency
181
- frequency = 'D'
182
-
183
  # Main forecasting logic
184
  def run_forecast(
185
  file,
186
- freq,
187
  eval_strategy,
188
  horizon,
189
  step_size,
@@ -199,12 +74,9 @@ def run_forecast(
199
  use_autoets,
200
  use_autoarima
201
  ):
202
- global frequency
203
- frequency = freq # Store for use in create_forecast_plot
204
-
205
  df, message = load_data(file)
206
  if df is None:
207
- return None, None, None, None, message, gr.Dropdown(visible=False), frequency
208
 
209
  models = []
210
  model_aliases = []
@@ -232,7 +104,7 @@ def run_forecast(
232
  model_aliases.append('autoarima')
233
 
234
  if not models:
235
- return None, None, None, None, "Please select at least one forecasting model", gr.Dropdown(visible=False), frequency
236
 
237
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
238
 
@@ -241,28 +113,13 @@ def run_forecast(
241
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
242
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
243
  eval_df = pd.DataFrame(evaluation).reset_index()
244
-
245
- # Get the cutoff dates for the dropdown
246
- cutoffs = sorted(cv_results['cutoff'].unique())
247
- cutoff_strs = [str(cutoff) for cutoff in cutoffs]
248
-
249
- # Create dropdown with cutoff dates
250
- cutoff_dropdown = gr.Dropdown(
251
- choices=cutoff_strs,
252
- value=cutoff_strs[-1] if cutoff_strs else None,
253
- label="Select Window Cutoff Date",
254
- visible=True
255
- )
256
-
257
- # Default to latest cutoff for initial plot
258
- fig_forecast = create_forecast_plot(cv_results, df, cutoffs[-1] if cutoffs else None)
259
-
260
- return eval_df, cv_results, df, fig_forecast, "Cross validation completed successfully!", cutoff_dropdown, frequency
261
 
262
  else: # Fixed window
263
  train_size = len(df) - horizon
264
  if train_size <= 0:
265
- return None, None, None, None, f"Not enough data for horizon={horizon}", gr.Dropdown(visible=False), frequency
266
 
267
  train_df = df.iloc[:train_size]
268
  test_df = df.iloc[train_size:]
@@ -270,16 +127,11 @@ def run_forecast(
270
  forecast = sf.predict(h=horizon)
271
  evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
272
  eval_df = pd.DataFrame(evaluation).reset_index()
273
-
274
- # No cutoff dropdown needed for fixed window
275
- cutoff_dropdown = gr.Dropdown(visible=False)
276
-
277
  fig_forecast = create_forecast_plot(forecast, df)
278
-
279
- return eval_df, forecast, df, fig_forecast, "Fixed window evaluation completed successfully!", cutoff_dropdown, frequency
280
 
281
  except Exception as e:
282
- return None, None, None, None, f"Error during forecasting: {str(e)}", gr.Dropdown(visible=False), frequency
283
 
284
  # Sample CSV file generation
285
  def download_sample():
@@ -310,11 +162,6 @@ with gr.Blocks(title="StatsForecast Demo") as app:
310
  gr.Markdown("# 📈 StatsForecast Demo App")
311
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
312
 
313
- # Store data for reuse between components
314
- original_df_state = gr.State(None)
315
- forecast_df_state = gr.State(None)
316
- frequency_state = gr.State("D")
317
-
318
  with gr.Row():
319
  with gr.Column(scale=2):
320
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
@@ -323,12 +170,13 @@ with gr.Blocks(title="StatsForecast Demo") as app:
323
  download_output = gr.File(label="Click to download", visible=True)
324
  download_btn.click(fn=download_sample, outputs=download_output)
325
 
326
- freq_input = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
327
  eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
328
  horizon = gr.Slider(1, 100, value=14, step=1, label="Horizon")
329
  step_size = gr.Slider(1, 50, value=5, step=1, label="Step Size")
330
  num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
331
 
 
332
  gr.Markdown("### Model Configuration")
333
  use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
334
  use_naive = gr.Checkbox(label="Use Naive", value=True)
@@ -344,35 +192,20 @@ with gr.Blocks(title="StatsForecast Demo") as app:
344
  submit_btn = gr.Button("Run Forecast")
345
 
346
  with gr.Column(scale=3):
347
- # Add cutoff selector dropdown (initially hidden)
348
- cutoff_selector = gr.Dropdown(
349
- choices=[],
350
- label="Select Window Cutoff Date",
351
- visible=False
352
- )
353
-
354
  eval_output = gr.Dataframe(label="Evaluation Results")
355
  forecast_output = gr.Dataframe(label="Forecast Data")
356
  plot_output = gr.Plot(label="Forecast Plot")
357
  message_output = gr.Textbox(label="Message")
358
 
359
- # Run forecast button click event
360
  submit_btn.click(
361
  fn=run_forecast,
362
  inputs=[
363
- file_input, freq_input, eval_strategy, horizon, step_size, num_windows,
364
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
365
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
366
  use_autoets, use_autoarima
367
  ],
368
- outputs=[eval_output, forecast_df_state, original_df_state, plot_output, message_output, cutoff_selector, frequency_state]
369
- )
370
-
371
- # Cutoff selector change event
372
- cutoff_selector.change(
373
- fn=update_plot,
374
- inputs=[forecast_df_state, original_df_state, cutoff_selector, frequency_state],
375
- outputs=[plot_output]
376
  )
377
 
378
  if __name__ == "__main__":
 
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import tempfile
 
5
 
6
  from statsforecast import StatsForecast
7
  from statsforecast.models import (
 
34
  return None, f"Error loading data: {str(e)}"
35
 
36
  # Function to generate and return a plot
37
+ def create_forecast_plot(forecast_df, original_df):
 
 
 
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
40
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
41
+
42
+ for unique_id in unique_ids:
43
+ original_data = original_df[original_df['unique_id'] == unique_id]
44
+ plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
45
+ forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
46
+ for col in forecast_cols:
47
+ if col in forecast_data.columns:
48
+ plt.plot(forecast_data['ds'], forecast_data[col], label=col)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  plt.title('Forecasting Results')
51
  plt.xlabel('Date')
 
55
  fig = plt.gcf()
56
  return fig
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Main forecasting logic
59
  def run_forecast(
60
  file,
61
+ frequency,
62
  eval_strategy,
63
  horizon,
64
  step_size,
 
74
  use_autoets,
75
  use_autoarima
76
  ):
 
 
 
77
  df, message = load_data(file)
78
  if df is None:
79
+ return None, None, None, message
80
 
81
  models = []
82
  model_aliases = []
 
104
  model_aliases.append('autoarima')
105
 
106
  if not models:
107
+ return None, None, None, "Please select at least one forecasting model"
108
 
109
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
110
 
 
113
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
114
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
115
  eval_df = pd.DataFrame(evaluation).reset_index()
116
+ fig_forecast = create_forecast_plot(cv_results, df)
117
+ return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  else: # Fixed window
120
  train_size = len(df) - horizon
121
  if train_size <= 0:
122
+ return None, None, None, f"Not enough data for horizon={horizon}"
123
 
124
  train_df = df.iloc[:train_size]
125
  test_df = df.iloc[train_size:]
 
127
  forecast = sf.predict(h=horizon)
128
  evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
129
  eval_df = pd.DataFrame(evaluation).reset_index()
 
 
 
 
130
  fig_forecast = create_forecast_plot(forecast, df)
131
+ return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!"
 
132
 
133
  except Exception as e:
134
+ return None, None, None, f"Error during forecasting: {str(e)}"
135
 
136
  # Sample CSV file generation
137
  def download_sample():
 
162
  gr.Markdown("# 📈 StatsForecast Demo App")
163
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
164
 
 
 
 
 
 
165
  with gr.Row():
166
  with gr.Column(scale=2):
167
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
 
170
  download_output = gr.File(label="Click to download", visible=True)
171
  download_btn.click(fn=download_sample, outputs=download_output)
172
 
173
+ frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
174
  eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
175
  horizon = gr.Slider(1, 100, value=14, step=1, label="Horizon")
176
  step_size = gr.Slider(1, 50, value=5, step=1, label="Step Size")
177
  num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
178
 
179
+
180
  gr.Markdown("### Model Configuration")
181
  use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
182
  use_naive = gr.Checkbox(label="Use Naive", value=True)
 
192
  submit_btn = gr.Button("Run Forecast")
193
 
194
  with gr.Column(scale=3):
 
 
 
 
 
 
 
195
  eval_output = gr.Dataframe(label="Evaluation Results")
196
  forecast_output = gr.Dataframe(label="Forecast Data")
197
  plot_output = gr.Plot(label="Forecast Plot")
198
  message_output = gr.Textbox(label="Message")
199
 
 
200
  submit_btn.click(
201
  fn=run_forecast,
202
  inputs=[
203
+ file_input, frequency, eval_strategy, horizon, step_size, num_windows,
204
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
205
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
206
  use_autoets, use_autoarima
207
  ],
208
+ outputs=[eval_output, forecast_output, plot_output, message_output]
 
 
 
 
 
 
 
209
  )
210
 
211
  if __name__ == "__main__":