Spaces:
Running
Running
Trying to plot each cross validation window separately
Browse files
app.py
CHANGED
@@ -33,10 +33,18 @@ def load_data(file):
|
|
33 |
except Exception as e:
|
34 |
return None, f"Error loading data: {str(e)}"
|
35 |
|
|
|
|
|
|
|
|
|
36 |
# Function to generate and return a plot
|
37 |
-
|
|
|
38 |
plt.figure(figsize=(10, 6))
|
39 |
unique_ids = forecast_df['unique_id'].unique()
|
|
|
|
|
|
|
40 |
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
|
41 |
|
42 |
for unique_id in unique_ids:
|
@@ -113,8 +121,10 @@ def run_forecast(
|
|
113 |
cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
|
114 |
evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
|
115 |
eval_df = pd.DataFrame(evaluation).reset_index()
|
116 |
-
|
117 |
-
|
|
|
|
|
118 |
|
119 |
else: # Fixed window
|
120 |
train_size = len(df) - horizon
|
@@ -128,11 +138,20 @@ def run_forecast(
|
|
128 |
evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
|
129 |
eval_df = pd.DataFrame(evaluation).reset_index()
|
130 |
fig_forecast = create_forecast_plot(forecast, df)
|
131 |
-
return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!"
|
132 |
|
133 |
except Exception as e:
|
134 |
return None, None, None, f"Error during forecasting: {str(e)}"
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
# Sample CSV file generation
|
137 |
def download_sample():
|
138 |
sample_data = """unique_id,ds,y
|
@@ -192,11 +211,15 @@ with gr.Blocks(title="StatsForecast Demo") as app:
|
|
192 |
submit_btn = gr.Button("Run Forecast")
|
193 |
|
194 |
with gr.Column(scale=3):
|
|
|
195 |
eval_output = gr.Dataframe(label="Evaluation Results")
|
196 |
forecast_output = gr.Dataframe(label="Forecast Data")
|
197 |
plot_output = gr.Plot(label="Forecast Plot")
|
198 |
message_output = gr.Textbox(label="Message")
|
199 |
|
|
|
|
|
|
|
200 |
submit_btn.click(
|
201 |
fn=run_forecast,
|
202 |
inputs=[
|
@@ -205,8 +228,11 @@ with gr.Blocks(title="StatsForecast Demo") as app:
|
|
205 |
use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
|
206 |
use_autoets, use_autoarima
|
207 |
],
|
208 |
-
outputs=[eval_output, forecast_output, plot_output, message_output]
|
209 |
)
|
210 |
|
211 |
if __name__ == "__main__":
|
212 |
app.launch(share=False)
|
|
|
|
|
|
|
|
33 |
except Exception as e:
|
34 |
return None, f"Error loading data: {str(e)}"
|
35 |
|
36 |
+
|
37 |
+
# Global store to hold cross-validation forecasts
|
38 |
+
forecast_store = {}
|
39 |
+
|
40 |
# Function to generate and return a plot
|
41 |
+
|
42 |
+
def create_forecast_plot(forecast_df, original_df, window=None):
|
43 |
plt.figure(figsize=(10, 6))
|
44 |
unique_ids = forecast_df['unique_id'].unique()
|
45 |
+
if window is not None and 'cutoff' in forecast_df.columns:
|
46 |
+
forecast_df = forecast_df[forecast_df['cutoff'] == window]
|
47 |
+
|
48 |
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
|
49 |
|
50 |
for unique_id in unique_ids:
|
|
|
121 |
cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
|
122 |
evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
|
123 |
eval_df = pd.DataFrame(evaluation).reset_index()
|
124 |
+
forecast_store['cv'] = {'forecast': cv_results, 'original': df}
|
125 |
+
unique_cutoffs = sorted(cv_results['cutoff'].unique())
|
126 |
+
fig_forecast = create_forecast_plot(cv_results, df, window=unique_cutoffs[0])
|
127 |
+
return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
|
128 |
|
129 |
else: # Fixed window
|
130 |
train_size = len(df) - horizon
|
|
|
138 |
evaluation = evaluate(df=forecast, metrics=[bias, mae, rmse, mape], models=model_aliases)
|
139 |
eval_df = pd.DataFrame(evaluation).reset_index()
|
140 |
fig_forecast = create_forecast_plot(forecast, df)
|
141 |
+
return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!", []
|
142 |
|
143 |
except Exception as e:
|
144 |
return None, None, None, f"Error during forecasting: {str(e)}"
|
145 |
|
146 |
+
|
147 |
+
# Function to update forecast plot for selected CV window
|
148 |
+
def update_forecast_plot(selected_window):
|
149 |
+
data = forecast_store.get('cv')
|
150 |
+
if not data:
|
151 |
+
return None
|
152 |
+
return create_forecast_plot(data['forecast'], data['original'], window=selected_window)
|
153 |
+
|
154 |
+
|
155 |
# Sample CSV file generation
|
156 |
def download_sample():
|
157 |
sample_data = """unique_id,ds,y
|
|
|
211 |
submit_btn = gr.Button("Run Forecast")
|
212 |
|
213 |
with gr.Column(scale=3):
|
214 |
+
window_selector = gr.Dropdown(label='Select CV Window', choices=[], visible=False)
|
215 |
eval_output = gr.Dataframe(label="Evaluation Results")
|
216 |
forecast_output = gr.Dataframe(label="Forecast Data")
|
217 |
plot_output = gr.Plot(label="Forecast Plot")
|
218 |
message_output = gr.Textbox(label="Message")
|
219 |
|
220 |
+
def handle_forecast_output(eval_df, forecast_df, plot, msg, windows):
|
221 |
+
return eval_df, forecast_df, plot, msg, gr.update(choices=[str(w) for w in windows], visible=bool(windows), value=str(windows[0]) if windows else None)
|
222 |
+
|
223 |
submit_btn.click(
|
224 |
fn=run_forecast,
|
225 |
inputs=[
|
|
|
228 |
use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
|
229 |
use_autoets, use_autoarima
|
230 |
],
|
231 |
+
outputs=[eval_output, forecast_output, plot_output, message_output, window_selector]
|
232 |
)
|
233 |
|
234 |
if __name__ == "__main__":
|
235 |
app.launch(share=False)
|
236 |
+
|
237 |
+
# Update plot when a window is selected
|
238 |
+
window_selector.change(fn=update_forecast_plot, inputs=window_selector, outputs=plot_output)
|