fmegahed commited on
Commit
57f0f2b
·
verified ·
1 Parent(s): 776c727

Fixing the indentation

Browse files
Files changed (1) hide show
  1. app.py +13 -32
app.py CHANGED
@@ -15,7 +15,7 @@ from statsforecast.models import (
15
  )
16
 
17
  from utilsforecast.evaluation import evaluate
18
- from utilsforecast.losses import *
19
 
20
  # Function to load and process uploaded CSV
21
  def load_data(file):
@@ -33,20 +33,15 @@ def load_data(file):
33
  except Exception as e:
34
  return None, f"Error loading data: {str(e)}"
35
 
36
-
37
- # Global store to hold cross-validation forecasts
38
- forecast_store = {}
39
-
40
- # Function to generate and return a plot
41
-
42
  def create_forecast_plot(forecast_df, original_df, window=None):
43
  plt.figure(figsize=(10, 6))
44
  unique_ids = forecast_df['unique_id'].unique()
45
- if window is not None and 'cutoff' in forecast_df.columns:
46
- forecast_df = forecast_df[forecast_df['cutoff'] == window]
47
-
48
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
49
 
 
 
 
50
  for unique_id in unique_ids:
51
  original_data = original_df[original_df['unique_id'] == unique_id]
52
  plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
@@ -55,7 +50,7 @@ def create_forecast_plot(forecast_df, original_df, window=None):
55
  if col in forecast_data.columns:
56
  plt.plot(forecast_data['ds'], forecast_data[col], label=col)
57
 
58
- plt.title('Forecasting Results')
59
  plt.xlabel('Date')
60
  plt.ylabel('Value')
61
  plt.legend()
@@ -84,7 +79,7 @@ def run_forecast(
84
  ):
85
  df, message = load_data(file)
86
  if df is None:
87
- return None, None, None, message
88
 
89
  models = []
90
  model_aliases = []
@@ -112,7 +107,7 @@ def run_forecast(
112
  model_aliases.append('autoarima')
113
 
114
  if not models:
115
- return None, None, None, "Please select at least one forecasting model"
116
 
117
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
118
 
@@ -121,7 +116,6 @@ def run_forecast(
121
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
122
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
123
  eval_df = pd.DataFrame(evaluation).reset_index()
124
- forecast_store['cv'] = {'forecast': cv_results, 'original': df}
125
  unique_cutoffs = sorted(cv_results['cutoff'].unique())
126
  fig_forecast = create_forecast_plot(cv_results, df, window=unique_cutoffs[0])
127
  return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
@@ -129,7 +123,7 @@ def run_forecast(
129
  else: # Fixed window
130
  train_size = len(df) - horizon
131
  if train_size <= 0:
132
- return None, None, None, f"Not enough data for horizon={horizon}"
133
 
134
  train_df = df.iloc[:train_size]
135
  test_df = df.iloc[train_size:]
@@ -141,16 +135,7 @@ def run_forecast(
141
  return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!", []
142
 
143
  except Exception as e:
144
- return None, None, None, f"Error during forecasting: {str(e)}"
145
-
146
-
147
- # Function to update forecast plot for selected CV window
148
- def update_forecast_plot(selected_window):
149
- data = forecast_store.get('cv')
150
- if not data:
151
- return None
152
- return create_forecast_plot(data['forecast'], data['original'], window=selected_window)
153
-
154
 
155
  # Sample CSV file generation
156
  def download_sample():
@@ -211,14 +196,11 @@ with gr.Blocks(title="StatsForecast Demo") as app:
211
  submit_btn = gr.Button("Run Forecast")
212
 
213
  with gr.Column(scale=3):
214
- window_selector = gr.Dropdown(label='Select CV Window', choices=[], visible=False)
215
  eval_output = gr.Dataframe(label="Evaluation Results")
216
  forecast_output = gr.Dataframe(label="Forecast Data")
217
  plot_output = gr.Plot(label="Forecast Plot")
218
  message_output = gr.Textbox(label="Message")
219
-
220
- def handle_forecast_output(eval_df, forecast_df, plot, msg, windows):
221
- return eval_df, forecast_df, plot, msg, gr.update(choices=[str(w) for w in windows], visible=bool(windows), value=str(windows[0]) if windows else None)
222
 
223
  submit_btn.click(
224
  fn=run_forecast,
@@ -231,8 +213,7 @@ with gr.Blocks(title="StatsForecast Demo") as app:
231
  outputs=[eval_output, forecast_output, plot_output, message_output, window_selector]
232
  )
233
 
 
 
234
  if __name__ == "__main__":
235
  app.launch(share=False)
236
-
237
- # Update plot when a window is selected
238
- window_selector.change(fn=update_forecast_plot, inputs=window_selector, outputs=plot_output)
 
15
  )
16
 
17
  from utilsforecast.evaluation import evaluate
18
+ from utilsforecast.losses import * # Assuming you need the metrics like bias, mae, rmse, mape
19
 
20
  # Function to load and process uploaded CSV
21
  def load_data(file):
 
33
  except Exception as e:
34
  return None, f"Error loading data: {str(e)}"
35
 
36
+ # Function to generate and return a plot for a specific cross-validation window
 
 
 
 
 
37
  def create_forecast_plot(forecast_df, original_df, window=None):
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
 
 
 
40
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
41
 
42
+ if window is not None and 'cutoff' in forecast_df.columns:
43
+ forecast_df = forecast_df[forecast_df['cutoff'] == window]
44
+
45
  for unique_id in unique_ids:
46
  original_data = original_df[original_df['unique_id'] == unique_id]
47
  plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
 
50
  if col in forecast_data.columns:
51
  plt.plot(forecast_data['ds'], forecast_data[col], label=col)
52
 
53
+ plt.title(f'Forecasting Results{" (Window: " + str(window) + ")" if window else ""}')
54
  plt.xlabel('Date')
55
  plt.ylabel('Value')
56
  plt.legend()
 
79
  ):
80
  df, message = load_data(file)
81
  if df is None:
82
+ return None, None, None, message, []
83
 
84
  models = []
85
  model_aliases = []
 
107
  model_aliases.append('autoarima')
108
 
109
  if not models:
110
+ return None, None, None, "Please select at least one forecasting model", []
111
 
112
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
113
 
 
116
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
117
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
118
  eval_df = pd.DataFrame(evaluation).reset_index()
 
119
  unique_cutoffs = sorted(cv_results['cutoff'].unique())
120
  fig_forecast = create_forecast_plot(cv_results, df, window=unique_cutoffs[0])
121
  return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
 
123
  else: # Fixed window
124
  train_size = len(df) - horizon
125
  if train_size <= 0:
126
+ return None, None, None, f"Not enough data for horizon={horizon}", []
127
 
128
  train_df = df.iloc[:train_size]
129
  test_df = df.iloc[train_size:]
 
135
  return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!", []
136
 
137
  except Exception as e:
138
+ return None, None, None, f"Error during forecasting: {str(e)}", []
 
 
 
 
 
 
 
 
 
139
 
140
  # Sample CSV file generation
141
  def download_sample():
 
196
  submit_btn = gr.Button("Run Forecast")
197
 
198
  with gr.Column(scale=3):
 
199
  eval_output = gr.Dataframe(label="Evaluation Results")
200
  forecast_output = gr.Dataframe(label="Forecast Data")
201
  plot_output = gr.Plot(label="Forecast Plot")
202
  message_output = gr.Textbox(label="Message")
203
+ window_selector = gr.Dropdown(label="Select Forecast Window", choices=[], visible=False)
 
 
204
 
205
  submit_btn.click(
206
  fn=run_forecast,
 
213
  outputs=[eval_output, forecast_output, plot_output, message_output, window_selector]
214
  )
215
 
216
+ window_selector.change(fn=create_forecast_plot, inputs=window_selector, outputs=plot_output)
217
+
218
  if __name__ == "__main__":
219
  app.launch(share=False)