samkeet commited on
Commit
00b00eb
·
verified ·
1 Parent(s): 6ce02b0

Upload 40 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ db/imp_db.db filter=lfs diff=lfs merge=lfs -text
.streamlit/config.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ [theme]
3
+ primaryColor="#f6ad51"
4
+ backgroundColor="#FFFFFF"
5
+ secondaryBackgroundColor="#F0F2F6"
6
+ textColor="#31333F"
7
+ font="sans serif"
Home.py ADDED
@@ -0,0 +1,1614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # importing pacakages
2
+ import streamlit as st
3
+
4
+ st.set_page_config(layout="wide")
5
+ from utilities import (
6
+ load_local_css,
7
+ set_header,
8
+ ensure_project_dct_structure,
9
+ store_hashed_password,
10
+ verify_password,
11
+ is_pswrd_flag_set,
12
+ set_pswrd_flag,
13
+ )
14
+ import os
15
+ from datetime import datetime
16
+ import pandas as pd
17
+ import pickle
18
+ import psycopg2
19
+
20
+ #
21
+ import numbers
22
+ from collections import OrderedDict
23
+ import re
24
+ from ppt_utils import create_ppt
25
+ from constants import default_dct
26
+ import time
27
+ from log_application import log_message, delete_old_log_files
28
+ import sqlite3
29
+
30
+ # setting page config
31
+ load_local_css("styles.css")
32
+ set_header()
33
+ db_cred = None
34
+
35
+ # --------------Functions----------------------#
36
+
37
+ # # schema = db_cred["schema"]
38
+
39
+
40
+ ##API DATA#######################
41
+ # Function to load gold layer data
42
+ # @st.cache_data(show_spinner=False)
43
+ def load_gold_layer_data(table_name):
44
+ # Fetch Table
45
+ query = f"""
46
+ SELECT * FROM {table_name};
47
+ """
48
+
49
+ # Execute the query and get the results
50
+ results = query_excecuter_postgres(
51
+ query, db_cred, insert=False, return_dataframe=True
52
+ )
53
+
54
+ if results is not None and not results.empty:
55
+ # Create a DataFrame
56
+ gold_layer_df = results
57
+
58
+ else:
59
+ st.warning("No data found for the selected table.")
60
+ st.stop()
61
+
62
+ # Columns to be removed
63
+ columns_to_remove = [
64
+ "clnt_nam",
65
+ "crte_dt_tm",
66
+ "crte_by_uid",
67
+ "updt_dt_tm",
68
+ "updt_by_uid",
69
+ "campgn_id",
70
+ "campgn_nam",
71
+ "ad_id",
72
+ "ad_nam",
73
+ "tctc_id",
74
+ "tctc_nam",
75
+ "campgn_grp_id",
76
+ "campgn_grp_nam",
77
+ "ad_grp_id",
78
+ "ad_grp_nam",
79
+ ]
80
+
81
+ # TEMP CODE
82
+ gold_layer_df = gold_layer_df.rename(
83
+ columns={
84
+ "imprssns_cnt": "mda_imprssns_cnt",
85
+ "clcks_cnt": "mda_clcks_cnt",
86
+ "vd_vws_cnt": "mda_vd_vws_cnt",
87
+ }
88
+ )
89
+
90
+ # Remove specific columns
91
+ gold_layer_df = gold_layer_df.drop(columns=columns_to_remove, errors="ignore")
92
+
93
+ # Convert columns to numeric or datetime as appropriate
94
+ for col in gold_layer_df.columns:
95
+ if (
96
+ col.startswith("rspns_mtrc_")
97
+ or col.startswith("mda_")
98
+ or col.startswith("exogenous_")
99
+ or col.startswith("internal_")
100
+ or col in ["spnd_amt"]
101
+ ):
102
+ gold_layer_df[col] = pd.to_numeric(gold_layer_df[col], errors="coerce")
103
+ elif col == "rcrd_dt":
104
+ gold_layer_df[col] = pd.to_datetime(gold_layer_df[col], errors="coerce")
105
+
106
+ # Replace columns starting with 'mda_' to 'media_'
107
+ gold_layer_df.columns = [
108
+ (col.replace("mda_", "media_") if col.startswith("mda_") else col)
109
+ for col in gold_layer_df.columns
110
+ ]
111
+
112
+ # Identify non-numeric columns
113
+ non_numeric_columns = gold_layer_df.select_dtypes(exclude=["number"]).columns
114
+ allow_non_numeric_columns = ["rcrd_dt", "aggrgtn_lvl", "sub_chnnl_nam", "panl_nam"]
115
+
116
+ # Remove non-numeric columns except for allowed non-numeric columns
117
+ non_numeric_columns_to_remove = [
118
+ col for col in non_numeric_columns if col not in allow_non_numeric_columns
119
+ ]
120
+ gold_layer_df = gold_layer_df.drop(
121
+ columns=non_numeric_columns_to_remove, errors="ignore"
122
+ )
123
+
124
+ # Remove specific columns
125
+ allow_columns = ["rcrd_dt", "aggrgtn_lvl", "sub_chnnl_nam", "panl_nam", "spnd_amt"]
126
+ for col in gold_layer_df.columns:
127
+ if (
128
+ col.startswith("rspns_mtrc_")
129
+ or col.startswith("media_")
130
+ or col.startswith("exogenous_")
131
+ or col.startswith("internal_")
132
+ ):
133
+ allow_columns.append(col)
134
+ gold_layer_df = gold_layer_df[allow_columns]
135
+
136
+ # Rename columns
137
+ gold_layer_df = gold_layer_df.rename(
138
+ columns={
139
+ "rcrd_dt": "date",
140
+ "sub_chnnl_nam": "channels",
141
+ "panl_nam": "panel",
142
+ "spnd_amt": "spends",
143
+ }
144
+ )
145
+
146
+ # Clean column values
147
+ gold_layer_df["panel"] = (
148
+ gold_layer_df["panel"].astype(str).str.lower().str.strip().str.replace(" ", "_")
149
+ )
150
+ gold_layer_df["channels"] = (
151
+ gold_layer_df["channels"]
152
+ .astype(str)
153
+ .str.lower()
154
+ .str.strip()
155
+ .str.replace(" ", "_")
156
+ )
157
+
158
+ # Replace columns starting with 'rspns_mtrc_' to 'response_metric_'
159
+ gold_layer_df.columns = [
160
+ (
161
+ col.replace("rspns_mtrc_", "response_metric_")
162
+ if col.startswith("rspns_mtrc_")
163
+ else col
164
+ )
165
+ for col in gold_layer_df.columns
166
+ ]
167
+
168
+ # Get the minimum date from the main dataframe
169
+ min_date = gold_layer_df["date"].min()
170
+
171
+ # Get maximum dates for daily and weekly data
172
+ max_date_daily = None
173
+ max_date_weekly = None
174
+
175
+ if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "daily"].empty:
176
+ max_date_daily = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "daily"][
177
+ "date"
178
+ ].max()
179
+
180
+ if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"].empty:
181
+ max_date_weekly = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"][
182
+ "date"
183
+ ].max() + pd.DateOffset(days=6)
184
+
185
+ # Determine final maximum date
186
+ if max_date_daily is not None and max_date_weekly is not None:
187
+ final_max_date = max(max_date_daily, max_date_weekly)
188
+ elif max_date_daily is not None:
189
+ final_max_date = max_date_daily
190
+ elif max_date_weekly is not None:
191
+ final_max_date = max_date_weekly
192
+
193
+ # Create a date range with daily frequency
194
+ date_range = pd.date_range(start=min_date, end=final_max_date, freq="D")
195
+
196
+ # Create a base DataFrame with all channels and all panels for each channel
197
+ unique_channels = gold_layer_df["channels"].unique()
198
+ unique_panels = gold_layer_df["panel"].unique()
199
+ base_data = [
200
+ (channel, panel, date)
201
+ for channel in unique_channels
202
+ for panel in unique_panels
203
+ for date in date_range
204
+ ]
205
+ base_df = pd.DataFrame(base_data, columns=["channels", "panel", "date"])
206
+
207
+ # Process weekly data to convert it to daily
208
+ if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"].empty:
209
+ weekly_data = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"].copy()
210
+ daily_data = []
211
+
212
+ for index, row in weekly_data.iterrows():
213
+ week_start = pd.to_datetime(row["date"]) - pd.to_timedelta(
214
+ pd.to_datetime(row["date"]).weekday(), unit="D"
215
+ )
216
+ for i in range(7):
217
+ daily_date = week_start + pd.DateOffset(days=i)
218
+ new_row = row.copy()
219
+ new_row["date"] = daily_date
220
+ for col in new_row.index:
221
+ if isinstance(new_row[col], numbers.Number):
222
+ new_row[col] = new_row[col] / 7
223
+ daily_data.append(new_row)
224
+
225
+ daily_data_df = pd.DataFrame(daily_data)
226
+ daily_data_df["aggrgtn_lvl"] = "daily"
227
+ gold_layer_df = pd.concat(
228
+ [gold_layer_df[gold_layer_df["aggrgtn_lvl"] != "weekly"], daily_data_df],
229
+ ignore_index=True,
230
+ )
231
+
232
+ # Process monthly data to convert it to daily
233
+ if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "monthly"].empty:
234
+ monthly_data = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "monthly"].copy()
235
+ daily_data = []
236
+
237
+ for index, row in monthly_data.iterrows():
238
+ month_start = pd.to_datetime(row["date"]).replace(day=1)
239
+ next_month_start = (month_start + pd.DateOffset(months=1)).replace(day=1)
240
+ days_in_month = (next_month_start - month_start).days
241
+
242
+ for i in range(days_in_month):
243
+ daily_date = month_start + pd.DateOffset(days=i)
244
+ new_row = row.copy()
245
+ new_row["date"] = daily_date
246
+ for col in new_row.index:
247
+ if isinstance(new_row[col], numbers.Number):
248
+ new_row[col] = new_row[col] / days_in_month
249
+ daily_data.append(new_row)
250
+
251
+ daily_data_df = pd.DataFrame(daily_data)
252
+ daily_data_df["aggrgtn_lvl"] = "daily"
253
+ gold_layer_df = pd.concat(
254
+ [gold_layer_df[gold_layer_df["aggrgtn_lvl"] != "monthly"], daily_data_df],
255
+ ignore_index=True,
256
+ )
257
+
258
+ # Remove aggrgtn_lvl column
259
+ gold_layer_df = gold_layer_df.drop(columns=["aggrgtn_lvl"], errors="ignore")
260
+
261
+ # Group by 'panel', and 'date'
262
+ gold_layer_df = gold_layer_df.groupby(["channels", "panel", "date"]).sum()
263
+
264
+ # Merge gold_layer_df to base_df on channels, panel and date
265
+ gold_layer_df_cleaned = pd.merge(
266
+ base_df, gold_layer_df, on=["channels", "panel", "date"], how="left"
267
+ )
268
+
269
+ # Pivot the dataframe and rename columns
270
+ pivot_columns = [
271
+ col
272
+ for col in gold_layer_df_cleaned.columns
273
+ if col not in ["channels", "panel", "date"]
274
+ ]
275
+ gold_layer_df_cleaned = gold_layer_df_cleaned.pivot_table(
276
+ index=["date", "panel"], columns="channels", values=pivot_columns, aggfunc="sum"
277
+ ).reset_index()
278
+
279
+ # Flatten the columns
280
+ gold_layer_df_cleaned.columns = [
281
+ "_".join(col).strip() if col[1] else col[0]
282
+ for col in gold_layer_df_cleaned.columns.values
283
+ ]
284
+
285
+ # Replace columns ending with '_all' to '_total'
286
+ gold_layer_df_cleaned.columns = [
287
+ col.replace("_all", "_total") if col.endswith("_all") else col
288
+ for col in gold_layer_df_cleaned.columns
289
+ ]
290
+
291
+ # Clean panel column values
292
+ gold_layer_df_cleaned["panel"] = (
293
+ gold_layer_df_cleaned["panel"]
294
+ .astype(str)
295
+ .str.lower()
296
+ .str.strip()
297
+ .str.replace(" ", "_")
298
+ )
299
+
300
+ # Drop all columns that end with '_total' except those starting with 'response_metric_'
301
+ cols_to_drop = [
302
+ col
303
+ for col in gold_layer_df_cleaned.columns
304
+ if col.endswith("_total") and not col.startswith("response_metric_")
305
+ ]
306
+ gold_layer_df_cleaned.drop(columns=cols_to_drop, inplace=True)
307
+
308
+ return gold_layer_df_cleaned
309
+
310
+
311
+ def check_valid_name():
312
+ if (
313
+ not st.session_state["project_name_box"]
314
+ .lower()
315
+ .startswith(defualt_project_prefix)
316
+ ):
317
+ st.session_state["disable_create_project"] = True
318
+ with warning_box:
319
+ st.warning("Project Name should follow naming conventions")
320
+ st.session_state["warning"] = (
321
+ "Project Name should follow naming conventions!"
322
+ )
323
+
324
+ with warning_box:
325
+ st.warning("Project Name should follow naming conventions")
326
+ st.button("Reset Name", on_click=reset_project_text_box, key="2")
327
+
328
+ if st.session_state["project_name_box"] == defualt_project_prefix:
329
+ with warning_box:
330
+ st.warning("Cannot Name only with Prefix")
331
+ st.session_state["warning"] = "Cannot Name only with Prefix"
332
+ st.session_state["disable_create_project"] = True
333
+
334
+ if st.session_state["project_name_box"] in user_projects:
335
+ with warning_box:
336
+ st.warning("Project already exists please enter new name")
337
+ st.session_state["warning"] = "Project already exists please enter new name"
338
+ st.session_state["disable_create_project"] = True
339
+ else:
340
+ st.session_state["disable_create_project"] = False
341
+
342
+
343
+ def query_excecuter_postgres(
344
+ query,
345
+ db_path=None,
346
+ params=None,
347
+ insert=True,
348
+ insert_retrieve=False,
349
+ db_cred=None,
350
+ ):
351
+ """
352
+ Executes a SQL query on a SQLite database, handling both insert and select operations.
353
+
354
+ Parameters:
355
+ query (str): The SQL query to be executed.
356
+ db_path (str): Path to the SQLite database file.
357
+ params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
358
+ insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
359
+ insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
360
+
361
+ """
362
+ try:
363
+ # Construct a cross-platform path to the database
364
+ db_dir = os.path.join("db")
365
+ os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists
366
+ db_path = os.path.join(db_dir, "imp_db.db")
367
+
368
+ # Establish connection to the SQLite database
369
+ conn = sqlite3.connect(db_path)
370
+ except sqlite3.Error as e:
371
+ st.warning(f"Unable to connect to the SQLite database: {e}")
372
+ st.stop()
373
+
374
+ # Create a cursor object to interact with the database
375
+ c = conn.cursor()
376
+
377
+ try:
378
+ # Execute the query with or without parameters
379
+ if params:
380
+ params = tuple(params)
381
+ query = query.replace("IN (?)", f"IN ({','.join(['?' for _ in params])})")
382
+ c.execute(query, params)
383
+ else:
384
+ c.execute(query)
385
+
386
+ if not insert:
387
+ # If not an insert operation, fetch and return the results
388
+ results = c.fetchall()
389
+ return results
390
+ elif insert_retrieve:
391
+ # If insert and retrieve operation, commit and return the last inserted row ID
392
+ conn.commit()
393
+ return c.lastrowid
394
+ else:
395
+ # For standard insert operations, commit the transaction
396
+ conn.commit()
397
+
398
+ except Exception as e:
399
+ st.write(f"Error executing query: {e}")
400
+ finally:
401
+ conn.close()
402
+
403
+
404
+ # Function to check if the input contains any SQL keywords
405
+ def contains_sql_keywords_check(user_input):
406
+
407
+ sql_keywords = [
408
+ "SELECT",
409
+ "INSERT",
410
+ "UPDATE",
411
+ "DELETE",
412
+ "DROP",
413
+ "ALTER",
414
+ "CREATE",
415
+ "GRANT",
416
+ "REVOKE",
417
+ "UNION",
418
+ "JOIN",
419
+ "WHERE",
420
+ "HAVING",
421
+ "EXEC",
422
+ "TRUNCATE",
423
+ "REPLACE",
424
+ "MERGE",
425
+ "DECLARE",
426
+ "SHOW",
427
+ "FROM",
428
+ ]
429
+
430
+ pattern = "|".join(re.escape(keyword) for keyword in sql_keywords)
431
+ return re.search(pattern, user_input, re.IGNORECASE)
432
+
433
+
434
+ # def get_table_names(schema):
435
+ # query = f"""
436
+ # SELECT table_name
437
+ # FROM information_schema.tables
438
+ # WHERE table_schema = '{schema}'
439
+ # AND table_type = 'BASE TABLE'
440
+ # AND table_name LIKE '%_mmo_gold';
441
+ # """
442
+ # table_names = query_excecuter_postgres(query, db_cred, insert=False)
443
+ # table_names = [table[0] for table in table_names]
444
+
445
+ # return table_names
446
+
447
+
448
+ def update_summary_df():
449
+ """
450
+ Updates the 'project_summary_df' in the session state with the latest project
451
+ summary information based on the most recent updates.
452
+
453
+ This function executes a SQL query to retrieve project metadata from a database
454
+ and stores the result in the session state.
455
+
456
+ Uses:
457
+ - query_excecuter_postgres(query, params=params, insert=False): A function that
458
+ executes the provided SQL query on a PostgreSQL database.
459
+
460
+ Modifies:
461
+ - st.session_state['project_summary_df']: Updates the dataframe with columns:
462
+ 'Project Number', 'Project Name', 'Last Modified Page', 'Last Modified Time'.
463
+ """
464
+
465
+ query = f"""
466
+ WITH LatestUpdates AS (
467
+ SELECT
468
+ prj_id,
469
+ page_nam,
470
+ updt_dt_tm,
471
+ ROW_NUMBER() OVER (PARTITION BY prj_id ORDER BY updt_dt_tm DESC) AS rn
472
+ FROM
473
+ mmo_project_meta_data
474
+ )
475
+ SELECT
476
+ p.prj_id,
477
+ p.prj_nam AS prj_nam,
478
+ lu.page_nam,
479
+ lu.updt_dt_tm
480
+ FROM
481
+ LatestUpdates lu
482
+ RIGHT JOIN
483
+ mmo_projects p ON lu.prj_id = p.prj_id
484
+ WHERE
485
+ p.prj_ownr_id = ? AND lu.rn = 1
486
+ """
487
+
488
+ params = (st.session_state["emp_id"],) # Parameters for the SQL query
489
+
490
+ # Execute the query and retrieve project summary data
491
+ project_summary = query_excecuter_postgres(
492
+ query, db_cred, params=params, insert=False
493
+ )
494
+
495
+ # Update the session state with the project summary dataframe
496
+ st.session_state["project_summary_df"] = pd.DataFrame(
497
+ project_summary,
498
+ columns=[
499
+ "Project Number",
500
+ "Project Name",
501
+ "Last Modified Page",
502
+ "Last Modified Time",
503
+ ],
504
+ )
505
+
506
+ st.session_state["project_summary_df"] = st.session_state[
507
+ "project_summary_df"
508
+ ].sort_values(by=["Last Modified Time"], ascending=False)
509
+
510
+
511
+ def reset_project_text_box():
512
+ st.session_state["project_name_box"] = defualt_project_prefix
513
+ st.session_state["disable_create_project"] = True
514
+
515
+
516
+ def query_excecuter_sqlite(
517
+ insert_projects_query,
518
+ insert_meta_data_query,
519
+ db_path=None,
520
+ params_projects=None,
521
+ params_meta=None,
522
+ ):
523
+ """
524
+ Executes the project insert and associated metadata insert in an SQLite database.
525
+
526
+ Parameters:
527
+ insert_projects_query (str): SQL query for inserting into the mmo_projects table.
528
+ insert_meta_data_query (str): SQL query for inserting into the mmo_project_meta_data table.
529
+ db_path (str): Path to the SQLite database file.
530
+ params_projects (tuple, optional): Parameters for the mmo_projects table insert.
531
+ params_meta (tuple, optional): Parameters for the mmo_project_meta_data table insert.
532
+
533
+ Returns:
534
+ bool: True if successful, False otherwise.
535
+ """
536
+ try:
537
+ # Construct a cross-platform path to the database
538
+ db_dir = os.path.join("db")
539
+ os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists
540
+ db_path = os.path.join(db_dir, "imp_db.db")
541
+
542
+ # Establish connection to the SQLite database
543
+ conn = sqlite3.connect(db_path)
544
+ cursor = conn.cursor()
545
+
546
+ # Execute the first insert query into the mmo_projects table
547
+ cursor.execute(insert_projects_query, params_projects)
548
+
549
+ # Get the last inserted project ID
550
+ prj_id = cursor.lastrowid
551
+
552
+ # Modify the parameters for the metadata table with the inserted prj_id
553
+ params_meta = (prj_id,) + params_meta
554
+
555
+ # Execute the second insert query into the mmo_project_meta_data table
556
+ cursor.execute(insert_meta_data_query, params_meta)
557
+
558
+ # Commit the transaction
559
+ conn.commit()
560
+
561
+ except sqlite3.Error as e:
562
+ st.warning(f"Error executing query: {e}")
563
+ return False
564
+ finally:
565
+ # Close the connection
566
+ conn.close()
567
+
568
+ return True
569
+
570
+
571
+ def new_project():
572
+ """
573
+ Cleans the project name input and inserts project data into the SQLite database,
574
+ updating session state and triggering UI rerun if successful.
575
+ """
576
+
577
+ # Define a dictionary containing project data
578
+ project_dct = default_dct.copy()
579
+
580
+ gold_layer_df = pd.DataFrame()
581
+ if str(api_name).strip().lower() != "na":
582
+ try:
583
+ gold_layer_df = load_gold_layer_data(api_name)
584
+ except Exception as e:
585
+ st.toast(
586
+ "Failed to load gold layer data. Please check the gold layer structure and connection.",
587
+ icon="⚠️",
588
+ )
589
+ log_message(
590
+ "error",
591
+ f"Error loading gold layer data: {str(e)}",
592
+ "Home",
593
+ )
594
+
595
+ project_dct["data_import"]["gold_layer_df"] = gold_layer_df
596
+
597
+ # Get current time for database insertion
598
+ inserted_time = datetime.now().isoformat()
599
+
600
+ # Define SQL queries for inserting project and metadata into the SQLite database
601
+ insert_projects_query = """
602
+ INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_dt_tm, crte_by_uid)
603
+ VALUES (?, ?, ?, ?, ?);
604
+ """
605
+ insert_meta_data_query = """
606
+ INSERT INTO mmo_project_meta_data (prj_id, page_nam, file_nam, pkl_obj, crte_dt_tm, crte_by_uid, updt_dt_tm)
607
+ VALUES (?, ?, ?, ?, ?, ?, ?);
608
+ """
609
+
610
+ # Get current time for metadata update
611
+ updt_dt_tm = datetime.now().isoformat()
612
+
613
+ # Serialize project_dct using pickle
614
+ project_pkl = pickle.dumps(project_dct)
615
+
616
+ # Prepare data for database insertion
617
+ projects_data = (
618
+ st.session_state["emp_id"], # prj_ownr_id
619
+ project_name, # prj_nam
620
+ ",".join(matching_user_id), # alwd_emp_id
621
+ inserted_time, # crte_dt_tm
622
+ st.session_state["emp_id"], # crte_by_uid
623
+ )
624
+
625
+ project_meta_data = (
626
+ "Home", # page_nam
627
+ "project_dct", # file_nam
628
+ project_pkl, # pkl_obj
629
+ inserted_time, # crte_dt_tm
630
+ st.session_state["emp_id"], # crte_by_uid
631
+ updt_dt_tm, # updt_dt_tm
632
+ )
633
+
634
+ # Execute the insertion query for SQLite
635
+ success = query_excecuter_sqlite(
636
+ insert_projects_query,
637
+ insert_meta_data_query,
638
+ params_projects=projects_data,
639
+ params_meta=project_meta_data,
640
+ )
641
+
642
+ if success:
643
+ st.success("Project Created")
644
+ update_summary_df()
645
+ else:
646
+ st.error("Failed to create project.")
647
+
648
+
649
+ def validate_password(user_input):
650
+ # List of SQL keywords to check for
651
+ sql_keywords = [
652
+ "SELECT",
653
+ "INSERT",
654
+ "UPDATE",
655
+ "DELETE",
656
+ "DROP",
657
+ "ALTER",
658
+ "CREATE",
659
+ "GRANT",
660
+ "REVOKE",
661
+ "UNION",
662
+ "JOIN",
663
+ "WHERE",
664
+ "HAVING",
665
+ "EXEC",
666
+ "TRUNCATE",
667
+ "REPLACE",
668
+ "MERGE",
669
+ "DECLARE",
670
+ "SHOW",
671
+ "FROM",
672
+ ]
673
+
674
+ # Create a regex pattern for SQL keywords
675
+ pattern = "|".join(re.escape(keyword) for keyword in sql_keywords)
676
+
677
+ # Check if input contains any SQL keywords
678
+ if re.search(pattern, user_input, re.IGNORECASE):
679
+ return "SQL keyword detected."
680
+
681
+ # Password validation criteria
682
+ if len(user_input) < 8:
683
+ return "Password should be at least 8 characters long."
684
+ if not re.search(r"[A-Z]", user_input):
685
+ return "Password should contain at least one uppercase letter."
686
+ if not re.search(r"[0-9]", user_input):
687
+ return "Password should contain at least one digit."
688
+ if not re.search(r"[a-z]", user_input):
689
+ return "Password should contain at least one lowercase letter."
690
+ if not re.search(r'[!@#$%^&*(),.?":{}|<>]', user_input):
691
+ return "Password should contain at least one special character."
692
+
693
+ # If all checks pass
694
+ return "Valid input."
695
+
696
+
697
+ def fetch_and_process_projects(emp_id):
698
+ query = f"""
699
+ WITH ProjectAccess AS (
700
+ SELECT
701
+ p.prj_id,
702
+ p.prj_nam,
703
+ p.alwd_emp_id,
704
+ u.emp_nam AS project_owner
705
+ FROM mmo_projects p
706
+ JOIN mmo_users u ON p.prj_ownr_id = u.emp_id
707
+ )
708
+ SELECT
709
+ pa.prj_id,
710
+ pa.prj_nam,
711
+ pa.project_owner
712
+ FROM
713
+ ProjectAccess pa
714
+ WHERE
715
+ pa.alwd_emp_id LIKE ?
716
+ ORDER BY
717
+ pa.prj_id;
718
+ """
719
+
720
+ params = (f"%{emp_id}%",)
721
+ results = query_excecuter_postgres(query, db_cred, params=params, insert=False)
722
+
723
+ # Process the results to create the desired dictionary structure
724
+ clone_project_dict = {}
725
+ for row in results:
726
+ project_id, project_name, project_owner = row
727
+
728
+ if project_owner not in clone_project_dict:
729
+ clone_project_dict[project_owner] = []
730
+
731
+ clone_project_dict[project_owner].append(
732
+ {"project_name": project_name, "project_id": project_id}
733
+ )
734
+
735
+ return clone_project_dict
736
+
737
+
738
+ def get_project_id_from_dict(projects_dict, owner_name, project_name):
739
+ if owner_name in projects_dict:
740
+ for project in projects_dict[owner_name]:
741
+ if project["project_name"] == project_name:
742
+ return project["project_id"]
743
+ return None
744
+
745
+
746
+ # def fetch_project_metadata(prj_id):
747
+ # query = f"""
748
+ # SELECT
749
+ # prj_id, page_nam, file_nam, pkl_obj, dshbrd_ts
750
+ # FROM
751
+ # mmo_project_meta_data
752
+ # WHERE
753
+ # prj_id = ?;
754
+ # """
755
+
756
+ # params = (prj_id,)
757
+ # return query_excecuter_postgres(query, db_cred, params=params, insert=False)
758
+
759
+
760
+ def fetch_project_metadata(prj_id):
761
+ # Query to select project metadata
762
+ query = """
763
+ SELECT
764
+ prj_id, page_nam, file_nam, pkl_obj, dshbrd_ts
765
+ FROM
766
+ mmo_project_meta_data
767
+ WHERE
768
+ prj_id = ?;
769
+ """
770
+
771
+ params = (prj_id,)
772
+ return query_excecuter_postgres(query, db_cred, params=params, insert=False)
773
+
774
+
775
+ # def create_new_project(prj_ownr_id, prj_nam, alwd_emp_id, emp_id):
776
+ # query = f"""
777
+ # INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_by_uid, crte_dt_tm)
778
+ # VALUES (?, ?, ?, ?, NOW())
779
+ # RETURNING prj_id;
780
+ # """
781
+
782
+ # params = (prj_ownr_id, prj_nam, alwd_emp_id, emp_id)
783
+ # result = query_excecuter_postgres(
784
+ # query, db_cred, params=params, insert=True, insert_retrieve=True
785
+ # )
786
+ # return result[0][0]
787
+
788
+
789
+ # def create_new_project(prj_ownr_id, prj_nam, alwd_emp_id, emp_id):
790
+ # # Query to insert a new project
791
+ # insert_query = """
792
+ # INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_by_uid, crte_dt_tm)
793
+ # VALUES (?, ?, ?, ?, DATETIME('now'));
794
+ # """
795
+
796
+ # params = (prj_ownr_id, prj_nam, alwd_emp_id, emp_id)
797
+
798
+ # # Execute the insert query
799
+ # query_excecuter_postgres(insert_query, db_cred, params=params, insert=True)
800
+
801
+ # # Retrieve the last inserted prj_id
802
+ # retrieve_id_query = "SELECT last_insert_rowid();"
803
+ # result = query_excecuter_postgres(retrieve_id_query, db_cred, insert_retrieve=True)
804
+
805
+ # return result[0][0]
806
+
807
+
808
+ def create_new_project(prj_ownr_id, prj_nam, alwd_emp_id, emp_id):
809
+ # Query to insert a new project
810
+ insert_query = """
811
+ INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_by_uid, crte_dt_tm)
812
+ VALUES (?, ?, ?, ?, DATETIME('now'));
813
+ """
814
+ params = (prj_ownr_id, prj_nam, alwd_emp_id, emp_id)
815
+
816
+ # Execute the insert query and retrieve the last inserted prj_id directly
817
+ last_inserted_id = query_excecuter_postgres(
818
+ insert_query, params=params, insert_retrieve=True
819
+ )
820
+
821
+ return last_inserted_id
822
+
823
+
824
+ def insert_project_metadata(new_prj_id, metadata, created_emp_id):
825
+ # query = f"""
826
+ # INSERT INTO mmo_project_meta_data (
827
+ # prj_id, page_nam, crte_dt_tm, file_nam, pkl_obj, dshbrd_ts, crte_by_uid
828
+ # )
829
+ # VALUES (?, ?, NOW(), ?, ?, ?, ?);
830
+ # """
831
+
832
+ query = """
833
+ INSERT INTO mmo_project_meta_data (
834
+ prj_id, page_nam, crte_dt_tm, file_nam, pkl_obj, dshbrd_ts, crte_by_uid
835
+ )
836
+ VALUES (?, ?, DATETIME('now'), ?, ?, ?, ?);
837
+ """
838
+
839
+ for row in metadata:
840
+ params = (new_prj_id, row[1], row[2], row[3], row[4], created_emp_id)
841
+ query_excecuter_postgres(query, db_cred, params=params, insert=True)
842
+
843
+
844
+ # def delete_projects_by_ids(prj_ids):
845
+ # # Ensure prj_ids is a tuple to use with the IN clause
846
+ # prj_ids_tuple = tuple(prj_ids)
847
+
848
+ # # Query to delete project metadata
849
+ # delete_metadata_query = f"""
850
+ # DELETE FROM mmo_project_meta_data
851
+ # WHERE prj_id IN ?;
852
+ # """
853
+
854
+ # delete_projects_query = f"""
855
+ # DELETE FROM mmo_projects
856
+ # WHERE prj_id IN ?;
857
+ # """
858
+
859
+ # try:
860
+ # # Delete from metadata table
861
+ # query_excecuter_postgres(
862
+ # delete_metadata_query, db_cred, params=(prj_ids_tuple,), insert=True
863
+ # )
864
+
865
+ # # Delete from projects table
866
+ # query_excecuter_postgres(
867
+ # delete_projects_query, db_cred, params=(prj_ids_tuple,), insert=True
868
+ # )
869
+
870
+ # except Exception as e:
871
+ # st.write(f"Error deleting projects: {e}")
872
+
873
+
874
+ def delete_projects_by_ids(prj_ids):
875
+ # Ensure prj_ids is a tuple to use with the IN clause
876
+ prj_ids_tuple = tuple(prj_ids)
877
+
878
+ # Dynamically generate placeholders for SQLite
879
+ placeholders = ", ".join(["?"] * len(prj_ids_tuple))
880
+
881
+ # Query to delete project metadata with dynamic placeholders
882
+ delete_metadata_query = f"""
883
+ DELETE FROM mmo_project_meta_data
884
+ WHERE prj_id IN ({placeholders});
885
+ """
886
+
887
+ delete_projects_query = f"""
888
+ DELETE FROM mmo_projects
889
+ WHERE prj_id IN ({placeholders});
890
+ """
891
+
892
+ try:
893
+ # Delete from metadata table
894
+ query_excecuter_postgres(
895
+ delete_metadata_query, db_cred, params=prj_ids_tuple, insert=True
896
+ )
897
+
898
+ # Delete from projects table
899
+ query_excecuter_postgres(
900
+ delete_projects_query, db_cred, params=prj_ids_tuple, insert=True
901
+ )
902
+
903
+ except Exception as e:
904
+ st.write(f"Error deleting projects: {e}")
905
+
906
+
907
+ def fetch_users_with_access(prj_id):
908
+ # Query to get allowed employee IDs for the project
909
+ get_allowed_emps_query = """
910
+ SELECT alwd_emp_id
911
+ FROM mmo_projects
912
+ WHERE prj_id = ?;
913
+ """
914
+
915
+ # Fetch the allowed employee IDs
916
+ allowed_emp_ids_result = query_excecuter_postgres(
917
+ get_allowed_emps_query, db_cred, params=(prj_id,), insert=False
918
+ )
919
+
920
+ if not allowed_emp_ids_result:
921
+ return []
922
+
923
+ # Extract the allowed employee IDs (Assuming alwd_emp_id is a comma-separated string)
924
+ allowed_emp_ids_str = allowed_emp_ids_result[0][0]
925
+ allowed_emp_ids = allowed_emp_ids_str.split(",")
926
+
927
+ # Convert to tuple for the IN clause
928
+ allowed_emp_ids_tuple = tuple(allowed_emp_ids)
929
+
930
+ # Query to get user details for the allowed employee IDs
931
+ get_users_query = """
932
+ SELECT emp_id, emp_nam, emp_typ
933
+ FROM mmo_users
934
+ WHERE emp_id IN ({});
935
+ """.format(
936
+ ",".join("?" * len(allowed_emp_ids_tuple))
937
+ ) # Dynamically construct the placeholder list
938
+
939
+ # Fetch user details
940
+ user_details = query_excecuter_postgres(
941
+ get_users_query, db_cred, params=allowed_emp_ids_tuple, insert=False
942
+ )
943
+
944
+ return user_details
945
+
946
+
947
+ # def update_project_access(prj_id, user_names, new_user_ids):
948
+ # # Convert the list of new user IDs to a comma-separated string
949
+ # new_user_ids_str = ",".join(new_user_ids)
950
+
951
+ # # Query to update the alwd_emp_id for the specified project
952
+ # update_access_query = f"""
953
+ # UPDATE mmo_projects
954
+ # SET alwd_emp_id = ?
955
+ # WHERE prj_id = ?;
956
+ # """
957
+
958
+ # # Execute the update query
959
+ # query_excecuter_postgres(
960
+ # update_access_query, db_cred, params=(new_user_ids_str, prj_id), insert=True
961
+ # )
962
+ # st.success(f"Project {prj_id} access updated successfully")
963
+
964
+
965
+ def fetch_user_ids_from_dict(user_dict, user_names):
966
+ user_ids = []
967
+ # Iterate over the user_dict to find matching user names
968
+ for user_id, details in user_dict.items():
969
+ if details[0] in user_names:
970
+ user_ids.append(user_id)
971
+ return user_ids
972
+
973
+
974
+ def update_project_access(prj_id, user_names, user_dict):
975
+ # Fetch the new user IDs based on the provided user names from the dictionary
976
+ new_user_ids = fetch_user_ids_from_dict(user_dict, user_names)
977
+
978
+ # Convert the list of new user IDs to a comma-separated string
979
+ new_user_ids_str = ",".join(new_user_ids)
980
+
981
+ # Query to update the alwd_emp_id for the specified project
982
+ update_access_query = f"""
983
+ UPDATE mmo_projects
984
+ SET alwd_emp_id = ?
985
+ WHERE prj_id = ?;
986
+ """
987
+
988
+ # Execute the update query
989
+ query_excecuter_postgres(
990
+ update_access_query, db_cred, params=(new_user_ids_str, prj_id), insert=True
991
+ )
992
+ st.write(f"Project {prj_id} access updated successfully")
993
+
994
+
995
+ def validate_emp_id():
996
+
997
+ if st.session_state.sign_up not in st.session_state["unique_ids"].keys():
998
+ st.warning("You dont have access to the tool please contact admin")
999
+
1000
+
1001
+ # -------------------Front END-------------------------#
1002
+
1003
+ st.header("Manage Projects")
1004
+
1005
+ unique_users_query = f"""
1006
+ SELECT DISTINCT emp_id, emp_nam, emp_typ
1007
+ FROM mmo_users;
1008
+ """
1009
+
1010
+ if "unique_ids" not in st.session_state:
1011
+
1012
+ unique_users_result = query_excecuter_postgres(
1013
+ unique_users_query, db_cred, insert=False
1014
+ ) # retrieves all the users who has access to MMO TOOL
1015
+
1016
+ if len(unique_users_result) == 0:
1017
+ st.warning("No users data present in db, please contact admin!")
1018
+ st.stop()
1019
+
1020
+ st.session_state["unique_ids"] = {
1021
+ emp_id: (emp_nam, emp_type) for emp_id, emp_nam, emp_type in unique_users_result
1022
+ }
1023
+
1024
+
1025
+ if "toggle" not in st.session_state:
1026
+ st.session_state["toggle"] = 0
1027
+
1028
+
1029
+ if "emp_id" not in st.session_state:
1030
+ reset_password = st.radio(
1031
+ "Select An Option",
1032
+ options=["Login", "Reset Password"],
1033
+ index=st.session_state["toggle"],
1034
+ horizontal=True,
1035
+ )
1036
+
1037
+ if reset_password == "Login":
1038
+ emp_id = st.text_input("Employee id").lower() # emp id
1039
+ password = st.text_input("Password", max_chars=15, type="password")
1040
+ login_button = st.button("Login", use_container_width=True)
1041
+
1042
+ else:
1043
+
1044
+ emp_id = st.text_input(
1045
+ "Employee id", key="sign_up", on_change=validate_emp_id
1046
+ ).lower()
1047
+
1048
+ current_password = st.text_input(
1049
+ "Enter Current Password and Press Enter to Validate",
1050
+ max_chars=15,
1051
+ type="password",
1052
+ key="current_password",
1053
+ )
1054
+ if emp_id:
1055
+
1056
+ if emp_id not in st.session_state["unique_ids"].keys():
1057
+ st.write("Invalid id!")
1058
+ st.stop()
1059
+ else:
1060
+ if not is_pswrd_flag_set(emp_id):
1061
+
1062
+ if verify_password(emp_id, current_password):
1063
+ st.success("Your password key has been successfully validated!")
1064
+
1065
+ elif (
1066
+ not verify_password(emp_id, current_password)
1067
+ and len(current_password) > 1
1068
+ ):
1069
+ st.write("Wrong Password Key Please Try Again")
1070
+ st.stop()
1071
+
1072
+ elif verify_password(emp_id, current_password):
1073
+ st.success("Your password has been successfully validated!")
1074
+
1075
+ elif (
1076
+ not verify_password(emp_id, current_password)
1077
+ and len(current_password) > 1
1078
+ ):
1079
+ st.write("Wrong Password Please Try Again")
1080
+ st.stop()
1081
+
1082
+ new_password = st.text_input(
1083
+ "Enter New Password", max_chars=15, type="password", key="new_password"
1084
+ )
1085
+
1086
+ st.markdown(
1087
+ "**Password must be at least 8 to 15 characters long and contain at least one uppercase letter, one lowercase letter, one digit, and one special character. No SQL commands allowed.**"
1088
+ )
1089
+
1090
+ validation_result = validate_password(new_password)
1091
+
1092
+ confirm_new_password = st.text_input(
1093
+ "Confirm New Password",
1094
+ max_chars=15,
1095
+ type="password",
1096
+ key="confirm_new_password",
1097
+ )
1098
+
1099
+ reset_button = st.button("Reset Password", use_container_width=True)
1100
+
1101
+ if reset_button:
1102
+
1103
+ validation_result = validate_password(new_password)
1104
+
1105
+ if validation_result != "Valid input.":
1106
+ st.warning(validation_result)
1107
+ st.stop()
1108
+ elif new_password != confirm_new_password:
1109
+ st.warning(
1110
+ "The new password and confirmation password do not match. Please try again."
1111
+ )
1112
+ st.stop()
1113
+ else:
1114
+ store_hashed_password(emp_id, confirm_new_password)
1115
+ set_pswrd_flag(emp_id)
1116
+ st.success("Password Reset Successful!")
1117
+
1118
+ with st.spinner("Redirecting to Login"):
1119
+ time.sleep(3)
1120
+ st.session_state["toggle"] = 0
1121
+ st.rerun()
1122
+
1123
+ st.stop()
1124
+
1125
+ if login_button:
1126
+
1127
+ if emp_id not in st.session_state["unique_ids"].keys() or len(password) == 0:
1128
+ st.warning("invalid id or password!")
1129
+
1130
+ st.stop()
1131
+
1132
+ if not is_pswrd_flag_set(emp_id):
1133
+ st.warning("Reset password to continue")
1134
+ with st.spinner("Redirecting"):
1135
+ st.session_state["toggle"] = 1
1136
+ time.sleep(2)
1137
+ st.rerun()
1138
+ st.stop()
1139
+
1140
+ elif verify_password(emp_id, password):
1141
+ with st.spinner("Loading Saved Projects"):
1142
+ st.session_state["emp_id"] = emp_id
1143
+
1144
+ update_summary_df() # function call to fetch user saved projects
1145
+
1146
+ st.session_state["clone_project_dict"] = fetch_and_process_projects(
1147
+ st.session_state["emp_id"]
1148
+ )
1149
+ if "project_dct" in st.session_state:
1150
+ del st.session_state["project_dct"]
1151
+
1152
+ st.session_state["project_name"] = None
1153
+
1154
+ delete_old_log_files()
1155
+
1156
+ st.rerun()
1157
+
1158
+ if (
1159
+ len(st.session_state["emp_id"]) == 0
1160
+ or st.session_state["emp_id"]
1161
+ not in st.session_state["unique_ids"].keys()
1162
+ ):
1163
+ st.stop()
1164
+ else:
1165
+ st.warning("Invalid user name or password")
1166
+
1167
+ st.stop()
1168
+
1169
+ if st.button("Logout"):
1170
+ if "emp_id" in st.session_state:
1171
+ del st.session_state["emp_id"]
1172
+ st.rerun()
1173
+
1174
+ if st.session_state["emp_id"] in st.session_state["unique_ids"].keys():
1175
+
1176
+ if "project_name" not in st.session_state:
1177
+ st.session_state["project_name"] = None
1178
+
1179
+ cols1 = st.columns([2, 1])
1180
+
1181
+ st.session_state["username"] = st.session_state["unique_ids"][
1182
+ st.session_state["emp_id"]
1183
+ ][0]
1184
+
1185
+ with cols1[0]:
1186
+ st.markdown(f"**Welcome {st.session_state['username']}**")
1187
+ with cols1[1]:
1188
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
1189
+
1190
+ st.markdown(
1191
+ """
1192
+ Enter project number in the text below and click on load project to load the project.
1193
+
1194
+ """
1195
+ )
1196
+
1197
+ st.markdown("Select Project")
1198
+
1199
+ # st.write(type(st.session_state.keys))
1200
+
1201
+ if len(st.session_state["project_summary_df"]) != 0:
1202
+
1203
+ # Display an editable data table using Streamlit's data editor component
1204
+
1205
+ table = st.dataframe(
1206
+ st.session_state["project_summary_df"],
1207
+ use_container_width=True,
1208
+ hide_index=True,
1209
+ )
1210
+
1211
+ project_number = st.selectbox(
1212
+ "Enter Project number",
1213
+ options=st.session_state["project_summary_df"]["Project Number"],
1214
+ )
1215
+
1216
+ log_message(
1217
+ "info",
1218
+ f"Project number {project_number} selected by employee {st.session_state['emp_id']}.",
1219
+ "Home",
1220
+ )
1221
+
1222
+ project_col = st.columns(2)
1223
+
1224
+ # if "load_project_key" not in st.session_state:
1225
+ # st.session_state["load_project_key"] = None\
1226
+
1227
+ def load_project_fun():
1228
+ st.session_state["project_name"] = (
1229
+ st.session_state["project_summary_df"]
1230
+ .loc[
1231
+ st.session_state["project_summary_df"]["Project Number"]
1232
+ == project_number,
1233
+ "Project Name",
1234
+ ]
1235
+ .values[0]
1236
+ ) # fetching project name from project number stored in summary df
1237
+
1238
+ project_dct_query = f"""
1239
+ SELECT pkl_obj
1240
+ FROM mmo_project_meta_data
1241
+ WHERE prj_id = ? AND file_nam = ?;
1242
+ """
1243
+ # Execute the query and retrieve the result
1244
+ project_dct_retrieved = query_excecuter_postgres(
1245
+ project_dct_query,
1246
+ db_cred,
1247
+ params=(project_number, "project_dct"),
1248
+ insert=False,
1249
+ )
1250
+ # retrieves project dict (meta data) stored in db
1251
+
1252
+ st.session_state["project_dct"] = pickle.loads(
1253
+ project_dct_retrieved[0][0]
1254
+ ) # converting bytes data to original objet using pickle
1255
+ st.session_state["project_number"] = project_number
1256
+
1257
+ keys_to_keep = [
1258
+ "unique_ids",
1259
+ "emp_id",
1260
+ "project_dct",
1261
+ "project_name",
1262
+ "project_number",
1263
+ "username",
1264
+ "project_summary_df",
1265
+ "clone_project_dict",
1266
+ ]
1267
+
1268
+ # Clear all keys in st.session_state except the ones to keep
1269
+ for key in list(st.session_state.keys()):
1270
+ if key not in keys_to_keep:
1271
+ del st.session_state[key]
1272
+
1273
+ ensure_project_dct_structure(st.session_state["project_dct"], default_dct)
1274
+
1275
+ if st.button(
1276
+ "Load Project",
1277
+ use_container_width=True,
1278
+ key="load_project_key",
1279
+ on_click=load_project_fun,
1280
+ ):
1281
+ st.success("Project Loded")
1282
+
1283
+ # st.rerun() # refresh the page
1284
+
1285
+ # st.write(st.session_state['project_dct'])
1286
+ if "radio_box_index" not in st.session_state:
1287
+ st.session_state["radio_box_index"] = 0
1288
+
1289
+ projct_radio = st.radio(
1290
+ "Select Options",
1291
+ [
1292
+ "Create New Project",
1293
+ "Modify Project Access",
1294
+ "Clone Saved Projects",
1295
+ "Delete Projects",
1296
+ ],
1297
+ horizontal=True,
1298
+ index=st.session_state["radio_box_index"],
1299
+ )
1300
+
1301
+ if projct_radio == "Modify Project Access":
1302
+
1303
+ with st.expander("Modify Project Access"):
1304
+ project_number_for_access = st.selectbox(
1305
+ "Select Project Number",
1306
+ st.session_state["project_summary_df"]["Project Number"],
1307
+ )
1308
+
1309
+ with st.spinner("Loading"):
1310
+ users_who_has_access = fetch_users_with_access(
1311
+ project_number_for_access
1312
+ )
1313
+
1314
+ users_name_who_has_access = [user[1] for user in users_who_has_access]
1315
+ modified_users_for_access_options = [
1316
+ details[0]
1317
+ for user_id, details in st.session_state["unique_ids"].items()
1318
+ if user_id != st.session_state["emp_id"]
1319
+ ]
1320
+
1321
+ users_name_who_has_access = [
1322
+ name
1323
+ for name in users_name_who_has_access
1324
+ if name in modified_users_for_access_options
1325
+ ]
1326
+
1327
+ modified_users_for_access = st.multiselect(
1328
+ "Select or deselect users to grant or revoke access, then click the 'Modify Access' button to submit changes.",
1329
+ options=modified_users_for_access_options,
1330
+ default=users_name_who_has_access,
1331
+ )
1332
+
1333
+ if st.button("Modify Access", use_container_width=True):
1334
+ with st.spinner("Modifying Access"):
1335
+ update_project_access(
1336
+ project_number_for_access,
1337
+ modified_users_for_access,
1338
+ st.session_state["unique_ids"],
1339
+ )
1340
+
1341
+ if projct_radio == "Create New Project":
1342
+
1343
+ with st.expander("Create New Project", expanded=False):
1344
+
1345
+ st.session_state["is_create_project_open"] = True
1346
+
1347
+ unique_users = [
1348
+ user[0] for user in st.session_state["unique_ids"].values()
1349
+ ] # fetching unique users who has access to the tool
1350
+
1351
+ user_projects = list(
1352
+ set(st.session_state["project_summary_df"]["Project Name"])
1353
+ ) # fetching corressponding user's projects
1354
+ st.markdown(
1355
+ """
1356
+ To create a new project, follow the instructions below:
1357
+
1358
+ 1. **Project Name**:
1359
+ - It should start with the client name, followed by the username.
1360
+ - It should not contain special characters except for underscores (`_`) and should not contain spaces.
1361
+ - Example format: `<client_name>_<username>_<project_name>`
1362
+
1363
+ 2. **Select User**: Select the user you want to give access to this project.
1364
+
1365
+ 3. **Create New Project**: Click **Create New Project** once the above details are entered.
1366
+
1367
+ **Example**:
1368
+
1369
+ - For a client named "ClientA" and a user named "UserX" with a project named "NewCampaign", the project name should be:
1370
+ `ClientA_UserX_NewCampaign`
1371
+ """
1372
+ )
1373
+
1374
+ project_col1 = st.columns(3)
1375
+
1376
+ with project_col1[0]:
1377
+
1378
+ # API_tables = get_table_names(schema) # load API files
1379
+
1380
+ slection_tables = ["NA"]
1381
+
1382
+ api_name = st.selectbox("Select API data", slection_tables, index=0)
1383
+
1384
+ # data availabe through API
1385
+ # api_path = API_path_dict[api_name]
1386
+
1387
+ with project_col1[1]:
1388
+ defualt_project_prefix = f"{api_name.split('_mmo_')[0]}_{st.session_state['unique_ids'][st.session_state['emp_id']][0]}_".replace(
1389
+ " ", "_"
1390
+ ).lower()
1391
+
1392
+ if "project_name_box" not in st.session_state:
1393
+ st.session_state["project_name_box"] = defualt_project_prefix
1394
+
1395
+ project_name = st.text_input(
1396
+ "Enter Project Name", key="project_name_box"
1397
+ )
1398
+ warning_box = st.empty()
1399
+
1400
+ with project_col1[2]:
1401
+
1402
+ allowed_users = st.multiselect(
1403
+ "Select Users who can access to this Project",
1404
+ [val for val in unique_users],
1405
+ )
1406
+
1407
+ allowed_users = list(allowed_users)
1408
+
1409
+ matching_user_id = []
1410
+
1411
+ if len(allowed_users) > 0:
1412
+
1413
+ # converting the selection to comma seperated values to store in db
1414
+
1415
+ for emp_id, details in st.session_state["unique_ids"].items():
1416
+ for name in allowed_users:
1417
+ if name in details:
1418
+ matching_user_id.append(emp_id)
1419
+ break
1420
+
1421
+ st.button(
1422
+ "Reset Project Name",
1423
+ on_click=reset_project_text_box,
1424
+ help="",
1425
+ use_container_width=True,
1426
+ )
1427
+
1428
+ create = st.button(
1429
+ "Create New Project",
1430
+ use_container_width=True,
1431
+ help="Project Name should follow naming convention",
1432
+ )
1433
+
1434
+ if create:
1435
+ if not project_name.lower().startswith(defualt_project_prefix):
1436
+ with warning_box:
1437
+ st.warning("Project Name should follow naming convention")
1438
+ st.stop()
1439
+
1440
+ if project_name == defualt_project_prefix:
1441
+ with warning_box:
1442
+ st.warning("Cannot name only with prefix")
1443
+ st.stop()
1444
+
1445
+ if project_name in user_projects:
1446
+ with warning_box:
1447
+ st.warning("Project already exists please enter new name")
1448
+ st.stop()
1449
+
1450
+ if not (
1451
+ 2 <= len(project_name) <= 50
1452
+ and bool(re.match("^[A-Za-z0-9_]*$", project_name))
1453
+ ):
1454
+ # Store the warning message details in session state
1455
+
1456
+ with warning_box:
1457
+ st.warning(
1458
+ "Please provide a valid project name (2-50 characters, only A-Z, a-z, 0-9, and _)."
1459
+ )
1460
+ st.stop()
1461
+
1462
+ if contains_sql_keywords_check(project_name):
1463
+ with warning_box:
1464
+ st.warning(
1465
+ "Input contains SQL keywords. Please avoid using SQL commands."
1466
+ )
1467
+ st.stop()
1468
+ else:
1469
+ pass
1470
+
1471
+ with st.spinner("Creating Project"):
1472
+ new_project()
1473
+
1474
+ with warning_box:
1475
+ st.write("Project Created")
1476
+
1477
+ st.session_state["radio_box_index"] = 1
1478
+
1479
+ log_message(
1480
+ "info",
1481
+ f"Employee {st.session_state['emp_id']} created new project {project_name}.",
1482
+ "Home",
1483
+ )
1484
+
1485
+ st.rerun()
1486
+
1487
+ if projct_radio == "Clone Saved Projects":
1488
+
1489
+ with st.expander("Clone Saved Projects", expanded=False):
1490
+
1491
+ if len(st.session_state["clone_project_dict"]) == 0:
1492
+ st.warning("You dont have access to any saved projects")
1493
+ st.stop()
1494
+
1495
+ cols = st.columns(2)
1496
+
1497
+ with cols[0]:
1498
+ owners = list(st.session_state["clone_project_dict"].keys())
1499
+ owner_name = st.selectbox("Select Owner", owners)
1500
+
1501
+ with cols[1]:
1502
+
1503
+ project_names = [
1504
+ project["project_name"]
1505
+ for project in st.session_state["clone_project_dict"][owner_name]
1506
+ ]
1507
+ project_name_owner = st.selectbox(
1508
+ "Select a saved Project available for you",
1509
+ project_names,
1510
+ )
1511
+
1512
+ defualt_project_prefix = f"{project_name_owner.split('_')[0]}_{st.session_state['unique_ids'][st.session_state['emp_id']][0]}_".replace(
1513
+ " ", "_"
1514
+ ).lower()
1515
+ user_projects = list(
1516
+ set(st.session_state["project_summary_df"]["Project Name"])
1517
+ )
1518
+
1519
+ cloned_project_name = st.text_input(
1520
+ "Enter Project Name",
1521
+ value=defualt_project_prefix,
1522
+ )
1523
+ warning_box = st.empty()
1524
+
1525
+ if st.button(
1526
+ "Load Project", use_container_width=True, key="load_project_button_key"
1527
+ ):
1528
+
1529
+ if not cloned_project_name.lower().startswith(defualt_project_prefix):
1530
+ with warning_box:
1531
+ st.warning("Project Name should follow naming conventions")
1532
+ st.stop()
1533
+
1534
+ if cloned_project_name == defualt_project_prefix:
1535
+ with warning_box:
1536
+ st.warning("Cannot Name only with Prefix")
1537
+ st.stop()
1538
+
1539
+ if cloned_project_name in user_projects:
1540
+ with warning_box:
1541
+ st.warning("Project already exists please enter new name")
1542
+ st.stop()
1543
+
1544
+ with st.spinner("Cloning Project"):
1545
+ old_prj_id = get_project_id_from_dict(
1546
+ st.session_state["clone_project_dict"],
1547
+ owner_name,
1548
+ project_name_owner,
1549
+ )
1550
+ old_metadata = fetch_project_metadata(old_prj_id)
1551
+
1552
+ new_prj_id = create_new_project(
1553
+ st.session_state["emp_id"],
1554
+ cloned_project_name,
1555
+ "",
1556
+ st.session_state["emp_id"],
1557
+ )
1558
+
1559
+ insert_project_metadata(
1560
+ new_prj_id, old_metadata, st.session_state["emp_id"]
1561
+ )
1562
+ update_summary_df()
1563
+ st.success("Project Cloned")
1564
+ st.rerun()
1565
+
1566
+ if projct_radio == "Delete Projects":
1567
+ if len(st.session_state["project_summary_df"]) != 0:
1568
+
1569
+ with st.expander("Delete Projects", expanded=True):
1570
+
1571
+ delete_projects = st.multiselect(
1572
+ "Select all the projects number who want to delete",
1573
+ st.session_state["project_summary_df"]["Project Number"],
1574
+ )
1575
+ st.warning(
1576
+ "Projects will be permanently deleted. Other users will not be able to clone them if they have not already done so."
1577
+ )
1578
+ if st.button("Delete Projects", use_container_width=True):
1579
+ if len(delete_projects) > 0:
1580
+ with st.spinner("Deleting Projects"):
1581
+ delete_projects_by_ids(delete_projects)
1582
+ update_summary_df()
1583
+ st.success("Projects Deleted")
1584
+ st.rerun()
1585
+
1586
+ else:
1587
+ st.warning("Please select atleast one project number to delete")
1588
+
1589
+ if projct_radio == "Download Project PPT":
1590
+
1591
+ try:
1592
+ ppt = create_ppt(
1593
+ st.session_state["project_name"],
1594
+ st.session_state["username"],
1595
+ "panel", # new
1596
+ )
1597
+
1598
+ if ppt is not False:
1599
+ st.download_button(
1600
+ "Download",
1601
+ data=ppt.getvalue(),
1602
+ file_name=st.session_state["project_name"]
1603
+ + " Project Summary.pptx",
1604
+ use_container_width=True,
1605
+ )
1606
+ else:
1607
+ st.warning("Please make some progress before downloading PPT.")
1608
+
1609
+ except Exception as e:
1610
+ st.warning("PPT Download Faild ")
1611
+ # new
1612
+ log_message(
1613
+ log_type="error", message=f"Error in PPT build: {e}", page_name="Home"
1614
+ )
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: MediaMixOptimization
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: streamlit
7
- sdk_version: 1.44.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: This tool helps in optimizing media spends
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: MediaMixOptimization
3
+ emoji: 📊
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.28.0
8
+ app_file: Home.py
9
+ pinned: false
10
+ short_description: This tool helps in optimizing media spends
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
config.json ADDED
File without changes
config.yaml ADDED
File without changes
constants.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################################################
2
+ # Default
3
+ ###########################################################################################################
4
+
5
+ import pandas as pd
6
+ from collections import OrderedDict
7
+
8
+ # default_dct
9
+ default_dct = {
10
+ "data_import": {
11
+ "gold_layer_df": pd.DataFrame(),
12
+ "granularity_selection": "daily",
13
+ "dashboard_df": None,
14
+ "tool_df": None,
15
+ "unique_panels": [],
16
+ "imputation_df": None,
17
+ "imputed_tool_df": None,
18
+ "group_dict": {},
19
+ "category_dict": {},
20
+ },
21
+ "data_validation": {
22
+ "target_column": 0,
23
+ "selected_panels": None,
24
+ "selected_feature": 0,
25
+ "validated_variables": [],
26
+ "Non_media_variables": 0,
27
+ "correlation": [],
28
+ },
29
+ "transformations": {
30
+ "final_df": None,
31
+ "summary_string": None,
32
+ "Media": {},
33
+ "Exogenous": {},
34
+ "Internal": {},
35
+ "Specific": {},
36
+ "correlation_plot_selection": [],
37
+ },
38
+ "model_build": {
39
+ "sel_target_col": None,
40
+ "all_iters_check": False,
41
+ "iterations": 0,
42
+ "build_button": False,
43
+ "show_results_check": False,
44
+ "session_state_saved": {},
45
+ },
46
+ "model_tuning": {
47
+ "sel_target_col": None,
48
+ "sel_model": {},
49
+ "flag_expander": False,
50
+ "start_date_default": None,
51
+ "end_date_default": None,
52
+ "repeat_default": "No",
53
+ "flags": {},
54
+ "select_all_flags_check": {},
55
+ "selected_flags": {},
56
+ "trend_check": False,
57
+ "week_num_check": False,
58
+ "sine_cosine_check": False,
59
+ "session_state_saved": {},
60
+ },
61
+ "saved_model_results": {
62
+ "selected_options": None,
63
+ "model_grid_sel": [1],
64
+ },
65
+ "current_media_performance": {
66
+ "model_outputs": {},
67
+ },
68
+ "media_performance": {"start_date": None, "end_date": None},
69
+ "response_curves": {
70
+ "original_metadata_file": None,
71
+ "modified_metadata_file": None,
72
+ },
73
+ "scenario_planner": {
74
+ "original_metadata_file": None,
75
+ "modified_metadata_file": None,
76
+ },
77
+ "saved_scenarios": {"saved_scenarios_dict": OrderedDict()},
78
+ "optimized_result_analysis": {
79
+ "selected_scenario_selectbox_visualize": 0,
80
+ "metric_selectbox_visualize": 0,
81
+ },
82
+ }
83
+
84
+ ###########################################################################################################
85
+ # Data Import
86
+ ###########################################################################################################
87
+
88
+ # Constants Data Import
89
+ upload_rows_limit = 1000000 # Maximum number of rows allowed for upload
90
+ upload_column_limit = 1000 # Maximum number of columns allowed for upload
91
+ word_length_limit_lower = 2 # Minimum allowed length for words
92
+ word_length_limit_upper = 100 # Maximum allowed length for words
93
+ minimum_percent_overlap = (
94
+ 1 # Minimum required percentage of overlap with the reference data
95
+ )
96
+ minimum_row_req = 30 # Minimum number of rows required
97
+ percent_drop_col_threshold = 50 # Percentage threshold above which columns are automatically categorized for drop imputation
98
+
99
+ ###########################################################################################################
100
+ # Transfromations
101
+ ###########################################################################################################
102
+
103
+ # Constants Transformations
104
+ predefined_defaults = {
105
+ "Lag": (1, 2),
106
+ "Lead": (1, 2),
107
+ "Moving Average": (1, 2),
108
+ "Saturation": (10, 20),
109
+ "Power": (2, 4),
110
+ "Adstock": (0.5, 0.7),
111
+ } # Pre-defined default values of every transformation
112
+
113
+ # Transfromations min, max and step
114
+ lead_min_value = 1
115
+ lead_max_value = 10
116
+ lead_step = 1
117
+ lag_min_value = 1
118
+ lag_max_value = 10
119
+ lag_step = 1
120
+ moving_average_min_value = 1
121
+ moving_average_max_value = 10
122
+ moving_average_step = 1
123
+ saturation_min_value = 0
124
+ saturation_max_value = 100
125
+ saturation_step = 1
126
+ power_min_value = 1
127
+ power_max_value = 5
128
+ power_step = 1
129
+ adstock_min_value = 0.0
130
+ adstock_max_value = 1.0
131
+ adstock_step = 0.05
132
+ display_max_col = 500 # Maximum columns to display
133
+
134
+ ###########################################################################################################
135
+ # Model Build
136
+ ###########################################################################################################
137
+
138
+ MAX_COMBINATIONS = 50000 # Max number of model combinations possible
139
+ MIN_MODEL_NAME_LENGTH = 0 # model can only be saved if len(model_name) is greater than this value
140
+ MIN_P_VALUE_THRESHOLD = 0.06 # coefficients with p values less than this value are considered valid
141
+ MODEL_POS_COEFF_RATIO_THRESHOLD = 0 # ratio of positive coefficients/total coefficients for model validity
142
+ MODEL_P_VALUE_RATIO_THRESHOLD = 0 # ratio of coefficients with p value/total coefficients for model validity
143
+ MAX_TOP_FEATURES = 5 # max number of top features selected per variable
144
+ MAX_NUM_FILTERS = 10 # maximum number of filters allowed
145
+ DEFAULT_FILTER_VALUE = 0.0 # default value of a new filter
146
+ VIF_LOW_THRESHOLD = 3 # Threshold for VIF to be colored green
147
+ VIF_HIGH_THRESHOLD = 10 # Threshold for VIF to be colored red
148
+ DEFAULT_TRAIN_RATIO = 3/4 # default train set ratio (75%)
149
+
150
+ ###########################################################################################################
151
+ # Model Tuning
152
+ ###########################################################################################################
153
+
154
+ import numpy as np
155
+
156
+ NUM_FLAG_COLS_TO_DISPLAY = 4 # Number of columns to be created on UI to display flags
157
+ HALF_YEAR_THRESHOLD = 6 # Threshold of months to create quarter frequency sine-cosine waves
158
+ FULL_YEAR_THRESHOLD = 12 # Threshold of months to create quarter annual sine-cosine waves
159
+ TREND_MIN = 5 # Starting value of trend line
160
+ ANNUAL_FREQUENCY = 2 * np.pi / 365 # annual frequency
161
+ QTR_FREQUENCY_FACTOR = 4 # multiplication factor to get quarterly frequency
162
+ HALF_YEARLY_FREQUENCY_FACTOR = 2 # multiplication factor to get semi-annual frequency
163
+
164
+ ###########################################################################################################
165
+ # Scenario Planner
166
+ ###########################################################################################################
167
+
168
+ # Constants Scenario Planner
169
+ xtol_tolerance_per = 1 # Percenatge of tolerance
170
+ mroi_threshold = 0.05 # mROI threshold
171
+ word_length_limit_lower = 2 # Minimum allowed length for words
172
+ word_length_limit_upper = 100 # Maximum allowed length for words
173
+
174
+ ###########################################################################################################
175
+ # PPT utils
176
+ ###########################################################################################################
177
+
178
+ TITLE_FONT_SIZE = 20
179
+ AXIS_LABEL_FONT_SIZE = 8
180
+ CHART_TITLE_FONT_SIZE = 14
181
+ AXIS_TITLE_FONT_SIZE = 12
182
+ DATA_LABEL_FONT_SIZE = 8
183
+ LEGEND_FONT_SIZE = 10
184
+ PIE_LEGEND_FONT_SIZE = 7
185
+
186
+ ###########################################################################################################
187
+ # Page Name
188
+ ###########################################################################################################
data_analysis.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.express as px
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from sklearn.metrics import r2_score
6
+ from collections import OrderedDict
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ import pandas as pd
10
+ import seaborn as sns
11
+ import matplotlib.pyplot as plt
12
+ import streamlit as st
13
+ import re
14
+ from matplotlib.colors import ListedColormap
15
+ # from st_aggrid import AgGrid, GridOptionsBuilder
16
+ # from src.agstyler import PINLEFT, PRECISION_TWO, draw_grid
17
+
18
+
19
+ def format_numbers(x):
20
+ if abs(x) >= 1e6:
21
+ # Format as millions with one decimal place and commas
22
+ return f'{x/1e6:,.1f}M'
23
+ elif abs(x) >= 1e3:
24
+ # Format as thousands with one decimal place and commas
25
+ return f'{x/1e3:,.1f}K'
26
+ else:
27
+ # Format with one decimal place and commas for values less than 1000
28
+ return f'{x:,.1f}'
29
+
30
+
31
+
32
+ def line_plot(data, x_col, y1_cols, y2_cols, title):
33
+ """
34
+ Create a line plot with two sets of y-axis data.
35
+
36
+ Parameters:
37
+ data (DataFrame): The data containing the columns to be plotted.
38
+ x_col (str): The column name for the x-axis.
39
+ y1_cols (list): List of column names for the primary y-axis.
40
+ y2_cols (list): List of column names for the secondary y-axis.
41
+ title (str): The title of the plot.
42
+
43
+ Returns:
44
+ fig (Figure): The Plotly figure object with the line plot.
45
+ """
46
+ fig = go.Figure()
47
+
48
+ # Add traces for the primary y-axis
49
+ for y1_col in y1_cols:
50
+ fig.add_trace(go.Scatter(x=data[x_col], y=data[y1_col], mode='lines', name=y1_col, line=dict(color='#11B6BD')))
51
+
52
+ # Add traces for the secondary y-axis
53
+ for y2_col in y2_cols:
54
+ fig.add_trace(go.Scatter(x=data[x_col], y=data[y2_col], mode='lines', name=y2_col, yaxis='y2', line=dict(color='#739FAE')))
55
+
56
+ # Configure the layout for the secondary y-axis if needed
57
+ if len(y2_cols) != 0:
58
+ fig.update_layout(yaxis=dict(), yaxis2=dict(overlaying='y', side='right'))
59
+ else:
60
+ fig.update_layout(yaxis=dict(), yaxis2=dict(overlaying='y', side='right'))
61
+
62
+ # Add title if provided
63
+ if title:
64
+ fig.update_layout(title=title)
65
+
66
+ # Customize axes and legend
67
+ fig.update_xaxes(showgrid=False)
68
+ fig.update_yaxes(showgrid=False)
69
+ fig.update_layout(legend=dict(
70
+ orientation="h",
71
+ yanchor="top",
72
+ y=1.1,
73
+ xanchor="center",
74
+ x=0.5
75
+ ))
76
+
77
+ return fig
78
+
79
+
80
+
81
+ def line_plot_target(df, target, title):
82
+ """
83
+ Create a line plot with a trendline for a target column.
84
+
85
+ Parameters:
86
+ df (DataFrame): The data containing the columns to be plotted.
87
+ target (str): The column name for the y-axis.
88
+ title (str): The title of the plot.
89
+
90
+ Returns:
91
+ fig (Figure): The Plotly figure object with the line plot and trendline.
92
+ """
93
+ # Calculate the trendline coefficients
94
+ coefficients = np.polyfit(df['date'].view('int64'), df[target], 1)
95
+ trendline = np.poly1d(coefficients)
96
+ fig = go.Figure()
97
+
98
+ # Add the target line plot
99
+ fig.add_trace(go.Scatter(x=df['date'], y=df[target], mode='lines', name=target, line=dict(color='#11B6BD')))
100
+
101
+ # Calculate and add the trendline plot
102
+ trendline_x = df['date']
103
+ trendline_y = trendline(df['date'].view('int64'))
104
+ fig.add_trace(go.Scatter(x=trendline_x, y=trendline_y, mode='lines', name='Trendline', line=dict(color='#739FAE')))
105
+
106
+ # Update layout with title and x-axis type
107
+ fig.update_layout(
108
+ title=title,
109
+ xaxis=dict(type='date')
110
+ )
111
+
112
+ # Add vertical lines at the start of each year
113
+ for year in df['date'].dt.year.unique()[1:]:
114
+ january_1 = pd.Timestamp(year=year, month=1, day=1)
115
+ fig.add_shape(
116
+ go.layout.Shape(
117
+ type="line",
118
+ x0=january_1,
119
+ x1=january_1,
120
+ y0=0,
121
+ y1=1,
122
+ xref="x",
123
+ yref="paper",
124
+ line=dict(color="grey", width=1.5, dash="dash"),
125
+ )
126
+ )
127
+
128
+ # Customize the legend
129
+ fig.update_layout(legend=dict(
130
+ orientation="h",
131
+ yanchor="top",
132
+ y=1.1,
133
+ xanchor="center",
134
+ x=0.5
135
+ ))
136
+
137
+ return fig
138
+
139
+
140
+ def correlation_plot(df, selected_features, target):
141
+ """
142
+ Create a correlation heatmap plot for selected features and target column.
143
+
144
+ Parameters:
145
+ df (DataFrame): The data containing the columns to be plotted.
146
+ selected_features (list): List of column names to be included in the correlation plot.
147
+ target (str): The target column name to be included in the correlation plot.
148
+
149
+ Returns:
150
+ fig (Figure): The Matplotlib figure object with the correlation heatmap plot.
151
+ """
152
+ # Define custom colormap
153
+ custom_cmap = ListedColormap(['#08083B', "#11B6BD"])
154
+
155
+ # Select the relevant columns for correlation calculation
156
+ corr_df = df[selected_features]
157
+ corr_df = pd.concat([corr_df, df[target]], axis=1)
158
+
159
+ # Create a matplotlib figure and axis
160
+ fig, ax = plt.subplots(figsize=(16, 12))
161
+
162
+ # Generate the heatmap with correlation coefficients
163
+ sns.heatmap(corr_df.corr(), annot=True, cmap='Blues', fmt=".2f", linewidths=0.5, mask=np.triu(corr_df.corr()))
164
+
165
+ # Customize the plot
166
+ plt.xticks(rotation=45)
167
+ plt.yticks(rotation=0)
168
+
169
+ return fig
170
+
171
+
172
+ def summary(data, selected_feature, spends, Target=None):
173
+ """
174
+ Create a summary table of selected features and optionally a target column.
175
+
176
+ Parameters:
177
+ data (DataFrame): The data containing the columns to be summarized.
178
+ selected_feature (list): List of column names to be included in the summary.
179
+ spends (str): The column name for the spends data.
180
+ Target (str, optional): The target column name for additional summary calculations. Default is None.
181
+
182
+ Returns:
183
+ sum_df (DataFrame): The summary DataFrame with formatted values.
184
+ """
185
+ if Target:
186
+ # Summarize data for the target column
187
+ sum_df = data[selected_feature]
188
+ sum_df['Year'] = data['date'].dt.year
189
+ sum_df = sum_df.groupby('Year')[selected_feature].sum().reset_index()
190
+
191
+ # Calculate total sum and append to the DataFrame
192
+ total_sum = sum_df.sum(numeric_only=True)
193
+ total_sum['Year'] = 'Total'
194
+ sum_df = pd.concat([sum_df, total_sum.to_frame().T], axis=0, ignore_index=True).copy()
195
+
196
+ # Set 'Year' as index and format numbers
197
+ sum_df.set_index(['Year'], inplace=True)
198
+ sum_df = sum_df.applymap(format_numbers)
199
+
200
+ # Format spends columns as currency
201
+ spends_col = [col for col in sum_df.columns if any(keyword in col for keyword in ['spends', 'cost'])]
202
+ for col in spends_col:
203
+ sum_df[col] = sum_df[col].map(lambda x: f'${x}')
204
+
205
+ return sum_df
206
+ else:
207
+ # Include spends in the selected features
208
+ selected_feature.append(spends)
209
+
210
+ # Ensure unique features
211
+ selected_feature = list(set(selected_feature))
212
+
213
+ if len(selected_feature) > 1:
214
+ imp_clicks = selected_feature[1]
215
+ spends_col = selected_feature[0]
216
+
217
+ # Summarize data for the selected features
218
+ sum_df = data[selected_feature]
219
+ sum_df['Year'] = data['date'].dt.year
220
+ sum_df = sum_df.groupby('Year')[selected_feature].agg('sum')
221
+
222
+ # Calculate CPM/CPC
223
+ sum_df['CPM/CPC'] = (sum_df[spends_col] / sum_df[imp_clicks]) * 1000
224
+
225
+ # Calculate grand total and append to the DataFrame
226
+ sum_df.loc['Grand Total'] = sum_df.sum()
227
+
228
+ # Format numbers and replace NaNs
229
+ sum_df = sum_df.applymap(format_numbers)
230
+ sum_df.fillna('-', inplace=True)
231
+ sum_df = sum_df.replace({"0.0": '-', 'nan': '-'})
232
+
233
+ # Format spends columns as currency
234
+ sum_df[spends_col] = sum_df[spends_col].map(lambda x: f'${x}')
235
+
236
+ return sum_df
237
+ else:
238
+ # Summarize data for a single selected feature
239
+ sum_df = data[selected_feature]
240
+ sum_df['Year'] = data['date'].dt.year
241
+ sum_df = sum_df.groupby('Year')[selected_feature].agg('sum')
242
+
243
+ # Calculate grand total and append to the DataFrame
244
+ sum_df.loc['Grand Total'] = sum_df.sum()
245
+
246
+ # Format numbers and replace NaNs
247
+ sum_df = sum_df.applymap(format_numbers)
248
+ sum_df.fillna('-', inplace=True)
249
+ sum_df = sum_df.replace({"0.0": '-', 'nan': '-'})
250
+
251
+ # Format spends columns as currency
252
+ spends_col = [col for col in sum_df.columns if any(keyword in col for keyword in ['spends', 'cost'])]
253
+ for col in spends_col:
254
+ sum_df[col] = sum_df[col].map(lambda x: f'${x}')
255
+
256
+ return sum_df
257
+
258
+
259
+
260
+ def sanitize_key(key, prefix=""):
261
+ # Use regular expressions to remove non-alphanumeric characters and spaces
262
+ key = re.sub(r'[^a-zA-Z0-9]', '', key)
263
+ return f"{prefix}{key}"
264
+
265
+
266
+
267
+
data_prep.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ import statsmodels.api as sm
6
+ from sklearn.metrics import mean_absolute_error, r2_score,mean_absolute_percentage_error
7
+ from sklearn.preprocessing import MinMaxScaler
8
+ import matplotlib.pyplot as plt
9
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
10
+ from plotly.subplots import make_subplots
11
+
12
+ st.set_option('deprecation.showPyplotGlobalUse', False)
13
+ from datetime import datetime
14
+ import seaborn as sns
15
+
16
+
17
+ def plot_actual_vs_predicted(date, y, predicted_values, model, target_column=None, flag=None, repeat_all_years=False, is_panel=False):
18
+ """
19
+ Plots actual vs predicted values with optional flags and aggregation for panel data.
20
+
21
+ Parameters:
22
+ date (pd.Series): Series of dates for x-axis.
23
+ y (pd.Series): Actual values.
24
+ predicted_values (pd.Series): Predicted values from the model.
25
+ model (object): Trained model object.
26
+ target_column (str, optional): Name of the target column.
27
+ flag (tuple, optional): Start and end dates for flagging periods.
28
+ repeat_all_years (bool, optional): Whether to repeat flags for all years.
29
+ is_panel (bool, optional): Whether the data is panel data requiring aggregation.
30
+
31
+ Returns:
32
+ metrics_table (pd.DataFrame): DataFrame containing MAPE, R-squared, and Adjusted R-squared.
33
+ line_values (list): List of flag values for plotting.
34
+ fig (go.Figure): Plotly figure object.
35
+ """
36
+ if flag is not None:
37
+ fig = make_subplots(specs=[[{"secondary_y": True}]])
38
+ else:
39
+ fig = go.Figure()
40
+
41
+ if is_panel:
42
+ df = pd.DataFrame()
43
+ df['date'] = date
44
+ df['Actual'] = y
45
+ df['Predicted'] = predicted_values
46
+ df_agg = df.groupby('date').agg({'Actual': 'sum', 'Predicted': 'sum'}).reset_index()
47
+ df_agg.columns = ['date', 'Actual', 'Predicted']
48
+ assert len(df_agg) == pd.Series(date).nunique()
49
+
50
+ fig.add_trace(go.Scatter(x=df_agg['date'], y=df_agg['Actual'], mode='lines', name='Actual', line=dict(color='#08083B')))
51
+ fig.add_trace(go.Scatter(x=df_agg['date'], y=df_agg['Predicted'], mode='lines', name='Predicted', line=dict(color='#11B6BD')))
52
+ else:
53
+ fig.add_trace(go.Scatter(x=date, y=y, mode='lines', name='Actual', line=dict(color='#08083B')))
54
+ fig.add_trace(go.Scatter(x=date, y=predicted_values, mode='lines', name='Predicted', line=dict(color='#11B6BD')))
55
+
56
+ line_values = []
57
+ if flag:
58
+ min_date, max_date = flag[0], flag[1]
59
+ min_week = datetime.strptime(str(min_date), "%Y-%m-%d").strftime("%U")
60
+ max_week = datetime.strptime(str(max_date), "%Y-%m-%d").strftime("%U")
61
+
62
+ if repeat_all_years:
63
+ line_values = list(pd.Series(date).map(lambda x: 1 if (pd.Timestamp(x).week >= int(min_week)) & (pd.Timestamp(x).week <= int(max_week)) else 0))
64
+ assert len(line_values) == len(date)
65
+ fig.add_trace(go.Scatter(x=date, y=line_values, mode='lines', name='Flag', line=dict(color='#FF5733')), secondary_y=True)
66
+ else:
67
+ line_values = list(pd.Series(date).map(lambda x: 1 if (pd.Timestamp(x) >= pd.Timestamp(min_date)) and (pd.Timestamp(x) <= pd.Timestamp(max_date)) else 0))
68
+ fig.add_trace(go.Scatter(x=date, y=line_values, mode='lines', name='Flag', line=dict(color='#FF5733')), secondary_y=True)
69
+
70
+ mape = mean_absolute_percentage_error(y, predicted_values)
71
+ r2 = r2_score(y, predicted_values)
72
+ adjr2 = 1 - (1 - r2) * (len(y) - 1) / (len(y) - len(model.params) - 1)
73
+
74
+ metrics_table = pd.DataFrame({
75
+ 'Metric': ['MAPE', 'R-squared', 'AdjR-squared'],
76
+ 'Value': [mape, r2, adjr2]
77
+ })
78
+
79
+ fig.update_layout(
80
+ xaxis=dict(title='Date'),
81
+ yaxis=dict(title=target_column),
82
+ xaxis_tickangle=-30
83
+ )
84
+ fig.add_annotation(
85
+ text=f"MAPE: {mape * 100:0.1f}%, Adj. R-squared: {adjr2 * 100:.1f}%",
86
+ xref="paper",
87
+ yref="paper",
88
+ x=0.95,
89
+ y=1.2,
90
+ showarrow=False,
91
+ )
92
+
93
+ return metrics_table, line_values, fig
94
+
95
+
96
+ def plot_residual_predicted(actual, predicted, df):
97
+ """
98
+ Plots standardized residuals against predicted values.
99
+
100
+ Parameters:
101
+ actual (pd.Series): Actual values.
102
+ predicted (pd.Series): Predicted values.
103
+ df (pd.DataFrame): DataFrame containing the data.
104
+
105
+ Returns:
106
+ fig (go.Figure): Plotly figure object.
107
+ """
108
+ df_ = df.copy()
109
+ df_['Residuals'] = actual - pd.Series(predicted)
110
+ df_['StdResidual'] = (df_['Residuals'] - df_['Residuals'].mean()) / df_['Residuals'].std()
111
+
112
+ fig = px.scatter(df_, x=predicted, y='StdResidual', opacity=0.5, color_discrete_sequence=["#11B6BD"])
113
+
114
+ fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
115
+ fig.add_hline(y=2, line_color="red")
116
+ fig.add_hline(y=-2, line_color="red")
117
+
118
+ fig.update_xaxes(title='Predicted')
119
+ fig.update_yaxes(title='Standardized Residuals (Actual - Predicted)')
120
+
121
+ fig.update_layout(title='2.3.1 Residuals over Predicted Values', autosize=False, width=600, height=400)
122
+
123
+ return fig
124
+
125
+
126
+ def residual_distribution(actual, predicted):
127
+ """
128
+ Plots the distribution of residuals.
129
+
130
+ Parameters:
131
+ actual (pd.Series): Actual values.
132
+ predicted (pd.Series): Predicted values.
133
+
134
+ Returns:
135
+ plt (matplotlib.pyplot): Matplotlib plot object.
136
+ """
137
+ Residuals = actual - pd.Series(predicted)
138
+
139
+ sns.set(style="whitegrid")
140
+ plt.figure(figsize=(6, 4))
141
+ sns.histplot(Residuals, kde=True, color="#11B6BD")
142
+
143
+ plt.title('2.3.3 Distribution of Residuals')
144
+ plt.xlabel('Residuals')
145
+ plt.ylabel('Probability Density')
146
+
147
+ return plt
148
+
149
+
150
+ def qqplot(actual, predicted):
151
+ """
152
+ Creates a QQ plot of the residuals.
153
+
154
+ Parameters:
155
+ actual (pd.Series): Actual values.
156
+ predicted (pd.Series): Predicted values.
157
+
158
+ Returns:
159
+ fig (go.Figure): Plotly figure object.
160
+ """
161
+ Residuals = actual - pd.Series(predicted)
162
+ Residuals = pd.Series(Residuals)
163
+ Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
164
+
165
+ fig = go.Figure()
166
+ fig.add_trace(go.Scatter(x=sm.ProbPlot(Resud_std).theoretical_quantiles,
167
+ y=sm.ProbPlot(Resud_std).sample_quantiles,
168
+ mode='markers',
169
+ marker=dict(size=5, color="#11B6BD"),
170
+ name='QQ Plot'))
171
+
172
+ diagonal_line = go.Scatter(
173
+ x=[-2, 2],
174
+ y=[-2, 2],
175
+ mode='lines',
176
+ line=dict(color='red'),
177
+ name=' '
178
+ )
179
+ fig.add_trace(diagonal_line)
180
+
181
+ fig.update_layout(title='2.3.2 QQ Plot of Residuals', title_x=0.5, autosize=False, width=600, height=400,
182
+ xaxis_title='Theoretical Quantiles', yaxis_title='Sample Quantiles')
183
+
184
+ return fig
185
+
db/imp_db.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36d945da7748723c6078c2d2bee4aaae5ca0838dcee0e2032a836341dcf7794e
3
+ size 15618048
db_creation.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sqlite3
2
+
3
+ # # Connect to SQLite database (or create it if it doesn't exist)
4
+ # conn = sqlite3.connect('imp_db.db')
5
+
6
+ # # Enable foreign key support
7
+ # conn.execute('PRAGMA foreign_keys = ON;')
8
+
9
+ # # Create a cursor object
10
+ # c = conn.cursor()
11
+
12
+ # # SQL queries to create tables
13
+ # create_mmo_users_table = """
14
+ # CREATE TABLE IF NOT EXISTS mmo_users (
15
+ # emp_id TEXT PRIMARY KEY ,
16
+ # emp_nam TEXT NOT NULL,
17
+ # emp_typ TEXT NOT NULL,
18
+ # pswrd_key TEXT NOT NULL,
19
+ # pswrd_flag INTEGER NOT NULL DEFAULT 0,
20
+ # crte_dt_tm TEXT DEFAULT (datetime('now')),
21
+ # crte_by_uid TEXT NOT NULL,
22
+ # updt_dt_tm TEXT DEFAULT (datetime('now')),
23
+ # updt_by_uid TEXT
24
+ # );
25
+ # """
26
+
27
+ # create_mmo_projects_table = """
28
+ # CREATE TABLE IF NOT EXISTS mmo_projects (
29
+ # prj_id INTEGER PRIMARY KEY AUTOINCREMENT,
30
+ # prj_ownr_id TEXT NOT NULL,
31
+ # prj_nam TEXT NOT NULL,
32
+ # alwd_emp_id TEXT,
33
+ # meta_data_agrgt TEXT,
34
+ # crte_dt_tm TEXT DEFAULT (datetime('now')),
35
+ # crte_by_uid TEXT NOT NULL,
36
+ # updt_dt_tm TEXT DEFAULT (datetime('now')),
37
+ # updt_by_uid TEXT,
38
+ # FOREIGN KEY (prj_ownr_id) REFERENCES mmo_users(emp_id)
39
+ # );
40
+ # """
41
+
42
+ # create_mmo_project_meta_data_table = """
43
+ # CREATE TABLE IF NOT EXISTS mmo_project_meta_data (
44
+ # prj_guid INTEGER PRIMARY KEY AUTOINCREMENT,
45
+ # prj_id INTEGER NOT NULL,
46
+ # page_nam TEXT NOT NULL,
47
+ # file_nam TEXT NOT NULL,
48
+ # pkl_obj BLOB,
49
+ # dshbrd_ts TEXT,
50
+ # crte_dt_tm TEXT DEFAULT (datetime('now')),
51
+ # crte_by_uid TEXT NOT NULL,
52
+ # updt_dt_tm TEXT DEFAULT (datetime('now')),
53
+ # updt_by_uid TEXT,
54
+ # FOREIGN KEY (prj_id) REFERENCES mmo_projects(prj_id)
55
+ # );
56
+ # """
57
+
58
+ # # Execute the queries to create tables
59
+ # c.execute(create_mmo_users_table)
60
+ # c.execute(create_mmo_projects_table)
61
+ # c.execute(create_mmo_project_meta_data_table)
62
+
63
+ # # Commit changes and close the connection
64
+ # conn.commit()
65
+ # conn.close()
66
+ import sqlite3
67
+
68
+ def add_user_to_db(db_path, user_id, name, user_type, pswrd_key):
69
+ """
70
+ Adds a user to the mmo_users table in the SQLite database.
71
+
72
+ Parameters:
73
+ - db_path (str): The path to the SQLite database file.
74
+ - user_id (str): The ID of the user.
75
+ - name (str): The name of the user.
76
+ - user_type (str): The type of the user.
77
+ - pswrd_key (str): The password key for the user.
78
+ """
79
+
80
+ try:
81
+ # Connect to the SQLite database
82
+ conn = sqlite3.connect(db_path)
83
+ cursor = conn.cursor()
84
+
85
+ # SQL query to insert a new user
86
+ insert_query = """
87
+ INSERT INTO mmo_users (emp_id, emp_nam, emp_typ, pswrd_key,crte_by_uid)
88
+ VALUES (?, ?, ?, ?,?)
89
+ """
90
+
91
+ # Execute the query with parameters
92
+ cursor.execute(insert_query, (user_id, name, user_type, pswrd_key, user_id))
93
+
94
+ # Commit the transaction
95
+ conn.commit()
96
+
97
+ print(f"User {name} added successfully.")
98
+
99
+ except sqlite3.Error as e:
100
+ print(f"Error adding user to the database: {e}")
101
+
102
+ finally:
103
+ # Close the database connection
104
+ conn.close()
105
+
106
+ # Define the database path and user details
107
+ db_path = r'db\imp_db.db' # Update this path to your actual database path
108
+ user_id = 'e162284'
109
+ name = 'admin'
110
+ user_type = 'admin'
111
+ pswrd_key = '$2b$12$wP7R0usvKWtr4X06qwGWvOFQCkzOZAzSVRAoDv/68x6GS4rHK5mDm'
112
+
113
+ # Add the user to the database
114
+ add_user_to_db(db_path, user_id, name, user_type, pswrd_key)
log_application.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains logging functions to initialize and configure loggers, facilitating efficient tracking of ETL processes.
3
+ Also Provides methods to log messages with different log levels and additional context information such as channel name and client name
4
+ """
5
+
6
+ import os
7
+ import streamlit as st
8
+ import logging, logging.handlers
9
+ from datetime import datetime, timezone
10
+
11
+
12
+ def delete_old_log_files():
13
+ """
14
+ Deletes all log files in the 'logs' directory that are older than the current month.
15
+
16
+ Returns:
17
+ None
18
+ """
19
+ # Set the logs directory to 'logs' within the current working directory
20
+ logs_directory = os.path.join(os.getcwd(), "logs")
21
+
22
+ # Ensure the logs directory exists
23
+ if not os.path.exists(logs_directory):
24
+ print(f"Directory {logs_directory} does not exist.")
25
+ return
26
+
27
+ # Get the current year and month in UTC
28
+ current_utc_time = datetime.now(timezone.utc)
29
+ current_year_month = current_utc_time.strftime("%Y%m")
30
+
31
+ # Iterate through all files in the logs directory
32
+ for filename in os.listdir(logs_directory):
33
+ # We assume the log files are in the format "eid_YYYYMM.log"
34
+ if filename.endswith(".log"):
35
+ # Extract the date portion from the filename
36
+ try:
37
+ file_date = filename.split("_")[-1].replace(".log", "")
38
+ # Compare the file date with the current date
39
+ if file_date < current_year_month:
40
+ # Construct the full file path
41
+ file_path = os.path.join(logs_directory, filename)
42
+ # Delete the file
43
+ os.remove(file_path)
44
+
45
+ except IndexError:
46
+ # If the filename doesn't match the expected pattern, skip it
47
+ pass
48
+
49
+
50
+ def create_log_file_name(emp_id, project_name):
51
+ """
52
+ Generates a log file name using the format eid_YYYYMMDD.log based on UTC time.
53
+
54
+ Returns:
55
+ str: The generated log file name.
56
+ """
57
+ # Get the current UTC time
58
+ utc_now = datetime.now(timezone.utc)
59
+ # Format the date as YYYYMMDD
60
+ date_str = utc_now.strftime("%Y%m%d")
61
+ # Create the file name with the format eid_YYYYMMDD.log
62
+ log_file_name = f"{emp_id}_{project_name}_{date_str}.log"
63
+ return log_file_name
64
+
65
+
66
+ def get_logger(log_file):
67
+ """
68
+ Initializes and configures a logger. If the log file already exists, it appends to it.
69
+ Returns the same logger instance if called multiple times.
70
+
71
+ Args:
72
+ log_file (str): The path to the log file where logs will be written.
73
+
74
+ Returns:
75
+ logging.Logger: Configured logger instance.
76
+ """
77
+
78
+ # Define the logger name
79
+ logger_name = os.path.basename(log_file)
80
+ logger = logging.getLogger(logger_name)
81
+
82
+ # If the logger already has handlers, return it to prevent duplicate handlers
83
+ if logger.hasHandlers():
84
+ return logger
85
+
86
+ # Set the logging level
87
+ logging_level = logging.INFO
88
+ logger.setLevel(logging_level)
89
+
90
+ # Ensure the logs directory exists and is writable
91
+ logs_dir = os.path.dirname(log_file)
92
+ if not os.path.exists(logs_dir):
93
+ os.makedirs(logs_dir, exist_ok=True)
94
+
95
+ # Change the directory permissions to 0777 (0o777 in Python 3)
96
+ os.chmod(logs_dir, 0o777)
97
+
98
+ # Create a file handler to append to the log file
99
+ file_handler = logging.FileHandler(log_file)
100
+ file_handler.setFormatter(
101
+ logging.Formatter("%(asctime)s %(levelname)s %(message)s", "%Y-%m-%d %H:%M:%S")
102
+ )
103
+ logger.addHandler(file_handler)
104
+
105
+ # Create a stream handler to print to console
106
+ stream_handler = logging.StreamHandler()
107
+ stream_handler.setFormatter(
108
+ logging.Formatter("%(asctime)s %(levelname)s %(message)s", "%Y-%m-%d %H:%M:%S")
109
+ )
110
+ logger.addHandler(stream_handler)
111
+
112
+ return logger
113
+
114
+
115
+ def log_message(log_type, message, page_name):
116
+ """
117
+ Logs a message with the specified log type and additional context information.
118
+
119
+ Args:
120
+ log_type (str): The type of log ('info', 'error', 'warning', 'debug').
121
+ message (str): The message to log.
122
+ page_name (str): The name of the page associated with the message.
123
+
124
+ Returns:
125
+ None
126
+ """
127
+
128
+ # Retrieve the employee ID and project name from session state
129
+ emp_id = st.session_state["emp_id"]
130
+ project_name = st.session_state["project_name"]
131
+
132
+ # Create log file name using a function that generates a name based on the current date and EID
133
+ log_file_name = create_log_file_name(emp_id, project_name)
134
+
135
+ # Construct the full path to the log file within the "logs" within the current working directory
136
+ file_path = os.path.join(os.getcwd(), "logs", log_file_name)
137
+
138
+ # Generate and store the logger instance in session state
139
+ logger = get_logger(file_path)
140
+
141
+ # Construct the log message with all required context information
142
+ log_message = (
143
+ f"USER_EID: {emp_id} ; "
144
+ f"PROJECT_NAME: {project_name} ; "
145
+ f"PAGE_NAME: {page_name} ; "
146
+ f"MESSAGE: {message}"
147
+ )
148
+
149
+ # Log the message with the appropriate log level based on the log_type argument
150
+ if log_type == "info":
151
+ logger.info(log_message)
152
+ elif log_type == "error":
153
+ logger.error(log_message)
154
+ elif log_type == "warning":
155
+ logger.warning(log_message)
156
+ else:
157
+ logger.debug(log_message)
logo.png ADDED
logs/e111111_None_20250408.log ADDED
@@ -0,0 +1 @@
 
 
1
+ 2025-04-08 13:32:28 INFO USER_EID: e111111 ; PROJECT_NAME: None ; PAGE_NAME: Home ; MESSAGE: Project number 1 selected by employee e111111.
logs/e111111_na_demo_user_demo_project_20250407.log ADDED
@@ -0,0 +1 @@
 
 
1
+ 2025-04-07 18:10:49 INFO USER_EID: e111111 ; PROJECT_NAME: na_demo_user_demo_project ; PAGE_NAME: Home ; MESSAGE: Project number 642 selected by employee e111111.
logs/e111111_na_demo_user_demo_project_20250408.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 2025-04-08 13:32:31 INFO USER_EID: e111111 ; PROJECT_NAME: na_demo_user_demo_project ; PAGE_NAME: Home ; MESSAGE: Project number 1 selected by employee e111111.
2
+ 2025-04-08 13:32:53 INFO USER_EID: e111111 ; PROJECT_NAME: na_demo_user_demo_project ; PAGE_NAME: Home ; MESSAGE: Project number 1 selected by employee e111111.
mmm_tool_document.docx ADDED
Binary file (36.1 kB). View file
 
pages/10_Saved_Scenarios.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Saved Scenarios",
6
+ page_icon="⚖️",
7
+ layout="wide",
8
+ initial_sidebar_state="collapsed",
9
+ )
10
+
11
+ import io
12
+ import sys
13
+ import json
14
+ import pickle
15
+ import zipfile
16
+ import traceback
17
+ import numpy as np
18
+ import pandas as pd
19
+ from scenario import numerize
20
+ from openpyxl import Workbook
21
+ from post_gres_cred import db_cred
22
+ from log_application import log_message
23
+ from utilities import (
24
+ project_selection,
25
+ update_db,
26
+ set_header,
27
+ load_local_css,
28
+ name_formating,
29
+ )
30
+
31
+ schema = db_cred["schema"]
32
+ load_local_css("styles.css")
33
+ set_header()
34
+
35
+ # Initialize project name session state
36
+ if "project_name" not in st.session_state:
37
+ st.session_state["project_name"] = None
38
+
39
+ # Fetch project dictionary
40
+ if "project_dct" not in st.session_state:
41
+ project_selection()
42
+ st.stop()
43
+
44
+ # Display Username and Project Name
45
+ if "username" in st.session_state and st.session_state["username"] is not None:
46
+
47
+ cols1 = st.columns([2, 1])
48
+
49
+ with cols1[0]:
50
+ st.markdown(f"**Welcome {st.session_state['username']}**")
51
+ with cols1[1]:
52
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
53
+
54
+ # Function to get saved scenarios dictionary
55
+ def get_saved_scenarios_dict():
56
+ return st.session_state["project_dct"]["saved_scenarios"][
57
+ "saved_scenarios_dict"
58
+ ]
59
+
60
+
61
+ # Function to format values based on their size
62
+ def format_value(value):
63
+ return round(value, 4) if value < 1 else round(value, 1)
64
+
65
+
66
+ # Function to recursively convert non-serializable types to serializable ones
67
+ def convert_to_serializable(obj):
68
+ if isinstance(obj, np.ndarray):
69
+ return obj.tolist()
70
+ elif isinstance(obj, dict):
71
+ return {key: convert_to_serializable(value) for key, value in obj.items()}
72
+ elif isinstance(obj, list):
73
+ return [convert_to_serializable(element) for element in obj]
74
+ elif isinstance(obj, (int, float, str, bool, type(None))):
75
+ return obj
76
+ else:
77
+ # Fallback: convert the object to a string
78
+ return str(obj)
79
+
80
+
81
+ # Function to generate zip file of current scenario
82
+ @st.cache_data(show_spinner=False)
83
+ def download_as_zip(
84
+ df,
85
+ scenario_data,
86
+ excel_name="optimization_results.xlsx",
87
+ json_name="scenario_params.json",
88
+ ):
89
+ # Create an in-memory bytes buffer for the ZIP file
90
+ buffer = io.BytesIO()
91
+
92
+ # Create a ZipFile object in memory
93
+ with zipfile.ZipFile(buffer, "w") as zip_file:
94
+ # Save the DataFrame to an Excel file in the zip using openpyxl
95
+ excel_buffer = io.BytesIO()
96
+ workbook = Workbook()
97
+ sheet = workbook.active
98
+ sheet.title = "Results"
99
+
100
+ # Write DataFrame headers
101
+ for col_num, column_title in enumerate(df.columns, 1):
102
+ sheet.cell(row=1, column=col_num, value=column_title)
103
+
104
+ # Write DataFrame data
105
+ for row_num, row_data in enumerate(df.values, 2):
106
+ for col_num, cell_value in enumerate(row_data, 1):
107
+ sheet.cell(row=row_num, column=col_num, value=cell_value)
108
+
109
+ # Save the workbook to the in-memory buffer
110
+ workbook.save(excel_buffer)
111
+ excel_buffer.seek(0) # Rewind the buffer to the beginning
112
+ zip_file.writestr(excel_name, excel_buffer.getvalue())
113
+
114
+ # Save the dictionary to a JSON file in the zip
115
+ json_buffer = io.BytesIO()
116
+ json_buffer.write(
117
+ json.dumps(convert_to_serializable(scenario_data), indent=4).encode("utf-8")
118
+ )
119
+ json_buffer.seek(0) # Rewind the buffer to the beginning
120
+ zip_file.writestr(json_name, json_buffer.getvalue())
121
+
122
+ buffer.seek(0) # Rewind the buffer to the beginning
123
+
124
+ return buffer
125
+
126
+
127
+ # Function to delete the selected scenario from the saved scenarios dictionary
128
+ def delete_selected_scenarios(selected_scenario):
129
+ if (
130
+ selected_scenario
131
+ in st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"]
132
+ ):
133
+ del st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"][
134
+ selected_scenario
135
+ ]
136
+
137
+
138
+ try:
139
+ # Page Title
140
+ st.title("Saved Scenarios")
141
+
142
+ # Placeholder to display scenarios name
143
+ scenarios_name_placeholder = st.empty()
144
+
145
+ # Get saved scenarios dictionary and scenario name list
146
+ saved_scenarios_dict = get_saved_scenarios_dict()
147
+ scenarios_list = list(saved_scenarios_dict.keys())
148
+
149
+ # Check if the list of saved scenarios is empty
150
+ if len(scenarios_list) == 0:
151
+ # Display a warning message if no scenarios are saved
152
+ st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
153
+
154
+ # Log message
155
+ log_message(
156
+ "warning",
157
+ "No scenarios saved. Please save a scenario to load.",
158
+ "Saved Scenarios",
159
+ )
160
+
161
+ st.stop()
162
+
163
+ # Columns for scenario selection and save progress
164
+ select_scenario_col, save_progress_col = st.columns(2)
165
+ save_message_display_placeholder = st.container()
166
+
167
+ # Display a dropdown saved scenario list
168
+ selected_scenario = select_scenario_col.selectbox(
169
+ "Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
170
+ )
171
+
172
+ # Save page progress
173
+ with save_progress_col:
174
+ st.write("###")
175
+ with save_message_display_placeholder, st.spinner("Saving Progress ..."):
176
+ if save_progress_col.button("Save Progress", use_container_width=True):
177
+ # Update DB
178
+ update_db(
179
+ prj_id=st.session_state["project_number"],
180
+ page_nam="Saved Scenarios",
181
+ file_nam="project_dct",
182
+ pkl_obj=pickle.dumps(st.session_state["project_dct"]),
183
+ schema=schema,
184
+ )
185
+
186
+ # Display success message
187
+ st.success("Progress saved successfully!", icon="💾")
188
+ st.toast("Progress saved successfully!", icon="💾")
189
+
190
+ # Log message
191
+ log_message("info", "Progress saved successfully!", "Saved Scenarios")
192
+
193
+ selected_scenario_data = saved_scenarios_dict[selected_scenario]
194
+
195
+ # Scenarios Name
196
+ metrics_name = selected_scenario_data["metrics_selected"]
197
+ panel_name = selected_scenario_data["panel_selected"]
198
+ optimization_name = selected_scenario_data["optimization"]
199
+ multiplier = selected_scenario_data["multiplier"]
200
+ timeframe = selected_scenario_data["timeframe"]
201
+
202
+ # Display the scenario details with bold "Metric," "Panel," and "Optimization"
203
+ scenarios_name_placeholder.markdown(
204
+ f"**Metric**: {name_formating(metrics_name)}; **Panel**: {name_formating(panel_name)}; **Fix**: {name_formating(optimization_name)}; **Timeframe**: {name_formating(timeframe)}"
205
+ )
206
+
207
+ # Create columns for download and delete buttons
208
+ download_col, delete_col = st.columns(2)
209
+ save_message_display_placeholder = st.container()
210
+
211
+ # Channel List
212
+ channels_list = list(selected_scenario_data["channels"].keys())
213
+
214
+ # List to hold data for all channels
215
+ channels_data = []
216
+
217
+ # Iterate through each channel and gather required data
218
+ for channel in channels_list:
219
+ channel_conversion_rate = selected_scenario_data["channels"][channel][
220
+ "conversion_rate"
221
+ ]
222
+ channel_actual_spends = (
223
+ selected_scenario_data["channels"][channel]["actual_total_spends"]
224
+ * channel_conversion_rate
225
+ )
226
+ channel_optimized_spends = (
227
+ selected_scenario_data["channels"][channel]["modified_total_spends"]
228
+ * channel_conversion_rate
229
+ )
230
+
231
+ channel_actual_metrics = selected_scenario_data["channels"][channel][
232
+ "actual_total_sales"
233
+ ]
234
+ channel_optimized_metrics = selected_scenario_data["channels"][channel][
235
+ "modified_total_sales"
236
+ ]
237
+
238
+ channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]
239
+
240
+ # Extract the ROI and MROI data
241
+ actual_roi = channel_roi_mroi_data["actual_roi"]
242
+ optimized_roi = channel_roi_mroi_data["optimized_roi"]
243
+ actual_mroi = channel_roi_mroi_data["actual_mroi"]
244
+ optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
245
+
246
+ # Calculate spends per metric
247
+ spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
248
+ spends_per_metrics_optimized = (
249
+ channel_optimized_spends / channel_optimized_metrics
250
+ )
251
+
252
+ # Append the collected data as a dictionary to the list
253
+ channels_data.append(
254
+ {
255
+ "Channel Name": channel,
256
+ "Spends Actual": numerize(channel_actual_spends / multiplier),
257
+ "Spends Optimized": numerize(channel_optimized_spends / multiplier),
258
+ f"{name_formating(metrics_name)} Actual": numerize(
259
+ channel_actual_metrics / multiplier
260
+ ),
261
+ f"{name_formating(metrics_name)} Optimized": numerize(
262
+ channel_optimized_metrics / multiplier
263
+ ),
264
+ "ROI Actual": format_value(actual_roi),
265
+ "ROI Optimized": format_value(optimized_roi),
266
+ "MROI Actual": format_value(actual_mroi),
267
+ "MROI Optimized": format_value(optimized_mroi),
268
+ f"Spends per {name_formating(metrics_name)} Actual": round(
269
+ spends_per_metrics_actual, 2
270
+ ),
271
+ f"Spends per {name_formating(metrics_name)} Optimized": round(
272
+ spends_per_metrics_optimized, 2
273
+ ),
274
+ }
275
+ )
276
+
277
+ # Create a DataFrame from the collected data
278
+ df = pd.DataFrame(channels_data)
279
+
280
+ # Display the DataFrame
281
+ st.dataframe(df, hide_index=True)
282
+
283
+ # Generate download able data for selected scenario
284
+ buffer = download_as_zip(
285
+ df,
286
+ selected_scenario_data,
287
+ excel_name="optimization_results.xlsx",
288
+ json_name="scenario_params.json",
289
+ )
290
+
291
+ # Provide the buffer as a downloadable ZIP file
292
+ file_name = f"{selected_scenario}_scenario_data.zip"
293
+ if download_col.download_button(
294
+ label="Download",
295
+ data=buffer,
296
+ file_name=file_name,
297
+ mime="application/zip",
298
+ use_container_width=True,
299
+ ):
300
+ # Log message
301
+ log_message(
302
+ "info",
303
+ f"FILE_NAME: {file_name} has been successfully downloaded.",
304
+ "Saved Scenarios",
305
+ )
306
+
307
+ # Button to trigger the deletion of the selected scenario
308
+ if delete_col.button(
309
+ "Delete",
310
+ use_container_width=True,
311
+ on_click=delete_selected_scenarios,
312
+ args=(selected_scenario,),
313
+ ):
314
+ # Display success message
315
+ with save_message_display_placeholder:
316
+ st.success(
317
+ "Selected scenario successfully deleted. Click the 'Save Progress' button to ensure your changes are updated!",
318
+ icon="🗑️",
319
+ )
320
+ st.toast(
321
+ "Selected scenario successfully deleted. Click the 'Save Progress' button to ensure your changes are updated!",
322
+ icon="🗑️",
323
+ )
324
+
325
+ # Log message
326
+ log_message(
327
+ "info", "Selected scenario successfully deleted.", "Saved Scenarios"
328
+ )
329
+
330
+ except Exception as e:
331
+ # Capture the error details
332
+ exc_type, exc_value, exc_traceback = sys.exc_info()
333
+ error_message = "".join(
334
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
335
+ )
336
+
337
+ # Log message
338
+ log_message("error", f"An error occurred: {error_message}.", "Saved Scenarios")
339
+
340
+ # Display a warning message
341
+ st.warning(
342
+ "Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
343
+ icon="⚠️",
344
+ )
pages/11_AI_Model_Media_Recommendation.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from scenario import numerize
3
+ import pandas as pd
4
+ from utilities import (
5
+ format_numbers,
6
+ load_local_css,
7
+ set_header,
8
+ name_formating,
9
+ project_selection,
10
+ )
11
+ import pickle
12
+ import yaml
13
+ from yaml import SafeLoader
14
+ from scenario import class_from_dict
15
+ import plotly.express as px
16
+ import numpy as np
17
+ import plotly.graph_objects as go
18
+ import pandas as pd
19
+ from plotly.subplots import make_subplots
20
+ import sqlite3
21
+ from utilities import update_db
22
+ from collections import OrderedDict
23
+ import os
24
+
25
+ st.set_page_config(layout="wide")
26
+ load_local_css("styles.css")
27
+ set_header()
28
+
29
+ st.empty()
30
+ st.header("AI Model Media Recommendation")
31
+
32
+ # def get_saved_scenarios_dict():
33
+ # # Path to the saved scenarios file
34
+ # saved_scenarios_dict_path = os.path.join(
35
+ # st.session_state["project_path"], "saved_scenarios.pkl"
36
+ # )
37
+
38
+ # # Load existing scenarios if the file exists
39
+ # if os.path.exists(saved_scenarios_dict_path):
40
+ # with open(saved_scenarios_dict_path, "rb") as f:
41
+ # saved_scenarios_dict = pickle.load(f)
42
+ # else:
43
+ # saved_scenarios_dict = OrderedDict()
44
+
45
+ # return saved_scenarios_dict
46
+
47
+
48
+ # # Function to format values based on their size
49
+ # def format_value(value):
50
+ # return round(value, 4) if value < 1 else round(value, 1)
51
+
52
+
53
+ # # Function to recursively convert non-serializable types to serializable ones
54
+ # def convert_to_serializable(obj):
55
+ # if isinstance(obj, np.ndarray):
56
+ # return obj.tolist()
57
+ # elif isinstance(obj, dict):
58
+ # return {key: convert_to_serializable(value) for key, value in obj.items()}
59
+ # elif isinstance(obj, list):
60
+ # return [convert_to_serializable(element) for element in obj]
61
+ # elif isinstance(obj, (int, float, str, bool, type(None))):
62
+ # return obj
63
+ # else:
64
+ # # Fallback: convert the object to a string
65
+ # return str(obj)
66
+
67
+
68
+ if "username" not in st.session_state:
69
+ st.session_state["username"] = None
70
+
71
+ if "project_name" not in st.session_state:
72
+ st.session_state["project_name"] = None
73
+
74
+ if "project_dct" not in st.session_state:
75
+ project_selection()
76
+ st.stop()
77
+ # if "project_path" not in st.session_state:
78
+ # st.stop()
79
+ # if 'username' in st.session_state and st.session_state['username'] is not None:
80
+
81
+ # data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")
82
+
83
+ # try:
84
+ # with open(data_path, "rb") as f:
85
+ # data = pickle.load(f)
86
+ # except Exception as e:
87
+ # st.error(f"Please import data from the Data Import Page")
88
+ # st.stop()
89
+ # # Get saved scenarios dictionary and scenario name list
90
+ # saved_scenarios_dict = get_saved_scenarios_dict()
91
+ # scenarios_list = list(saved_scenarios_dict.keys())
92
+
93
+ # #st.write(saved_scenarios_dict)
94
+ # # Check if the list of saved scenarios is empty
95
+ # if len(scenarios_list) == 0:
96
+ # # Display a warning message if no scenarios are saved
97
+ # st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
98
+ # st.stop()
99
+
100
+ # # Display a dropdown saved scenario list
101
+ # selected_scenario = st.selectbox(
102
+ # "Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
103
+ # )
104
+ # selected_scenario_data = saved_scenarios_dict[selected_scenario]
105
+
106
+ # # Scenarios Name
107
+ # metrics_name = selected_scenario_data["metrics_selected"]
108
+ # panel_name = selected_scenario_data["panel_selected"]
109
+ # optimization_name = selected_scenario_data["optimization"]
110
+
111
+ # # Display the scenario details with bold "Metric," "Panel," and "Optimization"
112
+
113
+ # # Create columns for download and delete buttons
114
+ # download_col, delete_col = st.columns(2)
115
+
116
+
117
+ # channels_list = list(selected_scenario_data["channels"].keys())
118
+
119
+ # # List to hold data for all channels
120
+ # channels_data = []
121
+
122
+ # # Iterate through each channel and gather required data
123
+ # for channel in channels_list:
124
+ # channel_conversion_rate = selected_scenario_data["channels"][channel][
125
+ # "conversion_rate"
126
+ # ]
127
+ # channel_actual_spends = (
128
+ # selected_scenario_data["channels"][channel]["actual_total_spends"]
129
+ # * channel_conversion_rate
130
+ # )
131
+ # channel_optimized_spends = (
132
+ # selected_scenario_data["channels"][channel]["modified_total_spends"]
133
+ # * channel_conversion_rate
134
+ # )
135
+
136
+ # channel_actual_metrics = selected_scenario_data["channels"][channel][
137
+ # "actual_total_sales"
138
+ # ]
139
+ # channel_optimized_metrics = selected_scenario_data["channels"][channel][
140
+ # "modified_total_sales"
141
+ # ]
142
+
143
+ # channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]
144
+
145
+ # # Extract the ROI and MROI data
146
+ # actual_roi = channel_roi_mroi_data["actual_roi"]
147
+ # optimized_roi = channel_roi_mroi_data["optimized_roi"]
148
+ # actual_mroi = channel_roi_mroi_data["actual_mroi"]
149
+ # optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
150
+
151
+ # # Calculate spends per metric
152
+ # spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
153
+ # spends_per_metrics_optimized = channel_optimized_spends / channel_optimized_metrics
154
+
155
+ # # Append the collected data as a dictionary to the list
156
+ # channels_data.append(
157
+ # {
158
+ # "Channel Name": channel,
159
+ # "Spends Actual": channel_actual_spends,
160
+ # "Spends Optimized": channel_optimized_spends,
161
+ # f"{metrics_name} Actual": channel_actual_metrics,
162
+ # f"{name_formating(metrics_name)} Optimized": numerize(
163
+ # channel_optimized_metrics
164
+ # ),
165
+ # "ROI Actual": format_value(actual_roi),
166
+ # "ROI Optimized": format_value(optimized_roi),
167
+ # "MROI Actual": format_value(actual_mroi),
168
+ # "MROI Optimized": format_value(optimized_mroi),
169
+ # f"Spends per {name_formating(metrics_name)} Actual": numerize(
170
+ # spends_per_metrics_actual
171
+ # ),
172
+ # f"Spends per {name_formating(metrics_name)} Optimized": numerize(
173
+ # spends_per_metrics_optimized
174
+ # ),
175
+ # }
176
+ # )
177
+
178
+ # # Create a DataFrame from the collected data
179
+
180
+ ##NEW CODE##########
181
+
182
+ scenarios_name_placeholder = st.empty()
183
+
184
+
185
+ # Function to get saved scenarios dictionary
186
+ def get_saved_scenarios_dict():
187
+ return st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"]
188
+
189
+
190
+ # Function to format values based on their size
191
+ def format_value(value):
192
+ return round(value, 4) if value < 1 else round(value, 1)
193
+
194
+
195
+ # Function to recursively convert non-serializable types to serializable ones
196
+ def convert_to_serializable(obj):
197
+ if isinstance(obj, np.ndarray):
198
+ return obj.tolist()
199
+ elif isinstance(obj, dict):
200
+ return {key: convert_to_serializable(value) for key, value in obj.items()}
201
+ elif isinstance(obj, list):
202
+ return [convert_to_serializable(element) for element in obj]
203
+ elif isinstance(obj, (int, float, str, bool, type(None))):
204
+ return obj
205
+ else:
206
+ # Fallback: convert the object to a string
207
+ return str(obj)
208
+
209
+
210
+ # Get saved scenarios dictionary and scenario name list
211
+ saved_scenarios_dict = get_saved_scenarios_dict()
212
+ scenarios_list = list(saved_scenarios_dict.keys())
213
+
214
+ # Check if the list of saved scenarios is empty
215
+ if len(scenarios_list) == 0:
216
+ # Display a warning message if no scenarios are saved
217
+ st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
218
+ st.stop()
219
+
220
+ # Display a dropdown saved scenario list
221
+ selected_scenario = st.selectbox(
222
+ "Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
223
+ )
224
+ selected_scenario_data = saved_scenarios_dict[selected_scenario]
225
+
226
+ # Scenarios Name
227
+ metrics_name = selected_scenario_data["metrics_selected"]
228
+ panel_name = selected_scenario_data["panel_selected"]
229
+ optimization_name = selected_scenario_data["optimization"]
230
+ multiplier = selected_scenario_data["multiplier"]
231
+ timeframe = selected_scenario_data["timeframe"]
232
+
233
+ # Display the scenario details with bold "Metric," "Panel," and "Optimization"
234
+ scenarios_name_placeholder.markdown(
235
+ f"**Metric**: {name_formating(metrics_name)}; **Panel**: {name_formating(panel_name)}; **Fix**: {name_formating(optimization_name)}; **Timeframe**: {name_formating(timeframe)}"
236
+ )
237
+
238
+ # Create columns for download and delete buttons
239
+ download_col, delete_col = st.columns(2)
240
+
241
+ # Channel List
242
+ channels_list = list(selected_scenario_data["channels"].keys())
243
+
244
+ # List to hold data for all channels
245
+ channels_data = []
246
+
247
+ # Iterate through each channel and gather required data
248
+ for channel in channels_list:
249
+ channel_conversion_rate = selected_scenario_data["channels"][channel][
250
+ "conversion_rate"
251
+ ]
252
+ channel_actual_spends = (
253
+ selected_scenario_data["channels"][channel]["actual_total_spends"]
254
+ * channel_conversion_rate
255
+ )
256
+ channel_optimized_spends = (
257
+ selected_scenario_data["channels"][channel]["modified_total_spends"]
258
+ * channel_conversion_rate
259
+ )
260
+
261
+ channel_actual_metrics = selected_scenario_data["channels"][channel][
262
+ "actual_total_sales"
263
+ ]
264
+ channel_optimized_metrics = selected_scenario_data["channels"][channel][
265
+ "modified_total_sales"
266
+ ]
267
+
268
+ channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]
269
+
270
+ # Extract the ROI and MROI data
271
+ actual_roi = channel_roi_mroi_data["actual_roi"]
272
+ optimized_roi = channel_roi_mroi_data["optimized_roi"]
273
+ actual_mroi = channel_roi_mroi_data["actual_mroi"]
274
+ optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
275
+
276
+ # Calculate spends per metric
277
+ spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
278
+ spends_per_metrics_optimized = channel_optimized_spends / channel_optimized_metrics
279
+
280
+ # Append the collected data as a dictionary to the list
281
+ channels_data.append(
282
+ {
283
+ "Channel Name": channel,
284
+ "Spends Actual": (channel_actual_spends / multiplier),
285
+ "Spends Optimized": (channel_optimized_spends / multiplier),
286
+ f"{name_formating(metrics_name)} Actual": (
287
+ channel_actual_metrics / multiplier
288
+ ),
289
+ f"{name_formating(metrics_name)} Optimized": (
290
+ channel_optimized_metrics / multiplier
291
+ ),
292
+ "ROI Actual": format_value(actual_roi),
293
+ "ROI Optimized": format_value(optimized_roi),
294
+ "MROI Actual": format_value(actual_mroi),
295
+ "MROI Optimized": format_value(optimized_mroi),
296
+ f"Spends per {name_formating(metrics_name)} Actual": round(
297
+ spends_per_metrics_actual, 2
298
+ ),
299
+ f"Spends per {name_formating(metrics_name)} Optimized": round(
300
+ spends_per_metrics_optimized, 2
301
+ ),
302
+ }
303
+ )
304
+
305
+ # Create a DataFrame from the collected data
306
+ # df = pd.DataFrame(channels_data)
307
+
308
+ # # Display the DataFrame
309
+ # st.dataframe(df, hide_index=True)
310
+
311
+ summary_df_sorted = pd.DataFrame(channels_data).sort_values(by=["Spends Optimized"])
312
+
313
+
314
+ summary_df_sorted["Delta"] = (
315
+ summary_df_sorted["Spends Optimized"] - summary_df_sorted["Spends Actual"]
316
+ )
317
+
318
+
319
+ summary_df_sorted["Delta_percent"] = np.round(
320
+ (summary_df_sorted["Delta"]) / summary_df_sorted["Spends Actual"] * 100, 2
321
+ )
322
+
323
+ # spends_data = pd.read_excel("Overview_data_test.xlsx")
324
+
325
+
326
+ st.header("Optimized Media Spend Overview")
327
+
328
+ channel_colors = px.colors.qualitative.Plotly
329
+
330
+ fig = make_subplots(
331
+ rows=1,
332
+ cols=3,
333
+ subplot_titles=("Actual Spend", "Spends Optimized", "Delta"),
334
+ horizontal_spacing=0.05,
335
+ )
336
+
337
+ for i, channel in enumerate(summary_df_sorted["Channel Name"].unique()):
338
+ channel_df = summary_df_sorted[summary_df_sorted["Channel Name"] == channel]
339
+ channel_color = channel_colors[i % len(channel_colors)]
340
+
341
+ fig.add_trace(
342
+ go.Bar(
343
+ x=channel_df["Spends Actual"],
344
+ y=channel_df["Channel Name"],
345
+ text=channel_df["Spends Actual"].apply(format_numbers),
346
+ marker_color=channel_color,
347
+ orientation="h",
348
+ ),
349
+ row=1,
350
+ col=1,
351
+ )
352
+
353
+ fig.add_trace(
354
+ go.Bar(
355
+ x=channel_df["Spends Optimized"],
356
+ y=channel_df["Channel Name"],
357
+ text=channel_df["Spends Optimized"].apply(format_numbers),
358
+ marker_color=channel_color,
359
+ orientation="h",
360
+ showlegend=False,
361
+ ),
362
+ row=1,
363
+ col=2,
364
+ )
365
+
366
+ fig.add_trace(
367
+ go.Bar(
368
+ x=channel_df["Delta_percent"],
369
+ y=channel_df["Channel Name"],
370
+ text=channel_df["Delta_percent"].apply(lambda x: f"{x:.0f}%"),
371
+ marker_color=channel_color,
372
+ orientation="h",
373
+ showlegend=False,
374
+ ),
375
+ row=1,
376
+ col=3,
377
+ )
378
+ fig.update_layout(height=600, width=900, title="", showlegend=False)
379
+
380
+ fig.update_yaxes(showticklabels=False, row=1, col=2)
381
+ fig.update_yaxes(showticklabels=False, row=1, col=3)
382
+
383
+ fig.update_xaxes(showticklabels=False, row=1, col=1)
384
+ fig.update_xaxes(showticklabels=False, row=1, col=2)
385
+ fig.update_xaxes(showticklabels=False, row=1, col=3)
386
+
387
+
388
+ st.plotly_chart(fig, use_container_width=True)
389
+
390
+
391
+ summary_df_sorted["Perc_alloted"] = np.round(
392
+ summary_df_sorted["Spends Optimized"] / summary_df_sorted["Spends Optimized"].sum(),
393
+ 2,
394
+ )
395
+ st.header("Optimized Media Spend Allocation")
396
+
397
+ fig = make_subplots(
398
+ rows=1,
399
+ cols=2,
400
+ subplot_titles=("Spends Optimized", "% Split"),
401
+ horizontal_spacing=0.05,
402
+ )
403
+
404
+ for i, channel in enumerate(summary_df_sorted["Channel Name"].unique()):
405
+ channel_df = summary_df_sorted[summary_df_sorted["Channel Name"] == channel]
406
+ channel_color = channel_colors[i % len(channel_colors)]
407
+
408
+ fig.add_trace(
409
+ go.Bar(
410
+ x=channel_df["Spends Optimized"],
411
+ y=channel_df["Channel Name"],
412
+ text=channel_df["Spends Optimized"].apply(format_numbers),
413
+ marker_color=channel_color,
414
+ orientation="h",
415
+ ),
416
+ row=1,
417
+ col=1,
418
+ )
419
+
420
+ fig.add_trace(
421
+ go.Bar(
422
+ x=channel_df["Perc_alloted"],
423
+ y=channel_df["Channel Name"],
424
+ text=channel_df["Perc_alloted"].apply(lambda x: f"{100*x:.0f}%"),
425
+ marker_color=channel_color,
426
+ orientation="h",
427
+ showlegend=False,
428
+ ),
429
+ row=1,
430
+ col=2,
431
+ )
432
+
433
+ fig.update_layout(height=600, width=900, title="", showlegend=False)
434
+
435
+ fig.update_yaxes(showticklabels=False, row=1, col=2)
436
+ fig.update_yaxes(showticklabels=False, row=1, col=3)
437
+
438
+ fig.update_xaxes(showticklabels=False, row=1, col=1)
439
+ fig.update_xaxes(showticklabels=False, row=1, col=2)
440
+ fig.update_xaxes(showticklabels=False, row=1, col=3)
441
+
442
+
443
+ st.plotly_chart(fig, use_container_width=True)
444
+
445
+
446
+ st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][
447
+ "imputed_tool_df"
448
+ ]
449
+ st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][
450
+ "category_dict"
451
+ ]
452
+
453
+ effectiveness_overall = pd.DataFrame()
454
+
455
+ response_metrics = list(
456
+ *[
457
+ st.session_state["category_dict"][key]
458
+ for key in st.session_state["category_dict"].keys()
459
+ if key == "Response Metrics"
460
+ ]
461
+ )
462
+
463
+ effectiveness_overall = (
464
+ st.session_state["cleaned_data"][response_metrics].sum().reset_index()
465
+ )
466
+
467
+ effectiveness_overall.columns = ["ResponseMetricName", "ResponseMetricValue"]
468
+
469
+
470
+ effectiveness_overall["Efficiency"] = effectiveness_overall["ResponseMetricValue"].map(
471
+ lambda x: x / summary_df_sorted["Spends Optimized"].sum()
472
+ )
473
+
474
+
475
+ columns6 = st.columns(3)
476
+
477
+ effectiveness_overall.sort_values(
478
+ by=["ResponseMetricValue"], ascending=False, inplace=True
479
+ )
480
+ effectiveness_overall = np.round(effectiveness_overall, 2)
481
+
482
+ columns4 = st.columns([0.55, 0.45])
483
+
484
+ # effectiveness_overall = effectiveness_overall.sort_values(by=["ResponseMetricValue"])
485
+
486
+ # with columns4[0]:
487
+ # fig = px.funnel(
488
+ # effectiveness_overall,
489
+ # x="ResponseMetricValue",
490
+ # y="ResponseMetricName",
491
+ # color="ResponseMetricName",
492
+ # title="Effectiveness",
493
+ # )
494
+ # fig.update_layout(
495
+ # showlegend=False,
496
+ # yaxis=dict(tickmode="array"),
497
+ # )
498
+ # fig.update_traces(
499
+ # textinfo="value",
500
+ # textposition="inside",
501
+ # texttemplate="%{x:.2s} ",
502
+ # hoverinfo="y+x+percent initial",
503
+ # )
504
+ # fig.update_traces(
505
+ # marker=dict(line=dict(color="black", width=2)),
506
+ # selector=dict(marker=dict(color="blue")),
507
+ # )
508
+
509
+ # st.plotly_chart(fig, use_container_width=True)
510
+
511
+ # with columns4[1]:
512
+ # fig1 = px.bar(
513
+ # effectiveness_overall.sort_values(by=["ResponseMetricValue"], ascending=False),
514
+ # x="Efficiency",
515
+ # y="ResponseMetricName",
516
+ # color="ResponseMetricName",
517
+ # text_auto=True,
518
+ # title="Efficiency",
519
+ # )
520
+
521
+ # # Update layout and traces
522
+ # fig1.update_traces(
523
+ # customdata=effectiveness_overall["Efficiency"], textposition="auto"
524
+ # )
525
+ # fig1.update_layout(showlegend=False)
526
+ # fig1.update_yaxes(title="", showticklabels=False)
527
+ # fig1.update_xaxes(title="", showticklabels=False)
528
+ # fig1.update_xaxes(tickfont=dict(size=20))
529
+ # fig1.update_yaxes(tickfont=dict(size=20))
530
+ # st.plotly_chart(fig1, use_container_width=True)
531
+
532
+ # Function to format metric names
533
+ def format_metric_name(metric_name):
534
+ return str(metric_name).lower().replace("response_metric_", "").replace("_", " ").strip().title()
535
+
536
+ # Apply the formatting function to the 'ResponseMetricName' column
537
+ effectiveness_overall["FormattedMetricName"] = effectiveness_overall["ResponseMetricName"].apply(format_metric_name)
538
+
539
+ # Multiselect widget with all options as default, but using the formatted names for display
540
+ all_metrics = effectiveness_overall["FormattedMetricName"].unique()
541
+ selected_metrics = st.multiselect(
542
+ "Select Metrics to Display",
543
+ options=all_metrics,
544
+ default=all_metrics
545
+ )
546
+
547
+ # Filter the data based on the selected metrics (using formatted names)
548
+ if selected_metrics:
549
+ filtered_data = effectiveness_overall[
550
+ effectiveness_overall["FormattedMetricName"].isin(selected_metrics)
551
+ ]
552
+
553
+ # Sort values for funnel plot
554
+ filtered_data = filtered_data.sort_values(by=["ResponseMetricValue"])
555
+
556
+ # Generate a consistent color mapping for all selected metrics
557
+ color_map = {metric: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]
558
+ for i, metric in enumerate(filtered_data["FormattedMetricName"].unique())}
559
+
560
+ # First plot: Funnel
561
+ with columns4[0]:
562
+ fig = px.funnel(
563
+ filtered_data,
564
+ x="ResponseMetricValue",
565
+ y="FormattedMetricName", # Use formatted names for y-axis
566
+ color="FormattedMetricName", # Use formatted names for color
567
+ color_discrete_map=color_map, # Ensure consistent colors
568
+ title="Effectiveness",
569
+ )
570
+ fig.update_layout(
571
+ showlegend=False,
572
+ yaxis=dict(title="Response Metric", tickmode="array"), # Set y-axis label to 'Response Metric'
573
+ )
574
+ fig.update_traces(
575
+ textinfo="value",
576
+ textposition="inside",
577
+ texttemplate="%{x:.2s} ",
578
+ hoverinfo="y+x+percent initial",
579
+ )
580
+ fig.update_traces(
581
+ marker=dict(line=dict(color="black", width=2)),
582
+ selector=dict(marker=dict(color="blue")),
583
+ )
584
+
585
+ st.plotly_chart(fig, use_container_width=True)
586
+
587
+ # Second plot: Bar chart
588
+ with columns4[1]:
589
+ fig1 = px.bar(
590
+ filtered_data.sort_values(by=["ResponseMetricValue"], ascending=False),
591
+ x="Efficiency",
592
+ y="FormattedMetricName", # Use formatted names for y-axis
593
+ color="FormattedMetricName", # Use formatted names for color
594
+ color_discrete_map=color_map, # Ensure consistent colors
595
+ text_auto=True,
596
+ title="Efficiency",
597
+ )
598
+
599
+ # Update layout and traces
600
+ fig1.update_traces(
601
+ customdata=filtered_data["Efficiency"], textposition="auto"
602
+ )
603
+ fig1.update_layout(showlegend=False)
604
+ fig1.update_yaxes(title="", showticklabels=False)
605
+ fig1.update_xaxes(title="", showticklabels=False)
606
+ fig1.update_xaxes(tickfont=dict(size=20))
607
+ fig1.update_yaxes(tickfont=dict(size=20))
608
+ st.plotly_chart(fig1, use_container_width=True)
609
+ else:
610
+ st.info("Please select at least one response metric to display the charts.")
611
+
612
+ st.header("Return Forecast by Media Channel")
613
+
614
+ with st.expander("Return Forecast by Media Channel"):
615
+
616
+
617
+ metric = metrics_name
618
+
619
+ metric = metric.lower().replace("_", " ") + " " + "actual"
620
+ summary_df_sorted.columns = [
621
+ col.lower().replace("_", " ") for col in summary_df_sorted.columns
622
+ ]
623
+
624
+ effectiveness = summary_df_sorted[metric]
625
+
626
+ summary_df_sorted["Efficiency"] = (
627
+ summary_df_sorted[metric] / summary_df_sorted["spends optimized"]
628
+ )
629
+
630
+ channel_colors = px.colors.qualitative.Plotly
631
+
632
+ fig = make_subplots(
633
+ rows=1,
634
+ cols=3,
635
+ subplot_titles=("Optimized Spends", "Effectiveness", "Efficiency"),
636
+ horizontal_spacing=0.05,
637
+ )
638
+
639
+ for i, channel in enumerate(summary_df_sorted["channel name"].unique()):
640
+ channel_df = summary_df_sorted[summary_df_sorted["channel name"] == channel]
641
+ channel_color = channel_colors[i % len(channel_colors)]
642
+
643
+ fig.add_trace(
644
+ go.Bar(
645
+ x=channel_df["spends optimized"],
646
+ y=channel_df["channel name"],
647
+ text=channel_df["spends optimized"].apply(format_numbers),
648
+ marker_color=channel_color,
649
+ orientation="h",
650
+ ),
651
+ row=1,
652
+ col=1,
653
+ )
654
+
655
+ fig.add_trace(
656
+ go.Bar(
657
+ x=channel_df[metric],
658
+ y=channel_df["channel name"],
659
+ text=channel_df[metric].apply(format_numbers),
660
+ marker_color=channel_color,
661
+ orientation="h",
662
+ showlegend=False,
663
+ ),
664
+ row=1,
665
+ col=2,
666
+ )
667
+
668
+ fig.add_trace(
669
+ go.Bar(
670
+ x=channel_df["Efficiency"],
671
+ y=channel_df["channel name"],
672
+ text=channel_df["Efficiency"].apply(lambda x: f"{x:.2f}"),
673
+ marker_color=channel_color,
674
+ orientation="h",
675
+ showlegend=False,
676
+ ),
677
+ row=1,
678
+ col=3,
679
+ )
680
+
681
+ fig.update_layout(
682
+ height=600,
683
+ width=900,
684
+ title="Media Channel Performance",
685
+ showlegend=False,
686
+ )
687
+
688
+ fig.update_yaxes(showticklabels=False, row=1, col=2)
689
+ fig.update_yaxes(showticklabels=False, row=1, col=3)
690
+
691
+ fig.update_xaxes(showticklabels=False, row=1, col=1)
692
+ fig.update_xaxes(showticklabels=False, row=1, col=2)
693
+ fig.update_xaxes(showticklabels=False, row=1, col=3)
694
+
695
+ st.plotly_chart(fig, use_container_width=True)
pages/12_Glossary.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import streamlit as st
3
+ import base64
4
+
5
+ from utilities import set_header, load_local_css
6
+
7
+ st.set_page_config(
8
+ page_title="Glossary",
9
+ page_icon=":shark:",
10
+ layout="wide",
11
+ initial_sidebar_state="collapsed",
12
+ )
13
+
14
+ load_local_css("styles.css")
15
+ set_header()
16
+
17
+
18
+ st.header("Glossary")
19
+
20
+ # Glossary
21
+ st.markdown(
22
+ """
23
+
24
+ ### 1. Glossary - Home​
25
+ **Blend Employee ID:** Users should have access to the UI through their Blend Employee ID. User data should be available and validated through this ID.​
26
+
27
+ **API Data:** Incorporate data retrieved from APIs to be used in the tool.​
28
+
29
+ ### 2. Glossary – Data Import​
30
+ **Granularity:** Defines the level of detail to which the data should be aggregated.​
31
+
32
+ **Panel:** Represents columns corresponding to markets, DMAs (Designated Market Areas).​
33
+
34
+ **Response Metrics:** Target variables or metrics such as App Installs, Revenue, Form Submission/Conversion. These are the variables used for model building and spends optimization.​
35
+
36
+ **Spends:** Variables representing spend data for all media channels.​
37
+
38
+ **Exogenous:** External variables like bank holidays, GDP, online trends, interest rates.​
39
+
40
+ **Internal:** Internal variables such as discounts or promotions.​
41
+
42
+ ### 3. Glossary – Data Assessment​
43
+ **Trendline:** Represents the linear average movement over the entire time period.​
44
+
45
+ **Media Variables:** Variables related to media activities such as Impressions, Clicks, Spends, Views. Examples: Bing Search Clicks, Instagram Impressions, YouTube Impressions.​
46
+
47
+ **Non-Media Variables:** Variables that exclude media-related data. Examples: Discount, Holiday, Google Trend.​
48
+
49
+ **CPM (Cost per 1000 Impressions):** Calculated as (YouTube Spends / YouTube Impressions) * 1000.​
50
+
51
+ **CPC (Cost per 1000 Clicks):** Calculated as (YouTube Spends / YouTube Clicks) * 1000.​
52
+
53
+ **Correlation:** Pearson correlation measures the linear relationship between two continuous variables, indicating the strength and direction of their association.​
54
+
55
+ ### 4. Glossary – Transformations​
56
+
57
+ **Transformations:** Transformations involve adjusting input variables to capture nonlinear relationships, like diminishing returns and advertising carryover. They are always applied at the panel level if panel exist; if no panel exists, the transformations are applied directly to the aggregated level.
58
+
59
+ **Lag:** Shifts the data backward by a specified number of periods. Formula: Lagged Series = Xt−lag
60
+
61
+ **Lead:** Shifts the data forward by a specified number of periods. Formula: Lead Series = Xt+lead
62
+
63
+ **Moving Average:** Smooths the data by averaging values over a specified window size. Formula: Moving Average = 1/𝑛 ∑1_0^(𝑛−1)𝑋_(𝑡−1)
64
+
65
+ **Saturation:** Applies a saturation effect to the data based on a specified saturation percentage. Formula: 𝑌_𝑡= (1/(1 +[(𝑠𝑎𝑡𝑢𝑟𝑎𝑡𝑖𝑜𝑛 𝑝𝑜𝑖𝑛𝑡)/𝑋_𝑡 ]^𝑠𝑡𝑒𝑒𝑝𝑛𝑒𝑠𝑠 )) × 𝑋_𝑡
66
+
67
+ **Power:** Raises the data to a specified power. Formula: Yt = Xtpower
68
+
69
+ **Adstock:** Applies a decay effect to the data, simulating diminishing returns over time. Formula: Yt = Xt + decay rate x Yt-1
70
+
71
+ ### 5. Glossary - AI Model Build
72
+ **Train Set:** The train set is a subset of the data on which the AI Model is trained. It is standard practice to select 70-75% of the data as train set.​
73
+
74
+ **Test Set:** The test set is the subset of data on which the AI Model’s performance is tested. There will be no common records between the train and test sets.​
75
+
76
+ **Residual:** Residual is defined as the difference between true value and predicted value from the AI Model. 𝑅𝑒𝑠𝑖𝑑𝑢𝑎𝑙=𝑇𝑟𝑢𝑒 𝑉𝑎𝑙𝑢𝑒 −𝑃𝑟𝑒𝑑𝑖𝑐𝑡𝑒𝑑 𝑉𝑎𝑙𝑢𝑒
77
+
78
+ **Actual VS Predicted Plot:** An actual vs. predicted plot visualizes the relationship between the actual values and the values predicted by the AI model, helping to assess model performance and identify patterns or discrepancies.​
79
+
80
+ **MAPE:** MAPE or Mean Absolute Percentage Error indicates the percentage of error in the AI Model’s predictions. We use a variant of MAPE called Weighted Average Percentage Error. 𝑀𝐴𝑃𝐸= ∑ |𝐴𝑐𝑡𝑢𝑎𝑙 𝑉𝑎𝑙𝑢𝑒 −𝑃𝑟𝑒𝑑𝑖𝑐𝑡𝑒𝑑 𝑉𝑎𝑙𝑢𝑒| ÷∑ [|𝐴𝑐𝑡𝑢𝑎𝑙 𝑉𝑎𝑙𝑢𝑒|]
81
+
82
+ **R-Squared:** R-Squared is a number that tells you how well the independent variable(s) in a AI model explain the variation in the dependent variable. Indicates the goodness of fit, with values closer to 1 suggesting a better fit.​
83
+
84
+ **Adjusted R-Squared:** Adjusted R-squared modifies the R-squared value to account for the number of predictors in the model, providing a more accurate measure of goodness of fit by penalizing the addition of non-significant predictors.​
85
+
86
+ **Multicollinearity:** Multicollinearity refers to a situation where two or more independent variables (media channels or marketing inputs) are highly correlated with each other. This makes it difficult to isolate the individual effect of each channel on the dependent variable (e.g., sales). It can lead to unreliable coefficient estimates, making it hard to determine the true impact of each media type.
87
+
88
+ ### 6. Glossary – AI Model Tuning
89
+ **Event:** An event refers to a specific occurrence or campaign, such as a promotion, holiday, or product launch, that can impact the performance of response metric​
90
+
91
+ **Trend:** Trend is a straight line which helps capture the underlying direction or pattern in the data over time, helping to identify and quantify the increasing or decreasing trend of the dependent variable.​
92
+
93
+ **Day of Week:** The "day of week" feature represents the specific day within a week (e.g., Monday, Tuesday) and is used to capture patterns or seasonal effects that occur on specific days.​
94
+
95
+ **Sine & Cosine Waves:** Sine and cosine waves are mathematical functions used to capture seasonality in the dependent variable​
96
+
97
+ **Contribution:** Contribution of a channel is the percentage of its contribution to the response metric. Contribution is an output of the AI Model, calculated using the media data and model’s coefficients​
98
+
99
+ ### 7. Glossary – Response Curves
100
+
101
+ **Response Curve:** A response curve in media mix modeling represents the relationship between media inputs (e.g., impressions, clicks) on the X-axis and the resulting business outcomes (e.g., revenue, app installs) on the Y-axis, illustrating the impact of media variables on the desired response metric.
102
+
103
+ **R-squared (R-squared):** R-squared (R-squared) is a statistical measure that represents the proportion of the variance in the dependent variable that is predictable from the independent variables. It indicates how well the regression model fits the data, with a value between 0 and 1, where 1 indicates a perfect fit.
104
+ **Formula for R-squared (R-squared):**
105
+ R-squared = 1 − SSres / SStot
106
+ Where:
107
+ - **SSres** is the sum of squares of residuals (errors).
108
+ - **SStot** is the total sum of squares (the variance of the dependent variable).
109
+
110
+ **Actual R-squared:** Actual R-squared is used to evaluate how well a s-curve fits the data. Modified R-squared is used to evaluate how well a modified (manually tuned) s-curve fits the data. The difference between modified s-curve and actual s-curve shows how much the fit improves or worsens.
111
+
112
+ ### 8. Glossary – Scenario Planner
113
+
114
+ **CPA (Cost Per Acquisition):** The cost associated with acquiring a customer or conversion. CPA = Total Spend / Number of Acquisitions
115
+
116
+ **ROI (Return on Investment):** A measure of the profitability of an investment relative to its cost. ROI = (Revenue - Total Spend) / Total Spend
117
+
118
+ **mROI (Marginal Return on Investment):** Measures the additional return generated by an additional unit of investment. It represents the slope of the ROI curve and shows the incremental effectiveness of the investment. mROI = Δspend / ΔRevenue​
119
+
120
+ **Bounds (Upper, Lower):** Constraints applied during optimization to ensure that media channel spends remain within specified limits. These bounds help in defining the feasible range for investments and prevent overspending or underspending.
121
+
122
+ **Panel & Timeframe:**
123
+ - **Panel:** The ability to optimize media strategies at a granular level based on specific categories such as geographies, product types, customer segments, or other defined groups. This allows for tailored strategies that address the unique characteristics and needs of each panel.
124
+ - **Timeframe:** The ability to optimize media strategies within specific time periods, such as by month, quarter, or year. This enables precise adjustments, ensuring that the approach aligns with specific business cycles.
125
+
126
+ **Actual Vs Optimized:**
127
+ - **Actual:** This category includes the unoptimized values, such as the actual spends and actual response metrics, reflecting the current or historical performance without any optimization applied.
128
+ - **Optimized:** This category encompasses the optimized values, such as optimized spends and optimized response metrics, representing the improved or adjusted figures after the optimization process.
129
+
130
+ **Regions:**
131
+ - **Yellow Region:** This represents the under-invested area where additional investments can lead to entering a region with a higher ROI than the average. Investing more in this region can yield better returns.
132
+ - **Green Region:** This is the optimal area where the ROI of the channel is above the average ROI, and the Marginal Return on Investment (mROI) is greater than 1, indicating that each additional dollar spent is generating substantial returns.
133
+ - **Red Region:** This represents the over-invested or saturated area where the mROI is less than 1, meaning that additional spending yields diminishing returns and is less effective.
134
+
135
+ **Important Formulas:���**
136
+ **1. Incremental CPA (%)** = (Total Optimized CPA−Total Actual CPA) / (Total Actual CPA​)​
137
+ **2. Incremental Spend (%)** = (Total Optimized Spend−Total Actual Spend) / (Total Actual Spend​)
138
+ **Note:** Calculated based on total actual spends (base spends is always zero)​
139
+ **3. Incremental Response Metric (%) [Media]** = (Total Optimized Response Metric−Total Actual Response​ Metric) / (Total Actual Response Metric−Total Base Response Metric​)
140
+ **Note:** Calculated based on media portion of total actual response metric only, excluding the fixed base contribution
141
+ **4. Incremental Response Metric (%) [Total]** = (Total Optimized Response Metric−Total Actual Response​ Metric) / (Total Actual Response Metric​)
142
+ **Note:** Calculated based on total actual response metric only, including the fixed base contribution
143
+
144
+ ### 10. Glossary – Model Optimized Recommendation
145
+
146
+ **% Split:** The percentage distribution of total media spend across different channels, optimized to achieve the best results.
147
+
148
+ **Return forecast by Media Channel:** Effectiveness and Efficiency for selected response metric.
149
+ """
150
+ )
pages/14_User_Management.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(layout="wide")
4
+ import pandas as pd
5
+ from utilities import (
6
+ load_local_css,
7
+ set_header,
8
+ query_excecuter_postgres,
9
+ store_hashed_password,
10
+ verify_password,
11
+ is_pswrd_flag_set,
12
+ )
13
+ import psycopg2
14
+
15
+ #
16
+ import bcrypt
17
+ import re
18
+ import time
19
+ import random
20
+ import string
21
+ from log_application import log_message
22
+
23
+ # setting page config
24
+ load_local_css("styles.css")
25
+ set_header()
26
+
27
+ # schema=db_cred['schema']
28
+
29
+ db_cred = None
30
+
31
+
32
+ def fetch_password_hash_key():
33
+ query = f"""
34
+ SELECT emp_id
35
+ FROM mmo_users
36
+ WHERE emp_nam ='admin';
37
+ """
38
+ hashkey = query_excecuter_postgres(query, db_cred, insert=False)
39
+
40
+ return hashkey
41
+
42
+
43
+ def fetch_users_with_access():
44
+ # Query to get allowed employee IDs for the project
45
+ unique_users_query = f"""
46
+ SELECT DISTINCT emp_id, emp_nam, emp_typ
47
+ FROM mmo_users;
48
+ """
49
+ try:
50
+ unique_users_result = query_excecuter_postgres(
51
+ unique_users_query, db_cred, insert=False
52
+ )
53
+
54
+ unique_users_result_df = pd.DataFrame(unique_users_result)
55
+ unique_users_result_df.columns = ["User ID", "User Name", "User Type"]
56
+
57
+ st.session_state["df_users"] = unique_users_result_df
58
+ except:
59
+ st.session_state["df_users"] = pd.DataFrame()
60
+
61
+
62
+ def add_user(emp_id, emp_name, emp_type, admin_id, plain_text_password):
63
+ """
64
+ Adds a new user to the mmo_users table if the user does not already exist.
65
+
66
+ Parameters:
67
+ emp_id (str): Employee ID of the user.
68
+ emp_name (str): Name of the user.
69
+ emp_type (str): Type of the user.
70
+ db_cred (dict): Database credentials with keys 'dbname', 'user', 'password', 'host', 'port'.
71
+ schema (str): The schema name.
72
+
73
+ Returns:
74
+ str: A success or error message.
75
+ """
76
+ # Query to check if the user already exists
77
+ check_user_query = f"""
78
+ SELECT emp_id FROM mmo_users
79
+ WHERE emp_id = ?;
80
+ """
81
+ params_check = (emp_id,)
82
+
83
+ try:
84
+ # Check if the user already exists
85
+ existing_user = query_excecuter_postgres(
86
+ check_user_query, db_cred, params=params_check, insert=False
87
+ )
88
+ # st.write(existing_user)
89
+
90
+ if existing_user:
91
+ # If user exists, return a warning message
92
+ return False
93
+
94
+ # Query to add a new user to the mmo_users table
95
+ else:
96
+ hashed_password = bcrypt.hashpw(
97
+ plain_text_password.encode("utf-8"), bcrypt.gensalt()
98
+ )
99
+
100
+ # Convert the byte string to a regular string for storage
101
+ hashed_password_str = hashed_password.decode("utf-8")
102
+ add_user_query = f"""
103
+ INSERT INTO mmo_users
104
+ (emp_id, emp_nam, emp_typ, crte_dt_tm, crte_by_uid,pswrd_key)
105
+ VALUES (?, ?, ?,datetime('now'), ?,?);
106
+ """
107
+ params_add = (emp_id, emp_name, emp_type, admin_id, hashed_password_str)
108
+
109
+ # Execute the insert query
110
+ query_excecuter_postgres(
111
+ add_user_query, db_cred, params=params_add, insert=True
112
+ )
113
+
114
+ return True
115
+
116
+ except Exception as e:
117
+ return False
118
+
119
+
120
+ # def delete_users_by_names(emp_id):
121
+
122
+ # # Sanitize and format the list of names for the SQL query
123
+ # formatted_ids =tuple(emp_id)
124
+
125
+ # # Query to delete users based on the employee names
126
+ # delete_user_query = f"""
127
+ # DELETE FROM mmo_users
128
+ # WHERE emp_id IN ?;
129
+ # """
130
+ # try:
131
+ # # Execute the delete query
132
+ # query_excecuter_postgres(delete_user_query, db_cred,params=(formatted_ids,), insert=True)
133
+
134
+ # return f"{len(emp_id)} users deleted successfully."
135
+
136
+ # except Exception as e:
137
+ # st.write(e)
138
+ # print(e)
139
+ # return f"An error occurred: {e}"
140
+
141
+
142
+ def delete_users_by_names(emp_ids):
143
+ """
144
+ Deletes users from the mmo_users table and their associated projects from the mmo_project_meta_data
145
+ and mmo_projects tables.
146
+
147
+ Parameters:
148
+ emp_ids (list): A list of employee IDs to delete.
149
+
150
+ Returns:
151
+ str: A success or error message.
152
+ """
153
+ # Convert the list of employee IDs into a tuple for the SQL query
154
+ formatted_ids = tuple(emp_ids)
155
+
156
+ # Queries to delete projects and users
157
+ delete_projects_meta_query = """
158
+ DELETE FROM mmo_project_meta_data
159
+ WHERE prj_id IN (
160
+ SELECT prj_id FROM mmo_projects
161
+ WHERE prj_ownr_id IN ({})
162
+ );
163
+ """.format(
164
+ ",".join("?" * len(formatted_ids))
165
+ )
166
+
167
+ delete_projects_query = """
168
+ DELETE FROM mmo_projects
169
+ WHERE prj_ownr_id IN ({});
170
+ """.format(
171
+ ",".join("?" * len(formatted_ids))
172
+ )
173
+
174
+ delete_user_query = """
175
+ DELETE FROM mmo_users
176
+ WHERE emp_id IN ({});
177
+ """.format(
178
+ ",".join("?" * len(formatted_ids))
179
+ )
180
+
181
+ try:
182
+ # Execute the delete queries using query_excecuter_postgres
183
+ query_excecuter_postgres(
184
+ delete_projects_meta_query, params=formatted_ids, insert=True
185
+ )
186
+ query_excecuter_postgres(
187
+ delete_projects_query, params=formatted_ids, insert=True
188
+ )
189
+ query_excecuter_postgres(delete_user_query, params=formatted_ids, insert=True)
190
+
191
+ return (
192
+ f"{len(emp_ids)} users and their associated projects deleted successfully."
193
+ )
194
+
195
+ except Exception as e:
196
+ return f"An error occurred: {e}"
197
+
198
+
199
+ def reset_user_name():
200
+ st.session_state.user_id = ""
201
+ st.session_state.user_name = ""
202
+
203
+
204
+ def contains_sql_keywords_check(user_input):
205
+
206
+ sql_keywords = [
207
+ "SELECT",
208
+ "INSERT",
209
+ "UPDATE",
210
+ "DELETE",
211
+ "DROP",
212
+ "ALTER",
213
+ "CREATE",
214
+ "GRANT",
215
+ "REVOKE",
216
+ "UNION",
217
+ "JOIN",
218
+ "WHERE",
219
+ "HAVING",
220
+ "EXEC",
221
+ "TRUNCATE",
222
+ "REPLACE",
223
+ "MERGE",
224
+ "DECLARE",
225
+ "SHOW",
226
+ "FROM",
227
+ ]
228
+
229
+ pattern = "|".join(re.escape(keyword) for keyword in sql_keywords)
230
+
231
+ return re.search(pattern, user_input, re.IGNORECASE)
232
+
233
+
234
+ def update_password_and_flag(user_id, plain_text_password, flag=False):
235
+ """
236
+ Hashes a plain text password using bcrypt, converts it to a UTF-8 string, and stores it as text.
237
+
238
+ Parameters:
239
+ plain_text_password (str): The plain text password to be hashed.
240
+ db_cred (dict): The database credentials including dbname, user, password, host, and port.
241
+ """
242
+ # Hash the plain text password
243
+ hashed_password = bcrypt.hashpw(
244
+ plain_text_password.encode("utf-8"), bcrypt.gensalt()
245
+ )
246
+
247
+ # Convert the byte string to a regular string for storage
248
+ hashed_password_str = hashed_password.decode("utf-8")
249
+
250
+ # SQL query to update the pswrd_key for the specified user_id
251
+ if flag:
252
+ query = f"""
253
+ UPDATE mmo_users
254
+ SET pswrd_key = ?, pswrd_flag = 0
255
+ WHERE emp_id = ?;
256
+ """
257
+ else:
258
+ query = f"""
259
+ UPDATE mmo_users
260
+ SET pswrd_key = ?
261
+ WHERE emp_id = ?;
262
+ """
263
+ # Execute the query using the existing query_excecuter_postgres function
264
+ query_excecuter_postgres(
265
+ query=query, db_cred=db_cred, params=(hashed_password_str, user_id), insert=True
266
+ )
267
+
268
+
269
+ def reset_pass_key():
270
+ st.session_state["pass_key"] = ""
271
+
272
+
273
+ st.title("User Management")
274
+
275
+ # hashed_key=fetch_password_hash_key()[0][0]
276
+
277
+ if "df_users" not in st.session_state:
278
+ fetch_users_with_access()
279
+ # st.write(hashed_key.encode())
280
+ if "unique_ids_admin" not in st.session_state:
281
+ unique_users_query = f"""
282
+ SELECT DISTINCT emp_id, emp_nam, emp_typ from mmo_users
283
+ Where emp_typ= 'admin';
284
+ """
285
+ unique_users_result = query_excecuter_postgres(
286
+ unique_users_query, db_cred, insert=False
287
+ ) # retrieves all the users who has access to MMO TOOL
288
+ st.session_state["unique_ids_admin"] = {
289
+ emp_id: (emp_nam, emp_type) for emp_id, emp_nam, emp_type in unique_users_result
290
+ }
291
+
292
+
293
+ if len(st.session_state["unique_ids_admin"]) == 0:
294
+
295
+ st.error("No admin found in the database!")
296
+ st.markdown(
297
+ """
298
+ - You can add an admin to the database.
299
+ - **Ensure** that you store the passkey securely, as losing it will prevent anyone from accessing the tool.
300
+ - Once done, reset the password on the "**Home**" page.
301
+ """
302
+ )
303
+
304
+ emp_id = st.text_input("employee id", key="emp1111").lower()
305
+ password = st.text_input("Enter password to access the page", type="password")
306
+
307
+
308
+ if emp_id not in st.session_state["unique_ids_admin"] and len(emp_id) > 4:
309
+
310
+ if not is_pswrd_flag_set(emp_id):
311
+ st.warning("Reset password in Home page to continue")
312
+ st.stop()
313
+
314
+ st.warning(
315
+ 'Incorrect username or password. If you are using the default password, please reset it on the "Home" page.'
316
+ )
317
+
318
+ st.stop()
319
+
320
+
321
+ # st.write(st.session_state["unique_ids_admin"])
322
+
323
+
324
+ # Check password or if no admin is present
325
+
326
+ if verify_password(emp_id, password) or len(st.session_state["unique_ids_admin"]) == 0:
327
+
328
+ st.success("Access Granted", icon="✅")
329
+
330
+ st.header("All Users")
331
+
332
+ st.dataframe(
333
+ st.session_state["df_users"], use_container_width=True, hide_index=True
334
+ )
335
+
336
+ user_manage = st.radio(
337
+ "Select a Option",
338
+ ["Add New User", "Delete Users", "Manage User Passwords"],
339
+ horizontal=True,
340
+ )
341
+
342
+ if user_manage == "Add New User":
343
+ with st.expander("", expanded=True):
344
+ add_col = st.columns(3)
345
+
346
+ with add_col[0]:
347
+ user_id = st.text_input("Enter User ID", key="user_id").lower()
348
+
349
+ with add_col[1]:
350
+ user_name = st.text_input("Enter User Name", key="user_name").lower()
351
+
352
+ with add_col[2]:
353
+
354
+ if len(st.session_state["unique_ids_admin"]) == 0:
355
+ user_types_options = ["Admin"]
356
+ else:
357
+ user_types_options = ["Data Scientist", "Media Planner", "Admin"]
358
+
359
+ user_type = (
360
+ st.selectbox("Select User Type", user_types_options)
361
+ .replace(" ", "_")
362
+ .lower()
363
+ )
364
+ warning_box = st.empty()
365
+
366
+ if "passkey" not in st.session_state:
367
+ st.session_state["passkey"] = ""
368
+ if len(user_id) < 3 or len(user_name) < 3:
369
+ st.session_state["passkey"] = ""
370
+
371
+ pass_key_col = st.columns(3)
372
+ with pass_key_col[0]:
373
+ st.markdown(
374
+ '<div style="display: flex; position:relative; justify-content: flex-end; align-items: center; height: 100%;">'
375
+ '<p style="font-weight: bold; margin: 0; position:absolute; top:4px;">Default Password</p>'
376
+ "</div>",
377
+ unsafe_allow_html=True,
378
+ )
379
+ with pass_key_col[1]:
380
+ passkey_box = st.empty()
381
+ with passkey_box:
382
+ st.code(f"{st.session_state['passkey']}", language="text")
383
+
384
+ st.button(
385
+ "Reset Values", on_click=reset_user_name, use_container_width=True
386
+ )
387
+
388
+ if st.button("Add User", use_container_width=True):
389
+
390
+ if user_id in st.session_state["df_users"]["User ID"]:
391
+ with warning_box:
392
+ st.warning("User id already exists")
393
+ st.stop()
394
+
395
+ if (
396
+ len(user_id) == 0
397
+ or len(user_name) == 0
398
+ or not user_id.startswith("e")
399
+ or len(user_id) < 6
400
+ ):
401
+ with warning_box:
402
+ st.warning("Enter a Valid User ID and User Name!")
403
+ st.stop()
404
+
405
+ if contains_sql_keywords_check(user_id):
406
+ with warning_box:
407
+ st.warning(
408
+ "Input contains SQL keywords. Please avoid using SQL commands."
409
+ )
410
+ st.stop()
411
+ else:
412
+ pass
413
+
414
+ if not (2 <= len(user_name) <= 50):
415
+ # Store the warning message details in session state
416
+
417
+ with warning_box:
418
+ st.warning(
419
+ "Please provide a valid user name (2-50 characters, only A-Z, a-z, 0-9, and _)."
420
+ )
421
+ st.stop()
422
+
423
+ if contains_sql_keywords_check(user_name):
424
+ with warning_box:
425
+ st.warning(
426
+ "Input contains SQL keywords. Please avoid using SQL commands."
427
+ )
428
+ st.stop()
429
+ else:
430
+ pass
431
+
432
+ characters = string.ascii_letters + string.digits # Letters and digits
433
+ plain_text_password = "".join(
434
+ random.choice(characters) for _ in range(10)
435
+ )
436
+ st.session_state["passkey"] = plain_text_password
437
+
438
+ if add_user(user_id, user_name, user_type, emp_id, plain_text_password):
439
+ with st.spinner("Adding New User"):
440
+
441
+ update_password_and_flag(user_id, plain_text_password)
442
+ fetch_users_with_access()
443
+ st.success("User added successfully")
444
+ st.rerun()
445
+ else:
446
+ st.warning(
447
+ f"User with emp_id {user_id} already exists in the database."
448
+ )
449
+ st.stop()
450
+
451
+ if user_manage == "Delete Users":
452
+ with st.expander("", expanded=True):
453
+ user_names_to_delete = st.multiselect(
454
+ "Select User IDS to Delete", st.session_state["df_users"]["User ID"]
455
+ )
456
+
457
+ if st.button("Delete", use_container_width=True):
458
+ delete_users_by_names(user_names_to_delete)
459
+ st.success("Users Deleted")
460
+ fetch_users_with_access()
461
+ st.rerun()
462
+
463
+ if user_manage == "Manage User Passwords":
464
+
465
+ with st.expander("**Manage User Passwords**", expanded=True):
466
+ st.markdown(
467
+ """
468
+ 1.Click on "**Reset Password**" to generate a new default password for the selected user.
469
+ 2. Share the pass key with the corresponding user
470
+ """,
471
+ unsafe_allow_html=True,
472
+ )
473
+
474
+ user_id = st.selectbox(
475
+ "Select A User ID",
476
+ st.session_state["df_users"]["User ID"],
477
+ on_change=reset_pass_key,
478
+ )
479
+
480
+ if "pass_key" not in st.session_state:
481
+ st.session_state["pass_key"] = ""
482
+
483
+ default_passkey = st.code(
484
+ f"Default Password: {st.session_state['pass_key']}", language="text"
485
+ )
486
+
487
+ if st.button("Reset Password", use_container_width=True, key="reset_pass_button_key"):
488
+ with st.spinner("Reseting Password"):
489
+ characters = (
490
+ string.ascii_letters + string.digits
491
+ ) # Letters and digits
492
+ plain_text_password = "".join(
493
+ random.choice(characters) for _ in range(10)
494
+ )
495
+ st.session_state["pass_key"] = plain_text_password
496
+
497
+ update_password_and_flag(user_id, plain_text_password, flag=True)
498
+ time.sleep(1)
499
+ del st.session_state["reset_pass_button_key"]
500
+ st.rerun()
501
+
502
+ elif len(password):
503
+ st.warning("Wrong user name or password!")
pages/1_Data_Import.py ADDED
@@ -0,0 +1,1213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Data Import",
6
+ page_icon="⚖️",
7
+ layout="wide",
8
+ initial_sidebar_state="collapsed",
9
+ )
10
+
11
+
12
+ import re
13
+ import sys
14
+ import pickle
15
+ import numbers
16
+ import traceback
17
+ import pandas as pd
18
+ from scenario import numerize
19
+ from post_gres_cred import db_cred
20
+ from collections import OrderedDict
21
+ from log_application import log_message
22
+ from utilities import set_header, load_local_css, update_db, project_selection
23
+ from constants import (
24
+ upload_rows_limit,
25
+ upload_column_limit,
26
+ word_length_limit_lower,
27
+ word_length_limit_upper,
28
+ minimum_percent_overlap,
29
+ minimum_row_req,
30
+ percent_drop_col_threshold,
31
+ )
32
+
33
+ schema = db_cred["schema"]
34
+ load_local_css("styles.css")
35
+ set_header()
36
+
37
+
38
+ # Initialize project name session state
39
+ if "project_name" not in st.session_state:
40
+ st.session_state["project_name"] = None
41
+
42
+ # Fetch project dictionary
43
+ if "project_dct" not in st.session_state:
44
+ project_selection()
45
+ st.stop()
46
+
47
+ # Display Username and Project Name
48
+ if "username" in st.session_state and st.session_state["username"] is not None:
49
+
50
+ cols1 = st.columns([2, 1])
51
+
52
+ with cols1[0]:
53
+ st.markdown(f"**Welcome {st.session_state['username']}**")
54
+ with cols1[1]:
55
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
56
+
57
+
58
+ # Initialize session state keys
59
+ if "granularity_selection_key" not in st.session_state:
60
+ st.session_state["granularity_selection_key"] = st.session_state["project_dct"][
61
+ "data_import"
62
+ ]["granularity_selection"]
63
+
64
+
65
+ # Function to format name
66
+ def name_format_func(name):
67
+ return str(name).strip().title()
68
+
69
+
70
+ # Function to get columns with specified prefix and remove prefix
71
+ def get_columns_with_prefix(df, prefix):
72
+ return [
73
+ col.replace(prefix, "")
74
+ for col in df.columns
75
+ if col.startswith(prefix) and str(col) != str(prefix)
76
+ ]
77
+
78
+
79
+ # Function to fetch columns info
80
+ @st.cache_data(show_spinner=False)
81
+ def fetch_columns(gold_layer_df, data_upload_df):
82
+ # Get lists of columns starting with 'spends_' and 'response_metric_' from gold_layer_df
83
+ spends_columns_gold_layer = get_columns_with_prefix(gold_layer_df, "spends_")
84
+ response_metric_columns_gold_layer = get_columns_with_prefix(
85
+ gold_layer_df, "response_metric_"
86
+ )
87
+
88
+ # Get lists of columns starting with 'spends_' and 'response_metric_' from data_upload_df
89
+ spends_columns_upload = get_columns_with_prefix(data_upload_df, "spends_")
90
+ response_metric_columns_upload = get_columns_with_prefix(
91
+ data_upload_df, "response_metric_"
92
+ )
93
+
94
+ # Combine lists from both DataFrames
95
+ spends_columns = spends_columns_gold_layer + spends_columns_upload
96
+ # Remove 'total' from the spends_columns list if it exists
97
+ spends_columns = list(
98
+ set([col for col in spends_columns if not col.endswith("_total")])
99
+ )
100
+
101
+ response_metric_columns = (
102
+ response_metric_columns_gold_layer + response_metric_columns_upload
103
+ )
104
+ # Filter columns ending with '_total' and remove the '_total' suffix
105
+ response_metric_columns = list(
106
+ set(
107
+ [
108
+ col[:-6]
109
+ for col in response_metric_columns
110
+ if col.endswith("_total") and len(col[:-6]) != 0
111
+ ]
112
+ )
113
+ )
114
+
115
+ # Get list of all columns from both DataFrames
116
+ gold_layer_columns = list(gold_layer_df.columns)
117
+ data_upload_columns = list(data_upload_df.columns)
118
+
119
+ # Combine all columns and get unique columns
120
+ all_columns = list(set(gold_layer_columns + data_upload_columns))
121
+
122
+ return (
123
+ spends_columns,
124
+ response_metric_columns,
125
+ all_columns,
126
+ gold_layer_columns,
127
+ data_upload_columns,
128
+ )
129
+
130
+
131
+ # Function to format values for display
132
+ @st.cache_data(show_spinner=False)
133
+ def format_values_for_display(values_list):
134
+ # Format value
135
+ formatted_list = [value.lower().strip() for value in values_list]
136
+ # Join values with commas and 'and' before the last value
137
+ if len(formatted_list) > 1:
138
+ return ", ".join(formatted_list[:-1]) + ", and " + formatted_list[-1]
139
+ elif formatted_list:
140
+ return formatted_list[0]
141
+ return "No values available"
142
+
143
+
144
+ # Function to validate input DataFrame
145
+ @st.cache_data(show_spinner=False)
146
+ def valid_input_df(
147
+ df,
148
+ spends_columns,
149
+ response_metric_columns,
150
+ total_columns,
151
+ gold_layer_columns,
152
+ data_upload_columns,
153
+ ):
154
+ # Check if DataFrame is empty
155
+ if df.empty or len(df) < 1:
156
+ return (True, None)
157
+
158
+ # Check for invalid column names
159
+ invalid_columns = [
160
+ col
161
+ for col in df.columns
162
+ if not re.match(r"^[A-Za-z0-9_]+$", col)
163
+ or not (word_length_limit_lower <= len(col) <= word_length_limit_upper)
164
+ ]
165
+ if invalid_columns:
166
+ return (
167
+ False,
168
+ f"Invalid column names: {format_values_for_display(invalid_columns)}. Use only letters, numbers, and underscores. Column name length should be {word_length_limit_lower} to {word_length_limit_upper} characters long.",
169
+ )
170
+
171
+ # Ensure 'panel' column values are strings and conform to specified pattern and length
172
+ if "panel" in df.columns:
173
+ df["panel"] = df["panel"].astype(str).str.strip()
174
+ invalid_panel_values = [
175
+ val
176
+ for val in df["panel"].unique()
177
+ if not re.match(r"^[A-Za-z0-9_]+$", val)
178
+ or not (word_length_limit_lower <= len(val) <= word_length_limit_upper)
179
+ ]
180
+ if invalid_panel_values:
181
+ return (
182
+ False,
183
+ f"Invalid panel values: {format_values_for_display(invalid_panel_values)}. Use only letters, numbers, and underscores. Panel name length should be {word_length_limit_lower} to {word_length_limit_upper} characters long.",
184
+ )
185
+
186
+ # Check for missing required columns
187
+ required_columns = ["date", "panel"]
188
+ missing_columns = [col for col in required_columns if col not in df.columns]
189
+ if missing_columns:
190
+ return (
191
+ False,
192
+ f"Missing compulsory columns: {format_values_for_display(missing_columns)}.",
193
+ )
194
+
195
+ # Check if all other columns are numeric
196
+ non_numeric_columns = [
197
+ col
198
+ for col in df.columns
199
+ if col not in required_columns and not pd.api.types.is_numeric_dtype(df[col])
200
+ ]
201
+ if non_numeric_columns:
202
+ return (
203
+ False,
204
+ f"Non-numeric columns: {format_values_for_display(non_numeric_columns)}. All columns except {format_values_for_display(required_columns)} should be numeric.",
205
+ )
206
+
207
+ # Ensure all columns in data_upload_columns are unique
208
+ duplicate_columns_in_upload = [
209
+ col for col in data_upload_columns if data_upload_columns.count(col) > 1
210
+ ]
211
+ if duplicate_columns_in_upload:
212
+ return (
213
+ False,
214
+ f"Duplicate columns found in the uploaded data: {format_values_for_display(set(duplicate_columns_in_upload))}.",
215
+ )
216
+
217
+ # Convert 'date' column to datetime format
218
+ try:
219
+ df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d")
220
+ except:
221
+ return False, "The 'date' column is not in the correct format 'YYYY-MM-DD'."
222
+
223
+ # Check date frequency
224
+ unique_panels = df["panel"].unique()
225
+ for panel in unique_panels:
226
+ date_diff = df[df["panel"] == panel]["date"].diff().dropna()
227
+ if not (
228
+ (date_diff == pd.Timedelta(days=1)).all()
229
+ or (date_diff == pd.Timedelta(weeks=1)).all()
230
+ ):
231
+ return False, "The 'date' column does not have a daily or weekly frequency."
232
+
233
+ # Check for null values in 'date' or 'panel' columns
234
+ if df[required_columns].isnull().any().any():
235
+ return (
236
+ False,
237
+ f"The {format_values_for_display(required_columns)} should not contain null values.",
238
+ )
239
+
240
+ # Check for panels with less than 1% date overlap
241
+ if not gold_layer_df.empty:
242
+ panels_with_low_overlap = []
243
+ unique_panels = list(
244
+ set(df["panel"].unique()).union(set(gold_layer_df["panel"].unique()))
245
+ )
246
+ for panel in unique_panels:
247
+ gold_layer_dates = set(
248
+ gold_layer_df[gold_layer_df["panel"] == panel]["date"]
249
+ )
250
+ data_upload_dates = set(df[df["panel"] == panel]["date"])
251
+ if gold_layer_dates and data_upload_dates:
252
+ overlap = len(gold_layer_dates & data_upload_dates) / len(
253
+ gold_layer_dates | data_upload_dates
254
+ )
255
+ else:
256
+ overlap = 0
257
+ if overlap < (minimum_percent_overlap / 100):
258
+ panels_with_low_overlap.append(panel)
259
+
260
+ if panels_with_low_overlap:
261
+ return (
262
+ False,
263
+ f"Date columns in the gold layer and uploaded data do not have at least {minimum_percent_overlap}% overlap for panels: {format_values_for_display(panels_with_low_overlap)}.",
264
+ )
265
+
266
+ # Check if spends_columns is less than two
267
+ if len(spends_columns) < 2:
268
+ return False, "Please add at least two spends columns."
269
+
270
+ # Check if response_metric_columns is empty
271
+ if len(response_metric_columns) < 1:
272
+ return False, "Please add response metric columns."
273
+
274
+ # Check if all numeric columns are positive except those starting with 'exogenous_' or 'internal_'
275
+ valid_prefixes = ["exogenous_", "internal_"]
276
+ negative_values_columns = [
277
+ col
278
+ for col in df.select_dtypes(include=[float, int]).columns
279
+ if not any(col.startswith(prefix) for prefix in valid_prefixes)
280
+ and (df[col] < 0).any()
281
+ ]
282
+ if negative_values_columns:
283
+ return (
284
+ False,
285
+ f"Negative values detected in columns: {format_values_for_display(negative_values_columns)}. Ensure all media and response metric columns are positive.",
286
+ )
287
+
288
+ # Check for unassociated columns
289
+ detected_channels = spends_columns + ["total"]
290
+ unassociated_columns = []
291
+ for col in df.columns:
292
+ if (col.startswith("_") or col.endswith("_")) or not (
293
+ col.startswith("exogenous_") # Column starts with "exogenous_"
294
+ or col.startswith("internal_") # Column starts with "internal_"
295
+ or any(
296
+ col == f"spends_{channel}" for channel in spends_columns
297
+ ) # Column is not in the format "spends_<channel>"
298
+ or any(
299
+ col == f"response_metric_{metric}_{channel}"
300
+ for metric in response_metric_columns
301
+ for channel in detected_channels
302
+ ) # Column is not in the format "response_metric_<metric>_<channel>"
303
+ or any(
304
+ col.startswith("media_")
305
+ and col.endswith(f"_{channel}")
306
+ and len(col) > len(f"media__{channel}")
307
+ for channel in spends_columns
308
+ ) # Column is not in the format "media_<media_variable_name>_<channel>"
309
+ or col in ["date", "panel"]
310
+ ):
311
+ unassociated_columns.append(col)
312
+
313
+ if unassociated_columns:
314
+ return (
315
+ False,
316
+ f"Columns with incorrect format detected: {format_values_for_display(unassociated_columns)}.",
317
+ )
318
+
319
+ return True, "The data is valid and meets all requirements."
320
+
321
+
322
+ # Function to load the uploaded Excel file into a DataFrame
323
+ @st.cache_data(show_spinner=False)
324
+ def load_and_transform_data(uploaded_file):
325
+ # Load the uploaded file into a DataFrame if a file is uploaded
326
+ if uploaded_file is not None:
327
+ df = pd.read_excel(uploaded_file)
328
+ else:
329
+ df = pd.DataFrame()
330
+ return df
331
+
332
+ # Check if DataFrame exceeds row and column limits
333
+ if len(df) > upload_rows_limit or len(df.columns) > upload_column_limit:
334
+ st.warning(
335
+ f"Data exceeds the row limit of {numerize(upload_rows_limit)} or column limit of {numerize(upload_column_limit)}. Please upload a smaller file.",
336
+ icon="⚠️",
337
+ )
338
+
339
+ # Log message
340
+ log_message(
341
+ "warning",
342
+ f"Data exceeds the row limit of {numerize(upload_rows_limit)} or column limit of {numerize(upload_column_limit)}. Please upload a smaller file.",
343
+ "Data Import",
344
+ )
345
+
346
+ return pd.DataFrame()
347
+
348
+ # If the DataFrame contains only 'panel' and 'date' columns, return empty DataFrame
349
+ if set(df.columns) == {"date", "panel"}:
350
+ return pd.DataFrame()
351
+
352
+ # Transform column names: lower, strip start and end, replace spaces with _
353
+ df.columns = [str(col).strip().lower().replace(" ", "_") for col in df.columns]
354
+
355
+ # If 'panel' column exists, clean its values
356
+ try:
357
+ if "panel" in df.columns:
358
+ df["panel"] = (
359
+ df["panel"].astype(str).str.lower().str.strip().str.replace(" ", "_")
360
+ )
361
+ except:
362
+ return df
363
+
364
+ try:
365
+ df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d")
366
+ except:
367
+ # The 'date' column is not in the correct format 'YYYY-MM-DD'
368
+ return df
369
+
370
+ # Check date frequency and convert to daily if needed
371
+ date_diff = df["date"].diff().dropna()
372
+ if (date_diff == pd.Timedelta(days=1)).all():
373
+ # Data is already at daily level
374
+ return df
375
+ elif (date_diff == pd.Timedelta(weeks=1)).all():
376
+ # Data is at weekly level, convert to daily
377
+ weekly_data = df.copy()
378
+ daily_data = []
379
+
380
+ for index, row in weekly_data.iterrows():
381
+ week_start = row["date"] - pd.to_timedelta(row["date"].weekday(), unit="D")
382
+ for i in range(7):
383
+ daily_date = week_start + pd.DateOffset(days=i)
384
+ new_row = row.copy()
385
+ new_row["date"] = daily_date
386
+ for col in df.columns:
387
+ if isinstance(new_row[col], numbers.Number):
388
+ new_row[col] = new_row[col] / 7
389
+ daily_data.append(new_row)
390
+
391
+ daily_data_df = pd.DataFrame(daily_data)
392
+ return daily_data_df
393
+ else:
394
+ # The 'date' column does not have a daily or weekly frequency
395
+ return df
396
+
397
+
398
+ # Function to merge DataFrames if present
399
+ @st.cache_data(show_spinner=False)
400
+ def merge_dataframes(gold_layer_df, data_upload_df):
401
+ if gold_layer_df.empty and data_upload_df.empty:
402
+ return pd.DataFrame()
403
+
404
+ if not gold_layer_df.empty and not data_upload_df.empty:
405
+ # Merge gold_layer_df and data_upload_df on 'panel', and 'date'
406
+ merged_df = pd.merge(
407
+ gold_layer_df,
408
+ data_upload_df,
409
+ on=["panel", "date"],
410
+ how="outer",
411
+ suffixes=("_gold", "_upload"),
412
+ )
413
+
414
+ # Handle duplicate columns
415
+ for col in merged_df.columns:
416
+ if col.endswith("_gold"):
417
+ base_col = col[:-5] # Remove '_gold' suffix
418
+ upload_col = base_col + "_upload" # Column name in data_upload_df
419
+ if upload_col in merged_df.columns:
420
+ # Prefer values from data_upload_df
421
+ merged_df[base_col] = merged_df[upload_col].combine_first(
422
+ merged_df[col]
423
+ )
424
+ merged_df.drop(columns=[col, upload_col], inplace=True)
425
+ else:
426
+ # Rename column to remove the suffix
427
+ merged_df.rename(columns={col: base_col}, inplace=True)
428
+
429
+ elif data_upload_df.empty:
430
+ merged_df = gold_layer_df.copy()
431
+
432
+ elif gold_layer_df.empty:
433
+ merged_df = data_upload_df.copy()
434
+
435
+ return merged_df
436
+
437
+
438
+ # Function to check if all required columns are present in the Uploaded DataFrame
439
+ @st.cache_data(show_spinner=False)
440
+ def check_required_columns(df, detected_channels, detected_response_metric):
441
+ required_columns = []
442
+
443
+ # Add all channels with 'spends_' + detected channel name
444
+ for channel in detected_channels:
445
+ required_columns.append(f"spends_{channel}")
446
+
447
+ # Add all channels with 'response_metric_' + detected channel name
448
+ for response_metric in detected_response_metric:
449
+ for channel in detected_channels + ["total"]:
450
+ required_columns.append(f"response_metric_{response_metric}_{channel}")
451
+
452
+ # Check for missing columns
453
+ missing_columns = [col for col in required_columns if col not in df.columns]
454
+
455
+ # Channel groupings
456
+ no_media_data = []
457
+ channel_columns_dict = {}
458
+ for channel in detected_channels:
459
+ channel_columns = [
460
+ col
461
+ for col in merged_df.columns
462
+ if channel in col
463
+ and not (
464
+ col.startswith("response_metric_")
465
+ or col.startswith("exogenous_")
466
+ or col.startswith("internal_")
467
+ )
468
+ and col.endswith(channel)
469
+ ]
470
+ channel_columns_dict[channel] = channel_columns
471
+
472
+ if len(channel_columns) <= 1:
473
+ no_media_data.append(channel)
474
+
475
+ return missing_columns, no_media_data, channel_columns_dict
476
+
477
+
478
+ # Function to prepare tool DataFrame
479
+ def prepare_tool_df(merged_df, granularity_selection):
480
+ # Drop all response metric columns that do not end with '_total'
481
+ cols_to_drop = [
482
+ col
483
+ for col in merged_df.columns
484
+ if col.startswith("response_metric_") and not col.endswith("_total")
485
+ ]
486
+
487
+ # Create a DataFrame to be used for the tool
488
+ tool_df = merged_df.drop(columns=cols_to_drop)
489
+
490
+ # Convert to weekly granularity by aggregating all data for given panel and week
491
+ if granularity_selection.lower() == "weekly":
492
+ tool_df.set_index("date", inplace=True)
493
+ tool_df = (
494
+ tool_df.groupby(
495
+ [pd.Grouper(freq="W-MON", closed="left", label="left"), "panel"]
496
+ )
497
+ .sum()
498
+ .reset_index()
499
+ )
500
+
501
+ return tool_df
502
+
503
+
504
+ # Function to generate imputation DataFrame
505
+ def generate_imputation_df(tool_df):
506
+ # Initialize lists to store the column details
507
+ column_names = []
508
+ categories = []
509
+ missing_values_info = []
510
+ zero_values_info = []
511
+ imputation_methods = []
512
+
513
+ # Define the function to calculate the percentage of missing values
514
+ def calculate_missing_percentage(series):
515
+ return series.isnull().sum(), (series.isnull().mean() * 100)
516
+
517
+ # Define the function to calculate the percentage of zero values
518
+ def calculate_zero_percentage(series):
519
+ return (series == 0).sum(), ((series == 0).mean() * 100)
520
+
521
+ # Iterate over each column to categorize and calculate missing and zero values
522
+ for col in tool_df.columns:
523
+ # Determine category based on column name prefix
524
+ if col == "date" or col == "panel":
525
+ continue
526
+ elif col.startswith("response_metric_"):
527
+ categories.append("Response Metrics")
528
+ elif col.startswith("spends_"):
529
+ categories.append("Spends")
530
+ elif col.startswith("exogenous_"):
531
+ categories.append("Exogenous")
532
+ elif col.startswith("internal_"):
533
+ categories.append("Internal")
534
+ else:
535
+ categories.append("Media")
536
+
537
+ # Calculate missing values and percentage
538
+ missing_count, missing_percentage = calculate_missing_percentage(tool_df[col])
539
+ missing_values_info.append(f"{missing_count} ({missing_percentage:.1f}%)")
540
+
541
+ # Calculate zero values and percentage
542
+ zero_count, zero_percentage = calculate_zero_percentage(tool_df[col])
543
+ zero_values_info.append(f"{zero_count} ({zero_percentage:.1f}%)")
544
+
545
+ # Determine default imputation method based on conditions
546
+ if col.startswith("spends_"):
547
+ imputation_methods.append("Fill with 0")
548
+ elif col.startswith("response_metric_"):
549
+ imputation_methods.append("Fill with Mean")
550
+ elif zero_percentage + missing_percentage > percent_drop_col_threshold:
551
+ imputation_methods.append("Drop Column")
552
+ else:
553
+ imputation_methods.append("Fill with Mean")
554
+
555
+ column_names.append(col)
556
+
557
+ # Create the DataFrame
558
+ imputation_df = pd.DataFrame(
559
+ {
560
+ "Column Name": column_names,
561
+ "Category": categories,
562
+ "Missing Values": missing_values_info,
563
+ "Zero Values": zero_values_info,
564
+ "Imputation Method": imputation_methods,
565
+ }
566
+ )
567
+
568
+ # Define the category order for sorting
569
+ category_order = {
570
+ "Response Metrics": 1,
571
+ "Spends": 2,
572
+ "Media": 3,
573
+ "Exogenous": 4,
574
+ "Internal": 5,
575
+ }
576
+
577
+ # Add a temporary column for sorting based on the category order
578
+ imputation_df["Category Order"] = imputation_df["Category"].map(category_order)
579
+
580
+ # Sort the DataFrame based on the category order and then drop the temporary column
581
+ imputation_df = imputation_df.sort_values(
582
+ by=["Category Order", "Column Name"]
583
+ ).drop(columns=["Category Order"])
584
+
585
+ return imputation_df
586
+
587
+
588
+ # Function to perform imputation as per user requests
589
+ def perform_imputation(imputation_df, tool_df):
590
+ # Detect channels associated with spends
591
+ detected_channels = [
592
+ col.replace("spends_", "")
593
+ for col in tool_df.columns
594
+ if col.startswith("spends_")
595
+ ]
596
+
597
+ # Create a dictionary with keys as channels and values as associated columns
598
+ group_dict = {
599
+ channel: [
600
+ col
601
+ for col in tool_df.columns
602
+ if channel in col
603
+ and not (
604
+ col.startswith("response_metric_")
605
+ or col.startswith("exogenous_")
606
+ or col.startswith("internal_")
607
+ )
608
+ ]
609
+ for channel in detected_channels
610
+ }
611
+
612
+ # Create a reverse dictionary with keys as columns and values as channels
613
+ column_to_channel_dict = {
614
+ col: channel for channel, cols in group_dict.items() for col in cols
615
+ }
616
+
617
+ # Perform imputation
618
+ already_dropped = []
619
+ for index, row in imputation_df.iterrows():
620
+ col_name = row["Column Name"]
621
+ impute_method = row["Imputation Method"]
622
+
623
+ # Skip already dropped columns
624
+ if col_name in already_dropped:
625
+ continue
626
+
627
+ # Skip imputation if dropping response metric column and add warning
628
+ if impute_method == "Drop Column" and col_name.startswith("response_metric_"):
629
+ return None, {}, f"Cannot drop response metric column: {col_name}"
630
+
631
+ # Drop column if requested
632
+ if impute_method == "Drop Column":
633
+ # If spends column is dropped, drop all related columns
634
+ if col_name.startswith("spends_"):
635
+ tool_df.drop(
636
+ columns=group_dict[col_name.replace("spends_", "")],
637
+ inplace=True,
638
+ )
639
+ already_dropped += group_dict[col_name.replace("spends_", "")]
640
+ del group_dict[col_name.replace("spends_", "")]
641
+ else:
642
+ tool_df.drop(columns=[col_name], inplace=True)
643
+ if not (
644
+ col_name.startswith("exogenous_")
645
+ or col_name.startswith("internal_")
646
+ ):
647
+ group_name = column_to_channel_dict[col_name]
648
+ group_dict[group_name].remove(col_name)
649
+
650
+ # Check for channels with one or fewer associated columns and add warning if needed
651
+ if len(group_dict[group_name]) <= 1:
652
+ return (
653
+ None,
654
+ {},
655
+ f"No media variable associated with category {col_name.replace('spends_', '')}.",
656
+ )
657
+ continue
658
+
659
+ # Check for each panel
660
+ for panel in tool_df["panel"].unique():
661
+ panel_df = tool_df[tool_df["panel"] == panel]
662
+
663
+ # Check if the column is entirely null or empty for the current panel
664
+ if panel_df[col_name].isnull().all():
665
+ if impute_method in ["Fill with Mean", "Fill with Median"]:
666
+ return (
667
+ None,
668
+ {},
669
+ f"Cannot impute for empty column(s) with mean or median. Select 'Fill with 0'. Details: Panel: {panel}, Column: {col_name}",
670
+ )
671
+
672
+ # Fill missing values as requested
673
+ if impute_method == "Fill with Mean":
674
+ tool_df[col_name] = tool_df.groupby("panel")[col_name].transform(
675
+ lambda x: x.fillna(x.mean())
676
+ )
677
+ elif impute_method == "Fill with Median":
678
+ tool_df[col_name] = tool_df.groupby("panel")[col_name].transform(
679
+ lambda x: x.fillna(x.median())
680
+ )
681
+ elif impute_method == "Fill with 0":
682
+ tool_df[col_name].fillna(0, inplace=True)
683
+
684
+ # Check if final DataFrame has at least one response metric and two spends categories
685
+ response_metrics = [
686
+ col for col in tool_df.columns if col.startswith("response_metric_")
687
+ ]
688
+ spends_categories = [col for col in tool_df.columns if col.startswith("spends_")]
689
+
690
+ if len(response_metrics) < 1:
691
+ return (None, {}, "The final DataFrame must have at least one response metric.")
692
+ if len(spends_categories) < 2:
693
+ return (
694
+ None,
695
+ {},
696
+ "The final DataFrame must have at least two spends categories.",
697
+ )
698
+
699
+ return tool_df, group_dict, "Imputed Successfully!"
700
+
701
+
702
+ # Function to display groups with custom styling
703
+ def display_groups(input_dict):
704
+ # Define custom CSS for pastel light blue rounded rectangle
705
+ custom_css = """
706
+ <style>
707
+ .group-box {
708
+ background-color: #ffdaab;
709
+ border-radius: 10px;
710
+ padding: 10px;
711
+ margin: 5px 0;
712
+ }
713
+ </style>
714
+ """
715
+ st.markdown(custom_css, unsafe_allow_html=True)
716
+
717
+ for group_name, values in input_dict.items():
718
+ group_html = f"<div class='group-box'><strong>{group_name}:</strong> {format_values_for_display(values)}</div>"
719
+ st.markdown(group_html, unsafe_allow_html=True)
720
+
721
+
722
+ # Function to categorize columns and create an ordered dictionary
723
+ def create_ordered_category_dict(df):
724
+ category_dict = {
725
+ "Response Metrics": [],
726
+ "Spends": [],
727
+ "Media": [],
728
+ "Exogenous": [],
729
+ "Internal": [],
730
+ }
731
+
732
+ # Define the category order for sorting
733
+ category_order = {
734
+ "Response Metrics": 1,
735
+ "Spends": 2,
736
+ "Media": 3,
737
+ "Exogenous": 4,
738
+ "Internal": 5,
739
+ }
740
+
741
+ for column in df.columns:
742
+ if column == "date" or column == "panel":
743
+ continue # Skip 'date' and 'panel' columns
744
+
745
+ if column.startswith("response_metric_"):
746
+ category_dict["Response Metrics"].append(column)
747
+ elif column.startswith("spends_"):
748
+ category_dict["Spends"].append(column)
749
+ elif column.startswith("exogenous_"):
750
+ category_dict["Exogenous"].append(column)
751
+ elif column.startswith("internal_"):
752
+ category_dict["Internal"].append(column)
753
+ else:
754
+ category_dict["Media"].append(column)
755
+
756
+ # Sort the dictionary based on the defined category order
757
+ sorted_category_dict = OrderedDict(
758
+ sorted(category_dict.items(), key=lambda item: category_order[item[0]])
759
+ )
760
+
761
+ return sorted_category_dict
762
+
763
+
764
+ try:
765
+ # Page Title
766
+ st.title("Data Import")
767
+
768
+ # Create file uploader
769
+ uploaded_file = st.file_uploader(
770
+ "Upload Data", type=["xlsx"], accept_multiple_files=False
771
+ )
772
+
773
+ # Expander with markdown for upload rules
774
+ with st.expander("Upload Rules and Guidelines"):
775
+ st.markdown(
776
+ """
777
+ ### Upload Guidelines
778
+
779
+ Please ensure your data adheres to the following rules:
780
+
781
+ 1. **File Format**:
782
+ - Upload all data in a single Excel file.
783
+
784
+ 2. **Compulsory Columns**:
785
+ - **Date**: Must be in the format `YYYY-MM-DD` only.
786
+ - **Panel**: If no panel data exists, use `aggregated` as a single panel.
787
+
788
+ 3. **Column Naming Conventions**:
789
+ - All columns should start with the associated category prefix.
790
+
791
+ **Examples**:
792
+
793
+ - **Response Metric Column**:
794
+ - Format: `response_metric_<response_metric_name>_<channel_name>`
795
+ - Example: `response_metric_revenue_facebook`
796
+
797
+ - **Total Response Metric**:
798
+ - Format: `response_metric_<response_metric_name>_total`
799
+ - Example: `response_metric_revenue_total`
800
+
801
+ - **Spend Column**:
802
+ - Format: `spends_<channel_name>`
803
+ - Example: `spends_facebook`
804
+
805
+ - **Media Column**:
806
+ - Format: `media_<media_variable_name>_<channel_name>`
807
+ - Example: `media_clicks_facebook`
808
+
809
+ - **Exogenous Column**:
810
+ - Format: `exogenous_<variable_name>`
811
+ - Example: `exogenous_unemployment_rate`
812
+
813
+ - **Internal Column**:
814
+ - Format: `internal_<variable_name>`
815
+ - Example: `internal_discount`
816
+
817
+ **Notes**:
818
+
819
+ - The `total` response metric should represent the total for a particular date and panel, including all channels and organic contributions.
820
+ - The `date` column for weekly data should be the Monday of that week, representing the data from that Monday to the following Sunday. Example: If the week starts on Monday, August 5th, 2024, and ends on Sunday, August 11th, 2024, the date column for that week should display 2024-08-05.
821
+ """
822
+ )
823
+
824
+ # Upload warning placeholder
825
+ upload_warning_placeholder = st.container()
826
+
827
+ # Load the uploaded file into a DataFrame if a file is uploaded
828
+ data_upload_df = load_and_transform_data(uploaded_file)
829
+
830
+ # Columns for user input
831
+ granularity_col, validate_process_col = st.columns(2)
832
+
833
+ # Dropdown for data granularity
834
+ granularity_selection = granularity_col.selectbox(
835
+ "Select data granularity",
836
+ options=["daily", "weekly"],
837
+ format_func=name_format_func,
838
+ key="granularity_selection_key",
839
+ )
840
+
841
+ # Gold Layer DataFrame
842
+ gold_layer_df = st.session_state["project_dct"]["data_import"]["gold_layer_df"]
843
+ if not gold_layer_df.empty:
844
+ st.subheader("Gold Layer DataFrame")
845
+ with st.expander("Gold Layer DataFrame"):
846
+ st.dataframe(
847
+ gold_layer_df,
848
+ hide_index=True,
849
+ column_config={
850
+ "date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
851
+ },
852
+ )
853
+ else:
854
+ st.info(
855
+ "No gold layer data is selected for this project. Please upload data manually.",
856
+ icon="📊",
857
+ )
858
+
859
+ # Check input data
860
+ with validate_process_col:
861
+ st.write("##") # Padding
862
+
863
+ if validate_process_col.button("Validate and Process", use_container_width=True):
864
+ with st.spinner("Processing ..."):
865
+ # Check if both DataFrames are empty
866
+ valid_input = True
867
+ if gold_layer_df.empty and data_upload_df.empty:
868
+ # If both gold_layer_df and data_upload_df are empty, display a warning and stop the script
869
+ st.warning(
870
+ "Both the Gold Layer data and the uploaded data are empty. Please provide at least one data source.",
871
+ icon="⚠️",
872
+ )
873
+
874
+ # Log message
875
+ log_message(
876
+ "warning",
877
+ "Both the Gold Layer data and the uploaded data are empty. Please provide at least one data source.",
878
+ "Data Import",
879
+ )
880
+ valid_input = False
881
+
882
+ # If the uploaded DataFrame is empty and the Gold Layer is not, swap them to ensure all validation conditions are checked
883
+ elif not gold_layer_df.empty and data_upload_df.empty:
884
+ data_upload_df, gold_layer_df = (
885
+ gold_layer_df.copy(),
886
+ data_upload_df.copy(),
887
+ )
888
+ valid_input = True
889
+
890
+ if valid_input:
891
+ # Fetch all necessary columns list
892
+ (
893
+ spends_columns,
894
+ response_metric_columns,
895
+ total_columns,
896
+ gold_layer_columns,
897
+ data_upload_columns,
898
+ ) = fetch_columns(gold_layer_df, data_upload_df)
899
+
900
+ with upload_warning_placeholder:
901
+ valid_input, message = valid_input_df(
902
+ data_upload_df,
903
+ spends_columns,
904
+ response_metric_columns,
905
+ total_columns,
906
+ gold_layer_columns,
907
+ data_upload_columns,
908
+ )
909
+ if not valid_input:
910
+ st.warning(message, icon="⚠️")
911
+
912
+ # Log message
913
+ log_message("warning", message, "Data Import")
914
+
915
+ # Merge gold_layer_df and data_upload_df on 'panel' and 'date'
916
+ if valid_input:
917
+ merged_df = merge_dataframes(gold_layer_df, data_upload_df)
918
+
919
+ missing_columns, no_media_data, channel_columns_dict = (
920
+ check_required_columns(
921
+ merged_df, spends_columns, response_metric_columns
922
+ )
923
+ )
924
+
925
+ with upload_warning_placeholder:
926
+ # Warning for categories with no media data
927
+ if no_media_data:
928
+ st.warning(
929
+ f"Categories without media data: {format_values_for_display(no_media_data)}. Please upload at least one media column to proceed.",
930
+ icon="⚠️",
931
+ )
932
+ valid_input = False
933
+
934
+ # Log message
935
+ log_message(
936
+ "warning",
937
+ f"Categories without media data: {format_values_for_display(no_media_data)}. Please upload at least one media column to proceed.",
938
+ "Data Import",
939
+ )
940
+
941
+ # Warning for insufficient rows
942
+ elif any(
943
+ granularity_selection == "daily"
944
+ and len(merged_df[merged_df["panel"] == panel])
945
+ < minimum_row_req
946
+ for panel in merged_df["panel"].unique()
947
+ ):
948
+ st.warning(
949
+ f"Insufficient data. Please provide at least {minimum_row_req} days of data for all panel.",
950
+ icon="⚠️",
951
+ )
952
+ valid_input = False
953
+
954
+ # Log message
955
+ log_message(
956
+ "warning",
957
+ f"Insufficient data. Please provide at least {minimum_row_req} days of data for all panel.",
958
+ "Data Import",
959
+ )
960
+
961
+ elif any(
962
+ granularity_selection == "weekly"
963
+ and len(merged_df[merged_df["panel"] == panel])
964
+ < minimum_row_req * 7
965
+ for panel in merged_df["panel"].unique()
966
+ ):
967
+ st.warning(
968
+ f"Insufficient data. Please provide at least {minimum_row_req} weeks of data for all panel.",
969
+ icon="⚠️",
970
+ )
971
+ valid_input = False
972
+
973
+ # Log message
974
+ log_message(
975
+ "warning",
976
+ f"Insufficient data. Please provide at least {minimum_row_req} weeks of data for all panel.",
977
+ "Data Import",
978
+ )
979
+
980
+ # Info for missing columns
981
+ elif missing_columns:
982
+ st.info(
983
+ f"Missing columns: {format_values_for_display(missing_columns)}. Please upload all required columns.",
984
+ icon="💡",
985
+ )
986
+
987
+ if valid_input:
988
+ # Create a copy of the merged DataFrame for dashboard purposes
989
+ dashboard_df = merged_df
990
+
991
+ # Create a DataFrame for tool purposes
992
+ tool_df = prepare_tool_df(merged_df, granularity_selection)
993
+
994
+ # Create Imputation DataFrame
995
+ imputation_df = generate_imputation_df(tool_df)
996
+
997
+ # Save data to project dictionary
998
+ st.session_state["project_dct"]["data_import"][
999
+ "granularity_selection"
1000
+ ] = st.session_state["granularity_selection_key"]
1001
+ st.session_state["project_dct"]["data_import"][
1002
+ "dashboard_df"
1003
+ ] = dashboard_df
1004
+ st.session_state["project_dct"]["data_import"]["tool_df"] = tool_df
1005
+ st.session_state["project_dct"]["data_import"]["unique_panels"] = (
1006
+ tool_df["panel"].unique()
1007
+ )
1008
+ st.session_state["project_dct"]["data_import"][
1009
+ "imputation_df"
1010
+ ] = imputation_df
1011
+
1012
+ # Success message
1013
+ with upload_warning_placeholder:
1014
+ st.success("Processed Successfully!", icon="🗂️")
1015
+ st.toast("Processed Successfully!", icon="🗂️")
1016
+
1017
+ # Log message
1018
+ log_message("info", "Processed Successfully!", "Data Import")
1019
+
1020
+ # Load saved data from project dictionary
1021
+ if st.session_state["project_dct"]["data_import"]["tool_df"] is None:
1022
+ st.stop()
1023
+ else:
1024
+ tool_df = st.session_state["project_dct"]["data_import"]["tool_df"]
1025
+ imputation_df = st.session_state["project_dct"]["data_import"]["imputation_df"]
1026
+ unique_panels = st.session_state["project_dct"]["data_import"]["unique_panels"]
1027
+
1028
+ # Unique Panel
1029
+ st.subheader("Unique Panel")
1030
+
1031
+ # Get unique panels count
1032
+ total_count = len(unique_panels)
1033
+
1034
+ # Define custom CSS for pastel light blue rounded rectangle
1035
+ custom_css = """
1036
+ <style>
1037
+ .panel-box {
1038
+ background-color: #ffdaab;
1039
+ border-radius: 10px;
1040
+ padding: 10px;
1041
+ margin: 0 0;
1042
+ }
1043
+ </style>
1044
+ """
1045
+
1046
+ # Display unique panels with total count
1047
+ st.markdown(custom_css, unsafe_allow_html=True)
1048
+ panel_html = f"<div class='panel-box'><strong>Unique Panels:</strong> {format_values_for_display(unique_panels)}<br><strong>Total Count:</strong> {total_count}</div>"
1049
+ st.markdown(panel_html, unsafe_allow_html=True)
1050
+ st.write("##") # Padding
1051
+
1052
+ # Impute Missing Values
1053
+ st.subheader("Impute Missing Values")
1054
+ edited_imputation_df = st.data_editor(
1055
+ imputation_df,
1056
+ column_config={
1057
+ "Imputation Method": st.column_config.SelectboxColumn(
1058
+ options=[
1059
+ "Drop Column",
1060
+ "Fill with Mean",
1061
+ "Fill with Median",
1062
+ "Fill with 0",
1063
+ ],
1064
+ required=True,
1065
+ default="Fill with 0",
1066
+ ),
1067
+ },
1068
+ column_order=[
1069
+ "Column Name",
1070
+ "Category",
1071
+ "Missing Values",
1072
+ "Zero Values",
1073
+ "Imputation Method",
1074
+ ],
1075
+ disabled=["Column Name", "Category", "Missing Values", "Zero Values"],
1076
+ hide_index=True,
1077
+ use_container_width=True,
1078
+ key="imputation_df_key",
1079
+ )
1080
+
1081
+ # Expander with markdown for imputation rules
1082
+ with st.expander("Impute Missing Values Guidelines"):
1083
+ st.markdown(
1084
+ f"""
1085
+ ### Imputation Guidelines
1086
+
1087
+ Please adhere to the following rules when handling missing values:
1088
+
1089
+ 1. **Default Imputation Strategies**:
1090
+ - **Response Metrics**: Imputed using the **mean** value of the column.
1091
+ - **Spends**: Imputed with **zero** values.
1092
+ - **Media, Exogenous, Internal**: Imputation strategy is **dynamic** based on the data.
1093
+
1094
+ 2. **Drop Threshold**:
1095
+ - If the combined percentage of **zeros** and **null values** in any column exceeds `{percent_drop_col_threshold}%`, the column will be **categorized to drop** by default which user can change manually.
1096
+ - **Example**: If `spends_facebook` has more than `{percent_drop_col_threshold}%` of zeros and nulls combined, it will be marked for dropping.
1097
+
1098
+ 3. **Category Generation and Association**:
1099
+ - Categories are automatically generated from the **Spends** columns.
1100
+ - **Example**: The column `spends_facebook` will generate the **facebook** category. This means columns like `spends_facebook`, `media_impression_facebook` and `media_clicks_facebook` will also be associated with this category.
1101
+
1102
+ 4. **Column Association and Imputation**:
1103
+ - Each category must have at least **one Media column** associated with it for imputation to proceed.
1104
+ - **Example**: If the **facebook** category does not have any media columns like `media_impression_facebook`, imputation will not be allowed for that category.
1105
+ - Solution: Either **drop the entire category** if it is empty, or **impute the columns** associated with the category instead of dropping them.
1106
+
1107
+ 5. **Response Metrics and Category Count**:
1108
+ - Dropping **Response Metric** columns is **not allowed** under any circumstances.
1109
+ - At least **two categories** must exist after imputation, or the Imputation will not proceed.
1110
+ - **Example**: If only **facebook** remains after selection, imputation will be halted.
1111
+
1112
+ **Notes**:
1113
+
1114
+ - The decision to drop a spends column will result in all associated columns being dropped.
1115
+ - **Example**: Dropping `spends_facebook` will also drop all related columns like `media_impression_facebook` and `media_clicks_facebook`.
1116
+ """
1117
+ )
1118
+
1119
+ # Imputation Warning Placeholder
1120
+ imputation_warning_placeholder = st.container()
1121
+
1122
+ # Save the DataFrame and dictionary from the current session
1123
+ if st.button("Impute and Save", use_container_width=True):
1124
+ with st.spinner("Imputing ..."):
1125
+ with imputation_warning_placeholder:
1126
+ # Perform Imputation
1127
+ imputed_tool_df, group_dict, message = perform_imputation(
1128
+ edited_imputation_df.copy(), tool_df.copy()
1129
+ )
1130
+
1131
+ if imputed_tool_df is None:
1132
+ st.warning(message, icon="⚠️")
1133
+
1134
+ # Log message
1135
+ log_message("warning", message, "Data Import")
1136
+
1137
+ else:
1138
+ st.session_state["project_dct"]["data_import"][
1139
+ "imputed_tool_df"
1140
+ ] = imputed_tool_df
1141
+ st.session_state["project_dct"]["data_import"][
1142
+ "imputation_df"
1143
+ ] = edited_imputation_df
1144
+ st.session_state["project_dct"]["data_import"][
1145
+ "group_dict"
1146
+ ] = group_dict
1147
+ st.session_state["project_dct"]["data_import"]["category_dict"] = (
1148
+ create_ordered_category_dict(imputed_tool_df)
1149
+ )
1150
+
1151
+ if imputed_tool_df is not None:
1152
+ # Update DB
1153
+ update_db(
1154
+ prj_id=st.session_state["project_number"],
1155
+ page_nam="Data Import",
1156
+ file_nam="project_dct",
1157
+ pkl_obj=pickle.dumps(st.session_state["project_dct"]),
1158
+ schema=schema,
1159
+ )
1160
+
1161
+ # Success message
1162
+ st.success("Saved Successfully!", icon="💾")
1163
+ st.toast("Saved Successfully!", icon="💾")
1164
+
1165
+ # Log message
1166
+ log_message("info", "Saved Successfully!", "Data Import")
1167
+
1168
+ # Load saved data from project dictionary
1169
+ if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
1170
+ st.stop()
1171
+ else:
1172
+ imputed_tool_df = st.session_state["project_dct"]["data_import"][
1173
+ "imputed_tool_df"
1174
+ ]
1175
+ group_dict = st.session_state["project_dct"]["data_import"]["group_dict"]
1176
+ category_dict = st.session_state["project_dct"]["data_import"]["category_dict"]
1177
+
1178
+ # Channel Groupings
1179
+ st.subheader("Channel Groupings")
1180
+ display_groups(group_dict)
1181
+ st.write("##") # Padding
1182
+
1183
+ # Variable Categorization
1184
+ st.subheader("Variable Categorization")
1185
+ display_groups(category_dict)
1186
+ st.write("##") # Padding
1187
+
1188
+ # Final DataFrame
1189
+ st.subheader("Final DataFrame")
1190
+ st.dataframe(
1191
+ imputed_tool_df,
1192
+ hide_index=True,
1193
+ column_config={
1194
+ "date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
1195
+ },
1196
+ )
1197
+ st.write("##") # Padding
1198
+
1199
+ except Exception as e:
1200
+ # Capture the error details
1201
+ exc_type, exc_value, exc_traceback = sys.exc_info()
1202
+ error_message = "".join(
1203
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
1204
+ )
1205
+
1206
+ # Log message
1207
+ log_message("error", f"An error occurred: {error_message}.", "Data Import")
1208
+
1209
+ # Display a warning message
1210
+ st.warning(
1211
+ "Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
1212
+ icon="⚠️",
1213
+ )
pages/2_Data_Assessment.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from data_analysis import *
4
+ import numpy as np
5
+ import pickle
6
+ import streamlit as st
7
+ from utilities import set_header, load_local_css, update_db, project_selection
8
+ from post_gres_cred import db_cred
9
+ from utilities import update_db
10
+ import re
11
+
12
+ st.set_page_config(
13
+ page_title="Data Assessment​",
14
+ page_icon=":shark:",
15
+ layout="wide",
16
+ initial_sidebar_state="collapsed",
17
+ )
18
+
19
+ schema = db_cred["schema"]
20
+ load_local_css("styles.css")
21
+ set_header()
22
+
23
+ if "username" not in st.session_state:
24
+ st.session_state["username"] = None
25
+
26
+ if "project_name" not in st.session_state:
27
+ st.session_state["project_name"] = None
28
+
29
+ if "project_dct" not in st.session_state:
30
+ project_selection()
31
+ st.stop()
32
+
33
+ if "username" in st.session_state and st.session_state["username"] is not None:
34
+
35
+ if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
36
+
37
+ st.error(f"Please import data from the Data Import Page")
38
+ st.stop()
39
+
40
+ st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][
41
+ "imputed_tool_df"
42
+ ]
43
+
44
+ st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][
45
+ "category_dict"
46
+ ]
47
+
48
+ # st.write(st.session_state['category_dict'])
49
+ cols1 = st.columns([2, 1])
50
+
51
+ with cols1[0]:
52
+ st.markdown(f"**Welcome {st.session_state['username']}**")
53
+ with cols1[1]:
54
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
55
+
56
+ st.title("Data Assessment​")
57
+
58
+ target_variables = [
59
+ st.session_state["category_dict"][key]
60
+ for key in st.session_state["category_dict"].keys()
61
+ if key == "Response Metrics"
62
+ ]
63
+
64
+ def format_display(inp):
65
+ return (
66
+ inp.title()
67
+ .replace("_", " ")
68
+ .replace("Media", "")
69
+ .replace("Cnt", "")
70
+ .strip()
71
+ )
72
+
73
+ target_variables = list(*target_variables)
74
+ target_column = st.selectbox(
75
+ "Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
76
+ target_variables,
77
+ index=st.session_state["project_dct"]["data_validation"]["target_column"],
78
+ format_func=format_display,
79
+ )
80
+
81
+ st.session_state["project_dct"]["data_validation"]["target_column"] = (
82
+ target_variables.index(target_column)
83
+ )
84
+
85
+ st.session_state["target_column"] = target_column
86
+
87
+
88
+ if "panel" not in st.session_state["cleaned_data"].columns:
89
+ st.write('True')
90
+ st.session_state["cleaned_data"]["panel"] = ["Aggregated"] * len(
91
+ st.session_state["cleaned_data"]
92
+ )
93
+
94
+ disable = True
95
+
96
+ else:
97
+ panels = st.session_state["cleaned_data"]["panel"]
98
+
99
+ disable = False
100
+
101
+ selected_panels = st.multiselect(
102
+ "Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
103
+ st.session_state["cleaned_data"]["panel"].unique(),
104
+ default=st.session_state["project_dct"]["data_validation"]["selected_panels"],
105
+ disabled=disable,
106
+ )
107
+
108
+ st.session_state["project_dct"]["data_validation"][
109
+ "selected_panels"
110
+ ] = selected_panels
111
+
112
+ aggregation_dict = {
113
+ item: "sum" if key == "Media" else "mean"
114
+ for key, value in st.session_state["category_dict"].items()
115
+ for item in value
116
+ if item not in ["date", "panel"]
117
+ }
118
+
119
+ aggregation_dict = {
120
+ key: value
121
+ for key, value in aggregation_dict.items()
122
+ if key in st.session_state["cleaned_data"].columns
123
+ }
124
+
125
+ with st.expander("**Target Variable Analysis**"):
126
+
127
+ if len(selected_panels) > 0:
128
+ st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][
129
+ st.session_state["cleaned_data"]["panel"].isin(selected_panels)
130
+ ]
131
+
132
+ st.session_state["Cleaned_data_panel"] = (
133
+ st.session_state["Cleaned_data_panel"]
134
+ .groupby(by="date")
135
+ .agg(aggregation_dict)
136
+ )
137
+ st.session_state["Cleaned_data_panel"] = st.session_state[
138
+ "Cleaned_data_panel"
139
+ ].reset_index()
140
+ else:
141
+ # st.write(st.session_state['cleaned_data'])
142
+ st.session_state["Cleaned_data_panel"] = (
143
+ st.session_state["cleaned_data"]
144
+ .groupby(by="date")
145
+ .agg(aggregation_dict)
146
+ )
147
+ st.session_state["Cleaned_data_panel"] = st.session_state[
148
+ "Cleaned_data_panel"
149
+ ].reset_index()
150
+
151
+ fig = line_plot_target(
152
+ st.session_state["Cleaned_data_panel"],
153
+ target=target_column,
154
+ title=f"{target_column} Over Time",
155
+ )
156
+ st.plotly_chart(fig, use_container_width=True)
157
+
158
+ media_channel = list(
159
+ *[
160
+ st.session_state["category_dict"][key]
161
+ for key in st.session_state["category_dict"].keys()
162
+ if key == "Media"
163
+ ]
164
+ )
165
+
166
+ spends_features = list(
167
+ *[
168
+ st.session_state["category_dict"][key]
169
+ for key in st.session_state["category_dict"].keys()
170
+ if key == "Spends"
171
+ ]
172
+ )
173
+ # st.write(media_channel)
174
+
175
+ exo_var = list(
176
+ *[
177
+ st.session_state["category_dict"][key]
178
+ for key in st.session_state["category_dict"].keys()
179
+ if key == "Exogenous"
180
+ ]
181
+ )
182
+ internal_var = list(
183
+ *[
184
+ st.session_state["category_dict"][key]
185
+ for key in st.session_state["category_dict"].keys()
186
+ if key == "Internal"
187
+ ]
188
+ )
189
+
190
+ Non_media_variables = exo_var + internal_var
191
+
192
+ st.markdown("### Annual Data Summary")
193
+
194
+ summary_df = summary(
195
+ st.session_state["Cleaned_data_panel"],
196
+ media_channel + [target_column] + spends_features,
197
+ spends=None,
198
+ Target=True,
199
+ )
200
+
201
+ st.dataframe(
202
+ summary_df.sort_index(axis=1),
203
+ use_container_width=True,
204
+ )
205
+
206
+ if st.checkbox("View Raw Data"):
207
+ st.cache_resource(show_spinner=False)
208
+
209
+ def raw_df_gen():
210
+ # Convert 'date' to datetime but do not convert to string yet for sorting
211
+ dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"])
212
+
213
+ # Concatenate the dates with other numeric columns formatted
214
+ raw_df = pd.concat(
215
+ [
216
+ dates,
217
+ st.session_state["Cleaned_data_panel"]
218
+ .select_dtypes(np.number)
219
+ .applymap(format_numbers),
220
+ ],
221
+ axis=1,
222
+ )
223
+
224
+ # Now sort raw_df by the 'date' column, which is still in datetime format
225
+ sorted_raw_df = raw_df.sort_values(by="date", ascending=True)
226
+
227
+ # After sorting, convert 'date' to string format for display
228
+ sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y")
229
+
230
+ return sorted_raw_df
231
+
232
+ # Display the sorted DataFrame in Streamlit
233
+ st.dataframe(raw_df_gen())
234
+
235
+ col1 = st.columns(1)
236
+
237
+ if "selected_feature" not in st.session_state:
238
+ st.session_state["selected_feature"] = None
239
+
240
+ # st.warning('Work in Progress')
241
+ with st.expander("Media Variables Analysis"):
242
+ # Get the selected feature
243
+
244
+ st.session_state["selected_feature"] = st.selectbox(
245
+ "Select Media", media_channel + spends_features, format_func=format_display
246
+ )
247
+
248
+ # st.write(st.session_state["selected_feature"].split('cnt_')[1] )
249
+ # st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
250
+
251
+ # )
252
+
253
+ # Filter spends features based on the selected feature
254
+ spends_col = st.columns(2)
255
+ spends_feature = [
256
+ col
257
+ for col in spends_features
258
+ if re.split(r"cost_|spends_", col.lower())[1]
259
+ in st.session_state["selected_feature"]
260
+ ]
261
+
262
+ with spends_col[0]:
263
+ if len(spends_feature) == 0:
264
+ st.warning(
265
+ "The selected metric does not include a 'spends' variable in the data. Please verify that the columns are correctly named or select the appropriate columns in the provided selection box."
266
+ )
267
+ else:
268
+ st.write(
269
+ f'Selected "{spends_feature[0]}" as the corresponding spends variable automatically. If this is incorrect, please click the checkbox to change the variable.'
270
+ )
271
+
272
+ with spends_col[1]:
273
+ if len(spends_feature) == 0 or st.checkbox(
274
+ 'Select "Spends" variable for CPM and CPC calculation'
275
+ ):
276
+ spends_feature = [st.selectbox("Spends Variable", spends_features)]
277
+
278
+ if "validation" not in st.session_state:
279
+
280
+ st.session_state["validation"] = st.session_state["project_dct"][
281
+ "data_validation"
282
+ ]["validated_variables"]
283
+
284
+ val_variables = [col for col in media_channel if col != "date"]
285
+
286
+ if not set(
287
+ st.session_state["project_dct"]["data_validation"]["validated_variables"]
288
+ ).issubset(set(val_variables)):
289
+
290
+ st.session_state["validation"] = []
291
+
292
+ else:
293
+ fig_row1 = line_plot(
294
+ st.session_state["Cleaned_data_panel"],
295
+ x_col="date",
296
+ y1_cols=[st.session_state["selected_feature"]],
297
+ y2_cols=[target_column],
298
+ title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
299
+ )
300
+ st.plotly_chart(fig_row1, use_container_width=True)
301
+ st.markdown("### Summary")
302
+ st.dataframe(
303
+ summary(
304
+ st.session_state["Cleaned_data_panel"],
305
+ [st.session_state["selected_feature"]],
306
+ spends=spends_feature[0],
307
+ ),
308
+ use_container_width=True,
309
+ )
310
+
311
+ cols2 = st.columns(2)
312
+
313
+ if len(
314
+ set(st.session_state["validation"]).intersection(val_variables)
315
+ ) == len(val_variables):
316
+ disable = True
317
+ help = "All media variables are validated"
318
+ else:
319
+ disable = False
320
+ help = ""
321
+
322
+ with cols2[0]:
323
+ if st.button("Validate", disabled=disable, help=help):
324
+ st.session_state["validation"].append(
325
+ st.session_state["selected_feature"]
326
+ )
327
+ with cols2[1]:
328
+
329
+ if st.checkbox("Validate All", disabled=disable, help=help):
330
+ st.session_state["validation"].extend(val_variables)
331
+ st.success("All media variables are validated ✅")
332
+
333
+ if len(
334
+ set(st.session_state["validation"]).intersection(val_variables)
335
+ ) != len(val_variables):
336
+ validation_data = pd.DataFrame(
337
+ {
338
+ "Validate": [
339
+ (True if col in st.session_state["validation"] else False)
340
+ for col in val_variables
341
+ ],
342
+ "Variables": val_variables,
343
+ }
344
+ )
345
+
346
+ sorted_validation_df = validation_data.sort_values(
347
+ by="Variables", ascending=True, na_position="first"
348
+ )
349
+ cols3 = st.columns([1, 30])
350
+ with cols3[1]:
351
+ validation_df = st.data_editor(
352
+ sorted_validation_df,
353
+ # column_config={
354
+ # 'Validate':st.column_config.CheckboxColumn(wi)
355
+ # },
356
+ column_config={
357
+ "Validate": st.column_config.CheckboxColumn(
358
+ default=False,
359
+ width=100,
360
+ ),
361
+ "Variables": st.column_config.TextColumn(width=1000),
362
+ },
363
+ hide_index=True,
364
+ )
365
+
366
+ selected_rows = validation_df[validation_df["Validate"] == True][
367
+ "Variables"
368
+ ]
369
+
370
+ # st.write(selected_rows)
371
+
372
+ st.session_state["validation"].extend(selected_rows)
373
+
374
+ st.session_state["project_dct"]["data_validation"][
375
+ "validated_variables"
376
+ ] = st.session_state["validation"]
377
+
378
+ not_validated_variables = [
379
+ col
380
+ for col in val_variables
381
+ if col not in st.session_state["validation"]
382
+ ]
383
+
384
+ if not_validated_variables:
385
+ not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
386
+ st.warning(not_validated_message)
387
+
388
+ with st.expander("Non-Media Variables Analysis"):
389
+ if len(Non_media_variables) == 0:
390
+ st.warning("Non-Media variables not present")
391
+
392
+ else:
393
+ selected_columns_row4 = st.selectbox(
394
+ "Select Channel",
395
+ Non_media_variables,
396
+ format_func=format_display,
397
+ index=st.session_state["project_dct"]["data_validation"][
398
+ "Non_media_variables"
399
+ ],
400
+ )
401
+
402
+ st.session_state["project_dct"]["data_validation"][
403
+ "Non_media_variables"
404
+ ] = Non_media_variables.index(selected_columns_row4)
405
+
406
+ # # Create the dual-axis line plot
407
+ fig_row4 = line_plot(
408
+ st.session_state["Cleaned_data_panel"],
409
+ x_col="date",
410
+ y1_cols=[selected_columns_row4],
411
+ y2_cols=[target_column],
412
+ title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
413
+ )
414
+ st.plotly_chart(fig_row4, use_container_width=True)
415
+ selected_non_media = selected_columns_row4
416
+ sum_df = st.session_state["Cleaned_data_panel"][
417
+ ["date", selected_non_media, target_column]
418
+ ]
419
+ sum_df["Year"] = pd.to_datetime(
420
+ st.session_state["Cleaned_data_panel"]["date"]
421
+ ).dt.year
422
+ # st.dataframe(df)
423
+ # st.dataframe(sum_df.head(2))
424
+
425
+ sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
426
+ sum_df.loc["Grand Total"] = sum_df.sum()
427
+ sum_df = sum_df.applymap(format_numbers)
428
+ sum_df.fillna("-", inplace=True)
429
+ sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
430
+ st.markdown("### Summary")
431
+ st.dataframe(sum_df, use_container_width=True)
432
+
433
+ with st.expander("Correlation Analysis"):
434
+ options = list(
435
+ st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
436
+ )
437
+
438
+
439
+ if "correlation" not in st.session_state["project_dct"]["data_import"]:
440
+ st.session_state["project_dct"]["data_import"]["correlation"]=[]
441
+
442
+ selected_options = st.multiselect(
443
+ "Select Variables for Correlation Plot",
444
+ [var for var in options if var != target_column],
445
+ default=st.session_state["project_dct"]["data_import"]["correlation"],
446
+ )
447
+
448
+ st.session_state["project_dct"]["data_import"]["correlation"] = selected_options
449
+
450
+ st.pyplot(
451
+ correlation_plot(
452
+ st.session_state["Cleaned_data_panel"],
453
+ selected_options,
454
+ target_column,
455
+ )
456
+ )
457
+
458
+ if st.button("Save Changes", use_container_width=True):
459
+ # Update DB
460
+ update_db(
461
+ prj_id=st.session_state["project_number"],
462
+ page_nam="Data Validation and Insights",
463
+ file_nam="project_dct",
464
+ pkl_obj=pickle.dumps(st.session_state["project_dct"]),
465
+ schema=schema,
466
+ )
467
+ st.success("Changes saved")
pages/3_AI_Model_Transformations.py ADDED
@@ -0,0 +1,1326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="AI Model Transformations",
6
+ page_icon="⚖️",
7
+ layout="wide",
8
+ initial_sidebar_state="collapsed",
9
+ )
10
+
11
+ import sys
12
+ import pickle
13
+ import traceback
14
+ import numpy as np
15
+ import pandas as pd
16
+ import plotly.graph_objects as go
17
+ from post_gres_cred import db_cred
18
+ from log_application import log_message
19
+ from utilities import (
20
+ set_header,
21
+ load_local_css,
22
+ update_db,
23
+ project_selection,
24
+ delete_entries,
25
+ retrieve_pkl_object,
26
+ )
27
+ from constants import (
28
+ predefined_defaults,
29
+ lead_min_value,
30
+ lead_max_value,
31
+ lead_step,
32
+ lag_min_value,
33
+ lag_max_value,
34
+ lag_step,
35
+ moving_average_min_value,
36
+ moving_average_max_value,
37
+ moving_average_step,
38
+ saturation_min_value,
39
+ saturation_max_value,
40
+ saturation_step,
41
+ power_min_value,
42
+ power_max_value,
43
+ power_step,
44
+ adstock_min_value,
45
+ adstock_max_value,
46
+ adstock_step,
47
+ display_max_col,
48
+ )
49
+
50
+
51
+ schema = db_cred["schema"]
52
+ load_local_css("styles.css")
53
+ set_header()
54
+
55
+
56
+ # Initialize project name session state
57
+ if "project_name" not in st.session_state:
58
+ st.session_state["project_name"] = None
59
+
60
+ # Fetch project dictionary
61
+ if "project_dct" not in st.session_state:
62
+ project_selection()
63
+ st.stop()
64
+
65
+ # Display Username and Project Name
66
+ if "username" in st.session_state and st.session_state["username"] is not None:
67
+
68
+ cols1 = st.columns([2, 1])
69
+
70
+ with cols1[0]:
71
+ st.markdown(f"**Welcome {st.session_state['username']}**")
72
+ with cols1[1]:
73
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
74
+
75
+ # Load saved data from project dictionary
76
+ if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
77
+ st.warning(
78
+ "The data import is incomplete. Please go back to the Data Import page and complete the save.",
79
+ icon="🔙",
80
+ )
81
+
82
+ # Log message
83
+ log_message(
84
+ "warning",
85
+ "The data import is incomplete. Please go back to the Data Import page and complete the save.",
86
+ "Transformations",
87
+ )
88
+
89
+ st.stop()
90
+ else:
91
+ final_df_loaded = st.session_state["project_dct"]["data_import"][
92
+ "imputed_tool_df"
93
+ ].copy()
94
+ bin_dict_loaded = st.session_state["project_dct"]["data_import"][
95
+ "category_dict"
96
+ ].copy()
97
+ unique_panels = st.session_state["project_dct"]["data_import"][
98
+ "unique_panels"
99
+ ].copy()
100
+
101
+ # Initialize project dictionary data
102
+ if st.session_state["project_dct"]["transformations"]["final_df"] is None:
103
+ st.session_state["project_dct"]["transformations"][
104
+ "final_df"
105
+ ] = final_df_loaded # Default as original dataframe
106
+
107
+ # Extract original columns for specified categories
108
+ original_columns = {
109
+ category: bin_dict_loaded[category]
110
+ for category in ["Media", "Internal", "Exogenous"]
111
+ if category in bin_dict_loaded
112
+ }
113
+
114
+ # Retrive Panel columns
115
+ panel = ["panel"] if len(unique_panels) > 1 else []
116
+
117
+
118
+ # Function to clear model metadata
119
+ def clear_pages():
120
+ # Reset Pages
121
+ st.session_state["project_dct"]["model_build"] = {
122
+ "sel_target_col": None,
123
+ "all_iters_check": False,
124
+ "iterations": 0,
125
+ "build_button": False,
126
+ "show_results_check": False,
127
+ "session_state_saved": {},
128
+ }
129
+ st.session_state["project_dct"]["model_tuning"] = {
130
+ "sel_target_col": None,
131
+ "sel_model": {},
132
+ "flag_expander": False,
133
+ "start_date_default": None,
134
+ "end_date_default": None,
135
+ "repeat_default": "No",
136
+ "flags": {},
137
+ "select_all_flags_check": {},
138
+ "selected_flags": {},
139
+ "trend_check": False,
140
+ "week_num_check": False,
141
+ "sine_cosine_check": False,
142
+ "session_state_saved": {},
143
+ }
144
+ st.session_state["project_dct"]["saved_model_results"] = {
145
+ "selected_options": None,
146
+ "model_grid_sel": [1],
147
+ }
148
+ if "model_results_df" in st.session_state:
149
+ del st.session_state["model_results_df"]
150
+ if "model_results_data" in st.session_state:
151
+ del st.session_state["model_results_data"]
152
+ if "coefficients_df" in st.session_state:
153
+ del st.session_state["coefficients_df"]
154
+
155
+
156
+ # Function to update transformation change
157
+ def transformation_change(category, transformation, key):
158
+ st.session_state["project_dct"]["transformations"][category][transformation] = (
159
+ st.session_state[key]
160
+ )
161
+
162
+
163
+ # Function to update specific transformation change
164
+ def transformation_specific_change(channel_name, transformation, key):
165
+ st.session_state["project_dct"]["transformations"]["Specific"][transformation][
166
+ channel_name
167
+ ] = st.session_state[key]
168
+
169
+
170
+ # Function to update transformations to apply change
171
+ def transformations_to_apply_change(category, key):
172
+ st.session_state["project_dct"]["transformations"][category][key] = (
173
+ st.session_state[key]
174
+ )
175
+
176
+
177
+ # Function to update channel select specific change
178
+ def channel_select_specific_change():
179
+ st.session_state["project_dct"]["transformations"]["Specific"][
180
+ "channel_select_specific"
181
+ ] = st.session_state["channel_select_specific"]
182
+
183
+
184
+ # Function to update specific transformation change
185
+ def specific_transformation_change(specific_transformation_key):
186
+ st.session_state["project_dct"]["transformations"]["Specific"][
187
+ specific_transformation_key
188
+ ] = st.session_state[specific_transformation_key]
189
+
190
+
191
+ # Function to build transformation widgets
192
+ def transformation_widgets(category, transform_params, date_granularity):
193
+ # Transformation Options
194
+ transformation_options = {
195
+ "Media": [
196
+ "Lag",
197
+ "Moving Average",
198
+ "Saturation",
199
+ "Power",
200
+ "Adstock",
201
+ ],
202
+ "Internal": ["Lead", "Lag", "Moving Average"],
203
+ "Exogenous": ["Lead", "Lag", "Moving Average"],
204
+ }
205
+
206
+ # Define a helper function to create widgets for each transformation
207
+ def create_transformation_widgets(column, transformations):
208
+ with column:
209
+ for transformation in transformations:
210
+ transformation_key = f"{transformation}_{category}"
211
+
212
+ slider_value = st.session_state["project_dct"]["transformations"][
213
+ category
214
+ ].get(transformation, predefined_defaults[transformation])
215
+
216
+ # Conditionally create widgets for selected transformations
217
+ if transformation == "Lead":
218
+ st.markdown(f"**{transformation} ({date_granularity})**")
219
+
220
+ lead = st.slider(
221
+ label="Lead periods",
222
+ min_value=lead_min_value,
223
+ max_value=lead_max_value,
224
+ step=lead_step,
225
+ value=slider_value,
226
+ key=transformation_key,
227
+ label_visibility="collapsed",
228
+ on_change=transformation_change,
229
+ args=(
230
+ category,
231
+ transformation,
232
+ transformation_key,
233
+ ),
234
+ )
235
+
236
+ start = lead[0]
237
+ end = lead[1]
238
+ step = lead_step
239
+ transform_params[category][transformation] = np.arange(
240
+ start, end + step, step
241
+ )
242
+
243
+ if transformation == "Lag":
244
+ st.markdown(f"**{transformation} ({date_granularity})**")
245
+
246
+ lag = st.slider(
247
+ label="Lag periods",
248
+ min_value=lag_min_value,
249
+ max_value=lag_max_value,
250
+ step=lag_step,
251
+ value=slider_value,
252
+ key=transformation_key,
253
+ label_visibility="collapsed",
254
+ on_change=transformation_change,
255
+ args=(
256
+ category,
257
+ transformation,
258
+ transformation_key,
259
+ ),
260
+ )
261
+
262
+ start = lag[0]
263
+ end = lag[1]
264
+ step = lag_step
265
+ transform_params[category][transformation] = np.arange(
266
+ start, end + step, step
267
+ )
268
+
269
+ if transformation == "Moving Average":
270
+ st.markdown(f"**{transformation} ({date_granularity})**")
271
+
272
+ window = st.slider(
273
+ label="Window size for Moving Average",
274
+ min_value=moving_average_min_value,
275
+ max_value=moving_average_max_value,
276
+ step=moving_average_step,
277
+ value=slider_value,
278
+ key=transformation_key,
279
+ label_visibility="collapsed",
280
+ on_change=transformation_change,
281
+ args=(
282
+ category,
283
+ transformation,
284
+ transformation_key,
285
+ ),
286
+ )
287
+
288
+ start = window[0]
289
+ end = window[1]
290
+ step = moving_average_step
291
+ transform_params[category][transformation] = np.arange(
292
+ start, end + step, step
293
+ )
294
+
295
+ if transformation == "Saturation":
296
+ st.markdown(f"**{transformation} (%)**")
297
+
298
+ saturation_point = st.slider(
299
+ label="Saturation Percentage",
300
+ min_value=saturation_min_value,
301
+ max_value=saturation_max_value,
302
+ step=saturation_step,
303
+ value=slider_value,
304
+ key=transformation_key,
305
+ label_visibility="collapsed",
306
+ on_change=transformation_change,
307
+ args=(
308
+ category,
309
+ transformation,
310
+ transformation_key,
311
+ ),
312
+ )
313
+
314
+ start = saturation_point[0]
315
+ end = saturation_point[1]
316
+ step = saturation_step
317
+ transform_params[category][transformation] = np.arange(
318
+ start, end + step, step
319
+ )
320
+
321
+ if transformation == "Power":
322
+ st.markdown(f"**{transformation}**")
323
+
324
+ power = st.slider(
325
+ label="Power",
326
+ min_value=power_min_value,
327
+ max_value=power_max_value,
328
+ step=power_step,
329
+ value=slider_value,
330
+ key=transformation_key,
331
+ label_visibility="collapsed",
332
+ on_change=transformation_change,
333
+ args=(
334
+ category,
335
+ transformation,
336
+ transformation_key,
337
+ ),
338
+ )
339
+
340
+ start = power[0]
341
+ end = power[1]
342
+ step = power_step
343
+ transform_params[category][transformation] = np.arange(
344
+ start, end + step, step
345
+ )
346
+
347
+ if transformation == "Adstock":
348
+ st.markdown(f"**{transformation}**")
349
+
350
+ rate = st.slider(
351
+ label="Decay Factor",
352
+ min_value=adstock_min_value,
353
+ max_value=adstock_max_value,
354
+ step=adstock_step,
355
+ value=slider_value,
356
+ key=transformation_key,
357
+ label_visibility="collapsed",
358
+ on_change=transformation_change,
359
+ args=(
360
+ category,
361
+ transformation,
362
+ transformation_key,
363
+ ),
364
+ )
365
+
366
+ start = rate[0]
367
+ end = rate[1]
368
+ step = adstock_step
369
+ adstock_range = [
370
+ round(a, 3) for a in np.arange(start, end + step, step)
371
+ ]
372
+ transform_params[category][transformation] = np.array(adstock_range)
373
+
374
+ with st.expander(f"All {category} Transformations", expanded=True):
375
+
376
+ transformation_key = f"transformation_{category}"
377
+
378
+ # Select which transformations to apply
379
+ sel_transformations = st.session_state["project_dct"]["transformations"][
380
+ category
381
+ ].get(transformation_key, [])
382
+
383
+ # Reset default selected channels list if options are changed
384
+ for channel in sel_transformations:
385
+ if channel not in transformation_options[category]:
386
+ (
387
+ st.session_state["project_dct"]["transformations"][category][
388
+ transformation_key
389
+ ],
390
+ sel_transformations,
391
+ ) = ([], [])
392
+
393
+ transformations_to_apply = st.multiselect(
394
+ label="Select transformations to apply",
395
+ options=transformation_options[category],
396
+ default=sel_transformations,
397
+ key=transformation_key,
398
+ on_change=transformations_to_apply_change,
399
+ args=(
400
+ category,
401
+ transformation_key,
402
+ ),
403
+ )
404
+
405
+ # Determine the number of transformations to put in each column
406
+ transformations_per_column = (
407
+ len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
408
+ )
409
+
410
+ # Create two columns
411
+ col1, col2 = st.columns(2)
412
+
413
+ # Assign transformations to each column
414
+ transformations_col1 = transformations_to_apply[:transformations_per_column]
415
+ transformations_col2 = transformations_to_apply[transformations_per_column:]
416
+
417
+ # Create widgets in each column
418
+ create_transformation_widgets(col1, transformations_col1)
419
+ create_transformation_widgets(col2, transformations_col2)
420
+
421
+
422
+ # Define a helper function to create widgets for each specific transformation
423
+ def create_specific_transformation_widgets(
424
+ column,
425
+ transformations,
426
+ channel_name,
427
+ date_granularity,
428
+ specific_transform_params,
429
+ ):
430
+ with column:
431
+ for transformation in transformations:
432
+ transformation_key = f"{transformation}_{channel_name}_specific"
433
+
434
+ if (
435
+ transformation
436
+ not in st.session_state["project_dct"]["transformations"]["Specific"]
437
+ ):
438
+ st.session_state["project_dct"]["transformations"]["Specific"][
439
+ transformation
440
+ ] = {}
441
+
442
+ slider_value = st.session_state["project_dct"]["transformations"][
443
+ "Specific"
444
+ ][transformation].get(channel_name, predefined_defaults[transformation])
445
+
446
+ # Conditionally create widgets for selected transformations
447
+ if transformation == "Lead":
448
+ st.markdown(f"**Lead ({date_granularity})**")
449
+
450
+ lead = st.slider(
451
+ label="Lead periods",
452
+ min_value=lead_min_value,
453
+ max_value=lead_max_value,
454
+ step=lead_step,
455
+ value=slider_value,
456
+ key=transformation_key,
457
+ label_visibility="collapsed",
458
+ on_change=transformation_specific_change,
459
+ args=(
460
+ channel_name,
461
+ transformation,
462
+ transformation_key,
463
+ ),
464
+ )
465
+
466
+ start = lead[0]
467
+ end = lead[1]
468
+ step = lead_step
469
+ specific_transform_params[channel_name]["Lead"] = np.arange(
470
+ start, end + step, step
471
+ )
472
+
473
+ if transformation == "Lag":
474
+ st.markdown(f"**Lag ({date_granularity})**")
475
+
476
+ lag = st.slider(
477
+ label="Lag periods",
478
+ min_value=lag_min_value,
479
+ max_value=lag_max_value,
480
+ step=lag_step,
481
+ value=slider_value,
482
+ key=transformation_key,
483
+ label_visibility="collapsed",
484
+ on_change=transformation_specific_change,
485
+ args=(
486
+ channel_name,
487
+ transformation,
488
+ transformation_key,
489
+ ),
490
+ )
491
+
492
+ start = lag[0]
493
+ end = lag[1]
494
+ step = lag_step
495
+ specific_transform_params[channel_name]["Lag"] = np.arange(
496
+ start, end + step, step
497
+ )
498
+
499
+ if transformation == "Moving Average":
500
+ st.markdown(f"**Moving Average ({date_granularity})**")
501
+
502
+ window = st.slider(
503
+ label="Window size for Moving Average",
504
+ min_value=moving_average_min_value,
505
+ max_value=moving_average_max_value,
506
+ step=moving_average_step,
507
+ value=slider_value,
508
+ key=transformation_key,
509
+ label_visibility="collapsed",
510
+ on_change=transformation_specific_change,
511
+ args=(
512
+ channel_name,
513
+ transformation,
514
+ transformation_key,
515
+ ),
516
+ )
517
+
518
+ start = window[0]
519
+ end = window[1]
520
+ step = moving_average_step
521
+ specific_transform_params[channel_name]["Moving Average"] = np.arange(
522
+ start, end + step, step
523
+ )
524
+
525
+ if transformation == "Saturation":
526
+ st.markdown("**Saturation (%)**")
527
+
528
+ saturation_point = st.slider(
529
+ label="Saturation Percentage",
530
+ min_value=saturation_min_value,
531
+ max_value=saturation_max_value,
532
+ step=saturation_step,
533
+ value=slider_value,
534
+ key=transformation_key,
535
+ label_visibility="collapsed",
536
+ on_change=transformation_specific_change,
537
+ args=(
538
+ channel_name,
539
+ transformation,
540
+ transformation_key,
541
+ ),
542
+ )
543
+
544
+ start = saturation_point[0]
545
+ end = saturation_point[1]
546
+ step = saturation_step
547
+ specific_transform_params[channel_name]["Saturation"] = np.arange(
548
+ start, end + step, step
549
+ )
550
+
551
+ if transformation == "Power":
552
+ st.markdown("**Power**")
553
+
554
+ power = st.slider(
555
+ label="Power",
556
+ min_value=power_min_value,
557
+ max_value=power_max_value,
558
+ step=power_step,
559
+ value=slider_value,
560
+ key=transformation_key,
561
+ label_visibility="collapsed",
562
+ on_change=transformation_specific_change,
563
+ args=(
564
+ channel_name,
565
+ transformation,
566
+ transformation_key,
567
+ ),
568
+ )
569
+
570
+ start = power[0]
571
+ end = power[1]
572
+ step = power_step
573
+ specific_transform_params[channel_name]["Power"] = np.arange(
574
+ start, end + step, step
575
+ )
576
+
577
+ if transformation == "Adstock":
578
+ st.markdown("**Adstock**")
579
+ rate = st.slider(
580
+ label="Decay Factor",
581
+ min_value=adstock_min_value,
582
+ max_value=adstock_max_value,
583
+ step=adstock_step,
584
+ value=slider_value,
585
+ key=transformation_key,
586
+ label_visibility="collapsed",
587
+ on_change=transformation_specific_change,
588
+ args=(
589
+ channel_name,
590
+ transformation,
591
+ transformation_key,
592
+ ),
593
+ )
594
+
595
+ start = rate[0]
596
+ end = rate[1]
597
+ step = adstock_step
598
+ adstock_range = [
599
+ round(a, 3) for a in np.arange(start, end + step, step)
600
+ ]
601
+ specific_transform_params[channel_name]["Adstock"] = np.array(
602
+ adstock_range
603
+ )
604
+
605
+
606
+ # Function to apply Lag transformation
607
+ def apply_lag(df, lag):
608
+ return df.shift(lag)
609
+
610
+
611
+ # Function to apply Lead transformation
612
+ def apply_lead(df, lead):
613
+ return df.shift(-lead)
614
+
615
+
616
+ # Function to apply Moving Average transformation
617
+ def apply_moving_average(df, window_size):
618
+ return df.rolling(window=window_size).mean()
619
+
620
+
621
+ # Function to apply Saturation transformation
622
+ def apply_saturation(df, saturation_percent_100):
623
+ # Convert percentage to fraction
624
+ saturation_percent = min(max(saturation_percent_100, 0.01), 99.99) / 100.0
625
+
626
+ # Get the maximum and minimum values
627
+ column_max = df.max()
628
+ column_min = df.min()
629
+
630
+ # If the data is constant, scale it directly
631
+ if column_min == column_max:
632
+ return df.apply(lambda x: x * saturation_percent)
633
+
634
+ # Compute the saturation point from the data range
635
+ saturation_point = (column_min + saturation_percent * column_max) / 2
636
+
637
+ # Calculate steepness for the saturation curve
638
+ numerator = np.log((1 / saturation_percent) - 1)
639
+ denominator = np.log(saturation_point / column_max)
640
+ steepness = numerator / denominator
641
+
642
+ # Apply the saturation transformation
643
+ transformed_series = df.apply(
644
+ lambda x: (1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)) * x
645
+ )
646
+
647
+ return transformed_series
648
+
649
+
650
+ # Function to apply Power transformation
651
+ def apply_power(df, power):
652
+ return df**power
653
+
654
+
655
+ # Function to apply Adstock transformation
656
+ def apply_adstock(df, factor):
657
+ x = 0
658
+ # Use the walrus operator to update x iteratively with the Adstock formula
659
+ adstock_var = [x := x * factor + v for v in df]
660
+ ans = pd.Series(adstock_var, index=df.index)
661
+ return ans
662
+
663
+
664
+ # Function to generate transformed columns names
665
+ @st.cache_resource(show_spinner=False)
666
+ def generate_transformed_columns(
667
+ original_columns, transform_params, specific_transform_params
668
+ ):
669
+ transformed_columns, summary = {}, {}
670
+
671
+ for category, columns in original_columns.items():
672
+ for column in columns:
673
+ transformed_columns[column] = []
674
+ summary_details = (
675
+ []
676
+ ) # List to hold transformation details for the current column
677
+
678
+ if (
679
+ column in specific_transform_params.keys()
680
+ and len(specific_transform_params[column]) > 0
681
+ ):
682
+ for transformation, values in specific_transform_params[column].items():
683
+ # Generate transformed column names for each value
684
+ for value in values:
685
+ transformed_name = f"{column}@{transformation}_{value}"
686
+ transformed_columns[column].append(transformed_name)
687
+
688
+ # Format the values list as a string with commas and "and" before the last item
689
+ if len(values) > 1:
690
+ formatted_values = (
691
+ ", ".join(map(str, values[:-1])) + " and " + str(values[-1])
692
+ )
693
+ else:
694
+ formatted_values = str(values[0])
695
+
696
+ # Add transformation details
697
+ summary_details.append(f"{transformation} ({formatted_values})")
698
+
699
+ else:
700
+ if category in transform_params:
701
+ for transformation, values in transform_params[category].items():
702
+ # Generate transformed column names for each value
703
+ if column not in specific_transform_params.keys():
704
+ for value in values:
705
+ transformed_name = f"{column}@{transformation}_{value}"
706
+ transformed_columns[column].append(transformed_name)
707
+
708
+ # Format the values list as a string with commas and "and" before the last item
709
+ if len(values) > 1:
710
+ formatted_values = (
711
+ ", ".join(map(str, values[:-1]))
712
+ + " and "
713
+ + str(values[-1])
714
+ )
715
+ else:
716
+ formatted_values = str(values[0])
717
+
718
+ # Add transformation details
719
+ summary_details.append(
720
+ f"{transformation} ({formatted_values})"
721
+ )
722
+
723
+ else:
724
+ summary_details = ["No transformation selected"]
725
+
726
+ # Only add to summary if there are transformation details for the column
727
+ if summary_details:
728
+ formatted_summary = "⮕ ".join(summary_details)
729
+ # Use <strong> tags to make the column name bold
730
+ summary[column] = f"<strong>{column}</strong>: {formatted_summary}"
731
+
732
+ # Generate a comprehensive summary string for all columns
733
+ summary_items = [
734
+ f"{idx + 1}. {details}" for idx, details in enumerate(summary.values())
735
+ ]
736
+
737
+ summary_string = "\n".join(summary_items)
738
+
739
+ return transformed_columns, summary_string
740
+
741
+
742
+ # Function to transform Dataframe slice
743
+ def transform_slice(
744
+ transform_params,
745
+ transformation_functions,
746
+ panel,
747
+ df,
748
+ df_slice,
749
+ category,
750
+ category_df,
751
+ ):
752
+ # Iterate through each transformation and its parameters for the current category
753
+ for transformation, parameters in transform_params[category].items():
754
+ transformation_function = transformation_functions[transformation]
755
+
756
+ # Check if there is panel data to group by
757
+ if len(panel) > 0:
758
+ # Apply the transformation to each group
759
+ category_df = pd.concat(
760
+ [
761
+ df_slice.groupby(panel)
762
+ .transform(transformation_function, p)
763
+ .add_suffix(f"@{transformation}_{p}")
764
+ for p in parameters
765
+ ],
766
+ axis=1,
767
+ )
768
+
769
+ # Replace all NaN or null values in category_df with 0
770
+ category_df.fillna(0, inplace=True)
771
+
772
+ # Update df_slice
773
+ df_slice = pd.concat(
774
+ [df[panel], category_df],
775
+ axis=1,
776
+ )
777
+
778
+ else:
779
+ for p in parameters:
780
+ # Apply the transformation function to each column
781
+ temp_df = df_slice.apply(
782
+ lambda x: transformation_function(x, p), axis=0
783
+ ).rename(
784
+ lambda x: f"{x}@{transformation}_{p}",
785
+ axis="columns",
786
+ )
787
+ # Concatenate the transformed DataFrame slice to the category DataFrame
788
+ category_df = pd.concat([category_df, temp_df], axis=1)
789
+
790
+ # Replace all NaN or null values in category_df with 0
791
+ category_df.fillna(0, inplace=True)
792
+
793
+ # Update df_slice
794
+ df_slice = pd.concat(
795
+ [df[panel], category_df],
796
+ axis=1,
797
+ )
798
+
799
+ return category_df, df, df_slice
800
+
801
+
802
+ # Function to apply transformations to DataFrame slices based on specified categories and parameters
803
+ @st.cache_resource(show_spinner=False)
804
+ def apply_category_transformations(
805
+ df_main, bin_dict, transform_params, panel, specific_transform_params
806
+ ):
807
+ # Dictionary for function mapping
808
+ transformation_functions = {
809
+ "Lead": apply_lead,
810
+ "Lag": apply_lag,
811
+ "Moving Average": apply_moving_average,
812
+ "Saturation": apply_saturation,
813
+ "Power": apply_power,
814
+ "Adstock": apply_adstock,
815
+ }
816
+
817
+ # List to collect all transformed DataFrames
818
+ transformed_dfs = []
819
+
820
+ # Iterate through each category specified in transform_params
821
+ for category in ["Media", "Exogenous", "Internal"]:
822
+ if (
823
+ category not in transform_params
824
+ or category not in bin_dict
825
+ or not transform_params[category]
826
+ ):
827
+ continue # Skip categories without transformations
828
+
829
+ # Initialize category_df as an empty DataFrame
830
+ category_df = pd.DataFrame()
831
+
832
+ # Slice the DataFrame based on the columns specified in bin_dict for the current category
833
+ df_slice = df_main[bin_dict[category] + panel].copy()
834
+
835
+ # Drop the column from df_slice to skip specific transformations
836
+ df_slice = df_slice.drop(
837
+ columns=list(specific_transform_params.keys()), errors="ignore"
838
+ ).copy()
839
+
840
+ category_df, df, df_slice_updated = transform_slice(
841
+ transform_params.copy(),
842
+ transformation_functions.copy(),
843
+ panel,
844
+ df_main.copy(),
845
+ df_slice.copy(),
846
+ category,
847
+ category_df.copy(),
848
+ )
849
+
850
+ # Append the transformed category DataFrame to the list if it's not empty
851
+ if not category_df.empty:
852
+ transformed_dfs.append(category_df)
853
+
854
+ # Apply channel specific transforms
855
+ for channel_specific in specific_transform_params:
856
+ # Initialize category_df as an empty DataFrame
857
+ category_df = pd.DataFrame()
858
+
859
+ df_slice_specific = df_main[[channel_specific] + panel].copy()
860
+ transform_params_specific = {
861
+ "Media": specific_transform_params[channel_specific]
862
+ }
863
+
864
+ category_df, df, df_slice_specific_updated = transform_slice(
865
+ transform_params_specific.copy(),
866
+ transformation_functions.copy(),
867
+ panel,
868
+ df_main.copy(),
869
+ df_slice_specific.copy(),
870
+ "Media",
871
+ category_df.copy(),
872
+ )
873
+
874
+ # Append the transformed category DataFrame to the list if it's not empty
875
+ if not category_df.empty:
876
+ transformed_dfs.append(category_df)
877
+
878
+ # If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
879
+ if len(transformed_dfs) > 0:
880
+ final_df = pd.concat([df_main] + transformed_dfs, axis=1)
881
+ else:
882
+ # If no transformations were applied, use the original DataFrame
883
+ final_df = df_main
884
+
885
+ # Find columns with '@' in their names
886
+ columns_with_at = [col for col in final_df.columns if "@" in col]
887
+
888
+ # Create a set of columns to drop
889
+ columns_to_drop = set()
890
+
891
+ # Iterate through columns with '@' to find shorter names to drop
892
+ for col in columns_with_at:
893
+ base_name = col.split("@")[0]
894
+ for other_col in columns_with_at:
895
+ if other_col.startswith(base_name) and len(other_col.split("@")) > len(
896
+ col.split("@")
897
+ ):
898
+ columns_to_drop.add(col)
899
+ break
900
+
901
+ # Drop the identified columns from the DataFrame
902
+ final_df.drop(columns=list(columns_to_drop), inplace=True)
903
+
904
+ return final_df
905
+
906
+
907
+ # Function to infers the granularity of the date column in a DataFrame
908
+ @st.cache_resource(show_spinner=False)
909
+ def infer_date_granularity(df):
910
+ # Find the most common difference
911
+ common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
912
+
913
+ # Map the most common difference to a granularity
914
+ if common_freq == 1:
915
+ return "daily"
916
+ elif common_freq == 7:
917
+ return "weekly"
918
+ elif 28 <= common_freq <= 31:
919
+ return "monthly"
920
+ else:
921
+ return "irregular"
922
+
923
+
924
+ # Function to clean display DataFrame
925
+ @st.cache_data(show_spinner=False)
926
+ def clean_display_df(df, display_max_col=500):
927
+ # Sort by 'panel' and 'date'
928
+ sort_columns = ["panel", "date"]
929
+ sorted_df = df.sort_values(by=sort_columns, ascending=True, na_position="first")
930
+
931
+ # Drop duplicate columns
932
+ sorted_df = sorted_df.loc[:, ~sorted_df.columns.duplicated()]
933
+
934
+ # Check if the DataFrame has more than display_max_col columns
935
+ exceeds_max_col = sorted_df.shape[1] > display_max_col
936
+
937
+ if exceeds_max_col:
938
+ # Create a new DataFrame with 'date' and 'panel' at the start
939
+ display_df = sorted_df[["date", "panel"]]
940
+
941
+ # Add the next display_max_col - 2 columns (as 'date' and 'panel' already occupy 2 columns)
942
+ additional_columns = sorted_df.columns.difference(["date", "panel"]).tolist()[
943
+ : display_max_col - 2
944
+ ]
945
+ display_df = pd.concat([display_df, sorted_df[additional_columns]], axis=1)
946
+ else:
947
+ # Ensure 'date' and 'panel' are the first two columns in the final display DataFrame
948
+ column_order = ["date", "panel"] + sorted_df.columns.difference(
949
+ ["date", "panel"]
950
+ ).tolist()
951
+ display_df = sorted_df[column_order]
952
+
953
+ # Return the display DataFrame and whether it exceeds 500 columns
954
+ return display_df, exceeds_max_col
955
+
956
+
957
+ #########################################################################################################################################################
958
+ # User input for transformations
959
+ #########################################################################################################################################################
960
+
961
+ try:
962
+ # Page Title
963
+ st.title("AI Model Transformations")
964
+
965
+ # Infer date granularity
966
+ date_granularity = infer_date_granularity(final_df_loaded)
967
+
968
+ # Initialize the main dictionary to store the transformation parameters for each category
969
+ transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
970
+
971
+ st.markdown("### Select Transformations to Apply")
972
+
973
+ with st.expander("Specific Media Transformations"):
974
+ # Select which transformations to apply
975
+ sel_channel_specific = st.session_state["project_dct"]["transformations"][
976
+ "Specific"
977
+ ].get("channel_select_specific", [])
978
+
979
+ # Reset default selected channels list if options are changed
980
+ for channel in sel_channel_specific:
981
+ if channel not in bin_dict_loaded["Media"]:
982
+ (
983
+ st.session_state["project_dct"]["transformations"]["Specific"][
984
+ "channel_select_specific"
985
+ ],
986
+ sel_channel_specific,
987
+ ) = ([], [])
988
+
989
+ select_specific_channels = st.multiselect(
990
+ label="Select channel variable",
991
+ default=sel_channel_specific,
992
+ options=bin_dict_loaded["Media"],
993
+ key="channel_select_specific",
994
+ on_change=channel_select_specific_change,
995
+ max_selections=30,
996
+ )
997
+
998
+ specific_transform_params = {}
999
+ for select_specific_channel in select_specific_channels:
1000
+ specific_transform_params[select_specific_channel] = {}
1001
+
1002
+ st.divider()
1003
+ channel_name = str(select_specific_channel).replace("_", " ").title()
1004
+ st.markdown(f"###### {channel_name}")
1005
+
1006
+ specific_transformation_key = (
1007
+ f"specific_transformation_{select_specific_channel}_Media"
1008
+ )
1009
+
1010
+ transformations_options = [
1011
+ "Lag",
1012
+ "Moving Average",
1013
+ "Saturation",
1014
+ "Power",
1015
+ "Adstock",
1016
+ ]
1017
+
1018
+ # Select which transformations to apply
1019
+ sel_transformations = st.session_state["project_dct"]["transformations"][
1020
+ "Specific"
1021
+ ].get(specific_transformation_key, [])
1022
+
1023
+ # Reset default selected channels list if options are changed
1024
+ for channel in sel_transformations:
1025
+ if channel not in transformations_options:
1026
+ (
1027
+ st.session_state["project_dct"]["transformations"]["Specific"][
1028
+ specific_transformation_key
1029
+ ],
1030
+ sel_channel_specific,
1031
+ ) = ([], [])
1032
+
1033
+ transformations_to_apply = st.multiselect(
1034
+ label="Select transformations to apply",
1035
+ options=transformations_options,
1036
+ default=sel_transformations,
1037
+ key=specific_transformation_key,
1038
+ on_change=specific_transformation_change,
1039
+ args=(specific_transformation_key,),
1040
+ )
1041
+
1042
+ # Determine the number of transformations to put in each column
1043
+ transformations_per_column = (
1044
+ len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
1045
+ )
1046
+
1047
+ # Create two columns
1048
+ col1, col2 = st.columns(2)
1049
+
1050
+ # Assign transformations to each column
1051
+ transformations_col1 = transformations_to_apply[:transformations_per_column]
1052
+ transformations_col2 = transformations_to_apply[transformations_per_column:]
1053
+
1054
+ # Create widgets in each column
1055
+ create_specific_transformation_widgets(
1056
+ col1,
1057
+ transformations_col1,
1058
+ select_specific_channel,
1059
+ date_granularity,
1060
+ specific_transform_params,
1061
+ )
1062
+ create_specific_transformation_widgets(
1063
+ col2,
1064
+ transformations_col2,
1065
+ select_specific_channel,
1066
+ date_granularity,
1067
+ specific_transform_params,
1068
+ )
1069
+
1070
+ # Create Widgets
1071
+ for category in ["Media", "Internal", "Exogenous"]:
1072
+ # Skip Internal
1073
+ if category == "Internal":
1074
+ continue
1075
+
1076
+ # Skip category if no column available
1077
+ elif (
1078
+ category not in bin_dict_loaded.keys()
1079
+ or len(bin_dict_loaded[category]) == 0
1080
+ ):
1081
+ st.info(
1082
+ f"{str(category).title()} category has no column associated with it. Skipping transformation step for this category.",
1083
+ icon="💬",
1084
+ )
1085
+ continue
1086
+
1087
+ transformation_widgets(category, transform_params, date_granularity)
1088
+
1089
+ #########################################################################################################################################################
1090
+ # Apply transformations
1091
+ #########################################################################################################################################################
1092
+
1093
+ # Reset transformation selection to default
1094
+ button_col = st.columns(2)
1095
+ with button_col[1]:
1096
+ if st.button("Reset to Default", use_container_width=True):
1097
+ st.session_state["project_dct"]["transformations"]["Media"] = {}
1098
+ st.session_state["project_dct"]["transformations"]["Exogenous"] = {}
1099
+ st.session_state["project_dct"]["transformations"]["Internal"] = {}
1100
+ st.session_state["project_dct"]["transformations"]["Specific"] = {}
1101
+
1102
+ # Log message
1103
+ log_message(
1104
+ "info",
1105
+ "All persistent selections have been reset to their default settings and cleared.",
1106
+ "Transformations",
1107
+ )
1108
+
1109
+ st.rerun()
1110
+
1111
+ # Apply category-based transformations to the DataFrame
1112
+ with button_col[0]:
1113
+ if st.button("Accept and Proceed", use_container_width=True):
1114
+ with st.spinner("Applying transformations ..."):
1115
+ final_df = apply_category_transformations(
1116
+ final_df_loaded.copy(),
1117
+ bin_dict_loaded.copy(),
1118
+ transform_params.copy(),
1119
+ panel.copy(),
1120
+ specific_transform_params.copy(),
1121
+ )
1122
+
1123
+ # Generate a dictionary mapping original column names to lists of transformed column names
1124
+ transformed_columns_dict, summary_string = generate_transformed_columns(
1125
+ original_columns, transform_params, specific_transform_params
1126
+ )
1127
+
1128
+ # Store into transformed dataframe and summary session state
1129
+ st.session_state["project_dct"]["transformations"][
1130
+ "final_df"
1131
+ ] = final_df
1132
+ st.session_state["project_dct"]["transformations"][
1133
+ "summary_string"
1134
+ ] = summary_string
1135
+
1136
+ # Display success message
1137
+ st.success("Transformation of the DataFrame is successful!", icon="✅")
1138
+
1139
+ # Log message
1140
+ log_message(
1141
+ "info",
1142
+ "Transformation of the DataFrame is successful!",
1143
+ "Transformations",
1144
+ )
1145
+
1146
+ #########################################################################################################################################################
1147
+ # Display the transformed DataFrame and summary
1148
+ #########################################################################################################################################################
1149
+
1150
+ # Display the transformed DataFrame in the Streamlit app
1151
+ st.markdown("### Transformed DataFrame")
1152
+ with st.spinner("Please wait while the transformed DataFrame is loading ..."):
1153
+ final_df = st.session_state["project_dct"]["transformations"]["final_df"].copy()
1154
+
1155
+ # Clean display DataFrame
1156
+ display_df, exceeds_max_col = clean_display_df(final_df, display_max_col)
1157
+
1158
+ # Check the number of columns and show only the first display_max_col if there are more
1159
+ if exceeds_max_col:
1160
+ # Display a info if the DataFrame has more than display_max_col columns
1161
+ st.info(
1162
+ f"The transformed DataFrame has more than {display_max_col} columns. Displaying only the first {display_max_col} columns.",
1163
+ icon="💬",
1164
+ )
1165
+
1166
+ # Display Final DataFrame
1167
+ st.dataframe(
1168
+ display_df,
1169
+ hide_index=True,
1170
+ column_config={
1171
+ "date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
1172
+ },
1173
+ )
1174
+
1175
+ # Total rows and columns
1176
+ total_rows, total_columns = st.session_state["project_dct"]["transformations"][
1177
+ "final_df"
1178
+ ].shape
1179
+ st.markdown(
1180
+ f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
1181
+ unsafe_allow_html=True,
1182
+ )
1183
+
1184
+ # Display the summary of transformations as markdown
1185
+ if (
1186
+ "summary_string" in st.session_state["project_dct"]["transformations"]
1187
+ and st.session_state["project_dct"]["transformations"]["summary_string"]
1188
+ ):
1189
+ with st.expander("Summary of Transformations"):
1190
+ st.markdown("### Summary of Transformations")
1191
+ st.markdown(
1192
+ st.session_state["project_dct"]["transformations"][
1193
+ "summary_string"
1194
+ ],
1195
+ unsafe_allow_html=True,
1196
+ )
1197
+
1198
+ #########################################################################################################################################################
1199
+ # Correlation Plot
1200
+ #########################################################################################################################################################
1201
+
1202
+ # Filter out the 'date' column
1203
+ variables = [
1204
+ col for col in final_df.columns if col.lower() not in ["date", "panel"]
1205
+ ]
1206
+
1207
+ with st.expander("Transformed Variable Correlation Plot"):
1208
+ selected_vars = st.multiselect(
1209
+ label="Choose variables for correlation plot:",
1210
+ options=variables,
1211
+ max_selections=30,
1212
+ default=st.session_state["project_dct"]["transformations"][
1213
+ "correlation_plot_selection"
1214
+ ],
1215
+ key="correlation_plot_key",
1216
+ )
1217
+
1218
+ # Calculate correlation
1219
+ if selected_vars:
1220
+ corr_df = final_df[selected_vars].corr()
1221
+
1222
+ # Prepare text annotations with 2 decimal places
1223
+ annotations = []
1224
+ for i in range(len(corr_df)):
1225
+ for j in range(len(corr_df.columns)):
1226
+ annotations.append(
1227
+ go.layout.Annotation(
1228
+ text=f"{corr_df.iloc[i, j]:.2f}",
1229
+ x=corr_df.columns[j],
1230
+ y=corr_df.index[i],
1231
+ showarrow=False,
1232
+ font=dict(color="black"),
1233
+ )
1234
+ )
1235
+
1236
+ # Plotly correlation plot using go
1237
+ heatmap = go.Heatmap(
1238
+ z=corr_df.values,
1239
+ x=corr_df.columns,
1240
+ y=corr_df.index,
1241
+ colorscale="RdBu",
1242
+ zmin=-1,
1243
+ zmax=1,
1244
+ )
1245
+
1246
+ layout = go.Layout(
1247
+ title="Transformed Variable Correlation Plot",
1248
+ xaxis=dict(title="Variables"),
1249
+ yaxis=dict(title="Variables"),
1250
+ width=1000,
1251
+ height=1000,
1252
+ annotations=annotations,
1253
+ )
1254
+
1255
+ fig = go.Figure(data=[heatmap], layout=layout)
1256
+
1257
+ st.plotly_chart(fig)
1258
+ else:
1259
+ st.write("Please select at least one variable to plot.")
1260
+
1261
+ #########################################################################################################################################################
1262
+ # Accept and Save
1263
+ #########################################################################################################################################################
1264
+
1265
+ # Check for saved model
1266
+ if (
1267
+ retrieve_pkl_object(
1268
+ st.session_state["project_number"], "Model_Build", "best_models", schema
1269
+ )
1270
+ is not None
1271
+ ): # db
1272
+ st.warning(
1273
+ "Saving transformations will overwrite existing ones and delete all saved models. To keep previous models, please start a new project.",
1274
+ icon="⚠️",
1275
+ )
1276
+
1277
+ if st.button("Accept and Save", use_container_width=True):
1278
+
1279
+ with st.spinner("Saving Changes"):
1280
+ # Update correlation plot selection
1281
+ st.session_state["project_dct"]["transformations"][
1282
+ "correlation_plot_selection"
1283
+ ] = st.session_state["correlation_plot_key"]
1284
+
1285
+ # Clear model metadata
1286
+ clear_pages()
1287
+
1288
+ # Update DB
1289
+ update_db(
1290
+ prj_id=st.session_state["project_number"],
1291
+ page_nam="Transformations",
1292
+ file_nam="project_dct",
1293
+ pkl_obj=pickle.dumps(st.session_state["project_dct"]),
1294
+ schema=schema,
1295
+ )
1296
+
1297
+ # Clear data from DB
1298
+ delete_entries(
1299
+ st.session_state["project_number"],
1300
+ ["Model_Build", "Model_Tuning"],
1301
+ db_cred,
1302
+ schema,
1303
+ )
1304
+
1305
+ # Success message
1306
+ st.success("Saved Successfully!", icon="💾")
1307
+ st.toast("Saved Successfully!", icon="💾")
1308
+
1309
+ # Log message
1310
+ log_message("info", "Saved Successfully!", "Transformations")
1311
+
1312
+ except Exception as e:
1313
+ # Capture the error details
1314
+ exc_type, exc_value, exc_traceback = sys.exc_info()
1315
+ error_message = "".join(
1316
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
1317
+ )
1318
+
1319
+ # Log message
1320
+ log_message("error", f"An error occurred: {error_message}.", "Transformations")
1321
+
1322
+ # Display a warning message
1323
+ st.warning(
1324
+ "Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
1325
+ icon="⚠️",
1326
+ )
pages/4_AI_Model_Build.py ADDED
The diff for this file is too large to render. See raw diff
 
pages/5_AI Model_Tuning.py ADDED
@@ -0,0 +1,1215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MMO Build Sprint 3
3
+ date :
4
+ changes : capability to tune MixedLM as well as simple LR in the same page
5
+ """
6
+
7
+ import os
8
+ import streamlit as st
9
+ import pandas as pd
10
+ from data_analysis import format_numbers
11
+ import pickle
12
+ from utilities import set_header, load_local_css
13
+ import statsmodels.api as sm
14
+ import re
15
+ from sklearn.preprocessing import MaxAbsScaler
16
+ import matplotlib.pyplot as plt
17
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
18
+ import statsmodels.formula.api as smf
19
+ from data_prep import *
20
+ import sqlite3
21
+ from utilities import (
22
+ set_header,
23
+ load_local_css,
24
+ update_db,
25
+ project_selection,
26
+ retrieve_pkl_object,
27
+ )
28
+ import numpy as np
29
+ from post_gres_cred import db_cred
30
+ import re
31
+ from constants import (
32
+ NUM_FLAG_COLS_TO_DISPLAY,
33
+ HALF_YEAR_THRESHOLD,
34
+ FULL_YEAR_THRESHOLD,
35
+ TREND_MIN,
36
+ ANNUAL_FREQUENCY,
37
+ QTR_FREQUENCY_FACTOR,
38
+ HALF_YEARLY_FREQUENCY_FACTOR,
39
+ )
40
+ from log_application import log_message
41
+ import sys, traceback
42
+
43
+ schema = db_cred["schema"]
44
+
45
+ st.set_option("deprecation.showPyplotGlobalUse", False)
46
+
47
+ st.set_page_config(
48
+ page_title="AI Model Tuning",
49
+ page_icon=":shark:",
50
+ layout="wide",
51
+ initial_sidebar_state="collapsed",
52
+ )
53
+ load_local_css("styles.css")
54
+ set_header()
55
+
56
+
57
+ # Define functions
58
+ # Get random effect from MixedLM Model
59
+ def get_random_effects(media_data, panel_col, _mdf):
60
+ # create an empty dataframe
61
+ random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
62
+
63
+ # Iterate over all panel values and add to dataframe
64
+ for i, market in enumerate(media_data[panel_col].unique()):
65
+ intercept = _mdf.random_effects[market].values[0]
66
+ random_eff_df.loc[i, "random_effect"] = intercept
67
+ random_eff_df.loc[i, panel_col] = market
68
+
69
+ return random_eff_df
70
+
71
+
72
+ # Predict on df using MixedLM model
73
+ def mdf_predict(X_df, mdf, random_eff_df):
74
+ # Create a copy of input df and predict using MixedLM model i.e fixed effect
75
+ X = X_df.copy()
76
+ X["fixed_effect"] = mdf.predict(X)
77
+
78
+ # Merge random effects
79
+ X = pd.merge(X, random_eff_df, on=panel_col, how="left")
80
+
81
+ # Get final predictions by adding random effect to fixed effect
82
+ X["pred"] = X["fixed_effect"] + X["random_effect"]
83
+
84
+ # Drop intermediate columns
85
+ X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
86
+
87
+ return X["pred"]
88
+
89
+
90
+ def format_display(inp):
91
+ # Format display titles
92
+ return inp.title().replace("_", " ").strip()
93
+
94
+
95
+ if "username" not in st.session_state:
96
+ st.session_state["username"] = None
97
+ if "project_name" not in st.session_state:
98
+ st.session_state["project_name"] = None
99
+ if "project_dct" not in st.session_state:
100
+ project_selection()
101
+ st.stop()
102
+ if "Flags" not in st.session_state:
103
+ st.session_state["Flags"] = {}
104
+
105
+ try:
106
+ # Check Authentications
107
+ if "username" in st.session_state and st.session_state["username"] is not None:
108
+
109
+ if (
110
+ retrieve_pkl_object(
111
+ st.session_state["project_number"], "Model_Build", "best_models", schema
112
+ )
113
+ is None
114
+ ): # db
115
+
116
+ st.error("Please save a model before tuning")
117
+ log_message(
118
+ "warning",
119
+ "No models saved",
120
+ "Model Tuning",
121
+ )
122
+ st.stop()
123
+
124
+ # Read previous progress (persistence)
125
+ if (
126
+ "session_state_saved"
127
+ in st.session_state["project_dct"]["model_build"].keys()
128
+ ):
129
+ for key in [
130
+ "Model",
131
+ "date",
132
+ "saved_model_names",
133
+ "media_data",
134
+ "X_test_spends",
135
+ "spends_data",
136
+ ]:
137
+ if key not in st.session_state:
138
+ st.session_state[key] = st.session_state["project_dct"][
139
+ "model_build"
140
+ ]["session_state_saved"][key]
141
+ st.session_state["bin_dict"] = st.session_state["project_dct"][
142
+ "model_build"
143
+ ]["session_state_saved"]["bin_dict"]
144
+ if (
145
+ "used_response_metrics" not in st.session_state
146
+ or st.session_state["used_response_metrics"] == []
147
+ ):
148
+ st.session_state["used_response_metrics"] = st.session_state[
149
+ "project_dct"
150
+ ]["model_build"]["session_state_saved"]["used_response_metrics"]
151
+ else:
152
+ st.error("Please load a session with a built model")
153
+ log_message(
154
+ "error",
155
+ "Session state saved not found in Project Dictionary",
156
+ "Model Tuning",
157
+ )
158
+ st.stop()
159
+
160
+ for key in ["select_all_flags_check", "selected_flags", "sel_model"]:
161
+ if key not in st.session_state["project_dct"]["model_tuning"].keys():
162
+ st.session_state["project_dct"]["model_tuning"][key] = {}
163
+
164
+ # is_panel = st.session_state['is_panel']
165
+ # panel_col = 'markets' # set the panel column
166
+ date_col = "date"
167
+
168
+ # set the panel column
169
+ panel_col = "panel"
170
+ is_panel = (
171
+ True if st.session_state["media_data"][panel_col].nunique() > 1 else False
172
+ )
173
+
174
+ if "Model_Tuned" not in st.session_state:
175
+ st.session_state["Model_Tuned"] = {}
176
+ cols1 = st.columns([2, 1])
177
+ with cols1[0]:
178
+ st.markdown(f"**Welcome {st.session_state['username']}**")
179
+ with cols1[1]:
180
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
181
+
182
+ st.title("AI Model Tuning")
183
+
184
+ # flag indicating there is not tuned model till now
185
+ if "is_tuned_model" not in st.session_state:
186
+ st.session_state["is_tuned_model"] = {}
187
+
188
+ # # Read all saved models
189
+ model_dict = retrieve_pkl_object(
190
+ st.session_state["project_number"], "Model_Build", "best_models", schema
191
+ )
192
+ saved_models = model_dict.keys()
193
+
194
+ # Get list of response metrics
195
+ st.session_state["used_response_metrics"] = list(
196
+ set([model.split("__")[1] for model in saved_models])
197
+ )
198
+
199
+ # Select previously selected response_metric (persistence)
200
+ default_target_idx = (
201
+ st.session_state["project_dct"]["model_tuning"].get("sel_target_col", None)
202
+ if st.session_state["project_dct"]["model_tuning"].get(
203
+ "sel_target_col", None
204
+ )
205
+ is not None
206
+ else st.session_state["used_response_metrics"][0]
207
+ )
208
+
209
+ # Dropdown to select response metric
210
+ sel_target_col = st.selectbox(
211
+ "Select the response metric",
212
+ st.session_state["used_response_metrics"],
213
+ index=st.session_state["used_response_metrics"].index(default_target_idx),
214
+ format_func=format_display,
215
+ )
216
+ # Format selected response metrics (target col)
217
+ target_col = (
218
+ sel_target_col.lower()
219
+ .replace(" ", "_")
220
+ .replace("-", "")
221
+ .replace(":", "")
222
+ .replace("__", "_")
223
+ )
224
+ st.session_state["project_dct"]["model_tuning"][
225
+ "sel_target_col"
226
+ ] = sel_target_col
227
+
228
+ # Look through all saved models, only show saved models of the selected resp metric (target_col)
229
+ # Get a list of models saved for selected response metric
230
+ required_saved_models = [
231
+ m.split("__")[0] for m in saved_models if m.split("__")[1] == target_col
232
+ ]
233
+
234
+ # Get previously seelcted model if available (persistence)
235
+ default_model_idx = st.session_state["project_dct"]["model_tuning"][
236
+ "sel_model"
237
+ ].get(sel_target_col, required_saved_models[0])
238
+ sel_model = st.selectbox(
239
+ "Select the model to tune",
240
+ required_saved_models,
241
+ index=required_saved_models.index(default_model_idx),
242
+ )
243
+
244
+ st.session_state["project_dct"]["model_tuning"]["sel_model"][
245
+ sel_target_col
246
+ ] = default_model_idx
247
+
248
+ sel_model_dict = model_dict[
249
+ sel_model + "__" + target_col
250
+ ] # get the model obj of the selected model
251
+
252
+ X_train = sel_model_dict["X_train"]
253
+ X_test = sel_model_dict["X_test"]
254
+ y_train = sel_model_dict["y_train"]
255
+ y_test = sel_model_dict["y_test"]
256
+ df = st.session_state["media_data"]
257
+
258
+ st.markdown("### Event Flags")
259
+ st.markdown("Helps in quantifying the impact of specific occurrences of events")
260
+
261
+ try:
262
+ # Dropdown to add event flags
263
+ with st.expander("Apply Event Flags"):
264
+
265
+ model = sel_model_dict["Model_object"]
266
+ date = st.session_state["date"]
267
+ date = pd.to_datetime(date)
268
+ X_train = sel_model_dict["X_train"]
269
+
270
+ features_set = sel_model_dict["feature_set"]
271
+
272
+ col = st.columns(3)
273
+
274
+ # Get date range
275
+ min_date = min(date).date()
276
+ max_date = max(date).date()
277
+
278
+ # Get previously selected start and end date of flag (persistence)
279
+ start_date_default = (
280
+ st.session_state["project_dct"]["model_tuning"].get(
281
+ "start_date_default"
282
+ )
283
+ if st.session_state["project_dct"]["model_tuning"].get(
284
+ "start_date_default"
285
+ )
286
+ is not None
287
+ else min_date
288
+ )
289
+ start_date_default = (
290
+ start_date_default if start_date_default > min_date else min_date
291
+ )
292
+ start_date_default = (
293
+ start_date_default if start_date_default < max_date else min_date
294
+ )
295
+
296
+ end_date_default = (
297
+ st.session_state["project_dct"]["model_tuning"].get(
298
+ "end_date_default"
299
+ )
300
+ if st.session_state["project_dct"]["model_tuning"].get(
301
+ "end_date_default"
302
+ )
303
+ is not None
304
+ else max_date
305
+ )
306
+ end_date_default = (
307
+ end_date_default if end_date_default > min_date else max_date
308
+ )
309
+ end_date_default = (
310
+ end_date_default if end_date_default < max_date else max_date
311
+ )
312
+
313
+ # Flag start and end date input boxes
314
+ with col[0]:
315
+ start_date = st.date_input(
316
+ "Select Start Date",
317
+ start_date_default,
318
+ min_value=min_date,
319
+ max_value=max_date,
320
+ )
321
+
322
+ if (start_date < min_date) or (start_date > max_date):
323
+ st.error(
324
+ "Please select dates in the range of the dates in the data"
325
+ )
326
+ st.stop()
327
+ with col[1]:
328
+ # Check if end date default > selected start date
329
+ end_date_default = (
330
+ end_date_default
331
+ if pd.Timestamp(end_date_default) >= pd.Timestamp(start_date)
332
+ else start_date
333
+ )
334
+ end_date = st.date_input(
335
+ "Select End Date",
336
+ end_date_default,
337
+ min_value=max(
338
+ pd.to_datetime(min_date), pd.to_datetime(start_date)
339
+ ),
340
+ max_value=pd.to_datetime(max_date),
341
+ )
342
+
343
+ if (
344
+ (start_date < min_date)
345
+ or (end_date < min_date)
346
+ or (start_date > max_date)
347
+ or (end_date > max_date)
348
+ ):
349
+ st.error(
350
+ "Please select dates in the range of the dates in the data"
351
+ )
352
+ st.stop()
353
+ if end_date < start_date:
354
+ st.error("Please select end date after start date")
355
+ st.stop()
356
+ with col[2]:
357
+ # Get default value of repeat check box (persistence)
358
+ repeat_default = (
359
+ st.session_state["project_dct"]["model_tuning"].get(
360
+ "repeat_default"
361
+ )
362
+ if st.session_state["project_dct"]["model_tuning"].get(
363
+ "repeat_default"
364
+ )
365
+ is not None
366
+ else "No"
367
+ )
368
+ repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1
369
+ repeat = st.selectbox(
370
+ "Repeat Annually", ["Yes", "No"], index=repeat_default_idx
371
+ )
372
+
373
+ # Update selected values to session dictionary (persistence)
374
+ st.session_state["project_dct"]["model_tuning"][
375
+ "start_date_default"
376
+ ] = start_date
377
+ st.session_state["project_dct"]["model_tuning"][
378
+ "end_date_default"
379
+ ] = end_date
380
+ st.session_state["project_dct"]["model_tuning"][
381
+ "repeat_default"
382
+ ] = repeat
383
+
384
+ if repeat == "Yes":
385
+ repeat = True
386
+ else:
387
+ repeat = False
388
+
389
+ if "flags" in st.session_state["project_dct"]["model_tuning"].keys():
390
+ st.session_state["Flags"] = st.session_state["project_dct"][
391
+ "model_tuning"
392
+ ]["flags"]
393
+
394
+ if is_panel:
395
+ # Create flag on Train
396
+ met, line_values, fig_flag = plot_actual_vs_predicted(
397
+ X_train[date_col],
398
+ y_train,
399
+ model.fittedvalues,
400
+ model,
401
+ target_column=sel_target_col,
402
+ flag=(start_date, end_date),
403
+ repeat_all_years=repeat,
404
+ is_panel=True,
405
+ )
406
+ st.plotly_chart(fig_flag, use_container_width=True)
407
+
408
+ # create flag on test
409
+ met, test_line_values, fig_flag = plot_actual_vs_predicted(
410
+ X_test[date_col],
411
+ y_test,
412
+ sel_model_dict["pred_test"],
413
+ model,
414
+ target_column=sel_target_col,
415
+ flag=(start_date, end_date),
416
+ repeat_all_years=repeat,
417
+ is_panel=True,
418
+ )
419
+
420
+ else:
421
+ pred_train = model.predict(X_train[features_set])
422
+ # Create flag on Train
423
+ met, line_values, fig_flag = plot_actual_vs_predicted(
424
+ X_train[date_col],
425
+ y_train,
426
+ pred_train,
427
+ model,
428
+ flag=(start_date, end_date),
429
+ repeat_all_years=repeat,
430
+ is_panel=False,
431
+ )
432
+ st.plotly_chart(fig_flag, use_container_width=True)
433
+
434
+ # create flag on test
435
+ pred_test = model.predict(X_test[features_set])
436
+ met, test_line_values, fig_flag = plot_actual_vs_predicted(
437
+ X_test[date_col],
438
+ y_test,
439
+ pred_test,
440
+ model,
441
+ flag=(start_date, end_date),
442
+ repeat_all_years=repeat,
443
+ is_panel=False,
444
+ )
445
+
446
+ flag_name = "f1_flag"
447
+ flag_name = st.text_input("Enter Flag Name")
448
+
449
+ # add selected target col to flag name
450
+ # Save the flag name, flag train values, flag test values to session state
451
+ if st.button("Save flag"):
452
+ st.session_state["Flags"][flag_name + "_flag__" + target_col] = {}
453
+ st.session_state["Flags"][flag_name + "_flag__" + target_col][
454
+ "train"
455
+ ] = line_values
456
+ st.session_state["Flags"][flag_name + "_flag__" + target_col][
457
+ "test"
458
+ ] = test_line_values
459
+ st.success(f'{flag_name + "_flag__" + target_col} stored')
460
+
461
+ st.session_state["project_dct"]["model_tuning"]["flags"] = (
462
+ st.session_state["Flags"]
463
+ )
464
+
465
+ # Only show flags created for the particular target col
466
+ target_model_flags = [
467
+ f.split("__")[0]
468
+ for f in st.session_state["Flags"].keys()
469
+ if f.split("__")[1] == target_col
470
+ ]
471
+ options = list(target_model_flags)
472
+ num_rows = -(-len(options) // NUM_FLAG_COLS_TO_DISPLAY)
473
+
474
+ tick = False
475
+ # Select all flags checkbox
476
+ if st.checkbox(
477
+ "Select all",
478
+ value=st.session_state["project_dct"]["model_tuning"][
479
+ "select_all_flags_check"
480
+ ].get(sel_target_col, False),
481
+ ):
482
+ tick = True
483
+ st.session_state["project_dct"]["model_tuning"][
484
+ "select_all_flags_check"
485
+ ][sel_target_col] = True
486
+ else:
487
+ st.session_state["project_dct"]["model_tuning"][
488
+ "select_all_flags_check"
489
+ ][sel_target_col] = False
490
+
491
+ # Get previous flag selection (persistence)
492
+ selection_defualts = st.session_state["project_dct"]["model_tuning"][
493
+ "selected_flags"
494
+ ].get(sel_target_col, [])
495
+ selected_options = selection_defualts
496
+
497
+ # create a checkbox for each available flag for selected response metric
498
+ for row in range(num_rows):
499
+ cols = st.columns(NUM_FLAG_COLS_TO_DISPLAY)
500
+ for col in cols:
501
+ if options:
502
+ option = options.pop(0)
503
+ option_default = True if option in selection_defualts else False
504
+ selected = col.checkbox(option, value=(tick or option_default))
505
+ if selected:
506
+ selected_options.append(option)
507
+ else:
508
+ if option in selected_options:
509
+ selected_options.remove(option)
510
+ selected_options = list(set(selected_options))
511
+
512
+ # Check if flag values match Data
513
+ # This is necessary because different models can have different train/test dates
514
+ remove_flags = []
515
+ for opt in selected_options:
516
+ train_match = len(
517
+ st.session_state["Flags"][opt + "__" + target_col]["train"]
518
+ ) == len(X_train[date_col])
519
+ test_match = len(
520
+ st.session_state["Flags"][opt + "__" + target_col]["test"]
521
+ ) == len(X_test[date_col])
522
+ if not train_match:
523
+ st.warning(f"Flag {opt} can not be used due to train date mismatch")
524
+ # selected_options.remove(opt)
525
+ remove_flags.append(opt)
526
+ if not test_match:
527
+ st.warning(f"Flag {opt} can not be used due to test date mismatch")
528
+ # selected_options.remove(opt)
529
+ remove_flags.append(opt)
530
+
531
+ if (
532
+ len(remove_flags) > 0
533
+ and len(list(set(selected_options).intersection(set(remove_flags)))) > 0
534
+ ):
535
+ selected_options = list(set(selected_options) - set(remove_flags))
536
+
537
+ st.session_state["project_dct"]["model_tuning"]["selected_flags"][
538
+ sel_target_col
539
+ ] = selected_options
540
+ except:
541
+ # Capture the error details
542
+ exc_type, exc_value, exc_traceback = sys.exc_info()
543
+ error_message = "".join(
544
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
545
+ )
546
+ log_message(
547
+ "error", f"Error while creating flags: {error_message}", "Model Tuning"
548
+ )
549
+ st.warning("An error occured, please try again", icon="⚠️")
550
+
551
+ try:
552
+ st.markdown("### Trend and Seasonality Calibration")
553
+ parameters = st.columns(3)
554
+
555
+ # Trend checkbox
556
+ with parameters[0]:
557
+ Trend = st.checkbox(
558
+ "**Trend**",
559
+ value=st.session_state["project_dct"]["model_tuning"].get(
560
+ "trend_check", False
561
+ ),
562
+ )
563
+ st.markdown(
564
+ "Helps account for long-term trends or seasonality that could influence advertising effectiveness"
565
+ )
566
+
567
+ # Day of Week (week number) checkbox
568
+ with parameters[1]:
569
+ day_of_week = st.checkbox(
570
+ "**Day of Week**",
571
+ value=st.session_state["project_dct"]["model_tuning"].get(
572
+ "week_num_check", False
573
+ ),
574
+ )
575
+ st.markdown(
576
+ "Assists in detecting and incorporating weekly patterns or seasonality"
577
+ )
578
+
579
+ # Sine and cosine Waves checkbox
580
+ with parameters[2]:
581
+ sine_cosine = st.checkbox(
582
+ "**Sine and Cosine Waves**",
583
+ value=st.session_state["project_dct"]["model_tuning"].get(
584
+ "sine_cosine_check", False
585
+ ),
586
+ )
587
+ st.markdown(
588
+ "Helps in capturing long term cyclical patterns or seasonality in the data"
589
+ )
590
+
591
+ if sine_cosine:
592
+ # Drop down to select Frequency of waves
593
+ xtrain_time_period_months = (
594
+ X_train[date_col].max() - X_train[date_col].min()
595
+ ).days / 30
596
+
597
+ # If we have 6 months of data, only quarter frequency is possible
598
+ if xtrain_time_period_months <= HALF_YEAR_THRESHOLD:
599
+ available_frequencies = ["Quarter"]
600
+
601
+ # If we have less than 12 months of data, we have quarter and semi-annual frequencies
602
+ elif xtrain_time_period_months < FULL_YEAR_THRESHOLD:
603
+ available_frequencies = ["Quarter", "Semi-Annual"]
604
+
605
+ # If we have 12 months of data or more, we have quarter, semi-annual and annual frequencies
606
+ elif xtrain_time_period_months >= FULL_YEAR_THRESHOLD:
607
+ available_frequencies = ["Quarter", "Semi-Annual", "Annual"]
608
+
609
+ wave_freq = st.selectbox("Select Frequency", available_frequencies)
610
+ except:
611
+ # Capture the error details
612
+ exc_type, exc_value, exc_traceback = sys.exc_info()
613
+ error_message = "".join(
614
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
615
+ )
616
+ log_message(
617
+ "error",
618
+ f"Error while selecting tuning parameters: {error_message}",
619
+ "Model Tuning",
620
+ )
621
+ st.warning("An error occured, please try again", icon="⚠️")
622
+
623
+ try:
624
+ # Build tuned model
625
+ if st.button(
626
+ "Build model with Selected Parameters and Flags",
627
+ key="build_tuned_model",
628
+ use_container_width=True,
629
+ ):
630
+ new_features = features_set
631
+ st.header("2.1 Results Summary")
632
+ ss = MaxAbsScaler()
633
+ if is_panel == True:
634
+ X_train_tuned = X_train[features_set]
635
+ X_train_tuned[target_col] = X_train[target_col]
636
+ X_train_tuned[date_col] = X_train[date_col]
637
+ X_train_tuned[panel_col] = X_train[panel_col]
638
+
639
+ X_test_tuned = X_test[features_set]
640
+ X_test_tuned[target_col] = X_test[target_col]
641
+ X_test_tuned[date_col] = X_test[date_col]
642
+ X_test_tuned[panel_col] = X_test[panel_col]
643
+
644
+ else:
645
+ X_train_tuned = X_train[features_set]
646
+
647
+ X_test_tuned = X_test[features_set]
648
+
649
+ for flag in selected_options:
650
+ # Get the flag values of train and test and add to the data
651
+ X_train_tuned[flag] = st.session_state["Flags"][
652
+ flag + "__" + target_col
653
+ ]["train"]
654
+ X_test_tuned[flag] = st.session_state["Flags"][
655
+ flag + "__" + target_col
656
+ ]["test"]
657
+
658
+ if Trend:
659
+ st.session_state["project_dct"]["model_tuning"][
660
+ "trend_check"
661
+ ] = True
662
+
663
+ # group by panel, calculate trend of each panel spearately. Add trend to new feature set
664
+ if is_panel:
665
+ newdata = pd.DataFrame()
666
+ panel_wise_end_point_train = {}
667
+ for panel, groupdf in X_train_tuned.groupby(panel_col):
668
+ groupdf.sort_values(date_col, inplace=True)
669
+ groupdf["Trend"] = np.arange(
670
+ TREND_MIN, len(groupdf) + TREND_MIN, 1
671
+ ) # Trend is a straight line with starting point as TREND_MIN
672
+ newdata = pd.concat([newdata, groupdf])
673
+ panel_wise_end_point_train[panel] = len(groupdf) + TREND_MIN
674
+ X_train_tuned = newdata.copy()
675
+
676
+ test_newdata = pd.DataFrame()
677
+ for panel, test_groupdf in X_test_tuned.groupby(panel_col):
678
+ test_groupdf.sort_values(date_col, inplace=True)
679
+ start = panel_wise_end_point_train[panel]
680
+ end = start + len(test_groupdf)
681
+ test_groupdf["Trend"] = np.arange(start, end, 1)
682
+ test_newdata = pd.concat([test_newdata, test_groupdf])
683
+ X_test_tuned = test_newdata.copy()
684
+
685
+ new_features = new_features + ["Trend"]
686
+
687
+ else:
688
+ X_train_tuned["Trend"] = np.arange(
689
+ TREND_MIN, len(X_train_tuned) + TREND_MIN, 1
690
+ ) # Trend is a straight line with starting point as TREND_MIN
691
+ X_test_tuned["Trend"] = np.arange(
692
+ len(X_train_tuned) + TREND_MIN,
693
+ len(X_train_tuned) + len(X_test_tuned) + TREND_MIN,
694
+ 1,
695
+ )
696
+ new_features = new_features + ["Trend"]
697
+
698
+ else:
699
+ st.session_state["project_dct"]["model_tuning"][
700
+ "trend_check"
701
+ ] = False # persistence
702
+
703
+ # Add day of week (Week_num) to test & train
704
+ if day_of_week:
705
+ st.session_state["project_dct"]["model_tuning"][
706
+ "week_num_check"
707
+ ] = True
708
+
709
+ if is_panel:
710
+ X_train_tuned[date_col] = pd.to_datetime(
711
+ X_train_tuned[date_col]
712
+ )
713
+ X_train_tuned["day_of_week"] = X_train_tuned[
714
+ date_col
715
+ ].dt.day_of_week # Day of week
716
+ # if all the dates in the data have the same day of week number this feature cant be used
717
+ if X_train_tuned["day_of_week"].nunique() == 1:
718
+ st.error(
719
+ "All dates in the data are of the same week day. Hence Week number can't be used."
720
+ )
721
+ else:
722
+ X_test_tuned[date_col] = pd.to_datetime(
723
+ X_test_tuned[date_col]
724
+ )
725
+ X_test_tuned["day_of_week"] = X_test_tuned[
726
+ date_col
727
+ ].dt.day_of_week # Day of week
728
+ new_features = new_features + ["day_of_week"]
729
+
730
+ else:
731
+ date = pd.to_datetime(date.values)
732
+ X_train_tuned["day_of_week"] = pd.to_datetime(
733
+ X_train[date_col]
734
+ ).dt.day_of_week # Day of week
735
+ X_test_tuned["day_of_week"] = pd.to_datetime(
736
+ X_test[date_col]
737
+ ).dt.day_of_week # Day of week
738
+
739
+ # if all the dates in the data have the same day of week number this feature cant be used
740
+ if X_train_tuned["day_of_week"].nunique() == 1:
741
+ st.error(
742
+ "All dates in the data are of the same week day. Hence Week number can't be used."
743
+ )
744
+ else:
745
+ new_features = new_features + ["day_of_week"]
746
+ else:
747
+ st.session_state["project_dct"]["model_tuning"][
748
+ "week_num_check"
749
+ ] = False
750
+
751
+ # create sine and cosine wave and add to data
752
+ if sine_cosine:
753
+ st.session_state["project_dct"]["model_tuning"][
754
+ "sine_cosine_check"
755
+ ] = True
756
+ frequency = ANNUAL_FREQUENCY # Annual Frequency
757
+ if wave_freq == "Quarter":
758
+ frequency = frequency * QTR_FREQUENCY_FACTOR
759
+ elif wave_freq == "Semi-Annual":
760
+ frequency = frequency * HALF_YEARLY_FREQUENCY_FACTOR
761
+ # create panel wise sine cosine waves in xtrain tuned. add to new feature set
762
+ if is_panel:
763
+ new_features = new_features + ["sine_wave", "cosine_wave"]
764
+ newdata = pd.DataFrame()
765
+ newdata_test = pd.DataFrame()
766
+ groups = X_train_tuned.groupby(panel_col)
767
+
768
+ train_panel_wise_end_point = {}
769
+ for panel, groupdf in groups:
770
+ num_samples = len(groupdf)
771
+ train_panel_wise_end_point[panel] = num_samples
772
+ days_since_start = np.arange(num_samples)
773
+ sine_wave = np.sin(frequency * days_since_start)
774
+ cosine_wave = np.cos(frequency * days_since_start)
775
+ sine_cosine_df = pd.DataFrame(
776
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
777
+ )
778
+ assert len(sine_cosine_df) == len(groupdf)
779
+
780
+ groupdf["sine_wave"] = sine_wave
781
+ groupdf["cosine_wave"] = cosine_wave
782
+ newdata = pd.concat([newdata, groupdf])
783
+
784
+ X_train_tuned = newdata.copy()
785
+
786
+ test_groups = X_test_tuned.groupby(panel_col)
787
+ for panel, test_groupdf in test_groups:
788
+ num_samples = len(test_groupdf)
789
+ start = train_panel_wise_end_point[panel]
790
+ days_since_start = np.arange(start, start + num_samples, 1)
791
+ # print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1)))
792
+ sine_wave = np.sin(frequency * days_since_start)
793
+ cosine_wave = np.cos(frequency * days_since_start)
794
+ sine_cosine_df = pd.DataFrame(
795
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
796
+ )
797
+ assert len(sine_cosine_df) == len(test_groupdf)
798
+ # groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
799
+ test_groupdf["sine_wave"] = sine_wave
800
+ test_groupdf["cosine_wave"] = cosine_wave
801
+ newdata_test = pd.concat([newdata_test, test_groupdf])
802
+
803
+ X_test_tuned = newdata_test.copy()
804
+
805
+ else:
806
+ new_features = new_features + ["sine_wave", "cosine_wave"]
807
+
808
+ num_samples = len(X_train_tuned)
809
+
810
+ days_since_start = np.arange(num_samples)
811
+ sine_wave = np.sin(frequency * days_since_start)
812
+ cosine_wave = np.cos(frequency * days_since_start)
813
+ sine_cosine_df = pd.DataFrame(
814
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
815
+ )
816
+ # Concatenate the sine and cosine waves with the scaled X DataFrame
817
+ X_train_tuned = pd.concat(
818
+ [X_train_tuned, sine_cosine_df], axis=1
819
+ )
820
+
821
+ test_num_samples = len(X_test_tuned)
822
+ start = num_samples
823
+ days_since_start = np.arange(start, start + test_num_samples, 1)
824
+ sine_wave = np.sin(frequency * days_since_start)
825
+ cosine_wave = np.cos(frequency * days_since_start)
826
+ sine_cosine_df = pd.DataFrame(
827
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
828
+ )
829
+ # Concatenate the sine and cosine waves with the scaled X DataFrame
830
+ X_test_tuned = pd.concat([X_test_tuned, sine_cosine_df], axis=1)
831
+
832
+ else:
833
+ st.session_state["project_dct"]["model_tuning"][
834
+ "sine_cosine_check"
835
+ ] = False
836
+
837
+ # Build model
838
+
839
+ # Get list of parameters added and scale
840
+ # previous features are scaled already during model build
841
+ added_params = list(set(new_features) - set(features_set))
842
+ if len(added_params) > 0:
843
+ concat_df = pd.concat([X_train_tuned, X_test_tuned]).reset_index(
844
+ drop=True
845
+ )
846
+
847
+ if is_panel:
848
+ train_max_date = X_train_tuned[date_col].max()
849
+
850
+ # concat_df = concat_df.reset_index(drop=True)
851
+ # concat_df=concat_df[added_params]
852
+ train_idx = X_train_tuned.index[-1]
853
+
854
+ concat_df[added_params] = ss.fit_transform(concat_df[added_params])
855
+ # added_params_df = pd.DataFrame(added_params_df)
856
+ # added_params_df.columns = added_params
857
+
858
+ if is_panel:
859
+ X_train_tuned[added_params] = concat_df[
860
+ concat_df[date_col] <= train_max_date
861
+ ][added_params].reset_index(drop=True)
862
+ X_test_tuned[added_params] = concat_df[
863
+ concat_df[date_col] > train_max_date
864
+ ][added_params].reset_index(drop=True)
865
+ else:
866
+ added_params_df = concat_df[added_params]
867
+ X_train_tuned[added_params] = added_params_df[: train_idx + 1]
868
+ X_test_tuned[added_params] = added_params_df.loc[
869
+ train_idx + 1 :
870
+ ].reset_index(drop=True)
871
+
872
+ # Add flags (flags are 0, 1 only so need to scale)
873
+ if selected_options:
874
+ new_features = new_features + selected_options
875
+
876
+ # Build Mixed LM model for panel level data
877
+ if is_panel:
878
+ X_train_tuned.sort_values([date_col, panel_col]).reset_index(
879
+ drop=True, inplace=True
880
+ )
881
+
882
+ new_features = list(set(new_features))
883
+ inp_vars_str = " + ".join(new_features)
884
+
885
+ md_str = target_col + " ~ " + inp_vars_str
886
+ md_tuned = smf.mixedlm(
887
+ md_str,
888
+ data=X_train_tuned[[target_col] + new_features],
889
+ groups=X_train_tuned[panel_col],
890
+ )
891
+ model_tuned = md_tuned.fit()
892
+
893
+ # plot actual vs predicted for original model and tuned model
894
+ metrics_table, line, actual_vs_predicted_plot = (
895
+ plot_actual_vs_predicted(
896
+ X_train[date_col],
897
+ y_train,
898
+ model.fittedvalues,
899
+ model,
900
+ target_column=sel_target_col,
901
+ is_panel=True,
902
+ )
903
+ )
904
+ metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
905
+ plot_actual_vs_predicted(
906
+ X_train_tuned[date_col],
907
+ X_train_tuned[target_col],
908
+ model_tuned.fittedvalues,
909
+ model_tuned,
910
+ target_column=sel_target_col,
911
+ is_panel=True,
912
+ )
913
+ )
914
+
915
+ # Build OLS model for panel level data
916
+ else:
917
+ new_features = list(set(new_features))
918
+ model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit()
919
+ metrics_table, line, actual_vs_predicted_plot = (
920
+ plot_actual_vs_predicted(
921
+ X_train[date_col],
922
+ y_train,
923
+ model.predict(X_train[features_set]),
924
+ model,
925
+ target_column=sel_target_col,
926
+ )
927
+ )
928
+
929
+ metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
930
+ plot_actual_vs_predicted(
931
+ X_train[date_col],
932
+ y_train,
933
+ model_tuned.predict(X_train_tuned[new_features]),
934
+ model_tuned,
935
+ target_column=sel_target_col,
936
+ )
937
+ )
938
+
939
+ # # ----------------------------------- TESTING -----------------------------------
940
+ #
941
+ # Plot Sine & cosine wave to test
942
+ # sine_cosine_plot = plot_actual_vs_predicted(
943
+ # X_train[date_col],
944
+ # y_train,
945
+ # X_train_tuned['sine_wave'],
946
+ # model_tuned,
947
+ # target_column=sel_target_col,
948
+ # is_panel=True,
949
+ # )
950
+ # st.plotly_chart(sine_cosine_plot, use_container_width=True)
951
+
952
+ # # Plot Trend line to test
953
+ # trend_plot = plot_tuned_params(
954
+ # X_train[date_col],
955
+ # y_train,
956
+ # X_train_tuned['Trend'],
957
+ # model_tuned,
958
+ # target_column=sel_target_col,
959
+ # is_panel=True,
960
+ # )
961
+ # st.plotly_chart(trend_plot, use_container_width=True)
962
+ #
963
+ # # Plot week number to test
964
+ # week_num_plot = plot_tuned_params(
965
+ # X_train[date_col],
966
+ # y_train,
967
+ # X_train_tuned['day_of_week'],
968
+ # model_tuned,
969
+ # target_column=sel_target_col,
970
+ # is_panel=True,
971
+ # )
972
+ # st.plotly_chart(week_num_plot, use_container_width=True)
973
+
974
+ # Get model metrics from metric table & display them
975
+ mape = np.round(metrics_table.iloc[0, 1], 2)
976
+ r2 = np.round(metrics_table.iloc[1, 1], 2)
977
+ adjr2 = np.round(metrics_table.iloc[2, 1], 2)
978
+
979
+ mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2)
980
+ r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2)
981
+ adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2)
982
+
983
+ parameters_ = st.columns(3)
984
+ with parameters_[0]:
985
+ st.metric("R-squared", r2_tuned, np.round(r2_tuned - r2, 2))
986
+ with parameters_[1]:
987
+ st.metric(
988
+ "Adj. R-squared", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2)
989
+ )
990
+ with parameters_[2]:
991
+ st.metric(
992
+ "MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse"
993
+ )
994
+
995
+ st.write(model_tuned.summary())
996
+
997
+ X_train_tuned[date_col] = X_train[date_col]
998
+ X_train_tuned[target_col] = y_train
999
+ X_test_tuned[date_col] = X_test[date_col]
1000
+ X_test_tuned[target_col] = y_test
1001
+
1002
+ st.header("2.2 Actual vs. Predicted Plot (Train)")
1003
+ if is_panel:
1004
+ metrics_table, line, actual_vs_predicted_plot = (
1005
+ plot_actual_vs_predicted(
1006
+ X_train_tuned[date_col],
1007
+ X_train_tuned[target_col],
1008
+ model_tuned.fittedvalues,
1009
+ model_tuned,
1010
+ target_column=sel_target_col,
1011
+ is_panel=True,
1012
+ )
1013
+ )
1014
+ else:
1015
+
1016
+ metrics_table, line, actual_vs_predicted_plot = (
1017
+ plot_actual_vs_predicted(
1018
+ X_train_tuned[date_col],
1019
+ X_train_tuned[target_col],
1020
+ model_tuned.predict(X_train_tuned[new_features]),
1021
+ model_tuned,
1022
+ target_column=sel_target_col,
1023
+ is_panel=False,
1024
+ )
1025
+ )
1026
+
1027
+ st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
1028
+
1029
+ st.markdown("## 2.3 Residual Analysis (Train)")
1030
+ if is_panel:
1031
+ columns = st.columns(2)
1032
+ with columns[0]:
1033
+ fig = plot_residual_predicted(
1034
+ y_train, model_tuned.fittedvalues, X_train_tuned
1035
+ )
1036
+ st.plotly_chart(fig)
1037
+
1038
+ with columns[1]:
1039
+ st.empty()
1040
+ fig = qqplot(y_train, model_tuned.fittedvalues)
1041
+ st.plotly_chart(fig)
1042
+
1043
+ with columns[0]:
1044
+ fig = residual_distribution(y_train, model_tuned.fittedvalues)
1045
+ st.pyplot(fig)
1046
+ else:
1047
+ columns = st.columns(2)
1048
+ with columns[0]:
1049
+ fig = plot_residual_predicted(
1050
+ y_train,
1051
+ model_tuned.predict(X_train_tuned[new_features]),
1052
+ X_train,
1053
+ )
1054
+ st.plotly_chart(fig)
1055
+
1056
+ with columns[1]:
1057
+ st.empty()
1058
+ fig = qqplot(
1059
+ y_train, model_tuned.predict(X_train_tuned[new_features])
1060
+ )
1061
+ st.plotly_chart(fig)
1062
+
1063
+ with columns[0]:
1064
+ fig = residual_distribution(
1065
+ y_train, model_tuned.predict(X_train_tuned[new_features])
1066
+ )
1067
+ st.pyplot(fig)
1068
+
1069
+ # st.session_state['is_tuned_model'][target_col] = True
1070
+ # Save tuned model in a dict
1071
+ st.session_state["Model_Tuned"][sel_model + "__" + target_col] = {
1072
+ "Model_object": model_tuned,
1073
+ "feature_set": new_features,
1074
+ "X_train_tuned": X_train_tuned,
1075
+ "X_test_tuned": X_test_tuned,
1076
+ }
1077
+
1078
+ with st.expander("Results Summary Test data"):
1079
+ if is_panel:
1080
+ random_eff_df = get_random_effects(
1081
+ st.session_state.media_data.copy(), panel_col, model_tuned
1082
+ )
1083
+ test_pred = mdf_predict(
1084
+ X_test_tuned, model_tuned, random_eff_df
1085
+ )
1086
+ else:
1087
+ test_pred = model_tuned.predict(X_test_tuned[new_features])
1088
+ st.header("2.2 Actual vs. Predicted Plot (Test)")
1089
+
1090
+ metrics_table, line, actual_vs_predicted_plot = (
1091
+ plot_actual_vs_predicted(
1092
+ X_test_tuned[date_col],
1093
+ y_test,
1094
+ test_pred,
1095
+ model,
1096
+ target_column=sel_target_col,
1097
+ is_panel=is_panel,
1098
+ )
1099
+ )
1100
+ st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
1101
+ st.markdown("## 2.3 Residual Analysis (Test)")
1102
+
1103
+ columns = st.columns(2)
1104
+ with columns[0]:
1105
+ fig = plot_residual_predicted(y_test, test_pred, X_test_tuned)
1106
+ st.plotly_chart(fig)
1107
+
1108
+ with columns[1]:
1109
+ st.empty()
1110
+ fig = qqplot(y_test, test_pred)
1111
+ st.plotly_chart(fig)
1112
+
1113
+ with columns[0]:
1114
+ fig = residual_distribution(y_test, test_pred)
1115
+ st.pyplot(fig)
1116
+ except:
1117
+ # Capture the error details
1118
+ exc_type, exc_value, exc_traceback = sys.exc_info()
1119
+ error_message = "".join(
1120
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
1121
+ )
1122
+ log_message(
1123
+ "error",
1124
+ f"Error while building tuned model: {error_message}",
1125
+ "Model Tuning",
1126
+ )
1127
+ st.warning("An error occured, please try again", icon="⚠️")
1128
+
1129
+ if (
1130
+ st.session_state["Model_Tuned"] is not None
1131
+ and len(list(st.session_state["Model_Tuned"].keys())) > 0
1132
+ ):
1133
+ if st.button("Use This model for Media Planning", use_container_width=True):
1134
+
1135
+ # remove previous tuned models saved for this target col
1136
+ _remove = [
1137
+ m
1138
+ for m in st.session_state["Model_Tuned"].keys()
1139
+ if m.split("__")[1] == target_col and m.split("__")[0] != sel_model
1140
+ ]
1141
+ if len(_remove) > 0:
1142
+ for m in _remove:
1143
+ del st.session_state["Model_Tuned"][m]
1144
+
1145
+ # Flag depicting tuned model for selected response metric
1146
+ st.session_state["is_tuned_model"][target_col] = True
1147
+
1148
+ tuned_model_pkl = pickle.dumps(st.session_state["Model_Tuned"])
1149
+
1150
+ update_db(
1151
+ st.session_state["project_number"],
1152
+ "Model_Tuning",
1153
+ "tuned_model",
1154
+ tuned_model_pkl,
1155
+ schema,
1156
+ # resp_mtrc=None,
1157
+ ) # db
1158
+
1159
+ log_message(
1160
+ "info",
1161
+ f"Tuned model {' '.join(_remove)} removed due to overwrite",
1162
+ "Model Tuning",
1163
+ )
1164
+
1165
+ # Save session state variables (persistence)
1166
+ st.session_state["project_dct"]["model_tuning"][
1167
+ "session_state_saved"
1168
+ ] = {}
1169
+ for key in [
1170
+ "bin_dict",
1171
+ "used_response_metrics",
1172
+ "is_tuned_model",
1173
+ "media_data",
1174
+ "X_test_spends",
1175
+ "spends_data",
1176
+ ]:
1177
+ st.session_state["project_dct"]["model_tuning"][
1178
+ "session_state_saved"
1179
+ ][key] = st.session_state[key]
1180
+
1181
+ project_dct_pkl = pickle.dumps(st.session_state["project_dct"])
1182
+
1183
+ update_db(
1184
+ st.session_state["project_number"],
1185
+ "Model_Tuning",
1186
+ "project_dct",
1187
+ project_dct_pkl,
1188
+ schema,
1189
+ # resp_mtrc=None,
1190
+ ) # db
1191
+
1192
+ log_message(
1193
+ "info",
1194
+ f'Tuned Model {sel_model + "__" + target_col} Saved',
1195
+ "Model Tuning",
1196
+ )
1197
+
1198
+ # Clear page metadata
1199
+ st.session_state["project_dct"]["scenario_planner"][
1200
+ "modified_metadata_file"
1201
+ ] = None
1202
+ st.session_state["project_dct"]["response_curves"][
1203
+ "modified_metadata_file"
1204
+ ] = None
1205
+
1206
+ st.success(sel_model + " for " + target_col + " Tuned saved!")
1207
+ except:
1208
+ # Capture the error details
1209
+ exc_type, exc_value, exc_traceback = sys.exc_info()
1210
+ error_message = "".join(
1211
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
1212
+ )
1213
+ log_message("error", f"An error has occured : {error_message}", "Model Tuning")
1214
+ st.warning("An error occured, please try again", icon="⚠️")
1215
+ # st.write(error_message)
pages/6_AI_Model_Validation.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.express as px
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import statsmodels.api as sm
7
+
8
+ # from sklearn.metrics import mean_absolute_percentage_error
9
+ import sys
10
+ import os
11
+ from utilities import set_header, load_local_css
12
+ import seaborn as sns
13
+ import matplotlib.pyplot as plt
14
+ import tempfile
15
+ from sklearn.preprocessing import MinMaxScaler
16
+
17
+ # from st_aggrid import AgGrid
18
+ # from st_aggrid import GridOptionsBuilder, GridUpdateMode
19
+ # from st_aggrid import GridOptionsBuilder
20
+ import sys
21
+ import re
22
+ import pickle
23
+ from sklearn.metrics import r2_score
24
+ from data_prep import plot_actual_vs_predicted
25
+ import sqlite3
26
+ from utilities import (
27
+ set_header,
28
+ load_local_css,
29
+ update_db,
30
+ project_selection,
31
+ retrieve_pkl_object,
32
+ )
33
+ from post_gres_cred import db_cred
34
+ from log_application import log_message
35
+ import sys, traceback
36
+
37
+ schema = db_cred["schema"]
38
+
39
+ sys.setrecursionlimit(10**6)
40
+
41
+ original_stdout = sys.stdout
42
+ sys.stdout = open("temp_stdout.txt", "w")
43
+ sys.stdout.close()
44
+ sys.stdout = original_stdout
45
+
46
+ st.set_page_config(layout="wide")
47
+ load_local_css("styles.css")
48
+ set_header()
49
+
50
+
51
+ ## DEFINE ALL FUCNTIONS
52
+ def plot_residual_predicted(actual, predicted, df_):
53
+ df_["Residuals"] = actual - pd.Series(predicted)
54
+ df_["StdResidual"] = (df_["Residuals"] - df_["Residuals"].mean()) / df_[
55
+ "Residuals"
56
+ ].std()
57
+
58
+ # Create a Plotly scatter plot
59
+ fig = px.scatter(
60
+ df_,
61
+ x=predicted,
62
+ y="StdResidual",
63
+ opacity=0.5,
64
+ color_discrete_sequence=["#11B6BD"],
65
+ )
66
+
67
+ # Add horizontal lines
68
+ fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
69
+ fig.add_hline(y=2, line_color="red")
70
+ fig.add_hline(y=-2, line_color="red")
71
+
72
+ fig.update_xaxes(title="Predicted")
73
+ fig.update_yaxes(title="Standardized Residuals (Actual - Predicted)")
74
+
75
+ # Set the same width and height for both figures
76
+ fig.update_layout(
77
+ title="Residuals over Predicted Values",
78
+ autosize=False,
79
+ width=600,
80
+ height=400,
81
+ )
82
+
83
+ return fig
84
+
85
+
86
+ def residual_distribution(actual, predicted):
87
+ Residuals = actual - pd.Series(predicted)
88
+
89
+ # Create a Seaborn distribution plot
90
+ sns.set(style="whitegrid")
91
+ plt.figure(figsize=(6, 4))
92
+ sns.histplot(Residuals, kde=True, color="#11B6BD")
93
+
94
+ plt.title(" Distribution of Residuals")
95
+ plt.xlabel("Residuals")
96
+ plt.ylabel("Probability Density")
97
+
98
+ return plt
99
+
100
+
101
+ def qqplot(actual, predicted):
102
+ Residuals = actual - pd.Series(predicted)
103
+ Residuals = pd.Series(Residuals)
104
+ Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
105
+
106
+ # Create a QQ plot using Plotly with custom colors
107
+ fig = go.Figure()
108
+ fig.add_trace(
109
+ go.Scatter(
110
+ x=sm.ProbPlot(Resud_std).theoretical_quantiles,
111
+ y=sm.ProbPlot(Resud_std).sample_quantiles,
112
+ mode="markers",
113
+ marker=dict(size=5, color="#11B6BD"),
114
+ name="QQ Plot",
115
+ )
116
+ )
117
+
118
+ # Add the 45-degree reference line
119
+ diagonal_line = go.Scatter(
120
+ x=[
121
+ -2,
122
+ 2,
123
+ ], # Adjust the x values as needed to fit the range of your data
124
+ y=[-2, 2], # Adjust the y values accordingly
125
+ mode="lines",
126
+ line=dict(color="red"), # Customize the line color and style
127
+ name=" ",
128
+ )
129
+ fig.add_trace(diagonal_line)
130
+
131
+ # Customize the layout
132
+ fig.update_layout(
133
+ title="QQ Plot of Residuals",
134
+ title_x=0.5,
135
+ autosize=False,
136
+ width=600,
137
+ height=400,
138
+ xaxis_title="Theoretical Quantiles",
139
+ yaxis_title="Sample Quantiles",
140
+ )
141
+
142
+ return fig
143
+
144
+
145
+ def get_random_effects(media_data, panel_col, mdf):
146
+ random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
147
+ for i, market in enumerate(media_data[panel_col].unique()):
148
+ print(i, end="\r")
149
+ intercept = mdf.random_effects[market].values[0]
150
+ random_eff_df.loc[i, "random_effect"] = intercept
151
+ random_eff_df.loc[i, panel_col] = market
152
+
153
+ return random_eff_df
154
+
155
+
156
+ def mdf_predict(X_df, mdf, random_eff_df):
157
+ X = X_df.copy()
158
+ X = pd.merge(
159
+ X,
160
+ random_eff_df[[panel_col, "random_effect"]],
161
+ on=panel_col,
162
+ how="left",
163
+ )
164
+ X["pred_fixed_effect"] = mdf.predict(X)
165
+
166
+ X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
167
+ X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
168
+ return X
169
+
170
+
171
+ def metrics_df_panel(model_dict, is_panel):
172
+ def wmape(actual, forecast):
173
+ # Weighted MAPE (WMAPE) eliminates the following shortcomings of MAPE & SMAPE
174
+ ## 1. MAPE becomes insanely high when actual is close to 0
175
+ ## 2. MAPE is more favourable to underforecast than overforecast
176
+ return np.sum(np.abs(actual - forecast)) / np.sum(np.abs(actual))
177
+
178
+ metrics_df = pd.DataFrame(
179
+ columns=[
180
+ "Model",
181
+ "R2",
182
+ "ADJR2",
183
+ "Train Mape",
184
+ "Test Mape",
185
+ "Summary",
186
+ "Model_object",
187
+ ]
188
+ )
189
+ i = 0
190
+ for key in model_dict.keys():
191
+ target = key.split("__")[1]
192
+ metrics_df.at[i, "Model"] = target
193
+ y = model_dict[key]["X_train_tuned"][target]
194
+
195
+ feature_set = model_dict[key]["feature_set"]
196
+
197
+ if is_panel:
198
+ random_df = get_random_effects(
199
+ media_data, panel_col, model_dict[key]["Model_object"]
200
+ )
201
+ pred = mdf_predict(
202
+ model_dict[key]["X_train_tuned"],
203
+ model_dict[key]["Model_object"],
204
+ random_df,
205
+ )["pred"]
206
+ else:
207
+ pred = model_dict[key]["Model_object"].predict(
208
+ model_dict[key]["X_train_tuned"][feature_set]
209
+ )
210
+
211
+ ytest = model_dict[key]["X_test_tuned"][target]
212
+ if is_panel:
213
+
214
+ predtest = mdf_predict(
215
+ model_dict[key]["X_test_tuned"],
216
+ model_dict[key]["Model_object"],
217
+ random_df,
218
+ )["pred"]
219
+
220
+ else:
221
+ predtest = model_dict[key]["Model_object"].predict(
222
+ model_dict[key]["X_test_tuned"][feature_set]
223
+ )
224
+
225
+ metrics_df.at[i, "R2"] = r2_score(y, pred)
226
+ metrics_df.at[i, "ADJR2"] = 1 - (1 - metrics_df.loc[i, "R2"]) * (len(y) - 1) / (
227
+ len(y) - len(model_dict[key]["feature_set"]) - 1
228
+ )
229
+ # metrics_df.at[i, "Train Mape"] = mean_absolute_percentage_error(y, pred)
230
+ # metrics_df.at[i, "Test Mape"] = mean_absolute_percentage_error(
231
+ # ytest, predtest
232
+ # )
233
+ metrics_df.at[i, "Train Mape"] = wmape(y, pred)
234
+ metrics_df.at[i, "Test Mape"] = wmape(ytest, predtest)
235
+ metrics_df.at[i, "Summary"] = model_dict[key]["Model_object"].summary()
236
+ metrics_df.at[i, "Model_object"] = model_dict[key]["Model_object"]
237
+ i += 1
238
+ metrics_df = np.round(metrics_df, 2)
239
+
240
+ metrics_df.rename(
241
+ columns={"R2": "R-squared", "ADJR2": "Adj. R-squared"}, inplace=True
242
+ )
243
+ return metrics_df
244
+
245
+
246
+ def map_channel(transformed_var, channel_dict):
247
+ for key, value_list in channel_dict.items():
248
+ if any(raw_var in transformed_var for raw_var in value_list):
249
+ return key
250
+ return transformed_var # Return the original value if no match is found
251
+
252
+
253
+ def contributions_nonpanel(model_dict):
254
+ # with open(os.path.join(st.session_state["project_path"], "channel_groups.pkl"), "rb") as f:
255
+ # channels = pickle.load(f)
256
+
257
+ channels = st.session_state["project_dct"]["data_import"]["group_dict"] # db
258
+ media_data = st.session_state["media_data"]
259
+ contribution_df = pd.DataFrame(columns=["Channel"])
260
+
261
+ for key in model_dict.keys():
262
+
263
+ best_feature_set = model_dict[key]["feature_set"]
264
+ model = model_dict[key]["Model_object"]
265
+ target = key.split("__")[1]
266
+ X_train = model_dict[key]["X_train_tuned"]
267
+ contri_df = pd.DataFrame()
268
+ y = []
269
+ y_pred = []
270
+
271
+ coef_df = pd.DataFrame(model.params)
272
+ coef_df.reset_index(inplace=True)
273
+ coef_df.columns = ["feature", "coef"]
274
+ x_train_contribution = X_train.copy()
275
+ x_train_contribution["pred"] = model.predict(X_train[best_feature_set])
276
+
277
+ for i in range(len(coef_df)):
278
+
279
+ coef = coef_df.loc[i, "coef"]
280
+ col = coef_df.loc[i, "feature"]
281
+ if col != "const":
282
+ x_train_contribution[str(col) + "_contr"] = (
283
+ coef * x_train_contribution[col]
284
+ )
285
+ else:
286
+ x_train_contribution["const"] = coef
287
+
288
+ tuning_cols = [
289
+ c
290
+ for c in x_train_contribution.filter(regex="contr").columns
291
+ if c
292
+ in [
293
+ "day_of_week_contr",
294
+ "Trend_contr",
295
+ "sine_wave_contr",
296
+ "cosine_wave_contr",
297
+ ]
298
+ ]
299
+ flag_cols = [
300
+ c
301
+ for c in x_train_contribution.filter(regex="contr").columns
302
+ if "_flag" in c
303
+ ]
304
+
305
+ # add exogenous contribution to base
306
+ all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
307
+ all_exog_vars = [
308
+ var.lower()
309
+ .replace(".", "_")
310
+ .replace("@", "_")
311
+ .replace(" ", "_")
312
+ .replace("-", "")
313
+ .replace(":", "")
314
+ .replace("__", "_")
315
+ for var in all_exog_vars
316
+ ]
317
+ exog_cols = []
318
+ if len(all_exog_vars) > 0:
319
+ for col in x_train_contribution.filter(regex="contr").columns:
320
+ if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
321
+ exog_cols.append(col)
322
+
323
+ base_cols = ["const"] + flag_cols + tuning_cols + exog_cols
324
+
325
+ x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(axis=1)
326
+ x_train_contribution.drop(columns=base_cols, inplace=True)
327
+
328
+ contri_df = pd.DataFrame(x_train_contribution.filter(regex="contr").sum(axis=0))
329
+
330
+ contri_df.reset_index(inplace=True)
331
+ contri_df.columns = ["Channel", target]
332
+ contri_df["Channel"] = contri_df["Channel"].apply(
333
+ lambda x: map_channel(x, channels)
334
+ )
335
+ contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
336
+ contri_df["Channel"].replace("base_contr", "base", inplace=True)
337
+ contribution_df = pd.merge(
338
+ contribution_df, contri_df, on="Channel", how="outer"
339
+ )
340
+
341
+ return contribution_df
342
+
343
+
344
+ def contributions_panel(model_dict):
345
+ channels = st.session_state["project_dct"]["data_import"]["group_dict"] # db
346
+ media_data = st.session_state["media_data"]
347
+ contribution_df = pd.DataFrame(columns=["Channel"])
348
+ for key in model_dict.keys():
349
+ best_feature_set = model_dict[key]["feature_set"]
350
+ model = model_dict[key]["Model_object"]
351
+ target = key.split("__")[1]
352
+ X_train = model_dict[key]["X_train_tuned"]
353
+ contri_df = pd.DataFrame()
354
+
355
+ y = []
356
+ y_pred = []
357
+
358
+ random_eff_df = get_random_effects(media_data, panel_col, model)
359
+ random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
360
+ random_eff_df["panel_effect"] = (
361
+ random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
362
+ )
363
+
364
+ coef_df = pd.DataFrame(model.fe_params)
365
+ coef_df.reset_index(inplace=True)
366
+ coef_df.columns = ["feature", "coef"]
367
+
368
+ x_train_contribution = X_train.copy()
369
+ x_train_contribution = mdf_predict(x_train_contribution, model, random_eff_df)
370
+
371
+ x_train_contribution = pd.merge(
372
+ x_train_contribution,
373
+ random_eff_df[[panel_col, "panel_effect"]],
374
+ on=panel_col,
375
+ how="left",
376
+ )
377
+ for i in range(len(coef_df)):
378
+ coef = coef_df.loc[i, "coef"]
379
+ col = coef_df.loc[i, "feature"]
380
+ if col.lower() != "intercept":
381
+ x_train_contribution[str(col) + "_contr"] = (
382
+ coef * x_train_contribution[col]
383
+ )
384
+
385
+ # x_train_contribution['sum_contributions'] = x_train_contribution.filter(regex="contr").sum(axis=1)
386
+ # x_train_contribution['sum_contributions'] = x_train_contribution['sum_contributions'] + x_train_contribution[
387
+ # 'panel_effect']
388
+
389
+ # base_cols = ["panel_effect"] + [
390
+ # c
391
+ # for c in x_train_contribution.filter(regex="contr").columns
392
+ # if c
393
+ # in [
394
+ # "day_of_week_contr",
395
+ # "Trend_contr",
396
+ # "sine_wave_contr",
397
+ # "cosine_wave_contr",
398
+ # ]
399
+ # ]
400
+ tuning_cols = [
401
+ c
402
+ for c in x_train_contribution.filter(regex="contr").columns
403
+ if c
404
+ in [
405
+ "day_of_week_contr",
406
+ "Trend_contr",
407
+ "sine_wave_contr",
408
+ "cosine_wave_contr",
409
+ ]
410
+ ]
411
+ flag_cols = [
412
+ c
413
+ for c in x_train_contribution.filter(regex="contr").columns
414
+ if "_flag" in c
415
+ ]
416
+
417
+ # add exogenous contribution to base
418
+ all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
419
+ all_exog_vars = [
420
+ var.lower()
421
+ .replace(".", "_")
422
+ .replace("@", "_")
423
+ .replace(" ", "_")
424
+ .replace("-", "")
425
+ .replace(":", "")
426
+ .replace("__", "_")
427
+ for var in all_exog_vars
428
+ ]
429
+ exog_cols = []
430
+ if len(all_exog_vars) > 0:
431
+ for col in x_train_contribution.filter(regex="contr").columns:
432
+ if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
433
+ exog_cols.append(col)
434
+
435
+ base_cols = ["panel_effect"] + flag_cols + tuning_cols + exog_cols
436
+
437
+ x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(axis=1)
438
+ x_train_contribution.drop(columns=base_cols, inplace=True)
439
+
440
+ contri_df = pd.DataFrame(x_train_contribution.filter(regex="contr").sum(axis=0))
441
+ contri_df.reset_index(inplace=True)
442
+ contri_df.columns = ["Channel", target]
443
+
444
+ contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
445
+ contri_df["Channel"] = contri_df["Channel"].apply(
446
+ lambda x: map_channel(x, channels)
447
+ )
448
+
449
+ contri_df["Channel"].replace("base_contr", "base", inplace=True)
450
+ contribution_df = pd.merge(
451
+ contribution_df, contri_df, on="Channel", how="outer"
452
+ )
453
+ # st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
454
+ return contribution_df
455
+
456
+
457
+ def create_grouped_bar_plot(contribution_df, contribution_selections):
458
+ # Extract the 'Channel' names
459
+ channel_names = contribution_df["Channel"].tolist()
460
+
461
+ # Dictionary to store all contributions except 'const' and 'base'
462
+ all_contributions = {
463
+ name: [] for name in channel_names if name not in ["const", "base"]
464
+ }
465
+
466
+ # Dictionary to store base sales for each selection
467
+ base_sales_dict = {}
468
+
469
+ # Accumulate contributions for each channel from each selection
470
+ for selection in contribution_selections:
471
+ contributions = contribution_df[selection].values.astype(float)
472
+ base_sales = 0 # Initialize base sales for the current selection
473
+
474
+ for channel_name, contribution in zip(channel_names, contributions):
475
+ if channel_name in all_contributions:
476
+ all_contributions[channel_name].append(contribution)
477
+ elif channel_name == "base":
478
+ base_sales = (
479
+ contribution # Capture base sales for the current selection
480
+ )
481
+
482
+ # Store base sales for each selection
483
+ base_sales_dict[selection] = base_sales
484
+
485
+ # Calculate the average of contributions and sort by this average
486
+ sorted_channels = sorted(all_contributions.items(), key=lambda x: -np.mean(x[1]))
487
+ sorted_channel_names = [name for name, _ in sorted_channels]
488
+ sorted_channel_names = [
489
+ "Base Sales"
490
+ ] + sorted_channel_names # Adding 'Base Sales' at the start
491
+
492
+ trace_data = []
493
+ max_value = 0 # Initialize max_value to find the highest bar for y-axis adjustment
494
+
495
+ # Create traces for the grouped bar chart
496
+ for i, selection in enumerate(contribution_selections):
497
+ display_name = sorted_channel_names
498
+ display_contribution = [base_sales_dict[selection]] + [
499
+ all_contributions[name][i] for name in sorted_channel_names[1:]
500
+ ] # Start with base sales for the current selection
501
+
502
+ # Generating text labels for each bar
503
+ text_values = [
504
+ f"{val}%" for val in np.round(display_contribution, 0).astype(int)
505
+ ]
506
+
507
+ # Find the max value for y-axis calculation
508
+ max_contribution = max(display_contribution)
509
+ if max_contribution > max_value:
510
+ max_value = max_contribution
511
+
512
+ # Create a bar trace for each selection
513
+ trace = go.Bar(
514
+ x=display_name,
515
+ y=display_contribution,
516
+ name=selection,
517
+ text=text_values,
518
+ textposition="outside",
519
+ )
520
+ trace_data.append(trace)
521
+
522
+ # Define layout for the bar chart
523
+ layout = go.Layout(
524
+ title="Metrics Contribution by Channel (Train)",
525
+ xaxis=dict(title="Channel Name"),
526
+ yaxis=dict(
527
+ title="Metrics Contribution", range=[0, max_value * 1.2]
528
+ ), # Set y-axis 20% higher than the max bar
529
+ barmode="group",
530
+ plot_bgcolor="white",
531
+ )
532
+
533
+ # Create the figure with trace data and layout
534
+ fig = go.Figure(data=trace_data, layout=layout)
535
+
536
+ return fig
537
+
538
+
539
+ def preprocess_and_plot(contribution_df, contribution_selections):
540
+ # Extract the 'Channel' names
541
+ channel_names = contribution_df["Channel"].tolist()
542
+
543
+ # Dictionary to store all contributions except 'const' and 'base'
544
+ all_contributions = {
545
+ name: [] for name in channel_names if name not in ["const", "base"]
546
+ }
547
+
548
+ # Dictionary to store base sales for each selection
549
+ base_sales_dict = {}
550
+
551
+ # Accumulate contributions for each channel from each selection
552
+ for selection in contribution_selections:
553
+ contributions = contribution_df[selection].values.astype(float)
554
+ base_sales = 0 # Initialize base sales for the current selection
555
+
556
+ for channel_name, contribution in zip(channel_names, contributions):
557
+ if channel_name in all_contributions:
558
+ all_contributions[channel_name].append(contribution)
559
+ elif channel_name == "base":
560
+ base_sales = (
561
+ contribution # Capture base sales for the current selection
562
+ )
563
+
564
+ # Store base sales for each selection
565
+ base_sales_dict[selection] = base_sales
566
+
567
+ # Calculate the average of contributions and sort by this average
568
+ sorted_channels = sorted(all_contributions.items(), key=lambda x: -np.mean(x[1]))
569
+ sorted_channel_names = [name for name, _ in sorted_channels]
570
+ sorted_channel_names = [
571
+ "Base Sales"
572
+ ] + sorted_channel_names # Adding 'Base Sales' at the start
573
+
574
+ # Initialize a Plotly figure
575
+ fig = go.Figure()
576
+
577
+ for i, selection in enumerate(contribution_selections):
578
+ display_name = ["Base Sales"] + sorted_channel_names[
579
+ 1:
580
+ ] # Channel names for the plot
581
+ display_contribution = [
582
+ base_sales_dict[selection]
583
+ ] # Start with base sales for the current selection
584
+
585
+ # Append average contributions for other channels
586
+ for name in sorted_channel_names[1:]:
587
+ display_contribution.append(all_contributions[name][i])
588
+
589
+ # Generating text labels for each bar
590
+ text_values = [
591
+ f"{val}%" for val in np.round(display_contribution, 0).astype(int)
592
+ ]
593
+
594
+ # Add a waterfall trace for each selection
595
+ fig.add_trace(
596
+ go.Waterfall(
597
+ orientation="v",
598
+ measure=["relative"] * len(display_contribution),
599
+ x=display_name,
600
+ text=text_values,
601
+ textposition="outside",
602
+ y=display_contribution,
603
+ increasing={"marker": {"color": "green"}},
604
+ decreasing={"marker": {"color": "red"}},
605
+ totals={"marker": {"color": "blue"}},
606
+ name=selection,
607
+ )
608
+ )
609
+
610
+ # Update layout of the figure
611
+ fig.update_layout(
612
+ title="Metrics Contribution by Channel (Train)",
613
+ xaxis={"title": "Channel Name"},
614
+ yaxis=dict(title="Metrics Contribution", range=[0, 100 * 1.2]),
615
+ )
616
+
617
+ return fig
618
+
619
+
620
+ def selection_change():
621
+ edited_rows: dict = st.session_state.project_selection["edited_rows"]
622
+ st.session_state["selected_row_index_gd_table"] = next(iter(edited_rows))
623
+ st.session_state["gd_table"] = st.session_state["gd_table"].assign(selected=False)
624
+
625
+ update_dict = {idx: values for idx, values in edited_rows.items()}
626
+
627
+ st.session_state["gd_table"].update(
628
+ pd.DataFrame.from_dict(update_dict, orient="index")
629
+ )
630
+
631
+
632
+ if "username" not in st.session_state:
633
+ st.session_state["username"] = None
634
+
635
+ if "project_name" not in st.session_state:
636
+ st.session_state["project_name"] = None
637
+
638
+ if "project_dct" not in st.session_state:
639
+ project_selection()
640
+ st.stop()
641
+
642
+ try:
643
+ st.session_state["bin_dict"] = st.session_state["project_dct"]["data_import"][
644
+ "category_dict"
645
+ ] # db
646
+
647
+ except Exception as e:
648
+ st.warning("Save atleast one tuned model to proceed")
649
+ log_message("warning", "No tuned models available", "AI Model Results")
650
+ st.stop()
651
+
652
+
653
+ if "gd_table" not in st.session_state:
654
+ st.session_state["gd_table"] = pd.DataFrame()
655
+
656
+ try:
657
+ if "username" in st.session_state and st.session_state["username"] is not None:
658
+
659
+ if (
660
+ retrieve_pkl_object(
661
+ st.session_state["project_number"],
662
+ "Model_Tuning",
663
+ "tuned_model",
664
+ schema,
665
+ )
666
+ is None
667
+ ):
668
+
669
+ st.error("Please save a tuned model")
670
+ st.stop()
671
+
672
+ if (
673
+ "session_state_saved"
674
+ in st.session_state["project_dct"]["model_tuning"].keys()
675
+ and st.session_state["project_dct"]["model_tuning"]["session_state_saved"]
676
+ != []
677
+ ):
678
+ for key in ["used_response_metrics", "media_data", "bin_dict"]:
679
+ if key not in st.session_state:
680
+ st.session_state[key] = st.session_state["project_dct"][
681
+ "model_tuning"
682
+ ]["session_state_saved"][key]
683
+ # st.session_state["bin_dict"] = st.session_state["project_dct"][
684
+ # "model_build"
685
+ # ]["session_state_saved"]["bin_dict"]
686
+
687
+ media_data = st.session_state["media_data"]
688
+
689
+ # st.write(media_data.columns)
690
+
691
+ # set the panel column
692
+ panel_col = "panel"
693
+ is_panel = (
694
+ True if st.session_state["media_data"][panel_col].nunique() > 1 else False
695
+ )
696
+ # st.write(is_panel)
697
+
698
+ date_col = "date"
699
+
700
+ transformed_data = st.session_state["project_dct"]["transformations"][
701
+ "final_df"
702
+ ] # db
703
+ tuned_model_dict = retrieve_pkl_object(
704
+ st.session_state["project_number"], "Model_Tuning", "tuned_model", schema
705
+ ) # db
706
+
707
+ feature_set_dct = {
708
+ key.split("__")[1]: key_dict["feature_set"]
709
+ for key, key_dict in tuned_model_dict.items()
710
+ }
711
+
712
+ # """ the above part should be modified so that we are fetching features set from the saved model"""
713
+
714
+ if "contribution_df" not in st.session_state:
715
+ st.session_state["contribution_df"] = None
716
+
717
+ metrics_table = metrics_df_panel(tuned_model_dict, is_panel)
718
+
719
+ cols1 = st.columns([2, 1])
720
+ with cols1[0]:
721
+ st.markdown(f"**Welcome {st.session_state['username']}**")
722
+ with cols1[1]:
723
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
724
+
725
+ st.title("AI Model Validation")
726
+
727
+ st.header("Contribution Overview")
728
+
729
+ # Get list of response metrics
730
+ st.session_state["used_response_metrics"] = list(
731
+ set([model.split("__")[1] for model in tuned_model_dict.keys()])
732
+ )
733
+ options = st.session_state["used_response_metrics"]
734
+
735
+ if len(options) == 0:
736
+ st.error("Please save and tune a model")
737
+ st.stop()
738
+ options = [
739
+ opt.lower()
740
+ .replace(" ", "_")
741
+ .replace("-", "")
742
+ .replace(":", "")
743
+ .replace("__", "_")
744
+ for opt in options
745
+ ]
746
+
747
+ default_options = (
748
+ st.session_state["project_dct"]["saved_model_results"].get(
749
+ "selected_options"
750
+ )
751
+ if st.session_state["project_dct"]["saved_model_results"].get(
752
+ "selected_options"
753
+ )
754
+ is not None
755
+ else [options[-1]]
756
+ )
757
+ for i in default_options:
758
+ if i not in options:
759
+ # st.write(i)
760
+ default_options.remove(i)
761
+
762
+ def remove_response_metric(name):
763
+ # Convert the name to a lowercase string and remove any leading or trailing spaces
764
+ name_str = str(name).lower().strip()
765
+
766
+ # Check if the name starts with "response metric" or "response_metric"
767
+ if name_str.startswith("response metric"):
768
+ return name[len("response metric") :].replace("_", " ").strip().title()
769
+ elif name_str.startswith("response_metric"):
770
+ return name[len("response_metric") :].replace("_", " ").strip().title()
771
+ else:
772
+ return name
773
+
774
+ contribution_selections = st.multiselect(
775
+ "Select the Response Metrics to compare contributions",
776
+ options,
777
+ default=default_options,
778
+ format_func=remove_response_metric,
779
+ )
780
+ trace_data = []
781
+
782
+ if is_panel:
783
+ st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
784
+
785
+ else:
786
+ st.session_state["contribution_df"] = contributions_nonpanel(
787
+ tuned_model_dict
788
+ )
789
+
790
+ # st.write(st.session_state["contribution_df"].columns)
791
+ # for selection in contribution_selections:
792
+
793
+ # trace = go.Bar(
794
+ # x=st.session_state["contribution_df"]["Channel"],
795
+ # y=st.session_state["contribution_df"][selection],
796
+ # name=selection,
797
+ # text=np.round(st.session_state["contribution_df"][selection], 0)
798
+ # .astype(int)
799
+ # .astype(str)
800
+ # + "%",
801
+ # textposition="outside",
802
+ # )
803
+ # trace_data.append(trace)
804
+
805
+ # layout = go.Layout(
806
+ # title="Metrics Contribution by Channel",
807
+ # xaxis=dict(title="Channel Name"),
808
+ # yaxis=dict(title="Metrics Contribution"),
809
+ # barmode="group",
810
+ # )
811
+ # fig = go.Figure(data=trace_data, layout=layout)
812
+ # st.plotly_chart(fig, use_container_width=True)
813
+
814
+ # Display the chart in Streamlit
815
+ st.plotly_chart(
816
+ create_grouped_bar_plot(
817
+ st.session_state["contribution_df"], contribution_selections
818
+ ),
819
+ use_container_width=True,
820
+ )
821
+
822
+ ############################################ Waterfall Chart ############################################
823
+
824
+ import plotly.graph_objects as go
825
+
826
+ st.plotly_chart(
827
+ preprocess_and_plot(
828
+ st.session_state["contribution_df"], contribution_selections
829
+ ),
830
+ use_container_width=True,
831
+ )
832
+
833
+ ############################################ Waterfall Chart ############################################
834
+ st.header("Analysis of Models Result")
835
+ gd_table = metrics_table.iloc[:, :-2]
836
+ target_column = gd_table.at[0, "Model"] # sprint8
837
+ st.session_state["gd_table"] = gd_table
838
+
839
+ with st.container():
840
+ table = st.data_editor(
841
+ st.session_state["gd_table"],
842
+ hide_index=True,
843
+ # on_change=selection_change,
844
+ key="project_selection",
845
+ use_container_width=True,
846
+ )
847
+
848
+ target_column = st.selectbox(
849
+ "Select a Model to analyse its results",
850
+ options=st.session_state.used_response_metrics,
851
+ placeholder=options[0],
852
+ )
853
+ feature_set = feature_set_dct[target_column]
854
+
855
+ model = metrics_table[metrics_table["Model"] == target_column][
856
+ "Model_object"
857
+ ].iloc[0]
858
+ target = metrics_table[metrics_table["Model"] == target_column]["Model"].iloc[0]
859
+ st.header("Model Summary")
860
+ st.write(model.summary())
861
+
862
+ sel_dict = tuned_model_dict[
863
+ [k for k in tuned_model_dict.keys() if k.split("__")[1] == target][0]
864
+ ]
865
+
866
+ feature_set = sel_dict["feature_set"]
867
+ X_train = sel_dict["X_train_tuned"]
868
+ y_train = X_train[target]
869
+
870
+ if is_panel:
871
+ random_effects = get_random_effects(media_data, panel_col, model)
872
+ pred = mdf_predict(X_train, model, random_effects)["pred"]
873
+ else:
874
+ pred = model.predict(X_train[feature_set])
875
+
876
+ X_test = sel_dict["X_test_tuned"]
877
+ y_test = X_test[target]
878
+ if is_panel:
879
+ predtest = mdf_predict(X_test, model, random_effects)["pred"]
880
+ else:
881
+ predtest = model.predict(X_test[feature_set])
882
+
883
+ metrics_table_train, _, fig_train = plot_actual_vs_predicted(
884
+ X_train[date_col],
885
+ y_train,
886
+ pred,
887
+ model,
888
+ target_column=target,
889
+ flag=None,
890
+ repeat_all_years=False,
891
+ is_panel=is_panel,
892
+ )
893
+
894
+ metrics_table_test, _, fig_test = plot_actual_vs_predicted(
895
+ X_test[date_col],
896
+ y_test,
897
+ predtest,
898
+ model,
899
+ target_column=target,
900
+ flag=None,
901
+ repeat_all_years=False,
902
+ is_panel=is_panel,
903
+ )
904
+
905
+ metrics_table_train = metrics_table_train.set_index("Metric").transpose()
906
+ metrics_table_train.index = ["Train"]
907
+ metrics_table_test = metrics_table_test.set_index("Metric").transpose()
908
+ metrics_table_test.index = ["Test"]
909
+ metrics_table = np.round(
910
+ pd.concat([metrics_table_train, metrics_table_test]), 2
911
+ )
912
+
913
+ st.markdown("Result Overview")
914
+ st.dataframe(np.round(metrics_table, 2), use_container_width=True)
915
+
916
+ st.header("Model Accuracy")
917
+ st.subheader("Actual vs Predicted Plot (Train)")
918
+
919
+ st.plotly_chart(fig_train, use_container_width=True)
920
+ st.subheader("Actual vs Predicted Plot (Test)")
921
+ st.plotly_chart(fig_test, use_container_width=True)
922
+
923
+ st.markdown("## Residual Analysis (Train)")
924
+ columns = st.columns(2)
925
+
926
+ Xtrain1 = X_train.copy()
927
+ with columns[0]:
928
+ fig = plot_residual_predicted(y_train, pred, Xtrain1)
929
+ st.plotly_chart(fig)
930
+
931
+ with columns[1]:
932
+ st.empty()
933
+ fig = qqplot(y_train, pred)
934
+ st.plotly_chart(fig)
935
+
936
+ with columns[0]:
937
+ fig = residual_distribution(y_train, pred)
938
+ st.pyplot(fig)
939
+
940
+ if st.button("Save this session", use_container_width=True):
941
+ project_dct_pkl = pickle.dumps(st.session_state["project_dct"])
942
+
943
+ update_db(
944
+ st.session_state["project_number"],
945
+ "AI_Model_Results",
946
+ "project_dct",
947
+ project_dct_pkl,
948
+ schema,
949
+ # resp_mtrc=None,
950
+ ) # db
951
+
952
+ log_message("info", "Session saved!", "AI Model Results")
953
+ st.success("Session Saved!")
954
+ except:
955
+ exc_type, exc_value, exc_traceback = sys.exc_info()
956
+ error_message = "".join(
957
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
958
+ )
959
+ log_message("error", f"Error: {error_message}", "AI Model Results")
960
+ st.warning("An error occured, please try again", icon="⚠️")
pages/7_AI_Model_Media_Performance.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from sklearn.preprocessing import MinMaxScaler
4
+ import pickle
5
+ import os
6
+ from utilities_with_panel import load_local_css, set_header
7
+ import yaml
8
+ from yaml import SafeLoader
9
+ import sqlite3
10
+ from datetime import timedelta
11
+ from utilities import (
12
+ set_header,
13
+ load_local_css,
14
+ update_db,
15
+ project_selection,
16
+ retrieve_pkl_object,
17
+ )
18
+ from utilities_with_panel import (
19
+ overview_test_data_prep_panel,
20
+ overview_test_data_prep_nonpanel,
21
+ initialize_data_cmp,
22
+ create_channel_summary,
23
+ create_contribution_pie,
24
+ create_contribuion_stacked_plot,
25
+ create_channel_spends_sales_plot,
26
+ format_numbers,
27
+ channel_name_formating,
28
+ )
29
+ from log_application import log_message
30
+ import sys, traceback
31
+ from post_gres_cred import db_cred
32
+
33
+ st.set_page_config(layout="wide")
34
+ load_local_css("styles.css")
35
+ set_header()
36
+
37
+
38
+ schema = db_cred["schema"]
39
+
40
+ if "username" not in st.session_state:
41
+ st.session_state["username"] = None
42
+
43
+ if "project_name" not in st.session_state:
44
+ st.session_state["project_name"] = None
45
+
46
+ if "project_dct" not in st.session_state:
47
+ project_selection()
48
+ st.stop()
49
+
50
+ tuned_model = retrieve_pkl_object(
51
+ st.session_state["project_number"], "Model_Tuning", "tuned_model", schema
52
+ )
53
+
54
+ if tuned_model is None:
55
+ st.error("Please save a tuned model")
56
+ st.stop()
57
+
58
+ if (
59
+ "session_state_saved" in st.session_state["project_dct"]["model_tuning"].keys()
60
+ and st.session_state["project_dct"]["model_tuning"]["session_state_saved"] != []
61
+ ):
62
+ for key in ["used_response_metrics", "media_data", "bin_dict"]:
63
+ if key not in st.session_state:
64
+ st.session_state[key] = st.session_state["project_dct"]["model_tuning"][
65
+ "session_state_saved"
66
+ ][key]
67
+
68
+
69
+ ## DEFINE ALL FUNCTIONS
70
+ def get_random_effects(media_data, panel_col, mdf):
71
+ random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
72
+ for i, market in enumerate(media_data[panel_col].unique()):
73
+ print(i, end="\r")
74
+ intercept = mdf.random_effects[market].values[0]
75
+ random_eff_df.loc[i, "random_effect"] = intercept
76
+ random_eff_df.loc[i, panel_col] = market
77
+
78
+ return random_eff_df
79
+
80
+
81
+ def process_train_and_test(train, test, features, panel_col, target_col):
82
+ X1 = train[features]
83
+
84
+ ss = MinMaxScaler()
85
+ X1 = pd.DataFrame(ss.fit_transform(X1), columns=X1.columns)
86
+
87
+ X1[panel_col] = train[panel_col]
88
+ X1[target_col] = train[target_col]
89
+
90
+ if test is not None:
91
+ X2 = test[features]
92
+ X2 = pd.DataFrame(ss.transform(X2), columns=X2.columns)
93
+ X2[panel_col] = test[panel_col]
94
+ X2[target_col] = test[target_col]
95
+ return X1, X2
96
+ return X1
97
+
98
+
99
+ def mdf_predict(X_df, mdf, random_eff_df):
100
+ X = X_df.copy()
101
+ X = pd.merge(
102
+ X,
103
+ random_eff_df[[panel_col, "random_effect"]],
104
+ on=panel_col,
105
+ how="left",
106
+ )
107
+ X["pred_fixed_effect"] = mdf.predict(X)
108
+
109
+ X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
110
+ X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
111
+
112
+ return X
113
+
114
+
115
+ try:
116
+ if "username" in st.session_state and st.session_state["username"] is not None:
117
+
118
+ # conn = sqlite3.connect(
119
+ # r"DB/User.db", check_same_thread=False
120
+ # ) # connection with sql db
121
+ # c = conn.cursor()
122
+
123
+ tuned_model = retrieve_pkl_object(
124
+ st.session_state["project_number"], "Model_Tuning", "tuned_model", schema
125
+ )
126
+
127
+ if tuned_model is None:
128
+ st.error("Please save a tuned model")
129
+ st.stop()
130
+
131
+ if (
132
+ "session_state_saved"
133
+ in st.session_state["project_dct"]["model_tuning"].keys()
134
+ and st.session_state["project_dct"]["model_tuning"]["session_state_saved"]
135
+ != []
136
+ ):
137
+ for key in [
138
+ "used_response_metrics",
139
+ "is_tuned_model",
140
+ "media_data",
141
+ "X_test_spends",
142
+ "spends_data",
143
+ ]:
144
+ st.session_state[key] = st.session_state["project_dct"]["model_tuning"][
145
+ "session_state_saved"
146
+ ][key]
147
+ elif (
148
+ "session_state_saved"
149
+ in st.session_state["project_dct"]["model_build"].keys()
150
+ and st.session_state["project_dct"]["model_build"]["session_state_saved"]
151
+ != []
152
+ ):
153
+ for key in [
154
+ "used_response_metrics",
155
+ "date",
156
+ "saved_model_names",
157
+ "media_data",
158
+ "X_test_spends",
159
+ ]:
160
+ st.session_state[key] = st.session_state["project_dct"]["model_build"][
161
+ "session_state_saved"
162
+ ][key]
163
+ else:
164
+ st.error("Please tune a model first")
165
+ st.session_state["bin_dict"] = st.session_state["project_dct"]["model_build"][
166
+ "session_state_saved"
167
+ ]["bin_dict"]
168
+ st.session_state["media_data"].columns = [
169
+ c.lower() for c in st.session_state["media_data"].columns
170
+ ]
171
+
172
+ # with open(
173
+ # os.path.join(st.session_state["project_path"], "data_import.pkl"),
174
+ # "rb",
175
+ # ) as f:
176
+ # data = pickle.load(f)
177
+
178
+ # # Accessing the loaded objects
179
+
180
+ # st.session_state["orig_media_data"] = data["final_df"]
181
+
182
+ st.session_state["orig_media_data"] = st.session_state["project_dct"][
183
+ "data_import"
184
+ ][
185
+ "imputed_tool_df"
186
+ ].copy() # db
187
+ st.session_state["channels"] = st.session_state["project_dct"]["data_import"][
188
+ "group_dict"
189
+ ].copy()
190
+ # target='Revenue'
191
+
192
+ # set the panel column
193
+ panel_col = "panel"
194
+ is_panel = (
195
+ True if st.session_state["media_data"][panel_col].nunique() > 1 else False
196
+ )
197
+
198
+ date_col = "date"
199
+
200
+ cols1 = st.columns([2, 1])
201
+ with cols1[0]:
202
+ st.markdown(f"**Welcome {st.session_state['username']}**")
203
+ with cols1[1]:
204
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
205
+
206
+ st.title("AI Model Media Performance")
207
+
208
+ def remove_response_metric(name):
209
+ # Convert the name to a lowercase string and remove any leading or trailing spaces
210
+ name_str = str(name).lower().strip()
211
+
212
+ # Check if the name starts with "response metric" or "response_metric"
213
+ if name_str.startswith("response metric"):
214
+ return name[len("response metric") :].replace("_", " ").strip().title()
215
+ elif name_str.startswith("response_metric"):
216
+ return name[len("response_metric") :].replace("_", " ").strip().title()
217
+ else:
218
+ return name
219
+
220
+ sel_target_col = st.selectbox(
221
+ "Select the response metric",
222
+ st.session_state["used_response_metrics"],
223
+ format_func=remove_response_metric,
224
+ )
225
+ sel_target_col_frmttd = sel_target_col.replace("_", " ").replace("-", " ")
226
+ sel_target_col_frmttd = sel_target_col_frmttd.title()
227
+ target_col = (
228
+ sel_target_col.lower()
229
+ .replace(" ", "_")
230
+ .replace("-", "")
231
+ .replace(":", "")
232
+ .replace("__", "_")
233
+ )
234
+ target = sel_target_col
235
+
236
+ # Contribution
237
+ if is_panel:
238
+ # read tuned mixedLM model
239
+ if st.session_state["is_tuned_model"][target_col] == True:
240
+
241
+ model_dict = retrieve_pkl_object(
242
+ st.session_state["project_number"],
243
+ "Model_Tuning",
244
+ "tuned_model",
245
+ schema,
246
+ ) # db
247
+
248
+ saved_models = list(model_dict.keys())
249
+ required_saved_models = [
250
+ m.split("__")[0]
251
+ for m in saved_models
252
+ if m.split("__")[1] == target_col
253
+ ]
254
+
255
+ sel_model = required_saved_models[
256
+ 0
257
+ ] # only 1 tuned model available per resp metric
258
+
259
+ sel_model_dict = model_dict[sel_model + "__" + target_col]
260
+
261
+ model = sel_model_dict["Model_object"]
262
+ X_train = sel_model_dict["X_train_tuned"]
263
+ X_test = sel_model_dict["X_test_tuned"]
264
+ best_feature_set = sel_model_dict["feature_set"]
265
+
266
+ # Calculate contributions
267
+
268
+ st.session_state["orig_media_data"].columns = [
269
+ col.lower()
270
+ .replace(".", "_")
271
+ .replace("@", "_")
272
+ .replace(" ", "_")
273
+ .replace("-", "")
274
+ .replace(":", "")
275
+ .replace("__", "_")
276
+ for col in st.session_state["orig_media_data"].columns
277
+ ]
278
+
279
+ media_data = st.session_state["media_data"]
280
+
281
+ contri_df = pd.DataFrame()
282
+
283
+ y = []
284
+ y_pred = []
285
+
286
+ random_eff_df = get_random_effects(media_data, panel_col, model)
287
+ random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
288
+ random_eff_df["panel_effect"] = (
289
+ random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
290
+ )
291
+
292
+ coef_df = pd.DataFrame(model.fe_params)
293
+ coef_df.reset_index(inplace=True)
294
+ coef_df.columns = ["feature", "coef"]
295
+
296
+ x_train_contribution = X_train.copy()
297
+ x_test_contribution = X_test.copy()
298
+
299
+ # preprocessing not needed since X_train is already preprocessed
300
+ # X1, X2 = process_train_and_test(x_train_contribution, x_test_contribution, best_feature_set, panel_col, target_col)
301
+ # x_train_contribution[best_feature_set] = X1[best_feature_set]
302
+ # x_test_contribution[best_feature_set] = X2[best_feature_set]
303
+
304
+ x_train_contribution = mdf_predict(
305
+ x_train_contribution, model, random_eff_df
306
+ )
307
+ x_test_contribution = mdf_predict(x_test_contribution, model, random_eff_df)
308
+
309
+ x_train_contribution = pd.merge(
310
+ x_train_contribution,
311
+ random_eff_df[[panel_col, "panel_effect"]],
312
+ on=panel_col,
313
+ how="left",
314
+ )
315
+ x_test_contribution = pd.merge(
316
+ x_test_contribution,
317
+ random_eff_df[[panel_col, "panel_effect"]],
318
+ on=panel_col,
319
+ how="left",
320
+ )
321
+
322
+ for i in range(len(coef_df))[1:]:
323
+ coef = coef_df.loc[i, "coef"]
324
+ col = coef_df.loc[i, "feature"]
325
+ if col.lower() != "intercept":
326
+ x_train_contribution[str(col) + "_contr"] = (
327
+ coef * x_train_contribution[col]
328
+ )
329
+ x_test_contribution[str(col) + "_contr"] = (
330
+ coef * x_train_contribution[col]
331
+ )
332
+
333
+ tuning_cols = [
334
+ c
335
+ for c in x_train_contribution.filter(regex="contr").columns
336
+ if c
337
+ in [
338
+ "day_of_week_contr",
339
+ "Trend_contr",
340
+ "sine_wave_contr",
341
+ "cosine_wave_contr",
342
+ ]
343
+ ]
344
+ flag_cols = [
345
+ c
346
+ for c in x_train_contribution.filter(regex="contr").columns
347
+ if "_flag" in c
348
+ ]
349
+
350
+ # add exogenous contribution to base
351
+ all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
352
+ all_exog_vars = [
353
+ var.lower()
354
+ .replace(".", "_")
355
+ .replace("@", "_")
356
+ .replace(" ", "_")
357
+ .replace("-", "")
358
+ .replace(":", "")
359
+ .replace("__", "_")
360
+ for var in all_exog_vars
361
+ ]
362
+ exog_cols = []
363
+ if len(all_exog_vars) > 0:
364
+ for col in x_train_contribution.filter(regex="contr").columns:
365
+ if (
366
+ len([exog_var for exog_var in all_exog_vars if exog_var in col])
367
+ > 0
368
+ ):
369
+ exog_cols.append(col)
370
+
371
+ base_cols = ["panel_effect"] + flag_cols + tuning_cols + exog_cols
372
+
373
+ x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(
374
+ axis=1
375
+ )
376
+ x_train_contribution.drop(columns=base_cols, inplace=True)
377
+ x_test_contribution["base_contr"] = x_test_contribution[base_cols].sum(
378
+ axis=1
379
+ )
380
+ x_test_contribution.drop(columns=base_cols, inplace=True)
381
+
382
+ overall_contributions = pd.concat(
383
+ [x_train_contribution, x_test_contribution]
384
+ ).reset_index(drop=True)
385
+
386
+ overview_test_data_prep_panel(
387
+ overall_contributions,
388
+ st.session_state["orig_media_data"],
389
+ st.session_state["spends_data"],
390
+ date_col,
391
+ panel_col,
392
+ target_col,
393
+ )
394
+
395
+ else: # NON PANEL
396
+ if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
397
+ # with open(
398
+ # os.path.join(st.session_state["project_path"], "tuned_model.pkl"),
399
+ # "rb",
400
+ # ) as file:
401
+ # model_dict = pickle.load(file)
402
+
403
+ model_dict = retrieve_pkl_object(
404
+ st.session_state["project_number"],
405
+ "Model_Tuning",
406
+ "tuned_model",
407
+ schema,
408
+ ) # db
409
+
410
+ saved_models = list(model_dict.keys())
411
+ required_saved_models = [
412
+ m.split("__")[0]
413
+ for m in saved_models
414
+ if m.split("__")[1] == target_col
415
+ ]
416
+
417
+ sel_model = required_saved_models[
418
+ 0
419
+ ] # only 1 tuned model available per resp metric
420
+ sel_model_dict = model_dict[sel_model + "__" + target_col]
421
+
422
+ model = sel_model_dict["Model_object"]
423
+ X_train = sel_model_dict["X_train_tuned"]
424
+ X_test = sel_model_dict["X_test_tuned"]
425
+ best_feature_set = sel_model_dict["feature_set"]
426
+
427
+ x_train_contribution = X_train.copy()
428
+ x_test_contribution = X_test.copy()
429
+
430
+ x_train_contribution["pred"] = model.predict(
431
+ x_train_contribution[best_feature_set]
432
+ )
433
+ x_test_contribution["pred"] = model.predict(
434
+ x_test_contribution[best_feature_set]
435
+ )
436
+
437
+ coef_df = pd.DataFrame(model.params)
438
+ coef_df.reset_index(inplace=True)
439
+ coef_df.columns = ["feature", "coef"]
440
+
441
+ # st.write(coef_df)
442
+ for i in range(len(coef_df)):
443
+ coef = coef_df.loc[i, "coef"]
444
+ col = coef_df.loc[i, "feature"]
445
+ if col != "const":
446
+ x_train_contribution[str(col) + "_contr"] = (
447
+ coef * x_train_contribution[col]
448
+ )
449
+ x_test_contribution[str(col) + "_contr"] = (
450
+ coef * x_test_contribution[col]
451
+ )
452
+ else:
453
+ x_train_contribution["const"] = coef
454
+ x_test_contribution["const"] = coef
455
+
456
+ tuning_cols = [
457
+ c
458
+ for c in x_train_contribution.filter(regex="contr").columns
459
+ if c
460
+ in [
461
+ "day_of_week_contr",
462
+ "Trend_contr",
463
+ "sine_wave_contr",
464
+ "cosine_wave_contr",
465
+ ]
466
+ ]
467
+ flag_cols = [
468
+ c
469
+ for c in x_train_contribution.filter(regex="contr").columns
470
+ if "_flag" in c
471
+ ]
472
+
473
+ # add exogenous contribution to base
474
+ all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
475
+ all_exog_vars = [
476
+ var.lower()
477
+ .replace(".", "_")
478
+ .replace("@", "_")
479
+ .replace(" ", "_")
480
+ .replace("-", "")
481
+ .replace(":", "")
482
+ .replace("__", "_")
483
+ for var in all_exog_vars
484
+ ]
485
+ exog_cols = []
486
+ if len(all_exog_vars) > 0:
487
+ for col in x_train_contribution.filter(regex="contr").columns:
488
+ if (
489
+ len([exog_var for exog_var in all_exog_vars if exog_var in col])
490
+ > 0
491
+ ):
492
+ exog_cols.append(col)
493
+
494
+ base_cols = ["const"] + flag_cols + tuning_cols + exog_cols
495
+ # st.write(base_cols)
496
+ x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(
497
+ axis=1
498
+ )
499
+ x_train_contribution.drop(columns=base_cols, inplace=True)
500
+
501
+ x_test_contribution["base_contr"] = x_test_contribution[base_cols].sum(
502
+ axis=1
503
+ )
504
+ x_test_contribution.drop(columns=base_cols, inplace=True)
505
+ # x_test_contribution.to_csv("Test/test_contr.csv", index=False)
506
+
507
+ overall_contributions = pd.concat(
508
+ [x_train_contribution, x_test_contribution]
509
+ ).reset_index(drop=True)
510
+ # overall_contributions.to_csv("Test/overall_contributions.csv", index=False)
511
+
512
+ overview_test_data_prep_nonpanel(
513
+ overall_contributions,
514
+ st.session_state["orig_media_data"].copy(),
515
+ st.session_state["spends_data"].copy(),
516
+ date_col,
517
+ target_col,
518
+ )
519
+ # for k, v in st.session_sta
520
+ # te.items():
521
+
522
+ # if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
523
+ # st.session_state[k] = v
524
+
525
+ # authenticator = st.session_state.get('authenticator')
526
+
527
+ # if authenticator is None:
528
+ # authenticator = load_authenticator()
529
+
530
+ # name, authentication_status, username = authenticator.login('Login', 'main')
531
+ # auth_status = st.session_state['authentication_status']
532
+
533
+ # if auth_status:
534
+ # authenticator.logout('Logout', 'main')
535
+
536
+ # is_state_initiaized = st.session_state.get('initialized',False)
537
+ # if not is_state_initiaized:
538
+
539
+ min_date = X_train[date_col].min().date()
540
+ max_date = X_test[date_col].max().date()
541
+ if "media_performance" not in st.session_state["project_dct"]:
542
+ st.session_state["project_dct"]["media_performance"] = {
543
+ "start_date": None,
544
+ "end_date": None,
545
+ }
546
+
547
+ start_default = st.session_state["project_dct"]["media_performance"].get(
548
+ "start_date", None
549
+ )
550
+ start_default = start_default if start_default is not None else min_date
551
+ start_default = start_default if start_default > min_date else min_date
552
+ start_default = start_default if start_default < max_date else min_date
553
+
554
+ end_default = st.session_state["project_dct"]["media_performance"].get(
555
+ "end_date", None
556
+ )
557
+ end_default = end_default if end_default is not None else max_date
558
+ end_default = end_default if end_default > min_date else max_date
559
+ end_default = end_default if end_default < max_date else max_date
560
+
561
+ st.write("Select a timeline for analysis")
562
+ date_columns = st.columns(2)
563
+
564
+ with date_columns[0]:
565
+ start_date = st.date_input(
566
+ "Select Start Date",
567
+ start_default,
568
+ min_value=min_date,
569
+ max_value=max_date,
570
+ )
571
+ if (start_date < min_date) or (start_date > max_date):
572
+ st.error("Please select dates in the range of the dates in the data")
573
+ st.stop()
574
+ end_default = (
575
+ end_default if end_default > start_date + timedelta(days=1) else max_date
576
+ )
577
+ with date_columns[1]:
578
+ end_default = (
579
+ end_default
580
+ if pd.Timestamp(end_default) >= pd.Timestamp(start_date)
581
+ else start_date
582
+ )
583
+
584
+ end_date = st.date_input(
585
+ "Select End Date",
586
+ end_default,
587
+ min_value=start_date + timedelta(days=1),
588
+ max_value=max_date,
589
+ )
590
+ if (
591
+ (start_date < min_date)
592
+ or (end_date < min_date)
593
+ or (start_date > max_date)
594
+ or (end_date > max_date)
595
+ ):
596
+ st.error("Please select dates in the range of the dates in the data")
597
+ st.stop()
598
+ if end_date < start_date + timedelta(days=1):
599
+ st.error("Please select end date after start date")
600
+ st.stop()
601
+
602
+ st.session_state["project_dct"]["media_performance"]["start_date"] = start_date
603
+ st.session_state["project_dct"]["media_performance"]["end_date"] = end_date
604
+
605
+ st.header("Overview of Previous Media Spend")
606
+
607
+ initialize_data_cmp(target_col, is_panel, panel_col, start_date, end_date)
608
+ scenario = st.session_state["scenario"]
609
+ raw_df = st.session_state["raw_df"]
610
+
611
+ columns = st.columns(2)
612
+
613
+ with columns[0]:
614
+ st.metric(
615
+ label="Media Spend",
616
+ value=format_numbers(float(scenario.actual_total_spends)),
617
+ )
618
+ ###print(f"##################### {scenario.actual_total_sales} ##################")
619
+ with columns[1]:
620
+ st.metric(
621
+ label=sel_target_col_frmttd,
622
+ value=format_numbers(
623
+ float(scenario.actual_total_sales), include_indicator=False
624
+ ),
625
+ )
626
+
627
+ actual_summary_df = create_channel_summary(scenario, sel_target_col_frmttd)
628
+ actual_summary_df["Channel"] = actual_summary_df["Channel"].apply(
629
+ channel_name_formating
630
+ )
631
+
632
+ columns = st.columns((3, 1))
633
+ with columns[0]:
634
+ with st.expander("Channel wise overview"):
635
+ st.markdown(
636
+ actual_summary_df.style.set_table_styles(
637
+ [
638
+ {
639
+ "selector": "th",
640
+ "props": [("background-color", "#f6dcc7")],
641
+ },
642
+ {
643
+ "selector": "tr:nth-child(even)",
644
+ "props": [("background-color", "#f6dcc7")],
645
+ },
646
+ ]
647
+ ).to_html(),
648
+ unsafe_allow_html=True,
649
+ )
650
+
651
+ st.markdown("<hr>", unsafe_allow_html=True)
652
+ ##############################
653
+
654
+ st.plotly_chart(
655
+ create_contribution_pie(scenario, sel_target_col_frmttd),
656
+ use_container_width=True,
657
+ )
658
+ st.markdown("<hr>", unsafe_allow_html=True)
659
+
660
+ ################################3
661
+ st.plotly_chart(
662
+ create_contribuion_stacked_plot(scenario, sel_target_col_frmttd),
663
+ use_container_width=True,
664
+ )
665
+ st.markdown("<hr>", unsafe_allow_html=True)
666
+ #######################################
667
+
668
+ selected_channel_name = st.selectbox(
669
+ "Channel",
670
+ st.session_state["channels_list"] + ["non media"],
671
+ format_func=channel_name_formating,
672
+ )
673
+ selected_channel = scenario.channels.get(selected_channel_name, None)
674
+
675
+ st.plotly_chart(
676
+ create_channel_spends_sales_plot(selected_channel, sel_target_col_frmttd),
677
+ use_container_width=True,
678
+ )
679
+
680
+ st.markdown("<hr>", unsafe_allow_html=True)
681
+
682
+ if st.button("Save this session", use_container_width=True):
683
+
684
+ project_dct_pkl = pickle.dumps(st.session_state["project_dct"])
685
+
686
+ update_db(
687
+ st.session_state["project_number"],
688
+ "Current_Media_Performance",
689
+ "project_dct",
690
+ project_dct_pkl,
691
+ schema,
692
+ # resp_mtrc=None,
693
+ ) # db
694
+
695
+ st.success("Session Saved!")
696
+
697
+ # Remove "response_metric_" from the start and "_total" from the end
698
+ if str(target_col).startswith("response_metric_"):
699
+ target_col = target_col.replace("response_metric_", "", 1)
700
+
701
+ # Remove the last 6 characters (length of "_total")
702
+ if str(target_col).endswith("_total"):
703
+ target_col = target_col[:-6]
704
+
705
+ if (
706
+ st.session_state["project_dct"]["current_media_performance"][
707
+ "model_outputs"
708
+ ][target_col]
709
+ is not None
710
+ ):
711
+ if (
712
+ len(
713
+ st.session_state["project_dct"]["current_media_performance"][
714
+ "model_outputs"
715
+ ][target_col]["contribution_data"]
716
+ )
717
+ > 0
718
+ ):
719
+ st.download_button(
720
+ label="Download Contribution File",
721
+ data=st.session_state["project_dct"]["current_media_performance"][
722
+ "model_outputs"
723
+ ][target_col]["contribution_data"].to_csv(),
724
+ file_name="contributions.csv",
725
+ key="dwnld_contr",
726
+ )
727
+ except:
728
+ exc_type, exc_value, exc_traceback = sys.exc_info()
729
+ error_message = "".join(
730
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
731
+ )
732
+ log_message("error", f"Error: {error_message}", "Current Media Performance")
733
+ st.warning("An error occured, please try again", icon="⚠️")
pages/8_Response_Curves.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(
4
+ page_title="Response Curves",
5
+ page_icon="⚖️",
6
+ layout="wide",
7
+ initial_sidebar_state="collapsed",
8
+ )
9
+
10
+ # Disable +/- for number input
11
+ st.markdown(
12
+ """
13
+ <style>
14
+ button.step-up {display: none;}
15
+ button.step-down {display: none;}
16
+ div[data-baseweb] {border-radius: 4px;}
17
+ </style>""",
18
+ unsafe_allow_html=True,
19
+ )
20
+
21
+ import sys
22
+ import json
23
+ import pickle
24
+ import traceback
25
+ import numpy as np
26
+ import pandas as pd
27
+ import plotly.express as px
28
+ import plotly.graph_objects as go
29
+ from post_gres_cred import db_cred
30
+ from sklearn.metrics import r2_score
31
+ from log_application import log_message
32
+ from utilities import project_selection, update_db, set_header, load_local_css
33
+ from utilities import (
34
+ get_panels_names,
35
+ get_metrics_names,
36
+ name_formating,
37
+ generate_rcs_data,
38
+ load_rcs_metadata_files,
39
+ )
40
+
41
+ schema = db_cred["schema"]
42
+ load_local_css("styles.css")
43
+ set_header()
44
+
45
+
46
+ # Initialize project name session state
47
+ if "project_name" not in st.session_state:
48
+ st.session_state["project_name"] = None
49
+
50
+ # Fetch project dictionary
51
+ if "project_dct" not in st.session_state:
52
+ project_selection()
53
+ st.stop()
54
+
55
+ # Display Username and Project Name
56
+ if "username" in st.session_state and st.session_state["username"] is not None:
57
+
58
+ cols1 = st.columns([2, 1])
59
+
60
+ with cols1[0]:
61
+ st.markdown(f"**Welcome {st.session_state['username']}**")
62
+ with cols1[1]:
63
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
64
+
65
+
66
+ # Function to build s curve
67
+ def s_curve(x, K, b, a, x0):
68
+ return K / (1 + b * np.exp(-a * (x - x0)))
69
+
70
+
71
+ # Function to update the RCS parameters in the modified RCS metadata data
72
+ def modify_rcs_parameters(metrics_selected, panel_selected, channel_selected):
73
+ # Define unique keys for each parameter based on the selection
74
+ K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
75
+ b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
76
+ a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
77
+ x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
78
+
79
+ # Retrieve the updated parameters from session state
80
+ K_updated, b_updated, a_updated, x0_updated = (
81
+ st.session_state[K_key],
82
+ st.session_state[b_key],
83
+ st.session_state[a_key],
84
+ st.session_state[x0_key],
85
+ )
86
+
87
+ # Load the existing modified RCS data
88
+ rcs_data_modified = st.session_state["project_dct"]["response_curves"][
89
+ "modified_metadata_file"
90
+ ]
91
+
92
+ # Update the RCS parameters for the selected metric and panel
93
+ rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
94
+ "K": K_updated,
95
+ "b": b_updated,
96
+ "a": a_updated,
97
+ "x0": x0_updated,
98
+ }
99
+
100
+
101
+ # Function to reset the parameters to their default values
102
+ def reset_parameters(
103
+ metrics_selected, panel_selected, channel_selected, original_channel_data
104
+ ):
105
+ # Define unique keys for each parameter based on the selection
106
+ K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
107
+ b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
108
+ a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
109
+ x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
110
+
111
+ # Reset session state values to original data
112
+ del st.session_state[K_key]
113
+ del st.session_state[b_key]
114
+ del st.session_state[a_key]
115
+ del st.session_state[x0_key]
116
+
117
+ # Reset the modified metadata file with original parameters
118
+ rcs_data_modified = st.session_state["project_dct"]["response_curves"][
119
+ "modified_metadata_file"
120
+ ]
121
+
122
+ # Update the parameters in the modified data to the original values
123
+ rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
124
+ "K": original_channel_data["K"],
125
+ "b": original_channel_data["b"],
126
+ "a": original_channel_data["a"],
127
+ "x0": original_channel_data["x0"],
128
+ }
129
+
130
+ # Update the modified metadata
131
+ st.session_state["project_dct"]["response_curves"][
132
+ "modified_metadata_file"
133
+ ] = rcs_data_modified
134
+
135
+
136
+ # Function to generate updated RCS parameter DataFrame
137
+ @st.cache_data(show_spinner=False)
138
+ def updated_parm_gen(original_data, modified_data, metrics_selected, panel_selected):
139
+ # Retrieve the data for the selected metric and panel
140
+ original_data_selection = original_data[metrics_selected][panel_selected]
141
+ modified_data_selection = modified_data[metrics_selected][panel_selected]
142
+
143
+ # Initialize an empty list to hold the data for the DataFrame
144
+ data = []
145
+
146
+ # Iterate through each channel in the selected metric and panel
147
+ for channel in original_data_selection:
148
+ # Extract original parameters
149
+ K_o, b_o, a_o, x0_o = (
150
+ original_data_selection[channel]["K"],
151
+ original_data_selection[channel]["b"],
152
+ original_data_selection[channel]["a"],
153
+ original_data_selection[channel]["x0"],
154
+ )
155
+ # Extract modified parameters
156
+ K_m, b_m, a_m, x0_m = (
157
+ modified_data_selection[channel]["K"],
158
+ modified_data_selection[channel]["b"],
159
+ modified_data_selection[channel]["a"],
160
+ modified_data_selection[channel]["x0"],
161
+ )
162
+
163
+ # Check if any parameters differ
164
+ if (K_o != K_m) or (b_o != b_m) or (a_o != a_m) or (x0_o != x0_m):
165
+ # Append the data to the list only if there is a difference
166
+ data.append(
167
+ {
168
+ "Metric": name_formating(metrics_selected),
169
+ "Panel": name_formating(panel_selected),
170
+ "Channel": name_formating(channel),
171
+ "K (Original)": K_o,
172
+ "b (Original)": b_o,
173
+ "a (Original)": a_o,
174
+ "x0 (Original)": x0_o,
175
+ "K (Modified)": K_m,
176
+ "b (Modified)": b_m,
177
+ "a (Modified)": a_m,
178
+ "x0 (Modified)": x0_m,
179
+ }
180
+ )
181
+
182
+ # Create a DataFrame from the collected data
183
+ df = pd.DataFrame(data)
184
+
185
+ return df
186
+
187
+
188
+ # Function to create JSON file for RCS data
189
+ @st.cache_data(show_spinner=False)
190
+ def create_json_file():
191
+ return json.dumps(
192
+ st.session_state["project_dct"]["response_curves"]["modified_metadata_file"],
193
+ indent=4,
194
+ )
195
+
196
+
197
+ try:
198
+ # Page Title
199
+ st.title("Response Curves")
200
+
201
+ # Retrieve the list of all metric names from the specified directory
202
+ metrics_list = get_metrics_names()
203
+
204
+ # Check if there are any metrics available in the metrics list
205
+ if not metrics_list:
206
+ # Display a warning message to the user if no metrics are found
207
+ st.warning(
208
+ "Please tune at least one model to generate response curves data.",
209
+ icon="⚠️",
210
+ )
211
+
212
+ # Log message
213
+ log_message(
214
+ "warning",
215
+ "Please tune at least one model to generate response curves data.",
216
+ "Response Curves",
217
+ )
218
+
219
+ # Stop further execution as there is no data to process
220
+ st.stop()
221
+
222
+ # Widget columns
223
+ metric_col, channel_col, panel_col, save_progress_col = st.columns(4)
224
+
225
+ # Metrics Selection
226
+ metrics_selected = metric_col.selectbox(
227
+ "Response Metrics",
228
+ sorted(metrics_list),
229
+ format_func=name_formating,
230
+ key="response_metrics_selectbox",
231
+ index=0,
232
+ )
233
+
234
+ # Retrieve the list of all panel names for specified Metrics
235
+ panel_list = get_panels_names(metrics_selected)
236
+
237
+ # Panel Selection
238
+ panel_selected = panel_col.selectbox(
239
+ "Panel",
240
+ sorted(panel_list),
241
+ format_func=name_formating,
242
+ key="panel_selected_selectbox",
243
+ index=0,
244
+ )
245
+
246
+ # Save Progress
247
+ with save_progress_col:
248
+ st.write("####") # Padding
249
+ save_progress_placeholder = st.empty()
250
+
251
+ # Placeholder to display message and spinner
252
+ message_spinner_placeholder = st.container()
253
+
254
+ # Save page progress
255
+ with message_spinner_placeholder, st.spinner("Saving Progress ..."):
256
+ if save_progress_placeholder.button("Save Progress", use_container_width=True):
257
+ # Update DB
258
+ update_db(
259
+ prj_id=st.session_state["project_number"],
260
+ page_nam="Response Curves",
261
+ file_nam="project_dct",
262
+ pkl_obj=pickle.dumps(st.session_state["project_dct"]),
263
+ schema=schema,
264
+ )
265
+
266
+ # Store the message details in session state
267
+ message_spinner_placeholder.success(
268
+ "Progress saved successfully!", icon="💾"
269
+ )
270
+ st.toast("Progress saved successfully!", icon="💾")
271
+
272
+ # Log message
273
+ log_message("info", "Progress saved successfully!", "Response Curves")
274
+
275
+ # Check if the RCS metadata file does not exist
276
+ if (
277
+ st.session_state["project_dct"]["response_curves"]["original_metadata_file"]
278
+ is None
279
+ or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
280
+ is None
281
+ ):
282
+ # RCS metadata file does not exist. Generating new RCS data
283
+ generate_rcs_data()
284
+
285
+ # Log message
286
+ log_message(
287
+ "info",
288
+ "RCS metadata file does not exist. Generating new RCS data.",
289
+ "Response Curves",
290
+ )
291
+
292
+ # Load metadata files if they exist
293
+ original_data, modified_data = load_rcs_metadata_files()
294
+
295
+ # Retrieve the list of all channels names for specified Metrics and Panel
296
+ chanel_list_final = list(original_data[metrics_selected][panel_selected].keys())
297
+
298
+ # Channel Selection
299
+ channel_selected = channel_col.selectbox(
300
+ "Channel",
301
+ sorted(chanel_list_final),
302
+ format_func=name_formating,
303
+ key="selected_channel_name_selectbox",
304
+ )
305
+
306
+ # Extract original channel data for the selected metric, panel, and channel
307
+ original_channel_data = original_data[metrics_selected][panel_selected][
308
+ channel_selected
309
+ ]
310
+
311
+ # Extract modified channel data for the same metric, panel, and channel
312
+ modified_channel_data = modified_data[metrics_selected][panel_selected][
313
+ channel_selected
314
+ ]
315
+
316
+ # X and Y values for plotting
317
+ x = original_channel_data["x"]
318
+ y = original_channel_data["y"]
319
+
320
+ # Scaling factor for X values and range for S-curve plotting
321
+ power = original_channel_data["power"]
322
+ x_plot = original_channel_data["x_plot"]
323
+
324
+ # Original S-curve parameters
325
+ K_orig = original_channel_data["K"]
326
+ b_orig = original_channel_data["b"]
327
+ a_orig = original_channel_data["a"]
328
+ x0_orig = original_channel_data["x0"]
329
+
330
+ # Modified S-curve parameters (user-adjusted)
331
+ K_mod = modified_channel_data["K"]
332
+ b_mod = modified_channel_data["b"]
333
+ a_mod = modified_channel_data["a"]
334
+ x0_mod = modified_channel_data["x0"]
335
+
336
+ # Create a scatter plot for the original data points
337
+ fig = px.scatter(
338
+ x=x,
339
+ y=y,
340
+ title="Original and Modified S-Curve Plot",
341
+ labels={"x": "Spends", "y": name_formating(metrics_selected)},
342
+ )
343
+
344
+ # Add the modified S-curve trace
345
+ fig.add_trace(
346
+ go.Scatter(
347
+ x=x_plot,
348
+ y=s_curve(
349
+ np.array(x_plot) / 10**power,
350
+ K_mod,
351
+ b_mod,
352
+ a_mod,
353
+ x0_mod,
354
+ ),
355
+ line=dict(color="red"),
356
+ name="Modified",
357
+ ),
358
+ )
359
+
360
+ # Add the original S-curve trace
361
+ fig.add_trace(
362
+ go.Scatter(
363
+ x=x_plot,
364
+ y=s_curve(
365
+ np.array(x_plot) / 10**power,
366
+ K_orig,
367
+ b_orig,
368
+ a_orig,
369
+ x0_orig,
370
+ ),
371
+ line=dict(color="rgba(0, 255, 0, 0.6)"), # Semi-transparent green
372
+ name="Original",
373
+ ),
374
+ )
375
+
376
+ # Customize the layout of the plot
377
+ fig.update_layout(
378
+ title="Comparison of Original and Modified Response-Curves",
379
+ xaxis_title="Input (Clicks, Impressions, etc..)",
380
+ yaxis_title=name_formating(metrics_selected),
381
+ legend_title="Curve Type",
382
+ )
383
+
384
+ # Display s-curve
385
+ st.plotly_chart(fig, use_container_width=True)
386
+
387
+ # Calculate R-squared for the original curve
388
+ y_orig_pred = s_curve(np.array(x) / 10**power, K_orig, b_orig, a_orig, x0_orig)
389
+ r2_orig = r2_score(y, y_orig_pred)
390
+
391
+ # Calculate R-squared for the modified curve
392
+ y_mod_pred = s_curve(np.array(x) / 10**power, K_mod, b_mod, a_mod, x0_mod)
393
+ r2_mod = r2_score(y, y_mod_pred)
394
+
395
+ # Calculate the difference in R-squared
396
+ r2_diff = r2_mod - r2_orig
397
+
398
+ # Display R-squared metrics
399
+ st.write("## R-squared Comparison")
400
+ r2_col = st.columns(3)
401
+
402
+ r2_col[0].metric("R-squared (Original)", f"{r2_orig:.2f}")
403
+ r2_col[1].metric("R-squared (Modified)", f"{r2_mod:.2f}")
404
+ r2_col[2].metric("Difference in R-squared", f"{r2_diff:.2f}")
405
+
406
+ # Define unique keys for each parameter based on the selection
407
+ K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
408
+ b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
409
+ a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
410
+ x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
411
+
412
+ # Initialize session state keys if they do not exist
413
+ if K_key not in st.session_state:
414
+ st.session_state[K_key] = K_mod
415
+ if b_key not in st.session_state:
416
+ st.session_state[b_key] = b_mod
417
+ if a_key not in st.session_state:
418
+ st.session_state[a_key] = a_mod
419
+ if x0_key not in st.session_state:
420
+ st.session_state[x0_key] = x0_mod
421
+
422
+ # RCS parameters input
423
+ rsc_ip_col = st.columns(4)
424
+ with rsc_ip_col[0]:
425
+ K_updated = st.number_input(
426
+ "K",
427
+ step=0.001,
428
+ min_value=0.000000,
429
+ format="%.6f",
430
+ on_change=modify_rcs_parameters,
431
+ args=(metrics_selected, panel_selected, channel_selected),
432
+ key=K_key,
433
+ )
434
+ with rsc_ip_col[1]:
435
+ b_updated = st.number_input(
436
+ "b",
437
+ step=0.001,
438
+ min_value=0.000000,
439
+ format="%.6f",
440
+ on_change=modify_rcs_parameters,
441
+ args=(metrics_selected, panel_selected, channel_selected),
442
+ key=b_key,
443
+ )
444
+ with rsc_ip_col[2]:
445
+ a_updated = st.number_input(
446
+ "a",
447
+ step=0.001,
448
+ min_value=0.000000,
449
+ format="%.6f",
450
+ on_change=modify_rcs_parameters,
451
+ args=(metrics_selected, panel_selected, channel_selected),
452
+ key=a_key,
453
+ )
454
+ with rsc_ip_col[3]:
455
+ x0_updated = st.number_input(
456
+ "x0",
457
+ step=0.001,
458
+ min_value=0.000000,
459
+ format="%.6f",
460
+ on_change=modify_rcs_parameters,
461
+ args=(metrics_selected, panel_selected, channel_selected),
462
+ key=x0_key,
463
+ )
464
+
465
+ # Create columns for Reset and Download buttons
466
+ reset_download_col = st.columns(2)
467
+ with reset_download_col[0]:
468
+ if st.button(
469
+ "Reset",
470
+ use_container_width=True,
471
+ ):
472
+ reset_parameters(
473
+ metrics_selected,
474
+ panel_selected,
475
+ channel_selected,
476
+ original_channel_data,
477
+ )
478
+
479
+ # Log message
480
+ log_message(
481
+ "info",
482
+ f"METRIC: {name_formating(metrics_selected)} ; PANEL: {name_formating(panel_selected)}, CHANNEL: {name_formating(channel_selected)} has been reset to its original value.",
483
+ "Response Curves",
484
+ )
485
+
486
+ st.rerun()
487
+
488
+ with reset_download_col[1]:
489
+ # Provide a download button for the modified RCS data
490
+ try:
491
+ # Create JSON file for RCS data
492
+ json_data = create_json_file()
493
+ st.download_button(
494
+ label="Download",
495
+ data=json_data,
496
+ file_name=f"{name_formating(metrics_selected)}_{name_formating(panel_selected)}_rcs_data.json",
497
+ mime="application/json",
498
+ use_container_width=True,
499
+ )
500
+ except:
501
+ # Download failed
502
+ pass
503
+
504
+ # Generate the DataFrame showing only non-matching parameters
505
+ updated_parm_df = updated_parm_gen(
506
+ original_data, modified_data, metrics_selected, panel_selected
507
+ )
508
+
509
+ # Display the DataFrame or show an informational message if no updates
510
+ if not updated_parm_df.empty:
511
+ st.write("## Parameter Comparison for Selected Metric and Panel")
512
+ st.dataframe(updated_parm_df, hide_index=True)
513
+ else:
514
+ st.info("No parameters are updated for the selected Metric and Panel")
515
+
516
+ except Exception as e:
517
+ # Capture the error details
518
+ exc_type, exc_value, exc_traceback = sys.exc_info()
519
+ error_message = "".join(
520
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
521
+ )
522
+
523
+ # Log message
524
+ log_message("error", f"An error occurred: {error_message}.", "Response Curves")
525
+
526
+ # Display a warning message
527
+ st.warning(
528
+ "Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
529
+ icon="⚠️",
530
+ )
pages/9_Scenario_Planner.py ADDED
The diff for this file is too large to render. See raw diff
 
post_gres_cred.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ db_cred = {'schema': "mmo_service_owner"}
2
+ schema = ""
ppt/template.txt ADDED
File without changes
ppt_utils.py ADDED
@@ -0,0 +1,1419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import pptx
4
+ from pptx import Presentation
5
+ from pptx.chart.data import CategoryChartData, ChartData
6
+ from pptx.enum.chart import XL_CHART_TYPE, XL_LEGEND_POSITION, XL_LABEL_POSITION
7
+ from pptx.enum.chart import XL_TICK_LABEL_POSITION
8
+ from pptx.util import Inches, Pt
9
+ import os
10
+ import pickle
11
+ from pathlib import Path
12
+ from sklearn.metrics import (
13
+ mean_absolute_error,
14
+ r2_score,
15
+ mean_absolute_percentage_error,
16
+ )
17
+ import streamlit as st
18
+ from collections import OrderedDict
19
+ from utilities import get_metrics_names, initialize_data, retrieve_pkl_object_without_warning
20
+ from io import BytesIO
21
+ from pptx.dml.color import RGBColor
22
+ from post_gres_cred import db_cred
23
+ schema=db_cred['schema']
24
+
25
+ from constants import (
26
+ TITLE_FONT_SIZE,
27
+ AXIS_LABEL_FONT_SIZE,
28
+ CHART_TITLE_FONT_SIZE,
29
+ AXIS_TITLE_FONT_SIZE,
30
+ DATA_LABEL_FONT_SIZE,
31
+ LEGEND_FONT_SIZE,
32
+ PIE_LEGEND_FONT_SIZE
33
+ )
34
+
35
+
36
+ def format_response_metric(target):
37
+ if target.startswith('response_metric_'):
38
+ target = target.replace('response_metric_', '')
39
+ target = target.replace("_", " ").title()
40
+ return target
41
+
42
+
43
+ def smape(actual, forecast):
44
+ # Symmetric Mape (SMAPE) eliminates shortcomings of MAPE :
45
+ ## 1. MAPE becomes insanely high when actual is close to 0
46
+ ## 2. MAPE is more favourable to underforecast than overforecast
47
+ return (1 / len(actual)) * np.sum(1 * np.abs(forecast - actual) / (np.abs(actual) + np.abs(forecast)))
48
+
49
+
50
+ def safe_num_to_per(num):
51
+ try:
52
+ return "{:.0%}".format(num)
53
+ except:
54
+ return num
55
+
56
+
57
+ # Function to convert numbers to abbreviated format
58
+ def convert_number_to_abbreviation(number):
59
+ try:
60
+ number = float(number)
61
+ if number >= 1000000:
62
+ return f'{number / 1000000:.1f} M'
63
+ elif number >= 1000:
64
+ return f'{number / 1000:.1f} K'
65
+ else:
66
+ return str(number)
67
+ except:
68
+ return number
69
+
70
+
71
+ def round_off(x, round_off_decimal=0):
72
+ # round off
73
+ try:
74
+ x = float(x)
75
+ if x < 1 and x > 0:
76
+ round_off_decimal = int(np.floor(np.abs(np.log10(x)))) + max(round_off_decimal, 1)
77
+ x = np.round(x, round_off_decimal)
78
+ elif x < 0 and x > -1:
79
+ round_off_decimal = int(np.floor(np.abs(np.log10(np.abs(x))))) + max(round_off_decimal, 1)
80
+ x = -np.round(x, round_off_decimal)
81
+ else:
82
+ x = np.round(x, round_off_decimal)
83
+ return x
84
+ except:
85
+ return x
86
+
87
+
88
+ def fill_table_placeholder(table_placeholder, slide, df, column_width=None, table_height=None):
89
+ cols = len(df.columns)
90
+ rows = len(df)
91
+
92
+ if table_height is None:
93
+ table_height = table_placeholder.height
94
+
95
+ x, y, cx, cy = table_placeholder.left, table_placeholder.top, table_placeholder.width, table_height
96
+ table = slide.shapes.add_table(rows + 1, cols, x, y, cx, cy).table
97
+
98
+ # Populate the table with data from the DataFrame
99
+ for row_idx, row in enumerate(df.values):
100
+ for col_idx, value in enumerate(row):
101
+ cell = table.cell(row_idx + 1, col_idx)
102
+ cell.text = str(value)
103
+ for col_idx, value in enumerate(df.columns):
104
+ cell = table.cell(0, col_idx)
105
+ cell.text = str(value)
106
+
107
+ if column_width is not None:
108
+ for col_idx, column_width in column_width.items():
109
+ table.columns[col_idx].width = Inches(column_width)
110
+
111
+ table_placeholder._element.getparent().remove(table_placeholder._element)
112
+
113
+
114
+ def bar_chart(chart_placeholder, slide, chart_data, titles={}, min_y=None, max_y=None, type='V', legend=True,
115
+ label_type=None, xaxis_pos=None):
116
+ x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height
117
+ if type == 'V':
118
+ graphic_frame = slide.shapes.add_chart(
119
+ XL_CHART_TYPE.COLUMN_CLUSTERED, x, y, cx, cy, chart_data
120
+ )
121
+ if type == 'H':
122
+ graphic_frame = slide.shapes.add_chart(
123
+ XL_CHART_TYPE.BAR_CLUSTERED, x, y, cx, cy, chart_data
124
+ )
125
+ chart = graphic_frame.chart
126
+
127
+ category_axis = chart.category_axis
128
+ value_axis = chart.value_axis
129
+
130
+ # Add chart title
131
+ if 'chart_title' in titles.keys():
132
+ chart.has_title = True
133
+ chart.chart_title.text_frame.text = titles['chart_title']
134
+ chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
135
+ chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)
136
+
137
+ # Add axis titles
138
+ if 'x_axis' in titles.keys():
139
+ category_axis.has_title = True
140
+ category_axis.axis_title.text_frame.text = titles['x_axis']
141
+ category_title = category_axis.axis_title.text_frame.paragraphs[0].runs[0]
142
+ category_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
143
+
144
+ if 'y_axis' in titles.keys():
145
+ value_axis.has_title = True
146
+ value_axis.axis_title.text_frame.text = titles['y_axis']
147
+ value_title = value_axis.axis_title.text_frame.paragraphs[0].runs[0]
148
+ value_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
149
+
150
+ if xaxis_pos == 'low':
151
+ category_axis.tick_label_position = XL_TICK_LABEL_POSITION.LOW
152
+
153
+ # Customize the chart
154
+ if legend:
155
+ chart.has_legend = True
156
+ chart.legend.position = XL_LEGEND_POSITION.BOTTOM
157
+ chart.legend.font.size = Pt(LEGEND_FONT_SIZE)
158
+ chart.legend.include_in_layout = False
159
+
160
+ # Adjust font size for axis labels
161
+ category_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
162
+ value_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
163
+
164
+ if min_y is not None:
165
+ value_axis.minimum_scale = min_y # Adjust this value as needed
166
+
167
+ if max_y is not None:
168
+ value_axis.maximum_scale = max_y # Adjust this value as needed
169
+
170
+ plot = chart.plots[0]
171
+ plot.has_data_labels = True
172
+ data_labels = plot.data_labels
173
+
174
+ if label_type == 'per':
175
+ data_labels.number_format = '0"%"'
176
+ elif label_type == '$':
177
+ data_labels.number_format = '$[>=1000000]#,##0.0,,"M";$[>=1000]#,##0.0,"K";$#,##0'
178
+ elif label_type == '$1':
179
+ data_labels.number_format = '$[>=1000000]#,##0,,"M";$[>=1000]#,##0,"K";$#,##0'
180
+ elif label_type == 'M':
181
+ data_labels.number_format = '#0.0,,"M"'
182
+ elif label_type == 'M1':
183
+ data_labels.number_format = '#0.00,,"M"'
184
+ elif label_type == 'K':
185
+ data_labels.number_format = '#0.0,"K"'
186
+
187
+ data_labels.font.size = Pt(DATA_LABEL_FONT_SIZE)
188
+
189
+ chart_placeholder._element.getparent().remove(chart_placeholder._element)
190
+
191
+
192
+ def line_chart(chart_placeholder, slide, chart_data, titles={}, min_y=None, max_y=None):
193
+ # Add the chart to the slide
194
+ x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height
195
+
196
+ chart = slide.shapes.add_chart(
197
+ XL_CHART_TYPE.LINE, x, y, cx, cy, chart_data
198
+ ).chart
199
+
200
+ chart.has_legend = True
201
+ chart.legend.position = XL_LEGEND_POSITION.BOTTOM
202
+ chart.legend.font.size = Pt(LEGEND_FONT_SIZE)
203
+
204
+ category_axis = chart.category_axis
205
+ value_axis = chart.value_axis
206
+
207
+ if min_y is not None:
208
+ value_axis.minimum_scale = min_y
209
+
210
+ if max_y is not None:
211
+ value_axis.maximum_scale = max_y
212
+
213
+ if min_y is not None and max_y is not None:
214
+ value_axis.major_unit = int((max_y - min_y) / 2)
215
+
216
+ if 'chart_title' in titles.keys():
217
+ chart.has_title = True
218
+ chart.chart_title.text_frame.text = titles['chart_title']
219
+ chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
220
+ chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)
221
+
222
+ if 'x_axis' in titles.keys():
223
+ category_axis.has_title = True
224
+ category_axis.axis_title.text_frame.text = titles['x_axis']
225
+ category_title = category_axis.axis_title.text_frame.paragraphs[0].runs[0]
226
+ category_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
227
+
228
+ if 'y_axis' in titles.keys():
229
+ value_axis.has_title = True
230
+ value_axis.axis_title.text_frame.text = titles['y_axis']
231
+ value_title = value_axis.axis_title.text_frame.paragraphs[0].runs[0]
232
+ value_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
233
+
234
+ # Adjust font size for axis labels
235
+ category_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
236
+ value_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
237
+
238
+ plot = chart.plots[0]
239
+ series = plot.series[1]
240
+ line = series.format.line
241
+ line.color.rgb = RGBColor(141, 47, 0)
242
+
243
+ chart_placeholder._element.getparent().remove(chart_placeholder._element)
244
+
245
+
246
+ def pie_chart(chart_placeholder, slide, chart_data, title):
247
+ # Add the chart to the slide
248
+ x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height
249
+
250
+ chart = slide.shapes.add_chart(
251
+ XL_CHART_TYPE.PIE, x, y, cx, cy, chart_data
252
+ ).chart
253
+
254
+ chart.has_legend = True
255
+ chart.legend.position = XL_LEGEND_POSITION.RIGHT
256
+ chart.legend.include_in_layout = False
257
+ chart.legend.font.size = Pt(PIE_LEGEND_FONT_SIZE)
258
+
259
+ chart.plots[0].has_data_labels = True
260
+ data_labels = chart.plots[0].data_labels
261
+ data_labels.number_format = '0%'
262
+ data_labels.position = XL_LABEL_POSITION.OUTSIDE_END
263
+ data_labels.font.size = Pt(DATA_LABEL_FONT_SIZE)
264
+
265
+ chart.has_title = True
266
+ chart.chart_title.text_frame.text = title
267
+ chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
268
+ chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)
269
+
270
+ chart_placeholder._element.getparent().remove(chart_placeholder._element)
271
+
272
+
273
+ def title_and_table(slide, title, df, column_width=None, custom_table_height=False):
274
+ placeholders = slide.placeholders
275
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
276
+ title_ph = slide.placeholders[ph_idx[0]]
277
+ title_ph.text = title
278
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
279
+
280
+ table_placeholder = slide.placeholders[ph_idx[1]]
281
+
282
+ table_height = None
283
+ if custom_table_height:
284
+ if len(df) < 4:
285
+ table_height = int(np.ceil(table_placeholder.height / 2))
286
+
287
+ fill_table_placeholder(table_placeholder, slide, df, column_width, table_height)
288
+
289
+ # try:
290
+ # font_size = 18 # default for 3*3
291
+ # if cols < 3:
292
+ # row_diff = 3 - rows
293
+ # font_size = font_size + ((row_diff)*2) # 1 row less -> 2 pt font size increase & vice versa
294
+ # else:
295
+ # row_diff = 2 - rows
296
+ # font_size = font_size + ((row_diff)*2)
297
+ # for row in table.rows:
298
+ # for cell in row.cells:
299
+ # cell.text_frame.paragraphs[0].runs[0].font.size = Pt(font_size)
300
+ # except Exception as e :
301
+ # print("**"*30)
302
+ # print(e)
303
+ # else:
304
+ # except Exception as e:
305
+ # print('table', e)
306
+ return slide
307
+
308
+
309
+ def data_import(data, bin_dict):
310
+ import_df = pd.DataFrame(columns=['Category', 'Value'])
311
+
312
+ import_df.at[0, 'Category'] = 'Date Range'
313
+
314
+ date_start = data['date'].min().date()
315
+ date_end = data['date'].max().date()
316
+ import_df.at[0, 'Value'] = str(date_start) + ' - ' + str(date_end)
317
+
318
+ import_df.at[1, 'Category'] = 'Response Metrics'
319
+ import_df.at[1, 'Value'] = ', '.join(bin_dict['Response Metrics'])
320
+
321
+ import_df.at[2, 'Category'] = 'Media Variables'
322
+ import_df.at[2, 'Value'] = ', '.join(bin_dict['Media'])
323
+
324
+ import_df.at[3, 'Category'] = 'Spend Variables'
325
+ import_df.at[3, 'Value'] = ', '.join(bin_dict['Spends'])
326
+
327
+ if bin_dict['Exogenous'] != []:
328
+ import_df.at[4, 'Category'] = 'Exogenous Variables'
329
+ import_df.at[4, 'Value'] = ', '.join(bin_dict['Exogenous'])
330
+
331
+ return import_df
332
+
333
+
334
+ def channel_groups_df(channel_groups_dct={}, bin_dict={}):
335
+ df = pd.DataFrame(columns=['Channel', 'Media Variables', 'Spend Variables'])
336
+ i = 0
337
+ for channel, vars in channel_groups_dct.items():
338
+ media_vars = ", ".join(list(set(vars).intersection(set(bin_dict["Media"]))))
339
+ spend_vars = ", ".join(list(set(vars).intersection(set(bin_dict["Spends"]))))
340
+ df.at[i, "Channel"] = channel
341
+ df.at[i, 'Media Variables'] = media_vars
342
+ df.at[i, 'Spend Variables'] = spend_vars
343
+ i += 1
344
+
345
+ return df
346
+
347
+
348
+ def transformations(transform_dict):
349
+ transform_df = pd.DataFrame(columns=['Category', 'Transformation', 'Value'])
350
+ i = 0
351
+
352
+ for category in ['Media', 'Exogenous']:
353
+ transformations = f'transformation_{category}'
354
+ category_dict = transform_dict[category]
355
+ if transformations in category_dict.keys():
356
+ for transformation in category_dict[transformations]:
357
+ transform_df.at[i, 'Category'] = category
358
+ transform_df.at[i, 'Transformation'] = transformation
359
+ transform_df.at[i, 'Value'] = str(category_dict[transformation][0]) + ' - ' + str(
360
+ category_dict[transformation][1])
361
+ i += 1
362
+ return transform_df
363
+
364
+
365
+ def model_metrics(model_dict, is_panel):
366
+ metrics_df = pd.DataFrame(
367
+ columns=[
368
+ "Response Metric",
369
+ "Model",
370
+ "R2",
371
+ "ADJR2",
372
+ "Train MAPE",
373
+ "Test MAPE"
374
+ ]
375
+ )
376
+ i = 0
377
+ for key in model_dict.keys():
378
+ target = key.split("__")[1]
379
+ metrics_df.at[i, "Response Metric"] = format_response_metric(target)
380
+ metrics_df.at[i, "Model"] = key.split("__")[0]
381
+
382
+ y = model_dict[key]["X_train_tuned"][target]
383
+
384
+ feature_set = model_dict[key]["feature_set"]
385
+
386
+ if is_panel:
387
+ random_df = get_random_effects(
388
+ media_data, panel_col, model_dict[key]["Model_object"]
389
+ )
390
+ pred = mdf_predict(
391
+ model_dict[key]["X_train_tuned"],
392
+ model_dict[key]["Model_object"],
393
+ random_df,
394
+ )["pred"]
395
+ else:
396
+ pred = model_dict[key]["Model_object"].predict(model_dict[key]["X_train_tuned"][feature_set])
397
+
398
+ ytest = model_dict[key]["X_test_tuned"][target]
399
+ if is_panel:
400
+
401
+ predtest = mdf_predict(
402
+ model_dict[key]["X_test_tuned"],
403
+ model_dict[key]["Model_object"],
404
+ random_df,
405
+ )["pred"]
406
+
407
+ else:
408
+ predtest = model_dict[key]["Model_object"].predict(model_dict[key]["X_test_tuned"][feature_set])
409
+
410
+ metrics_df.at[i, "R2"] = np.round(r2_score(y, pred), 2)
411
+ adjr2 = 1 - (1 - metrics_df.loc[i, "R2"]) * (
412
+ len(y) - 1
413
+ ) / (len(y) - len(model_dict[key]["feature_set"]) - 1)
414
+ metrics_df.at[i, "ADJR2"] = np.round(adjr2, 2)
415
+ # y = np.where(np.abs(y) < 0.00001, 0.00001, y)
416
+ metrics_df.at[i, "Train MAPE"] = np.round(smape(y, pred), 2)
417
+ metrics_df.at[i, "Test MAPE"] = np.round(smape(ytest, predtest), 2)
418
+ i += 1
419
+ metrics_df = np.round(metrics_df, 2)
420
+
421
+ return metrics_df
422
+
423
+
424
+ def model_result(slide, model_key, model_dict, model_metrics_df, date_col):
425
+ placeholders = slide.placeholders
426
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
427
+ title_ph = slide.placeholders[ph_idx[0]]
428
+ title_ph.text = model_key.split('__')[0]
429
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
430
+ target = model_key.split('__')[1]
431
+
432
+ metrics_table_placeholder = slide.placeholders[ph_idx[1]]
433
+ metrics_df = model_metrics_df[model_metrics_df['Model'] == model_key.split('__')[0]].reset_index(drop=True)
434
+
435
+ # Accuracy = 1-mape
436
+ metrics_df['Accuracy'] = 100 * (1 - metrics_df['Train MAPE'])
437
+ metrics_df['Accuracy'] = metrics_df['Accuracy'].apply(lambda x: f'{np.round(x, 0)}%')
438
+
439
+ ## Removing metrics as requested by Ioannis
440
+
441
+ metrics_df = metrics_df.drop(columns=['R2', 'ADJR2', 'Train MAPE', 'Test MAPE'])
442
+ fill_table_placeholder(metrics_table_placeholder, slide, metrics_df)
443
+
444
+ # coeff_table_placeholder = slide.placeholders[ph_idx[2]]
445
+ # coeff_df = pd.DataFrame(model_dict['Model_object'].params)
446
+ # coeff_df.reset_index(inplace=True)
447
+ # coeff_df.columns = ['Feature', 'Coefficent']
448
+ # fill_table_placeholder(coeff_table_placeholder, slide, coeff_df)
449
+
450
+ chart_placeholder = slide.placeholders[ph_idx[2]]
451
+ full_df = pd.concat([model_dict['X_train_tuned'], model_dict['X_test_tuned']])
452
+ full_df['Predicted'] = model_dict['Model_object'].predict(full_df[model_dict['feature_set']])
453
+ pred_df = full_df[[date_col, target, 'Predicted']]
454
+ pred_df.rename(columns={target: 'Actual'}, inplace=True)
455
+
456
+ # Create chart data
457
+ chart_data = CategoryChartData()
458
+ chart_data.categories = pred_df[date_col]
459
+ chart_data.add_series('Actual', pred_df['Actual'])
460
+ chart_data.add_series('Predicted', pred_df['Predicted'])
461
+
462
+ # Set range for y axis
463
+ min_y = np.floor(min(pred_df['Actual'].min(), pred_df['Predicted'].min()))
464
+ max_y = np.ceil(max(pred_df['Actual'].max(), pred_df['Predicted'].max()))
465
+
466
+ # Create the chart
467
+ line_chart(chart_placeholder=chart_placeholder,
468
+ slide=slide,
469
+ chart_data=chart_data,
470
+ titles={'chart_title': 'Actual VS Predicted',
471
+ 'x_axis': 'Date',
472
+ 'y_axis': target.title().replace('_', ' ')
473
+ },
474
+ min_y=min_y,
475
+ max_y=max_y
476
+ )
477
+
478
+ return slide
479
+
480
+
481
+ def metrics_contributions(slide, contributions_excels_dict, panel_col):
482
+ # Create data for metrics contributions
483
+ all_contribution_df = pd.DataFrame(columns=['Channel'])
484
+ target_sum_dict = {}
485
+ sort_support_dct = {}
486
+ for target in contributions_excels_dict.keys():
487
+ contribution_df = contributions_excels_dict[target]['CONTRIBUTION MMM'].copy()
488
+ if 'Date' in contribution_df.columns:
489
+ contribution_df.drop(columns=['Date'], inplace=True)
490
+ if panel_col in contribution_df.columns:
491
+ contribution_df.drop(columns=[panel_col], inplace=True)
492
+
493
+ contribution_df = pd.DataFrame(np.sum(contribution_df, axis=0)).reset_index()
494
+ contribution_df.columns = ['Channel', target]
495
+ target_sum = contribution_df[target].sum()
496
+ target_sum_dict[target] = target_sum
497
+ contribution_df[target] = 100 * contribution_df[target] / target_sum
498
+
499
+ all_contribution_df = pd.merge(all_contribution_df, contribution_df, on='Channel', how='outer')
500
+
501
+ sorted_target_sum_dict = sorted(target_sum_dict.items(), key=lambda kv: kv[1], reverse=True)
502
+ sorted_target_sum_keys = [kv[0] for kv in sorted_target_sum_dict]
503
+ if len([metric for metric in sorted_target_sum_keys if metric.lower() == 'revenue']) == 1:
504
+ rev_metric = [metric for metric in sorted_target_sum_keys if metric.lower() == 'revenue'][0]
505
+ sorted_target_sum_keys.remove(rev_metric)
506
+ sorted_target_sum_keys.append(rev_metric)
507
+ all_contribution_df = all_contribution_df[['Channel'] + sorted_target_sum_keys]
508
+
509
+ # for col in all_contribution_df.columns:
510
+ # all_contribution_df[col]=all_contribution_df[col].apply(lambda x: round_off(x,1))
511
+
512
+ # Sort Data by Average contribution of the channels keeping base first <Removed>
513
+ # all_contribution_df['avg'] = np.mean(all_contribution_df[list(contributions_excels_dict.keys())],axis=1)
514
+ # all_contribution_df['rank'] = all_contribution_df['avg'].rank(ascending=False)
515
+
516
+ # Sort data by contribution of bottom funnel metric
517
+ bottom_funnel_metric = sorted_target_sum_keys[-1]
518
+ all_contribution_df['rank'] = all_contribution_df[bottom_funnel_metric].rank(ascending=False)
519
+ all_contribution_df.loc[all_contribution_df[all_contribution_df['Channel'] == 'base'].index, 'rank'] = 0
520
+ all_contribution_df = all_contribution_df.sort_values(by='rank')
521
+ all_contribution_df.drop(columns=['rank'], inplace=True)
522
+
523
+ # Add title
524
+ placeholders = slide.placeholders
525
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
526
+ title_ph = slide.placeholders[ph_idx[0]]
527
+ title_ph.text = "Response Metrics Contributions"
528
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
529
+
530
+ for target in contributions_excels_dict.keys():
531
+ all_contribution_df[target] = all_contribution_df[target].astype(float)
532
+
533
+
534
+ # Create chart data
535
+ chart_data = CategoryChartData()
536
+ chart_data.categories = all_contribution_df['Channel']
537
+ for target in sorted_target_sum_keys:
538
+ chart_data.add_series(format_response_metric(target), all_contribution_df[target])
539
+ chart_placeholder = slide.placeholders[ph_idx[1]]
540
+
541
+ if isinstance(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])), float):
542
+
543
+ # Add the chart to the slide
544
+ bar_chart(chart_placeholder=chart_placeholder,
545
+ slide=slide,
546
+ chart_data=chart_data,
547
+ titles={'chart_title': 'Response Metrics Contributions',
548
+ # 'x_axis':'Channels',
549
+ 'y_axis': 'Contributions'},
550
+ min_y=np.floor(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime']))),
551
+ max_y=np.ceil(np.max(all_contribution_df.select_dtypes(exclude=['object', 'datetime']))),
552
+ type='V',
553
+ label_type='per'
554
+ )
555
+ else:
556
+
557
+ bar_chart(chart_placeholder=chart_placeholder,
558
+ slide=slide,
559
+ chart_data=chart_data,
560
+ titles={'chart_title': 'Response Metrics Contributions',
561
+ # 'x_axis':'Channels',
562
+ 'y_axis': 'Contributions'},
563
+ min_y=np.floor(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
564
+ max_y=np.ceil(np.max(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
565
+ type='V',
566
+ label_type='per'
567
+ )
568
+
569
+ return slide
570
+
571
+
572
+ def model_media_performance(slide, target, contributions_excels_dict, date_col='Date', is_panel=False,
573
+ panel_col='panel'):
574
+ # Add title
575
+ placeholders = slide.placeholders
576
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
577
+ title_ph = slide.placeholders[ph_idx[0]]
578
+ title_ph.text = "Media Performance - " + target.title().replace("_", " ")
579
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
580
+
581
+ # CONTRIBUTION CHART
582
+ # Create contribution data
583
+ contribution_df = contributions_excels_dict[target]['CONTRIBUTION MMM']
584
+ if panel_col in contribution_df.columns:
585
+ contribution_df.drop(columns=[panel_col], inplace=True)
586
+ # contribution_df.drop(columns=[date_col], inplace=True)
587
+ contribution_df = pd.DataFrame(np.sum(contribution_df, axis=0)).reset_index()
588
+ contribution_df.columns = ['Channel', format_response_metric(target)]
589
+ contribution_df['Channel'] = contribution_df['Channel'].apply(lambda x: x.title())
590
+ target_sum = contribution_df[format_response_metric(target)].sum()
591
+ contribution_df[format_response_metric(target)] = contribution_df[format_response_metric(target)] / target_sum
592
+ contribution_df.sort_values(by=['Channel'], ascending=False, inplace=True)
593
+
594
+ # for col in contribution_df.columns:
595
+ # contribution_df[col] = contribution_df[col].apply(lambda x : round_off(x))
596
+
597
+ # Create Chart Data
598
+ chart_data = ChartData()
599
+ chart_data.categories = contribution_df['Channel']
600
+ chart_data.add_series('Contribution', contribution_df[format_response_metric(target)])
601
+
602
+ chart_placeholder = slide.placeholders[ph_idx[2]]
603
+ pie_chart(chart_placeholder=chart_placeholder,
604
+ slide=slide,
605
+ chart_data=chart_data,
606
+ title='Contribution')
607
+
608
+ # SPENDS CHART
609
+
610
+ initialize_data(panel='aggregated', metrics=target)
611
+ scenario = st.session_state["scenario"]
612
+ spends_values = {
613
+ channel_name: round(
614
+ scenario.channels[channel_name].actual_total_spends
615
+ * scenario.channels[channel_name].conversion_rate,
616
+ 1,
617
+ )
618
+ for channel_name in st.session_state["channels_list"]
619
+ }
620
+ spends_df = pd.DataFrame(columns=['Channel', 'Media Spend'])
621
+ spends_df['Channel'] = list(spends_values.keys())
622
+ spends_df['Media Spend'] = list(spends_values.values())
623
+ spends_sum = spends_df['Media Spend'].sum()
624
+ spends_df['Media Spend'] = spends_df['Media Spend'] / spends_sum
625
+ spends_df['Channel'] = spends_df['Channel'].apply(lambda x: x.title())
626
+ spends_df.sort_values(by='Channel', ascending=False, inplace=True)
627
+ # for col in spends_df.columns:
628
+ # spends_df[col] = spends_df[col].apply(lambda x : round_off(x))
629
+
630
+ # Create Chart Data
631
+ spends_chart_data = ChartData()
632
+ spends_chart_data = ChartData()
633
+ spends_chart_data.categories = spends_df['Channel']
634
+ spends_chart_data.add_series('Media Spend', spends_df['Media Spend'])
635
+
636
+ spends_chart_placeholder = slide.placeholders[ph_idx[1]]
637
+ pie_chart(chart_placeholder=spends_chart_placeholder,
638
+ slide=slide,
639
+ chart_data=spends_chart_data,
640
+ title='Media Spend')
641
+ # spends_values.append(0)
642
+ return contribution_df, spends_df
643
+
644
+
645
+ # def get_saved_scenarios_dict(project_path):
646
+ # # Path to the saved scenarios file
647
+ # saved_scenarios_dict_path = os.path.join(
648
+ # project_path, "saved_scenarios.pkl"
649
+ # )
650
+ #
651
+ # # Load existing scenarios if the file exists
652
+ # if os.path.exists(saved_scenarios_dict_path):
653
+ # with open(saved_scenarios_dict_path, "rb") as f:
654
+ # saved_scenarios_dict = pickle.load(f)
655
+ # else:
656
+ # saved_scenarios_dict = OrderedDict()
657
+ #
658
+ # return saved_scenarios_dict
659
+
660
+ def optimization_summary(slide, scenario, scenario_name):
661
+ placeholders = slide.placeholders
662
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
663
+ title_ph = slide.placeholders[ph_idx[0]]
664
+ title_ph.text = 'Optimization Summary' # + ' (Scenario: ' + scenario_name + ')'
665
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
666
+
667
+ multiplier = 1 / float(scenario['multiplier'])
668
+ # st.write(scenario['multiplier'], multiplier)
669
+ ## Multiplier is an indicator of selected time fram
670
+ ## Doesn't effect CPA
671
+
672
+ opt_on = scenario['optimization']
673
+ if opt_on.lower() == 'spends':
674
+ opt_on = 'Media Spend'
675
+
676
+ details_ph = slide.placeholders[ph_idx[3]]
677
+ details_ph.text = 'Scenario Name: ' + scenario_name + \
678
+ '\nResponse Metric: ' + str(scenario['metrics_selected']).replace("_", " ").title() + \
679
+ '\nOptimized on: ' + str(opt_on).replace("_", " ").title()
680
+
681
+ scenario_df = pd.DataFrame(columns=['Category', 'Actual', 'Simulated', 'Change'])
682
+ scenario_df.at[0, 'Category'] = 'Media Spend'
683
+
684
+ scenario_df.at[0, 'Actual'] = scenario['actual_total_spends'] * multiplier
685
+ scenario_df.at[0, 'Simulated'] = scenario['modified_total_spends'] * multiplier
686
+ scenario_df.at[0, 'Change'] = (scenario['modified_total_spends'] - scenario['actual_total_spends']) * multiplier
687
+
688
+ scenario_df.at[1, 'Category'] = scenario['metrics_selected'].replace("_", " ").title()
689
+ scenario_df.at[1, 'Actual'] = scenario['actual_total_sales'] * multiplier
690
+ scenario_df.at[1, 'Simulated'] = (scenario['modified_total_sales']) * multiplier
691
+ scenario_df.at[1, 'Change'] = (scenario['modified_total_sales'] - scenario['actual_total_sales']) * multiplier
692
+
693
+ scenario_df.at[2, 'Category'] = 'CPA'
694
+ actual_cpa = scenario['actual_total_spends'] / scenario['actual_total_sales']
695
+ modified_cpa = scenario['modified_total_spends'] / scenario['modified_total_sales']
696
+ scenario_df.at[2, 'Actual'] = actual_cpa
697
+ scenario_df.at[2, 'Simulated'] = modified_cpa
698
+ scenario_df.at[2, 'Change'] = modified_cpa - actual_cpa
699
+
700
+ scenario_df.at[3, 'Category'] = 'ROI'
701
+ act_roi = scenario['actual_total_sales'] / scenario['actual_total_spends']
702
+ opt_roi = scenario['modified_total_sales'] / scenario['modified_total_spends']
703
+ scenario_df.at[3, 'Actual'] = act_roi
704
+ scenario_df.at[3, 'Simulated'] = opt_roi
705
+ scenario_df.at[3, 'Change'] = opt_roi - act_roi
706
+
707
+ for col in scenario_df.columns:
708
+ scenario_df[col] = scenario_df[col].apply(lambda x: round_off(x, 1))
709
+ scenario_df[col] = scenario_df[col].apply(lambda x: convert_number_to_abbreviation(x))
710
+
711
+ table_placeholder = slide.placeholders[ph_idx[1]]
712
+ fill_table_placeholder(table_placeholder, slide, scenario_df)
713
+
714
+ channel_spends_df = pd.DataFrame(columns=['Channel', 'Actual Spends', 'Optimized Spends'])
715
+ for i, channel in enumerate(scenario['channels'].values()):
716
+ channel_spends_df.at[i, 'Channel'] = channel['name']
717
+ channel_conversion_rate = channel[
718
+ "conversion_rate"
719
+ ]
720
+ channel_spends_df.at[i, 'Actual Spends'] = (
721
+ channel["actual_total_spends"]
722
+ * channel_conversion_rate
723
+ ) * multiplier
724
+ channel_spends_df.at[i, 'Optimized Spends'] = (
725
+ channel["modified_total_spends"]
726
+ * channel_conversion_rate
727
+ ) * multiplier
728
+ channel_spends_df['Actual Spends'] = channel_spends_df['Actual Spends'].astype('float')
729
+ channel_spends_df['Optimized Spends'] = channel_spends_df['Optimized Spends'].astype('float')
730
+
731
+ for col in channel_spends_df.columns:
732
+ channel_spends_df[col] = channel_spends_df[col].apply(lambda x: round_off(x, 0))
733
+
734
+ # Sort data on Actual Spends
735
+ channel_spends_df.sort_values(by='Actual Spends', inplace=True, ascending=False)
736
+
737
+ # Create chart data
738
+ chart_data = CategoryChartData()
739
+ chart_data.categories = channel_spends_df['Channel']
740
+ for col in ['Actual Spends', 'Optimized Spends']:
741
+ chart_data.add_series(col, channel_spends_df[col])
742
+
743
+ chart_placeholder = slide.placeholders[ph_idx[2]]
744
+
745
+ # Add the chart to the slide
746
+ if isinstance(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime'])),float):
747
+ bar_chart(chart_placeholder=chart_placeholder,
748
+ slide=slide,
749
+ chart_data=chart_data,
750
+ titles={'chart_title': 'Channel Wise Spends',
751
+ # 'x_axis':'Channels',
752
+ 'y_axis': 'Spends'},
753
+ # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
754
+ min_y=0,
755
+ max_y=np.ceil(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
756
+ label_type='$'
757
+ )
758
+ else:
759
+ # Add the chart to the slide
760
+ bar_chart(chart_placeholder=chart_placeholder,
761
+ slide=slide,
762
+ chart_data=chart_data,
763
+ titles={'chart_title': 'Channel Wise Spends',
764
+ # 'x_axis':'Channels',
765
+ 'y_axis': 'Spends'},
766
+ # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
767
+ min_y=0,
768
+ max_y=np.ceil(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
769
+ label_type='$'
770
+ )
771
+
772
+
773
+ def channel_wise_spends(slide, scenario):
774
+ placeholders = slide.placeholders
775
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
776
+ title_ph = slide.placeholders[ph_idx[0]]
777
+ title_ph.text = 'Channel Spends and Impact'
778
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
779
+ # print(scenario.keys())
780
+
781
+ multiplier = 1 / float(scenario['multiplier'])
782
+ channel_spends_df = pd.DataFrame(columns=['Channel', 'Actual Spends', 'Optimized Spends'])
783
+ for i, channel in enumerate(scenario['channels'].values()):
784
+ channel_spends_df.at[i, 'Channel'] = channel['name']
785
+ channel_conversion_rate = channel["conversion_rate"]
786
+ channel_spends_df.at[i, 'Actual Spends'] = (channel[
787
+ "actual_total_spends"] * channel_conversion_rate) * multiplier
788
+ channel_spends_df.at[i, 'Optimized Spends'] = (channel[
789
+ "modified_total_spends"] * channel_conversion_rate) * multiplier
790
+ channel_spends_df['Actual Spends'] = channel_spends_df['Actual Spends'].astype('float')
791
+ channel_spends_df['Optimized Spends'] = channel_spends_df['Optimized Spends'].astype('float')
792
+
793
+ actual_sum = channel_spends_df['Actual Spends'].sum()
794
+ opt_sum = channel_spends_df['Optimized Spends'].sum()
795
+
796
+ for col in channel_spends_df.columns:
797
+ channel_spends_df[col] = channel_spends_df[col].apply(lambda x: round_off(x, 0))
798
+
799
+ channel_spends_df['Actual Spends %'] = 100 * (channel_spends_df['Actual Spends'] / actual_sum)
800
+ channel_spends_df['Optimized Spends %'] = 100 * (channel_spends_df['Optimized Spends'] / opt_sum)
801
+ channel_spends_df['Actual Spends %'] = np.round(channel_spends_df['Actual Spends %'])
802
+ channel_spends_df['Optimized Spends %'] = np.round(channel_spends_df['Optimized Spends %'])
803
+
804
+ # Sort Data based on Actual Spends %
805
+ channel_spends_df.sort_values(by='Actual Spends %', inplace=True)
806
+
807
+ # Create chart data
808
+ chart_data = CategoryChartData()
809
+ chart_data.categories = channel_spends_df['Channel']
810
+ for col in ['Actual Spends %', 'Optimized Spends %']:
811
+ # for col in ['Actual Spends %']:
812
+ chart_data.add_series(col, channel_spends_df[col])
813
+ chart_placeholder = slide.placeholders[ph_idx[1]]
814
+
815
+ # Add the chart to the slide
816
+ if isinstance(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']]), float):
817
+ bar_chart(chart_placeholder=chart_placeholder,
818
+ slide=slide,
819
+ chart_data=chart_data,
820
+ titles={'chart_title': 'Spend Split %',
821
+ # 'x_axis':'Channels',
822
+ 'y_axis': 'Spend %'},
823
+ min_y=0,
824
+ max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']])),
825
+ type='H',
826
+ legend=True,
827
+ label_type='per',
828
+ xaxis_pos='low'
829
+ )
830
+ else:
831
+ bar_chart(chart_placeholder=chart_placeholder,
832
+ slide=slide,
833
+ chart_data=chart_data,
834
+ titles={'chart_title': 'Spend Split %',
835
+ # 'x_axis':'Channels',
836
+ 'y_axis': 'Spend %'},
837
+ min_y=0,
838
+ max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']]).values[0]),
839
+ type='H',
840
+ legend=True,
841
+ label_type='per',
842
+ xaxis_pos='low'
843
+ )
844
+ #
845
+ # # Create chart data
846
+ # chart_data_1 = CategoryChartData()
847
+ # chart_data_1.categories = channel_spends_df['Channel']
848
+ # # for col in ['Actual Spends %', 'Optimized Spends %']:
849
+ # for col in ['Optimized Spends %']:
850
+ # chart_data_1.add_series(col, channel_spends_df[col])
851
+ # chart_placeholder_1 = slide.placeholders[ph_idx[3]]
852
+ #
853
+ # # Add the chart to the slide
854
+ # bar_chart(chart_placeholder=chart_placeholder_1,
855
+ # slide=slide,
856
+ # chart_data=chart_data_1,
857
+ # titles={'chart_title': 'Optimized Spends Split %',
858
+ # # 'x_axis':'Channels',
859
+ # 'y_axis': 'Spends %'},
860
+ # min_y=0,
861
+ # max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']])),
862
+ # type='H',
863
+ # legend=False,
864
+ # label_type='per'
865
+ # )
866
+
867
+ channel_spends_df['Delta %'] = 100 * (channel_spends_df['Optimized Spends'] - channel_spends_df['Actual Spends']) / \
868
+ channel_spends_df['Actual Spends']
869
+ channel_spends_df['Delta %'] = channel_spends_df['Delta %'].apply(lambda x: round_off(x, 0))
870
+
871
+ # Create chart data
872
+ delta_chart_data = CategoryChartData()
873
+ delta_chart_data.categories = channel_spends_df['Channel']
874
+ col = 'Delta %'
875
+ delta_chart_data.add_series(col, channel_spends_df[col])
876
+ delta_chart_placeholder = slide.placeholders[ph_idx[3]]
877
+
878
+ # Add the chart to the slide
879
+ if isinstance(np.min(channel_spends_df['Delta %']), float):
880
+ bar_chart(chart_placeholder=delta_chart_placeholder,
881
+ slide=slide,
882
+ chart_data=delta_chart_data,
883
+ titles={'chart_title': 'Spend Delta %',
884
+ 'y_axis': 'Spend Delta %'},
885
+ min_y=np.floor(np.min(channel_spends_df['Delta %'])),
886
+ max_y=np.ceil(np.max(channel_spends_df['Delta %'])),
887
+ type='H',
888
+ legend=False,
889
+ label_type='per',
890
+ xaxis_pos='low'
891
+
892
+ )
893
+ else:
894
+ bar_chart(chart_placeholder=delta_chart_placeholder,
895
+ slide=slide,
896
+ chart_data=delta_chart_data,
897
+ titles={'chart_title': 'Spend Delta %',
898
+ 'y_axis': 'Spend Delta %'},
899
+ min_y=np.floor(np.min(channel_spends_df['Delta %']).values[0]),
900
+ max_y=np.ceil(np.max(channel_spends_df['Delta %']).values[0]),
901
+ type='H',
902
+ legend=False,
903
+ label_type='per',
904
+ xaxis_pos='low'
905
+
906
+ )
907
+
908
+ # Incremental Impact
909
+ channel_inc_df = pd.DataFrame(columns=['Channel', 'Increment'])
910
+ for i, channel in enumerate(scenario['channels'].values()):
911
+ channel_inc_df.at[i, 'Channel'] = channel['name']
912
+ act_impact = channel['actual_total_sales']
913
+ opt_impact = channel['modified_total_sales']
914
+ impact = opt_impact - act_impact
915
+ impact = round_off(impact, 0)
916
+ impact = impact if abs(impact) > 0.0001 else 0
917
+ channel_inc_df.at[i, 'Increment'] = impact
918
+
919
+ channel_inc_df_1 = pd.merge(channel_spends_df, channel_inc_df, how='left', on='Channel')
920
+
921
+ # Create chart data
922
+ delta_chart_data = CategoryChartData()
923
+ delta_chart_data.categories = channel_inc_df_1['Channel']
924
+ col = 'Increment'
925
+ delta_chart_data.add_series(col, channel_inc_df_1[col])
926
+ delta_chart_placeholder = slide.placeholders[ph_idx[2]]
927
+
928
+ label_req = True
929
+ if min(np.abs(channel_inc_df_1[col])) > 100000: # 0.1M
930
+ label_type = 'M'
931
+ elif min(np.abs(channel_inc_df_1[col])) > 10000 and max(np.abs(channel_inc_df_1[col])) > 1000000:
932
+ label_type = 'M1'
933
+ elif min(np.abs(channel_inc_df_1[col])) > 100 and max(np.abs(channel_inc_df_1[col])) > 1000:
934
+ label_type = 'K'
935
+ else:
936
+ label_req = False
937
+ # Add the chart to the slide
938
+ if label_req:
939
+ bar_chart(chart_placeholder=delta_chart_placeholder,
940
+ slide=slide,
941
+ chart_data=delta_chart_data,
942
+ titles={'chart_title': 'Incremental Impact',
943
+ 'y_axis': format_response_metric(scenario['metrics_selected'])},
944
+ # min_y=np.floor(np.min(channel_inc_df_1['Delta %'])),
945
+ # max_y=np.ceil(np.max(channel_inc_df_1['Delta %'])),
946
+ type='H',
947
+ label_type=label_type,
948
+ legend=False,
949
+ xaxis_pos='low'
950
+ )
951
+ else:
952
+ bar_chart(chart_placeholder=delta_chart_placeholder,
953
+ slide=slide,
954
+ chart_data=delta_chart_data,
955
+ titles={'chart_title': 'Increment',
956
+ 'y_axis': scenario['metrics_selected']},
957
+ # min_y=np.floor(np.min(channel_inc_df_1['Delta %'])),
958
+ # max_y=np.ceil(np.max(channel_inc_df_1['Delta %'])),
959
+ type='H',
960
+ legend=False,
961
+ xaxis_pos='low'
962
+ )
963
+
964
+
965
+ def channel_wise_roi(slide, scenario):
966
+ channel_roi_mroi = scenario['channel_roi_mroi']
967
+
968
+ # Add title
969
+ placeholders = slide.placeholders
970
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
971
+ title_ph = slide.placeholders[ph_idx[0]]
972
+ title_ph.text = 'Channel ROIs'
973
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
974
+
975
+ channel_roi_df = pd.DataFrame(columns=['Channel', 'Actual ROI', 'Optimized ROI'])
976
+ for i, channel in enumerate(channel_roi_mroi.keys()):
977
+ channel_roi_df.at[i, 'Channel'] = channel
978
+ channel_roi_df.at[i, 'Actual ROI'] = channel_roi_mroi[channel]['actual_roi']
979
+ channel_roi_df.at[i, 'Optimized ROI'] = channel_roi_mroi[channel]['optimized_roi']
980
+ channel_roi_df['Actual ROI'] = channel_roi_df['Actual ROI'].astype('float')
981
+ channel_roi_df['Optimized ROI'] = channel_roi_df['Optimized ROI'].astype('float')
982
+
983
+ for col in channel_roi_df.columns:
984
+ channel_roi_df[col] = channel_roi_df[col].apply(lambda x: round_off(x, 2))
985
+
986
+ # Create chart data
987
+ chart_data = CategoryChartData()
988
+ chart_data.categories = channel_roi_df['Channel']
989
+ for col in ['Actual ROI', 'Optimized ROI']:
990
+ chart_data.add_series(col, channel_roi_df[col])
991
+
992
+ chart_placeholder = slide.placeholders[ph_idx[1]]
993
+
994
+ # Add the chart to the slide
995
+ if isinstance(channel_roi_df.select_dtypes(exclude=['object', 'datetime']), float):
996
+ bar_chart(chart_placeholder=chart_placeholder,
997
+ slide=slide,
998
+ chart_data=chart_data,
999
+ titles={'chart_title': 'Channel Wise ROI',
1000
+ # 'x_axis':'Channels',
1001
+ 'y_axis': 'ROI'},
1002
+ # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
1003
+ min_y=0,
1004
+ max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime']))
1005
+ )
1006
+ else:
1007
+ bar_chart(chart_placeholder=chart_placeholder,
1008
+ slide=slide,
1009
+ chart_data=chart_data,
1010
+ titles={'chart_title': 'Channel Wise ROI',
1011
+ # 'x_axis':'Channels',
1012
+ 'y_axis': 'ROI'},
1013
+ # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
1014
+ min_y=0,
1015
+ max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])).values[0]
1016
+ )
1017
+ # act_roi = scenario['actual_total_sales']/scenario['actual_total_spends']
1018
+ # opt_roi = scenario['modified_total_sales']/scenario['modified_total_spends']
1019
+ #
1020
+ # act_roi_ph = slide.placeholders[ph_idx[2]]
1021
+ # act_roi_ph.text = 'Actual ROI: ' + str(round_off(act_roi,2))
1022
+ # opt_roi_ph = slide.placeholders[ph_idx[3]]
1023
+ # opt_roi_ph.text = 'Optimized ROI: ' + str(round_off(opt_roi, 2))
1024
+
1025
+ ## Removing mroi chart as per Ioannis' feedback
1026
+ # channel_mroi_df = pd.DataFrame(columns=['Channel', 'Actual mROI', 'Optimized mROI'])
1027
+ # for i, channel in enumerate(channel_roi_mroi.keys()):
1028
+ # channel_mroi_df.at[i, 'Channel'] = channel
1029
+ # channel_mroi_df.at[i, 'Actual mROI'] = channel_roi_mroi[channel]['actual_mroi']
1030
+ # channel_mroi_df.at[i, 'Optimized mROI'] = channel_roi_mroi[channel]['optimized_mroi']
1031
+ # channel_mroi_df['Actual mROI']=channel_mroi_df['Actual mROI'].astype('float')
1032
+ # channel_mroi_df['Optimized mROI']=channel_mroi_df['Optimized mROI'].astype('float')
1033
+ #
1034
+ # for col in channel_mroi_df.columns:
1035
+ # channel_mroi_df[col]=channel_mroi_df[col].apply(lambda x: round_off(x))
1036
+ #
1037
+ # # Create chart data
1038
+ # mroi_chart_data = CategoryChartData()
1039
+ # mroi_chart_data.categories = channel_mroi_df['Channel']
1040
+ # for col in ['Actual mROI', 'Optimized mROI']:
1041
+ # mroi_chart_data.add_series(col, channel_mroi_df[col])
1042
+ #
1043
+ # mroi_chart_placeholder=slide.placeholders[ph_idx[2]]
1044
+ #
1045
+ # # Add the chart to the slide
1046
+ # bar_chart(chart_placeholder=mroi_chart_placeholder,
1047
+ # slide=slide,
1048
+ # chart_data=mroi_chart_data,
1049
+ # titles={'chart_title':'Channel Wise mROI',
1050
+ # # 'x_axis':'Channels',
1051
+ # 'y_axis':'mROI'},
1052
+ # # min_y=np.floor(np.min(channel_mroi_df.select_dtypes(exclude=['object', 'datetime']))),
1053
+ # min_y=0,
1054
+ # max_y=np.ceil(np.max(channel_mroi_df.select_dtypes(exclude=['object', 'datetime'])))
1055
+ # )
1056
+
1057
+
1058
+ def effictiveness_efficiency(slide, final_data, bin_dct, scenario):
1059
+ # Add title
1060
+ placeholders = slide.placeholders
1061
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
1062
+ title_ph = slide.placeholders[ph_idx[0]]
1063
+ title_ph.text = 'Effectiveness and Efficiency'
1064
+ title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
1065
+
1066
+ response_metrics = bin_dct['Response Metrics']
1067
+
1068
+ kpi_df = final_data[response_metrics].sum(axis=0).reset_index()
1069
+ kpi_df.columns = ['Response Metric', 'Effectiveness']
1070
+ kpi_df['Efficiency'] = kpi_df['Effectiveness'] / scenario['modified_total_spends']
1071
+ kpi_df['Efficiency'] = kpi_df['Efficiency'].apply(lambda x: round_off(x, 1))
1072
+ kpi_df.sort_values(by='Effectiveness', inplace=True)
1073
+ kpi_df['Response Metric'] = kpi_df['Response Metric'].apply(lambda x: format_response_metric(x))
1074
+
1075
+ # Create chart data for effectiveness
1076
+ chart_data = CategoryChartData()
1077
+ chart_data.categories = kpi_df['Response Metric']
1078
+ chart_data.add_series('Effectiveness', kpi_df['Effectiveness'])
1079
+
1080
+ chart_placeholder = slide.placeholders[ph_idx[1]]
1081
+
1082
+ # Add the chart to the slide
1083
+ bar_chart(chart_placeholder=chart_placeholder,
1084
+ slide=slide,
1085
+ chart_data=chart_data,
1086
+ titles={'chart_title': 'Effectiveness',
1087
+ # 'x_axis':'Channels',
1088
+ # 'y_axis': 'ROI'
1089
+ },
1090
+ # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
1091
+ min_y=0,
1092
+ # max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])),
1093
+ type='H',
1094
+ label_type='M'
1095
+ )
1096
+
1097
+ # Create chart data for efficiency
1098
+ chart_data_1 = CategoryChartData()
1099
+ chart_data_1.categories = kpi_df['Response Metric']
1100
+ chart_data_1.add_series('Efficiency', kpi_df['Efficiency'])
1101
+
1102
+ chart_placeholder_1 = slide.placeholders[ph_idx[2]]
1103
+
1104
+ # Add the chart to the slide
1105
+ bar_chart(chart_placeholder=chart_placeholder_1,
1106
+ slide=slide,
1107
+ chart_data=chart_data_1,
1108
+ titles={'chart_title': 'Efficiency',
1109
+ # 'x_axis':'Channels',
1110
+ # 'y_axis': 'ROI'
1111
+ },
1112
+ # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
1113
+ min_y=0,
1114
+ # max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])),
1115
+ type='H'
1116
+ )
1117
+
1118
+ definition_ph_1 = slide.placeholders[ph_idx[3]]
1119
+ definition_ph_1.text = 'Effectiveness is measured as the total sum of the Response Metric'
1120
+ definition_ph_2 = slide.placeholders[ph_idx[4]]
1121
+ definition_ph_2.text = 'Efficiency is measured as the ratio of sum of the Response Metric and sum of Media Spend'
1122
+
1123
+
1124
+ def load_pickle(path):
1125
+ with open(path, "rb") as f:
1126
+ file_data = pickle.load(f)
1127
+ return file_data
1128
+
1129
+
1130
+ def read_all_files():
1131
+ files=[]
1132
+
1133
+ # Read data and bin dictionary
1134
+ if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is not None:
1135
+ final_df_loaded = st.session_state["project_dct"]["data_import"]["imputed_tool_df"].copy()
1136
+ bin_dict_loaded = st.session_state["project_dct"]["data_import"]["category_dict"].copy()
1137
+
1138
+ files.append(final_df_loaded)
1139
+ files.append(bin_dict_loaded)
1140
+
1141
+ if "group_dict" in st.session_state["project_dct"]["data_import"].keys():
1142
+ channels = st.session_state["project_dct"]["data_import"]["group_dict"]
1143
+ files.append(channels)
1144
+
1145
+
1146
+ if st.session_state["project_dct"]["transformations"]["final_df"] is not None:
1147
+ transform_dict = st.session_state["project_dct"]["transformations"]
1148
+ files.append(transform_dict)
1149
+ if retrieve_pkl_object_without_warning(st.session_state['project_number'], "Model_Tuning", "tuned_model", schema) is not None:
1150
+ tuned_model_dict = retrieve_pkl_object_without_warning(st.session_state['project_number'], "Model_Tuning",
1151
+ "tuned_model", schema) # db
1152
+
1153
+ files.append(tuned_model_dict)
1154
+ else:
1155
+ files.append(None)
1156
+ else:
1157
+ files.append(None)
1158
+
1159
+ if len(list(st.session_state["project_dct"]["current_media_performance"]["model_outputs"].keys()))>0: # check if there are model outputs for at least one metric
1160
+ metrics_list = list(st.session_state["project_dct"]["current_media_performance"]["model_outputs"].keys())
1161
+ contributions_excels_dict = {}
1162
+ for metrics in metrics_list:
1163
+ # raw_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["raw_data"]
1164
+ # spend_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["spends_data"]
1165
+ contribution_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["contribution_data"]
1166
+ contributions_excels_dict[metrics] = {'CONTRIBUTION MMM':contribution_df}
1167
+ files.append(contributions_excels_dict)
1168
+
1169
+ # Get Saved Scenarios
1170
+ if len(list(st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"].keys()))>0:
1171
+ files.append(st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"])
1172
+
1173
+ # saved_scenarios_loaded = get_saved_scenarios_dict(project_path)
1174
+
1175
+
1176
+ return files
1177
+
1178
+
1179
+
1180
+ '''
1181
+
1182
+ Template Layout
1183
+
1184
+ 0 : Title
1185
+ 1 : Data Details Section {no changes required}
1186
+ 2 : Data Import
1187
+ 3 : Data Import - Channel Groups
1188
+ 4 : Model Results {Duplicate for each model}
1189
+ 5 : Metrics Contribution
1190
+ 6 : Media performance {Duplicate for each model}
1191
+ 7 : Media performance Tabular View {Duplicate for each model}
1192
+ 8 : Optimization Section {no changes}
1193
+ 9 : Optimization Summary {Duplicate for each section}
1194
+ 10 : Channel Spends {Duplicate for each model}
1195
+ 11 : Channel Wise ROI {Duplicate for each model}
1196
+ 12 : Efficiency & Efficacy
1197
+ 13 : Appendix
1198
+ 14 : Transformations
1199
+ 15 : Model Summary
1200
+ 16 : Thank You Slide
1201
+
1202
+ '''
1203
+
1204
+
1205
+ def create_ppt(project_name, username, panel_col):
1206
+ # Read saved files
1207
+ files = read_all_files()
1208
+ transform_dict, tuned_model_dict, contributions_excels_dict, saved_scenarios_loaded = None, None, None, None
1209
+
1210
+ if len(files)>0:
1211
+ # saved_data = files[0]
1212
+ data = files[0]
1213
+ bin_dict = files[1]
1214
+
1215
+ channel_groups_dct = files[2]
1216
+ try:
1217
+ transform_dict = files[3]
1218
+ tuned_model_dict = files[4]
1219
+ contributions_excels_dict = files[5]
1220
+ saved_scenarios_loaded = files[6]
1221
+ except Exception as e:
1222
+ print(e)
1223
+
1224
+ else:
1225
+ return False
1226
+
1227
+ is_panel = True if data[panel_col].nunique()>1 else False
1228
+
1229
+ template_path = 'ppt/template.pptx'
1230
+ # ppt_path = os.path.join('ProjectSummary.pptx')
1231
+
1232
+ prs = Presentation(template_path)
1233
+ num_slides = len(prs.slides)
1234
+ slides = prs.slides
1235
+
1236
+ # Title Slide
1237
+ title_slide_layout = slides[0].slide_layout
1238
+ title_slide = prs.slides.add_slide(title_slide_layout)
1239
+
1240
+ # Add title & project name
1241
+ placeholders = title_slide.placeholders
1242
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
1243
+ title_ph = title_slide.placeholders[ph_idx[0]]
1244
+ title_ph.text = 'Media Mix Optimization Summary'
1245
+ txt_ph = title_slide.placeholders[ph_idx[1]]
1246
+ txt_ph.text = 'Project Name: ' + project_name + '\nCreated By: ' + username
1247
+
1248
+ # Model Details Section
1249
+ model_section_slide_layout = slides[1].slide_layout
1250
+ model_section_slide = prs.slides.add_slide(model_section_slide_layout)
1251
+
1252
+ ## Add title
1253
+ placeholders = model_section_slide.placeholders
1254
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
1255
+ title_ph = model_section_slide.placeholders[ph_idx[0]]
1256
+ title_ph.text = 'Model Details'
1257
+ section_ph = model_section_slide.placeholders[ph_idx[1]]
1258
+ section_ph.text = 'Section 1'
1259
+
1260
+ # Data Import
1261
+ data_import_slide_layout = slides[2].slide_layout
1262
+ data_import_slide = prs.slides.add_slide(data_import_slide_layout)
1263
+ data_import_slide = title_and_table(slide=data_import_slide,
1264
+ title='Data Import',
1265
+ df=data_import(data, bin_dict),
1266
+ column_width={0: 2, 1: 7}
1267
+ )
1268
+
1269
+ # Channel Groups
1270
+ channel_group_slide_layout = slides[3].slide_layout
1271
+ channel_group_slide = prs.slides.add_slide(channel_group_slide_layout)
1272
+ channel_group_slide = title_and_table(slide=channel_group_slide,
1273
+ title='Channels - Media and Spend',
1274
+ df=channel_groups_df(channel_groups_dct, bin_dict),
1275
+ column_width={0: 2, 1: 5, 2: 2}
1276
+ )
1277
+
1278
+ if tuned_model_dict is not None:
1279
+ model_metrics_df = model_metrics(tuned_model_dict, False)
1280
+
1281
+ # Model Results
1282
+ for model_key, model_dict in tuned_model_dict.items():
1283
+ model_result_slide_layout = slides[4].slide_layout
1284
+ model_result_slide = prs.slides.add_slide(model_result_slide_layout)
1285
+ model_result_slide = model_result(slide=model_result_slide,
1286
+ model_key=model_key,
1287
+ model_dict=model_dict,
1288
+ model_metrics_df=model_metrics_df,
1289
+ date_col='date')
1290
+
1291
+ if contributions_excels_dict is not None:
1292
+
1293
+ # Metrics Contributions
1294
+ metrics_contributions_slide_layout = slides[5].slide_layout
1295
+ metrics_contributions_slide = prs.slides.add_slide(metrics_contributions_slide_layout)
1296
+ metrics_contributions_slide = metrics_contributions(slide=metrics_contributions_slide,
1297
+ contributions_excels_dict=contributions_excels_dict,
1298
+ panel_col=panel_col
1299
+ )
1300
+
1301
+ # Media Performance
1302
+ for target in contributions_excels_dict.keys():
1303
+
1304
+ # Chart
1305
+ model_media_perf_slide_layout = slides[6].slide_layout
1306
+ model_media_perf_slide = prs.slides.add_slide(model_media_perf_slide_layout)
1307
+ contribution_df, spends_df = model_media_performance(slide=model_media_perf_slide,
1308
+ target=target,
1309
+ contributions_excels_dict=contributions_excels_dict
1310
+ )
1311
+
1312
+ # Tabular View
1313
+ contri_spends_df = pd.merge(spends_df, contribution_df, on='Channel', how='outer')
1314
+ contri_spends_df.fillna(0, inplace=True)
1315
+
1316
+ for col in [c for c in contri_spends_df.columns if c != 'Channel']:
1317
+ contri_spends_df[col] = contri_spends_df[col].apply(lambda x: safe_num_to_per(x))
1318
+
1319
+ media_performance_table_slide_layout = slides[7].slide_layout
1320
+ media_performance_table_slide = prs.slides.add_slide(media_performance_table_slide_layout)
1321
+ media_performance_table_slide = title_and_table(slide=media_performance_table_slide,
1322
+ title='Media and Spends Channels Tabular View',
1323
+ df=contri_spends_df,
1324
+ # column_width={0:2, 1:5, 2:2}
1325
+ )
1326
+
1327
+ if saved_scenarios_loaded is not None:
1328
+ # Optimization Details
1329
+ opt_section_slide_layout = slides[8].slide_layout
1330
+ opt_section_slide = prs.slides.add_slide(opt_section_slide_layout)
1331
+
1332
+ ## Add title
1333
+ placeholders = opt_section_slide.placeholders
1334
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
1335
+ title_ph = opt_section_slide.placeholders[ph_idx[0]]
1336
+ title_ph.text = 'Optimizations Details'
1337
+ section_ph = opt_section_slide.placeholders[ph_idx[1]]
1338
+ section_ph.text = 'Section 2'
1339
+
1340
+ # Optimization
1341
+ for scenario_name, scenario in saved_scenarios_loaded.items():
1342
+ opt_summary_slide_layout = slides[9].slide_layout
1343
+ opt_summary_slide = prs.slides.add_slide(opt_summary_slide_layout)
1344
+ optimization_summary(opt_summary_slide, scenario, scenario_name)
1345
+
1346
+ channel_spends_slide_layout = slides[10].slide_layout
1347
+ channel_spends_slide = prs.slides.add_slide(channel_spends_slide_layout)
1348
+ channel_wise_spends(channel_spends_slide, scenario)
1349
+
1350
+ channel_roi_slide_layout = slides[11].slide_layout
1351
+ channel_roi_slide = prs.slides.add_slide(channel_roi_slide_layout)
1352
+ channel_wise_roi(channel_roi_slide, scenario)
1353
+
1354
+ effictiveness_efficiency_slide_layout = slides[12].slide_layout
1355
+ effictiveness_efficiency_slide = prs.slides.add_slide(effictiveness_efficiency_slide_layout)
1356
+ effictiveness_efficiency(effictiveness_efficiency_slide,
1357
+ data,
1358
+ bin_dict,
1359
+ scenario)
1360
+
1361
+ # Appendix Section
1362
+ appendix_section_slide_layout = slides[13].slide_layout
1363
+ appendix_section_slide = prs.slides.add_slide(appendix_section_slide_layout)
1364
+
1365
+ if tuned_model_dict is not None:
1366
+
1367
+ ## Add title
1368
+ placeholders = appendix_section_slide.placeholders
1369
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
1370
+ title_ph = appendix_section_slide.placeholders[ph_idx[0]]
1371
+ title_ph.text = 'Appendix'
1372
+ section_ph = appendix_section_slide.placeholders[ph_idx[1]]
1373
+ section_ph.text = 'Section 3'
1374
+
1375
+ # Add transformations
1376
+ # if transform_dict is not None:
1377
+ # # Transformations
1378
+ # transformation_slide_layout = slides[14].slide_layout
1379
+ # transformation_slide = prs.slides.add_slide(transformation_slide_layout)
1380
+ # transformation_slide = title_and_table(slide=transformation_slide,
1381
+ # title='Transformations',
1382
+ # df=transformations(transform_dict),
1383
+ # custom_table_height=True
1384
+ # )
1385
+
1386
+ # Add model summary
1387
+ # Model Summary
1388
+ model_metrics_df = model_metrics(tuned_model_dict, False)
1389
+ model_summary_slide_layout = slides[15].slide_layout
1390
+ model_summary_slide = prs.slides.add_slide(model_summary_slide_layout)
1391
+ model_summary_slide = title_and_table(slide=model_summary_slide,
1392
+ title='Model Summary',
1393
+ df=model_metrics_df,
1394
+ custom_table_height=True
1395
+ )
1396
+
1397
+ # Last Slide
1398
+ last_slide_layout = slides[num_slides - 1].slide_layout
1399
+ last_slide = prs.slides.add_slide(last_slide_layout)
1400
+
1401
+ # Add title
1402
+ placeholders = last_slide.placeholders
1403
+ ph_idx = [ph.placeholder_format.idx for ph in placeholders]
1404
+ title_ph = last_slide.placeholders[ph_idx[0]]
1405
+ title_ph.text = 'Thank You'
1406
+
1407
+ # Remove template slides
1408
+ xml_slides = prs.slides._sldIdLst
1409
+ slides = list(xml_slides)
1410
+ for index in range(num_slides):
1411
+ xml_slides.remove(slides[index])
1412
+
1413
+ # prs.save(ppt_path)
1414
+
1415
+ # save the output into binary form
1416
+ binary_output = BytesIO()
1417
+ prs.save(binary_output)
1418
+
1419
+ return binary_output
requirements.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair == 4.2.0
2
+ attrs == 23.1.0
3
+ bcrypt == 4.0.1
4
+ blinker == 1.6.2
5
+ cachetools == 5.3.1
6
+ certifi == 2023.7.22
7
+ charset-normalizer == 3.2.0
8
+ click == 8.1.7
9
+ colorama == 0.4.6
10
+ contourpy == 1.1.1
11
+ cycler == 0.11.0
12
+ dacite == 1.8.1
13
+ entrypoints == 0.4
14
+ et-xmlfile == 1.1.0
15
+ extra-streamlit-components == 0.1.56
16
+ fonttools == 4.42.1
17
+ gitdb == 4.0.10
18
+ GitPython == 3.1.35
19
+ htmlmin == 0.1.12
20
+ idna == 3.4
21
+ ImageHash == 4.3.1
22
+ importlib-metadata == 6.8.0
23
+ importlib-resources == 6.1.0
24
+ Jinja2 == 3.1.2
25
+ joblib == 1.3.2
26
+ jsonschema == 4.19.0
27
+ jsonschema-specifications== 2023.7.1
28
+ kaleido == 0.2.1
29
+ kiwisolver == 1.4.5
30
+ markdown-it-py == 3.0.0
31
+ MarkupSafe == 2.1.3
32
+ matplotlib == 3.7.0
33
+ mdurl == 0.1.2
34
+ networkx == 3.1
35
+ numerize == 0.12
36
+ numpy == 1.23.5
37
+ openpyxl>=3.1.0
38
+ packaging == 23.1
39
+ pandas == 1.5.2
40
+ pandas-profiling == 3.6.6
41
+ patsy == 0.5.3
42
+ phik == 0.12.3
43
+ Pillow == 10.0.0
44
+ pip == 23.2.1
45
+ plotly == 5.11.0
46
+ protobuf == 3.20.3
47
+ pyarrow == 13.0.0
48
+ pydantic == 1.10.13
49
+ pydeck == 0.8.1b0
50
+ Pygments == 2.16.1
51
+ PyJWT == 2.8.0
52
+ Pympler == 1.0.1
53
+ pyparsing == 3.1.1
54
+ python-dateutil == 2.8.2
55
+ python-decouple == 3.8
56
+ pytz == 2023.3.post1
57
+ PyWavelets == 1.4.1
58
+ PyYAML == 6.0.1
59
+ referencing == 0.30.2
60
+ requests == 2.31.0
61
+ rich == 13.5.2
62
+ rpds-py == 0.10.2
63
+ scikit-learn == 1.1.3
64
+ scipy == 1.9.3
65
+ seaborn == 0.12.2
66
+ semver == 3.0.1
67
+ setuptools == 68.1.2
68
+ six == 1.16.0
69
+ smmap == 5.0.0
70
+ statsmodels == 0.14.0
71
+ streamlit == 1.28.0
72
+ streamlit-aggrid == 0.3.4.post3
73
+ streamlit-authenticator == 0.2.1
74
+ streamlit-pandas-profiling== 0.1.3
75
+ sweetviz == 2.2.1
76
+ tangled-up-in-unicode == 0.2.0
77
+ tenacity == 8.2.3
78
+ threadpoolctl == 3.2.0
79
+ toml == 0.10.2
80
+ toolz == 0.12.0
81
+ tornado == 6.3.3
82
+ tqdm == 4.66.1
83
+ typeguard == 2.13.3
84
+ typing_extensions == 4.7.1
85
+ tzdata == 2023.3
86
+ tzlocal == 5.0.1
87
+ urllib3 == 2.0.4
88
+ validators == 0.22.0
89
+ visions == 0.7.5
90
+ watchdog == 3.0.0
91
+ wheel == 0.41.2
92
+ wordcloud == 1.9.2
93
+ ydata-profiling == 4.5.1
94
+ zipp == 3.16.2
95
+ psycopg2 == 2.9.9
96
+ python-pptx == 0.6.21
scenario.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import minimize, LinearConstraint, NonlinearConstraint
3
+ from collections import OrderedDict
4
+ import pandas as pd
5
+ from decimal import Decimal
6
+
7
+
8
+ def round_num(n, decimal=2):
9
+ n = Decimal(n)
10
+ return n.to_integral() if n == n.to_integral() else round(n.normalize(), decimal)
11
+
12
+
13
+ def numerize(n, decimal=2):
14
+ # 60 sufixes
15
+ sufixes = [
16
+ "",
17
+ "K",
18
+ "M",
19
+ "B",
20
+ "T",
21
+ "Qa",
22
+ "Qu",
23
+ "S",
24
+ "Oc",
25
+ "No",
26
+ "D",
27
+ "Ud",
28
+ "Dd",
29
+ "Td",
30
+ "Qt",
31
+ "Qi",
32
+ "Se",
33
+ "Od",
34
+ "Nd",
35
+ "V",
36
+ "Uv",
37
+ "Dv",
38
+ "Tv",
39
+ "Qv",
40
+ "Qx",
41
+ "Sx",
42
+ "Ox",
43
+ "Nx",
44
+ "Tn",
45
+ "Qa",
46
+ "Qu",
47
+ "S",
48
+ "Oc",
49
+ "No",
50
+ "D",
51
+ "Ud",
52
+ "Dd",
53
+ "Td",
54
+ "Qt",
55
+ "Qi",
56
+ "Se",
57
+ "Od",
58
+ "Nd",
59
+ "V",
60
+ "Uv",
61
+ "Dv",
62
+ "Tv",
63
+ "Qv",
64
+ "Qx",
65
+ "Sx",
66
+ "Ox",
67
+ "Nx",
68
+ "Tn",
69
+ "x",
70
+ "xx",
71
+ "xxx",
72
+ "X",
73
+ "XX",
74
+ "XXX",
75
+ "END",
76
+ ]
77
+
78
+ sci_expr = [
79
+ 1e0,
80
+ 1e3,
81
+ 1e6,
82
+ 1e9,
83
+ 1e12,
84
+ 1e15,
85
+ 1e18,
86
+ 1e21,
87
+ 1e24,
88
+ 1e27,
89
+ 1e30,
90
+ 1e33,
91
+ 1e36,
92
+ 1e39,
93
+ 1e42,
94
+ 1e45,
95
+ 1e48,
96
+ 1e51,
97
+ 1e54,
98
+ 1e57,
99
+ 1e60,
100
+ 1e63,
101
+ 1e66,
102
+ 1e69,
103
+ 1e72,
104
+ 1e75,
105
+ 1e78,
106
+ 1e81,
107
+ 1e84,
108
+ 1e87,
109
+ 1e90,
110
+ 1e93,
111
+ 1e96,
112
+ 1e99,
113
+ 1e102,
114
+ 1e105,
115
+ 1e108,
116
+ 1e111,
117
+ 1e114,
118
+ 1e117,
119
+ 1e120,
120
+ 1e123,
121
+ 1e126,
122
+ 1e129,
123
+ 1e132,
124
+ 1e135,
125
+ 1e138,
126
+ 1e141,
127
+ 1e144,
128
+ 1e147,
129
+ 1e150,
130
+ 1e153,
131
+ 1e156,
132
+ 1e159,
133
+ 1e162,
134
+ 1e165,
135
+ 1e168,
136
+ 1e171,
137
+ 1e174,
138
+ 1e177,
139
+ ]
140
+ minus_buff = n
141
+ n = abs(n)
142
+
143
+ if n < 1:
144
+ return f"{round(n/1000, decimal)}K"
145
+
146
+ for x in range(len(sci_expr)):
147
+ try:
148
+ if n >= sci_expr[x] and n < sci_expr[x + 1]:
149
+ sufix = sufixes[x]
150
+ if n >= 1e3:
151
+ num = str(round_num(n / sci_expr[x], decimal))
152
+ else:
153
+ num = str(round_num(n, decimal))
154
+ return num + sufix if minus_buff > 0 else "-" + num + sufix
155
+ except IndexError:
156
+ pass
157
+
158
+
159
+ def class_to_dict(class_instance):
160
+ attr_dict = {}
161
+ if isinstance(class_instance, Channel):
162
+ attr_dict["type"] = "Channel"
163
+ attr_dict["name"] = class_instance.name
164
+ attr_dict["dates"] = class_instance.dates
165
+ attr_dict["spends"] = class_instance.actual_spends
166
+ attr_dict["conversion_rate"] = class_instance.conversion_rate
167
+ attr_dict["modified_spends"] = class_instance.modified_spends
168
+ attr_dict["modified_sales"] = class_instance.modified_sales
169
+ attr_dict["response_curve_type"] = class_instance.response_curve_type
170
+ attr_dict["response_curve_params"] = class_instance.response_curve_params
171
+ attr_dict["penalty"] = class_instance.penalty
172
+ attr_dict["bounds"] = class_instance.bounds
173
+ attr_dict["actual_total_spends"] = class_instance.actual_total_spends
174
+ attr_dict["actual_total_sales"] = class_instance.actual_total_sales
175
+ attr_dict["modified_total_spends"] = class_instance.modified_total_spends
176
+ attr_dict["modified_total_sales"] = class_instance.modified_total_sales
177
+ attr_dict["actual_mroi"] = class_instance.get_marginal_roi("actual")
178
+ attr_dict["modified_mroi"] = class_instance.get_marginal_roi("modified")
179
+ attr_dict["freeze"] = class_instance.freeze
180
+ attr_dict["correction"] = class_instance.correction
181
+
182
+ elif isinstance(class_instance, Scenario):
183
+ attr_dict["type"] = "Scenario"
184
+ attr_dict["name"] = class_instance.name
185
+ attr_dict["bounds"] = class_instance.bounds
186
+ channels = []
187
+ for channel in class_instance.channels.values():
188
+ channels.append(class_to_dict(channel))
189
+ attr_dict["channels"] = channels
190
+ attr_dict["constant"] = class_instance.constant
191
+ attr_dict["correction"] = class_instance.correction
192
+ attr_dict["actual_total_spends"] = class_instance.actual_total_spends
193
+ attr_dict["actual_total_sales"] = class_instance.actual_total_sales
194
+ attr_dict["modified_total_spends"] = class_instance.modified_total_spends
195
+ attr_dict["modified_total_sales"] = class_instance.modified_total_sales
196
+
197
+ return attr_dict
198
+
199
+
200
+ # def class_convert_to_dict(class_instance):
201
+ # attr_dict = {}
202
+ # if isinstance(class_instance, Channel):
203
+ # attr_dict["type"] = "Channel"
204
+ # attr_dict["name"] = class_instance.name
205
+ # attr_dict["dates"] = class_instance.dates
206
+ # attr_dict["spends"] = class_instance.actual_spends
207
+ # attr_dict["conversion_rate"] = class_instance.conversion_rate
208
+ # attr_dict["modified_spends"] = class_instance.modified_spends
209
+ # attr_dict["modified_sales"] = class_instance.modified_sales
210
+ # attr_dict["response_curve_type"] = class_instance.response_curve_type
211
+ # attr_dict["response_curve_params"] = class_instance.response_curve_params
212
+ # # attr_dict["penalty"] = class_instance.penalty
213
+ # attr_dict["bounds"] = class_instance.bounds
214
+ # attr_dict["actual_total_spends"] = class_instance.actual_total_spends
215
+ # attr_dict["actual_total_sales"] = class_instance.actual_total_sales
216
+ # attr_dict["modified_total_spends"] = class_instance.modified_total_spends
217
+ # attr_dict["modified_total_sales"] = class_instance.modified_total_sales
218
+ # # attr_dict["actual_mroi"] = class_instance.get_marginal_roi("actual")
219
+ # # attr_dict["modified_mroi"] = class_instance.get_marginal_roi("modified")
220
+
221
+ # attr_dict["freeze"] = class_instance.freeze
222
+ # attr_dict["correction"] = class_instance.correction
223
+
224
+ # elif isinstance(class_instance, Scenario):
225
+ # attr_dict["type"] = "Scenario"
226
+ # attr_dict["name"] = class_instance.name
227
+ # channels = {}
228
+ # for channel in class_instance.channels.values():
229
+ # channels[channel.name] = class_to_dict(channel)
230
+ # attr_dict["channels"] = channels
231
+ # attr_dict["constant"] = class_instance.constant
232
+ # attr_dict["correction"] = class_instance.correction
233
+ # attr_dict["actual_total_spends"] = class_instance.actual_total_spends
234
+ # attr_dict["actual_total_sales"] = class_instance.actual_total_sales
235
+ # attr_dict["modified_total_spends"] = class_instance.modified_total_spends
236
+ # attr_dict["modified_total_sales"] = class_instance.modified_total_sales
237
+
238
+ # attr_dict["bound_type"] = class_instance.bound_type
239
+
240
+ # attr_dict["bounds"] = class_instance.bounds
241
+
242
+ # return attr_dict.copy()
243
+
244
+ # Function to convert class instance to dictionary
245
+ def class_convert_to_dict(class_instance):
246
+ attr_dict = {}
247
+
248
+ if isinstance(class_instance, Channel):
249
+ # Convert Channel instance to dictionary
250
+ attr_dict["type"] = "Channel"
251
+ attr_dict["name"] = class_instance.name
252
+ attr_dict["dates"] = class_instance.dates
253
+ attr_dict["spends"] = list(class_instance.actual_spends)
254
+ attr_dict["conversion_rate"] = class_instance.conversion_rate
255
+ attr_dict["modified_spends"] = list(class_instance.modified_spends)
256
+ attr_dict["modified_sales"] = list(class_instance.modified_sales)
257
+ attr_dict["response_curve_type"] = class_instance.response_curve_type
258
+ attr_dict["response_curve_params"] = class_instance.response_curve_params.copy()
259
+ attr_dict["bounds"] = class_instance.bounds.copy()
260
+ attr_dict["actual_total_spends"] = class_instance.actual_total_spends
261
+ attr_dict["actual_total_sales"] = class_instance.actual_total_sales
262
+ attr_dict["modified_total_spends"] = class_instance.modified_total_spends
263
+ attr_dict["modified_total_sales"] = class_instance.modified_total_sales
264
+ attr_dict["freeze"] = class_instance.freeze
265
+ attr_dict["correction"] = class_instance.correction.copy()
266
+
267
+ elif isinstance(class_instance, Scenario):
268
+ # Convert Scenario instance to dictionary
269
+ attr_dict["type"] = "Scenario"
270
+ attr_dict["name"] = class_instance.name
271
+
272
+ channels = {}
273
+ for channel in class_instance.channels.values():
274
+ channels[channel.name] = class_convert_to_dict(channel)
275
+ attr_dict["channels"] = channels
276
+
277
+ attr_dict["constant"] = list(class_instance.constant)
278
+ attr_dict["correction"] = list(class_instance.correction)
279
+ attr_dict["actual_total_spends"] = class_instance.actual_total_spends
280
+ attr_dict["actual_total_sales"] = class_instance.actual_total_sales
281
+ attr_dict["modified_total_spends"] = class_instance.modified_total_spends
282
+ attr_dict["modified_total_sales"] = class_instance.modified_total_sales
283
+ attr_dict["bound_type"] = class_instance.bound_type
284
+ attr_dict["bounds"] = class_instance.bounds.copy()
285
+
286
+ return attr_dict
287
+
288
+
289
+ def class_from_dict(attr_dict):
290
+ if attr_dict["type"] == "Channel":
291
+ return Channel.from_dict(attr_dict)
292
+ elif attr_dict["type"] == "Scenario":
293
+ return Scenario.from_dict(attr_dict)
294
+
295
+
296
+ class Channel:
297
+ def __init__(
298
+ self,
299
+ name,
300
+ dates,
301
+ spends,
302
+ response_curve_type,
303
+ response_curve_params,
304
+ bounds,
305
+ correction,
306
+ conversion_rate=1,
307
+ modified_spends=None,
308
+ penalty=True,
309
+ freeze=False,
310
+ ):
311
+ self.name = name
312
+ self.dates = dates
313
+ self.conversion_rate = conversion_rate
314
+ self.actual_spends = spends.copy()
315
+ self.correction = correction
316
+
317
+ if modified_spends is None:
318
+ self.modified_spends = self.actual_spends.copy()
319
+ else:
320
+ self.modified_spends = modified_spends
321
+
322
+ self.response_curve_type = response_curve_type
323
+ self.response_curve_params = response_curve_params
324
+ self.bounds = bounds
325
+ self.penalty = penalty
326
+ self.freeze = freeze
327
+
328
+ self.upper_limit = self.actual_spends.max() + self.actual_spends.std()
329
+ self.power = np.ceil(np.log(self.actual_spends.max()) / np.log(10)) - 3
330
+ # self.actual_sales = None
331
+ # self.actual_sales = self.response_curve(self.actual_spends)
332
+ self.actual_total_spends = self.actual_spends.sum()
333
+ self.actual_total_sales = self.actual_sales.sum()
334
+ self.modified_sales = self.calculate_sales()
335
+ self.modified_total_spends = self.modified_spends.sum()
336
+ self.modified_total_sales = self.modified_sales.sum()
337
+ self.delta_spends = self.modified_total_spends - self.actual_total_spends
338
+ self.delta_sales = self.modified_total_sales - self.actual_total_sales
339
+
340
+ @property
341
+ def actual_sales(self):
342
+ return self.response_curve(self.actual_spends) + self.correction
343
+
344
+ def update_penalty(self, penalty):
345
+ self.penalty = penalty
346
+
347
+ def _modify_spends(self, spends_array, total_spends):
348
+ return spends_array * total_spends / spends_array.sum()
349
+
350
+ def modify_spends(self, total_spends):
351
+ self.modified_spends = (
352
+ self.modified_spends * total_spends / self.modified_spends.sum()
353
+ )
354
+
355
+ def calculate_sales(self):
356
+ return self.response_curve(self.modified_spends) + self.correction
357
+
358
+ def response_curve(self, x):
359
+ if self.penalty:
360
+ x = np.where(
361
+ x < self.upper_limit,
362
+ x,
363
+ self.upper_limit + (x - self.upper_limit) * self.upper_limit / x,
364
+ )
365
+ if self.response_curve_type == "s-curve":
366
+ if self.power >= 0:
367
+ x = x / 10**self.power
368
+ x = x.astype("float64")
369
+ K = self.response_curve_params["K"]
370
+ b = self.response_curve_params["b"]
371
+ a = self.response_curve_params["a"]
372
+ x0 = self.response_curve_params["x0"]
373
+ sales = K / (1 + b * np.exp(-a * (x - x0)))
374
+ if self.response_curve_type == "linear":
375
+ beta = self.response_curve_params["beta"]
376
+ sales = beta * x
377
+
378
+ return sales
379
+
380
+ def get_marginal_roi(self, flag):
381
+ K = self.response_curve_params["K"]
382
+ a = self.response_curve_params["a"]
383
+ # x = self.modified_total_spends
384
+ # if self.power >= 0 :
385
+ # x = x / 10**self.power
386
+ # x = x.astype('float64')
387
+ # return K*b*a*np.exp(-a*(x-x0)) / (1 + b * np.exp(-a*(x - x0)))**2
388
+ if flag == "actual":
389
+ y = self.response_curve(self.actual_spends)
390
+ # spends_array = self.actual_spends
391
+ # total_spends = self.actual_total_spends
392
+ # total_sales = self.actual_total_sales
393
+
394
+ else:
395
+ y = self.response_curve(self.modified_spends)
396
+ # spends_array = self.modified_spends
397
+ # total_spends = self.modified_total_spends
398
+ # total_sales = self.modified_total_sales
399
+
400
+ # spends_inc_1 = self._modify_spends(spends_array, total_spends+1)
401
+ mroi = a * (y) * (1 - y / K)
402
+ return mroi.sum() / len(self.modified_spends)
403
+ # spends_inc_1 = self.spends_array + 1
404
+ # new_total_sales = self.response_curve(spends_inc_1).sum()
405
+ # return (new_total_sales - total_sales) / len(self.modified_spends)
406
+
407
+ def update(self, total_spends):
408
+ self.modify_spends(total_spends)
409
+ self.modified_sales = self.calculate_sales()
410
+ self.modified_total_spends = self.modified_spends.sum()
411
+ self.modified_total_sales = self.modified_sales.sum()
412
+ self.delta_spends = self.modified_total_spends - self.actual_total_spends
413
+ self.delta_sales = self.modified_total_sales - self.actual_total_sales
414
+
415
+ def intialize(self):
416
+ self.new_spends = self.old_spends
417
+
418
+ def __str__(self):
419
+ return f"{self.name},{self.actual_total_sales}, {self.modified_total_spends}"
420
+
421
+ @classmethod
422
+ def from_dict(cls, attr_dict):
423
+ return Channel(
424
+ name=attr_dict["name"],
425
+ dates=attr_dict["dates"],
426
+ spends=attr_dict["spends"],
427
+ bounds=attr_dict["bounds"],
428
+ modified_spends=attr_dict["modified_spends"],
429
+ response_curve_type=attr_dict["response_curve_type"],
430
+ response_curve_params=attr_dict["response_curve_params"],
431
+ penalty=attr_dict["penalty"],
432
+ correction=attr_dict["correction"],
433
+ )
434
+
435
+ def update_response_curves(self, response_curve_params):
436
+ self.response_curve_params = response_curve_params
437
+
438
+
439
+ class Scenario:
440
+ def __init__(
441
+ self, name, channels, constant, correction, bound_type=False, bounds=[-10, 10]
442
+ ):
443
+ self.name = name
444
+ self.channels = channels
445
+ self.constant = constant
446
+ self.correction = correction
447
+ self.bound_type = bound_type
448
+ self.bounds = bounds
449
+
450
+ self.actual_total_spends = self.calculate_modified_total_spends()
451
+ self.actual_total_sales = self.calculate_actual_total_sales()
452
+ self.modified_total_sales = self.calculate_modified_total_sales()
453
+ self.modified_total_spends = self.calculate_modified_total_spends()
454
+ self.delta_spends = self.modified_total_spends - self.actual_total_spends
455
+ self.delta_sales = self.modified_total_sales - self.actual_total_sales
456
+
457
+ def update_penalty(self, value):
458
+ for channel in self.channels.values():
459
+ channel.update_penalty(value)
460
+
461
+ def calculate_modified_total_spends(self):
462
+ total_actual_spends = 0.0
463
+ for channel in self.channels.values():
464
+ total_actual_spends += channel.actual_total_spends * channel.conversion_rate
465
+ return total_actual_spends
466
+
467
+ def calculate_modified_total_spends(self):
468
+ total_modified_spends = 0.0
469
+ for channel in self.channels.values():
470
+ # import streamlit as st
471
+ # st.write(channel.modified_total_spends )
472
+ total_modified_spends += (
473
+ channel.modified_total_spends * channel.conversion_rate
474
+ )
475
+ return total_modified_spends
476
+
477
+ def calculate_actual_total_sales(self):
478
+ total_actual_sales = self.constant.sum() + self.correction.sum()
479
+ for channel in self.channels.values():
480
+ total_actual_sales += channel.actual_total_sales
481
+ return total_actual_sales
482
+
483
+ def calculate_modified_total_sales(self):
484
+ total_modified_sales = self.constant.sum() + self.correction.sum()
485
+ for channel in self.channels.values():
486
+ total_modified_sales += channel.modified_total_sales
487
+ return total_modified_sales
488
+
489
+ def update(self, channel_name, modified_spends):
490
+ self.channels[channel_name].update(modified_spends)
491
+ self.modified_total_sales = self.calculate_modified_total_sales()
492
+ self.modified_total_spends = self.calculate_modified_total_spends()
493
+ self.delta_spends = self.modified_total_spends - self.actual_total_spends
494
+ self.delta_sales = self.modified_total_sales - self.actual_total_sales
495
+
496
+ # def optimize_spends(self, sales_percent, channels_list, algo="COBYLA"):
497
+ # desired_sales = self.actual_total_sales * (1 + sales_percent / 100.0)
498
+
499
+ # def constraint(x):
500
+ # for ch, spends in zip(channels_list, x):
501
+ # self.update(ch, spends)
502
+ # return self.modified_total_sales - desired_sales
503
+
504
+ # bounds = []
505
+ # for ch in channels_list:
506
+ # bounds.append(
507
+ # (1 + np.array([-50.0, 100.0]) / 100.0)
508
+ # * self.channels[ch].actual_total_spends
509
+ # )
510
+
511
+ # initial_point = []
512
+ # for bound in bounds:
513
+ # initial_point.append(bound[0])
514
+
515
+ # power = np.ceil(np.log(sum(initial_point)) / np.log(10))
516
+
517
+ # constraints = [NonlinearConstraint(constraint, -1.0, 1.0)]
518
+
519
+ # res = minimize(
520
+ # lambda x: sum(x) / 10 ** (power),
521
+ # bounds=bounds,
522
+ # x0=initial_point,
523
+ # constraints=constraints,
524
+ # method=algo,
525
+ # options={"maxiter": int(2e7), "catol": 1},
526
+ # )
527
+
528
+ # for channel_name, modified_spends in zip(channels_list, res.x):
529
+ # self.update(channel_name, modified_spends)
530
+
531
+ # return zip(channels_list, res.x)
532
+
533
+ def optimize_spends(self, sales_percent, channels_list, algo="trust-constr"):
534
+ desired_sales = self.actual_total_sales * (1 + sales_percent / 100.0)
535
+
536
+ def constraint(x):
537
+ for ch, spends in zip(channels_list, x):
538
+ self.update(ch, spends)
539
+ return self.modified_total_sales - desired_sales
540
+
541
+ bounds = []
542
+ for ch in channels_list:
543
+ bounds.append(
544
+ (1 + np.array([-50.0, 100.0]) / 100.0)
545
+ * self.channels[ch].actual_total_spends
546
+ )
547
+
548
+ initial_point = []
549
+ for bound in bounds:
550
+ initial_point.append(bound[0])
551
+
552
+ power = np.ceil(np.log(sum(initial_point)) / np.log(10))
553
+
554
+ constraints = [NonlinearConstraint(constraint, -1.0, 1.0)]
555
+
556
+ res = minimize(
557
+ lambda x: sum(x) / 10 ** (power),
558
+ bounds=bounds,
559
+ x0=initial_point,
560
+ constraints=constraints,
561
+ method=algo,
562
+ options={"maxiter": int(2e7), "xtol": 100},
563
+ )
564
+
565
+ for channel_name, modified_spends in zip(channels_list, res.x):
566
+ self.update(channel_name, modified_spends)
567
+
568
+ return zip(channels_list, res.x)
569
+
570
+ def optimize(self, spends_percent, channels_list):
571
+ # channels_list = self.channels.keys()
572
+ num_channels = len(channels_list)
573
+ spends_constant = []
574
+ spends_constraint = 0.0
575
+ for channel_name in channels_list:
576
+ # spends_constraint += self.channels[channel_name].modified_total_spends
577
+ spends_constant.append(self.channels[channel_name].conversion_rate)
578
+ spends_constraint += (
579
+ self.channels[channel_name].actual_total_spends
580
+ * self.channels[channel_name].conversion_rate
581
+ )
582
+ spends_constraint = spends_constraint * (1 + spends_percent / 100)
583
+ # constraint= LinearConstraint(np.ones((num_channels,)), lb = spends_constraint, ub = spends_constraint)
584
+ constraint = LinearConstraint(
585
+ np.array(spends_constant),
586
+ lb=spends_constraint,
587
+ ub=spends_constraint,
588
+ )
589
+ bounds = []
590
+ old_spends = []
591
+ for channel_name in channels_list:
592
+ _channel_class = self.channels[channel_name]
593
+ channel_bounds = _channel_class.bounds
594
+ channel_actual_total_spends = _channel_class.actual_total_spends * (
595
+ (1 + spends_percent / 100)
596
+ )
597
+ old_spends.append(channel_actual_total_spends)
598
+ bounds.append((1 + channel_bounds / 100) * channel_actual_total_spends)
599
+
600
+ def objective_function(x):
601
+ for channel_name, modified_spends in zip(channels_list, x):
602
+ self.update(channel_name, modified_spends)
603
+ return -1 * self.modified_total_sales
604
+
605
+ power = np.ceil(np.log(self.modified_total_sales) / np.log(10))
606
+
607
+ res = minimize(
608
+ lambda x: objective_function(x) / 10 ** (power - 1),
609
+ method="trust-constr",
610
+ x0=old_spends,
611
+ constraints=constraint,
612
+ bounds=bounds,
613
+ options={"maxiter": int(1e7), "xtol": 0.1},
614
+ )
615
+ # res = dual_annealing(
616
+ # objective_function,
617
+ # x0=old_spends,
618
+ # mi
619
+ # constraints=constraint,
620
+ # bounds=bounds,
621
+ # tol=1e-16
622
+
623
+ for channel_name, modified_spends in zip(channels_list, res.x):
624
+ self.update(channel_name, modified_spends)
625
+
626
+ return zip(channels_list, res.x)
627
+
628
+ def save(self):
629
+ details = {}
630
+ actual_list = []
631
+ modified_list = []
632
+ data = {}
633
+ channel_data = []
634
+
635
+ summary_rows = []
636
+ actual_list.append(
637
+ {
638
+ "name": "Total",
639
+ "Spends": self.actual_total_spends,
640
+ "Sales": self.actual_total_sales,
641
+ }
642
+ )
643
+ modified_list.append(
644
+ {
645
+ "name": "Total",
646
+ "Spends": self.modified_total_spends,
647
+ "Sales": self.modified_total_sales,
648
+ }
649
+ )
650
+ for channel in self.channels.values():
651
+ name_mod = channel.name.replace("_", " ")
652
+ if name_mod.lower().endswith(" imp"):
653
+ name_mod = name_mod.replace("Imp", " Impressions")
654
+ summary_rows.append(
655
+ [
656
+ name_mod,
657
+ channel.actual_total_spends,
658
+ channel.modified_total_spends,
659
+ channel.actual_total_sales,
660
+ channel.modified_total_sales,
661
+ round(
662
+ channel.actual_total_sales / channel.actual_total_spends,
663
+ 2,
664
+ ),
665
+ round(
666
+ channel.modified_total_sales / channel.modified_total_spends,
667
+ 2,
668
+ ),
669
+ channel.get_marginal_roi("actual"),
670
+ channel.get_marginal_roi("modified"),
671
+ ]
672
+ )
673
+ data[channel.name] = channel.modified_spends
674
+ data["Date"] = channel.dates
675
+ data["Sales"] = (
676
+ data.get("Sales", np.zeros((len(channel.dates),)))
677
+ + channel.modified_sales
678
+ )
679
+ actual_list.append(
680
+ {
681
+ "name": channel.name,
682
+ "Spends": channel.actual_total_spends,
683
+ "Sales": channel.actual_total_sales,
684
+ "ROI": round(
685
+ channel.actual_total_sales / channel.actual_total_spends,
686
+ 2,
687
+ ),
688
+ }
689
+ )
690
+ modified_list.append(
691
+ {
692
+ "name": channel.name,
693
+ "Spends": channel.modified_total_spends,
694
+ "Sales": channel.modified_total_sales,
695
+ "ROI": round(
696
+ channel.modified_total_sales / channel.modified_total_spends,
697
+ 2,
698
+ ),
699
+ "Marginal ROI": channel.get_marginal_roi("modified"),
700
+ }
701
+ )
702
+
703
+ channel_data.append(
704
+ {
705
+ "channel": channel.name,
706
+ "spends_act": channel.actual_total_spends,
707
+ "spends_mod": channel.modified_total_spends,
708
+ "sales_act": channel.actual_total_sales,
709
+ "sales_mod": channel.modified_total_sales,
710
+ }
711
+ )
712
+ summary_rows.append(
713
+ [
714
+ "Total",
715
+ self.actual_total_spends,
716
+ self.modified_total_spends,
717
+ self.actual_total_sales,
718
+ self.modified_total_sales,
719
+ round(self.actual_total_sales / self.actual_total_spends, 2),
720
+ round(self.modified_total_sales / self.modified_total_spends, 2),
721
+ 0.0,
722
+ 0.0,
723
+ ]
724
+ )
725
+ details["Actual"] = actual_list
726
+ details["Modified"] = modified_list
727
+ columns_index = pd.MultiIndex.from_product(
728
+ [[""], ["Channel"]], names=["first", "second"]
729
+ )
730
+ columns_index = columns_index.append(
731
+ pd.MultiIndex.from_product(
732
+ [["Spends", "NRPU", "ROI", "MROI"], ["Actual", "Simulated"]],
733
+ names=["first", "second"],
734
+ )
735
+ )
736
+ details["Summary"] = pd.DataFrame(summary_rows, columns=columns_index)
737
+ data_df = pd.DataFrame(data)
738
+ channel_list = list(self.channels.keys())
739
+ data_df = data_df[["Date", *channel_list, "Sales"]]
740
+
741
+ details["download"] = {
742
+ "data_df": data_df,
743
+ "channels_df": pd.DataFrame(channel_data),
744
+ "total_spends_act": self.actual_total_spends,
745
+ "total_sales_act": self.actual_total_sales,
746
+ "total_spends_mod": self.modified_total_spends,
747
+ "total_sales_mod": self.modified_total_sales,
748
+ }
749
+
750
+ return details
751
+
752
+ @classmethod
753
+ def from_dict(cls, attr_dict):
754
+ channels_list = attr_dict["channels"]
755
+ channels = {
756
+ channel["name"]: class_from_dict(channel) for channel in channels_list
757
+ }
758
+ return Scenario(
759
+ name=attr_dict["name"],
760
+ channels=channels,
761
+ constant=attr_dict["constant"],
762
+ correction=attr_dict["correction"],
763
+ )
single_manifest.yml ADDED
File without changes
styles.css ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ html {
2
+ margin: 0;
3
+ }
4
+
5
+
6
+ #MainMenu {
7
+
8
+ visibility: collapse;
9
+ }
10
+
11
+ footer {
12
+ visibility: collapse;
13
+ }
14
+
15
+ div.block-container{
16
+ padding: 2rem 3rem;
17
+ }
18
+
19
+ div[data-testid="stExpander"]{
20
+ border: 1px solid '#739FAE';
21
+ }
22
+
23
+ a[data-testid="stPageLink-NavLink"]{
24
+ border: 1px solid rgba(0, 110, 192, 0.2);
25
+ margin: 0;
26
+ padding: 0.25rem 0.5rem;
27
+ border-radius: 0.5rem;
28
+ }
29
+
30
+
31
+
32
+ hr{
33
+ margin: 0;
34
+ padding: 0;
35
+ }
36
+
37
+ hr.spends-heading-seperator {
38
+ background-color : #11B6BD;
39
+ height: 2px;
40
+ }
41
+
42
+ .spends-header{
43
+ font-size: 1rem;
44
+ font-weight: bold;
45
+ margin: 0;
46
+
47
+ }
48
+
49
+ td {
50
+ max-width: 100px;
51
+ /* white-space:nowrap; */
52
+ }
53
+
54
+ .main-header {
55
+ display: flex;
56
+ flex-direction: row;
57
+ justify-content: space-between;
58
+ align-items: center;
59
+
60
+ }
61
+ .blend-logo {
62
+ max-height: 64px;
63
+ /* max-width: 300px; */
64
+ object-fit: cover;
65
+ }
66
+
67
+ table {
68
+ width: 90%;
69
+ }
70
+
71
+ .lime-logo {
72
+ margin: 0;
73
+ padding: 0;
74
+ display: flex;
75
+ align-items: center ;
76
+ max-height: 64px;
77
+ }
78
+
79
+ .lime-text {
80
+ color: #00EDED;
81
+ font-size: 30px;
82
+ margin: 0;
83
+ padding: 0;
84
+ line-height: 0%;
85
+ }
86
+
87
+ .lime-img {
88
+ max-height: 30px;
89
+ /* max-height: 64px; */
90
+ /* max-width: 300px; */
91
+ object-fit: cover;
92
+ }
93
+
94
+ img {
95
+ margin: 0;
96
+ padding: 0;
97
+ }
temp_stdout.txt ADDED
File without changes
utilities.py ADDED
@@ -0,0 +1,2155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import json
4
+ from scenario import Channel, Scenario
5
+ import numpy as np
6
+ from plotly.subplots import make_subplots
7
+ import plotly.graph_objects as go
8
+ from scenario import class_to_dict
9
+ from collections import OrderedDict
10
+ import io
11
+ import plotly
12
+ from pathlib import Path
13
+ import pickle
14
+ import yaml
15
+ from yaml import SafeLoader
16
+ from streamlit.components.v1 import html
17
+ import smtplib
18
+ from scipy.optimize import curve_fit
19
+ from sklearn.metrics import r2_score
20
+ from scenario import class_from_dict, class_convert_to_dict
21
+ import os
22
+ import base64
23
+ import sqlite3
24
+ import datetime
25
+ from scenario import numerize
26
+ import psycopg2
27
+
28
+ #
29
+ import re
30
+ import bcrypt
31
+ import os
32
+ import json
33
+ import glob
34
+ import pickle
35
+ import streamlit as st
36
+ import streamlit as st
37
+ import pandas as pd
38
+ import json
39
+ from scenario import Channel, Scenario
40
+ import numpy as np
41
+ from plotly.subplots import make_subplots
42
+ import plotly.graph_objects as go
43
+ from scenario import class_to_dict
44
+ from collections import OrderedDict
45
+ import io
46
+ import plotly
47
+ from pathlib import Path
48
+ import pickle
49
+ import yaml
50
+ from yaml import SafeLoader
51
+ from streamlit.components.v1 import html
52
+ import smtplib
53
+ from scipy.optimize import curve_fit
54
+ from sklearn.metrics import r2_score
55
+ from scenario import class_from_dict, class_convert_to_dict
56
+ import os
57
+ import base64
58
+ import sqlite3
59
+ import datetime
60
+ from scenario import numerize
61
+ import sqlite3
62
+
63
+ # # schema = db_cred["schema"]
64
+
65
+ color_palette = [
66
+ "#F3F3F0",
67
+ "#5E7D7E",
68
+ "#2FA1FF",
69
+ "#00EDED",
70
+ "#00EAE4",
71
+ "#304550",
72
+ "#EDEBEB",
73
+ "#7FBEFD",
74
+ "#003059",
75
+ "#A2F3F3",
76
+ "#E1D6E2",
77
+ "#B6B6B6",
78
+ ]
79
+
80
+
81
+ CURRENCY_INDICATOR = "$"
82
+ db_cred = None
83
+ # database_file = r"DB/User.db"
84
+
85
+ # conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db
86
+ # c = conn.cursor()
87
+
88
+
89
+ # def query_excecuter_postgres(
90
+ # query,
91
+ # db_cred,
92
+ # params=None,
93
+ # insert=True,
94
+ # insert_retrieve=False,
95
+ # ):
96
+ # """
97
+ # Executes a SQL query on a PostgreSQL database, handling both insert and select operations.
98
+
99
+ # Parameters:
100
+ # query (str): The SQL query to be executed.
101
+ # params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
102
+ # insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
103
+ # insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
104
+
105
+ # """
106
+ # # Database connection parameters
107
+ # dbname = db_cred["dbname"]
108
+ # user = db_cred["user"]
109
+ # password = db_cred["password"]
110
+ # host = db_cred["host"]
111
+ # port = db_cred["port"]
112
+
113
+ # try:
114
+ # # Establish connection to the PostgreSQL database
115
+ # conn = psycopg2.connect(
116
+ # dbname=dbname, user=user, password=password, host=host, port=port
117
+ # )
118
+ # except psycopg2.Error as e:
119
+ # st.warning(f"Unable to connect to the database: {e}")
120
+ # st.stop()
121
+
122
+ # # Create a cursor object to interact with the database
123
+ # c = conn.cursor()
124
+
125
+ # try:
126
+ # # Execute the query with or without parameters
127
+ # if params:
128
+ # c.execute(query, params)
129
+ # else:
130
+ # c.execute(query)
131
+
132
+ # if not insert:
133
+ # # If not an insert operation, fetch and return the results
134
+ # results = c.fetchall()
135
+ # return results
136
+ # elif insert_retrieve:
137
+ # # If insert and retrieve operation, fetch and return the results
138
+ # conn.commit()
139
+ # return c.fetchall()
140
+ # else:
141
+ # conn.commit()
142
+
143
+ # except Exception as e:
144
+ # st.write(f"Error executing query: {e}")
145
+ # finally:
146
+ # conn.close()
147
+
148
+
149
+ db_path = os.path.join("imp_db.db")
150
+
151
+
152
+ def query_excecuter_postgres(
153
+ query, db_path=None, params=None, insert=True, insert_retrieve=False, db_cred=None
154
+ ):
155
+ """
156
+ Executes a SQL query on a SQLite database, handling both insert and select operations.
157
+
158
+ Parameters:
159
+ query (str): The SQL query to be executed.
160
+ db_path (str): Path to the SQLite database file.
161
+ params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
162
+ insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
163
+ insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
164
+
165
+ """
166
+ try:
167
+ # Construct a cross-platform path to the database
168
+ db_dir = os.path.join("db")
169
+ os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists
170
+ db_path = os.path.join(db_dir, "imp_db.db")
171
+
172
+ # Establish connection to the SQLite database
173
+ conn = sqlite3.connect(db_path)
174
+ except sqlite3.Error as e:
175
+ st.warning(f"Unable to connect to the SQLite database: {e}")
176
+ st.stop()
177
+
178
+ # Create a cursor object to interact with the database
179
+ c = conn.cursor()
180
+
181
+ # Prepare the query with proper placeholders
182
+ if params:
183
+ # Handle the `IN (?)` clause dynamically
184
+ query = query.replace("IN (?)", f"IN ({','.join(['?' for _ in params])})")
185
+ c.execute(query, params)
186
+ else:
187
+ c.execute(query)
188
+
189
+ try:
190
+ if not insert:
191
+ # If not an insert operation, fetch and return the results
192
+ results = c.fetchall()
193
+ return results
194
+ elif insert_retrieve:
195
+ # If insert and retrieve operation, commit and return the last inserted row ID
196
+ conn.commit()
197
+ return c.lastrowid
198
+ else:
199
+ # For standard insert operations, commit the transaction
200
+ conn.commit()
201
+
202
+ except Exception as e:
203
+ st.write(f"Error executing query: {e}")
204
+ finally:
205
+ conn.close()
206
+
207
+
208
+ def update_summary_df():
209
+ """
210
+ Updates the 'project_summary_df' in the session state with the latest project
211
+ summary information based on the most recent updates.
212
+
213
+ This function executes a SQL query to retrieve project metadata from a database
214
+ and stores the result in the session state.
215
+
216
+ Uses:
217
+ - query_excecuter_postgres(query, params=params, insert=False): A function that
218
+ executes the provided SQL query on a PostgreSQL database.
219
+
220
+ Modifies:
221
+ - st.session_state['project_summary_df']: Updates the dataframe with columns:
222
+ 'Project Number', 'Project Name', 'Last Modified Page', 'Last Modified Time'.
223
+ """
224
+
225
+ query = f"""
226
+ WITH LatestUpdates AS (
227
+ SELECT
228
+ prj_id,
229
+ page_nam,
230
+ updt_dt_tm,
231
+ ROW_NUMBER() OVER (PARTITION BY prj_id ORDER BY updt_dt_tm DESC) AS rn
232
+ FROM
233
+ mmo_project_meta_data
234
+ )
235
+ SELECT
236
+ p.prj_id,
237
+ p.prj_nam AS prj_nam,
238
+ lu.page_nam,
239
+ lu.updt_dt_tm
240
+ FROM
241
+ LatestUpdates lu
242
+ RIGHT JOIN
243
+ mmo_projects p ON lu.prj_id = p.prj_id
244
+ WHERE
245
+ p.prj_ownr_id = ? AND lu.rn = 1
246
+ """
247
+
248
+ params = (st.session_state["emp_id"],) # Parameters for the SQL query
249
+
250
+ # Execute the query and retrieve project summary data
251
+ project_summary = query_excecuter_postgres(
252
+ query, db_cred, params=params, insert=False
253
+ )
254
+
255
+ # Update the session state with the project summary dataframe
256
+ st.session_state["project_summary_df"] = pd.DataFrame(
257
+ project_summary,
258
+ columns=[
259
+ "Project Number",
260
+ "Project Name",
261
+ "Last Modified Page",
262
+ "Last Modified Time",
263
+ ],
264
+ )
265
+
266
+ st.session_state["project_summary_df"] = st.session_state[
267
+ "project_summary_df"
268
+ ].sort_values(by=["Last Modified Time"], ascending=False)
269
+
270
+ return st.session_state["project_summary_df"]
271
+
272
+
273
+ from constants import default_dct
274
+
275
+
276
+ def ensure_project_dct_structure(session_state, default_dct):
277
+ for key, value in default_dct.items():
278
+ if key not in session_state:
279
+ session_state[key] = value
280
+ elif isinstance(value, dict):
281
+ ensure_project_dct_structure(session_state[key], value)
282
+
283
+
284
+ def project_selection():
285
+
286
+ emp_id = st.text_input("employee id", key="emp1111").lower()
287
+ password = st.text_input("Password", max_chars=15, type="password")
288
+
289
+ if st.button("Login"):
290
+
291
+ if "unique_ids" not in st.session_state:
292
+ unique_users_query = f"""
293
+ SELECT DISTINCT emp_id, emp_nam, emp_typ from mmo_users;
294
+ """
295
+ unique_users_result = query_excecuter_postgres(
296
+ unique_users_query, db_cred, insert=False
297
+ ) # retrieves all the users who has access to MMO TOOL
298
+ st.session_state["unique_ids"] = {
299
+ emp_id: (emp_nam, emp_type)
300
+ for emp_id, emp_nam, emp_type in unique_users_result
301
+ }
302
+
303
+ if emp_id not in st.session_state["unique_ids"].keys() or len(password) == 0:
304
+ st.warning("invalid id or password!")
305
+ st.stop()
306
+
307
+ if not is_pswrd_flag_set(emp_id):
308
+ st.warning("Reset password in home page to continue")
309
+ st.stop()
310
+
311
+ elif not verify_password(emp_id, password):
312
+ st.warning("Invalid user name or password")
313
+ st.stop()
314
+
315
+ else:
316
+ st.session_state["emp_id"] = emp_id
317
+ st.session_state["username"] = st.session_state["unique_ids"][
318
+ st.session_state["emp_id"]
319
+ ][0]
320
+
321
+ with st.spinner("Loading Saved Projects"):
322
+ st.session_state["project_summary_df"] = update_summary_df()
323
+
324
+ # st.write(st.session_state["project_name"][0])
325
+ if len(st.session_state["project_summary_df"]) == 0:
326
+ st.warning("No projects found please create a project in Home page")
327
+ st.stop()
328
+
329
+ else:
330
+
331
+ try:
332
+ st.session_state["project_name"] = (
333
+ st.session_state["project_summary_df"]
334
+ .loc[
335
+ st.session_state["project_summary_df"]["Project Number"]
336
+ == st.session_state["project_summary_df"].iloc[0, 0],
337
+ "Project Name",
338
+ ]
339
+ .values[0]
340
+ ) # fetching project name from project number stored in summary df
341
+
342
+ poroject_dct_query = f"""
343
+
344
+ SELECT pkl_obj FROM mmo_project_meta_data WHERE prj_id = ? AND file_nam=?;
345
+
346
+ """
347
+ # Execute the query and retrieve the result
348
+
349
+ project_number = int(st.session_state["project_summary_df"].iloc[0, 0])
350
+
351
+ st.session_state["project_number"] = project_number
352
+
353
+ project_dct_retrieved = query_excecuter_postgres(
354
+ poroject_dct_query,
355
+ db_cred,
356
+ params=(project_number, "project_dct"),
357
+ insert=False,
358
+ )
359
+ # retrieves project dict (meta data) stored in db
360
+
361
+ st.session_state["project_dct"] = pickle.loads(
362
+ project_dct_retrieved[0][0]
363
+ ) # converting bytes data to original objet using pickle
364
+ ensure_project_dct_structure(
365
+ st.session_state["project_dct"], default_dct
366
+ )
367
+
368
+ st.success("Project Loded")
369
+ st.rerun()
370
+
371
+ except Exception as e:
372
+
373
+ st.write(
374
+ "Failed to load project meta data from db please create new project!"
375
+ )
376
+ st.stop()
377
+
378
+
379
+ def update_db(prj_id, page_nam, file_nam, pkl_obj, resp_mtrc="", schema=""):
380
+
381
+ # Check if an entry already exists
382
+
383
+ check_query = f"""
384
+ SELECT 1 FROM mmo_project_meta_data
385
+ WHERE prj_id = ? AND file_nam =?;
386
+ """
387
+
388
+ check_params = (prj_id, file_nam)
389
+ result = query_excecuter_postgres(
390
+ check_query, db_cred, params=check_params, insert=False
391
+ )
392
+
393
+ # If entry exists, perform an update
394
+ if result is not None and result:
395
+
396
+ update_query = f"""
397
+ UPDATE mmo_project_meta_data
398
+ SET file_nam = ?, pkl_obj = ?, page_nam=? ,updt_dt_tm = datetime('now')
399
+
400
+ WHERE prj_id = ? AND file_nam = ?;
401
+ """
402
+
403
+ update_params = (file_nam, pkl_obj, page_nam, prj_id, file_nam)
404
+
405
+ query_excecuter_postgres(
406
+ update_query, db_cred, params=update_params, insert=True
407
+ )
408
+
409
+ # If entry does not exist, perform an insert
410
+ else:
411
+
412
+ insert_query = f"""
413
+ INSERT INTO mmo_project_meta_data
414
+ (prj_id, page_nam, file_nam, pkl_obj,crte_by_uid, crte_dt_tm, updt_dt_tm)
415
+ VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'));
416
+ """
417
+
418
+ insert_params = (
419
+ prj_id,
420
+ page_nam,
421
+ file_nam,
422
+ pkl_obj,
423
+ st.session_state["emp_id"],
424
+ )
425
+
426
+ query_excecuter_postgres(
427
+ insert_query, db_cred, params=insert_params, insert=True
428
+ )
429
+
430
+ # st.success(f"Inserted project meta data for project {prj_id}, page {page_nam}")
431
+
432
+
433
+ def retrieve_pkl_object(prj_id, page_nam, file_nam, schema=""):
434
+
435
+ query = f"""
436
+ SELECT pkl_obj FROM mmo_project_meta_data
437
+ WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
438
+ """
439
+
440
+ params = (prj_id, page_nam, file_nam)
441
+ result = query_excecuter_postgres(
442
+ query, db_cred=db_cred, params=params, insert=False
443
+ )
444
+
445
+ if result and result[0] and result[0][0]:
446
+ pkl_obj = result[0][0]
447
+ # Deserialize the pickle object
448
+ return pickle.loads(pkl_obj)
449
+ else:
450
+ return None
451
+
452
+
453
+ def validate_text(input_text):
454
+
455
+ # Check the length of the text
456
+ if len(input_text) < 2:
457
+ return False, "Input should be at least 2 characters long."
458
+ if len(input_text) > 30:
459
+ return False, "Input should not exceed 30 characters."
460
+
461
+ # Check if the text contains only allowed characters
462
+ if not re.match(r"^[A-Za-z0-9_]+$", input_text):
463
+ return (
464
+ False,
465
+ "Input contains invalid characters. Only letters, numbers and underscores are allowed.",
466
+ )
467
+
468
+ return True, "Input is valid."
469
+
470
+
471
+ def delete_entries(prj_id, page_names, db_cred=None, schema=None):
472
+ """
473
+ Deletes all entries from the project_meta_data table based on prj_id and a list of page names.
474
+
475
+ Parameters:
476
+ prj_id (int): The project ID.
477
+ page_names (list): A list of page names.
478
+ db_cred (dict): Database credentials with keys 'dbname', 'user', 'password', 'host', 'port'.
479
+ schema (str): The schema name.
480
+ """
481
+ # Create placeholders for each page name in the list
482
+ placeholders = ", ".join(["?"] * len(page_names))
483
+ query = f"""
484
+ DELETE FROM mmo_project_meta_data
485
+ WHERE prj_id = ? AND page_nam IN ({placeholders});
486
+ """
487
+
488
+ # Combine prj_id and page_names into one list of parameters
489
+ params = (prj_id, *page_names)
490
+
491
+ query_excecuter_postgres(query, db_cred, params=params, insert=True)
492
+
493
+
494
+ # st.success(f"Deleted entries for project {prj_id}, page {page_name}")
495
+ def store_hashed_password(
496
+ user_id,
497
+ plain_text_password,
498
+ ):
499
+ """
500
+ Hashes a plain text password using bcrypt, converts it to a UTF-8 string, and stores it as text.
501
+
502
+ Parameters:
503
+ plain_text_password (str): The plain text password to be hashed.
504
+ db_cred (dict): The database credentials including dbname, user, password, host, and port.
505
+ """
506
+ # Hash the plain text password
507
+ hashed_password = bcrypt.hashpw(
508
+ plain_text_password.encode("utf-8"), bcrypt.gensalt()
509
+ )
510
+
511
+ # Convert the byte string to a regular string for storage
512
+ hashed_password_str = hashed_password.decode("utf-8")
513
+
514
+ # SQL query to update the pswrd_key for the specified user_id
515
+ query = f"""
516
+ UPDATE mmo_users
517
+ SET pswrd_key = ?
518
+ WHERE emp_id = ?;
519
+ """
520
+
521
+ # Execute the query using the existing query_excecuter_postgres function
522
+ query_excecuter_postgres(
523
+ query=query, db_cred=db_cred, params=(hashed_password_str, user_id), insert=True
524
+ )
525
+
526
+
527
+ def verify_password(user_id, plain_text_password):
528
+ """
529
+ Verifies the plain text password against the stored hashed password for the specified user_id.
530
+
531
+ Parameters:
532
+ user_id (int): The ID of the user whose password is being verified.
533
+ plain_text_password (str): The plain text password to verify.
534
+ db_cred (dict): The database credentials including dbname, user, password, host, and port.
535
+ """
536
+ # SQL query to retrieve the hashed password for the user_id
537
+ query = f"""
538
+ SELECT pswrd_key FROM mmo_users WHERE emp_id = ?;
539
+ """
540
+
541
+ # Execute the query using the existing query_excecuter_postgres function
542
+ result = query_excecuter_postgres(
543
+ query=query, db_cred=db_cred, params=(user_id,), insert=False
544
+ )
545
+
546
+ if result:
547
+
548
+ stored_hashed_password_str = result[0][0]
549
+ # Convert the stored string back to bytes
550
+ stored_hashed_password = stored_hashed_password_str.encode("utf-8")
551
+
552
+ if bcrypt.checkpw(plain_text_password.encode("utf-8"), stored_hashed_password):
553
+
554
+ return True
555
+ else:
556
+
557
+ return False
558
+ else:
559
+
560
+ return False
561
+
562
+
563
+ def update_password_in_db(user_id, plain_text_password):
564
+ """
565
+ Hashes the plain text password and updates the `pswrd_key`
566
+ column for the given `emp_id` in the `mmo_users` table.
567
+
568
+ Parameters:
569
+ emp_id (var): The ID of the user whose password needs to be updated.
570
+ plain_text_password (str): The plain text password to be hashed and stored.
571
+ db_cred (dict): Database credentials required to connect to the database.
572
+ """
573
+ # Hash the plain text password using bcrypt
574
+ hashed_password = bcrypt.hashpw(
575
+ plain_text_password.encode("utf-8"), bcrypt.gensalt()
576
+ )
577
+
578
+ # Convert the hashed password from bytes to a string for storage
579
+ hashed_password_str = hashed_password.decode("utf-8")
580
+
581
+ # SQL query to update the password in the database
582
+ query = f"""
583
+ UPDATE mmo_users
584
+ SET pswrd_key = ?
585
+ WHERE emp_id = ?
586
+ """
587
+
588
+ # Parameters for the query
589
+ params = (hashed_password_str, user_id)
590
+
591
+ # Execute the query using the query_excecuter_postgres function
592
+ query_excecuter_postgres(query, db_cred, params=params, insert=True)
593
+
594
+
595
+ def is_pswrd_flag_set(user_id):
596
+ query = f"""
597
+ SELECT pswrd_flag
598
+ FROM mmo_users
599
+ WHERE emp_id = ?;
600
+ """
601
+
602
+ # Execute the query
603
+ result = query_excecuter_postgres(query, db_cred, params=(user_id,), insert=False)
604
+
605
+ # Return True if the flag is 1, otherwise return False
606
+ if result and result[0][0] == 1:
607
+ return True
608
+ else:
609
+ return False
610
+
611
+
612
+ def set_pswrd_flag(user_id):
613
+ query = f"""
614
+ UPDATE mmo_users
615
+ SET pswrd_flag = 1
616
+ WHERE emp_id = ?;
617
+ """
618
+
619
+ # Execute the update query
620
+ query_excecuter_postgres(query, db_cred, params=(user_id,), insert=True)
621
+
622
+
623
+ def retrieve_pkl_object_without_warning(prj_id, page_nam, file_nam, schema):
624
+
625
+ query = f"""
626
+ SELECT pkl_obj FROM mmo_project_meta_data
627
+ WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
628
+ """
629
+
630
+ params = (prj_id, page_nam, file_nam)
631
+ result = query_excecuter_postgres(
632
+ query, db_cred=db_cred, params=params, insert=False
633
+ )
634
+
635
+ if result and result[0] and result[0][0]:
636
+ pkl_obj = result[0][0]
637
+ # Deserialize the pickle object
638
+ return pickle.loads(pkl_obj)
639
+ else:
640
+ # st.warning(
641
+ # "Pickle object not found for the given project ID, page name, and file name."
642
+ # )
643
+ return None
644
+
645
+
646
+ color_palette = [
647
+ "#F3F3F0",
648
+ "#5E7D7E",
649
+ "#2FA1FF",
650
+ "#00EDED",
651
+ "#00EAE4",
652
+ "#304550",
653
+ "#EDEBEB",
654
+ "#7FBEFD",
655
+ "#003059",
656
+ "#A2F3F3",
657
+ "#E1D6E2",
658
+ "#B6B6B6",
659
+ ]
660
+
661
+
662
+ CURRENCY_INDICATOR = "$"
663
+
664
+
665
+ # database_file = r"DB/User.db"
666
+
667
+ # conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db
668
+ # c = conn.cursor()
669
+
670
+
671
+ # def load_authenticator():
672
+ # with open("config.yaml") as file:
673
+ # config = yaml.load(file, Loader=SafeLoader)
674
+ # st.session_state["config"] = config
675
+ # authenticator = stauth.Authenticate(
676
+ # credentials=config["credentials"],
677
+ # cookie_name=config["cookie"]["name"],
678
+ # key=config["cookie"]["key"],
679
+ # cookie_expiry_days=config["cookie"]["expiry_days"],
680
+ # preauthorized=config["preauthorized"],
681
+ # )
682
+ # st.session_state["authenticator"] = authenticator
683
+ # return authenticator
684
+
685
+
686
+ # Authentication
687
+ # def authenticator():
688
+ # for k, v in st.session_state.items():
689
+ # if k not in ["logout", "login", "config"] and not k.startswith(
690
+ # "FormSubmitter"
691
+ # ):
692
+ # st.session_state[k] = v
693
+ # with open("config.yaml") as file:
694
+ # config = yaml.load(file, Loader=SafeLoader)
695
+ # st.session_state["config"] = config
696
+ # authenticator = stauth.Authenticate(
697
+ # config["credentials"],
698
+ # config["cookie"]["name"],
699
+ # config["cookie"]["key"],
700
+ # config["cookie"]["expiry_days"],
701
+ # config["preauthorized"],
702
+ # )
703
+ # st.session_state["authenticator"] = authenticator
704
+ # name, authentication_status, username = authenticator.login(
705
+ # "Login", "main"
706
+ # )
707
+ # auth_status = st.session_state.get("authentication_status")
708
+
709
+ # if auth_status == True:
710
+ # authenticator.logout("Logout", "main")
711
+ # is_state_initiaized = st.session_state.get("initialized", False)
712
+
713
+ # if not is_state_initiaized:
714
+
715
+ # if "session_name" not in st.session_state:
716
+ # st.session_state["session_name"] = None
717
+
718
+ # return name
719
+
720
+
721
+ # def authentication():
722
+ # with open("config.yaml") as file:
723
+ # config = yaml.load(file, Loader=SafeLoader)
724
+
725
+ # authenticator = stauth.Authenticate(
726
+ # config["credentials"],
727
+ # config["cookie"]["name"],
728
+ # config["cookie"]["key"],
729
+ # config["cookie"]["expiry_days"],
730
+ # config["preauthorized"],
731
+ # )
732
+
733
+ # name, authentication_status, username = authenticator.login(
734
+ # "Login", "main"
735
+ # )
736
+ # return authenticator, name, authentication_status, username
737
+
738
+
739
+ def nav_page(page_name, timeout_secs=3):
740
+ nav_script = """
741
+ <script type="text/javascript">
742
+ function attempt_nav_page(page_name, start_time, timeout_secs) {
743
+ var links = window.parent.document.getElementsByTagName("a");
744
+ for (var i = 0; i < links.length; i++) {
745
+ if (links[i].href.toLowerCase().endsWith("/" + page_name.toLowerCase())) {
746
+ links[i].click();
747
+ return;
748
+ }
749
+ }
750
+ var elasped = new Date() - start_time;
751
+ if (elasped < timeout_secs * 1000) {
752
+ setTimeout(attempt_nav_page, 100, page_name, start_time, timeout_secs);
753
+ } else {
754
+ alert("Unable to navigate to page '" + page_name + "' after " + timeout_secs + " second(s).");
755
+ }
756
+ }
757
+ window.addEventListener("load", function() {
758
+ attempt_nav_page("%s", new Date(), %d);
759
+ });
760
+ </script>
761
+ """ % (
762
+ page_name,
763
+ timeout_secs,
764
+ )
765
+ html(nav_script)
766
+
767
+
768
+ # def load_local_css(file_name):
769
+ # with open(file_name) as f:
770
+ # st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
771
+
772
+
773
+ # def set_header():
774
+ # return st.markdown(f"""<div class='main-header'>
775
+ # <h1>MMM LiME</h1>
776
+ # <img src="https://assets-global.website-files.com/64c8fffb0e95cbc525815b79/64df84637f83a891c1473c51_Vector%20(Stroke).svg ">
777
+ # </div>""", unsafe_allow_html=True)
778
+
779
+ path = os.path.dirname(__file__)
780
+
781
+ file_ = open(f"{path}/logo.png", "rb")
782
+
783
+ contents = file_.read()
784
+
785
+ data_url = base64.b64encode(contents).decode("utf-8")
786
+
787
+ file_.close()
788
+
789
+
790
+ DATA_PATH = "./data"
791
+
792
+ IMAGES_PATH = "./data/images_224_224"
793
+
794
+
795
+ def load_local_css(file_name):
796
+
797
+ with open(file_name) as f:
798
+
799
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
800
+
801
+
802
+ # def set_header():
803
+
804
+ # return st.markdown(f"""<div class='main-header'>
805
+
806
+ # <h1>H & M Recommendations</h1>
807
+
808
+ # <img src="data:image;base64,{data_url}", alt="Logo">
809
+
810
+ # </div>""", unsafe_allow_html=True)
811
+ path1 = os.path.dirname(__file__)
812
+
813
+ # file_1 = open(f"{path}/willbank.png", "rb")
814
+
815
+ # contents1 = file_1.read()
816
+
817
+ # data_url1 = base64.b64encode(contents1).decode("utf-8")
818
+
819
+ # file_1.close()
820
+
821
+
822
+ DATA_PATH1 = "./data"
823
+
824
+ IMAGES_PATH1 = "./data/images_224_224"
825
+
826
+
827
+ def set_header():
828
+ return st.markdown(
829
+ f"""<div class='main-header'>
830
+ <!-- <h1></h1> -->
831
+ <div >
832
+ <img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo">
833
+ </div>""",
834
+ unsafe_allow_html=True,
835
+ )
836
+
837
+
838
+ # def set_header():
839
+ # logo_path = "./path/to/your/local/LIME_logo.png" # Replace with the actual file path
840
+ # text = "LiME"
841
+ # return st.markdown(f"""<div class='main-header'>
842
+ # <img src="data:image/png;base64,{data_url}" alt="Logo" style="float: left; margin-right: 10px; width: 100px; height: auto;">
843
+ # <h1>{text}</h1>
844
+ # </div>""", unsafe_allow_html=True)
845
+
846
+
847
+ def s_curve(x, K, b, a, x0):
848
+ return K / (1 + b * np.exp(-a * (x - x0)))
849
+
850
+
851
+ def panel_level(input_df, date_column="Date"):
852
+ # Ensure 'Date' is set as the index
853
+ if date_column not in input_df.index.names:
854
+ input_df = input_df.set_index(date_column)
855
+
856
+ # Select numeric columns only (excluding 'Date' since it's now the index)
857
+ numeric_columns_df = input_df.select_dtypes(include="number")
858
+
859
+ # Group by 'Date' (which is the index) and sum the numeric columns
860
+ aggregated_df = numeric_columns_df.groupby(input_df.index).sum()
861
+
862
+ # Reset the index to bring the 'Date' column
863
+ aggregated_df = aggregated_df.reset_index()
864
+
865
+ return aggregated_df
866
+
867
+
868
+ def fetch_actual_data(
869
+ panel=None,
870
+ target_file="Overview_data_test.xlsx",
871
+ updated_rcs=None,
872
+ metrics=None,
873
+ ):
874
+ excel = pd.read_excel(Path(target_file), sheet_name=None)
875
+
876
+ # Extract dataframes for raw data, spend input, and contribution MMM
877
+ raw_df = excel["RAW DATA MMM"]
878
+ spend_df = excel["SPEND INPUT"]
879
+ contri_df = excel["CONTRIBUTION MMM"]
880
+
881
+ # Check if the panel is not None
882
+ if panel is not None and panel != "Aggregated":
883
+ raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
884
+ spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
885
+ contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
886
+ elif panel == "Aggregated":
887
+ raw_df = panel_level(raw_df, date_column="Date")
888
+ spend_df = panel_level(spend_df, date_column="Week")
889
+ contri_df = panel_level(contri_df, date_column="Date")
890
+
891
+ # Revenue_df = excel['Revenue']
892
+
893
+ ## remove sesonalities, indices etc ...
894
+ unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
895
+ ## remove sesonalities, indices etc ...
896
+
897
+ exclude_columns = [
898
+ "Date",
899
+ "Region",
900
+ "Controls_Grammarly_Index_SeasonalAVG",
901
+ "Controls_Quillbot_Index",
902
+ "Daily_Positive_Outliers",
903
+ "External_RemoteClass_Index",
904
+ "Intervals ON 20190520-20190805 | 20200518-20200803 | 20210517-20210802",
905
+ "Intervals ON 20190826-20191209 | 20200824-20201207 | 20210823-20211206",
906
+ "Intervals ON 20201005-20201019",
907
+ "Promotion_PercentOff",
908
+ "Promotion_TimeBased",
909
+ "Seasonality_Indicator_Chirstmas",
910
+ "Seasonality_Indicator_NewYears_Days",
911
+ "Seasonality_Indicator_Thanksgiving",
912
+ "Trend 20200302 / 20200803",
913
+ ] + unnamed_cols
914
+
915
+ raw_df["Date"] = pd.to_datetime(raw_df["Date"])
916
+ contri_df["Date"] = pd.to_datetime(contri_df["Date"])
917
+ input_df = raw_df.sort_values(by="Date")
918
+ output_df = contri_df.sort_values(by="Date")
919
+ spend_df["Week"] = pd.to_datetime(
920
+ spend_df["Week"], format="%Y-%m-%d", errors="coerce"
921
+ )
922
+ spend_df.sort_values(by="Week", inplace=True)
923
+
924
+ # spend_df['Week'] = pd.to_datetime(spend_df['Week'], errors='coerce')
925
+ # spend_df = spend_df.sort_values(by='Week')
926
+
927
+ channel_list = [col for col in input_df.columns if col not in exclude_columns]
928
+ channel_list = list(set(channel_list) - set(["fb_level_achieved_tier_1", "ga_app"]))
929
+
930
+ infeasible_channels = [
931
+ c
932
+ for c in contri_df.select_dtypes(include=["float", "int"]).columns
933
+ if contri_df[c].sum() <= 0
934
+ ]
935
+ # st.write(channel_list)
936
+ channel_list = list(set(channel_list) - set(infeasible_channels))
937
+
938
+ upper_limits = {}
939
+ output_cols = []
940
+ actual_output_dic = {}
941
+ actual_input_dic = {}
942
+
943
+ for inp_col in channel_list:
944
+ # st.write(inp_col)
945
+ spends = input_df[inp_col].values
946
+ x = spends.copy()
947
+ # upper limit for penalty
948
+ upper_limits[inp_col] = 2 * x.max()
949
+
950
+ # contribution
951
+ # out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
952
+ out_col = inp_col
953
+ y = output_df[out_col].values.copy()
954
+ actual_output_dic[inp_col] = y.copy()
955
+ actual_input_dic[inp_col] = x.copy()
956
+ ##output cols aggregation
957
+ output_cols.append(out_col)
958
+
959
+ return pd.DataFrame(actual_input_dic), pd.DataFrame(actual_output_dic)
960
+
961
+
962
+ # Function to initialize model results data
963
+ def initialize_data(panel=None, metrics=None):
964
+ # Extract dataframes for raw data, spend input, and contribution data
965
+ raw_df = st.session_state["project_dct"]["current_media_performance"][
966
+ "model_outputs"
967
+ ][metrics]["raw_data"].copy()
968
+ spend_df = st.session_state["project_dct"]["current_media_performance"][
969
+ "model_outputs"
970
+ ][metrics]["spends_data"].copy()
971
+ contribution_df = st.session_state["project_dct"]["current_media_performance"][
972
+ "model_outputs"
973
+ ][metrics]["contribution_data"].copy()
974
+
975
+ # Check if 'Panel' or 'panel' is in the columns
976
+ panel_column = None
977
+ if "Panel" in raw_df.columns:
978
+ panel_column = "Panel"
979
+ elif "panel" in raw_df.columns:
980
+ panel_column = "panel"
981
+
982
+ # Filter data by panel if provided
983
+ if panel and panel.lower() != "aggregated":
984
+ raw_df = raw_df[raw_df[panel_column] == panel].drop(columns=[panel_column])
985
+ spend_df = spend_df[spend_df[panel_column] == panel].drop(
986
+ columns=[panel_column]
987
+ )
988
+ contribution_df = contribution_df[contribution_df[panel_column] == panel].drop(
989
+ columns=[panel_column]
990
+ )
991
+ else:
992
+ raw_df = panel_level(raw_df, date_column="Date")
993
+ spend_df = panel_level(spend_df, date_column="Date")
994
+ contribution_df = panel_level(contribution_df, date_column="Date")
995
+
996
+ # Remove unnecessary columns
997
+ unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
998
+ exclude_columns = ["Date"] + unnamed_cols
999
+
1000
+ # Convert Date columns to datetime
1001
+ for df in [raw_df, spend_df, contribution_df]:
1002
+ df["Date"] = pd.to_datetime(df["Date"], format="%Y-%m-%d", errors="coerce")
1003
+
1004
+ # Sort data by Date
1005
+ input_df = raw_df.sort_values(by="Date")
1006
+ contribution_df = contribution_df.sort_values(by="Date")
1007
+ spend_df.sort_values(by="Date", inplace=True)
1008
+
1009
+ # Extract channels excluding unwanted columns
1010
+ channel_list = [col for col in input_df.columns if col not in exclude_columns]
1011
+
1012
+ # Filter out channels with non-positive contributions
1013
+ negative_contributions = [
1014
+ col
1015
+ for col in contribution_df.select_dtypes(include=["float", "int"]).columns
1016
+ if contribution_df[col].sum() <= 0
1017
+ ]
1018
+ channel_list = list(set(channel_list) - set(negative_contributions))
1019
+
1020
+ # Initialize dictionaries for metrics and response curves
1021
+ response_curves, mapes, rmses, upper_limits = {}, {}, {}, {}
1022
+ r2_scores, powers, conversion_rates, actual_output, actual_input = (
1023
+ {},
1024
+ {},
1025
+ {},
1026
+ {},
1027
+ {},
1028
+ )
1029
+ channels = {}
1030
+ sales = None
1031
+ dates = input_df["Date"].values
1032
+
1033
+ # Fit s-curve for each channel
1034
+ for channel in channel_list:
1035
+ spends = input_df[channel].values
1036
+ x = spends.copy()
1037
+ upper_limits[channel] = 2 * x.max()
1038
+
1039
+ # Get corresponding output column
1040
+ output_col = [
1041
+ _col for _col in contribution_df.columns if _col.startswith(channel)
1042
+ ][0]
1043
+ y = contribution_df[output_col].values.copy()
1044
+ actual_output[channel] = y.copy()
1045
+ actual_input[channel] = x.copy()
1046
+
1047
+ # Scale input data
1048
+ power = np.ceil(np.log(x.max()) / np.log(10)) - 3
1049
+ if power >= 0:
1050
+ x = x / 10**power
1051
+ x, y = x.astype("float64"), y.astype("float64")
1052
+
1053
+ # Set bounds for curve fitting
1054
+ if y.max() <= 0.01:
1055
+ bounds = (
1056
+ (0, 0, 0, 0),
1057
+ (3 * 0.01, 1000, 1, x.max() if x.max() > 0 else 0.01),
1058
+ )
1059
+ else:
1060
+ bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
1061
+
1062
+ # Set y to 0 where x is 0
1063
+ y[x == 0] = 0
1064
+
1065
+ # Fit s-curve and calculate metrics
1066
+ # params, _ = curve_fit(
1067
+ # s_curve,
1068
+ # x
1069
+ # y,
1070
+ # p0=(2 * y.max(), 0.01, 1e-5, x.max()),
1071
+ # bounds=bounds,
1072
+ # maxfev=int(1e6),
1073
+ # )
1074
+ params, _ = curve_fit(
1075
+ s_curve,
1076
+ list(x) + [0] * len(x),
1077
+ list(y) + [0] * len(y),
1078
+ p0=(2 * y.max(), 0.01, 1e-5, x.max()),
1079
+ bounds=bounds,
1080
+ maxfev=int(1e6),
1081
+ )
1082
+
1083
+ mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
1084
+ rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
1085
+ r2_score_ = r2_score(y, s_curve(x, *params))
1086
+
1087
+ # Store metrics and parameters
1088
+ response_curves[channel] = {
1089
+ "K": params[0],
1090
+ "b": params[1],
1091
+ "a": params[2],
1092
+ "x0": params[3],
1093
+ }
1094
+ mapes[channel] = mape
1095
+ rmses[channel] = rmse
1096
+ r2_scores[channel] = r2_score_
1097
+ powers[channel] = power
1098
+
1099
+ conversion_rate = spend_df[channel].sum() / max(input_df[channel].sum(), 1e-9)
1100
+ conversion_rates[channel] = conversion_rate
1101
+ correction = y - s_curve(x, *params)
1102
+
1103
+ # Initialize Channel object
1104
+ channel_obj = Channel(
1105
+ name=channel,
1106
+ dates=dates,
1107
+ spends=spends,
1108
+ conversion_rate=conversion_rate,
1109
+ response_curve_type="s-curve",
1110
+ response_curve_params={
1111
+ "K": params[0],
1112
+ "b": params[1],
1113
+ "a": params[2],
1114
+ "x0": params[3],
1115
+ },
1116
+ bounds=np.array([-10, 10]),
1117
+ correction=correction,
1118
+ )
1119
+ channels[channel] = channel_obj
1120
+ if sales is None:
1121
+ sales = channel_obj.actual_sales
1122
+ else:
1123
+ sales += channel_obj.actual_sales
1124
+
1125
+ # Calculate other contributions
1126
+ other_contributions = (
1127
+ contribution_df.drop(columns=[*response_curves.keys()])
1128
+ .sum(axis=1, numeric_only=True)
1129
+ .values
1130
+ )
1131
+
1132
+ # Initialize Scenario object
1133
+ scenario = Scenario(
1134
+ name="default",
1135
+ channels=channels,
1136
+ constant=other_contributions,
1137
+ correction=np.array([]),
1138
+ )
1139
+
1140
+ # Set session state variables
1141
+ st.session_state.update(
1142
+ {
1143
+ "initialized": True,
1144
+ "actual_df": input_df,
1145
+ "raw_df": raw_df,
1146
+ "contri_df": contribution_df,
1147
+ "default_scenario_dict": class_to_dict(scenario),
1148
+ "scenario": scenario,
1149
+ "channels_list": channel_list,
1150
+ "optimization_channels": {
1151
+ channel_name: False for channel_name in channel_list
1152
+ },
1153
+ "rcs": response_curves.copy(),
1154
+ "powers": powers,
1155
+ "actual_contribution_df": pd.DataFrame(actual_output),
1156
+ "actual_input_df": pd.DataFrame(actual_input),
1157
+ "xlsx_buffer": io.BytesIO(),
1158
+ "saved_scenarios": (
1159
+ pickle.load(open("../saved_scenarios.pkl", "rb"))
1160
+ if Path("../saved_scenarios.pkl").exists()
1161
+ else OrderedDict()
1162
+ ),
1163
+ "disable_download_button": True,
1164
+ }
1165
+ )
1166
+
1167
+ for channel in channels.values():
1168
+ st.session_state[channel.name] = numerize(
1169
+ channel.actual_total_spends * channel.conversion_rate, 1
1170
+ )
1171
+
1172
+ # Prepare response curve data for output
1173
+ response_curve_data = {}
1174
+ for channel, params in st.session_state["rcs"].items():
1175
+ x = st.session_state["actual_input_df"][channel].values.astype(float)
1176
+ y = st.session_state["actual_contribution_df"][channel].values.astype(float)
1177
+ power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
1178
+ x_plot = list(np.linspace(0, 5 * max(x), 100))
1179
+
1180
+ response_curve_data[channel] = {
1181
+ "K": float(params["K"]),
1182
+ "b": float(params["b"]),
1183
+ "a": float(params["a"]),
1184
+ "x0": float(params["x0"]),
1185
+ "power": power,
1186
+ "x": list(x),
1187
+ "y": list(y),
1188
+ "x_plot": x_plot,
1189
+ }
1190
+
1191
+ return response_curve_data, scenario
1192
+
1193
+
1194
+ # def initialize_data(panel=None, metrics=None):
1195
+ # # Extract dataframes for raw data, spend input, and contribution data
1196
+ # raw_df = st.session_state["project_dct"]["current_media_performance"][
1197
+ # "model_outputs"
1198
+ # ][metrics]["raw_data"]
1199
+ # spend_df = st.session_state["project_dct"]["current_media_performance"][
1200
+ # "model_outputs"
1201
+ # ][metrics]["spends_data"]
1202
+ # contri_df = st.session_state["project_dct"]["current_media_performance"][
1203
+ # "model_outputs"
1204
+ # ][metrics]["contribution_data"]
1205
+
1206
+ # # Check if the panel is not None
1207
+ # if panel is not None and panel.lower() != "aggregated":
1208
+ # raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
1209
+ # spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
1210
+ # contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
1211
+ # elif panel.lower() == "aggregated":
1212
+ # raw_df = panel_level(raw_df, date_column="Date")
1213
+ # spend_df = panel_level(spend_df, date_column="Date")
1214
+ # contri_df = panel_level(contri_df, date_column="Date")
1215
+
1216
+ # ## remove sesonalities, indices etc ...
1217
+ # unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
1218
+
1219
+ # ## remove sesonalities, indices etc ...
1220
+ # exclude_columns = ["Date"] + unnamed_cols
1221
+
1222
+ # raw_df["Date"] = pd.to_datetime(raw_df["Date"], format="%Y-%m-%d", errors="coerce")
1223
+ # contri_df["Date"] = pd.to_datetime(
1224
+ # contri_df["Date"], format="%Y-%m-%d", errors="coerce"
1225
+ # )
1226
+ # spend_df["Date"] = pd.to_datetime(
1227
+ # spend_df["Date"], format="%Y-%m-%d", errors="coerce"
1228
+ # )
1229
+
1230
+ # input_df = raw_df.sort_values(by="Date")
1231
+ # output_df = contri_df.sort_values(by="Date")
1232
+ # spend_df.sort_values(by="Date", inplace=True)
1233
+
1234
+ # channel_list = [col for col in input_df.columns if col not in exclude_columns]
1235
+
1236
+ # negative_contribution = [
1237
+ # c
1238
+ # for c in contri_df.select_dtypes(include=["float", "int"]).columns
1239
+ # if contri_df[c].sum() <= 0
1240
+ # ]
1241
+ # channel_list = list(set(channel_list) - set(negative_contribution))
1242
+
1243
+ # response_curves = {}
1244
+ # mapes = {}
1245
+ # rmses = {}
1246
+ # upper_limits = {}
1247
+ # powers = {}
1248
+ # r2 = {}
1249
+ # conv_rates = {}
1250
+ # output_cols = []
1251
+ # channels = {}
1252
+ # sales = None
1253
+ # dates = input_df.Date.values
1254
+ # actual_output_dic = {}
1255
+ # actual_input_dic = {}
1256
+
1257
+ # for inp_col in channel_list:
1258
+ # spends = input_df[inp_col].values
1259
+ # x = spends.copy()
1260
+ # # upper limit for penalty
1261
+ # upper_limits[inp_col] = 2 * x.max()
1262
+
1263
+ # # contribution
1264
+ # out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
1265
+ # y = output_df[out_col].values.copy()
1266
+ # actual_output_dic[inp_col] = y.copy()
1267
+ # actual_input_dic[inp_col] = x.copy()
1268
+ # ##output cols aggregation
1269
+ # output_cols.append(out_col)
1270
+
1271
+ # ## scale the input
1272
+ # power = np.ceil(np.log(x.max()) / np.log(10)) - 3
1273
+ # if power >= 0:
1274
+ # x = x / 10**power
1275
+
1276
+ # x = x.astype("float64")
1277
+ # y = y.astype("float64")
1278
+
1279
+ # if y.max() <= 0.01:
1280
+ # if x.max() <= 0.0:
1281
+ # bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, 0.01))
1282
+
1283
+ # else:
1284
+ # bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, x.max()))
1285
+ # else:
1286
+ # bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
1287
+
1288
+ # params, _ = curve_fit(
1289
+ # s_curve,
1290
+ # x,
1291
+ # y,
1292
+ # p0=(2 * y.max(), 0.01, 1e-5, x.max()),
1293
+ # bounds=bounds,
1294
+ # maxfev=int(1e5),
1295
+ # )
1296
+ # mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
1297
+ # rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
1298
+ # r2_ = r2_score(y, s_curve(x, *params))
1299
+
1300
+ # response_curves[inp_col] = {
1301
+ # "K": params[0],
1302
+ # "b": params[1],
1303
+ # "a": params[2],
1304
+ # "x0": params[3],
1305
+ # }
1306
+
1307
+ # mapes[inp_col] = mape
1308
+ # rmses[inp_col] = rmse
1309
+ # r2[inp_col] = r2_
1310
+ # powers[inp_col] = power
1311
+
1312
+ # conv = spend_df[inp_col].sum() / max(input_df[inp_col].sum(), 1e-9)
1313
+ # conv_rates[inp_col] = conv
1314
+
1315
+ # correction = y - s_curve(x, *params)
1316
+
1317
+ # channel = Channel(
1318
+ # name=inp_col,
1319
+ # dates=dates,
1320
+ # spends=spends,
1321
+ # conversion_rate=conv_rates[inp_col],
1322
+ # response_curve_type="s-curve",
1323
+ # response_curve_params={
1324
+ # "K": params[0],
1325
+ # "b": params[1],
1326
+ # "a": params[2],
1327
+ # "x0": params[3],
1328
+ # },
1329
+ # bounds=np.array([-10, 10]),
1330
+ # correction=correction,
1331
+ # )
1332
+
1333
+ # channels[inp_col] = channel
1334
+ # if sales is None:
1335
+ # sales = channel.actual_sales
1336
+ # else:
1337
+ # sales += channel.actual_sales
1338
+
1339
+ # other_contributions = (
1340
+ # output_df.drop([*output_cols], axis=1).sum(axis=1, numeric_only=True).values
1341
+ # )
1342
+
1343
+ # scenario = Scenario(
1344
+ # name="default",
1345
+ # channels=channels,
1346
+ # constant=other_contributions,
1347
+ # correction=np.array([]),
1348
+ # )
1349
+
1350
+ # ## setting session variables
1351
+ # st.session_state["initialized"] = True
1352
+ # st.session_state["actual_df"] = input_df
1353
+ # st.session_state["raw_df"] = raw_df
1354
+ # st.session_state["contri_df"] = output_df
1355
+ # default_scenario_dict = class_to_dict(scenario)
1356
+ # st.session_state["default_scenario_dict"] = default_scenario_dict
1357
+ # st.session_state["scenario"] = scenario
1358
+ # st.session_state["channels_list"] = channel_list
1359
+ # st.session_state["optimization_channels"] = {
1360
+ # channel_name: False for channel_name in channel_list
1361
+ # }
1362
+ # st.session_state["rcs"] = response_curves.copy()
1363
+
1364
+ # st.session_state["powers"] = powers
1365
+ # st.session_state["actual_contribution_df"] = pd.DataFrame(actual_output_dic)
1366
+ # st.session_state["actual_input_df"] = pd.DataFrame(actual_input_dic)
1367
+
1368
+ # for channel in channels.values():
1369
+ # st.session_state[channel.name] = numerize(
1370
+ # channel.actual_total_spends * channel.conversion_rate, 1
1371
+ # )
1372
+
1373
+ # st.session_state["xlsx_buffer"] = io.BytesIO()
1374
+
1375
+ # if Path("../saved_scenarios.pkl").exists():
1376
+ # with open("../saved_scenarios.pkl", "rb") as f:
1377
+ # st.session_state["saved_scenarios"] = pickle.load(f)
1378
+ # else:
1379
+ # st.session_state["saved_scenarios"] = OrderedDict()
1380
+
1381
+ # # st.session_state["total_spends_change"] = 0
1382
+ # st.session_state["optimization_channels"] = {
1383
+ # channel_name: False for channel_name in channel_list
1384
+ # }
1385
+ # st.session_state["disable_download_button"] = True
1386
+
1387
+ # rcs_data = {}
1388
+ # for channel in st.session_state["rcs"]:
1389
+ # # Convert to native Python lists and types
1390
+ # x = list(st.session_state["actual_input_df"][channel].values.astype(float))
1391
+ # y = list(
1392
+ # st.session_state["actual_contribution_df"][channel].values.astype(float)
1393
+ # )
1394
+ # power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
1395
+ # x_plot = list(np.linspace(0, 5 * max(x), 100))
1396
+
1397
+ # rcs_data[channel] = {
1398
+ # "K": float(st.session_state["rcs"][channel]["K"]),
1399
+ # "b": float(st.session_state["rcs"][channel]["b"]),
1400
+ # "a": float(st.session_state["rcs"][channel]["a"]),
1401
+ # "x0": float(st.session_state["rcs"][channel]["x0"]),
1402
+ # "power": power,
1403
+ # "x": x,
1404
+ # "y": y,
1405
+ # "x_plot": x_plot,
1406
+ # }
1407
+
1408
+ # return rcs_data, scenario
1409
+
1410
+
1411
+ # def initialize_data():
1412
+ # # fetch data from excel
1413
+ # output = pd.read_excel('data.xlsx',sheet_name=None)
1414
+ # raw_df = output['RAW DATA MMM']
1415
+ # contribution_df = output['CONTRIBUTION MMM']
1416
+ # Revenue_df = output['Revenue']
1417
+
1418
+ # ## channels to be shows
1419
+ # channel_list = []
1420
+ # for col in raw_df.columns:
1421
+ # if 'click' in col.lower() or 'spend' in col.lower() or 'imp' in col.lower():
1422
+ # channel_list.append(col)
1423
+ # else:
1424
+ # pass
1425
+
1426
+ # ## NOTE : Considered only Desktop spends for all calculations
1427
+ # acutal_df = raw_df[raw_df.Region == 'Desktop'].copy()
1428
+ # ## NOTE : Considered one year of data
1429
+ # acutal_df = acutal_df[acutal_df.Date>'2020-12-31']
1430
+ # actual_df = acutal_df.drop('Region',axis=1).sort_values(by='Date')[[*channel_list,'Date']]
1431
+
1432
+ # ##load response curves
1433
+ # with open('./grammarly_response_curves.json','r') as f:
1434
+ # response_curves = json.load(f)
1435
+
1436
+ # ## create channel dict for scenario creation
1437
+ # dates = actual_df.Date.values
1438
+ # channels = {}
1439
+ # rcs = {}
1440
+ # constant = 0.
1441
+ # for i,info_dict in enumerate(response_curves):
1442
+ # name = info_dict.get('name')
1443
+ # response_curve_type = info_dict.get('response_curve')
1444
+ # response_curve_params = info_dict.get('params')
1445
+ # rcs[name] = response_curve_params
1446
+ # if name != 'constant':
1447
+ # spends = actual_df[name].values
1448
+ # channel = Channel(name=name,dates=dates,
1449
+ # spends=spends,
1450
+ # response_curve_type=response_curve_type,
1451
+ # response_curve_params=response_curve_params,
1452
+ # bounds=np.array([-30,30]))
1453
+
1454
+ # channels[name] = channel
1455
+ # else:
1456
+ # constant = info_dict.get('value',0.) * len(dates)
1457
+
1458
+ # ## create scenario
1459
+ # scenario = Scenario(name='default', channels=channels, constant=constant)
1460
+ # default_scenario_dict = class_to_dict(scenario)
1461
+
1462
+
1463
+ # ## setting session variables
1464
+ # st.session_state['initialized'] = True
1465
+ # st.session_state['actual_df'] = actual_df
1466
+ # st.session_state['raw_df'] = raw_df
1467
+ # st.session_state['default_scenario_dict'] = default_scenario_dict
1468
+ # st.session_state['scenario'] = scenario
1469
+ # st.session_state['channels_list'] = channel_list
1470
+ # st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
1471
+ # st.session_state['rcs'] = rcs
1472
+ # for channel in channels.values():
1473
+ # if channel.name not in st.session_state:
1474
+ # st.session_state[channel.name] = float(channel.actual_total_spends)
1475
+
1476
+ # if 'xlsx_buffer' not in st.session_state:
1477
+ # st.session_state['xlsx_buffer'] = io.BytesIO()
1478
+
1479
+ # ## for saving scenarios
1480
+ # if 'saved_scenarios' not in st.session_state:
1481
+ # if Path('../saved_scenarios.pkl').exists():
1482
+ # with open('../saved_scenarios.pkl','rb') as f:
1483
+ # st.session_state['saved_scenarios'] = pickle.load(f)
1484
+
1485
+ # else:
1486
+ # st.session_state['saved_scenarios'] = OrderedDict()
1487
+
1488
+ # if 'total_spends_change' not in st.session_state:
1489
+ # st.session_state['total_spends_change'] = 0
1490
+
1491
+ # if 'optimization_channels' not in st.session_state:
1492
+ # st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
1493
+
1494
+ # if 'disable_download_button' not in st.session_state:
1495
+ # st.session_state['disable_download_button'] = True
1496
+
1497
+
1498
+ def create_channel_summary(scenario):
1499
+
1500
+ # Provided data
1501
+ data = {
1502
+ "Channel": [
1503
+ "Paid Search",
1504
+ "Ga will cid baixo risco",
1505
+ "Digital tactic others",
1506
+ "Fb la tier 1",
1507
+ "Fb la tier 2",
1508
+ "Paid social others",
1509
+ "Programmatic",
1510
+ "Kwai",
1511
+ "Indicacao",
1512
+ "Infleux",
1513
+ "Influencer",
1514
+ ],
1515
+ "Spends": [
1516
+ "$ 11.3K",
1517
+ "$ 155.2K",
1518
+ "$ 50.7K",
1519
+ "$ 125.4K",
1520
+ "$ 125.2K",
1521
+ "$ 105K",
1522
+ "$ 3.3M",
1523
+ "$ 47.5K",
1524
+ "$ 55.9K",
1525
+ "$ 632.3K",
1526
+ "$ 48.3K",
1527
+ ],
1528
+ "Revenue": [
1529
+ "558.0K",
1530
+ "3.5M",
1531
+ "5.2M",
1532
+ "3.1M",
1533
+ "3.1M",
1534
+ "2.1M",
1535
+ "20.8M",
1536
+ "1.6M",
1537
+ "728.4K",
1538
+ "22.9M",
1539
+ "4.8M",
1540
+ ],
1541
+ }
1542
+
1543
+ # Create DataFrame
1544
+ df = pd.DataFrame(data)
1545
+
1546
+ # Convert currency strings to numeric values
1547
+ df["Spends"] = (
1548
+ df["Spends"]
1549
+ .replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
1550
+ .map(pd.eval)
1551
+ .astype(int)
1552
+ )
1553
+ df["Revenue"] = (
1554
+ df["Revenue"]
1555
+ .replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
1556
+ .map(pd.eval)
1557
+ .astype(int)
1558
+ )
1559
+
1560
+ # Calculate ROI
1561
+ df["ROI"] = (df["Revenue"] - df["Spends"]) / df["Spends"]
1562
+
1563
+ # Format columns
1564
+ format_currency = lambda x: f"${x:,.1f}"
1565
+ format_roi = lambda x: f"{x:.1f}"
1566
+
1567
+ df["Spends"] = [
1568
+ "$ 11.3K",
1569
+ "$ 155.2K",
1570
+ "$ 50.7K",
1571
+ "$ 125.4K",
1572
+ "$ 125.2K",
1573
+ "$ 105K",
1574
+ "$ 3.3M",
1575
+ "$ 47.5K",
1576
+ "$ 55.9K",
1577
+ "$ 632.3K",
1578
+ "$ 48.3K",
1579
+ ]
1580
+ df["Revenue"] = [
1581
+ "$ 536.3K",
1582
+ "$ 3.4M",
1583
+ "$ 5M",
1584
+ "$ 3M",
1585
+ "$ 3M",
1586
+ "$ 2M",
1587
+ "$ 20M",
1588
+ "$ 1.5M",
1589
+ "$ 7.1M",
1590
+ "$ 22M",
1591
+ "$ 4.6M",
1592
+ ]
1593
+ df["ROI"] = df["ROI"].apply(format_roi)
1594
+
1595
+ return df
1596
+
1597
+
1598
+ # @st.cache(allow_output_mutation=True)
1599
+ # def create_contribution_pie(scenario):
1600
+ # #c1f7dc
1601
+ # colors_map = {col:color for col,color in zip(st.session_state['channels_list'],plotly.colors.n_colors(plotly.colors.hex_to_rgb('#BE6468'), plotly.colors.hex_to_rgb('#E7B8B7'),23))}
1602
+ # total_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "pie"}, {"type": "pie"}]])
1603
+ # total_contribution_fig.add_trace(
1604
+ # go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
1605
+ # values= [round(scenario.channels[channel_name].actual_total_spends * scenario.channels[channel_name].conversion_rate,1) for channel_name in st.session_state['channels_list']] + [0],
1606
+ # marker=dict(colors = [plotly.colors.label_rgb(colors_map[channel_name]) for channel_name in st.session_state['channels_list']] + ['#F0F0F0']),
1607
+ # hole=0.3),
1608
+ # row=1, col=1)
1609
+
1610
+ # total_contribution_fig.add_trace(
1611
+ # go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
1612
+ # values= [scenario.channels[channel_name].actual_total_sales for channel_name in st.session_state['channels_list']] + [scenario.correction.sum() + scenario.constant.sum()],
1613
+ # hole=0.3),
1614
+ # row=1, col=2)
1615
+
1616
+ # total_contribution_fig.update_traces(textposition='inside',texttemplate='%{percent:.1%}')
1617
+ # total_contribution_fig.update_layout(uniformtext_minsize=12,title='Channel contribution', uniformtext_mode='hide')
1618
+ # return total_contribution_fig
1619
+
1620
+ # @st.cache(allow_output_mutation=True)
1621
+
1622
+ # def create_contribuion_stacked_plot(scenario):
1623
+ # weekly_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "bar"}, {"type": "bar"}]])
1624
+ # raw_df = st.session_state['raw_df']
1625
+ # df = raw_df.sort_values(by='Date')
1626
+ # x = df.Date
1627
+ # weekly_spends_data = []
1628
+ # weekly_sales_data = []
1629
+ # for channel_name in st.session_state['channels_list']:
1630
+ # weekly_spends_data.append((go.Bar(x=x,
1631
+ # y=scenario.channels[channel_name].actual_spends * scenario.channels[channel_name].conversion_rate,
1632
+ # name=channel_name_formating(channel_name),
1633
+ # hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
1634
+ # legendgroup=channel_name)))
1635
+ # weekly_sales_data.append((go.Bar(x=x,
1636
+ # y=scenario.channels[channel_name].actual_sales,
1637
+ # name=channel_name_formating(channel_name),
1638
+ # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1639
+ # legendgroup=channel_name, showlegend=False)))
1640
+ # for _d in weekly_spends_data:
1641
+ # weekly_contribution_fig.add_trace(_d, row=1, col=1)
1642
+ # for _d in weekly_sales_data:
1643
+ # weekly_contribution_fig.add_trace(_d, row=1, col=2)
1644
+ # weekly_contribution_fig.add_trace(go.Bar(x=x,
1645
+ # y=scenario.constant + scenario.correction,
1646
+ # name='Non Media',
1647
+ # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), row=1, col=2)
1648
+ # weekly_contribution_fig.update_layout(barmode='stack', title='Channel contribuion by week', xaxis_title='Date')
1649
+ # weekly_contribution_fig.update_xaxes(showgrid=False)
1650
+ # weekly_contribution_fig.update_yaxes(showgrid=False)
1651
+ # return weekly_contribution_fig
1652
+
1653
+ # @st.cache(allow_output_mutation=True)
1654
+ # def create_channel_spends_sales_plot(channel):
1655
+ # if channel is not None:
1656
+ # x = channel.dates
1657
+ # _spends = channel.actual_spends * channel.conversion_rate
1658
+ # _sales = channel.actual_sales
1659
+ # channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1660
+ # channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
1661
+ # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#005b96'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
1662
+ # channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
1663
+ # channel_sales_spends_fig.update_xaxes(showgrid=False)
1664
+ # channel_sales_spends_fig.update_yaxes(showgrid=False)
1665
+ # else:
1666
+ # raw_df = st.session_state['raw_df']
1667
+ # df = raw_df.sort_values(by='Date')
1668
+ # x = df.Date
1669
+ # scenario = class_from_dict(st.session_state['default_scenario_dict'])
1670
+ # _sales = scenario.constant + scenario.correction
1671
+ # channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1672
+ # channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
1673
+ # # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#15C39A'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
1674
+ # channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
1675
+ # channel_sales_spends_fig.update_xaxes(showgrid=False)
1676
+ # channel_sales_spends_fig.update_yaxes(showgrid=False)
1677
+ # return channel_sales_spends_fig
1678
+
1679
+
1680
+ # Define a shared color palette
1681
+
1682
+
1683
+ def create_contribution_pie():
1684
+ color_palette = [
1685
+ "#F3F3F0",
1686
+ "#5E7D7E",
1687
+ "#2FA1FF",
1688
+ "#00EDED",
1689
+ "#00EAE4",
1690
+ "#304550",
1691
+ "#EDEBEB",
1692
+ "#7FBEFD",
1693
+ "#003059",
1694
+ "#A2F3F3",
1695
+ "#E1D6E2",
1696
+ "#B6B6B6",
1697
+ ]
1698
+ total_contribution_fig = make_subplots(
1699
+ rows=1,
1700
+ cols=2,
1701
+ subplot_titles=["Spends", "Revenue"],
1702
+ specs=[[{"type": "pie"}, {"type": "pie"}]],
1703
+ )
1704
+
1705
+ channels_list = [
1706
+ "Paid Search",
1707
+ "Ga will cid baixo risco",
1708
+ "Digital tactic others",
1709
+ "Fb la tier 1",
1710
+ "Fb la tier 2",
1711
+ "Paid social others",
1712
+ "Programmatic",
1713
+ "Kwai",
1714
+ "Indicacao",
1715
+ "Infleux",
1716
+ "Influencer",
1717
+ "Non Media",
1718
+ ]
1719
+
1720
+ # Assign colors from the limited palette to channels
1721
+ colors_map = {
1722
+ col: color_palette[i % len(color_palette)]
1723
+ for i, col in enumerate(channels_list)
1724
+ }
1725
+ colors_map["Non Media"] = color_palette[
1726
+ 5
1727
+ ] # Assign fixed green color for 'Non Media'
1728
+
1729
+ # Hardcoded values for Spends and Revenue
1730
+ spends_values = [0.5, 3.36, 1.1, 2.7, 2.7, 2.27, 70.6, 1, 1, 13.7, 1, 0]
1731
+ revenue_values = [1, 4, 5, 3, 3, 2, 50.8, 1.5, 0.7, 13, 0, 16]
1732
+
1733
+ # Add trace for Spends pie chart
1734
+ total_contribution_fig.add_trace(
1735
+ go.Pie(
1736
+ labels=[channel_name for channel_name in channels_list],
1737
+ values=spends_values,
1738
+ marker=dict(
1739
+ colors=[colors_map[channel_name] for channel_name in channels_list]
1740
+ ),
1741
+ hole=0.3,
1742
+ ),
1743
+ row=1,
1744
+ col=1,
1745
+ )
1746
+
1747
+ # Add trace for Revenue pie chart
1748
+ total_contribution_fig.add_trace(
1749
+ go.Pie(
1750
+ labels=[channel_name for channel_name in channels_list],
1751
+ values=revenue_values,
1752
+ marker=dict(
1753
+ colors=[colors_map[channel_name] for channel_name in channels_list]
1754
+ ),
1755
+ hole=0.3,
1756
+ ),
1757
+ row=1,
1758
+ col=2,
1759
+ )
1760
+
1761
+ total_contribution_fig.update_traces(
1762
+ textposition="inside", texttemplate="%{percent:.1%}"
1763
+ )
1764
+ total_contribution_fig.update_layout(
1765
+ uniformtext_minsize=12,
1766
+ title="Channel contribution",
1767
+ uniformtext_mode="hide",
1768
+ )
1769
+ return total_contribution_fig
1770
+
1771
+
1772
+ def create_contribuion_stacked_plot(scenario):
1773
+ weekly_contribution_fig = make_subplots(
1774
+ rows=1,
1775
+ cols=2,
1776
+ subplot_titles=["Spends", "Revenue"],
1777
+ specs=[[{"type": "bar"}, {"type": "bar"}]],
1778
+ )
1779
+ raw_df = st.session_state["raw_df"]
1780
+ df = raw_df.sort_values(by="Date")
1781
+ x = df.Date
1782
+ weekly_spends_data = []
1783
+ weekly_sales_data = []
1784
+
1785
+ for i, channel_name in enumerate(st.session_state["channels_list"]):
1786
+ color = color_palette[i % len(color_palette)]
1787
+
1788
+ weekly_spends_data.append(
1789
+ go.Bar(
1790
+ x=x,
1791
+ y=scenario.channels[channel_name].actual_spends
1792
+ * scenario.channels[channel_name].conversion_rate,
1793
+ name=channel_name_formating(channel_name),
1794
+ hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
1795
+ legendgroup=channel_name,
1796
+ marker_color=color,
1797
+ )
1798
+ )
1799
+
1800
+ weekly_sales_data.append(
1801
+ go.Bar(
1802
+ x=x,
1803
+ y=scenario.channels[channel_name].actual_sales,
1804
+ name=channel_name_formating(channel_name),
1805
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1806
+ legendgroup=channel_name,
1807
+ showlegend=False,
1808
+ marker_color=color,
1809
+ )
1810
+ )
1811
+
1812
+ for _d in weekly_spends_data:
1813
+ weekly_contribution_fig.add_trace(_d, row=1, col=1)
1814
+ for _d in weekly_sales_data:
1815
+ weekly_contribution_fig.add_trace(_d, row=1, col=2)
1816
+
1817
+ weekly_contribution_fig.add_trace(
1818
+ go.Bar(
1819
+ x=x,
1820
+ y=scenario.constant + scenario.correction,
1821
+ name="Non Media",
1822
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1823
+ marker_color=color_palette[-1],
1824
+ ),
1825
+ row=1,
1826
+ col=2,
1827
+ )
1828
+
1829
+ weekly_contribution_fig.update_layout(
1830
+ barmode="stack",
1831
+ title="Channel contribution by week",
1832
+ xaxis_title="Date",
1833
+ )
1834
+ weekly_contribution_fig.update_xaxes(showgrid=False)
1835
+ weekly_contribution_fig.update_yaxes(showgrid=False)
1836
+ return weekly_contribution_fig
1837
+
1838
+
1839
+ def create_channel_spends_sales_plot(channel):
1840
+ if channel is not None:
1841
+ x = channel.dates
1842
+ _spends = channel.actual_spends * channel.conversion_rate
1843
+ _sales = channel.actual_sales
1844
+ channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1845
+ channel_sales_spends_fig.add_trace(
1846
+ go.Bar(
1847
+ x=x,
1848
+ y=_sales,
1849
+ marker_color=color_palette[
1850
+ 3
1851
+ ], # You can choose a color from the palette
1852
+ name="Revenue",
1853
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1854
+ ),
1855
+ secondary_y=False,
1856
+ )
1857
+
1858
+ channel_sales_spends_fig.add_trace(
1859
+ go.Scatter(
1860
+ x=x,
1861
+ y=_spends,
1862
+ line=dict(
1863
+ color=color_palette[2]
1864
+ ), # You can choose another color from the palette
1865
+ name="Spends",
1866
+ hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
1867
+ ),
1868
+ secondary_y=True,
1869
+ )
1870
+
1871
+ channel_sales_spends_fig.update_layout(
1872
+ xaxis_title="Date",
1873
+ yaxis_title="Revenue",
1874
+ yaxis2_title="Spends ($)",
1875
+ title="Channel spends and Revenue week-wise",
1876
+ )
1877
+ channel_sales_spends_fig.update_xaxes(showgrid=False)
1878
+ channel_sales_spends_fig.update_yaxes(showgrid=False)
1879
+ else:
1880
+ raw_df = st.session_state["raw_df"]
1881
+ df = raw_df.sort_values(by="Date")
1882
+ x = df.Date
1883
+ scenario = class_from_dict(st.session_state["default_scenario_dict"])
1884
+ _sales = scenario.constant + scenario.correction
1885
+ channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1886
+ channel_sales_spends_fig.add_trace(
1887
+ go.Bar(
1888
+ x=x,
1889
+ y=_sales,
1890
+ marker_color=color_palette[
1891
+ 0
1892
+ ], # You can choose a color from the palette
1893
+ name="Revenue",
1894
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1895
+ ),
1896
+ secondary_y=False,
1897
+ )
1898
+
1899
+ channel_sales_spends_fig.update_layout(
1900
+ xaxis_title="Date",
1901
+ yaxis_title="Revenue",
1902
+ yaxis2_title="Spends ($)",
1903
+ title="Channel spends and Revenue week-wise",
1904
+ )
1905
+ channel_sales_spends_fig.update_xaxes(showgrid=False)
1906
+ channel_sales_spends_fig.update_yaxes(showgrid=False)
1907
+
1908
+ return channel_sales_spends_fig
1909
+
1910
+
1911
+ def format_numbers(value, n_decimals=1, include_indicator=True):
1912
+ if value is None:
1913
+ return None
1914
+ _value = value if value < 1 else numerize(value, n_decimals)
1915
+ if include_indicator:
1916
+ return f"{CURRENCY_INDICATOR} {_value}"
1917
+ else:
1918
+ return f"{_value}"
1919
+
1920
+
1921
+ def decimal_formater(num_string, n_decimals=1):
1922
+ parts = num_string.split(".")
1923
+ if len(parts) == 1:
1924
+ return num_string + "." + "0" * n_decimals
1925
+ else:
1926
+ to_be_padded = n_decimals - len(parts[-1])
1927
+ if to_be_padded > 0:
1928
+ return num_string + "0" * to_be_padded
1929
+ else:
1930
+ return num_string
1931
+
1932
+
1933
+ def channel_name_formating(channel_name):
1934
+ name_mod = channel_name.replace("_", " ")
1935
+ if name_mod.lower().endswith(" imp"):
1936
+ name_mod = name_mod.replace("Imp", "Spend")
1937
+ elif name_mod.lower().endswith(" clicks"):
1938
+ name_mod = name_mod.replace("Clicks", "Spend")
1939
+ return name_mod
1940
+
1941
+
1942
+ def send_email(email, message):
1943
+ s = smtplib.SMTP("smtp.gmail.com", 587)
1944
+ s.starttls()
1945
+ s.login("[email protected]", "jgydhpfusuremcol")
1946
+ s.sendmail("[email protected]", email, message)
1947
+ s.quit()
1948
+
1949
+
1950
+ # if __name__ == "__main__":
1951
+ # initialize_data()
1952
+
1953
+
1954
+ #############################################################################################################
1955
+
1956
+ import os
1957
+ import json
1958
+ import streamlit as st
1959
+
1960
+
1961
+ # Function to get panels names
1962
+ def get_panels_names(file_selected):
1963
+ raw_data_df = st.session_state["project_dct"]["current_media_performance"][
1964
+ "model_outputs"
1965
+ ][file_selected]["raw_data"]
1966
+
1967
+ if "panel" in raw_data_df.columns:
1968
+ panel = list(set(raw_data_df["panel"]))
1969
+ elif "Panel" in raw_data_df.columns:
1970
+ panel = list(set(raw_data_df["Panel"]))
1971
+ else:
1972
+ panel = []
1973
+
1974
+ return panel + ["aggregated"]
1975
+
1976
+
1977
+ # Function to get metrics names
1978
+ def get_metrics_names():
1979
+ return list(
1980
+ st.session_state["project_dct"]["current_media_performance"][
1981
+ "model_outputs"
1982
+ ].keys()
1983
+ )
1984
+
1985
+
1986
+ # Function to load the original and modified rcs metadata files into dictionaries
1987
+ def load_rcs_metadata_files():
1988
+ original_data = st.session_state["project_dct"]["response_curves"][
1989
+ "original_metadata_file"
1990
+ ]
1991
+ modified_data = st.session_state["project_dct"]["response_curves"][
1992
+ "modified_metadata_file"
1993
+ ]
1994
+
1995
+ return original_data, modified_data
1996
+
1997
+
1998
+ # Function to format name
1999
+ def name_formating(name):
2000
+ # Replace underscores with spaces
2001
+ name_mod = name.replace("_", " ")
2002
+
2003
+ # Capitalize the first letter of each word
2004
+ name_mod = name_mod.title()
2005
+
2006
+ return name_mod
2007
+
2008
+
2009
+ # Function to load the original and modified scenario metadata files into dictionaries
2010
+ def load_scenario_metadata_files():
2011
+ original_data = st.session_state["project_dct"]["scenario_planner"][
2012
+ "original_metadata_file"
2013
+ ]
2014
+ modified_data = st.session_state["project_dct"]["scenario_planner"][
2015
+ "modified_metadata_file"
2016
+ ]
2017
+
2018
+ return original_data, modified_data
2019
+
2020
+
2021
+ # Function to generate RCS data and store it as dictionary
2022
+ def generate_rcs_data():
2023
+ # Retrieve the list of all metric names from the specified directory
2024
+ metrics_list = get_metrics_names()
2025
+
2026
+ # Dictionary to store RCS data for all metrics and their respective panels
2027
+ all_rcs_data_original = {}
2028
+ all_rcs_data_modified = {}
2029
+
2030
+ # Iterate over each metric in the metrics list
2031
+ for metric in metrics_list:
2032
+ # Retrieve the list of panel names from the current metric's Excel file
2033
+ panel_list = get_panels_names(file_selected=metric)
2034
+
2035
+ # Check if rcs_data_modified exist
2036
+ if (
2037
+ st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
2038
+ is not None
2039
+ ):
2040
+ modified_data = st.session_state["project_dct"]["response_curves"][
2041
+ "modified_metadata_file"
2042
+ ]
2043
+
2044
+ # Iterate over each panel in the panel list
2045
+ for panel in panel_list:
2046
+ # Initialize the original RCS data for the current panel and metric
2047
+ rcs_dict_original, scenario = initialize_data(
2048
+ panel=panel,
2049
+ metrics=metric,
2050
+ )
2051
+
2052
+ # Ensure the dictionary has the metric as a key for original data
2053
+ if metric not in all_rcs_data_original:
2054
+ all_rcs_data_original[metric] = {}
2055
+
2056
+ # Store the original RCS data under the corresponding panel for the current metric
2057
+ all_rcs_data_original[metric][panel] = rcs_dict_original
2058
+
2059
+ # Ensure the dictionary has the metric as a key for modified data
2060
+ if metric not in all_rcs_data_modified:
2061
+ all_rcs_data_modified[metric] = {}
2062
+
2063
+ # Store the modified RCS data under the corresponding panel for the current metric
2064
+ for channel in rcs_dict_original:
2065
+ all_rcs_data_modified[metric][panel] = all_rcs_data_modified[
2066
+ metric
2067
+ ].get(panel, {})
2068
+
2069
+ try:
2070
+ updated_rcs_dict = modified_data[metric][panel][channel]
2071
+ except:
2072
+ updated_rcs_dict = {
2073
+ "K": rcs_dict_original[channel]["K"],
2074
+ "b": rcs_dict_original[channel]["b"],
2075
+ "a": rcs_dict_original[channel]["a"],
2076
+ "x0": rcs_dict_original[channel]["x0"],
2077
+ }
2078
+
2079
+ all_rcs_data_modified[metric][panel][channel] = updated_rcs_dict
2080
+
2081
+ # Write the original RCS data
2082
+ st.session_state["project_dct"]["response_curves"][
2083
+ "original_metadata_file"
2084
+ ] = all_rcs_data_original
2085
+
2086
+ # Write the modified RCS data
2087
+ st.session_state["project_dct"]["response_curves"][
2088
+ "modified_metadata_file"
2089
+ ] = all_rcs_data_modified
2090
+
2091
+
2092
+ # Function to generate scenario data and store it as dictionary
2093
+ def generate_scenario_data():
2094
+ # Retrieve the list of all metric names from the specified directory
2095
+ metrics_list = get_metrics_names()
2096
+
2097
+ # Dictionary to store scenario data for all metrics and their respective panels
2098
+ all_scenario_data_original = {}
2099
+ all_scenario_data_modified = {}
2100
+
2101
+ # Iterate over each metric in the metrics list
2102
+ for metric in metrics_list:
2103
+ # Retrieve the list of panel names from the current metric's Excel file
2104
+ panel_list = get_panels_names(metric)
2105
+
2106
+ # Check if scenario_data_modified exist
2107
+ if (
2108
+ st.session_state["project_dct"]["scenario_planner"][
2109
+ "modified_metadata_file"
2110
+ ]
2111
+ is not None
2112
+ ):
2113
+ modified_data = st.session_state["project_dct"]["scenario_planner"][
2114
+ "modified_metadata_file"
2115
+ ]
2116
+
2117
+ # Iterate over each panel in the panel list
2118
+ for panel in panel_list:
2119
+ # Initialize the original scenario data for the current panel and metric
2120
+ rcs_dict_original, scenario = initialize_data(
2121
+ panel=panel,
2122
+ metrics=metric,
2123
+ )
2124
+
2125
+ # Ensure the dictionary has the metric as a key for original data
2126
+ if metric not in all_scenario_data_original:
2127
+ all_scenario_data_original[metric] = {}
2128
+
2129
+ # Store the original scenario data under the corresponding panel for the current metric
2130
+ all_scenario_data_original[metric][panel] = class_convert_to_dict(scenario)
2131
+
2132
+ # Ensure the dictionary has the metric as a key for modified data
2133
+ if metric not in all_scenario_data_modified:
2134
+ all_scenario_data_modified[metric] = {}
2135
+
2136
+ # Store the modified scenario data under the corresponding panel for the current metric
2137
+ try:
2138
+ all_scenario_data_modified[metric][panel] = modified_data[metric][panel]
2139
+ except:
2140
+ all_scenario_data_modified[metric][panel] = class_convert_to_dict(
2141
+ scenario
2142
+ )
2143
+
2144
+ # Write the original scenario data
2145
+ st.session_state["project_dct"]["scenario_planner"][
2146
+ "original_metadata_file"
2147
+ ] = all_scenario_data_original
2148
+
2149
+ # Write the modified scenario data
2150
+ st.session_state["project_dct"]["scenario_planner"][
2151
+ "modified_metadata_file"
2152
+ ] = all_scenario_data_modified
2153
+
2154
+
2155
+ #############################################################################################################
utilities_with_panel.py ADDED
@@ -0,0 +1,1520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scenario import numerize
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import json
5
+ from scenario import Channel, Scenario
6
+ import numpy as np
7
+ from plotly.subplots import make_subplots
8
+ import plotly.graph_objects as go
9
+ from scenario import class_to_dict
10
+ from collections import OrderedDict
11
+ import io
12
+ import plotly
13
+ from pathlib import Path
14
+ import pickle
15
+ import yaml
16
+ from yaml import SafeLoader
17
+ from streamlit.components.v1 import html
18
+ import smtplib
19
+ from scipy.optimize import curve_fit
20
+ from sklearn.metrics import r2_score
21
+ from scenario import class_from_dict
22
+ from utilities import retrieve_pkl_object
23
+ import os
24
+ import base64
25
+
26
+
27
+
28
+
29
+ # # schema = db_cred["schema"]
30
+
31
+ color_palette = [
32
+ "#F3F3F0",
33
+ "#5E7D7E",
34
+ "#2FA1FF",
35
+ "#00EDED",
36
+ "#00EAE4",
37
+ "#304550",
38
+ "#EDEBEB",
39
+ "#7FBEFD",
40
+ "#003059",
41
+ "#A2F3F3",
42
+ "#E1D6E2",
43
+ "#B6B6B6",
44
+ ]
45
+
46
+
47
+ CURRENCY_INDICATOR = "$"
48
+
49
+
50
+ if "project_dct" not in st.session_state or "project_number" not in st.session_state:
51
+ st.error(
52
+ "No tuned model available. Please build and tune a model to generate response curves."
53
+ )
54
+ st.stop()
55
+
56
+ tuned_model = retrieve_pkl_object(
57
+ st.session_state["project_number"], "Model_Tuning", "tuned_model"
58
+ )
59
+
60
+ if tuned_model is None:
61
+ st.error(
62
+ "No tuned model available. Please build and tune a model to generate response curves."
63
+ )
64
+ st.stop()
65
+
66
+
67
+ def load_authenticator():
68
+ with open("config.yaml") as file:
69
+ config = yaml.load(file, Loader=SafeLoader)
70
+ st.session_state["config"] = config
71
+ authenticator = stauth.Authenticate(
72
+ config["credentials"],
73
+ config["cookie"]["name"],
74
+ config["cookie"]["key"],
75
+ config["cookie"]["expiry_days"],
76
+ config["preauthorized"],
77
+ )
78
+ st.session_state["authenticator"] = authenticator
79
+ return authenticator
80
+
81
+
82
+ def nav_page(page_name, timeout_secs=3):
83
+ nav_script = """
84
+ <script type="text/javascript">
85
+ function attempt_nav_page(page_name, start_time, timeout_secs) {
86
+ var links = window.parent.document.getElementsByTagName("a");
87
+ for (var i = 0; i < links.length; i++) {
88
+ if (links[i].href.toLowerCase().endsWith("/" + page_name.toLowerCase())) {
89
+ links[i].click();
90
+ return;
91
+ }
92
+ }
93
+ var elasped = new Date() - start_time;
94
+ if (elasped < timeout_secs * 1000) {
95
+ setTimeout(attempt_nav_page, 100, page_name, start_time, timeout_secs);
96
+ } else {
97
+ alert("Unable to navigate to page '" + page_name + "' after " + timeout_secs + " second(s).");
98
+ }
99
+ }
100
+ window.addEventListener("load", function() {
101
+ attempt_nav_page("?", new Date(), %d);
102
+ });
103
+ </script>
104
+ """ % (
105
+ page_name,
106
+ timeout_secs,
107
+ )
108
+ html(nav_script)
109
+
110
+
111
+ # def load_local_css(file_name):
112
+ # with open(file_name) as f:
113
+ # st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
114
+
115
+
116
+ # def set_header():
117
+ # return st.markdown(f"""<div class='main-header'>
118
+ # <h1>MMM LiME</h1>
119
+ # <img src="https://assets-global.website-files.com/64c8fffb0e95cbc525815b79/64df84637f83a891c1473c51_Vector%20(Stroke).svg ">
120
+ # </div>""", unsafe_allow_html=True)
121
+
122
+ path = os.path.dirname(__file__)
123
+
124
+ file_ = open(f"{path}/logo.png", "rb")
125
+
126
+ contents = file_.read()
127
+
128
+ data_url = base64.b64encode(contents).decode("utf-8")
129
+
130
+ file_.close()
131
+
132
+
133
+ DATA_PATH = "./data"
134
+
135
+ IMAGES_PATH = "./data/images_224_224"
136
+
137
+
138
+ # is_panel = True if len(panel_col) > 0 else False
139
+
140
+ # manoj
141
+
142
+ is_panel = False
143
+ date_col = "Date"
144
+ # is_panel = False # flag if set to true - do panel level response curves
145
+
146
+
147
+ def load_local_css(file_name):
148
+
149
+ with open(file_name) as f:
150
+
151
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
152
+
153
+
154
+ # def set_header():
155
+
156
+ # return st.markdown(f"""<div class='main-header'>
157
+
158
+ # <h1>H & M Recommendations</h1>
159
+
160
+ # <img src="data:image;base64,{data_url}", alt="Logo">
161
+
162
+ # </div>""", unsafe_allow_html=True)
163
+ path1 = os.path.dirname(__file__)
164
+
165
+ # file_1 = open(f"{path}/willbank.png", "rb")
166
+
167
+ # contents1 = file_1.read()
168
+
169
+ # data_url1 = base64.b64encode(contents1).decode("utf-8")
170
+
171
+ # file_1.close()
172
+
173
+
174
+ DATA_PATH1 = "./data"
175
+
176
+ IMAGES_PATH1 = "./data/images_224_224"
177
+
178
+
179
+ def set_header():
180
+ return st.markdown(
181
+ f"""<div class='main-header'>
182
+ <!-- <h1></h1> -->
183
+ <div >
184
+ <img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo">
185
+ </div>
186
+ <img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo">
187
+ </div>""",
188
+ unsafe_allow_html=True,
189
+ )
190
+
191
+
192
+ # def set_header():
193
+ # logo_path = "./path/to/your/local/LIME_logo.png" # Replace with the actual file path
194
+ # text = "LiME"
195
+ # return st.markdown(f"""<div class='main-header'>
196
+ # <img src="data:image/png;base64,{data_url}" alt="Logo" style="float: left; margin-right: 10px; width: 100px; height: auto;">
197
+ # <h1>{text}</h1>
198
+ # </div>""", unsafe_allow_html=True)
199
+
200
+
201
+ def s_curve(x, K, b, a, x0):
202
+ return K / (1 + b * np.exp(-a * (x - x0)))
203
+
204
+
205
+ def overview_test_data_prep_panel(X, df, spends_X, date_col, panel_col, target_col):
206
+ """
207
+ function to create the data which is used in initialize data fn
208
+ X : X test with contributions
209
+ df : originally uploaded data (media data) which has raw vars
210
+ spends_X : spends of dates in X test
211
+ """
212
+
213
+ channels = st.session_state["channels"]
214
+ channel_list = channels.keys()
215
+
216
+ # map transformed variable to raw variable name & channel name
217
+ # mapping eg : paid_search_clicks_lag_2 (transformed var) --> paid_search_clicks (raw var) --> paid_search (channel)
218
+ variables = {}
219
+ channel_and_variables = {}
220
+ new_variables = {}
221
+ new_channels_and_variables = {}
222
+
223
+ for transformed_var in [
224
+ col
225
+ for col in X.drop(columns=[date_col, panel_col, target_col, "pred"]).columns
226
+ if "_contr" not in col
227
+ ]:
228
+ if len([col for col in df.columns if col in transformed_var]) == 1:
229
+ raw_var = [col for col in df.columns if col in transformed_var][0]
230
+
231
+ variables[transformed_var] = raw_var
232
+
233
+ # Check if the list comprehension result is not empty before accessing the first element
234
+ channels_list = [
235
+ channel for channel, raw_vars in channels.items() if raw_var in raw_vars
236
+ ]
237
+ if channels_list:
238
+ channel_and_variables[raw_var] = channels_list[0]
239
+ else:
240
+ # Handle the case where channels_list is empty
241
+ # You might want to set a default value or handle it according to your use case
242
+ channel_and_variables[raw_var] = None
243
+ else:
244
+ new_variables[transformed_var] = transformed_var
245
+ new_channels_and_variables[transformed_var] = "base"
246
+
247
+ # Raw DF
248
+ raw_X = pd.merge(
249
+ X[[date_col, panel_col]],
250
+ df[[date_col, panel_col] + list(variables.values())],
251
+ how="left",
252
+ on=[date_col, panel_col],
253
+ )
254
+ assert len(raw_X) == len(X)
255
+
256
+ raw_X_cols = []
257
+ for i in raw_X.columns:
258
+ if i in channel_and_variables.keys():
259
+ raw_X_cols.append(channel_and_variables[i])
260
+ else:
261
+ raw_X_cols.append(i)
262
+ raw_X.columns = raw_X_cols
263
+
264
+ # Contribution DF
265
+ contr_X = X[
266
+ [date_col, panel_col]
267
+ + [col for col in X.columns if "_contr" in col and "sum_" not in col]
268
+ ].copy()
269
+ # if "base_contr" in contr_X.columns:
270
+ # contr_X.rename(columns={'base_contr':'const_contr'},inplace=True)
271
+ # # new_variables = [
272
+ # col
273
+ # for col in contr_X.columns
274
+ # if "_flag" in col.lower() or "trend" in col.lower() or "sine" in col.lower()
275
+ # ]
276
+ # if len(new_variables) > 0:
277
+ # contr_X["const"] = contr_X[["panel_effect"] + new_variables].sum(axis=1)
278
+ # contr_X.drop(columns=["panel_effect"], inplace=True)
279
+ # contr_X.drop(columns=new_variables, inplace=True)
280
+ # else:
281
+ # contr_X.rename(columns={"panel_effect": "const"}, inplace=True)
282
+
283
+ new_contr_X_cols = []
284
+ for col in contr_X.columns:
285
+ col_clean = col.replace("_contr", "")
286
+ new_contr_X_cols.append(col_clean)
287
+ contr_X.columns = new_contr_X_cols
288
+
289
+ contr_X_cols = []
290
+ for i in contr_X.columns:
291
+ if i in variables.keys():
292
+ contr_X_cols.append(channel_and_variables[variables[i]])
293
+ else:
294
+ contr_X_cols.append(i)
295
+ contr_X.columns = contr_X_cols
296
+
297
+ # Spends DF
298
+ spends_X.columns = [col.replace("_cost", "") for col in spends_X.columns]
299
+
300
+ raw_X.rename(columns={"date": "Date"}, inplace=True)
301
+ contr_X.rename(columns={"date": "Date"}, inplace=True)
302
+ spends_X.rename(columns={"date": "Week"}, inplace=True)
303
+
304
+ spends_X.columns = [
305
+ col.replace("spends_", "") if col.startswith("spends_") else col
306
+ for col in spends_X.columns
307
+ ]
308
+
309
+ # Rename column to 'Date'
310
+ spends_X.rename(columns={"Week": "Date"}, inplace=True)
311
+
312
+ # Remove "response_metric_" from the start and "_total" from the end
313
+ if str(target_col).startswith("response_metric_"):
314
+ target_col = target_col.replace("response_metric_", "", 1)
315
+
316
+ # Remove the last 6 characters (length of "_total")
317
+ if str(target_col).endswith("_total"):
318
+ target_col = target_col[:-6]
319
+
320
+ # Rename column to 'Date'
321
+ spends_X.rename(columns={"Week": "Date"}, inplace=True)
322
+
323
+ # Save raw, spends and contribution data
324
+ st.session_state["project_dct"]["current_media_performance"]["model_outputs"][
325
+ target_col
326
+ ] = {
327
+ "raw_data": raw_X,
328
+ "contribution_data": contr_X,
329
+ "spends_data": spends_X,
330
+ }
331
+
332
+ # Clear page metadata
333
+ st.session_state["project_dct"]["scenario_planner"]["original_metadata_file"] = None
334
+ st.session_state["project_dct"]["response_curves"]["original_metadata_file"] = None
335
+
336
+ # st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = None
337
+ # st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] = None
338
+
339
+
340
+ def overview_test_data_prep_nonpanel(X, df, spends_X, date_col, target_col):
341
+ """
342
+ function to create the data which is used in initialize data fn
343
+ """
344
+
345
+ # with open(
346
+ # os.path.join(st.session_state["project_path"], "channel_groups.pkl"), "rb"
347
+ # ) as f:
348
+ # channels = pickle.load(f)
349
+
350
+ # channel_list = list(channels.keys())
351
+ channels = st.session_state["channels"]
352
+ channel_list = channels.keys()
353
+
354
+ # map transformed variable to raw variable name & channel name
355
+ # mapping eg : paid_search_clicks_lag_2 (transformed var) --> paid_search_clicks (raw var) --> paid_search (channel)
356
+ variables = {}
357
+ channel_and_variables = {}
358
+ new_variables = {}
359
+ new_channels_and_variables = {}
360
+
361
+ cols_to_del = list(
362
+ set([date_col, target_col, "pred"]).intersection((set(X.columns)))
363
+ )
364
+
365
+ # remove exog cols from RAW data (exog cols are part of base, raw data needs media vars only)
366
+ all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
367
+ all_exog_vars = [
368
+ var.lower()
369
+ .replace(".", "_")
370
+ .replace("@", "_")
371
+ .replace(" ", "_")
372
+ .replace("-", "")
373
+ .replace(":", "")
374
+ .replace("__", "_")
375
+ for var in all_exog_vars
376
+ ]
377
+ exog_cols = []
378
+ if len(all_exog_vars) > 0:
379
+ for col in X.columns:
380
+ if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
381
+ exog_cols.append(col)
382
+ cols_to_del = cols_to_del + exog_cols
383
+ for transformed_var in [
384
+ col for col in X.drop(columns=cols_to_del).columns if "_contr" not in col
385
+ ]: # also has 'const'
386
+
387
+ if (
388
+ len([col for col in df.columns if col in transformed_var]) == 1
389
+ ): # col is raw var
390
+ raw_var = [col for col in df.columns if col in transformed_var][0]
391
+ variables[transformed_var] = raw_var
392
+ channel_and_variables[raw_var] = [
393
+ channel for channel, raw_vars in channels.items() if raw_var in raw_vars
394
+ ][0]
395
+ else: # when no corresponding raw var then base
396
+ new_variables[transformed_var] = transformed_var
397
+ new_channels_and_variables[transformed_var] = "base"
398
+
399
+ # Raw DF
400
+ raw_X = pd.merge(
401
+ X[[date_col]],
402
+ df[[date_col] + list(variables.values())],
403
+ how="left",
404
+ on=[date_col],
405
+ )
406
+ assert len(raw_X) == len(X)
407
+
408
+ raw_X_cols = []
409
+ for i in raw_X.columns:
410
+ if i in channel_and_variables.keys():
411
+ raw_X_cols.append(channel_and_variables[i])
412
+ else:
413
+ raw_X_cols.append(i)
414
+ raw_X.columns = raw_X_cols
415
+
416
+ # Contribution DF
417
+ contr_X = X[
418
+ [date_col] + [col for col in X.columns if "_contr" in col and "sum_" not in col]
419
+ ].copy()
420
+
421
+ new_variables = [
422
+ col
423
+ for col in contr_X.columns
424
+ if "_flag" in col.lower() or "trend" in col.lower() or "sine" in col.lower()
425
+ ]
426
+ if (
427
+ len(new_variables) > 0
428
+ ): # if new vars are available, their contributions should be added to base (called const)
429
+ contr_X["const_contr"] = contr_X[["const_contr"] + new_variables].sum(axis=1)
430
+ contr_X.drop(columns=new_variables, inplace=True)
431
+
432
+ new_contr_X_cols = []
433
+ for col in contr_X.columns:
434
+ col_clean = col.replace("_contr", "")
435
+ new_contr_X_cols.append(col_clean)
436
+ contr_X.columns = new_contr_X_cols
437
+
438
+ contr_X_cols = []
439
+ for i in contr_X.columns:
440
+ if i in variables.keys():
441
+ contr_X_cols.append(channel_and_variables[variables[i]])
442
+ else:
443
+ contr_X_cols.append(i)
444
+ contr_X.columns = contr_X_cols
445
+
446
+ # Spends DF
447
+ # spends_X.columns = [
448
+ # col.replace("_cost", "").replace("_spends", "").replace("_spend", "")
449
+ # for col in spends_X.columns
450
+ # ]
451
+ spends_X_col_map = {
452
+ col: bucket
453
+ for col in spends_X.columns
454
+ for bucket in channels.keys()
455
+ if col in channels[bucket]
456
+ }
457
+ spends_X.rename(columns=spends_X_col_map, inplace=True)
458
+
459
+ raw_X.rename(columns={"date": "Date"}, inplace=True)
460
+ contr_X.rename(columns={"date": "Date"}, inplace=True)
461
+ spends_X.rename(columns={"date": "Week"}, inplace=True)
462
+
463
+ spends_X.columns = [
464
+ col.replace("spends_", "") if col.startswith("spends_") else col
465
+ for col in spends_X.columns
466
+ ]
467
+
468
+ # Rename column to 'Date'
469
+ spends_X.rename(columns={"Week": "Date"}, inplace=True)
470
+
471
+ # Remove "response_metric_" from the start and "_total" from the end
472
+ if str(target_col).startswith("response_metric_"):
473
+ target_col = target_col.replace("response_metric_", "", 1)
474
+
475
+ # Remove the last 6 characters (length of "_total")
476
+ if str(target_col).endswith("_total"):
477
+ target_col = target_col[:-6]
478
+
479
+ # Rename column to 'Date'
480
+ spends_X.rename(columns={"Week": "Date"}, inplace=True)
481
+
482
+ # Save raw, spends and contribution data
483
+ st.session_state["project_dct"]["current_media_performance"]["model_outputs"][
484
+ target_col
485
+ ] = {
486
+ "raw_data": raw_X,
487
+ "contribution_data": contr_X,
488
+ "spends_data": spends_X,
489
+ }
490
+
491
+ # Clear page metadata
492
+ st.session_state["project_dct"]["scenario_planner"]["original_metadata_file"] = None
493
+ st.session_state["project_dct"]["response_curves"]["original_metadata_file"] = None
494
+
495
+ # st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = None
496
+ # st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] = None
497
+
498
+
499
+ def initialize_data_cmp(target_col, is_panel, panel_col, start_date, end_date):
500
+ start_date = pd.to_datetime(start_date)
501
+ end_date = pd.to_datetime(end_date)
502
+
503
+ # Remove "response_metric_" from the start and "_total" from the end
504
+ if str(target_col).startswith("response_metric_"):
505
+ target_col = target_col.replace("response_metric_", "", 1)
506
+
507
+ # Remove the last 6 characters (length of "_total")
508
+ if str(target_col).endswith("_total"):
509
+ target_col = target_col[:-6]
510
+
511
+ # Extract dataframes for raw data, spend input, and contribution data
512
+ raw_df = st.session_state["project_dct"]["current_media_performance"][
513
+ "model_outputs"
514
+ ][target_col]["raw_data"]
515
+ spend_df = st.session_state["project_dct"]["current_media_performance"][
516
+ "model_outputs"
517
+ ][target_col]["spends_data"]
518
+ contri_df = st.session_state["project_dct"]["current_media_performance"][
519
+ "model_outputs"
520
+ ][target_col]["contribution_data"]
521
+
522
+ # Remove unnecessary columns
523
+ unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
524
+ exclude_columns = ["Date"] + unnamed_cols
525
+
526
+ if is_panel:
527
+ exclude_columns = exclude_columns + [panel_col]
528
+
529
+ # Aggregate all 3 dfs to date level (from date-panel level)
530
+ raw_df[date_col] = pd.to_datetime(raw_df[date_col])
531
+ raw_df = raw_df[raw_df[date_col] >= start_date]
532
+ raw_df = raw_df[raw_df[date_col] <= end_date].reset_index(drop=True)
533
+ raw_df_aggregations = {c: "sum" for c in raw_df.columns if c not in exclude_columns}
534
+ raw_df = raw_df.groupby(date_col).agg(raw_df_aggregations).reset_index()
535
+
536
+ contri_df[date_col] = pd.to_datetime(contri_df[date_col])
537
+ contri_df = contri_df[contri_df[date_col] >= start_date]
538
+ contri_df = contri_df[contri_df[date_col] <= end_date].reset_index(drop=True)
539
+ contri_df_aggregations = {
540
+ c: "sum" for c in contri_df.columns if c not in exclude_columns
541
+ }
542
+ contri_df = contri_df.groupby(date_col).agg(contri_df_aggregations).reset_index()
543
+
544
+ input_df = raw_df.sort_values(by=[date_col])
545
+
546
+ output_df = contri_df.sort_values(by=[date_col])
547
+ spend_df["Date"] = pd.to_datetime(
548
+ spend_df["Date"], format="%Y-%m-%d", errors="coerce"
549
+ )
550
+ spend_df = spend_df[spend_df["Date"] >= start_date]
551
+ spend_df = spend_df[spend_df["Date"] <= end_date].reset_index(drop=True)
552
+ spend_df_aggregations = {
553
+ c: "sum" for c in spend_df.columns if c not in exclude_columns
554
+ }
555
+ spend_df = spend_df.groupby("Date").agg(spend_df_aggregations).reset_index()
556
+ # spend_df['Week'] = pd.to_datetime(spend_df['Week'], errors='coerce')
557
+ # spend_df = spend_df.sort_values(by='Week')
558
+
559
+ channel_list = [col for col in input_df.columns if col not in exclude_columns]
560
+
561
+ response_curves = {}
562
+ mapes = {}
563
+ rmses = {}
564
+ upper_limits = {}
565
+ powers = {}
566
+ r2 = {}
567
+ conv_rates = {}
568
+ output_cols = []
569
+ channels = {}
570
+ sales = None
571
+ dates = input_df.Date.values
572
+ actual_output_dic = {}
573
+ actual_input_dic = {}
574
+
575
+ # channel_list=['programmatic']
576
+ infeasible_channels = [
577
+ c
578
+ for c in contri_df.select_dtypes(include=["float", "int"]).columns
579
+ if contri_df[c].sum() <= 0
580
+ ]
581
+
582
+ channel_list = list(set(channel_list) - set(infeasible_channels))
583
+
584
+ for inp_col in channel_list:
585
+ spends = input_df[inp_col].values
586
+
587
+ x = spends.copy()
588
+
589
+ # upper limit for penalty
590
+ upper_limits[inp_col] = 2 * x.max()
591
+
592
+ out_col = inp_col
593
+ y = output_df[out_col].values.copy()
594
+
595
+ actual_output_dic[inp_col] = y.copy()
596
+ actual_input_dic[inp_col] = x.copy()
597
+ ##output cols aggregation
598
+ output_cols.append(out_col)
599
+
600
+ ## scale the input
601
+ power = np.ceil(np.log(x.max()) / np.log(10)) - 3
602
+ if power >= 0:
603
+ x = x / 10**power
604
+
605
+ x = x.astype("float64")
606
+ y = y.astype("float64")
607
+
608
+ if y.max() <= 0.01:
609
+ bounds = (
610
+ (0, 0, 0, 0),
611
+ (3 * 0.01, 1000, 1, x.max() if x.max() > 0 else 0.01),
612
+ )
613
+ else:
614
+ bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
615
+
616
+ params, _ = curve_fit(
617
+ s_curve,
618
+ x,
619
+ y,
620
+ p0=(2 * y.max(), 0.01, 1e-5, x.max()),
621
+ bounds=bounds,
622
+ maxfev=int(1e5),
623
+ )
624
+ mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
625
+ rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
626
+ r2_ = r2_score(y, s_curve(x, *params))
627
+
628
+ response_curves[inp_col] = {
629
+ "K": params[0],
630
+ "b": params[1],
631
+ "a": params[2],
632
+ "x0": params[3],
633
+ }
634
+ mapes[inp_col] = mape
635
+ rmses[inp_col] = rmse
636
+ r2[inp_col] = r2_
637
+ powers[inp_col] = power
638
+
639
+ conv = spend_df[inp_col].sum() / max(input_df[inp_col].sum(), 1e-9)
640
+ conv_rates[inp_col] = conv
641
+
642
+ channel = Channel(
643
+ name=inp_col,
644
+ dates=dates,
645
+ spends=spends,
646
+ # conversion_rate = np.mean(list(conv_rates[inp_col].values())),
647
+ conversion_rate=conv_rates[inp_col],
648
+ response_curve_type="s-curve",
649
+ response_curve_params={
650
+ "K": params[0],
651
+ "b": params[1],
652
+ "a": params[2],
653
+ "x0": params[3],
654
+ },
655
+ bounds=np.array([-10, 10]),
656
+ correction=y - s_curve(x, *params),
657
+ )
658
+ channels[inp_col] = channel
659
+ if sales is None:
660
+ sales = channel.actual_sales
661
+ else:
662
+ sales += channel.actual_sales
663
+
664
+ other_contributions = (
665
+ output_df.drop([*output_cols], axis=1).sum(axis=1, numeric_only=True).values
666
+ )
667
+ correction = output_df.drop(["Date"], axis=1).sum(axis=1).values - (
668
+ sales + other_contributions
669
+ )
670
+
671
+ # Testing
672
+ # scenario_test_df = pd.DataFrame(
673
+ # columns=["other_contributions", "correction", "sales"]
674
+ # )
675
+ # scenario_test_df["other_contributions"] = other_contributions
676
+ # scenario_test_df["correction"] = correction
677
+ # scenario_test_df["sales"] = sales
678
+ # scenario_test_df.to_csv("test\scenario_test_df.csv", index=False)
679
+ # output_df.to_csv("test\output_df.csv", index=False)
680
+
681
+ scenario = Scenario(
682
+ name="default",
683
+ channels=channels,
684
+ constant=other_contributions,
685
+ correction=correction,
686
+ )
687
+ ## setting session variables
688
+ st.session_state["initialized"] = True
689
+ st.session_state["actual_df"] = input_df
690
+ st.session_state["raw_df"] = raw_df
691
+ st.session_state["contri_df"] = output_df
692
+ default_scenario_dict = class_to_dict(scenario)
693
+ st.session_state["default_scenario_dict"] = default_scenario_dict
694
+ st.session_state["scenario"] = scenario
695
+ st.session_state["channels_list"] = channel_list
696
+ st.session_state["optimization_channels"] = {
697
+ channel_name: False for channel_name in channel_list
698
+ }
699
+ st.session_state["rcs"] = response_curves
700
+
701
+ # orig_rcs_path = os.path.join(
702
+ # st.session_state["project_path"], f"orig_rcs_{target_col}_{panel_col}.json"
703
+ # )
704
+ # if Path(orig_rcs_path).exists():
705
+ # with open(orig_rcs_path, "r") as f:
706
+ # st.session_state["orig_rcs"] = json.load(f)
707
+ # else:
708
+ # st.session_state["orig_rcs"] = response_curves.copy()
709
+ # with open(orig_rcs_path, "w") as f:
710
+ # json.dump(st.session_state["orig_rcs"], f)
711
+
712
+ st.session_state["powers"] = powers
713
+ st.session_state["actual_contribution_df"] = pd.DataFrame(actual_output_dic)
714
+ st.session_state["actual_input_df"] = pd.DataFrame(actual_input_dic)
715
+
716
+ for channel in channels.values():
717
+ st.session_state[channel.name] = numerize(
718
+ channel.actual_total_spends * channel.conversion_rate, 1
719
+ )
720
+
721
+ # st.session_state["xlsx_buffer"] = io.BytesIO()
722
+ #
723
+ # if Path("../saved_scenarios.pkl").exists():
724
+ # with open("../saved_scenarios.pkl", "rb") as f:
725
+ # st.session_state["saved_scenarios"] = pickle.load(f)
726
+ # else:
727
+ # st.session_state["saved_scenarios"] = OrderedDict()
728
+ #
729
+ # st.session_state["total_spends_change"] = 0
730
+ # st.session_state["optimization_channels"] = {
731
+ # channel_name: False for channel_name in channel_list
732
+ # }
733
+ # st.session_state["disable_download_button"] = True
734
+
735
+
736
+ # def initialize_data():
737
+ # # fetch data from excel
738
+ # output = pd.read_excel('data.xlsx',sheet_name=None)
739
+ # raw_df = output['RAW DATA MMM']
740
+ # contribution_df = output['CONTRIBUTION MMM']
741
+ # Revenue_df = output['Revenue']
742
+
743
+ # ## channels to be shows
744
+ # channel_list = []
745
+ # for col in raw_df.columns:
746
+ # if 'click' in col.lower() or 'spend' in col.lower() or 'imp' in col.lower():
747
+ #
748
+ # channel_list.append(col)
749
+ # else:
750
+ # pass
751
+
752
+ # ## NOTE : Considered only Desktop spends for all calculations
753
+ # acutal_df = raw_df[raw_df.Region == 'Desktop'].copy()
754
+ # ## NOTE : Considered one year of data
755
+ # acutal_df = acutal_df[acutal_df.Date>'2020-12-31']
756
+ # actual_df = acutal_df.drop('Region',axis=1).sort_values(by='Date')[[*channel_list,'Date']]
757
+
758
+ # ##load response curves
759
+ # with open('./grammarly_response_curves.json','r') as f:
760
+ # response_curves = json.load(f)
761
+
762
+ # ## create channel dict for scenario creation
763
+ # dates = actual_df.Date.values
764
+ # channels = {}
765
+ # rcs = {}
766
+ # constant = 0.
767
+ # for i,info_dict in enumerate(response_curves):
768
+ # name = info_dict.get('name')
769
+ # response_curve_type = info_dict.get('response_curve')
770
+ # response_curve_params = info_dict.get('params')
771
+ # rcs[name] = response_curve_params
772
+ # if name != 'constant':
773
+ # spends = actual_df[name].values
774
+ # channel = Channel(name=name,dates=dates,
775
+ # spends=spends,
776
+ # response_curve_type=response_curve_type,
777
+ # response_curve_params=response_curve_params,
778
+ # bounds=np.array([-30,30]))
779
+
780
+ # channels[name] = channel
781
+ # else:
782
+ # constant = info_dict.get('value',0.) * len(dates)
783
+
784
+ # ## create scenario
785
+ # scenario = Scenario(name='default', channels=channels, constant=constant)
786
+ # default_scenario_dict = class_to_dict(scenario)
787
+
788
+
789
+ # ## setting session variables
790
+ # st.session_state['initialized'] = True
791
+ # st.session_state['actual_df'] = actual_df
792
+ # st.session_state['raw_df'] = raw_df
793
+ # st.session_state['default_scenario_dict'] = default_scenario_dict
794
+ # st.session_state['scenario'] = scenario
795
+ # st.session_state['channels_list'] = channel_list
796
+ # st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
797
+ # st.session_state['rcs'] = rcs
798
+ # for channel in channels.values():
799
+ # if channel.name not in st.session_state:
800
+ # st.session_state[channel.name] = float(channel.actual_total_spends)
801
+
802
+ # if 'xlsx_buffer' not in st.session_state:
803
+ # st.session_state['xlsx_buffer'] = io.BytesIO()
804
+
805
+ # ## for saving scenarios
806
+ # if 'saved_scenarios' not in st.session_state:
807
+ # if Path('../saved_scenarios.pkl').exists():
808
+ # with open('../saved_scenarios.pkl','rb') as f:
809
+ # st.session_state['saved_scenarios'] = pickle.load(f)
810
+
811
+ # else:
812
+ # st.session_state['saved_scenarios'] = OrderedDict()
813
+
814
+ # if 'total_spends_change' not in st.session_state:
815
+ # st.session_state['total_spends_change'] = 0
816
+
817
+ # if 'optimization_channels' not in st.session_state:
818
+ # st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
819
+
820
+
821
+ # if 'disable_download_button' not in st.session_state:
822
+ # st.session_state['disable_download_button'] = True
823
+ def create_channel_summary(scenario, target_column):
824
+ def round_off(x, round_off_decimal=0):
825
+ # round off
826
+ try:
827
+ x = float(x)
828
+ if x < 1 and x > 0:
829
+ round_off_decimal = int(np.floor(np.abs(np.log10(x)))) + max(
830
+ round_off_decimal, 1
831
+ )
832
+ x = np.round(x, round_off_decimal)
833
+ elif x < 0 and x > -1:
834
+ round_off_decimal = int(np.floor(np.abs(np.log10(np.abs(x))))) + max(
835
+ round_off_decimal, 1
836
+ )
837
+ x = -np.round(x, round_off_decimal)
838
+ else:
839
+ x = np.round(x, round_off_decimal)
840
+ return x
841
+ except:
842
+ return x
843
+
844
+ summary_columns = []
845
+
846
+ actual_spends_rows = []
847
+
848
+ actual_sales_rows = []
849
+
850
+ actual_roi_rows = []
851
+
852
+ for channel in scenario.channels.values():
853
+
854
+ name_mod = channel.name.replace("_", " ")
855
+
856
+ if name_mod.lower().endswith(" imp"):
857
+ name_mod = name_mod.replace("Imp", " Impressions")
858
+
859
+ summary_columns.append(name_mod)
860
+
861
+ actual_spends_rows.append(
862
+ format_numbers(float(channel.actual_total_spends * channel.conversion_rate))
863
+ )
864
+
865
+ actual_sales_rows.append(format_numbers((float(channel.actual_total_sales))))
866
+
867
+ roi = (channel.actual_total_sales) / (
868
+ channel.actual_total_spends * channel.conversion_rate
869
+ )
870
+ if roi < 0.0001:
871
+ roi = 0
872
+
873
+ actual_roi_rows.append(
874
+ decimal_formater(
875
+ str(round_off(roi, 2)),
876
+ n_decimals=2,
877
+ )
878
+ )
879
+
880
+ actual_summary_df = pd.DataFrame(
881
+ [
882
+ summary_columns,
883
+ actual_spends_rows,
884
+ actual_sales_rows,
885
+ actual_roi_rows,
886
+ ]
887
+ ).T
888
+
889
+ actual_summary_df.columns = ["Channel", "Spends", target_column, "ROI"]
890
+
891
+ actual_summary_df[target_column] = actual_summary_df[target_column].map(
892
+ lambda x: str(x)[1:]
893
+ )
894
+
895
+ return actual_summary_df
896
+
897
+
898
+ # def create_channel_summary(scenario):
899
+ #
900
+ # # Provided data
901
+ # data = {
902
+ # 'Channel': ['Paid Search', 'Ga will cid baixo risco', 'Digital tactic others', 'Fb la tier 1', 'Fb la tier 2', 'Paid social others', 'Programmatic', 'Kwai', 'Indicacao', 'Infleux', 'Influencer'],
903
+ # 'Spends': ['$ 11.3K', '$ 155.2K', '$ 50.7K', '$ 125.4K', '$ 125.2K', '$ 105K', '$ 3.3M', '$ 47.5K', '$ 55.9K', '$ 632.3K', '$ 48.3K'],
904
+ # 'Revenue': ['558.0K', '3.5M', '5.2M', '3.1M', '3.1M', '2.1M', '20.8M', '1.6M', '728.4K', '22.9M', '4.8M']
905
+ # }
906
+ #
907
+ # # Create DataFrame
908
+ # df = pd.DataFrame(data)
909
+ #
910
+ # # Convert currency strings to numeric values
911
+ # df['Spends'] = df['Spends'].replace({'\$': '', 'K': '*1e3', 'M': '*1e6'}, regex=True).map(pd.eval).astype(int)
912
+ # df['Revenue'] = df['Revenue'].replace({'\$': '', 'K': '*1e3', 'M': '*1e6'}, regex=True).map(pd.eval).astype(int)
913
+ #
914
+ # # Calculate ROI
915
+ # df['ROI'] = ((df['Revenue'] - df['Spends']) / df['Spends'])
916
+ #
917
+ # # Format columns
918
+ # format_currency = lambda x: f"${x:,.1f}"
919
+ # format_roi = lambda x: f"{x:.1f}"
920
+ #
921
+ # df['Spends'] = ['$ 11.3K', '$ 155.2K', '$ 50.7K', '$ 125.4K', '$ 125.2K', '$ 105K', '$ 3.3M', '$ 47.5K', '$ 55.9K', '$ 632.3K', '$ 48.3K']
922
+ # df['Revenue'] = ['$ 536.3K', '$ 3.4M', '$ 5M', '$ 3M', '$ 3M', '$ 2M', '$ 20M', '$ 1.5M', '$ 7.1M', '$ 22M', '$ 4.6M']
923
+ # df['ROI'] = df['ROI'].apply(format_roi)
924
+ #
925
+ # return df
926
+
927
+
928
+ # @st.cache_resource()
929
+ # def create_contribution_pie(_scenario):
930
+ # colors_map = {
931
+ # col: color
932
+ # for col, color in zip(
933
+ # st.session_state["channels_list"],
934
+ # plotly.colors.n_colors(
935
+ # plotly.colors.hex_to_rgb("#BE6468"),
936
+ # plotly.colors.hex_to_rgb("#E7B8B7"),
937
+ # 20,
938
+ # ),
939
+ # )
940
+ # }
941
+ # total_contribution_fig = make_subplots(
942
+ # rows=1,
943
+ # cols=2,
944
+ # subplot_titles=["Spends", "Revenue"],
945
+ # specs=[[{"type": "pie"}, {"type": "pie"}]],
946
+ # )
947
+ # total_contribution_fig.add_trace(
948
+ # go.Pie(
949
+ # labels=[
950
+ # channel_name_formating(channel_name)
951
+ # for channel_name in st.session_state["channels_list"]
952
+ # ]
953
+ # + ["Non Media"],
954
+ # values=[
955
+ # round(
956
+ # _scenario.channels[channel_name].actual_total_spends
957
+ # * _scenario.channels[channel_name].conversion_rate,
958
+ # 1,
959
+ # )
960
+ # for channel_name in st.session_state["channels_list"]
961
+ # ]
962
+ # + [0],
963
+ # marker=dict(
964
+ # colors=[
965
+ # plotly.colors.label_rgb(colors_map[channel_name])
966
+ # for channel_name in st.session_state["channels_list"]
967
+ # ]
968
+ # + ["#F0F0F0"]
969
+ # ),
970
+ # hole=0.3,
971
+ # ),
972
+ # row=1,
973
+ # col=1,
974
+ # )
975
+
976
+ # total_contribution_fig.add_trace(
977
+ # go.Pie(
978
+ # labels=[
979
+ # channel_name_formating(channel_name)
980
+ # for channel_name in st.session_state["channels_list"]
981
+ # ]
982
+ # + ["Non Media"],
983
+ # values=[
984
+ # _scenario.channels[channel_name].actual_total_sales
985
+ # for channel_name in st.session_state["channels_list"]
986
+ # ]
987
+ # + [_scenario.correction.sum() + _scenario.constant.sum()],
988
+ # hole=0.3,
989
+ # ),
990
+ # row=1,
991
+ # col=2,
992
+ # )
993
+
994
+ # total_contribution_fig.update_traces(
995
+ # textposition="inside", texttemplate="%{percent:.1%}"
996
+ # )
997
+ # total_contribution_fig.update_layout(
998
+ # uniformtext_minsize=12,
999
+ # title="Channel contribution",
1000
+ # uniformtext_mode="hide",
1001
+ # )
1002
+ # return total_contribution_fig
1003
+
1004
+
1005
+ # @st.cache_resource()
1006
+ # def create_contribution_pie(_scenario):
1007
+ # colors = plotly.colors.qualitative.Plotly # A diverse color palette
1008
+ # colors_map = {
1009
+ # col: colors[i % len(colors)]
1010
+ # for i, col in enumerate(st.session_state["channels_list"])
1011
+ # }
1012
+
1013
+ # total_contribution_fig = make_subplots(
1014
+ # rows=1,
1015
+ # cols=2,
1016
+ # subplot_titles=["Spends", "Revenue"],
1017
+ # specs=[[{"type": "pie"}, {"type": "pie"}]],
1018
+ # )
1019
+
1020
+ # total_contribution_fig.add_trace(
1021
+ # go.Pie(
1022
+ # labels=[
1023
+ # channel_name_formating(channel_name)
1024
+ # for channel_name in st.session_state["channels_list"]
1025
+ # ]
1026
+ # + ["Non Media"],
1027
+ # values=[
1028
+ # round(
1029
+ # _scenario.channels[channel_name].actual_total_spends
1030
+ # * _scenario.channels[channel_name].conversion_rate,
1031
+ # 1,
1032
+ # )
1033
+ # for channel_name in st.session_state["channels_list"]
1034
+ # ]
1035
+ # + [0],
1036
+ # marker=dict(
1037
+ # colors=[
1038
+ # colors_map[channel_name]
1039
+ # for channel_name in st.session_state["channels_list"]
1040
+ # ]
1041
+ # + ["#F0F0F0"]
1042
+ # ),
1043
+ # hole=0.3,
1044
+ # ),
1045
+ # row=1,
1046
+ # col=1,
1047
+ # )
1048
+
1049
+ # total_contribution_fig.add_trace(
1050
+ # go.Pie(
1051
+ # labels=[
1052
+ # channel_name_formating(channel_name)
1053
+ # for channel_name in st.session_state["channels_list"]
1054
+ # ]
1055
+ # + ["Non Media"],
1056
+ # values=[
1057
+ # _scenario.channels[channel_name].actual_total_sales
1058
+ # for channel_name in st.session_state["channels_list"]
1059
+ # ]
1060
+ # + [_scenario.correction.sum() + _scenario.constant.sum()],
1061
+ # marker=dict(
1062
+ # colors=[
1063
+ # colors_map[channel_name]
1064
+ # for channel_name in st.session_state["channels_list"]
1065
+ # ]
1066
+ # + ["#F0F0F0"]
1067
+ # ),
1068
+ # hole=0.3,
1069
+ # ),
1070
+ # row=1,
1071
+ # col=2,
1072
+ # )
1073
+
1074
+ # total_contribution_fig.update_traces(
1075
+ # textposition="inside", texttemplate="%{percent:.1%}"
1076
+ # )
1077
+ # total_contribution_fig.update_layout(
1078
+ # uniformtext_minsize=12,
1079
+ # title="Channel contribution",
1080
+ # uniformtext_mode="hide",
1081
+ # )
1082
+ # return total_contribution_fig
1083
+
1084
+
1085
+ # @st.cache_resource()
1086
+ def create_contribution_pie(_scenario, target_col):
1087
+ colors = plotly.colors.qualitative.Plotly # A diverse color palette
1088
+ colors_map = {
1089
+ col: colors[i % len(colors)]
1090
+ for i, col in enumerate(st.session_state["channels_list"])
1091
+ }
1092
+
1093
+ spends_values = [
1094
+ round(
1095
+ _scenario.channels[channel_name].actual_total_spends
1096
+ * _scenario.channels[channel_name].conversion_rate,
1097
+ 1,
1098
+ )
1099
+ for channel_name in st.session_state["channels_list"]
1100
+ ]
1101
+ spends_values.append(0) # Adding Non Media value
1102
+
1103
+ revenue_values = [
1104
+ _scenario.channels[channel_name].actual_total_sales
1105
+ for channel_name in st.session_state["channels_list"]
1106
+ ]
1107
+ revenue_values.append(
1108
+ _scenario.correction.sum() + _scenario.constant.sum()
1109
+ ) # Adding Non Media value
1110
+
1111
+ total_contribution_fig = make_subplots(
1112
+ rows=1,
1113
+ cols=2,
1114
+ subplot_titles=["Spend", target_col],
1115
+ specs=[[{"type": "pie"}, {"type": "pie"}]],
1116
+ )
1117
+
1118
+ total_contribution_fig.add_trace(
1119
+ go.Pie(
1120
+ labels=[
1121
+ channel_name_formating(channel_name)
1122
+ for channel_name in st.session_state["channels_list"]
1123
+ ]
1124
+ + ["Non Media"],
1125
+ values=spends_values,
1126
+ marker=dict(
1127
+ colors=[
1128
+ colors_map[channel_name]
1129
+ for channel_name in st.session_state["channels_list"]
1130
+ ]
1131
+ + ["#F0F0F0"]
1132
+ ),
1133
+ hole=0.3,
1134
+ ),
1135
+ row=1,
1136
+ col=1,
1137
+ )
1138
+
1139
+ total_contribution_fig.add_trace(
1140
+ go.Pie(
1141
+ labels=[
1142
+ channel_name_formating(channel_name)
1143
+ for channel_name in st.session_state["channels_list"]
1144
+ ]
1145
+ + ["Non Media"],
1146
+ values=revenue_values,
1147
+ marker=dict(
1148
+ colors=[
1149
+ colors_map[channel_name]
1150
+ for channel_name in st.session_state["channels_list"]
1151
+ ]
1152
+ + ["#F0F0F0"]
1153
+ ),
1154
+ hole=0.3,
1155
+ ),
1156
+ row=1,
1157
+ col=2,
1158
+ )
1159
+
1160
+ total_contribution_fig.update_traces(
1161
+ textposition="inside", texttemplate="%{percent:.1%}"
1162
+ )
1163
+ total_contribution_fig.update_layout(
1164
+ uniformtext_minsize=12,
1165
+ title="Channel contribution",
1166
+ uniformtext_mode="hide",
1167
+ )
1168
+ return total_contribution_fig
1169
+
1170
+
1171
+ # def create_contribuion_stacked_plot(scenario):
1172
+ # weekly_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "bar"}, {"type": "bar"}]])
1173
+ # raw_df = st.session_state['raw_df']
1174
+ # df = raw_df.sort_values(by='Date')
1175
+ # x = df.Date
1176
+ # weekly_spends_data = []
1177
+ # weekly_sales_data = []
1178
+ # for channel_name in st.session_state['channels_list']:
1179
+ # weekly_spends_data.append((go.Bar(x=x,
1180
+ # y=scenario.channels[channel_name].actual_spends * scenario.channels[channel_name].conversion_rate,
1181
+ # name=channel_name_formating(channel_name),
1182
+ # hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
1183
+ # legendgroup=channel_name)))
1184
+ # weekly_sales_data.append((go.Bar(x=x,
1185
+ # y=scenario.channels[channel_name].actual_sales,
1186
+ # name=channel_name_formating(channel_name),
1187
+ # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1188
+ # legendgroup=channel_name, showlegend=False)))
1189
+ # for _d in weekly_spends_data:
1190
+ # weekly_contribution_fig.add_trace(_d, row=1, col=1)
1191
+ # for _d in weekly_sales_data:
1192
+ # weekly_contribution_fig.add_trace(_d, row=1, col=2)
1193
+ # weekly_contribution_fig.add_trace(go.Bar(x=x,
1194
+ # y=scenario.constant + scenario.correction,
1195
+ # name='Non Media',
1196
+ # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), row=1, col=2)
1197
+ # weekly_contribution_fig.update_layout(barmode='stack', title='Channel contribuion by week', xaxis_title='Date')
1198
+ # weekly_contribution_fig.update_xaxes(showgrid=False)
1199
+ # weekly_contribution_fig.update_yaxes(showgrid=False)
1200
+ # return weekly_contribution_fig
1201
+
1202
+ # @st.cache_resource()
1203
+ # def create_channel_spends_sales_plot(channel):
1204
+ # if channel is not None:
1205
+ # x = channel.dates
1206
+ # _spends = channel.actual_spends * channel.conversion_rate
1207
+ # _sales = channel.actual_sales
1208
+ # channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1209
+ # channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
1210
+ # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#005b96'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
1211
+ # channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
1212
+ # channel_sales_spends_fig.update_xaxes(showgrid=False)
1213
+ # channel_sales_spends_fig.update_yaxes(showgrid=False)
1214
+ # else:
1215
+ # raw_df = st.session_state['raw_df']
1216
+ # df = raw_df.sort_values(by='Date')
1217
+ # x = df.Date
1218
+ # scenario = class_from_dict(st.session_state['default_scenario_dict'])
1219
+ # _sales = scenario.constant + scenario.correction
1220
+ # channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1221
+ # channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
1222
+ # # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#15C39A'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
1223
+ # channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
1224
+ # channel_sales_spends_fig.update_xaxes(showgrid=False)
1225
+ # channel_sales_spends_fig.update_yaxes(showgrid=False)
1226
+ # return channel_sales_spends_fig
1227
+
1228
+
1229
+ # Define a shared color palette
1230
+
1231
+
1232
+ # def create_contribution_pie():
1233
+ # color_palette = ['#F3F3F0', '#5E7D7E', '#2FA1FF', '#00EDED', '#00EAE4', '#304550', '#EDEBEB', '#7FBEFD', '#003059', '#A2F3F3', '#E1D6E2', '#B6B6B6']
1234
+ # total_contribution_fig = make_subplots(rows=1, cols=2, subplot_titles=['Spends', 'Revenue'], specs=[[{"type": "pie"}, {"type": "pie"}]])
1235
+ #
1236
+ # channels_list = ['Paid Search', 'Ga will cid baixo risco', 'Digital tactic others', 'Fb la tier 1', 'Fb la tier 2', 'Paid social others', 'Programmatic', 'Kwai', 'Indicacao', 'Infleux', 'Influencer', 'Non Media']
1237
+ #
1238
+ # # Assign colors from the limited palette to channels
1239
+ # colors_map = {col: color_palette[i % len(color_palette)] for i, col in enumerate(channels_list)}
1240
+ # colors_map['Non Media'] = color_palette[5] # Assign fixed green color for 'Non Media'
1241
+ #
1242
+ # # Hardcoded values for Spends and Revenue
1243
+ # spends_values = [0.5, 3.36, 1.1, 2.7, 2.7, 2.27, 70.6, 1, 1, 13.7, 1, 0]
1244
+ # revenue_values = [1, 4, 5, 3, 3, 2, 50.8, 1.5, 0.7, 13, 0, 16]
1245
+ #
1246
+ # # Add trace for Spends pie chart
1247
+ # total_contribution_fig.add_trace(
1248
+ # go.Pie(
1249
+ # labels=[channel_name for channel_name in channels_list],
1250
+ # values=spends_values,
1251
+ # marker=dict(colors=[colors_map[channel_name] for channel_name in channels_list]),
1252
+ # hole=0.3
1253
+ # ),
1254
+ # row=1, col=1
1255
+ # )
1256
+ #
1257
+ # # Add trace for Revenue pie chart
1258
+ # total_contribution_fig.add_trace(
1259
+ # go.Pie(
1260
+ # labels=[channel_name for channel_name in channels_list],
1261
+ # values=revenue_values,
1262
+ # marker=dict(colors=[colors_map[channel_name] for channel_name in channels_list]),
1263
+ # hole=0.3
1264
+ # ),
1265
+ # row=1, col=2
1266
+ # )
1267
+ #
1268
+ # total_contribution_fig.update_traces(textposition='inside', texttemplate='%{percent:.1%}')
1269
+ # total_contribution_fig.update_layout(uniformtext_minsize=12, title='Channel contribution', uniformtext_mode='hide')
1270
+ # return total_contribution_fig
1271
+
1272
+ # @st.cache_resource()
1273
+ # def create_contribuion_stacked_plot(_scenario):
1274
+ # weekly_contribution_fig = make_subplots(
1275
+ # rows=1,
1276
+ # cols=2,
1277
+ # subplot_titles=["Spends", "Revenue"],
1278
+ # specs=[[{"type": "bar"}, {"type": "bar"}]],
1279
+ # )
1280
+ # raw_df = st.session_state["raw_df"]
1281
+ # df = raw_df.sort_values(by="Date")
1282
+ # x = df.Date
1283
+ # weekly_spends_data = []
1284
+ # weekly_sales_data = []
1285
+
1286
+ # for i, channel_name in enumerate(st.session_state["channels_list"]):
1287
+ # color = color_palette[i % len(color_palette)]
1288
+
1289
+ # weekly_spends_data.append(
1290
+ # go.Bar(
1291
+ # x=x,
1292
+ # y=_scenario.channels[channel_name].actual_spends
1293
+ # * _scenario.channels[channel_name].conversion_rate,
1294
+ # name=channel_name_formating(channel_name),
1295
+ # hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
1296
+ # legendgroup=channel_name,
1297
+ # marker_color=color,
1298
+ # )
1299
+ # )
1300
+
1301
+ # weekly_sales_data.append(
1302
+ # go.Bar(
1303
+ # x=x,
1304
+ # y=_scenario.channels[channel_name].actual_sales,
1305
+ # name=channel_name_formating(channel_name),
1306
+ # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1307
+ # legendgroup=channel_name,
1308
+ # showlegend=False,
1309
+ # marker_color=color,
1310
+ # )
1311
+ # )
1312
+
1313
+ # for _d in weekly_spends_data:
1314
+ # weekly_contribution_fig.add_trace(_d, row=1, col=1)
1315
+ # for _d in weekly_sales_data:
1316
+ # weekly_contribution_fig.add_trace(_d, row=1, col=2)
1317
+
1318
+ # weekly_contribution_fig.add_trace(
1319
+ # go.Bar(
1320
+ # x=x,
1321
+ # y=_scenario.constant + _scenario.correction,
1322
+ # name="Non Media",
1323
+ # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1324
+ # marker_color=color_palette[-1],
1325
+ # ),
1326
+ # row=1,
1327
+ # col=2,
1328
+ # )
1329
+
1330
+ # weekly_contribution_fig.update_layout(
1331
+ # barmode="stack",
1332
+ # title="Channel contribution by week",
1333
+ # xaxis_title="Date",
1334
+ # )
1335
+ # weekly_contribution_fig.update_xaxes(showgrid=False)
1336
+ # weekly_contribution_fig.update_yaxes(showgrid=False)
1337
+ # return weekly_contribution_fig
1338
+
1339
+
1340
+ # @st.cache_resource()
1341
+ def create_contribuion_stacked_plot(_scenario, target_col):
1342
+ color_palette = plotly.colors.qualitative.Plotly # A diverse color palette
1343
+
1344
+ weekly_contribution_fig = make_subplots(
1345
+ rows=1,
1346
+ cols=2,
1347
+ subplot_titles=["Spend", target_col],
1348
+ specs=[[{"type": "bar"}, {"type": "bar"}]],
1349
+ )
1350
+
1351
+ raw_df = st.session_state["raw_df"]
1352
+ df = raw_df.sort_values(by="Date")
1353
+ x = df.Date
1354
+ weekly_spends_data = []
1355
+ weekly_sales_data = []
1356
+
1357
+ for i, channel_name in enumerate(st.session_state["channels_list"]):
1358
+ color = color_palette[i % len(color_palette)]
1359
+
1360
+ weekly_spends_data.append(
1361
+ go.Bar(
1362
+ x=x,
1363
+ y=_scenario.channels[channel_name].actual_spends
1364
+ * _scenario.channels[channel_name].conversion_rate,
1365
+ name=channel_name_formating(channel_name),
1366
+ hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
1367
+ legendgroup=channel_name,
1368
+ marker_color=color,
1369
+ )
1370
+ )
1371
+
1372
+ weekly_sales_data.append(
1373
+ go.Bar(
1374
+ x=x,
1375
+ y=_scenario.channels[channel_name].actual_sales,
1376
+ name=channel_name_formating(channel_name),
1377
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1378
+ legendgroup=channel_name,
1379
+ showlegend=False,
1380
+ marker_color=color,
1381
+ )
1382
+ )
1383
+
1384
+ for _d in weekly_spends_data:
1385
+ weekly_contribution_fig.add_trace(_d, row=1, col=1)
1386
+ for _d in weekly_sales_data:
1387
+ weekly_contribution_fig.add_trace(_d, row=1, col=2)
1388
+
1389
+ weekly_contribution_fig.add_trace(
1390
+ go.Bar(
1391
+ x=x,
1392
+ y=_scenario.constant + _scenario.correction,
1393
+ name="Non Media",
1394
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1395
+ marker_color=color_palette[-1],
1396
+ ),
1397
+ row=1,
1398
+ col=2,
1399
+ )
1400
+
1401
+ weekly_contribution_fig.update_layout(
1402
+ barmode="stack",
1403
+ title="Channel contribution by week",
1404
+ xaxis_title="Date",
1405
+ )
1406
+ weekly_contribution_fig.update_xaxes(showgrid=False)
1407
+ weekly_contribution_fig.update_yaxes(showgrid=False)
1408
+ return weekly_contribution_fig
1409
+
1410
+
1411
+ def create_channel_spends_sales_plot(channel, target_col):
1412
+ if channel is not None:
1413
+ x = channel.dates
1414
+ _spends = channel.actual_spends * channel.conversion_rate
1415
+ _sales = channel.actual_sales
1416
+ channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1417
+ channel_sales_spends_fig.add_trace(
1418
+ go.Bar(
1419
+ x=x,
1420
+ y=_sales,
1421
+ marker_color=color_palette[
1422
+ 3
1423
+ ], # You can choose a color from the palette
1424
+ name=target_col,
1425
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1426
+ ),
1427
+ secondary_y=False,
1428
+ )
1429
+
1430
+ channel_sales_spends_fig.add_trace(
1431
+ go.Scatter(
1432
+ x=x,
1433
+ y=_spends,
1434
+ line=dict(
1435
+ color=color_palette[2]
1436
+ ), # You can choose another color from the palette
1437
+ name="Spends",
1438
+ hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
1439
+ ),
1440
+ secondary_y=True,
1441
+ )
1442
+
1443
+ channel_sales_spends_fig.update_layout(
1444
+ xaxis_title="Date",
1445
+ yaxis_title=target_col,
1446
+ yaxis2_title="Spend ($)",
1447
+ title="Weekly Channel Spends and " + target_col,
1448
+ )
1449
+ channel_sales_spends_fig.update_xaxes(showgrid=False)
1450
+ channel_sales_spends_fig.update_yaxes(showgrid=False)
1451
+ else:
1452
+ raw_df = st.session_state["raw_df"]
1453
+ df = raw_df.sort_values(by="Date")
1454
+ x = df.Date
1455
+ scenario = class_from_dict(st.session_state["default_scenario_dict"])
1456
+ _sales = scenario.constant + scenario.correction
1457
+ channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
1458
+ channel_sales_spends_fig.add_trace(
1459
+ go.Bar(
1460
+ x=x,
1461
+ y=_sales,
1462
+ marker_color=color_palette[
1463
+ 0
1464
+ ], # You can choose a color from the palette
1465
+ name="Revenue",
1466
+ hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
1467
+ ),
1468
+ secondary_y=False,
1469
+ )
1470
+
1471
+ channel_sales_spends_fig.update_layout(
1472
+ xaxis_title="Date",
1473
+ yaxis_title="Revenue",
1474
+ yaxis2_title="Spend ($)",
1475
+ title="Channel spends and Revenue week-wise",
1476
+ )
1477
+ channel_sales_spends_fig.update_xaxes(showgrid=False)
1478
+ channel_sales_spends_fig.update_yaxes(showgrid=False)
1479
+
1480
+ return channel_sales_spends_fig
1481
+
1482
+
1483
+ def format_numbers(value, n_decimals=1, include_indicator=True):
1484
+ if include_indicator:
1485
+ return f"{CURRENCY_INDICATOR} {numerize(value,n_decimals)}"
1486
+ else:
1487
+ return f"{numerize(value,n_decimals)}"
1488
+
1489
+
1490
+ def decimal_formater(num_string, n_decimals=1):
1491
+ parts = num_string.split(".")
1492
+ if len(parts) == 1:
1493
+ return num_string + "." + "0" * n_decimals
1494
+ else:
1495
+ to_be_padded = n_decimals - len(parts[-1])
1496
+ if to_be_padded > 0:
1497
+ return num_string + "0" * to_be_padded
1498
+ else:
1499
+ return num_string
1500
+
1501
+
1502
+ def channel_name_formating(channel_name):
1503
+ name_mod = channel_name.replace("_", " ")
1504
+ if name_mod.lower().endswith(" imp"):
1505
+ name_mod = name_mod.replace("Imp", "Spend")
1506
+ elif name_mod.lower().endswith(" clicks"):
1507
+ name_mod = name_mod.replace("Clicks", "Spend")
1508
+ return name_mod
1509
+
1510
+
1511
+ def send_email(email, message):
1512
+ s = smtplib.SMTP("smtp.gmail.com", 587)
1513
+ s.starttls()
1514
+ s.login("[email protected]", "jgydhpfusuremcol")
1515
+ s.sendmail("[email protected]", email, message)
1516
+ s.quit()
1517
+
1518
+
1519
+ # if __name__ == "__main__":
1520
+ # initialize_data()