fmegahed commited on
Commit
fb6ca91
·
verified ·
1 Parent(s): 8872cab

Bug fixes in the app, with both the evaluate function and the download button

Browse files
Files changed (1) hide show
  1. app.py +73 -63
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import pandas as pd
2
- import numpy as np
3
  import matplotlib.pyplot as plt
4
  import gradio as gr
 
5
 
6
  from statsforecast import StatsForecast
7
  from statsforecast.models import (
8
- HistoricAverage,
9
  Naive,
10
  SeasonalNaive,
11
  WindowAverage,
@@ -15,36 +15,46 @@ from statsforecast.models import (
15
  )
16
 
17
  from utilsforecast.evaluation import evaluate
18
- import tempfile
19
 
20
- # Function to load and process the CSV file
21
  def load_data(file):
22
  if file is None:
23
  return None, "Please upload a CSV file"
24
-
25
  try:
26
- # Safe read using file-like object
27
  df = pd.read_csv(file)
28
-
29
- # Check for required columns
30
  required_cols = ['unique_id', 'ds', 'y']
31
  missing_cols = [col for col in required_cols if col not in df.columns]
32
-
33
  if missing_cols:
34
  return None, f"Missing required columns: {', '.join(missing_cols)}"
35
-
36
- # Convert 'ds' to datetime
37
  df['ds'] = pd.to_datetime(df['ds'])
38
-
39
- # Sort by date
40
  df = df.sort_values(['unique_id', 'ds'])
41
-
42
  return df, "Data loaded successfully!"
43
-
44
  except Exception as e:
45
  return None, f"Error loading data: {str(e)}"
46
 
47
- # Forecasting logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def run_forecast(
49
  file,
50
  frequency,
@@ -68,99 +78,99 @@ def run_forecast(
68
  return None, None, None, message
69
 
70
  models = []
71
-
 
72
  if use_historical_avg:
73
- models.append(HistoricAverage(alias='historical_average'))
 
74
  if use_naive:
75
  models.append(Naive(alias='naive'))
 
76
  if use_seasonal_naive:
77
  models.append(SeasonalNaive(m=seasonality, alias='seasonal_naive'))
 
78
  if use_window_avg:
79
  models.append(WindowAverage(window_size=window_size, alias='window_average'))
 
80
  if use_seasonal_window_avg:
81
  models.append(SeasonalWindowAverage(m=seasonality, window_size=seasonal_window_size, alias='seasonal_window_average'))
 
82
  if use_autoets:
83
  models.append(AutoETS(alias='autoets'))
 
84
  if use_autoarima:
85
  models.append(AutoARIMA(alias='autoarima'))
 
86
 
87
  if not models:
88
  return None, None, None, "Please select at least one forecasting model"
89
 
90
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
91
-
92
  try:
93
  if eval_strategy == "Cross Validation":
94
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
95
- evaluation = evaluate(cv_results, df, metrics=['me', 'mae', 'rmse', 'mape'])
96
  eval_df = pd.DataFrame(evaluation).reset_index()
97
  fig_forecast = create_forecast_plot(cv_results, df)
98
  return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!"
99
- else:
 
100
  train_size = len(df) - horizon
101
  if train_size <= 0:
102
  return None, None, None, f"Not enough data for horizon={horizon}"
103
-
104
  train_df = df.iloc[:train_size]
105
  test_df = df.iloc[train_size:]
106
  sf.fit(train_df)
107
  forecast = sf.predict(h=horizon)
108
- evaluation = evaluate(forecast, test_df, metrics=['me', 'mae', 'rmse', 'mape'])
109
  eval_df = pd.DataFrame(evaluation).reset_index()
110
  fig_forecast = create_forecast_plot(forecast, df)
111
  return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!"
112
-
113
  except Exception as e:
114
  return None, None, None, f"Error during forecasting: {str(e)}"
115
 
116
- # Forecast plot
117
- def create_forecast_plot(forecast_df, original_df):
118
- plt.figure(figsize=(10, 6))
119
- unique_ids = forecast_df['unique_id'].unique()
120
- forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
121
-
122
- for unique_id in unique_ids:
123
- original_data = original_df[original_df['unique_id'] == unique_id]
124
- plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
125
- forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
126
- for col in forecast_cols:
127
- if col in forecast_data.columns:
128
- plt.plot(forecast_data['ds'], forecast_data[col], label=col)
129
-
130
- plt.title('Forecasting Results')
131
- plt.xlabel('Date')
132
- plt.ylabel('Value')
133
- plt.legend()
134
- plt.grid(True)
135
- fig = plt.gcf()
136
- return fig
137
-
138
- # Download sample file (placeholder path)
139
  def download_sample():
140
- return "sample_data.csv"
141
-
142
- # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  with gr.Blocks(title="StatsForecast Demo") as app:
144
  gr.Markdown("# 📈 StatsForecast Demo App")
145
- gr.Markdown("Upload a CSV with `unique_id`, `ds`, `y` columns and configure forecasting models.")
146
 
147
  with gr.Row():
148
  with gr.Column(scale=2):
149
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
 
150
  download_btn = gr.Button("Download Sample Data")
151
- download_output = gr.File(interactive=False, label="Sample Data", visible=False)
152
  download_btn.click(fn=download_sample, outputs=download_output)
153
 
154
- frequency = gr.Dropdown(
155
- choices=["H", "D", "WS", "MS", "QS", "YS"],
156
- label="Frequency",
157
- value="D"
158
- )
159
- eval_strategy = gr.Radio(
160
- choices=["Fixed Window", "Cross Validation"],
161
- label="Evaluation Strategy",
162
- value="Cross Validation"
163
- )
164
  horizon = gr.Slider(1, 100, value=14, label="Horizon")
165
  step_size = gr.Slider(1, 50, value=5, label="Step Size")
166
  num_windows = gr.Slider(1, 20, value=3, label="Number of Windows")
@@ -197,4 +207,4 @@ with gr.Blocks(title="StatsForecast Demo") as app:
197
  )
198
 
199
  if __name__ == "__main__":
200
- app.launch()
 
1
  import pandas as pd
 
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
+ import tempfile
5
 
6
  from statsforecast import StatsForecast
7
  from statsforecast.models import (
8
+ HistoricalAverage,
9
  Naive,
10
  SeasonalNaive,
11
  WindowAverage,
 
15
  )
16
 
17
  from utilsforecast.evaluation import evaluate
 
18
 
19
+ # Function to load and process uploaded CSV
20
  def load_data(file):
21
  if file is None:
22
  return None, "Please upload a CSV file"
 
23
  try:
 
24
  df = pd.read_csv(file)
 
 
25
  required_cols = ['unique_id', 'ds', 'y']
26
  missing_cols = [col for col in required_cols if col not in df.columns]
 
27
  if missing_cols:
28
  return None, f"Missing required columns: {', '.join(missing_cols)}"
 
 
29
  df['ds'] = pd.to_datetime(df['ds'])
 
 
30
  df = df.sort_values(['unique_id', 'ds'])
 
31
  return df, "Data loaded successfully!"
 
32
  except Exception as e:
33
  return None, f"Error loading data: {str(e)}"
34
 
35
+ # Function to generate and return a plot
36
+ def create_forecast_plot(forecast_df, original_df):
37
+ plt.figure(figsize=(10, 6))
38
+ unique_ids = forecast_df['unique_id'].unique()
39
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
40
+
41
+ for unique_id in unique_ids:
42
+ original_data = original_df[original_df['unique_id'] == unique_id]
43
+ plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
44
+ forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
45
+ for col in forecast_cols:
46
+ if col in forecast_data.columns:
47
+ plt.plot(forecast_data['ds'], forecast_data[col], label=col)
48
+
49
+ plt.title('Forecasting Results')
50
+ plt.xlabel('Date')
51
+ plt.ylabel('Value')
52
+ plt.legend()
53
+ plt.grid(True)
54
+ fig = plt.gcf()
55
+ return fig
56
+
57
+ # Main forecasting logic
58
  def run_forecast(
59
  file,
60
  frequency,
 
78
  return None, None, None, message
79
 
80
  models = []
81
+ model_aliases = []
82
+
83
  if use_historical_avg:
84
+ models.append(HistoricalAverage(alias='historical_average'))
85
+ model_aliases.append('historical_average')
86
  if use_naive:
87
  models.append(Naive(alias='naive'))
88
+ model_aliases.append('naive')
89
  if use_seasonal_naive:
90
  models.append(SeasonalNaive(m=seasonality, alias='seasonal_naive'))
91
+ model_aliases.append('seasonal_naive')
92
  if use_window_avg:
93
  models.append(WindowAverage(window_size=window_size, alias='window_average'))
94
+ model_aliases.append('window_average')
95
  if use_seasonal_window_avg:
96
  models.append(SeasonalWindowAverage(m=seasonality, window_size=seasonal_window_size, alias='seasonal_window_average'))
97
+ model_aliases.append('seasonal_window_average')
98
  if use_autoets:
99
  models.append(AutoETS(alias='autoets'))
100
+ model_aliases.append('autoets')
101
  if use_autoarima:
102
  models.append(AutoARIMA(alias='autoarima'))
103
+ model_aliases.append('autoarima')
104
 
105
  if not models:
106
  return None, None, None, "Please select at least one forecasting model"
107
 
108
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
109
+
110
  try:
111
  if eval_strategy == "Cross Validation":
112
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
113
+ evaluation = evaluate(df=cv_results, metrics=['me', 'mae', 'rmse', 'mape'], models=model_aliases)
114
  eval_df = pd.DataFrame(evaluation).reset_index()
115
  fig_forecast = create_forecast_plot(cv_results, df)
116
  return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!"
117
+
118
+ else: # Fixed window
119
  train_size = len(df) - horizon
120
  if train_size <= 0:
121
  return None, None, None, f"Not enough data for horizon={horizon}"
122
+
123
  train_df = df.iloc[:train_size]
124
  test_df = df.iloc[train_size:]
125
  sf.fit(train_df)
126
  forecast = sf.predict(h=horizon)
127
+ evaluation = evaluate(df=forecast, metrics=['me', 'mae', 'rmse', 'mape'], models=model_aliases)
128
  eval_df = pd.DataFrame(evaluation).reset_index()
129
  fig_forecast = create_forecast_plot(forecast, df)
130
  return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!"
131
+
132
  except Exception as e:
133
  return None, None, None, f"Error during forecasting: {str(e)}"
134
 
135
+ # Sample CSV file generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def download_sample():
137
+ sample_data = """unique_id,ds,y
138
+ series1,2023-01-01,100
139
+ series1,2023-01-02,105
140
+ series1,2023-01-03,102
141
+ series1,2023-01-04,107
142
+ series1,2023-01-05,104
143
+ series1,2023-01-06,110
144
+ series1,2023-01-07,108
145
+ series1,2023-01-08,112
146
+ series1,2023-01-09,115
147
+ series1,2023-01-10,118
148
+ series1,2023-01-11,120
149
+ series1,2023-01-12,123
150
+ series1,2023-01-13,126
151
+ series1,2023-01-14,129
152
+ series1,2023-01-15,131
153
+ """
154
+ temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='')
155
+ temp.write(sample_data)
156
+ temp.close()
157
+ return temp.name
158
+
159
+ # Gradio interface
160
  with gr.Blocks(title="StatsForecast Demo") as app:
161
  gr.Markdown("# 📈 StatsForecast Demo App")
162
+ gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
163
 
164
  with gr.Row():
165
  with gr.Column(scale=2):
166
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
167
+
168
  download_btn = gr.Button("Download Sample Data")
169
+ download_output = gr.File(label="Click to download", visible=True)
170
  download_btn.click(fn=download_sample, outputs=download_output)
171
 
172
+ frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
173
+ eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
 
 
 
 
 
 
 
 
174
  horizon = gr.Slider(1, 100, value=14, label="Horizon")
175
  step_size = gr.Slider(1, 50, value=5, label="Step Size")
176
  num_windows = gr.Slider(1, 20, value=3, label="Number of Windows")
 
207
  )
208
 
209
  if __name__ == "__main__":
210
+ app.launch(share=True)