fmegahed commited on
Commit
9455ec6
·
verified ·
1 Parent(s): 5fe9e2e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+
6
+ from statsforecast import StatsForecast
7
+ from statsforecast.models import (
8
+ HistoricalAverage,
9
+ Naive,
10
+ SeasonalNaive,
11
+ WindowAverage,
12
+ SeasonalWindowAverage,
13
+ AutoETS,
14
+ AutoARIMA
15
+ )
16
+
17
+ from utilsforecast.evaluation import evaluate
18
+ import tempfile
19
+
20
+ # Function to load and process the CSV file
21
+ def load_data(file):
22
+ if file is None:
23
+ return None, "Please upload a CSV file"
24
+
25
+ try:
26
+ # Safe read using file-like object
27
+ df = pd.read_csv(file)
28
+
29
+ # Check for required columns
30
+ required_cols = ['unique_id', 'ds', 'y']
31
+ missing_cols = [col for col in required_cols if col not in df.columns]
32
+
33
+ if missing_cols:
34
+ return None, f"Missing required columns: {', '.join(missing_cols)}"
35
+
36
+ # Convert 'ds' to datetime
37
+ df['ds'] = pd.to_datetime(df['ds'])
38
+
39
+ # Sort by date
40
+ df = df.sort_values(['unique_id', 'ds'])
41
+
42
+ return df, "Data loaded successfully!"
43
+
44
+ except Exception as e:
45
+ return None, f"Error loading data: {str(e)}"
46
+
47
+ # Forecasting logic
48
+ def run_forecast(
49
+ file,
50
+ frequency,
51
+ eval_strategy,
52
+ horizon,
53
+ step_size,
54
+ num_windows,
55
+ use_historical_avg,
56
+ use_naive,
57
+ use_seasonal_naive,
58
+ seasonality,
59
+ use_window_avg,
60
+ window_size,
61
+ use_seasonal_window_avg,
62
+ seasonal_window_size,
63
+ use_autoets,
64
+ use_autoarima
65
+ ):
66
+ df, message = load_data(file)
67
+ if df is None:
68
+ return None, None, None, message
69
+
70
+ models = []
71
+
72
+ if use_historical_avg:
73
+ models.append(HistoricalAverage(alias='historical_average'))
74
+ if use_naive:
75
+ models.append(Naive(alias='naive'))
76
+ if use_seasonal_naive:
77
+ models.append(SeasonalNaive(m=seasonality, alias='seasonal_naive'))
78
+ if use_window_avg:
79
+ models.append(WindowAverage(window_size=window_size, alias='window_average'))
80
+ if use_seasonal_window_avg:
81
+ models.append(SeasonalWindowAverage(m=seasonality, window_size=seasonal_window_size, alias='seasonal_window_average'))
82
+ if use_autoets:
83
+ models.append(AutoETS(alias='autoets'))
84
+ if use_autoarima:
85
+ models.append(AutoARIMA(alias='autoarima'))
86
+
87
+ if not models:
88
+ return None, None, None, "Please select at least one forecasting model"
89
+
90
+ sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
91
+
92
+ try:
93
+ if eval_strategy == "Cross Validation":
94
+ cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
95
+ evaluation = evaluate(cv_results, df, metrics=['me', 'mae', 'rmse', 'mape'])
96
+ eval_df = pd.DataFrame(evaluation).reset_index()
97
+ fig_forecast = create_forecast_plot(cv_results, df)
98
+ return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!"
99
+ else:
100
+ train_size = len(df) - horizon
101
+ if train_size <= 0:
102
+ return None, None, None, f"Not enough data for horizon={horizon}"
103
+
104
+ train_df = df.iloc[:train_size]
105
+ test_df = df.iloc[train_size:]
106
+ sf.fit(train_df)
107
+ forecast = sf.predict(h=horizon)
108
+ evaluation = evaluate(forecast, test_df, metrics=['me', 'mae', 'rmse', 'mape'])
109
+ eval_df = pd.DataFrame(evaluation).reset_index()
110
+ fig_forecast = create_forecast_plot(forecast, df)
111
+ return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!"
112
+
113
+ except Exception as e:
114
+ return None, None, None, f"Error during forecasting: {str(e)}"
115
+
116
+ # Forecast plot
117
+ def create_forecast_plot(forecast_df, original_df):
118
+ plt.figure(figsize=(10, 6))
119
+ unique_ids = forecast_df['unique_id'].unique()
120
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
121
+
122
+ for unique_id in unique_ids:
123
+ original_data = original_df[original_df['unique_id'] == unique_id]
124
+ plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
125
+ forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
126
+ for col in forecast_cols:
127
+ if col in forecast_data.columns:
128
+ plt.plot(forecast_data['ds'], forecast_data[col], label=col)
129
+
130
+ plt.title('Forecasting Results')
131
+ plt.xlabel('Date')
132
+ plt.ylabel('Value')
133
+ plt.legend()
134
+ plt.grid(True)
135
+ fig = plt.gcf()
136
+ return fig
137
+
138
+ # Download sample file (placeholder path)
139
+ def download_sample():
140
+ return "sample_data.csv"
141
+
142
+ # Gradio UI
143
+ with gr.Blocks(title="StatsForecast Demo") as app:
144
+ gr.Markdown("# 📈 StatsForecast Demo App")
145
+ gr.Markdown("Upload a CSV with `unique_id`, `ds`, `y` columns and configure forecasting models.")
146
+
147
+ with gr.Row():
148
+ with gr.Column(scale=2):
149
+ file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
150
+ download_btn = gr.Button("Download Sample Data")
151
+ download_output = gr.File(interactive=False, label="Sample Data", visible=False)
152
+ download_btn.click(fn=download_sample, outputs=download_output)
153
+
154
+ frequency = gr.Dropdown(
155
+ choices=["H", "D", "WS", "MS", "QS", "YS"],
156
+ label="Frequency",
157
+ value="D"
158
+ )
159
+ eval_strategy = gr.Radio(
160
+ choices=["Fixed Window", "Cross Validation"],
161
+ label="Evaluation Strategy",
162
+ value="Cross Validation"
163
+ )
164
+ horizon = gr.Slider(1, 100, value=14, label="Horizon")
165
+ step_size = gr.Slider(1, 50, value=5, label="Step Size")
166
+ num_windows = gr.Slider(1, 20, value=3, label="Number of Windows")
167
+
168
+ gr.Markdown("### Model Configuration")
169
+ use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
170
+ use_naive = gr.Checkbox(label="Use Naive", value=True)
171
+ use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
172
+ seasonality = gr.Number(label="Seasonality", value=7)
173
+ use_window_avg = gr.Checkbox(label="Use Window Average")
174
+ window_size = gr.Number(label="Window Size", value=3)
175
+ use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
176
+ seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
177
+ use_autoets = gr.Checkbox(label="Use AutoETS")
178
+ use_autoarima = gr.Checkbox(label="Use AutoARIMA")
179
+
180
+ submit_btn = gr.Button("Run Forecast")
181
+
182
+ with gr.Column(scale=3):
183
+ eval_output = gr.Dataframe(label="Evaluation Results")
184
+ forecast_output = gr.Dataframe(label="Forecast Data")
185
+ plot_output = gr.Plot(label="Forecast Plot")
186
+ message_output = gr.Textbox(label="Message")
187
+
188
+ submit_btn.click(
189
+ fn=run_forecast,
190
+ inputs=[
191
+ file_input, frequency, eval_strategy, horizon, step_size, num_windows,
192
+ use_historical_avg, use_naive, use_seasonal_naive, seasonality,
193
+ use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
194
+ use_autoets, use_autoarima
195
+ ],
196
+ outputs=[eval_output, forecast_output, plot_output, message_output]
197
+ )
198
+
199
+ if __name__ == "__main__":
200
+ app.launch()