Spaces:
Sleeping
Sleeping
Upload 40 files
Browse files- .gitattributes +36 -35
- .streamlit/config.toml +7 -0
- Home.py +1614 -0
- README.md +13 -13
- config.json +0 -0
- config.yaml +0 -0
- constants.py +188 -0
- data_analysis.py +267 -0
- data_prep.py +185 -0
- db/imp_db.db +3 -0
- db_creation.py +114 -0
- log_application.py +157 -0
- logo.png +0 -0
- logs/e111111_None_20250408.log +1 -0
- logs/e111111_na_demo_user_demo_project_20250407.log +1 -0
- logs/e111111_na_demo_user_demo_project_20250408.log +2 -0
- mmm_tool_document.docx +0 -0
- pages/10_Saved_Scenarios.py +344 -0
- pages/11_AI_Model_Media_Recommendation.py +695 -0
- pages/12_Glossary.py +150 -0
- pages/14_User_Management.py +503 -0
- pages/1_Data_Import.py +1213 -0
- pages/2_Data_Assessment.py +467 -0
- pages/3_AI_Model_Transformations.py +1326 -0
- pages/4_AI_Model_Build.py +0 -0
- pages/5_AI Model_Tuning.py +1215 -0
- pages/6_AI_Model_Validation.py +960 -0
- pages/7_AI_Model_Media_Performance.py +733 -0
- pages/8_Response_Curves.py +530 -0
- pages/9_Scenario_Planner.py +0 -0
- post_gres_cred.py +2 -0
- ppt/template.txt +0 -0
- ppt_utils.py +1419 -0
- requirements.txt +96 -0
- scenario.py +763 -0
- single_manifest.yml +0 -0
- styles.css +97 -0
- temp_stdout.txt +0 -0
- utilities.py +2155 -0
- utilities_with_panel.py +1520 -0
.gitattributes
CHANGED
@@ -1,35 +1,36 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
db/imp_db.db filter=lfs diff=lfs merge=lfs -text
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[theme]
|
3 |
+
primaryColor="#f6ad51"
|
4 |
+
backgroundColor="#FFFFFF"
|
5 |
+
secondaryBackgroundColor="#F0F2F6"
|
6 |
+
textColor="#31333F"
|
7 |
+
font="sans serif"
|
Home.py
ADDED
@@ -0,0 +1,1614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# importing pacakages
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(layout="wide")
|
5 |
+
from utilities import (
|
6 |
+
load_local_css,
|
7 |
+
set_header,
|
8 |
+
ensure_project_dct_structure,
|
9 |
+
store_hashed_password,
|
10 |
+
verify_password,
|
11 |
+
is_pswrd_flag_set,
|
12 |
+
set_pswrd_flag,
|
13 |
+
)
|
14 |
+
import os
|
15 |
+
from datetime import datetime
|
16 |
+
import pandas as pd
|
17 |
+
import pickle
|
18 |
+
import psycopg2
|
19 |
+
|
20 |
+
#
|
21 |
+
import numbers
|
22 |
+
from collections import OrderedDict
|
23 |
+
import re
|
24 |
+
from ppt_utils import create_ppt
|
25 |
+
from constants import default_dct
|
26 |
+
import time
|
27 |
+
from log_application import log_message, delete_old_log_files
|
28 |
+
import sqlite3
|
29 |
+
|
30 |
+
# setting page config
|
31 |
+
load_local_css("styles.css")
|
32 |
+
set_header()
|
33 |
+
db_cred = None
|
34 |
+
|
35 |
+
# --------------Functions----------------------#
|
36 |
+
|
37 |
+
# # schema = db_cred["schema"]
|
38 |
+
|
39 |
+
|
40 |
+
##API DATA#######################
|
41 |
+
# Function to load gold layer data
|
42 |
+
# @st.cache_data(show_spinner=False)
|
43 |
+
def load_gold_layer_data(table_name):
|
44 |
+
# Fetch Table
|
45 |
+
query = f"""
|
46 |
+
SELECT * FROM {table_name};
|
47 |
+
"""
|
48 |
+
|
49 |
+
# Execute the query and get the results
|
50 |
+
results = query_excecuter_postgres(
|
51 |
+
query, db_cred, insert=False, return_dataframe=True
|
52 |
+
)
|
53 |
+
|
54 |
+
if results is not None and not results.empty:
|
55 |
+
# Create a DataFrame
|
56 |
+
gold_layer_df = results
|
57 |
+
|
58 |
+
else:
|
59 |
+
st.warning("No data found for the selected table.")
|
60 |
+
st.stop()
|
61 |
+
|
62 |
+
# Columns to be removed
|
63 |
+
columns_to_remove = [
|
64 |
+
"clnt_nam",
|
65 |
+
"crte_dt_tm",
|
66 |
+
"crte_by_uid",
|
67 |
+
"updt_dt_tm",
|
68 |
+
"updt_by_uid",
|
69 |
+
"campgn_id",
|
70 |
+
"campgn_nam",
|
71 |
+
"ad_id",
|
72 |
+
"ad_nam",
|
73 |
+
"tctc_id",
|
74 |
+
"tctc_nam",
|
75 |
+
"campgn_grp_id",
|
76 |
+
"campgn_grp_nam",
|
77 |
+
"ad_grp_id",
|
78 |
+
"ad_grp_nam",
|
79 |
+
]
|
80 |
+
|
81 |
+
# TEMP CODE
|
82 |
+
gold_layer_df = gold_layer_df.rename(
|
83 |
+
columns={
|
84 |
+
"imprssns_cnt": "mda_imprssns_cnt",
|
85 |
+
"clcks_cnt": "mda_clcks_cnt",
|
86 |
+
"vd_vws_cnt": "mda_vd_vws_cnt",
|
87 |
+
}
|
88 |
+
)
|
89 |
+
|
90 |
+
# Remove specific columns
|
91 |
+
gold_layer_df = gold_layer_df.drop(columns=columns_to_remove, errors="ignore")
|
92 |
+
|
93 |
+
# Convert columns to numeric or datetime as appropriate
|
94 |
+
for col in gold_layer_df.columns:
|
95 |
+
if (
|
96 |
+
col.startswith("rspns_mtrc_")
|
97 |
+
or col.startswith("mda_")
|
98 |
+
or col.startswith("exogenous_")
|
99 |
+
or col.startswith("internal_")
|
100 |
+
or col in ["spnd_amt"]
|
101 |
+
):
|
102 |
+
gold_layer_df[col] = pd.to_numeric(gold_layer_df[col], errors="coerce")
|
103 |
+
elif col == "rcrd_dt":
|
104 |
+
gold_layer_df[col] = pd.to_datetime(gold_layer_df[col], errors="coerce")
|
105 |
+
|
106 |
+
# Replace columns starting with 'mda_' to 'media_'
|
107 |
+
gold_layer_df.columns = [
|
108 |
+
(col.replace("mda_", "media_") if col.startswith("mda_") else col)
|
109 |
+
for col in gold_layer_df.columns
|
110 |
+
]
|
111 |
+
|
112 |
+
# Identify non-numeric columns
|
113 |
+
non_numeric_columns = gold_layer_df.select_dtypes(exclude=["number"]).columns
|
114 |
+
allow_non_numeric_columns = ["rcrd_dt", "aggrgtn_lvl", "sub_chnnl_nam", "panl_nam"]
|
115 |
+
|
116 |
+
# Remove non-numeric columns except for allowed non-numeric columns
|
117 |
+
non_numeric_columns_to_remove = [
|
118 |
+
col for col in non_numeric_columns if col not in allow_non_numeric_columns
|
119 |
+
]
|
120 |
+
gold_layer_df = gold_layer_df.drop(
|
121 |
+
columns=non_numeric_columns_to_remove, errors="ignore"
|
122 |
+
)
|
123 |
+
|
124 |
+
# Remove specific columns
|
125 |
+
allow_columns = ["rcrd_dt", "aggrgtn_lvl", "sub_chnnl_nam", "panl_nam", "spnd_amt"]
|
126 |
+
for col in gold_layer_df.columns:
|
127 |
+
if (
|
128 |
+
col.startswith("rspns_mtrc_")
|
129 |
+
or col.startswith("media_")
|
130 |
+
or col.startswith("exogenous_")
|
131 |
+
or col.startswith("internal_")
|
132 |
+
):
|
133 |
+
allow_columns.append(col)
|
134 |
+
gold_layer_df = gold_layer_df[allow_columns]
|
135 |
+
|
136 |
+
# Rename columns
|
137 |
+
gold_layer_df = gold_layer_df.rename(
|
138 |
+
columns={
|
139 |
+
"rcrd_dt": "date",
|
140 |
+
"sub_chnnl_nam": "channels",
|
141 |
+
"panl_nam": "panel",
|
142 |
+
"spnd_amt": "spends",
|
143 |
+
}
|
144 |
+
)
|
145 |
+
|
146 |
+
# Clean column values
|
147 |
+
gold_layer_df["panel"] = (
|
148 |
+
gold_layer_df["panel"].astype(str).str.lower().str.strip().str.replace(" ", "_")
|
149 |
+
)
|
150 |
+
gold_layer_df["channels"] = (
|
151 |
+
gold_layer_df["channels"]
|
152 |
+
.astype(str)
|
153 |
+
.str.lower()
|
154 |
+
.str.strip()
|
155 |
+
.str.replace(" ", "_")
|
156 |
+
)
|
157 |
+
|
158 |
+
# Replace columns starting with 'rspns_mtrc_' to 'response_metric_'
|
159 |
+
gold_layer_df.columns = [
|
160 |
+
(
|
161 |
+
col.replace("rspns_mtrc_", "response_metric_")
|
162 |
+
if col.startswith("rspns_mtrc_")
|
163 |
+
else col
|
164 |
+
)
|
165 |
+
for col in gold_layer_df.columns
|
166 |
+
]
|
167 |
+
|
168 |
+
# Get the minimum date from the main dataframe
|
169 |
+
min_date = gold_layer_df["date"].min()
|
170 |
+
|
171 |
+
# Get maximum dates for daily and weekly data
|
172 |
+
max_date_daily = None
|
173 |
+
max_date_weekly = None
|
174 |
+
|
175 |
+
if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "daily"].empty:
|
176 |
+
max_date_daily = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "daily"][
|
177 |
+
"date"
|
178 |
+
].max()
|
179 |
+
|
180 |
+
if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"].empty:
|
181 |
+
max_date_weekly = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"][
|
182 |
+
"date"
|
183 |
+
].max() + pd.DateOffset(days=6)
|
184 |
+
|
185 |
+
# Determine final maximum date
|
186 |
+
if max_date_daily is not None and max_date_weekly is not None:
|
187 |
+
final_max_date = max(max_date_daily, max_date_weekly)
|
188 |
+
elif max_date_daily is not None:
|
189 |
+
final_max_date = max_date_daily
|
190 |
+
elif max_date_weekly is not None:
|
191 |
+
final_max_date = max_date_weekly
|
192 |
+
|
193 |
+
# Create a date range with daily frequency
|
194 |
+
date_range = pd.date_range(start=min_date, end=final_max_date, freq="D")
|
195 |
+
|
196 |
+
# Create a base DataFrame with all channels and all panels for each channel
|
197 |
+
unique_channels = gold_layer_df["channels"].unique()
|
198 |
+
unique_panels = gold_layer_df["panel"].unique()
|
199 |
+
base_data = [
|
200 |
+
(channel, panel, date)
|
201 |
+
for channel in unique_channels
|
202 |
+
for panel in unique_panels
|
203 |
+
for date in date_range
|
204 |
+
]
|
205 |
+
base_df = pd.DataFrame(base_data, columns=["channels", "panel", "date"])
|
206 |
+
|
207 |
+
# Process weekly data to convert it to daily
|
208 |
+
if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"].empty:
|
209 |
+
weekly_data = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "weekly"].copy()
|
210 |
+
daily_data = []
|
211 |
+
|
212 |
+
for index, row in weekly_data.iterrows():
|
213 |
+
week_start = pd.to_datetime(row["date"]) - pd.to_timedelta(
|
214 |
+
pd.to_datetime(row["date"]).weekday(), unit="D"
|
215 |
+
)
|
216 |
+
for i in range(7):
|
217 |
+
daily_date = week_start + pd.DateOffset(days=i)
|
218 |
+
new_row = row.copy()
|
219 |
+
new_row["date"] = daily_date
|
220 |
+
for col in new_row.index:
|
221 |
+
if isinstance(new_row[col], numbers.Number):
|
222 |
+
new_row[col] = new_row[col] / 7
|
223 |
+
daily_data.append(new_row)
|
224 |
+
|
225 |
+
daily_data_df = pd.DataFrame(daily_data)
|
226 |
+
daily_data_df["aggrgtn_lvl"] = "daily"
|
227 |
+
gold_layer_df = pd.concat(
|
228 |
+
[gold_layer_df[gold_layer_df["aggrgtn_lvl"] != "weekly"], daily_data_df],
|
229 |
+
ignore_index=True,
|
230 |
+
)
|
231 |
+
|
232 |
+
# Process monthly data to convert it to daily
|
233 |
+
if not gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "monthly"].empty:
|
234 |
+
monthly_data = gold_layer_df[gold_layer_df["aggrgtn_lvl"] == "monthly"].copy()
|
235 |
+
daily_data = []
|
236 |
+
|
237 |
+
for index, row in monthly_data.iterrows():
|
238 |
+
month_start = pd.to_datetime(row["date"]).replace(day=1)
|
239 |
+
next_month_start = (month_start + pd.DateOffset(months=1)).replace(day=1)
|
240 |
+
days_in_month = (next_month_start - month_start).days
|
241 |
+
|
242 |
+
for i in range(days_in_month):
|
243 |
+
daily_date = month_start + pd.DateOffset(days=i)
|
244 |
+
new_row = row.copy()
|
245 |
+
new_row["date"] = daily_date
|
246 |
+
for col in new_row.index:
|
247 |
+
if isinstance(new_row[col], numbers.Number):
|
248 |
+
new_row[col] = new_row[col] / days_in_month
|
249 |
+
daily_data.append(new_row)
|
250 |
+
|
251 |
+
daily_data_df = pd.DataFrame(daily_data)
|
252 |
+
daily_data_df["aggrgtn_lvl"] = "daily"
|
253 |
+
gold_layer_df = pd.concat(
|
254 |
+
[gold_layer_df[gold_layer_df["aggrgtn_lvl"] != "monthly"], daily_data_df],
|
255 |
+
ignore_index=True,
|
256 |
+
)
|
257 |
+
|
258 |
+
# Remove aggrgtn_lvl column
|
259 |
+
gold_layer_df = gold_layer_df.drop(columns=["aggrgtn_lvl"], errors="ignore")
|
260 |
+
|
261 |
+
# Group by 'panel', and 'date'
|
262 |
+
gold_layer_df = gold_layer_df.groupby(["channels", "panel", "date"]).sum()
|
263 |
+
|
264 |
+
# Merge gold_layer_df to base_df on channels, panel and date
|
265 |
+
gold_layer_df_cleaned = pd.merge(
|
266 |
+
base_df, gold_layer_df, on=["channels", "panel", "date"], how="left"
|
267 |
+
)
|
268 |
+
|
269 |
+
# Pivot the dataframe and rename columns
|
270 |
+
pivot_columns = [
|
271 |
+
col
|
272 |
+
for col in gold_layer_df_cleaned.columns
|
273 |
+
if col not in ["channels", "panel", "date"]
|
274 |
+
]
|
275 |
+
gold_layer_df_cleaned = gold_layer_df_cleaned.pivot_table(
|
276 |
+
index=["date", "panel"], columns="channels", values=pivot_columns, aggfunc="sum"
|
277 |
+
).reset_index()
|
278 |
+
|
279 |
+
# Flatten the columns
|
280 |
+
gold_layer_df_cleaned.columns = [
|
281 |
+
"_".join(col).strip() if col[1] else col[0]
|
282 |
+
for col in gold_layer_df_cleaned.columns.values
|
283 |
+
]
|
284 |
+
|
285 |
+
# Replace columns ending with '_all' to '_total'
|
286 |
+
gold_layer_df_cleaned.columns = [
|
287 |
+
col.replace("_all", "_total") if col.endswith("_all") else col
|
288 |
+
for col in gold_layer_df_cleaned.columns
|
289 |
+
]
|
290 |
+
|
291 |
+
# Clean panel column values
|
292 |
+
gold_layer_df_cleaned["panel"] = (
|
293 |
+
gold_layer_df_cleaned["panel"]
|
294 |
+
.astype(str)
|
295 |
+
.str.lower()
|
296 |
+
.str.strip()
|
297 |
+
.str.replace(" ", "_")
|
298 |
+
)
|
299 |
+
|
300 |
+
# Drop all columns that end with '_total' except those starting with 'response_metric_'
|
301 |
+
cols_to_drop = [
|
302 |
+
col
|
303 |
+
for col in gold_layer_df_cleaned.columns
|
304 |
+
if col.endswith("_total") and not col.startswith("response_metric_")
|
305 |
+
]
|
306 |
+
gold_layer_df_cleaned.drop(columns=cols_to_drop, inplace=True)
|
307 |
+
|
308 |
+
return gold_layer_df_cleaned
|
309 |
+
|
310 |
+
|
311 |
+
def check_valid_name():
|
312 |
+
if (
|
313 |
+
not st.session_state["project_name_box"]
|
314 |
+
.lower()
|
315 |
+
.startswith(defualt_project_prefix)
|
316 |
+
):
|
317 |
+
st.session_state["disable_create_project"] = True
|
318 |
+
with warning_box:
|
319 |
+
st.warning("Project Name should follow naming conventions")
|
320 |
+
st.session_state["warning"] = (
|
321 |
+
"Project Name should follow naming conventions!"
|
322 |
+
)
|
323 |
+
|
324 |
+
with warning_box:
|
325 |
+
st.warning("Project Name should follow naming conventions")
|
326 |
+
st.button("Reset Name", on_click=reset_project_text_box, key="2")
|
327 |
+
|
328 |
+
if st.session_state["project_name_box"] == defualt_project_prefix:
|
329 |
+
with warning_box:
|
330 |
+
st.warning("Cannot Name only with Prefix")
|
331 |
+
st.session_state["warning"] = "Cannot Name only with Prefix"
|
332 |
+
st.session_state["disable_create_project"] = True
|
333 |
+
|
334 |
+
if st.session_state["project_name_box"] in user_projects:
|
335 |
+
with warning_box:
|
336 |
+
st.warning("Project already exists please enter new name")
|
337 |
+
st.session_state["warning"] = "Project already exists please enter new name"
|
338 |
+
st.session_state["disable_create_project"] = True
|
339 |
+
else:
|
340 |
+
st.session_state["disable_create_project"] = False
|
341 |
+
|
342 |
+
|
343 |
+
def query_excecuter_postgres(
|
344 |
+
query,
|
345 |
+
db_path=None,
|
346 |
+
params=None,
|
347 |
+
insert=True,
|
348 |
+
insert_retrieve=False,
|
349 |
+
db_cred=None,
|
350 |
+
):
|
351 |
+
"""
|
352 |
+
Executes a SQL query on a SQLite database, handling both insert and select operations.
|
353 |
+
|
354 |
+
Parameters:
|
355 |
+
query (str): The SQL query to be executed.
|
356 |
+
db_path (str): Path to the SQLite database file.
|
357 |
+
params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
|
358 |
+
insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
|
359 |
+
insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
|
360 |
+
|
361 |
+
"""
|
362 |
+
try:
|
363 |
+
# Construct a cross-platform path to the database
|
364 |
+
db_dir = os.path.join("db")
|
365 |
+
os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists
|
366 |
+
db_path = os.path.join(db_dir, "imp_db.db")
|
367 |
+
|
368 |
+
# Establish connection to the SQLite database
|
369 |
+
conn = sqlite3.connect(db_path)
|
370 |
+
except sqlite3.Error as e:
|
371 |
+
st.warning(f"Unable to connect to the SQLite database: {e}")
|
372 |
+
st.stop()
|
373 |
+
|
374 |
+
# Create a cursor object to interact with the database
|
375 |
+
c = conn.cursor()
|
376 |
+
|
377 |
+
try:
|
378 |
+
# Execute the query with or without parameters
|
379 |
+
if params:
|
380 |
+
params = tuple(params)
|
381 |
+
query = query.replace("IN (?)", f"IN ({','.join(['?' for _ in params])})")
|
382 |
+
c.execute(query, params)
|
383 |
+
else:
|
384 |
+
c.execute(query)
|
385 |
+
|
386 |
+
if not insert:
|
387 |
+
# If not an insert operation, fetch and return the results
|
388 |
+
results = c.fetchall()
|
389 |
+
return results
|
390 |
+
elif insert_retrieve:
|
391 |
+
# If insert and retrieve operation, commit and return the last inserted row ID
|
392 |
+
conn.commit()
|
393 |
+
return c.lastrowid
|
394 |
+
else:
|
395 |
+
# For standard insert operations, commit the transaction
|
396 |
+
conn.commit()
|
397 |
+
|
398 |
+
except Exception as e:
|
399 |
+
st.write(f"Error executing query: {e}")
|
400 |
+
finally:
|
401 |
+
conn.close()
|
402 |
+
|
403 |
+
|
404 |
+
# Function to check if the input contains any SQL keywords
|
405 |
+
def contains_sql_keywords_check(user_input):
|
406 |
+
|
407 |
+
sql_keywords = [
|
408 |
+
"SELECT",
|
409 |
+
"INSERT",
|
410 |
+
"UPDATE",
|
411 |
+
"DELETE",
|
412 |
+
"DROP",
|
413 |
+
"ALTER",
|
414 |
+
"CREATE",
|
415 |
+
"GRANT",
|
416 |
+
"REVOKE",
|
417 |
+
"UNION",
|
418 |
+
"JOIN",
|
419 |
+
"WHERE",
|
420 |
+
"HAVING",
|
421 |
+
"EXEC",
|
422 |
+
"TRUNCATE",
|
423 |
+
"REPLACE",
|
424 |
+
"MERGE",
|
425 |
+
"DECLARE",
|
426 |
+
"SHOW",
|
427 |
+
"FROM",
|
428 |
+
]
|
429 |
+
|
430 |
+
pattern = "|".join(re.escape(keyword) for keyword in sql_keywords)
|
431 |
+
return re.search(pattern, user_input, re.IGNORECASE)
|
432 |
+
|
433 |
+
|
434 |
+
# def get_table_names(schema):
|
435 |
+
# query = f"""
|
436 |
+
# SELECT table_name
|
437 |
+
# FROM information_schema.tables
|
438 |
+
# WHERE table_schema = '{schema}'
|
439 |
+
# AND table_type = 'BASE TABLE'
|
440 |
+
# AND table_name LIKE '%_mmo_gold';
|
441 |
+
# """
|
442 |
+
# table_names = query_excecuter_postgres(query, db_cred, insert=False)
|
443 |
+
# table_names = [table[0] for table in table_names]
|
444 |
+
|
445 |
+
# return table_names
|
446 |
+
|
447 |
+
|
448 |
+
def update_summary_df():
|
449 |
+
"""
|
450 |
+
Updates the 'project_summary_df' in the session state with the latest project
|
451 |
+
summary information based on the most recent updates.
|
452 |
+
|
453 |
+
This function executes a SQL query to retrieve project metadata from a database
|
454 |
+
and stores the result in the session state.
|
455 |
+
|
456 |
+
Uses:
|
457 |
+
- query_excecuter_postgres(query, params=params, insert=False): A function that
|
458 |
+
executes the provided SQL query on a PostgreSQL database.
|
459 |
+
|
460 |
+
Modifies:
|
461 |
+
- st.session_state['project_summary_df']: Updates the dataframe with columns:
|
462 |
+
'Project Number', 'Project Name', 'Last Modified Page', 'Last Modified Time'.
|
463 |
+
"""
|
464 |
+
|
465 |
+
query = f"""
|
466 |
+
WITH LatestUpdates AS (
|
467 |
+
SELECT
|
468 |
+
prj_id,
|
469 |
+
page_nam,
|
470 |
+
updt_dt_tm,
|
471 |
+
ROW_NUMBER() OVER (PARTITION BY prj_id ORDER BY updt_dt_tm DESC) AS rn
|
472 |
+
FROM
|
473 |
+
mmo_project_meta_data
|
474 |
+
)
|
475 |
+
SELECT
|
476 |
+
p.prj_id,
|
477 |
+
p.prj_nam AS prj_nam,
|
478 |
+
lu.page_nam,
|
479 |
+
lu.updt_dt_tm
|
480 |
+
FROM
|
481 |
+
LatestUpdates lu
|
482 |
+
RIGHT JOIN
|
483 |
+
mmo_projects p ON lu.prj_id = p.prj_id
|
484 |
+
WHERE
|
485 |
+
p.prj_ownr_id = ? AND lu.rn = 1
|
486 |
+
"""
|
487 |
+
|
488 |
+
params = (st.session_state["emp_id"],) # Parameters for the SQL query
|
489 |
+
|
490 |
+
# Execute the query and retrieve project summary data
|
491 |
+
project_summary = query_excecuter_postgres(
|
492 |
+
query, db_cred, params=params, insert=False
|
493 |
+
)
|
494 |
+
|
495 |
+
# Update the session state with the project summary dataframe
|
496 |
+
st.session_state["project_summary_df"] = pd.DataFrame(
|
497 |
+
project_summary,
|
498 |
+
columns=[
|
499 |
+
"Project Number",
|
500 |
+
"Project Name",
|
501 |
+
"Last Modified Page",
|
502 |
+
"Last Modified Time",
|
503 |
+
],
|
504 |
+
)
|
505 |
+
|
506 |
+
st.session_state["project_summary_df"] = st.session_state[
|
507 |
+
"project_summary_df"
|
508 |
+
].sort_values(by=["Last Modified Time"], ascending=False)
|
509 |
+
|
510 |
+
|
511 |
+
def reset_project_text_box():
|
512 |
+
st.session_state["project_name_box"] = defualt_project_prefix
|
513 |
+
st.session_state["disable_create_project"] = True
|
514 |
+
|
515 |
+
|
516 |
+
def query_excecuter_sqlite(
|
517 |
+
insert_projects_query,
|
518 |
+
insert_meta_data_query,
|
519 |
+
db_path=None,
|
520 |
+
params_projects=None,
|
521 |
+
params_meta=None,
|
522 |
+
):
|
523 |
+
"""
|
524 |
+
Executes the project insert and associated metadata insert in an SQLite database.
|
525 |
+
|
526 |
+
Parameters:
|
527 |
+
insert_projects_query (str): SQL query for inserting into the mmo_projects table.
|
528 |
+
insert_meta_data_query (str): SQL query for inserting into the mmo_project_meta_data table.
|
529 |
+
db_path (str): Path to the SQLite database file.
|
530 |
+
params_projects (tuple, optional): Parameters for the mmo_projects table insert.
|
531 |
+
params_meta (tuple, optional): Parameters for the mmo_project_meta_data table insert.
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
bool: True if successful, False otherwise.
|
535 |
+
"""
|
536 |
+
try:
|
537 |
+
# Construct a cross-platform path to the database
|
538 |
+
db_dir = os.path.join("db")
|
539 |
+
os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists
|
540 |
+
db_path = os.path.join(db_dir, "imp_db.db")
|
541 |
+
|
542 |
+
# Establish connection to the SQLite database
|
543 |
+
conn = sqlite3.connect(db_path)
|
544 |
+
cursor = conn.cursor()
|
545 |
+
|
546 |
+
# Execute the first insert query into the mmo_projects table
|
547 |
+
cursor.execute(insert_projects_query, params_projects)
|
548 |
+
|
549 |
+
# Get the last inserted project ID
|
550 |
+
prj_id = cursor.lastrowid
|
551 |
+
|
552 |
+
# Modify the parameters for the metadata table with the inserted prj_id
|
553 |
+
params_meta = (prj_id,) + params_meta
|
554 |
+
|
555 |
+
# Execute the second insert query into the mmo_project_meta_data table
|
556 |
+
cursor.execute(insert_meta_data_query, params_meta)
|
557 |
+
|
558 |
+
# Commit the transaction
|
559 |
+
conn.commit()
|
560 |
+
|
561 |
+
except sqlite3.Error as e:
|
562 |
+
st.warning(f"Error executing query: {e}")
|
563 |
+
return False
|
564 |
+
finally:
|
565 |
+
# Close the connection
|
566 |
+
conn.close()
|
567 |
+
|
568 |
+
return True
|
569 |
+
|
570 |
+
|
571 |
+
def new_project():
|
572 |
+
"""
|
573 |
+
Cleans the project name input and inserts project data into the SQLite database,
|
574 |
+
updating session state and triggering UI rerun if successful.
|
575 |
+
"""
|
576 |
+
|
577 |
+
# Define a dictionary containing project data
|
578 |
+
project_dct = default_dct.copy()
|
579 |
+
|
580 |
+
gold_layer_df = pd.DataFrame()
|
581 |
+
if str(api_name).strip().lower() != "na":
|
582 |
+
try:
|
583 |
+
gold_layer_df = load_gold_layer_data(api_name)
|
584 |
+
except Exception as e:
|
585 |
+
st.toast(
|
586 |
+
"Failed to load gold layer data. Please check the gold layer structure and connection.",
|
587 |
+
icon="⚠️",
|
588 |
+
)
|
589 |
+
log_message(
|
590 |
+
"error",
|
591 |
+
f"Error loading gold layer data: {str(e)}",
|
592 |
+
"Home",
|
593 |
+
)
|
594 |
+
|
595 |
+
project_dct["data_import"]["gold_layer_df"] = gold_layer_df
|
596 |
+
|
597 |
+
# Get current time for database insertion
|
598 |
+
inserted_time = datetime.now().isoformat()
|
599 |
+
|
600 |
+
# Define SQL queries for inserting project and metadata into the SQLite database
|
601 |
+
insert_projects_query = """
|
602 |
+
INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_dt_tm, crte_by_uid)
|
603 |
+
VALUES (?, ?, ?, ?, ?);
|
604 |
+
"""
|
605 |
+
insert_meta_data_query = """
|
606 |
+
INSERT INTO mmo_project_meta_data (prj_id, page_nam, file_nam, pkl_obj, crte_dt_tm, crte_by_uid, updt_dt_tm)
|
607 |
+
VALUES (?, ?, ?, ?, ?, ?, ?);
|
608 |
+
"""
|
609 |
+
|
610 |
+
# Get current time for metadata update
|
611 |
+
updt_dt_tm = datetime.now().isoformat()
|
612 |
+
|
613 |
+
# Serialize project_dct using pickle
|
614 |
+
project_pkl = pickle.dumps(project_dct)
|
615 |
+
|
616 |
+
# Prepare data for database insertion
|
617 |
+
projects_data = (
|
618 |
+
st.session_state["emp_id"], # prj_ownr_id
|
619 |
+
project_name, # prj_nam
|
620 |
+
",".join(matching_user_id), # alwd_emp_id
|
621 |
+
inserted_time, # crte_dt_tm
|
622 |
+
st.session_state["emp_id"], # crte_by_uid
|
623 |
+
)
|
624 |
+
|
625 |
+
project_meta_data = (
|
626 |
+
"Home", # page_nam
|
627 |
+
"project_dct", # file_nam
|
628 |
+
project_pkl, # pkl_obj
|
629 |
+
inserted_time, # crte_dt_tm
|
630 |
+
st.session_state["emp_id"], # crte_by_uid
|
631 |
+
updt_dt_tm, # updt_dt_tm
|
632 |
+
)
|
633 |
+
|
634 |
+
# Execute the insertion query for SQLite
|
635 |
+
success = query_excecuter_sqlite(
|
636 |
+
insert_projects_query,
|
637 |
+
insert_meta_data_query,
|
638 |
+
params_projects=projects_data,
|
639 |
+
params_meta=project_meta_data,
|
640 |
+
)
|
641 |
+
|
642 |
+
if success:
|
643 |
+
st.success("Project Created")
|
644 |
+
update_summary_df()
|
645 |
+
else:
|
646 |
+
st.error("Failed to create project.")
|
647 |
+
|
648 |
+
|
649 |
+
def validate_password(user_input):
|
650 |
+
# List of SQL keywords to check for
|
651 |
+
sql_keywords = [
|
652 |
+
"SELECT",
|
653 |
+
"INSERT",
|
654 |
+
"UPDATE",
|
655 |
+
"DELETE",
|
656 |
+
"DROP",
|
657 |
+
"ALTER",
|
658 |
+
"CREATE",
|
659 |
+
"GRANT",
|
660 |
+
"REVOKE",
|
661 |
+
"UNION",
|
662 |
+
"JOIN",
|
663 |
+
"WHERE",
|
664 |
+
"HAVING",
|
665 |
+
"EXEC",
|
666 |
+
"TRUNCATE",
|
667 |
+
"REPLACE",
|
668 |
+
"MERGE",
|
669 |
+
"DECLARE",
|
670 |
+
"SHOW",
|
671 |
+
"FROM",
|
672 |
+
]
|
673 |
+
|
674 |
+
# Create a regex pattern for SQL keywords
|
675 |
+
pattern = "|".join(re.escape(keyword) for keyword in sql_keywords)
|
676 |
+
|
677 |
+
# Check if input contains any SQL keywords
|
678 |
+
if re.search(pattern, user_input, re.IGNORECASE):
|
679 |
+
return "SQL keyword detected."
|
680 |
+
|
681 |
+
# Password validation criteria
|
682 |
+
if len(user_input) < 8:
|
683 |
+
return "Password should be at least 8 characters long."
|
684 |
+
if not re.search(r"[A-Z]", user_input):
|
685 |
+
return "Password should contain at least one uppercase letter."
|
686 |
+
if not re.search(r"[0-9]", user_input):
|
687 |
+
return "Password should contain at least one digit."
|
688 |
+
if not re.search(r"[a-z]", user_input):
|
689 |
+
return "Password should contain at least one lowercase letter."
|
690 |
+
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', user_input):
|
691 |
+
return "Password should contain at least one special character."
|
692 |
+
|
693 |
+
# If all checks pass
|
694 |
+
return "Valid input."
|
695 |
+
|
696 |
+
|
697 |
+
def fetch_and_process_projects(emp_id):
|
698 |
+
query = f"""
|
699 |
+
WITH ProjectAccess AS (
|
700 |
+
SELECT
|
701 |
+
p.prj_id,
|
702 |
+
p.prj_nam,
|
703 |
+
p.alwd_emp_id,
|
704 |
+
u.emp_nam AS project_owner
|
705 |
+
FROM mmo_projects p
|
706 |
+
JOIN mmo_users u ON p.prj_ownr_id = u.emp_id
|
707 |
+
)
|
708 |
+
SELECT
|
709 |
+
pa.prj_id,
|
710 |
+
pa.prj_nam,
|
711 |
+
pa.project_owner
|
712 |
+
FROM
|
713 |
+
ProjectAccess pa
|
714 |
+
WHERE
|
715 |
+
pa.alwd_emp_id LIKE ?
|
716 |
+
ORDER BY
|
717 |
+
pa.prj_id;
|
718 |
+
"""
|
719 |
+
|
720 |
+
params = (f"%{emp_id}%",)
|
721 |
+
results = query_excecuter_postgres(query, db_cred, params=params, insert=False)
|
722 |
+
|
723 |
+
# Process the results to create the desired dictionary structure
|
724 |
+
clone_project_dict = {}
|
725 |
+
for row in results:
|
726 |
+
project_id, project_name, project_owner = row
|
727 |
+
|
728 |
+
if project_owner not in clone_project_dict:
|
729 |
+
clone_project_dict[project_owner] = []
|
730 |
+
|
731 |
+
clone_project_dict[project_owner].append(
|
732 |
+
{"project_name": project_name, "project_id": project_id}
|
733 |
+
)
|
734 |
+
|
735 |
+
return clone_project_dict
|
736 |
+
|
737 |
+
|
738 |
+
def get_project_id_from_dict(projects_dict, owner_name, project_name):
|
739 |
+
if owner_name in projects_dict:
|
740 |
+
for project in projects_dict[owner_name]:
|
741 |
+
if project["project_name"] == project_name:
|
742 |
+
return project["project_id"]
|
743 |
+
return None
|
744 |
+
|
745 |
+
|
746 |
+
# def fetch_project_metadata(prj_id):
|
747 |
+
# query = f"""
|
748 |
+
# SELECT
|
749 |
+
# prj_id, page_nam, file_nam, pkl_obj, dshbrd_ts
|
750 |
+
# FROM
|
751 |
+
# mmo_project_meta_data
|
752 |
+
# WHERE
|
753 |
+
# prj_id = ?;
|
754 |
+
# """
|
755 |
+
|
756 |
+
# params = (prj_id,)
|
757 |
+
# return query_excecuter_postgres(query, db_cred, params=params, insert=False)
|
758 |
+
|
759 |
+
|
760 |
+
def fetch_project_metadata(prj_id):
|
761 |
+
# Query to select project metadata
|
762 |
+
query = """
|
763 |
+
SELECT
|
764 |
+
prj_id, page_nam, file_nam, pkl_obj, dshbrd_ts
|
765 |
+
FROM
|
766 |
+
mmo_project_meta_data
|
767 |
+
WHERE
|
768 |
+
prj_id = ?;
|
769 |
+
"""
|
770 |
+
|
771 |
+
params = (prj_id,)
|
772 |
+
return query_excecuter_postgres(query, db_cred, params=params, insert=False)
|
773 |
+
|
774 |
+
|
775 |
+
# def create_new_project(prj_ownr_id, prj_nam, alwd_emp_id, emp_id):
|
776 |
+
# query = f"""
|
777 |
+
# INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_by_uid, crte_dt_tm)
|
778 |
+
# VALUES (?, ?, ?, ?, NOW())
|
779 |
+
# RETURNING prj_id;
|
780 |
+
# """
|
781 |
+
|
782 |
+
# params = (prj_ownr_id, prj_nam, alwd_emp_id, emp_id)
|
783 |
+
# result = query_excecuter_postgres(
|
784 |
+
# query, db_cred, params=params, insert=True, insert_retrieve=True
|
785 |
+
# )
|
786 |
+
# return result[0][0]
|
787 |
+
|
788 |
+
|
789 |
+
# def create_new_project(prj_ownr_id, prj_nam, alwd_emp_id, emp_id):
|
790 |
+
# # Query to insert a new project
|
791 |
+
# insert_query = """
|
792 |
+
# INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_by_uid, crte_dt_tm)
|
793 |
+
# VALUES (?, ?, ?, ?, DATETIME('now'));
|
794 |
+
# """
|
795 |
+
|
796 |
+
# params = (prj_ownr_id, prj_nam, alwd_emp_id, emp_id)
|
797 |
+
|
798 |
+
# # Execute the insert query
|
799 |
+
# query_excecuter_postgres(insert_query, db_cred, params=params, insert=True)
|
800 |
+
|
801 |
+
# # Retrieve the last inserted prj_id
|
802 |
+
# retrieve_id_query = "SELECT last_insert_rowid();"
|
803 |
+
# result = query_excecuter_postgres(retrieve_id_query, db_cred, insert_retrieve=True)
|
804 |
+
|
805 |
+
# return result[0][0]
|
806 |
+
|
807 |
+
|
808 |
+
def create_new_project(prj_ownr_id, prj_nam, alwd_emp_id, emp_id):
|
809 |
+
# Query to insert a new project
|
810 |
+
insert_query = """
|
811 |
+
INSERT INTO mmo_projects (prj_ownr_id, prj_nam, alwd_emp_id, crte_by_uid, crte_dt_tm)
|
812 |
+
VALUES (?, ?, ?, ?, DATETIME('now'));
|
813 |
+
"""
|
814 |
+
params = (prj_ownr_id, prj_nam, alwd_emp_id, emp_id)
|
815 |
+
|
816 |
+
# Execute the insert query and retrieve the last inserted prj_id directly
|
817 |
+
last_inserted_id = query_excecuter_postgres(
|
818 |
+
insert_query, params=params, insert_retrieve=True
|
819 |
+
)
|
820 |
+
|
821 |
+
return last_inserted_id
|
822 |
+
|
823 |
+
|
824 |
+
def insert_project_metadata(new_prj_id, metadata, created_emp_id):
|
825 |
+
# query = f"""
|
826 |
+
# INSERT INTO mmo_project_meta_data (
|
827 |
+
# prj_id, page_nam, crte_dt_tm, file_nam, pkl_obj, dshbrd_ts, crte_by_uid
|
828 |
+
# )
|
829 |
+
# VALUES (?, ?, NOW(), ?, ?, ?, ?);
|
830 |
+
# """
|
831 |
+
|
832 |
+
query = """
|
833 |
+
INSERT INTO mmo_project_meta_data (
|
834 |
+
prj_id, page_nam, crte_dt_tm, file_nam, pkl_obj, dshbrd_ts, crte_by_uid
|
835 |
+
)
|
836 |
+
VALUES (?, ?, DATETIME('now'), ?, ?, ?, ?);
|
837 |
+
"""
|
838 |
+
|
839 |
+
for row in metadata:
|
840 |
+
params = (new_prj_id, row[1], row[2], row[3], row[4], created_emp_id)
|
841 |
+
query_excecuter_postgres(query, db_cred, params=params, insert=True)
|
842 |
+
|
843 |
+
|
844 |
+
# def delete_projects_by_ids(prj_ids):
|
845 |
+
# # Ensure prj_ids is a tuple to use with the IN clause
|
846 |
+
# prj_ids_tuple = tuple(prj_ids)
|
847 |
+
|
848 |
+
# # Query to delete project metadata
|
849 |
+
# delete_metadata_query = f"""
|
850 |
+
# DELETE FROM mmo_project_meta_data
|
851 |
+
# WHERE prj_id IN ?;
|
852 |
+
# """
|
853 |
+
|
854 |
+
# delete_projects_query = f"""
|
855 |
+
# DELETE FROM mmo_projects
|
856 |
+
# WHERE prj_id IN ?;
|
857 |
+
# """
|
858 |
+
|
859 |
+
# try:
|
860 |
+
# # Delete from metadata table
|
861 |
+
# query_excecuter_postgres(
|
862 |
+
# delete_metadata_query, db_cred, params=(prj_ids_tuple,), insert=True
|
863 |
+
# )
|
864 |
+
|
865 |
+
# # Delete from projects table
|
866 |
+
# query_excecuter_postgres(
|
867 |
+
# delete_projects_query, db_cred, params=(prj_ids_tuple,), insert=True
|
868 |
+
# )
|
869 |
+
|
870 |
+
# except Exception as e:
|
871 |
+
# st.write(f"Error deleting projects: {e}")
|
872 |
+
|
873 |
+
|
874 |
+
def delete_projects_by_ids(prj_ids):
|
875 |
+
# Ensure prj_ids is a tuple to use with the IN clause
|
876 |
+
prj_ids_tuple = tuple(prj_ids)
|
877 |
+
|
878 |
+
# Dynamically generate placeholders for SQLite
|
879 |
+
placeholders = ", ".join(["?"] * len(prj_ids_tuple))
|
880 |
+
|
881 |
+
# Query to delete project metadata with dynamic placeholders
|
882 |
+
delete_metadata_query = f"""
|
883 |
+
DELETE FROM mmo_project_meta_data
|
884 |
+
WHERE prj_id IN ({placeholders});
|
885 |
+
"""
|
886 |
+
|
887 |
+
delete_projects_query = f"""
|
888 |
+
DELETE FROM mmo_projects
|
889 |
+
WHERE prj_id IN ({placeholders});
|
890 |
+
"""
|
891 |
+
|
892 |
+
try:
|
893 |
+
# Delete from metadata table
|
894 |
+
query_excecuter_postgres(
|
895 |
+
delete_metadata_query, db_cred, params=prj_ids_tuple, insert=True
|
896 |
+
)
|
897 |
+
|
898 |
+
# Delete from projects table
|
899 |
+
query_excecuter_postgres(
|
900 |
+
delete_projects_query, db_cred, params=prj_ids_tuple, insert=True
|
901 |
+
)
|
902 |
+
|
903 |
+
except Exception as e:
|
904 |
+
st.write(f"Error deleting projects: {e}")
|
905 |
+
|
906 |
+
|
907 |
+
def fetch_users_with_access(prj_id):
|
908 |
+
# Query to get allowed employee IDs for the project
|
909 |
+
get_allowed_emps_query = """
|
910 |
+
SELECT alwd_emp_id
|
911 |
+
FROM mmo_projects
|
912 |
+
WHERE prj_id = ?;
|
913 |
+
"""
|
914 |
+
|
915 |
+
# Fetch the allowed employee IDs
|
916 |
+
allowed_emp_ids_result = query_excecuter_postgres(
|
917 |
+
get_allowed_emps_query, db_cred, params=(prj_id,), insert=False
|
918 |
+
)
|
919 |
+
|
920 |
+
if not allowed_emp_ids_result:
|
921 |
+
return []
|
922 |
+
|
923 |
+
# Extract the allowed employee IDs (Assuming alwd_emp_id is a comma-separated string)
|
924 |
+
allowed_emp_ids_str = allowed_emp_ids_result[0][0]
|
925 |
+
allowed_emp_ids = allowed_emp_ids_str.split(",")
|
926 |
+
|
927 |
+
# Convert to tuple for the IN clause
|
928 |
+
allowed_emp_ids_tuple = tuple(allowed_emp_ids)
|
929 |
+
|
930 |
+
# Query to get user details for the allowed employee IDs
|
931 |
+
get_users_query = """
|
932 |
+
SELECT emp_id, emp_nam, emp_typ
|
933 |
+
FROM mmo_users
|
934 |
+
WHERE emp_id IN ({});
|
935 |
+
""".format(
|
936 |
+
",".join("?" * len(allowed_emp_ids_tuple))
|
937 |
+
) # Dynamically construct the placeholder list
|
938 |
+
|
939 |
+
# Fetch user details
|
940 |
+
user_details = query_excecuter_postgres(
|
941 |
+
get_users_query, db_cred, params=allowed_emp_ids_tuple, insert=False
|
942 |
+
)
|
943 |
+
|
944 |
+
return user_details
|
945 |
+
|
946 |
+
|
947 |
+
# def update_project_access(prj_id, user_names, new_user_ids):
|
948 |
+
# # Convert the list of new user IDs to a comma-separated string
|
949 |
+
# new_user_ids_str = ",".join(new_user_ids)
|
950 |
+
|
951 |
+
# # Query to update the alwd_emp_id for the specified project
|
952 |
+
# update_access_query = f"""
|
953 |
+
# UPDATE mmo_projects
|
954 |
+
# SET alwd_emp_id = ?
|
955 |
+
# WHERE prj_id = ?;
|
956 |
+
# """
|
957 |
+
|
958 |
+
# # Execute the update query
|
959 |
+
# query_excecuter_postgres(
|
960 |
+
# update_access_query, db_cred, params=(new_user_ids_str, prj_id), insert=True
|
961 |
+
# )
|
962 |
+
# st.success(f"Project {prj_id} access updated successfully")
|
963 |
+
|
964 |
+
|
965 |
+
def fetch_user_ids_from_dict(user_dict, user_names):
|
966 |
+
user_ids = []
|
967 |
+
# Iterate over the user_dict to find matching user names
|
968 |
+
for user_id, details in user_dict.items():
|
969 |
+
if details[0] in user_names:
|
970 |
+
user_ids.append(user_id)
|
971 |
+
return user_ids
|
972 |
+
|
973 |
+
|
974 |
+
def update_project_access(prj_id, user_names, user_dict):
|
975 |
+
# Fetch the new user IDs based on the provided user names from the dictionary
|
976 |
+
new_user_ids = fetch_user_ids_from_dict(user_dict, user_names)
|
977 |
+
|
978 |
+
# Convert the list of new user IDs to a comma-separated string
|
979 |
+
new_user_ids_str = ",".join(new_user_ids)
|
980 |
+
|
981 |
+
# Query to update the alwd_emp_id for the specified project
|
982 |
+
update_access_query = f"""
|
983 |
+
UPDATE mmo_projects
|
984 |
+
SET alwd_emp_id = ?
|
985 |
+
WHERE prj_id = ?;
|
986 |
+
"""
|
987 |
+
|
988 |
+
# Execute the update query
|
989 |
+
query_excecuter_postgres(
|
990 |
+
update_access_query, db_cred, params=(new_user_ids_str, prj_id), insert=True
|
991 |
+
)
|
992 |
+
st.write(f"Project {prj_id} access updated successfully")
|
993 |
+
|
994 |
+
|
995 |
+
def validate_emp_id():
|
996 |
+
|
997 |
+
if st.session_state.sign_up not in st.session_state["unique_ids"].keys():
|
998 |
+
st.warning("You dont have access to the tool please contact admin")
|
999 |
+
|
1000 |
+
|
1001 |
+
# -------------------Front END-------------------------#
|
1002 |
+
|
1003 |
+
st.header("Manage Projects")
|
1004 |
+
|
1005 |
+
unique_users_query = f"""
|
1006 |
+
SELECT DISTINCT emp_id, emp_nam, emp_typ
|
1007 |
+
FROM mmo_users;
|
1008 |
+
"""
|
1009 |
+
|
1010 |
+
if "unique_ids" not in st.session_state:
|
1011 |
+
|
1012 |
+
unique_users_result = query_excecuter_postgres(
|
1013 |
+
unique_users_query, db_cred, insert=False
|
1014 |
+
) # retrieves all the users who has access to MMO TOOL
|
1015 |
+
|
1016 |
+
if len(unique_users_result) == 0:
|
1017 |
+
st.warning("No users data present in db, please contact admin!")
|
1018 |
+
st.stop()
|
1019 |
+
|
1020 |
+
st.session_state["unique_ids"] = {
|
1021 |
+
emp_id: (emp_nam, emp_type) for emp_id, emp_nam, emp_type in unique_users_result
|
1022 |
+
}
|
1023 |
+
|
1024 |
+
|
1025 |
+
if "toggle" not in st.session_state:
|
1026 |
+
st.session_state["toggle"] = 0
|
1027 |
+
|
1028 |
+
|
1029 |
+
if "emp_id" not in st.session_state:
|
1030 |
+
reset_password = st.radio(
|
1031 |
+
"Select An Option",
|
1032 |
+
options=["Login", "Reset Password"],
|
1033 |
+
index=st.session_state["toggle"],
|
1034 |
+
horizontal=True,
|
1035 |
+
)
|
1036 |
+
|
1037 |
+
if reset_password == "Login":
|
1038 |
+
emp_id = st.text_input("Employee id").lower() # emp id
|
1039 |
+
password = st.text_input("Password", max_chars=15, type="password")
|
1040 |
+
login_button = st.button("Login", use_container_width=True)
|
1041 |
+
|
1042 |
+
else:
|
1043 |
+
|
1044 |
+
emp_id = st.text_input(
|
1045 |
+
"Employee id", key="sign_up", on_change=validate_emp_id
|
1046 |
+
).lower()
|
1047 |
+
|
1048 |
+
current_password = st.text_input(
|
1049 |
+
"Enter Current Password and Press Enter to Validate",
|
1050 |
+
max_chars=15,
|
1051 |
+
type="password",
|
1052 |
+
key="current_password",
|
1053 |
+
)
|
1054 |
+
if emp_id:
|
1055 |
+
|
1056 |
+
if emp_id not in st.session_state["unique_ids"].keys():
|
1057 |
+
st.write("Invalid id!")
|
1058 |
+
st.stop()
|
1059 |
+
else:
|
1060 |
+
if not is_pswrd_flag_set(emp_id):
|
1061 |
+
|
1062 |
+
if verify_password(emp_id, current_password):
|
1063 |
+
st.success("Your password key has been successfully validated!")
|
1064 |
+
|
1065 |
+
elif (
|
1066 |
+
not verify_password(emp_id, current_password)
|
1067 |
+
and len(current_password) > 1
|
1068 |
+
):
|
1069 |
+
st.write("Wrong Password Key Please Try Again")
|
1070 |
+
st.stop()
|
1071 |
+
|
1072 |
+
elif verify_password(emp_id, current_password):
|
1073 |
+
st.success("Your password has been successfully validated!")
|
1074 |
+
|
1075 |
+
elif (
|
1076 |
+
not verify_password(emp_id, current_password)
|
1077 |
+
and len(current_password) > 1
|
1078 |
+
):
|
1079 |
+
st.write("Wrong Password Please Try Again")
|
1080 |
+
st.stop()
|
1081 |
+
|
1082 |
+
new_password = st.text_input(
|
1083 |
+
"Enter New Password", max_chars=15, type="password", key="new_password"
|
1084 |
+
)
|
1085 |
+
|
1086 |
+
st.markdown(
|
1087 |
+
"**Password must be at least 8 to 15 characters long and contain at least one uppercase letter, one lowercase letter, one digit, and one special character. No SQL commands allowed.**"
|
1088 |
+
)
|
1089 |
+
|
1090 |
+
validation_result = validate_password(new_password)
|
1091 |
+
|
1092 |
+
confirm_new_password = st.text_input(
|
1093 |
+
"Confirm New Password",
|
1094 |
+
max_chars=15,
|
1095 |
+
type="password",
|
1096 |
+
key="confirm_new_password",
|
1097 |
+
)
|
1098 |
+
|
1099 |
+
reset_button = st.button("Reset Password", use_container_width=True)
|
1100 |
+
|
1101 |
+
if reset_button:
|
1102 |
+
|
1103 |
+
validation_result = validate_password(new_password)
|
1104 |
+
|
1105 |
+
if validation_result != "Valid input.":
|
1106 |
+
st.warning(validation_result)
|
1107 |
+
st.stop()
|
1108 |
+
elif new_password != confirm_new_password:
|
1109 |
+
st.warning(
|
1110 |
+
"The new password and confirmation password do not match. Please try again."
|
1111 |
+
)
|
1112 |
+
st.stop()
|
1113 |
+
else:
|
1114 |
+
store_hashed_password(emp_id, confirm_new_password)
|
1115 |
+
set_pswrd_flag(emp_id)
|
1116 |
+
st.success("Password Reset Successful!")
|
1117 |
+
|
1118 |
+
with st.spinner("Redirecting to Login"):
|
1119 |
+
time.sleep(3)
|
1120 |
+
st.session_state["toggle"] = 0
|
1121 |
+
st.rerun()
|
1122 |
+
|
1123 |
+
st.stop()
|
1124 |
+
|
1125 |
+
if login_button:
|
1126 |
+
|
1127 |
+
if emp_id not in st.session_state["unique_ids"].keys() or len(password) == 0:
|
1128 |
+
st.warning("invalid id or password!")
|
1129 |
+
|
1130 |
+
st.stop()
|
1131 |
+
|
1132 |
+
if not is_pswrd_flag_set(emp_id):
|
1133 |
+
st.warning("Reset password to continue")
|
1134 |
+
with st.spinner("Redirecting"):
|
1135 |
+
st.session_state["toggle"] = 1
|
1136 |
+
time.sleep(2)
|
1137 |
+
st.rerun()
|
1138 |
+
st.stop()
|
1139 |
+
|
1140 |
+
elif verify_password(emp_id, password):
|
1141 |
+
with st.spinner("Loading Saved Projects"):
|
1142 |
+
st.session_state["emp_id"] = emp_id
|
1143 |
+
|
1144 |
+
update_summary_df() # function call to fetch user saved projects
|
1145 |
+
|
1146 |
+
st.session_state["clone_project_dict"] = fetch_and_process_projects(
|
1147 |
+
st.session_state["emp_id"]
|
1148 |
+
)
|
1149 |
+
if "project_dct" in st.session_state:
|
1150 |
+
del st.session_state["project_dct"]
|
1151 |
+
|
1152 |
+
st.session_state["project_name"] = None
|
1153 |
+
|
1154 |
+
delete_old_log_files()
|
1155 |
+
|
1156 |
+
st.rerun()
|
1157 |
+
|
1158 |
+
if (
|
1159 |
+
len(st.session_state["emp_id"]) == 0
|
1160 |
+
or st.session_state["emp_id"]
|
1161 |
+
not in st.session_state["unique_ids"].keys()
|
1162 |
+
):
|
1163 |
+
st.stop()
|
1164 |
+
else:
|
1165 |
+
st.warning("Invalid user name or password")
|
1166 |
+
|
1167 |
+
st.stop()
|
1168 |
+
|
1169 |
+
if st.button("Logout"):
|
1170 |
+
if "emp_id" in st.session_state:
|
1171 |
+
del st.session_state["emp_id"]
|
1172 |
+
st.rerun()
|
1173 |
+
|
1174 |
+
if st.session_state["emp_id"] in st.session_state["unique_ids"].keys():
|
1175 |
+
|
1176 |
+
if "project_name" not in st.session_state:
|
1177 |
+
st.session_state["project_name"] = None
|
1178 |
+
|
1179 |
+
cols1 = st.columns([2, 1])
|
1180 |
+
|
1181 |
+
st.session_state["username"] = st.session_state["unique_ids"][
|
1182 |
+
st.session_state["emp_id"]
|
1183 |
+
][0]
|
1184 |
+
|
1185 |
+
with cols1[0]:
|
1186 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
1187 |
+
with cols1[1]:
|
1188 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
1189 |
+
|
1190 |
+
st.markdown(
|
1191 |
+
"""
|
1192 |
+
Enter project number in the text below and click on load project to load the project.
|
1193 |
+
|
1194 |
+
"""
|
1195 |
+
)
|
1196 |
+
|
1197 |
+
st.markdown("Select Project")
|
1198 |
+
|
1199 |
+
# st.write(type(st.session_state.keys))
|
1200 |
+
|
1201 |
+
if len(st.session_state["project_summary_df"]) != 0:
|
1202 |
+
|
1203 |
+
# Display an editable data table using Streamlit's data editor component
|
1204 |
+
|
1205 |
+
table = st.dataframe(
|
1206 |
+
st.session_state["project_summary_df"],
|
1207 |
+
use_container_width=True,
|
1208 |
+
hide_index=True,
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
project_number = st.selectbox(
|
1212 |
+
"Enter Project number",
|
1213 |
+
options=st.session_state["project_summary_df"]["Project Number"],
|
1214 |
+
)
|
1215 |
+
|
1216 |
+
log_message(
|
1217 |
+
"info",
|
1218 |
+
f"Project number {project_number} selected by employee {st.session_state['emp_id']}.",
|
1219 |
+
"Home",
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
project_col = st.columns(2)
|
1223 |
+
|
1224 |
+
# if "load_project_key" not in st.session_state:
|
1225 |
+
# st.session_state["load_project_key"] = None\
|
1226 |
+
|
1227 |
+
def load_project_fun():
|
1228 |
+
st.session_state["project_name"] = (
|
1229 |
+
st.session_state["project_summary_df"]
|
1230 |
+
.loc[
|
1231 |
+
st.session_state["project_summary_df"]["Project Number"]
|
1232 |
+
== project_number,
|
1233 |
+
"Project Name",
|
1234 |
+
]
|
1235 |
+
.values[0]
|
1236 |
+
) # fetching project name from project number stored in summary df
|
1237 |
+
|
1238 |
+
project_dct_query = f"""
|
1239 |
+
SELECT pkl_obj
|
1240 |
+
FROM mmo_project_meta_data
|
1241 |
+
WHERE prj_id = ? AND file_nam = ?;
|
1242 |
+
"""
|
1243 |
+
# Execute the query and retrieve the result
|
1244 |
+
project_dct_retrieved = query_excecuter_postgres(
|
1245 |
+
project_dct_query,
|
1246 |
+
db_cred,
|
1247 |
+
params=(project_number, "project_dct"),
|
1248 |
+
insert=False,
|
1249 |
+
)
|
1250 |
+
# retrieves project dict (meta data) stored in db
|
1251 |
+
|
1252 |
+
st.session_state["project_dct"] = pickle.loads(
|
1253 |
+
project_dct_retrieved[0][0]
|
1254 |
+
) # converting bytes data to original objet using pickle
|
1255 |
+
st.session_state["project_number"] = project_number
|
1256 |
+
|
1257 |
+
keys_to_keep = [
|
1258 |
+
"unique_ids",
|
1259 |
+
"emp_id",
|
1260 |
+
"project_dct",
|
1261 |
+
"project_name",
|
1262 |
+
"project_number",
|
1263 |
+
"username",
|
1264 |
+
"project_summary_df",
|
1265 |
+
"clone_project_dict",
|
1266 |
+
]
|
1267 |
+
|
1268 |
+
# Clear all keys in st.session_state except the ones to keep
|
1269 |
+
for key in list(st.session_state.keys()):
|
1270 |
+
if key not in keys_to_keep:
|
1271 |
+
del st.session_state[key]
|
1272 |
+
|
1273 |
+
ensure_project_dct_structure(st.session_state["project_dct"], default_dct)
|
1274 |
+
|
1275 |
+
if st.button(
|
1276 |
+
"Load Project",
|
1277 |
+
use_container_width=True,
|
1278 |
+
key="load_project_key",
|
1279 |
+
on_click=load_project_fun,
|
1280 |
+
):
|
1281 |
+
st.success("Project Loded")
|
1282 |
+
|
1283 |
+
# st.rerun() # refresh the page
|
1284 |
+
|
1285 |
+
# st.write(st.session_state['project_dct'])
|
1286 |
+
if "radio_box_index" not in st.session_state:
|
1287 |
+
st.session_state["radio_box_index"] = 0
|
1288 |
+
|
1289 |
+
projct_radio = st.radio(
|
1290 |
+
"Select Options",
|
1291 |
+
[
|
1292 |
+
"Create New Project",
|
1293 |
+
"Modify Project Access",
|
1294 |
+
"Clone Saved Projects",
|
1295 |
+
"Delete Projects",
|
1296 |
+
],
|
1297 |
+
horizontal=True,
|
1298 |
+
index=st.session_state["radio_box_index"],
|
1299 |
+
)
|
1300 |
+
|
1301 |
+
if projct_radio == "Modify Project Access":
|
1302 |
+
|
1303 |
+
with st.expander("Modify Project Access"):
|
1304 |
+
project_number_for_access = st.selectbox(
|
1305 |
+
"Select Project Number",
|
1306 |
+
st.session_state["project_summary_df"]["Project Number"],
|
1307 |
+
)
|
1308 |
+
|
1309 |
+
with st.spinner("Loading"):
|
1310 |
+
users_who_has_access = fetch_users_with_access(
|
1311 |
+
project_number_for_access
|
1312 |
+
)
|
1313 |
+
|
1314 |
+
users_name_who_has_access = [user[1] for user in users_who_has_access]
|
1315 |
+
modified_users_for_access_options = [
|
1316 |
+
details[0]
|
1317 |
+
for user_id, details in st.session_state["unique_ids"].items()
|
1318 |
+
if user_id != st.session_state["emp_id"]
|
1319 |
+
]
|
1320 |
+
|
1321 |
+
users_name_who_has_access = [
|
1322 |
+
name
|
1323 |
+
for name in users_name_who_has_access
|
1324 |
+
if name in modified_users_for_access_options
|
1325 |
+
]
|
1326 |
+
|
1327 |
+
modified_users_for_access = st.multiselect(
|
1328 |
+
"Select or deselect users to grant or revoke access, then click the 'Modify Access' button to submit changes.",
|
1329 |
+
options=modified_users_for_access_options,
|
1330 |
+
default=users_name_who_has_access,
|
1331 |
+
)
|
1332 |
+
|
1333 |
+
if st.button("Modify Access", use_container_width=True):
|
1334 |
+
with st.spinner("Modifying Access"):
|
1335 |
+
update_project_access(
|
1336 |
+
project_number_for_access,
|
1337 |
+
modified_users_for_access,
|
1338 |
+
st.session_state["unique_ids"],
|
1339 |
+
)
|
1340 |
+
|
1341 |
+
if projct_radio == "Create New Project":
|
1342 |
+
|
1343 |
+
with st.expander("Create New Project", expanded=False):
|
1344 |
+
|
1345 |
+
st.session_state["is_create_project_open"] = True
|
1346 |
+
|
1347 |
+
unique_users = [
|
1348 |
+
user[0] for user in st.session_state["unique_ids"].values()
|
1349 |
+
] # fetching unique users who has access to the tool
|
1350 |
+
|
1351 |
+
user_projects = list(
|
1352 |
+
set(st.session_state["project_summary_df"]["Project Name"])
|
1353 |
+
) # fetching corressponding user's projects
|
1354 |
+
st.markdown(
|
1355 |
+
"""
|
1356 |
+
To create a new project, follow the instructions below:
|
1357 |
+
|
1358 |
+
1. **Project Name**:
|
1359 |
+
- It should start with the client name, followed by the username.
|
1360 |
+
- It should not contain special characters except for underscores (`_`) and should not contain spaces.
|
1361 |
+
- Example format: `<client_name>_<username>_<project_name>`
|
1362 |
+
|
1363 |
+
2. **Select User**: Select the user you want to give access to this project.
|
1364 |
+
|
1365 |
+
3. **Create New Project**: Click **Create New Project** once the above details are entered.
|
1366 |
+
|
1367 |
+
**Example**:
|
1368 |
+
|
1369 |
+
- For a client named "ClientA" and a user named "UserX" with a project named "NewCampaign", the project name should be:
|
1370 |
+
`ClientA_UserX_NewCampaign`
|
1371 |
+
"""
|
1372 |
+
)
|
1373 |
+
|
1374 |
+
project_col1 = st.columns(3)
|
1375 |
+
|
1376 |
+
with project_col1[0]:
|
1377 |
+
|
1378 |
+
# API_tables = get_table_names(schema) # load API files
|
1379 |
+
|
1380 |
+
slection_tables = ["NA"]
|
1381 |
+
|
1382 |
+
api_name = st.selectbox("Select API data", slection_tables, index=0)
|
1383 |
+
|
1384 |
+
# data availabe through API
|
1385 |
+
# api_path = API_path_dict[api_name]
|
1386 |
+
|
1387 |
+
with project_col1[1]:
|
1388 |
+
defualt_project_prefix = f"{api_name.split('_mmo_')[0]}_{st.session_state['unique_ids'][st.session_state['emp_id']][0]}_".replace(
|
1389 |
+
" ", "_"
|
1390 |
+
).lower()
|
1391 |
+
|
1392 |
+
if "project_name_box" not in st.session_state:
|
1393 |
+
st.session_state["project_name_box"] = defualt_project_prefix
|
1394 |
+
|
1395 |
+
project_name = st.text_input(
|
1396 |
+
"Enter Project Name", key="project_name_box"
|
1397 |
+
)
|
1398 |
+
warning_box = st.empty()
|
1399 |
+
|
1400 |
+
with project_col1[2]:
|
1401 |
+
|
1402 |
+
allowed_users = st.multiselect(
|
1403 |
+
"Select Users who can access to this Project",
|
1404 |
+
[val for val in unique_users],
|
1405 |
+
)
|
1406 |
+
|
1407 |
+
allowed_users = list(allowed_users)
|
1408 |
+
|
1409 |
+
matching_user_id = []
|
1410 |
+
|
1411 |
+
if len(allowed_users) > 0:
|
1412 |
+
|
1413 |
+
# converting the selection to comma seperated values to store in db
|
1414 |
+
|
1415 |
+
for emp_id, details in st.session_state["unique_ids"].items():
|
1416 |
+
for name in allowed_users:
|
1417 |
+
if name in details:
|
1418 |
+
matching_user_id.append(emp_id)
|
1419 |
+
break
|
1420 |
+
|
1421 |
+
st.button(
|
1422 |
+
"Reset Project Name",
|
1423 |
+
on_click=reset_project_text_box,
|
1424 |
+
help="",
|
1425 |
+
use_container_width=True,
|
1426 |
+
)
|
1427 |
+
|
1428 |
+
create = st.button(
|
1429 |
+
"Create New Project",
|
1430 |
+
use_container_width=True,
|
1431 |
+
help="Project Name should follow naming convention",
|
1432 |
+
)
|
1433 |
+
|
1434 |
+
if create:
|
1435 |
+
if not project_name.lower().startswith(defualt_project_prefix):
|
1436 |
+
with warning_box:
|
1437 |
+
st.warning("Project Name should follow naming convention")
|
1438 |
+
st.stop()
|
1439 |
+
|
1440 |
+
if project_name == defualt_project_prefix:
|
1441 |
+
with warning_box:
|
1442 |
+
st.warning("Cannot name only with prefix")
|
1443 |
+
st.stop()
|
1444 |
+
|
1445 |
+
if project_name in user_projects:
|
1446 |
+
with warning_box:
|
1447 |
+
st.warning("Project already exists please enter new name")
|
1448 |
+
st.stop()
|
1449 |
+
|
1450 |
+
if not (
|
1451 |
+
2 <= len(project_name) <= 50
|
1452 |
+
and bool(re.match("^[A-Za-z0-9_]*$", project_name))
|
1453 |
+
):
|
1454 |
+
# Store the warning message details in session state
|
1455 |
+
|
1456 |
+
with warning_box:
|
1457 |
+
st.warning(
|
1458 |
+
"Please provide a valid project name (2-50 characters, only A-Z, a-z, 0-9, and _)."
|
1459 |
+
)
|
1460 |
+
st.stop()
|
1461 |
+
|
1462 |
+
if contains_sql_keywords_check(project_name):
|
1463 |
+
with warning_box:
|
1464 |
+
st.warning(
|
1465 |
+
"Input contains SQL keywords. Please avoid using SQL commands."
|
1466 |
+
)
|
1467 |
+
st.stop()
|
1468 |
+
else:
|
1469 |
+
pass
|
1470 |
+
|
1471 |
+
with st.spinner("Creating Project"):
|
1472 |
+
new_project()
|
1473 |
+
|
1474 |
+
with warning_box:
|
1475 |
+
st.write("Project Created")
|
1476 |
+
|
1477 |
+
st.session_state["radio_box_index"] = 1
|
1478 |
+
|
1479 |
+
log_message(
|
1480 |
+
"info",
|
1481 |
+
f"Employee {st.session_state['emp_id']} created new project {project_name}.",
|
1482 |
+
"Home",
|
1483 |
+
)
|
1484 |
+
|
1485 |
+
st.rerun()
|
1486 |
+
|
1487 |
+
if projct_radio == "Clone Saved Projects":
|
1488 |
+
|
1489 |
+
with st.expander("Clone Saved Projects", expanded=False):
|
1490 |
+
|
1491 |
+
if len(st.session_state["clone_project_dict"]) == 0:
|
1492 |
+
st.warning("You dont have access to any saved projects")
|
1493 |
+
st.stop()
|
1494 |
+
|
1495 |
+
cols = st.columns(2)
|
1496 |
+
|
1497 |
+
with cols[0]:
|
1498 |
+
owners = list(st.session_state["clone_project_dict"].keys())
|
1499 |
+
owner_name = st.selectbox("Select Owner", owners)
|
1500 |
+
|
1501 |
+
with cols[1]:
|
1502 |
+
|
1503 |
+
project_names = [
|
1504 |
+
project["project_name"]
|
1505 |
+
for project in st.session_state["clone_project_dict"][owner_name]
|
1506 |
+
]
|
1507 |
+
project_name_owner = st.selectbox(
|
1508 |
+
"Select a saved Project available for you",
|
1509 |
+
project_names,
|
1510 |
+
)
|
1511 |
+
|
1512 |
+
defualt_project_prefix = f"{project_name_owner.split('_')[0]}_{st.session_state['unique_ids'][st.session_state['emp_id']][0]}_".replace(
|
1513 |
+
" ", "_"
|
1514 |
+
).lower()
|
1515 |
+
user_projects = list(
|
1516 |
+
set(st.session_state["project_summary_df"]["Project Name"])
|
1517 |
+
)
|
1518 |
+
|
1519 |
+
cloned_project_name = st.text_input(
|
1520 |
+
"Enter Project Name",
|
1521 |
+
value=defualt_project_prefix,
|
1522 |
+
)
|
1523 |
+
warning_box = st.empty()
|
1524 |
+
|
1525 |
+
if st.button(
|
1526 |
+
"Load Project", use_container_width=True, key="load_project_button_key"
|
1527 |
+
):
|
1528 |
+
|
1529 |
+
if not cloned_project_name.lower().startswith(defualt_project_prefix):
|
1530 |
+
with warning_box:
|
1531 |
+
st.warning("Project Name should follow naming conventions")
|
1532 |
+
st.stop()
|
1533 |
+
|
1534 |
+
if cloned_project_name == defualt_project_prefix:
|
1535 |
+
with warning_box:
|
1536 |
+
st.warning("Cannot Name only with Prefix")
|
1537 |
+
st.stop()
|
1538 |
+
|
1539 |
+
if cloned_project_name in user_projects:
|
1540 |
+
with warning_box:
|
1541 |
+
st.warning("Project already exists please enter new name")
|
1542 |
+
st.stop()
|
1543 |
+
|
1544 |
+
with st.spinner("Cloning Project"):
|
1545 |
+
old_prj_id = get_project_id_from_dict(
|
1546 |
+
st.session_state["clone_project_dict"],
|
1547 |
+
owner_name,
|
1548 |
+
project_name_owner,
|
1549 |
+
)
|
1550 |
+
old_metadata = fetch_project_metadata(old_prj_id)
|
1551 |
+
|
1552 |
+
new_prj_id = create_new_project(
|
1553 |
+
st.session_state["emp_id"],
|
1554 |
+
cloned_project_name,
|
1555 |
+
"",
|
1556 |
+
st.session_state["emp_id"],
|
1557 |
+
)
|
1558 |
+
|
1559 |
+
insert_project_metadata(
|
1560 |
+
new_prj_id, old_metadata, st.session_state["emp_id"]
|
1561 |
+
)
|
1562 |
+
update_summary_df()
|
1563 |
+
st.success("Project Cloned")
|
1564 |
+
st.rerun()
|
1565 |
+
|
1566 |
+
if projct_radio == "Delete Projects":
|
1567 |
+
if len(st.session_state["project_summary_df"]) != 0:
|
1568 |
+
|
1569 |
+
with st.expander("Delete Projects", expanded=True):
|
1570 |
+
|
1571 |
+
delete_projects = st.multiselect(
|
1572 |
+
"Select all the projects number who want to delete",
|
1573 |
+
st.session_state["project_summary_df"]["Project Number"],
|
1574 |
+
)
|
1575 |
+
st.warning(
|
1576 |
+
"Projects will be permanently deleted. Other users will not be able to clone them if they have not already done so."
|
1577 |
+
)
|
1578 |
+
if st.button("Delete Projects", use_container_width=True):
|
1579 |
+
if len(delete_projects) > 0:
|
1580 |
+
with st.spinner("Deleting Projects"):
|
1581 |
+
delete_projects_by_ids(delete_projects)
|
1582 |
+
update_summary_df()
|
1583 |
+
st.success("Projects Deleted")
|
1584 |
+
st.rerun()
|
1585 |
+
|
1586 |
+
else:
|
1587 |
+
st.warning("Please select atleast one project number to delete")
|
1588 |
+
|
1589 |
+
if projct_radio == "Download Project PPT":
|
1590 |
+
|
1591 |
+
try:
|
1592 |
+
ppt = create_ppt(
|
1593 |
+
st.session_state["project_name"],
|
1594 |
+
st.session_state["username"],
|
1595 |
+
"panel", # new
|
1596 |
+
)
|
1597 |
+
|
1598 |
+
if ppt is not False:
|
1599 |
+
st.download_button(
|
1600 |
+
"Download",
|
1601 |
+
data=ppt.getvalue(),
|
1602 |
+
file_name=st.session_state["project_name"]
|
1603 |
+
+ " Project Summary.pptx",
|
1604 |
+
use_container_width=True,
|
1605 |
+
)
|
1606 |
+
else:
|
1607 |
+
st.warning("Please make some progress before downloading PPT.")
|
1608 |
+
|
1609 |
+
except Exception as e:
|
1610 |
+
st.warning("PPT Download Faild ")
|
1611 |
+
# new
|
1612 |
+
log_message(
|
1613 |
+
log_type="error", message=f"Error in PPT build: {e}", page_name="Home"
|
1614 |
+
)
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: MediaMixOptimization
|
3 |
-
emoji:
|
4 |
-
colorFrom: red
|
5 |
-
colorTo:
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
-
app_file:
|
9 |
-
pinned: false
|
10 |
-
short_description: This tool helps in optimizing media spends
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: MediaMixOptimization
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: purple
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.28.0
|
8 |
+
app_file: Home.py
|
9 |
+
pinned: false
|
10 |
+
short_description: This tool helps in optimizing media spends
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
config.json
ADDED
File without changes
|
config.yaml
ADDED
File without changes
|
constants.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################################################
|
2 |
+
# Default
|
3 |
+
###########################################################################################################
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
# default_dct
|
9 |
+
default_dct = {
|
10 |
+
"data_import": {
|
11 |
+
"gold_layer_df": pd.DataFrame(),
|
12 |
+
"granularity_selection": "daily",
|
13 |
+
"dashboard_df": None,
|
14 |
+
"tool_df": None,
|
15 |
+
"unique_panels": [],
|
16 |
+
"imputation_df": None,
|
17 |
+
"imputed_tool_df": None,
|
18 |
+
"group_dict": {},
|
19 |
+
"category_dict": {},
|
20 |
+
},
|
21 |
+
"data_validation": {
|
22 |
+
"target_column": 0,
|
23 |
+
"selected_panels": None,
|
24 |
+
"selected_feature": 0,
|
25 |
+
"validated_variables": [],
|
26 |
+
"Non_media_variables": 0,
|
27 |
+
"correlation": [],
|
28 |
+
},
|
29 |
+
"transformations": {
|
30 |
+
"final_df": None,
|
31 |
+
"summary_string": None,
|
32 |
+
"Media": {},
|
33 |
+
"Exogenous": {},
|
34 |
+
"Internal": {},
|
35 |
+
"Specific": {},
|
36 |
+
"correlation_plot_selection": [],
|
37 |
+
},
|
38 |
+
"model_build": {
|
39 |
+
"sel_target_col": None,
|
40 |
+
"all_iters_check": False,
|
41 |
+
"iterations": 0,
|
42 |
+
"build_button": False,
|
43 |
+
"show_results_check": False,
|
44 |
+
"session_state_saved": {},
|
45 |
+
},
|
46 |
+
"model_tuning": {
|
47 |
+
"sel_target_col": None,
|
48 |
+
"sel_model": {},
|
49 |
+
"flag_expander": False,
|
50 |
+
"start_date_default": None,
|
51 |
+
"end_date_default": None,
|
52 |
+
"repeat_default": "No",
|
53 |
+
"flags": {},
|
54 |
+
"select_all_flags_check": {},
|
55 |
+
"selected_flags": {},
|
56 |
+
"trend_check": False,
|
57 |
+
"week_num_check": False,
|
58 |
+
"sine_cosine_check": False,
|
59 |
+
"session_state_saved": {},
|
60 |
+
},
|
61 |
+
"saved_model_results": {
|
62 |
+
"selected_options": None,
|
63 |
+
"model_grid_sel": [1],
|
64 |
+
},
|
65 |
+
"current_media_performance": {
|
66 |
+
"model_outputs": {},
|
67 |
+
},
|
68 |
+
"media_performance": {"start_date": None, "end_date": None},
|
69 |
+
"response_curves": {
|
70 |
+
"original_metadata_file": None,
|
71 |
+
"modified_metadata_file": None,
|
72 |
+
},
|
73 |
+
"scenario_planner": {
|
74 |
+
"original_metadata_file": None,
|
75 |
+
"modified_metadata_file": None,
|
76 |
+
},
|
77 |
+
"saved_scenarios": {"saved_scenarios_dict": OrderedDict()},
|
78 |
+
"optimized_result_analysis": {
|
79 |
+
"selected_scenario_selectbox_visualize": 0,
|
80 |
+
"metric_selectbox_visualize": 0,
|
81 |
+
},
|
82 |
+
}
|
83 |
+
|
84 |
+
###########################################################################################################
|
85 |
+
# Data Import
|
86 |
+
###########################################################################################################
|
87 |
+
|
88 |
+
# Constants Data Import
|
89 |
+
upload_rows_limit = 1000000 # Maximum number of rows allowed for upload
|
90 |
+
upload_column_limit = 1000 # Maximum number of columns allowed for upload
|
91 |
+
word_length_limit_lower = 2 # Minimum allowed length for words
|
92 |
+
word_length_limit_upper = 100 # Maximum allowed length for words
|
93 |
+
minimum_percent_overlap = (
|
94 |
+
1 # Minimum required percentage of overlap with the reference data
|
95 |
+
)
|
96 |
+
minimum_row_req = 30 # Minimum number of rows required
|
97 |
+
percent_drop_col_threshold = 50 # Percentage threshold above which columns are automatically categorized for drop imputation
|
98 |
+
|
99 |
+
###########################################################################################################
|
100 |
+
# Transfromations
|
101 |
+
###########################################################################################################
|
102 |
+
|
103 |
+
# Constants Transformations
|
104 |
+
predefined_defaults = {
|
105 |
+
"Lag": (1, 2),
|
106 |
+
"Lead": (1, 2),
|
107 |
+
"Moving Average": (1, 2),
|
108 |
+
"Saturation": (10, 20),
|
109 |
+
"Power": (2, 4),
|
110 |
+
"Adstock": (0.5, 0.7),
|
111 |
+
} # Pre-defined default values of every transformation
|
112 |
+
|
113 |
+
# Transfromations min, max and step
|
114 |
+
lead_min_value = 1
|
115 |
+
lead_max_value = 10
|
116 |
+
lead_step = 1
|
117 |
+
lag_min_value = 1
|
118 |
+
lag_max_value = 10
|
119 |
+
lag_step = 1
|
120 |
+
moving_average_min_value = 1
|
121 |
+
moving_average_max_value = 10
|
122 |
+
moving_average_step = 1
|
123 |
+
saturation_min_value = 0
|
124 |
+
saturation_max_value = 100
|
125 |
+
saturation_step = 1
|
126 |
+
power_min_value = 1
|
127 |
+
power_max_value = 5
|
128 |
+
power_step = 1
|
129 |
+
adstock_min_value = 0.0
|
130 |
+
adstock_max_value = 1.0
|
131 |
+
adstock_step = 0.05
|
132 |
+
display_max_col = 500 # Maximum columns to display
|
133 |
+
|
134 |
+
###########################################################################################################
|
135 |
+
# Model Build
|
136 |
+
###########################################################################################################
|
137 |
+
|
138 |
+
MAX_COMBINATIONS = 50000 # Max number of model combinations possible
|
139 |
+
MIN_MODEL_NAME_LENGTH = 0 # model can only be saved if len(model_name) is greater than this value
|
140 |
+
MIN_P_VALUE_THRESHOLD = 0.06 # coefficients with p values less than this value are considered valid
|
141 |
+
MODEL_POS_COEFF_RATIO_THRESHOLD = 0 # ratio of positive coefficients/total coefficients for model validity
|
142 |
+
MODEL_P_VALUE_RATIO_THRESHOLD = 0 # ratio of coefficients with p value/total coefficients for model validity
|
143 |
+
MAX_TOP_FEATURES = 5 # max number of top features selected per variable
|
144 |
+
MAX_NUM_FILTERS = 10 # maximum number of filters allowed
|
145 |
+
DEFAULT_FILTER_VALUE = 0.0 # default value of a new filter
|
146 |
+
VIF_LOW_THRESHOLD = 3 # Threshold for VIF to be colored green
|
147 |
+
VIF_HIGH_THRESHOLD = 10 # Threshold for VIF to be colored red
|
148 |
+
DEFAULT_TRAIN_RATIO = 3/4 # default train set ratio (75%)
|
149 |
+
|
150 |
+
###########################################################################################################
|
151 |
+
# Model Tuning
|
152 |
+
###########################################################################################################
|
153 |
+
|
154 |
+
import numpy as np
|
155 |
+
|
156 |
+
NUM_FLAG_COLS_TO_DISPLAY = 4 # Number of columns to be created on UI to display flags
|
157 |
+
HALF_YEAR_THRESHOLD = 6 # Threshold of months to create quarter frequency sine-cosine waves
|
158 |
+
FULL_YEAR_THRESHOLD = 12 # Threshold of months to create quarter annual sine-cosine waves
|
159 |
+
TREND_MIN = 5 # Starting value of trend line
|
160 |
+
ANNUAL_FREQUENCY = 2 * np.pi / 365 # annual frequency
|
161 |
+
QTR_FREQUENCY_FACTOR = 4 # multiplication factor to get quarterly frequency
|
162 |
+
HALF_YEARLY_FREQUENCY_FACTOR = 2 # multiplication factor to get semi-annual frequency
|
163 |
+
|
164 |
+
###########################################################################################################
|
165 |
+
# Scenario Planner
|
166 |
+
###########################################################################################################
|
167 |
+
|
168 |
+
# Constants Scenario Planner
|
169 |
+
xtol_tolerance_per = 1 # Percenatge of tolerance
|
170 |
+
mroi_threshold = 0.05 # mROI threshold
|
171 |
+
word_length_limit_lower = 2 # Minimum allowed length for words
|
172 |
+
word_length_limit_upper = 100 # Maximum allowed length for words
|
173 |
+
|
174 |
+
###########################################################################################################
|
175 |
+
# PPT utils
|
176 |
+
###########################################################################################################
|
177 |
+
|
178 |
+
TITLE_FONT_SIZE = 20
|
179 |
+
AXIS_LABEL_FONT_SIZE = 8
|
180 |
+
CHART_TITLE_FONT_SIZE = 14
|
181 |
+
AXIS_TITLE_FONT_SIZE = 12
|
182 |
+
DATA_LABEL_FONT_SIZE = 8
|
183 |
+
LEGEND_FONT_SIZE = 10
|
184 |
+
PIE_LEGEND_FONT_SIZE = 7
|
185 |
+
|
186 |
+
###########################################################################################################
|
187 |
+
# Page Name
|
188 |
+
###########################################################################################################
|
data_analysis.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import plotly.express as px
|
3 |
+
import numpy as np
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from sklearn.metrics import r2_score
|
6 |
+
from collections import OrderedDict
|
7 |
+
import plotly.express as px
|
8 |
+
import plotly.graph_objects as go
|
9 |
+
import pandas as pd
|
10 |
+
import seaborn as sns
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import streamlit as st
|
13 |
+
import re
|
14 |
+
from matplotlib.colors import ListedColormap
|
15 |
+
# from st_aggrid import AgGrid, GridOptionsBuilder
|
16 |
+
# from src.agstyler import PINLEFT, PRECISION_TWO, draw_grid
|
17 |
+
|
18 |
+
|
19 |
+
def format_numbers(x):
|
20 |
+
if abs(x) >= 1e6:
|
21 |
+
# Format as millions with one decimal place and commas
|
22 |
+
return f'{x/1e6:,.1f}M'
|
23 |
+
elif abs(x) >= 1e3:
|
24 |
+
# Format as thousands with one decimal place and commas
|
25 |
+
return f'{x/1e3:,.1f}K'
|
26 |
+
else:
|
27 |
+
# Format with one decimal place and commas for values less than 1000
|
28 |
+
return f'{x:,.1f}'
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def line_plot(data, x_col, y1_cols, y2_cols, title):
|
33 |
+
"""
|
34 |
+
Create a line plot with two sets of y-axis data.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
data (DataFrame): The data containing the columns to be plotted.
|
38 |
+
x_col (str): The column name for the x-axis.
|
39 |
+
y1_cols (list): List of column names for the primary y-axis.
|
40 |
+
y2_cols (list): List of column names for the secondary y-axis.
|
41 |
+
title (str): The title of the plot.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
fig (Figure): The Plotly figure object with the line plot.
|
45 |
+
"""
|
46 |
+
fig = go.Figure()
|
47 |
+
|
48 |
+
# Add traces for the primary y-axis
|
49 |
+
for y1_col in y1_cols:
|
50 |
+
fig.add_trace(go.Scatter(x=data[x_col], y=data[y1_col], mode='lines', name=y1_col, line=dict(color='#11B6BD')))
|
51 |
+
|
52 |
+
# Add traces for the secondary y-axis
|
53 |
+
for y2_col in y2_cols:
|
54 |
+
fig.add_trace(go.Scatter(x=data[x_col], y=data[y2_col], mode='lines', name=y2_col, yaxis='y2', line=dict(color='#739FAE')))
|
55 |
+
|
56 |
+
# Configure the layout for the secondary y-axis if needed
|
57 |
+
if len(y2_cols) != 0:
|
58 |
+
fig.update_layout(yaxis=dict(), yaxis2=dict(overlaying='y', side='right'))
|
59 |
+
else:
|
60 |
+
fig.update_layout(yaxis=dict(), yaxis2=dict(overlaying='y', side='right'))
|
61 |
+
|
62 |
+
# Add title if provided
|
63 |
+
if title:
|
64 |
+
fig.update_layout(title=title)
|
65 |
+
|
66 |
+
# Customize axes and legend
|
67 |
+
fig.update_xaxes(showgrid=False)
|
68 |
+
fig.update_yaxes(showgrid=False)
|
69 |
+
fig.update_layout(legend=dict(
|
70 |
+
orientation="h",
|
71 |
+
yanchor="top",
|
72 |
+
y=1.1,
|
73 |
+
xanchor="center",
|
74 |
+
x=0.5
|
75 |
+
))
|
76 |
+
|
77 |
+
return fig
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
def line_plot_target(df, target, title):
|
82 |
+
"""
|
83 |
+
Create a line plot with a trendline for a target column.
|
84 |
+
|
85 |
+
Parameters:
|
86 |
+
df (DataFrame): The data containing the columns to be plotted.
|
87 |
+
target (str): The column name for the y-axis.
|
88 |
+
title (str): The title of the plot.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
fig (Figure): The Plotly figure object with the line plot and trendline.
|
92 |
+
"""
|
93 |
+
# Calculate the trendline coefficients
|
94 |
+
coefficients = np.polyfit(df['date'].view('int64'), df[target], 1)
|
95 |
+
trendline = np.poly1d(coefficients)
|
96 |
+
fig = go.Figure()
|
97 |
+
|
98 |
+
# Add the target line plot
|
99 |
+
fig.add_trace(go.Scatter(x=df['date'], y=df[target], mode='lines', name=target, line=dict(color='#11B6BD')))
|
100 |
+
|
101 |
+
# Calculate and add the trendline plot
|
102 |
+
trendline_x = df['date']
|
103 |
+
trendline_y = trendline(df['date'].view('int64'))
|
104 |
+
fig.add_trace(go.Scatter(x=trendline_x, y=trendline_y, mode='lines', name='Trendline', line=dict(color='#739FAE')))
|
105 |
+
|
106 |
+
# Update layout with title and x-axis type
|
107 |
+
fig.update_layout(
|
108 |
+
title=title,
|
109 |
+
xaxis=dict(type='date')
|
110 |
+
)
|
111 |
+
|
112 |
+
# Add vertical lines at the start of each year
|
113 |
+
for year in df['date'].dt.year.unique()[1:]:
|
114 |
+
january_1 = pd.Timestamp(year=year, month=1, day=1)
|
115 |
+
fig.add_shape(
|
116 |
+
go.layout.Shape(
|
117 |
+
type="line",
|
118 |
+
x0=january_1,
|
119 |
+
x1=january_1,
|
120 |
+
y0=0,
|
121 |
+
y1=1,
|
122 |
+
xref="x",
|
123 |
+
yref="paper",
|
124 |
+
line=dict(color="grey", width=1.5, dash="dash"),
|
125 |
+
)
|
126 |
+
)
|
127 |
+
|
128 |
+
# Customize the legend
|
129 |
+
fig.update_layout(legend=dict(
|
130 |
+
orientation="h",
|
131 |
+
yanchor="top",
|
132 |
+
y=1.1,
|
133 |
+
xanchor="center",
|
134 |
+
x=0.5
|
135 |
+
))
|
136 |
+
|
137 |
+
return fig
|
138 |
+
|
139 |
+
|
140 |
+
def correlation_plot(df, selected_features, target):
|
141 |
+
"""
|
142 |
+
Create a correlation heatmap plot for selected features and target column.
|
143 |
+
|
144 |
+
Parameters:
|
145 |
+
df (DataFrame): The data containing the columns to be plotted.
|
146 |
+
selected_features (list): List of column names to be included in the correlation plot.
|
147 |
+
target (str): The target column name to be included in the correlation plot.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
fig (Figure): The Matplotlib figure object with the correlation heatmap plot.
|
151 |
+
"""
|
152 |
+
# Define custom colormap
|
153 |
+
custom_cmap = ListedColormap(['#08083B', "#11B6BD"])
|
154 |
+
|
155 |
+
# Select the relevant columns for correlation calculation
|
156 |
+
corr_df = df[selected_features]
|
157 |
+
corr_df = pd.concat([corr_df, df[target]], axis=1)
|
158 |
+
|
159 |
+
# Create a matplotlib figure and axis
|
160 |
+
fig, ax = plt.subplots(figsize=(16, 12))
|
161 |
+
|
162 |
+
# Generate the heatmap with correlation coefficients
|
163 |
+
sns.heatmap(corr_df.corr(), annot=True, cmap='Blues', fmt=".2f", linewidths=0.5, mask=np.triu(corr_df.corr()))
|
164 |
+
|
165 |
+
# Customize the plot
|
166 |
+
plt.xticks(rotation=45)
|
167 |
+
plt.yticks(rotation=0)
|
168 |
+
|
169 |
+
return fig
|
170 |
+
|
171 |
+
|
172 |
+
def summary(data, selected_feature, spends, Target=None):
|
173 |
+
"""
|
174 |
+
Create a summary table of selected features and optionally a target column.
|
175 |
+
|
176 |
+
Parameters:
|
177 |
+
data (DataFrame): The data containing the columns to be summarized.
|
178 |
+
selected_feature (list): List of column names to be included in the summary.
|
179 |
+
spends (str): The column name for the spends data.
|
180 |
+
Target (str, optional): The target column name for additional summary calculations. Default is None.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
sum_df (DataFrame): The summary DataFrame with formatted values.
|
184 |
+
"""
|
185 |
+
if Target:
|
186 |
+
# Summarize data for the target column
|
187 |
+
sum_df = data[selected_feature]
|
188 |
+
sum_df['Year'] = data['date'].dt.year
|
189 |
+
sum_df = sum_df.groupby('Year')[selected_feature].sum().reset_index()
|
190 |
+
|
191 |
+
# Calculate total sum and append to the DataFrame
|
192 |
+
total_sum = sum_df.sum(numeric_only=True)
|
193 |
+
total_sum['Year'] = 'Total'
|
194 |
+
sum_df = pd.concat([sum_df, total_sum.to_frame().T], axis=0, ignore_index=True).copy()
|
195 |
+
|
196 |
+
# Set 'Year' as index and format numbers
|
197 |
+
sum_df.set_index(['Year'], inplace=True)
|
198 |
+
sum_df = sum_df.applymap(format_numbers)
|
199 |
+
|
200 |
+
# Format spends columns as currency
|
201 |
+
spends_col = [col for col in sum_df.columns if any(keyword in col for keyword in ['spends', 'cost'])]
|
202 |
+
for col in spends_col:
|
203 |
+
sum_df[col] = sum_df[col].map(lambda x: f'${x}')
|
204 |
+
|
205 |
+
return sum_df
|
206 |
+
else:
|
207 |
+
# Include spends in the selected features
|
208 |
+
selected_feature.append(spends)
|
209 |
+
|
210 |
+
# Ensure unique features
|
211 |
+
selected_feature = list(set(selected_feature))
|
212 |
+
|
213 |
+
if len(selected_feature) > 1:
|
214 |
+
imp_clicks = selected_feature[1]
|
215 |
+
spends_col = selected_feature[0]
|
216 |
+
|
217 |
+
# Summarize data for the selected features
|
218 |
+
sum_df = data[selected_feature]
|
219 |
+
sum_df['Year'] = data['date'].dt.year
|
220 |
+
sum_df = sum_df.groupby('Year')[selected_feature].agg('sum')
|
221 |
+
|
222 |
+
# Calculate CPM/CPC
|
223 |
+
sum_df['CPM/CPC'] = (sum_df[spends_col] / sum_df[imp_clicks]) * 1000
|
224 |
+
|
225 |
+
# Calculate grand total and append to the DataFrame
|
226 |
+
sum_df.loc['Grand Total'] = sum_df.sum()
|
227 |
+
|
228 |
+
# Format numbers and replace NaNs
|
229 |
+
sum_df = sum_df.applymap(format_numbers)
|
230 |
+
sum_df.fillna('-', inplace=True)
|
231 |
+
sum_df = sum_df.replace({"0.0": '-', 'nan': '-'})
|
232 |
+
|
233 |
+
# Format spends columns as currency
|
234 |
+
sum_df[spends_col] = sum_df[spends_col].map(lambda x: f'${x}')
|
235 |
+
|
236 |
+
return sum_df
|
237 |
+
else:
|
238 |
+
# Summarize data for a single selected feature
|
239 |
+
sum_df = data[selected_feature]
|
240 |
+
sum_df['Year'] = data['date'].dt.year
|
241 |
+
sum_df = sum_df.groupby('Year')[selected_feature].agg('sum')
|
242 |
+
|
243 |
+
# Calculate grand total and append to the DataFrame
|
244 |
+
sum_df.loc['Grand Total'] = sum_df.sum()
|
245 |
+
|
246 |
+
# Format numbers and replace NaNs
|
247 |
+
sum_df = sum_df.applymap(format_numbers)
|
248 |
+
sum_df.fillna('-', inplace=True)
|
249 |
+
sum_df = sum_df.replace({"0.0": '-', 'nan': '-'})
|
250 |
+
|
251 |
+
# Format spends columns as currency
|
252 |
+
spends_col = [col for col in sum_df.columns if any(keyword in col for keyword in ['spends', 'cost'])]
|
253 |
+
for col in spends_col:
|
254 |
+
sum_df[col] = sum_df[col].map(lambda x: f'${x}')
|
255 |
+
|
256 |
+
return sum_df
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
def sanitize_key(key, prefix=""):
|
261 |
+
# Use regular expressions to remove non-alphanumeric characters and spaces
|
262 |
+
key = re.sub(r'[^a-zA-Z0-9]', '', key)
|
263 |
+
return f"{prefix}{key}"
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
|
data_prep.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
import statsmodels.api as sm
|
6 |
+
from sklearn.metrics import mean_absolute_error, r2_score,mean_absolute_percentage_error
|
7 |
+
from sklearn.preprocessing import MinMaxScaler
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
10 |
+
from plotly.subplots import make_subplots
|
11 |
+
|
12 |
+
st.set_option('deprecation.showPyplotGlobalUse', False)
|
13 |
+
from datetime import datetime
|
14 |
+
import seaborn as sns
|
15 |
+
|
16 |
+
|
17 |
+
def plot_actual_vs_predicted(date, y, predicted_values, model, target_column=None, flag=None, repeat_all_years=False, is_panel=False):
|
18 |
+
"""
|
19 |
+
Plots actual vs predicted values with optional flags and aggregation for panel data.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
date (pd.Series): Series of dates for x-axis.
|
23 |
+
y (pd.Series): Actual values.
|
24 |
+
predicted_values (pd.Series): Predicted values from the model.
|
25 |
+
model (object): Trained model object.
|
26 |
+
target_column (str, optional): Name of the target column.
|
27 |
+
flag (tuple, optional): Start and end dates for flagging periods.
|
28 |
+
repeat_all_years (bool, optional): Whether to repeat flags for all years.
|
29 |
+
is_panel (bool, optional): Whether the data is panel data requiring aggregation.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
metrics_table (pd.DataFrame): DataFrame containing MAPE, R-squared, and Adjusted R-squared.
|
33 |
+
line_values (list): List of flag values for plotting.
|
34 |
+
fig (go.Figure): Plotly figure object.
|
35 |
+
"""
|
36 |
+
if flag is not None:
|
37 |
+
fig = make_subplots(specs=[[{"secondary_y": True}]])
|
38 |
+
else:
|
39 |
+
fig = go.Figure()
|
40 |
+
|
41 |
+
if is_panel:
|
42 |
+
df = pd.DataFrame()
|
43 |
+
df['date'] = date
|
44 |
+
df['Actual'] = y
|
45 |
+
df['Predicted'] = predicted_values
|
46 |
+
df_agg = df.groupby('date').agg({'Actual': 'sum', 'Predicted': 'sum'}).reset_index()
|
47 |
+
df_agg.columns = ['date', 'Actual', 'Predicted']
|
48 |
+
assert len(df_agg) == pd.Series(date).nunique()
|
49 |
+
|
50 |
+
fig.add_trace(go.Scatter(x=df_agg['date'], y=df_agg['Actual'], mode='lines', name='Actual', line=dict(color='#08083B')))
|
51 |
+
fig.add_trace(go.Scatter(x=df_agg['date'], y=df_agg['Predicted'], mode='lines', name='Predicted', line=dict(color='#11B6BD')))
|
52 |
+
else:
|
53 |
+
fig.add_trace(go.Scatter(x=date, y=y, mode='lines', name='Actual', line=dict(color='#08083B')))
|
54 |
+
fig.add_trace(go.Scatter(x=date, y=predicted_values, mode='lines', name='Predicted', line=dict(color='#11B6BD')))
|
55 |
+
|
56 |
+
line_values = []
|
57 |
+
if flag:
|
58 |
+
min_date, max_date = flag[0], flag[1]
|
59 |
+
min_week = datetime.strptime(str(min_date), "%Y-%m-%d").strftime("%U")
|
60 |
+
max_week = datetime.strptime(str(max_date), "%Y-%m-%d").strftime("%U")
|
61 |
+
|
62 |
+
if repeat_all_years:
|
63 |
+
line_values = list(pd.Series(date).map(lambda x: 1 if (pd.Timestamp(x).week >= int(min_week)) & (pd.Timestamp(x).week <= int(max_week)) else 0))
|
64 |
+
assert len(line_values) == len(date)
|
65 |
+
fig.add_trace(go.Scatter(x=date, y=line_values, mode='lines', name='Flag', line=dict(color='#FF5733')), secondary_y=True)
|
66 |
+
else:
|
67 |
+
line_values = list(pd.Series(date).map(lambda x: 1 if (pd.Timestamp(x) >= pd.Timestamp(min_date)) and (pd.Timestamp(x) <= pd.Timestamp(max_date)) else 0))
|
68 |
+
fig.add_trace(go.Scatter(x=date, y=line_values, mode='lines', name='Flag', line=dict(color='#FF5733')), secondary_y=True)
|
69 |
+
|
70 |
+
mape = mean_absolute_percentage_error(y, predicted_values)
|
71 |
+
r2 = r2_score(y, predicted_values)
|
72 |
+
adjr2 = 1 - (1 - r2) * (len(y) - 1) / (len(y) - len(model.params) - 1)
|
73 |
+
|
74 |
+
metrics_table = pd.DataFrame({
|
75 |
+
'Metric': ['MAPE', 'R-squared', 'AdjR-squared'],
|
76 |
+
'Value': [mape, r2, adjr2]
|
77 |
+
})
|
78 |
+
|
79 |
+
fig.update_layout(
|
80 |
+
xaxis=dict(title='Date'),
|
81 |
+
yaxis=dict(title=target_column),
|
82 |
+
xaxis_tickangle=-30
|
83 |
+
)
|
84 |
+
fig.add_annotation(
|
85 |
+
text=f"MAPE: {mape * 100:0.1f}%, Adj. R-squared: {adjr2 * 100:.1f}%",
|
86 |
+
xref="paper",
|
87 |
+
yref="paper",
|
88 |
+
x=0.95,
|
89 |
+
y=1.2,
|
90 |
+
showarrow=False,
|
91 |
+
)
|
92 |
+
|
93 |
+
return metrics_table, line_values, fig
|
94 |
+
|
95 |
+
|
96 |
+
def plot_residual_predicted(actual, predicted, df):
|
97 |
+
"""
|
98 |
+
Plots standardized residuals against predicted values.
|
99 |
+
|
100 |
+
Parameters:
|
101 |
+
actual (pd.Series): Actual values.
|
102 |
+
predicted (pd.Series): Predicted values.
|
103 |
+
df (pd.DataFrame): DataFrame containing the data.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
fig (go.Figure): Plotly figure object.
|
107 |
+
"""
|
108 |
+
df_ = df.copy()
|
109 |
+
df_['Residuals'] = actual - pd.Series(predicted)
|
110 |
+
df_['StdResidual'] = (df_['Residuals'] - df_['Residuals'].mean()) / df_['Residuals'].std()
|
111 |
+
|
112 |
+
fig = px.scatter(df_, x=predicted, y='StdResidual', opacity=0.5, color_discrete_sequence=["#11B6BD"])
|
113 |
+
|
114 |
+
fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
|
115 |
+
fig.add_hline(y=2, line_color="red")
|
116 |
+
fig.add_hline(y=-2, line_color="red")
|
117 |
+
|
118 |
+
fig.update_xaxes(title='Predicted')
|
119 |
+
fig.update_yaxes(title='Standardized Residuals (Actual - Predicted)')
|
120 |
+
|
121 |
+
fig.update_layout(title='2.3.1 Residuals over Predicted Values', autosize=False, width=600, height=400)
|
122 |
+
|
123 |
+
return fig
|
124 |
+
|
125 |
+
|
126 |
+
def residual_distribution(actual, predicted):
|
127 |
+
"""
|
128 |
+
Plots the distribution of residuals.
|
129 |
+
|
130 |
+
Parameters:
|
131 |
+
actual (pd.Series): Actual values.
|
132 |
+
predicted (pd.Series): Predicted values.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
plt (matplotlib.pyplot): Matplotlib plot object.
|
136 |
+
"""
|
137 |
+
Residuals = actual - pd.Series(predicted)
|
138 |
+
|
139 |
+
sns.set(style="whitegrid")
|
140 |
+
plt.figure(figsize=(6, 4))
|
141 |
+
sns.histplot(Residuals, kde=True, color="#11B6BD")
|
142 |
+
|
143 |
+
plt.title('2.3.3 Distribution of Residuals')
|
144 |
+
plt.xlabel('Residuals')
|
145 |
+
plt.ylabel('Probability Density')
|
146 |
+
|
147 |
+
return plt
|
148 |
+
|
149 |
+
|
150 |
+
def qqplot(actual, predicted):
|
151 |
+
"""
|
152 |
+
Creates a QQ plot of the residuals.
|
153 |
+
|
154 |
+
Parameters:
|
155 |
+
actual (pd.Series): Actual values.
|
156 |
+
predicted (pd.Series): Predicted values.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
fig (go.Figure): Plotly figure object.
|
160 |
+
"""
|
161 |
+
Residuals = actual - pd.Series(predicted)
|
162 |
+
Residuals = pd.Series(Residuals)
|
163 |
+
Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
|
164 |
+
|
165 |
+
fig = go.Figure()
|
166 |
+
fig.add_trace(go.Scatter(x=sm.ProbPlot(Resud_std).theoretical_quantiles,
|
167 |
+
y=sm.ProbPlot(Resud_std).sample_quantiles,
|
168 |
+
mode='markers',
|
169 |
+
marker=dict(size=5, color="#11B6BD"),
|
170 |
+
name='QQ Plot'))
|
171 |
+
|
172 |
+
diagonal_line = go.Scatter(
|
173 |
+
x=[-2, 2],
|
174 |
+
y=[-2, 2],
|
175 |
+
mode='lines',
|
176 |
+
line=dict(color='red'),
|
177 |
+
name=' '
|
178 |
+
)
|
179 |
+
fig.add_trace(diagonal_line)
|
180 |
+
|
181 |
+
fig.update_layout(title='2.3.2 QQ Plot of Residuals', title_x=0.5, autosize=False, width=600, height=400,
|
182 |
+
xaxis_title='Theoretical Quantiles', yaxis_title='Sample Quantiles')
|
183 |
+
|
184 |
+
return fig
|
185 |
+
|
db/imp_db.db
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36d945da7748723c6078c2d2bee4aaae5ca0838dcee0e2032a836341dcf7794e
|
3 |
+
size 15618048
|
db_creation.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import sqlite3
|
2 |
+
|
3 |
+
# # Connect to SQLite database (or create it if it doesn't exist)
|
4 |
+
# conn = sqlite3.connect('imp_db.db')
|
5 |
+
|
6 |
+
# # Enable foreign key support
|
7 |
+
# conn.execute('PRAGMA foreign_keys = ON;')
|
8 |
+
|
9 |
+
# # Create a cursor object
|
10 |
+
# c = conn.cursor()
|
11 |
+
|
12 |
+
# # SQL queries to create tables
|
13 |
+
# create_mmo_users_table = """
|
14 |
+
# CREATE TABLE IF NOT EXISTS mmo_users (
|
15 |
+
# emp_id TEXT PRIMARY KEY ,
|
16 |
+
# emp_nam TEXT NOT NULL,
|
17 |
+
# emp_typ TEXT NOT NULL,
|
18 |
+
# pswrd_key TEXT NOT NULL,
|
19 |
+
# pswrd_flag INTEGER NOT NULL DEFAULT 0,
|
20 |
+
# crte_dt_tm TEXT DEFAULT (datetime('now')),
|
21 |
+
# crte_by_uid TEXT NOT NULL,
|
22 |
+
# updt_dt_tm TEXT DEFAULT (datetime('now')),
|
23 |
+
# updt_by_uid TEXT
|
24 |
+
# );
|
25 |
+
# """
|
26 |
+
|
27 |
+
# create_mmo_projects_table = """
|
28 |
+
# CREATE TABLE IF NOT EXISTS mmo_projects (
|
29 |
+
# prj_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
30 |
+
# prj_ownr_id TEXT NOT NULL,
|
31 |
+
# prj_nam TEXT NOT NULL,
|
32 |
+
# alwd_emp_id TEXT,
|
33 |
+
# meta_data_agrgt TEXT,
|
34 |
+
# crte_dt_tm TEXT DEFAULT (datetime('now')),
|
35 |
+
# crte_by_uid TEXT NOT NULL,
|
36 |
+
# updt_dt_tm TEXT DEFAULT (datetime('now')),
|
37 |
+
# updt_by_uid TEXT,
|
38 |
+
# FOREIGN KEY (prj_ownr_id) REFERENCES mmo_users(emp_id)
|
39 |
+
# );
|
40 |
+
# """
|
41 |
+
|
42 |
+
# create_mmo_project_meta_data_table = """
|
43 |
+
# CREATE TABLE IF NOT EXISTS mmo_project_meta_data (
|
44 |
+
# prj_guid INTEGER PRIMARY KEY AUTOINCREMENT,
|
45 |
+
# prj_id INTEGER NOT NULL,
|
46 |
+
# page_nam TEXT NOT NULL,
|
47 |
+
# file_nam TEXT NOT NULL,
|
48 |
+
# pkl_obj BLOB,
|
49 |
+
# dshbrd_ts TEXT,
|
50 |
+
# crte_dt_tm TEXT DEFAULT (datetime('now')),
|
51 |
+
# crte_by_uid TEXT NOT NULL,
|
52 |
+
# updt_dt_tm TEXT DEFAULT (datetime('now')),
|
53 |
+
# updt_by_uid TEXT,
|
54 |
+
# FOREIGN KEY (prj_id) REFERENCES mmo_projects(prj_id)
|
55 |
+
# );
|
56 |
+
# """
|
57 |
+
|
58 |
+
# # Execute the queries to create tables
|
59 |
+
# c.execute(create_mmo_users_table)
|
60 |
+
# c.execute(create_mmo_projects_table)
|
61 |
+
# c.execute(create_mmo_project_meta_data_table)
|
62 |
+
|
63 |
+
# # Commit changes and close the connection
|
64 |
+
# conn.commit()
|
65 |
+
# conn.close()
|
66 |
+
import sqlite3
|
67 |
+
|
68 |
+
def add_user_to_db(db_path, user_id, name, user_type, pswrd_key):
|
69 |
+
"""
|
70 |
+
Adds a user to the mmo_users table in the SQLite database.
|
71 |
+
|
72 |
+
Parameters:
|
73 |
+
- db_path (str): The path to the SQLite database file.
|
74 |
+
- user_id (str): The ID of the user.
|
75 |
+
- name (str): The name of the user.
|
76 |
+
- user_type (str): The type of the user.
|
77 |
+
- pswrd_key (str): The password key for the user.
|
78 |
+
"""
|
79 |
+
|
80 |
+
try:
|
81 |
+
# Connect to the SQLite database
|
82 |
+
conn = sqlite3.connect(db_path)
|
83 |
+
cursor = conn.cursor()
|
84 |
+
|
85 |
+
# SQL query to insert a new user
|
86 |
+
insert_query = """
|
87 |
+
INSERT INTO mmo_users (emp_id, emp_nam, emp_typ, pswrd_key,crte_by_uid)
|
88 |
+
VALUES (?, ?, ?, ?,?)
|
89 |
+
"""
|
90 |
+
|
91 |
+
# Execute the query with parameters
|
92 |
+
cursor.execute(insert_query, (user_id, name, user_type, pswrd_key, user_id))
|
93 |
+
|
94 |
+
# Commit the transaction
|
95 |
+
conn.commit()
|
96 |
+
|
97 |
+
print(f"User {name} added successfully.")
|
98 |
+
|
99 |
+
except sqlite3.Error as e:
|
100 |
+
print(f"Error adding user to the database: {e}")
|
101 |
+
|
102 |
+
finally:
|
103 |
+
# Close the database connection
|
104 |
+
conn.close()
|
105 |
+
|
106 |
+
# Define the database path and user details
|
107 |
+
db_path = r'db\imp_db.db' # Update this path to your actual database path
|
108 |
+
user_id = 'e162284'
|
109 |
+
name = 'admin'
|
110 |
+
user_type = 'admin'
|
111 |
+
pswrd_key = '$2b$12$wP7R0usvKWtr4X06qwGWvOFQCkzOZAzSVRAoDv/68x6GS4rHK5mDm'
|
112 |
+
|
113 |
+
# Add the user to the database
|
114 |
+
add_user_to_db(db_path, user_id, name, user_type, pswrd_key)
|
log_application.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contains logging functions to initialize and configure loggers, facilitating efficient tracking of ETL processes.
|
3 |
+
Also Provides methods to log messages with different log levels and additional context information such as channel name and client name
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import streamlit as st
|
8 |
+
import logging, logging.handlers
|
9 |
+
from datetime import datetime, timezone
|
10 |
+
|
11 |
+
|
12 |
+
def delete_old_log_files():
|
13 |
+
"""
|
14 |
+
Deletes all log files in the 'logs' directory that are older than the current month.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
None
|
18 |
+
"""
|
19 |
+
# Set the logs directory to 'logs' within the current working directory
|
20 |
+
logs_directory = os.path.join(os.getcwd(), "logs")
|
21 |
+
|
22 |
+
# Ensure the logs directory exists
|
23 |
+
if not os.path.exists(logs_directory):
|
24 |
+
print(f"Directory {logs_directory} does not exist.")
|
25 |
+
return
|
26 |
+
|
27 |
+
# Get the current year and month in UTC
|
28 |
+
current_utc_time = datetime.now(timezone.utc)
|
29 |
+
current_year_month = current_utc_time.strftime("%Y%m")
|
30 |
+
|
31 |
+
# Iterate through all files in the logs directory
|
32 |
+
for filename in os.listdir(logs_directory):
|
33 |
+
# We assume the log files are in the format "eid_YYYYMM.log"
|
34 |
+
if filename.endswith(".log"):
|
35 |
+
# Extract the date portion from the filename
|
36 |
+
try:
|
37 |
+
file_date = filename.split("_")[-1].replace(".log", "")
|
38 |
+
# Compare the file date with the current date
|
39 |
+
if file_date < current_year_month:
|
40 |
+
# Construct the full file path
|
41 |
+
file_path = os.path.join(logs_directory, filename)
|
42 |
+
# Delete the file
|
43 |
+
os.remove(file_path)
|
44 |
+
|
45 |
+
except IndexError:
|
46 |
+
# If the filename doesn't match the expected pattern, skip it
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
def create_log_file_name(emp_id, project_name):
|
51 |
+
"""
|
52 |
+
Generates a log file name using the format eid_YYYYMMDD.log based on UTC time.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
str: The generated log file name.
|
56 |
+
"""
|
57 |
+
# Get the current UTC time
|
58 |
+
utc_now = datetime.now(timezone.utc)
|
59 |
+
# Format the date as YYYYMMDD
|
60 |
+
date_str = utc_now.strftime("%Y%m%d")
|
61 |
+
# Create the file name with the format eid_YYYYMMDD.log
|
62 |
+
log_file_name = f"{emp_id}_{project_name}_{date_str}.log"
|
63 |
+
return log_file_name
|
64 |
+
|
65 |
+
|
66 |
+
def get_logger(log_file):
|
67 |
+
"""
|
68 |
+
Initializes and configures a logger. If the log file already exists, it appends to it.
|
69 |
+
Returns the same logger instance if called multiple times.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
log_file (str): The path to the log file where logs will be written.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
logging.Logger: Configured logger instance.
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Define the logger name
|
79 |
+
logger_name = os.path.basename(log_file)
|
80 |
+
logger = logging.getLogger(logger_name)
|
81 |
+
|
82 |
+
# If the logger already has handlers, return it to prevent duplicate handlers
|
83 |
+
if logger.hasHandlers():
|
84 |
+
return logger
|
85 |
+
|
86 |
+
# Set the logging level
|
87 |
+
logging_level = logging.INFO
|
88 |
+
logger.setLevel(logging_level)
|
89 |
+
|
90 |
+
# Ensure the logs directory exists and is writable
|
91 |
+
logs_dir = os.path.dirname(log_file)
|
92 |
+
if not os.path.exists(logs_dir):
|
93 |
+
os.makedirs(logs_dir, exist_ok=True)
|
94 |
+
|
95 |
+
# Change the directory permissions to 0777 (0o777 in Python 3)
|
96 |
+
os.chmod(logs_dir, 0o777)
|
97 |
+
|
98 |
+
# Create a file handler to append to the log file
|
99 |
+
file_handler = logging.FileHandler(log_file)
|
100 |
+
file_handler.setFormatter(
|
101 |
+
logging.Formatter("%(asctime)s %(levelname)s %(message)s", "%Y-%m-%d %H:%M:%S")
|
102 |
+
)
|
103 |
+
logger.addHandler(file_handler)
|
104 |
+
|
105 |
+
# Create a stream handler to print to console
|
106 |
+
stream_handler = logging.StreamHandler()
|
107 |
+
stream_handler.setFormatter(
|
108 |
+
logging.Formatter("%(asctime)s %(levelname)s %(message)s", "%Y-%m-%d %H:%M:%S")
|
109 |
+
)
|
110 |
+
logger.addHandler(stream_handler)
|
111 |
+
|
112 |
+
return logger
|
113 |
+
|
114 |
+
|
115 |
+
def log_message(log_type, message, page_name):
|
116 |
+
"""
|
117 |
+
Logs a message with the specified log type and additional context information.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
log_type (str): The type of log ('info', 'error', 'warning', 'debug').
|
121 |
+
message (str): The message to log.
|
122 |
+
page_name (str): The name of the page associated with the message.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
None
|
126 |
+
"""
|
127 |
+
|
128 |
+
# Retrieve the employee ID and project name from session state
|
129 |
+
emp_id = st.session_state["emp_id"]
|
130 |
+
project_name = st.session_state["project_name"]
|
131 |
+
|
132 |
+
# Create log file name using a function that generates a name based on the current date and EID
|
133 |
+
log_file_name = create_log_file_name(emp_id, project_name)
|
134 |
+
|
135 |
+
# Construct the full path to the log file within the "logs" within the current working directory
|
136 |
+
file_path = os.path.join(os.getcwd(), "logs", log_file_name)
|
137 |
+
|
138 |
+
# Generate and store the logger instance in session state
|
139 |
+
logger = get_logger(file_path)
|
140 |
+
|
141 |
+
# Construct the log message with all required context information
|
142 |
+
log_message = (
|
143 |
+
f"USER_EID: {emp_id} ; "
|
144 |
+
f"PROJECT_NAME: {project_name} ; "
|
145 |
+
f"PAGE_NAME: {page_name} ; "
|
146 |
+
f"MESSAGE: {message}"
|
147 |
+
)
|
148 |
+
|
149 |
+
# Log the message with the appropriate log level based on the log_type argument
|
150 |
+
if log_type == "info":
|
151 |
+
logger.info(log_message)
|
152 |
+
elif log_type == "error":
|
153 |
+
logger.error(log_message)
|
154 |
+
elif log_type == "warning":
|
155 |
+
logger.warning(log_message)
|
156 |
+
else:
|
157 |
+
logger.debug(log_message)
|
logo.png
ADDED
![]() |
logs/e111111_None_20250408.log
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
2025-04-08 13:32:28 INFO USER_EID: e111111 ; PROJECT_NAME: None ; PAGE_NAME: Home ; MESSAGE: Project number 1 selected by employee e111111.
|
logs/e111111_na_demo_user_demo_project_20250407.log
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
2025-04-07 18:10:49 INFO USER_EID: e111111 ; PROJECT_NAME: na_demo_user_demo_project ; PAGE_NAME: Home ; MESSAGE: Project number 642 selected by employee e111111.
|
logs/e111111_na_demo_user_demo_project_20250408.log
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
2025-04-08 13:32:31 INFO USER_EID: e111111 ; PROJECT_NAME: na_demo_user_demo_project ; PAGE_NAME: Home ; MESSAGE: Project number 1 selected by employee e111111.
|
2 |
+
2025-04-08 13:32:53 INFO USER_EID: e111111 ; PROJECT_NAME: na_demo_user_demo_project ; PAGE_NAME: Home ; MESSAGE: Project number 1 selected by employee e111111.
|
mmm_tool_document.docx
ADDED
Binary file (36.1 kB). View file
|
|
pages/10_Saved_Scenarios.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Saved Scenarios",
|
6 |
+
page_icon="⚖️",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="collapsed",
|
9 |
+
)
|
10 |
+
|
11 |
+
import io
|
12 |
+
import sys
|
13 |
+
import json
|
14 |
+
import pickle
|
15 |
+
import zipfile
|
16 |
+
import traceback
|
17 |
+
import numpy as np
|
18 |
+
import pandas as pd
|
19 |
+
from scenario import numerize
|
20 |
+
from openpyxl import Workbook
|
21 |
+
from post_gres_cred import db_cred
|
22 |
+
from log_application import log_message
|
23 |
+
from utilities import (
|
24 |
+
project_selection,
|
25 |
+
update_db,
|
26 |
+
set_header,
|
27 |
+
load_local_css,
|
28 |
+
name_formating,
|
29 |
+
)
|
30 |
+
|
31 |
+
schema = db_cred["schema"]
|
32 |
+
load_local_css("styles.css")
|
33 |
+
set_header()
|
34 |
+
|
35 |
+
# Initialize project name session state
|
36 |
+
if "project_name" not in st.session_state:
|
37 |
+
st.session_state["project_name"] = None
|
38 |
+
|
39 |
+
# Fetch project dictionary
|
40 |
+
if "project_dct" not in st.session_state:
|
41 |
+
project_selection()
|
42 |
+
st.stop()
|
43 |
+
|
44 |
+
# Display Username and Project Name
|
45 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
46 |
+
|
47 |
+
cols1 = st.columns([2, 1])
|
48 |
+
|
49 |
+
with cols1[0]:
|
50 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
51 |
+
with cols1[1]:
|
52 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
53 |
+
|
54 |
+
# Function to get saved scenarios dictionary
|
55 |
+
def get_saved_scenarios_dict():
|
56 |
+
return st.session_state["project_dct"]["saved_scenarios"][
|
57 |
+
"saved_scenarios_dict"
|
58 |
+
]
|
59 |
+
|
60 |
+
|
61 |
+
# Function to format values based on their size
|
62 |
+
def format_value(value):
|
63 |
+
return round(value, 4) if value < 1 else round(value, 1)
|
64 |
+
|
65 |
+
|
66 |
+
# Function to recursively convert non-serializable types to serializable ones
|
67 |
+
def convert_to_serializable(obj):
|
68 |
+
if isinstance(obj, np.ndarray):
|
69 |
+
return obj.tolist()
|
70 |
+
elif isinstance(obj, dict):
|
71 |
+
return {key: convert_to_serializable(value) for key, value in obj.items()}
|
72 |
+
elif isinstance(obj, list):
|
73 |
+
return [convert_to_serializable(element) for element in obj]
|
74 |
+
elif isinstance(obj, (int, float, str, bool, type(None))):
|
75 |
+
return obj
|
76 |
+
else:
|
77 |
+
# Fallback: convert the object to a string
|
78 |
+
return str(obj)
|
79 |
+
|
80 |
+
|
81 |
+
# Function to generate zip file of current scenario
|
82 |
+
@st.cache_data(show_spinner=False)
|
83 |
+
def download_as_zip(
|
84 |
+
df,
|
85 |
+
scenario_data,
|
86 |
+
excel_name="optimization_results.xlsx",
|
87 |
+
json_name="scenario_params.json",
|
88 |
+
):
|
89 |
+
# Create an in-memory bytes buffer for the ZIP file
|
90 |
+
buffer = io.BytesIO()
|
91 |
+
|
92 |
+
# Create a ZipFile object in memory
|
93 |
+
with zipfile.ZipFile(buffer, "w") as zip_file:
|
94 |
+
# Save the DataFrame to an Excel file in the zip using openpyxl
|
95 |
+
excel_buffer = io.BytesIO()
|
96 |
+
workbook = Workbook()
|
97 |
+
sheet = workbook.active
|
98 |
+
sheet.title = "Results"
|
99 |
+
|
100 |
+
# Write DataFrame headers
|
101 |
+
for col_num, column_title in enumerate(df.columns, 1):
|
102 |
+
sheet.cell(row=1, column=col_num, value=column_title)
|
103 |
+
|
104 |
+
# Write DataFrame data
|
105 |
+
for row_num, row_data in enumerate(df.values, 2):
|
106 |
+
for col_num, cell_value in enumerate(row_data, 1):
|
107 |
+
sheet.cell(row=row_num, column=col_num, value=cell_value)
|
108 |
+
|
109 |
+
# Save the workbook to the in-memory buffer
|
110 |
+
workbook.save(excel_buffer)
|
111 |
+
excel_buffer.seek(0) # Rewind the buffer to the beginning
|
112 |
+
zip_file.writestr(excel_name, excel_buffer.getvalue())
|
113 |
+
|
114 |
+
# Save the dictionary to a JSON file in the zip
|
115 |
+
json_buffer = io.BytesIO()
|
116 |
+
json_buffer.write(
|
117 |
+
json.dumps(convert_to_serializable(scenario_data), indent=4).encode("utf-8")
|
118 |
+
)
|
119 |
+
json_buffer.seek(0) # Rewind the buffer to the beginning
|
120 |
+
zip_file.writestr(json_name, json_buffer.getvalue())
|
121 |
+
|
122 |
+
buffer.seek(0) # Rewind the buffer to the beginning
|
123 |
+
|
124 |
+
return buffer
|
125 |
+
|
126 |
+
|
127 |
+
# Function to delete the selected scenario from the saved scenarios dictionary
|
128 |
+
def delete_selected_scenarios(selected_scenario):
|
129 |
+
if (
|
130 |
+
selected_scenario
|
131 |
+
in st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"]
|
132 |
+
):
|
133 |
+
del st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"][
|
134 |
+
selected_scenario
|
135 |
+
]
|
136 |
+
|
137 |
+
|
138 |
+
try:
|
139 |
+
# Page Title
|
140 |
+
st.title("Saved Scenarios")
|
141 |
+
|
142 |
+
# Placeholder to display scenarios name
|
143 |
+
scenarios_name_placeholder = st.empty()
|
144 |
+
|
145 |
+
# Get saved scenarios dictionary and scenario name list
|
146 |
+
saved_scenarios_dict = get_saved_scenarios_dict()
|
147 |
+
scenarios_list = list(saved_scenarios_dict.keys())
|
148 |
+
|
149 |
+
# Check if the list of saved scenarios is empty
|
150 |
+
if len(scenarios_list) == 0:
|
151 |
+
# Display a warning message if no scenarios are saved
|
152 |
+
st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
|
153 |
+
|
154 |
+
# Log message
|
155 |
+
log_message(
|
156 |
+
"warning",
|
157 |
+
"No scenarios saved. Please save a scenario to load.",
|
158 |
+
"Saved Scenarios",
|
159 |
+
)
|
160 |
+
|
161 |
+
st.stop()
|
162 |
+
|
163 |
+
# Columns for scenario selection and save progress
|
164 |
+
select_scenario_col, save_progress_col = st.columns(2)
|
165 |
+
save_message_display_placeholder = st.container()
|
166 |
+
|
167 |
+
# Display a dropdown saved scenario list
|
168 |
+
selected_scenario = select_scenario_col.selectbox(
|
169 |
+
"Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
|
170 |
+
)
|
171 |
+
|
172 |
+
# Save page progress
|
173 |
+
with save_progress_col:
|
174 |
+
st.write("###")
|
175 |
+
with save_message_display_placeholder, st.spinner("Saving Progress ..."):
|
176 |
+
if save_progress_col.button("Save Progress", use_container_width=True):
|
177 |
+
# Update DB
|
178 |
+
update_db(
|
179 |
+
prj_id=st.session_state["project_number"],
|
180 |
+
page_nam="Saved Scenarios",
|
181 |
+
file_nam="project_dct",
|
182 |
+
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
|
183 |
+
schema=schema,
|
184 |
+
)
|
185 |
+
|
186 |
+
# Display success message
|
187 |
+
st.success("Progress saved successfully!", icon="💾")
|
188 |
+
st.toast("Progress saved successfully!", icon="💾")
|
189 |
+
|
190 |
+
# Log message
|
191 |
+
log_message("info", "Progress saved successfully!", "Saved Scenarios")
|
192 |
+
|
193 |
+
selected_scenario_data = saved_scenarios_dict[selected_scenario]
|
194 |
+
|
195 |
+
# Scenarios Name
|
196 |
+
metrics_name = selected_scenario_data["metrics_selected"]
|
197 |
+
panel_name = selected_scenario_data["panel_selected"]
|
198 |
+
optimization_name = selected_scenario_data["optimization"]
|
199 |
+
multiplier = selected_scenario_data["multiplier"]
|
200 |
+
timeframe = selected_scenario_data["timeframe"]
|
201 |
+
|
202 |
+
# Display the scenario details with bold "Metric," "Panel," and "Optimization"
|
203 |
+
scenarios_name_placeholder.markdown(
|
204 |
+
f"**Metric**: {name_formating(metrics_name)}; **Panel**: {name_formating(panel_name)}; **Fix**: {name_formating(optimization_name)}; **Timeframe**: {name_formating(timeframe)}"
|
205 |
+
)
|
206 |
+
|
207 |
+
# Create columns for download and delete buttons
|
208 |
+
download_col, delete_col = st.columns(2)
|
209 |
+
save_message_display_placeholder = st.container()
|
210 |
+
|
211 |
+
# Channel List
|
212 |
+
channels_list = list(selected_scenario_data["channels"].keys())
|
213 |
+
|
214 |
+
# List to hold data for all channels
|
215 |
+
channels_data = []
|
216 |
+
|
217 |
+
# Iterate through each channel and gather required data
|
218 |
+
for channel in channels_list:
|
219 |
+
channel_conversion_rate = selected_scenario_data["channels"][channel][
|
220 |
+
"conversion_rate"
|
221 |
+
]
|
222 |
+
channel_actual_spends = (
|
223 |
+
selected_scenario_data["channels"][channel]["actual_total_spends"]
|
224 |
+
* channel_conversion_rate
|
225 |
+
)
|
226 |
+
channel_optimized_spends = (
|
227 |
+
selected_scenario_data["channels"][channel]["modified_total_spends"]
|
228 |
+
* channel_conversion_rate
|
229 |
+
)
|
230 |
+
|
231 |
+
channel_actual_metrics = selected_scenario_data["channels"][channel][
|
232 |
+
"actual_total_sales"
|
233 |
+
]
|
234 |
+
channel_optimized_metrics = selected_scenario_data["channels"][channel][
|
235 |
+
"modified_total_sales"
|
236 |
+
]
|
237 |
+
|
238 |
+
channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]
|
239 |
+
|
240 |
+
# Extract the ROI and MROI data
|
241 |
+
actual_roi = channel_roi_mroi_data["actual_roi"]
|
242 |
+
optimized_roi = channel_roi_mroi_data["optimized_roi"]
|
243 |
+
actual_mroi = channel_roi_mroi_data["actual_mroi"]
|
244 |
+
optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
|
245 |
+
|
246 |
+
# Calculate spends per metric
|
247 |
+
spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
|
248 |
+
spends_per_metrics_optimized = (
|
249 |
+
channel_optimized_spends / channel_optimized_metrics
|
250 |
+
)
|
251 |
+
|
252 |
+
# Append the collected data as a dictionary to the list
|
253 |
+
channels_data.append(
|
254 |
+
{
|
255 |
+
"Channel Name": channel,
|
256 |
+
"Spends Actual": numerize(channel_actual_spends / multiplier),
|
257 |
+
"Spends Optimized": numerize(channel_optimized_spends / multiplier),
|
258 |
+
f"{name_formating(metrics_name)} Actual": numerize(
|
259 |
+
channel_actual_metrics / multiplier
|
260 |
+
),
|
261 |
+
f"{name_formating(metrics_name)} Optimized": numerize(
|
262 |
+
channel_optimized_metrics / multiplier
|
263 |
+
),
|
264 |
+
"ROI Actual": format_value(actual_roi),
|
265 |
+
"ROI Optimized": format_value(optimized_roi),
|
266 |
+
"MROI Actual": format_value(actual_mroi),
|
267 |
+
"MROI Optimized": format_value(optimized_mroi),
|
268 |
+
f"Spends per {name_formating(metrics_name)} Actual": round(
|
269 |
+
spends_per_metrics_actual, 2
|
270 |
+
),
|
271 |
+
f"Spends per {name_formating(metrics_name)} Optimized": round(
|
272 |
+
spends_per_metrics_optimized, 2
|
273 |
+
),
|
274 |
+
}
|
275 |
+
)
|
276 |
+
|
277 |
+
# Create a DataFrame from the collected data
|
278 |
+
df = pd.DataFrame(channels_data)
|
279 |
+
|
280 |
+
# Display the DataFrame
|
281 |
+
st.dataframe(df, hide_index=True)
|
282 |
+
|
283 |
+
# Generate download able data for selected scenario
|
284 |
+
buffer = download_as_zip(
|
285 |
+
df,
|
286 |
+
selected_scenario_data,
|
287 |
+
excel_name="optimization_results.xlsx",
|
288 |
+
json_name="scenario_params.json",
|
289 |
+
)
|
290 |
+
|
291 |
+
# Provide the buffer as a downloadable ZIP file
|
292 |
+
file_name = f"{selected_scenario}_scenario_data.zip"
|
293 |
+
if download_col.download_button(
|
294 |
+
label="Download",
|
295 |
+
data=buffer,
|
296 |
+
file_name=file_name,
|
297 |
+
mime="application/zip",
|
298 |
+
use_container_width=True,
|
299 |
+
):
|
300 |
+
# Log message
|
301 |
+
log_message(
|
302 |
+
"info",
|
303 |
+
f"FILE_NAME: {file_name} has been successfully downloaded.",
|
304 |
+
"Saved Scenarios",
|
305 |
+
)
|
306 |
+
|
307 |
+
# Button to trigger the deletion of the selected scenario
|
308 |
+
if delete_col.button(
|
309 |
+
"Delete",
|
310 |
+
use_container_width=True,
|
311 |
+
on_click=delete_selected_scenarios,
|
312 |
+
args=(selected_scenario,),
|
313 |
+
):
|
314 |
+
# Display success message
|
315 |
+
with save_message_display_placeholder:
|
316 |
+
st.success(
|
317 |
+
"Selected scenario successfully deleted. Click the 'Save Progress' button to ensure your changes are updated!",
|
318 |
+
icon="🗑️",
|
319 |
+
)
|
320 |
+
st.toast(
|
321 |
+
"Selected scenario successfully deleted. Click the 'Save Progress' button to ensure your changes are updated!",
|
322 |
+
icon="🗑️",
|
323 |
+
)
|
324 |
+
|
325 |
+
# Log message
|
326 |
+
log_message(
|
327 |
+
"info", "Selected scenario successfully deleted.", "Saved Scenarios"
|
328 |
+
)
|
329 |
+
|
330 |
+
except Exception as e:
|
331 |
+
# Capture the error details
|
332 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
333 |
+
error_message = "".join(
|
334 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
335 |
+
)
|
336 |
+
|
337 |
+
# Log message
|
338 |
+
log_message("error", f"An error occurred: {error_message}.", "Saved Scenarios")
|
339 |
+
|
340 |
+
# Display a warning message
|
341 |
+
st.warning(
|
342 |
+
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
|
343 |
+
icon="⚠️",
|
344 |
+
)
|
pages/11_AI_Model_Media_Recommendation.py
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from scenario import numerize
|
3 |
+
import pandas as pd
|
4 |
+
from utilities import (
|
5 |
+
format_numbers,
|
6 |
+
load_local_css,
|
7 |
+
set_header,
|
8 |
+
name_formating,
|
9 |
+
project_selection,
|
10 |
+
)
|
11 |
+
import pickle
|
12 |
+
import yaml
|
13 |
+
from yaml import SafeLoader
|
14 |
+
from scenario import class_from_dict
|
15 |
+
import plotly.express as px
|
16 |
+
import numpy as np
|
17 |
+
import plotly.graph_objects as go
|
18 |
+
import pandas as pd
|
19 |
+
from plotly.subplots import make_subplots
|
20 |
+
import sqlite3
|
21 |
+
from utilities import update_db
|
22 |
+
from collections import OrderedDict
|
23 |
+
import os
|
24 |
+
|
25 |
+
st.set_page_config(layout="wide")
|
26 |
+
load_local_css("styles.css")
|
27 |
+
set_header()
|
28 |
+
|
29 |
+
st.empty()
|
30 |
+
st.header("AI Model Media Recommendation")
|
31 |
+
|
32 |
+
# def get_saved_scenarios_dict():
|
33 |
+
# # Path to the saved scenarios file
|
34 |
+
# saved_scenarios_dict_path = os.path.join(
|
35 |
+
# st.session_state["project_path"], "saved_scenarios.pkl"
|
36 |
+
# )
|
37 |
+
|
38 |
+
# # Load existing scenarios if the file exists
|
39 |
+
# if os.path.exists(saved_scenarios_dict_path):
|
40 |
+
# with open(saved_scenarios_dict_path, "rb") as f:
|
41 |
+
# saved_scenarios_dict = pickle.load(f)
|
42 |
+
# else:
|
43 |
+
# saved_scenarios_dict = OrderedDict()
|
44 |
+
|
45 |
+
# return saved_scenarios_dict
|
46 |
+
|
47 |
+
|
48 |
+
# # Function to format values based on their size
|
49 |
+
# def format_value(value):
|
50 |
+
# return round(value, 4) if value < 1 else round(value, 1)
|
51 |
+
|
52 |
+
|
53 |
+
# # Function to recursively convert non-serializable types to serializable ones
|
54 |
+
# def convert_to_serializable(obj):
|
55 |
+
# if isinstance(obj, np.ndarray):
|
56 |
+
# return obj.tolist()
|
57 |
+
# elif isinstance(obj, dict):
|
58 |
+
# return {key: convert_to_serializable(value) for key, value in obj.items()}
|
59 |
+
# elif isinstance(obj, list):
|
60 |
+
# return [convert_to_serializable(element) for element in obj]
|
61 |
+
# elif isinstance(obj, (int, float, str, bool, type(None))):
|
62 |
+
# return obj
|
63 |
+
# else:
|
64 |
+
# # Fallback: convert the object to a string
|
65 |
+
# return str(obj)
|
66 |
+
|
67 |
+
|
68 |
+
if "username" not in st.session_state:
|
69 |
+
st.session_state["username"] = None
|
70 |
+
|
71 |
+
if "project_name" not in st.session_state:
|
72 |
+
st.session_state["project_name"] = None
|
73 |
+
|
74 |
+
if "project_dct" not in st.session_state:
|
75 |
+
project_selection()
|
76 |
+
st.stop()
|
77 |
+
# if "project_path" not in st.session_state:
|
78 |
+
# st.stop()
|
79 |
+
# if 'username' in st.session_state and st.session_state['username'] is not None:
|
80 |
+
|
81 |
+
# data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")
|
82 |
+
|
83 |
+
# try:
|
84 |
+
# with open(data_path, "rb") as f:
|
85 |
+
# data = pickle.load(f)
|
86 |
+
# except Exception as e:
|
87 |
+
# st.error(f"Please import data from the Data Import Page")
|
88 |
+
# st.stop()
|
89 |
+
# # Get saved scenarios dictionary and scenario name list
|
90 |
+
# saved_scenarios_dict = get_saved_scenarios_dict()
|
91 |
+
# scenarios_list = list(saved_scenarios_dict.keys())
|
92 |
+
|
93 |
+
# #st.write(saved_scenarios_dict)
|
94 |
+
# # Check if the list of saved scenarios is empty
|
95 |
+
# if len(scenarios_list) == 0:
|
96 |
+
# # Display a warning message if no scenarios are saved
|
97 |
+
# st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
|
98 |
+
# st.stop()
|
99 |
+
|
100 |
+
# # Display a dropdown saved scenario list
|
101 |
+
# selected_scenario = st.selectbox(
|
102 |
+
# "Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
|
103 |
+
# )
|
104 |
+
# selected_scenario_data = saved_scenarios_dict[selected_scenario]
|
105 |
+
|
106 |
+
# # Scenarios Name
|
107 |
+
# metrics_name = selected_scenario_data["metrics_selected"]
|
108 |
+
# panel_name = selected_scenario_data["panel_selected"]
|
109 |
+
# optimization_name = selected_scenario_data["optimization"]
|
110 |
+
|
111 |
+
# # Display the scenario details with bold "Metric," "Panel," and "Optimization"
|
112 |
+
|
113 |
+
# # Create columns for download and delete buttons
|
114 |
+
# download_col, delete_col = st.columns(2)
|
115 |
+
|
116 |
+
|
117 |
+
# channels_list = list(selected_scenario_data["channels"].keys())
|
118 |
+
|
119 |
+
# # List to hold data for all channels
|
120 |
+
# channels_data = []
|
121 |
+
|
122 |
+
# # Iterate through each channel and gather required data
|
123 |
+
# for channel in channels_list:
|
124 |
+
# channel_conversion_rate = selected_scenario_data["channels"][channel][
|
125 |
+
# "conversion_rate"
|
126 |
+
# ]
|
127 |
+
# channel_actual_spends = (
|
128 |
+
# selected_scenario_data["channels"][channel]["actual_total_spends"]
|
129 |
+
# * channel_conversion_rate
|
130 |
+
# )
|
131 |
+
# channel_optimized_spends = (
|
132 |
+
# selected_scenario_data["channels"][channel]["modified_total_spends"]
|
133 |
+
# * channel_conversion_rate
|
134 |
+
# )
|
135 |
+
|
136 |
+
# channel_actual_metrics = selected_scenario_data["channels"][channel][
|
137 |
+
# "actual_total_sales"
|
138 |
+
# ]
|
139 |
+
# channel_optimized_metrics = selected_scenario_data["channels"][channel][
|
140 |
+
# "modified_total_sales"
|
141 |
+
# ]
|
142 |
+
|
143 |
+
# channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]
|
144 |
+
|
145 |
+
# # Extract the ROI and MROI data
|
146 |
+
# actual_roi = channel_roi_mroi_data["actual_roi"]
|
147 |
+
# optimized_roi = channel_roi_mroi_data["optimized_roi"]
|
148 |
+
# actual_mroi = channel_roi_mroi_data["actual_mroi"]
|
149 |
+
# optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
|
150 |
+
|
151 |
+
# # Calculate spends per metric
|
152 |
+
# spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
|
153 |
+
# spends_per_metrics_optimized = channel_optimized_spends / channel_optimized_metrics
|
154 |
+
|
155 |
+
# # Append the collected data as a dictionary to the list
|
156 |
+
# channels_data.append(
|
157 |
+
# {
|
158 |
+
# "Channel Name": channel,
|
159 |
+
# "Spends Actual": channel_actual_spends,
|
160 |
+
# "Spends Optimized": channel_optimized_spends,
|
161 |
+
# f"{metrics_name} Actual": channel_actual_metrics,
|
162 |
+
# f"{name_formating(metrics_name)} Optimized": numerize(
|
163 |
+
# channel_optimized_metrics
|
164 |
+
# ),
|
165 |
+
# "ROI Actual": format_value(actual_roi),
|
166 |
+
# "ROI Optimized": format_value(optimized_roi),
|
167 |
+
# "MROI Actual": format_value(actual_mroi),
|
168 |
+
# "MROI Optimized": format_value(optimized_mroi),
|
169 |
+
# f"Spends per {name_formating(metrics_name)} Actual": numerize(
|
170 |
+
# spends_per_metrics_actual
|
171 |
+
# ),
|
172 |
+
# f"Spends per {name_formating(metrics_name)} Optimized": numerize(
|
173 |
+
# spends_per_metrics_optimized
|
174 |
+
# ),
|
175 |
+
# }
|
176 |
+
# )
|
177 |
+
|
178 |
+
# # Create a DataFrame from the collected data
|
179 |
+
|
180 |
+
##NEW CODE##########
|
181 |
+
|
182 |
+
scenarios_name_placeholder = st.empty()
|
183 |
+
|
184 |
+
|
185 |
+
# Function to get saved scenarios dictionary
|
186 |
+
def get_saved_scenarios_dict():
|
187 |
+
return st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"]
|
188 |
+
|
189 |
+
|
190 |
+
# Function to format values based on their size
|
191 |
+
def format_value(value):
|
192 |
+
return round(value, 4) if value < 1 else round(value, 1)
|
193 |
+
|
194 |
+
|
195 |
+
# Function to recursively convert non-serializable types to serializable ones
|
196 |
+
def convert_to_serializable(obj):
|
197 |
+
if isinstance(obj, np.ndarray):
|
198 |
+
return obj.tolist()
|
199 |
+
elif isinstance(obj, dict):
|
200 |
+
return {key: convert_to_serializable(value) for key, value in obj.items()}
|
201 |
+
elif isinstance(obj, list):
|
202 |
+
return [convert_to_serializable(element) for element in obj]
|
203 |
+
elif isinstance(obj, (int, float, str, bool, type(None))):
|
204 |
+
return obj
|
205 |
+
else:
|
206 |
+
# Fallback: convert the object to a string
|
207 |
+
return str(obj)
|
208 |
+
|
209 |
+
|
210 |
+
# Get saved scenarios dictionary and scenario name list
|
211 |
+
saved_scenarios_dict = get_saved_scenarios_dict()
|
212 |
+
scenarios_list = list(saved_scenarios_dict.keys())
|
213 |
+
|
214 |
+
# Check if the list of saved scenarios is empty
|
215 |
+
if len(scenarios_list) == 0:
|
216 |
+
# Display a warning message if no scenarios are saved
|
217 |
+
st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
|
218 |
+
st.stop()
|
219 |
+
|
220 |
+
# Display a dropdown saved scenario list
|
221 |
+
selected_scenario = st.selectbox(
|
222 |
+
"Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
|
223 |
+
)
|
224 |
+
selected_scenario_data = saved_scenarios_dict[selected_scenario]
|
225 |
+
|
226 |
+
# Scenarios Name
|
227 |
+
metrics_name = selected_scenario_data["metrics_selected"]
|
228 |
+
panel_name = selected_scenario_data["panel_selected"]
|
229 |
+
optimization_name = selected_scenario_data["optimization"]
|
230 |
+
multiplier = selected_scenario_data["multiplier"]
|
231 |
+
timeframe = selected_scenario_data["timeframe"]
|
232 |
+
|
233 |
+
# Display the scenario details with bold "Metric," "Panel," and "Optimization"
|
234 |
+
scenarios_name_placeholder.markdown(
|
235 |
+
f"**Metric**: {name_formating(metrics_name)}; **Panel**: {name_formating(panel_name)}; **Fix**: {name_formating(optimization_name)}; **Timeframe**: {name_formating(timeframe)}"
|
236 |
+
)
|
237 |
+
|
238 |
+
# Create columns for download and delete buttons
|
239 |
+
download_col, delete_col = st.columns(2)
|
240 |
+
|
241 |
+
# Channel List
|
242 |
+
channels_list = list(selected_scenario_data["channels"].keys())
|
243 |
+
|
244 |
+
# List to hold data for all channels
|
245 |
+
channels_data = []
|
246 |
+
|
247 |
+
# Iterate through each channel and gather required data
|
248 |
+
for channel in channels_list:
|
249 |
+
channel_conversion_rate = selected_scenario_data["channels"][channel][
|
250 |
+
"conversion_rate"
|
251 |
+
]
|
252 |
+
channel_actual_spends = (
|
253 |
+
selected_scenario_data["channels"][channel]["actual_total_spends"]
|
254 |
+
* channel_conversion_rate
|
255 |
+
)
|
256 |
+
channel_optimized_spends = (
|
257 |
+
selected_scenario_data["channels"][channel]["modified_total_spends"]
|
258 |
+
* channel_conversion_rate
|
259 |
+
)
|
260 |
+
|
261 |
+
channel_actual_metrics = selected_scenario_data["channels"][channel][
|
262 |
+
"actual_total_sales"
|
263 |
+
]
|
264 |
+
channel_optimized_metrics = selected_scenario_data["channels"][channel][
|
265 |
+
"modified_total_sales"
|
266 |
+
]
|
267 |
+
|
268 |
+
channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]
|
269 |
+
|
270 |
+
# Extract the ROI and MROI data
|
271 |
+
actual_roi = channel_roi_mroi_data["actual_roi"]
|
272 |
+
optimized_roi = channel_roi_mroi_data["optimized_roi"]
|
273 |
+
actual_mroi = channel_roi_mroi_data["actual_mroi"]
|
274 |
+
optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
|
275 |
+
|
276 |
+
# Calculate spends per metric
|
277 |
+
spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
|
278 |
+
spends_per_metrics_optimized = channel_optimized_spends / channel_optimized_metrics
|
279 |
+
|
280 |
+
# Append the collected data as a dictionary to the list
|
281 |
+
channels_data.append(
|
282 |
+
{
|
283 |
+
"Channel Name": channel,
|
284 |
+
"Spends Actual": (channel_actual_spends / multiplier),
|
285 |
+
"Spends Optimized": (channel_optimized_spends / multiplier),
|
286 |
+
f"{name_formating(metrics_name)} Actual": (
|
287 |
+
channel_actual_metrics / multiplier
|
288 |
+
),
|
289 |
+
f"{name_formating(metrics_name)} Optimized": (
|
290 |
+
channel_optimized_metrics / multiplier
|
291 |
+
),
|
292 |
+
"ROI Actual": format_value(actual_roi),
|
293 |
+
"ROI Optimized": format_value(optimized_roi),
|
294 |
+
"MROI Actual": format_value(actual_mroi),
|
295 |
+
"MROI Optimized": format_value(optimized_mroi),
|
296 |
+
f"Spends per {name_formating(metrics_name)} Actual": round(
|
297 |
+
spends_per_metrics_actual, 2
|
298 |
+
),
|
299 |
+
f"Spends per {name_formating(metrics_name)} Optimized": round(
|
300 |
+
spends_per_metrics_optimized, 2
|
301 |
+
),
|
302 |
+
}
|
303 |
+
)
|
304 |
+
|
305 |
+
# Create a DataFrame from the collected data
|
306 |
+
# df = pd.DataFrame(channels_data)
|
307 |
+
|
308 |
+
# # Display the DataFrame
|
309 |
+
# st.dataframe(df, hide_index=True)
|
310 |
+
|
311 |
+
summary_df_sorted = pd.DataFrame(channels_data).sort_values(by=["Spends Optimized"])
|
312 |
+
|
313 |
+
|
314 |
+
summary_df_sorted["Delta"] = (
|
315 |
+
summary_df_sorted["Spends Optimized"] - summary_df_sorted["Spends Actual"]
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
summary_df_sorted["Delta_percent"] = np.round(
|
320 |
+
(summary_df_sorted["Delta"]) / summary_df_sorted["Spends Actual"] * 100, 2
|
321 |
+
)
|
322 |
+
|
323 |
+
# spends_data = pd.read_excel("Overview_data_test.xlsx")
|
324 |
+
|
325 |
+
|
326 |
+
st.header("Optimized Media Spend Overview")
|
327 |
+
|
328 |
+
channel_colors = px.colors.qualitative.Plotly
|
329 |
+
|
330 |
+
fig = make_subplots(
|
331 |
+
rows=1,
|
332 |
+
cols=3,
|
333 |
+
subplot_titles=("Actual Spend", "Spends Optimized", "Delta"),
|
334 |
+
horizontal_spacing=0.05,
|
335 |
+
)
|
336 |
+
|
337 |
+
for i, channel in enumerate(summary_df_sorted["Channel Name"].unique()):
|
338 |
+
channel_df = summary_df_sorted[summary_df_sorted["Channel Name"] == channel]
|
339 |
+
channel_color = channel_colors[i % len(channel_colors)]
|
340 |
+
|
341 |
+
fig.add_trace(
|
342 |
+
go.Bar(
|
343 |
+
x=channel_df["Spends Actual"],
|
344 |
+
y=channel_df["Channel Name"],
|
345 |
+
text=channel_df["Spends Actual"].apply(format_numbers),
|
346 |
+
marker_color=channel_color,
|
347 |
+
orientation="h",
|
348 |
+
),
|
349 |
+
row=1,
|
350 |
+
col=1,
|
351 |
+
)
|
352 |
+
|
353 |
+
fig.add_trace(
|
354 |
+
go.Bar(
|
355 |
+
x=channel_df["Spends Optimized"],
|
356 |
+
y=channel_df["Channel Name"],
|
357 |
+
text=channel_df["Spends Optimized"].apply(format_numbers),
|
358 |
+
marker_color=channel_color,
|
359 |
+
orientation="h",
|
360 |
+
showlegend=False,
|
361 |
+
),
|
362 |
+
row=1,
|
363 |
+
col=2,
|
364 |
+
)
|
365 |
+
|
366 |
+
fig.add_trace(
|
367 |
+
go.Bar(
|
368 |
+
x=channel_df["Delta_percent"],
|
369 |
+
y=channel_df["Channel Name"],
|
370 |
+
text=channel_df["Delta_percent"].apply(lambda x: f"{x:.0f}%"),
|
371 |
+
marker_color=channel_color,
|
372 |
+
orientation="h",
|
373 |
+
showlegend=False,
|
374 |
+
),
|
375 |
+
row=1,
|
376 |
+
col=3,
|
377 |
+
)
|
378 |
+
fig.update_layout(height=600, width=900, title="", showlegend=False)
|
379 |
+
|
380 |
+
fig.update_yaxes(showticklabels=False, row=1, col=2)
|
381 |
+
fig.update_yaxes(showticklabels=False, row=1, col=3)
|
382 |
+
|
383 |
+
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
384 |
+
fig.update_xaxes(showticklabels=False, row=1, col=2)
|
385 |
+
fig.update_xaxes(showticklabels=False, row=1, col=3)
|
386 |
+
|
387 |
+
|
388 |
+
st.plotly_chart(fig, use_container_width=True)
|
389 |
+
|
390 |
+
|
391 |
+
summary_df_sorted["Perc_alloted"] = np.round(
|
392 |
+
summary_df_sorted["Spends Optimized"] / summary_df_sorted["Spends Optimized"].sum(),
|
393 |
+
2,
|
394 |
+
)
|
395 |
+
st.header("Optimized Media Spend Allocation")
|
396 |
+
|
397 |
+
fig = make_subplots(
|
398 |
+
rows=1,
|
399 |
+
cols=2,
|
400 |
+
subplot_titles=("Spends Optimized", "% Split"),
|
401 |
+
horizontal_spacing=0.05,
|
402 |
+
)
|
403 |
+
|
404 |
+
for i, channel in enumerate(summary_df_sorted["Channel Name"].unique()):
|
405 |
+
channel_df = summary_df_sorted[summary_df_sorted["Channel Name"] == channel]
|
406 |
+
channel_color = channel_colors[i % len(channel_colors)]
|
407 |
+
|
408 |
+
fig.add_trace(
|
409 |
+
go.Bar(
|
410 |
+
x=channel_df["Spends Optimized"],
|
411 |
+
y=channel_df["Channel Name"],
|
412 |
+
text=channel_df["Spends Optimized"].apply(format_numbers),
|
413 |
+
marker_color=channel_color,
|
414 |
+
orientation="h",
|
415 |
+
),
|
416 |
+
row=1,
|
417 |
+
col=1,
|
418 |
+
)
|
419 |
+
|
420 |
+
fig.add_trace(
|
421 |
+
go.Bar(
|
422 |
+
x=channel_df["Perc_alloted"],
|
423 |
+
y=channel_df["Channel Name"],
|
424 |
+
text=channel_df["Perc_alloted"].apply(lambda x: f"{100*x:.0f}%"),
|
425 |
+
marker_color=channel_color,
|
426 |
+
orientation="h",
|
427 |
+
showlegend=False,
|
428 |
+
),
|
429 |
+
row=1,
|
430 |
+
col=2,
|
431 |
+
)
|
432 |
+
|
433 |
+
fig.update_layout(height=600, width=900, title="", showlegend=False)
|
434 |
+
|
435 |
+
fig.update_yaxes(showticklabels=False, row=1, col=2)
|
436 |
+
fig.update_yaxes(showticklabels=False, row=1, col=3)
|
437 |
+
|
438 |
+
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
439 |
+
fig.update_xaxes(showticklabels=False, row=1, col=2)
|
440 |
+
fig.update_xaxes(showticklabels=False, row=1, col=3)
|
441 |
+
|
442 |
+
|
443 |
+
st.plotly_chart(fig, use_container_width=True)
|
444 |
+
|
445 |
+
|
446 |
+
st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][
|
447 |
+
"imputed_tool_df"
|
448 |
+
]
|
449 |
+
st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][
|
450 |
+
"category_dict"
|
451 |
+
]
|
452 |
+
|
453 |
+
effectiveness_overall = pd.DataFrame()
|
454 |
+
|
455 |
+
response_metrics = list(
|
456 |
+
*[
|
457 |
+
st.session_state["category_dict"][key]
|
458 |
+
for key in st.session_state["category_dict"].keys()
|
459 |
+
if key == "Response Metrics"
|
460 |
+
]
|
461 |
+
)
|
462 |
+
|
463 |
+
effectiveness_overall = (
|
464 |
+
st.session_state["cleaned_data"][response_metrics].sum().reset_index()
|
465 |
+
)
|
466 |
+
|
467 |
+
effectiveness_overall.columns = ["ResponseMetricName", "ResponseMetricValue"]
|
468 |
+
|
469 |
+
|
470 |
+
effectiveness_overall["Efficiency"] = effectiveness_overall["ResponseMetricValue"].map(
|
471 |
+
lambda x: x / summary_df_sorted["Spends Optimized"].sum()
|
472 |
+
)
|
473 |
+
|
474 |
+
|
475 |
+
columns6 = st.columns(3)
|
476 |
+
|
477 |
+
effectiveness_overall.sort_values(
|
478 |
+
by=["ResponseMetricValue"], ascending=False, inplace=True
|
479 |
+
)
|
480 |
+
effectiveness_overall = np.round(effectiveness_overall, 2)
|
481 |
+
|
482 |
+
columns4 = st.columns([0.55, 0.45])
|
483 |
+
|
484 |
+
# effectiveness_overall = effectiveness_overall.sort_values(by=["ResponseMetricValue"])
|
485 |
+
|
486 |
+
# with columns4[0]:
|
487 |
+
# fig = px.funnel(
|
488 |
+
# effectiveness_overall,
|
489 |
+
# x="ResponseMetricValue",
|
490 |
+
# y="ResponseMetricName",
|
491 |
+
# color="ResponseMetricName",
|
492 |
+
# title="Effectiveness",
|
493 |
+
# )
|
494 |
+
# fig.update_layout(
|
495 |
+
# showlegend=False,
|
496 |
+
# yaxis=dict(tickmode="array"),
|
497 |
+
# )
|
498 |
+
# fig.update_traces(
|
499 |
+
# textinfo="value",
|
500 |
+
# textposition="inside",
|
501 |
+
# texttemplate="%{x:.2s} ",
|
502 |
+
# hoverinfo="y+x+percent initial",
|
503 |
+
# )
|
504 |
+
# fig.update_traces(
|
505 |
+
# marker=dict(line=dict(color="black", width=2)),
|
506 |
+
# selector=dict(marker=dict(color="blue")),
|
507 |
+
# )
|
508 |
+
|
509 |
+
# st.plotly_chart(fig, use_container_width=True)
|
510 |
+
|
511 |
+
# with columns4[1]:
|
512 |
+
# fig1 = px.bar(
|
513 |
+
# effectiveness_overall.sort_values(by=["ResponseMetricValue"], ascending=False),
|
514 |
+
# x="Efficiency",
|
515 |
+
# y="ResponseMetricName",
|
516 |
+
# color="ResponseMetricName",
|
517 |
+
# text_auto=True,
|
518 |
+
# title="Efficiency",
|
519 |
+
# )
|
520 |
+
|
521 |
+
# # Update layout and traces
|
522 |
+
# fig1.update_traces(
|
523 |
+
# customdata=effectiveness_overall["Efficiency"], textposition="auto"
|
524 |
+
# )
|
525 |
+
# fig1.update_layout(showlegend=False)
|
526 |
+
# fig1.update_yaxes(title="", showticklabels=False)
|
527 |
+
# fig1.update_xaxes(title="", showticklabels=False)
|
528 |
+
# fig1.update_xaxes(tickfont=dict(size=20))
|
529 |
+
# fig1.update_yaxes(tickfont=dict(size=20))
|
530 |
+
# st.plotly_chart(fig1, use_container_width=True)
|
531 |
+
|
532 |
+
# Function to format metric names
|
533 |
+
def format_metric_name(metric_name):
|
534 |
+
return str(metric_name).lower().replace("response_metric_", "").replace("_", " ").strip().title()
|
535 |
+
|
536 |
+
# Apply the formatting function to the 'ResponseMetricName' column
|
537 |
+
effectiveness_overall["FormattedMetricName"] = effectiveness_overall["ResponseMetricName"].apply(format_metric_name)
|
538 |
+
|
539 |
+
# Multiselect widget with all options as default, but using the formatted names for display
|
540 |
+
all_metrics = effectiveness_overall["FormattedMetricName"].unique()
|
541 |
+
selected_metrics = st.multiselect(
|
542 |
+
"Select Metrics to Display",
|
543 |
+
options=all_metrics,
|
544 |
+
default=all_metrics
|
545 |
+
)
|
546 |
+
|
547 |
+
# Filter the data based on the selected metrics (using formatted names)
|
548 |
+
if selected_metrics:
|
549 |
+
filtered_data = effectiveness_overall[
|
550 |
+
effectiveness_overall["FormattedMetricName"].isin(selected_metrics)
|
551 |
+
]
|
552 |
+
|
553 |
+
# Sort values for funnel plot
|
554 |
+
filtered_data = filtered_data.sort_values(by=["ResponseMetricValue"])
|
555 |
+
|
556 |
+
# Generate a consistent color mapping for all selected metrics
|
557 |
+
color_map = {metric: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]
|
558 |
+
for i, metric in enumerate(filtered_data["FormattedMetricName"].unique())}
|
559 |
+
|
560 |
+
# First plot: Funnel
|
561 |
+
with columns4[0]:
|
562 |
+
fig = px.funnel(
|
563 |
+
filtered_data,
|
564 |
+
x="ResponseMetricValue",
|
565 |
+
y="FormattedMetricName", # Use formatted names for y-axis
|
566 |
+
color="FormattedMetricName", # Use formatted names for color
|
567 |
+
color_discrete_map=color_map, # Ensure consistent colors
|
568 |
+
title="Effectiveness",
|
569 |
+
)
|
570 |
+
fig.update_layout(
|
571 |
+
showlegend=False,
|
572 |
+
yaxis=dict(title="Response Metric", tickmode="array"), # Set y-axis label to 'Response Metric'
|
573 |
+
)
|
574 |
+
fig.update_traces(
|
575 |
+
textinfo="value",
|
576 |
+
textposition="inside",
|
577 |
+
texttemplate="%{x:.2s} ",
|
578 |
+
hoverinfo="y+x+percent initial",
|
579 |
+
)
|
580 |
+
fig.update_traces(
|
581 |
+
marker=dict(line=dict(color="black", width=2)),
|
582 |
+
selector=dict(marker=dict(color="blue")),
|
583 |
+
)
|
584 |
+
|
585 |
+
st.plotly_chart(fig, use_container_width=True)
|
586 |
+
|
587 |
+
# Second plot: Bar chart
|
588 |
+
with columns4[1]:
|
589 |
+
fig1 = px.bar(
|
590 |
+
filtered_data.sort_values(by=["ResponseMetricValue"], ascending=False),
|
591 |
+
x="Efficiency",
|
592 |
+
y="FormattedMetricName", # Use formatted names for y-axis
|
593 |
+
color="FormattedMetricName", # Use formatted names for color
|
594 |
+
color_discrete_map=color_map, # Ensure consistent colors
|
595 |
+
text_auto=True,
|
596 |
+
title="Efficiency",
|
597 |
+
)
|
598 |
+
|
599 |
+
# Update layout and traces
|
600 |
+
fig1.update_traces(
|
601 |
+
customdata=filtered_data["Efficiency"], textposition="auto"
|
602 |
+
)
|
603 |
+
fig1.update_layout(showlegend=False)
|
604 |
+
fig1.update_yaxes(title="", showticklabels=False)
|
605 |
+
fig1.update_xaxes(title="", showticklabels=False)
|
606 |
+
fig1.update_xaxes(tickfont=dict(size=20))
|
607 |
+
fig1.update_yaxes(tickfont=dict(size=20))
|
608 |
+
st.plotly_chart(fig1, use_container_width=True)
|
609 |
+
else:
|
610 |
+
st.info("Please select at least one response metric to display the charts.")
|
611 |
+
|
612 |
+
st.header("Return Forecast by Media Channel")
|
613 |
+
|
614 |
+
with st.expander("Return Forecast by Media Channel"):
|
615 |
+
|
616 |
+
|
617 |
+
metric = metrics_name
|
618 |
+
|
619 |
+
metric = metric.lower().replace("_", " ") + " " + "actual"
|
620 |
+
summary_df_sorted.columns = [
|
621 |
+
col.lower().replace("_", " ") for col in summary_df_sorted.columns
|
622 |
+
]
|
623 |
+
|
624 |
+
effectiveness = summary_df_sorted[metric]
|
625 |
+
|
626 |
+
summary_df_sorted["Efficiency"] = (
|
627 |
+
summary_df_sorted[metric] / summary_df_sorted["spends optimized"]
|
628 |
+
)
|
629 |
+
|
630 |
+
channel_colors = px.colors.qualitative.Plotly
|
631 |
+
|
632 |
+
fig = make_subplots(
|
633 |
+
rows=1,
|
634 |
+
cols=3,
|
635 |
+
subplot_titles=("Optimized Spends", "Effectiveness", "Efficiency"),
|
636 |
+
horizontal_spacing=0.05,
|
637 |
+
)
|
638 |
+
|
639 |
+
for i, channel in enumerate(summary_df_sorted["channel name"].unique()):
|
640 |
+
channel_df = summary_df_sorted[summary_df_sorted["channel name"] == channel]
|
641 |
+
channel_color = channel_colors[i % len(channel_colors)]
|
642 |
+
|
643 |
+
fig.add_trace(
|
644 |
+
go.Bar(
|
645 |
+
x=channel_df["spends optimized"],
|
646 |
+
y=channel_df["channel name"],
|
647 |
+
text=channel_df["spends optimized"].apply(format_numbers),
|
648 |
+
marker_color=channel_color,
|
649 |
+
orientation="h",
|
650 |
+
),
|
651 |
+
row=1,
|
652 |
+
col=1,
|
653 |
+
)
|
654 |
+
|
655 |
+
fig.add_trace(
|
656 |
+
go.Bar(
|
657 |
+
x=channel_df[metric],
|
658 |
+
y=channel_df["channel name"],
|
659 |
+
text=channel_df[metric].apply(format_numbers),
|
660 |
+
marker_color=channel_color,
|
661 |
+
orientation="h",
|
662 |
+
showlegend=False,
|
663 |
+
),
|
664 |
+
row=1,
|
665 |
+
col=2,
|
666 |
+
)
|
667 |
+
|
668 |
+
fig.add_trace(
|
669 |
+
go.Bar(
|
670 |
+
x=channel_df["Efficiency"],
|
671 |
+
y=channel_df["channel name"],
|
672 |
+
text=channel_df["Efficiency"].apply(lambda x: f"{x:.2f}"),
|
673 |
+
marker_color=channel_color,
|
674 |
+
orientation="h",
|
675 |
+
showlegend=False,
|
676 |
+
),
|
677 |
+
row=1,
|
678 |
+
col=3,
|
679 |
+
)
|
680 |
+
|
681 |
+
fig.update_layout(
|
682 |
+
height=600,
|
683 |
+
width=900,
|
684 |
+
title="Media Channel Performance",
|
685 |
+
showlegend=False,
|
686 |
+
)
|
687 |
+
|
688 |
+
fig.update_yaxes(showticklabels=False, row=1, col=2)
|
689 |
+
fig.update_yaxes(showticklabels=False, row=1, col=3)
|
690 |
+
|
691 |
+
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
692 |
+
fig.update_xaxes(showticklabels=False, row=1, col=2)
|
693 |
+
fig.update_xaxes(showticklabels=False, row=1, col=3)
|
694 |
+
|
695 |
+
st.plotly_chart(fig, use_container_width=True)
|
pages/12_Glossary.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import streamlit as st
|
3 |
+
import base64
|
4 |
+
|
5 |
+
from utilities import set_header, load_local_css
|
6 |
+
|
7 |
+
st.set_page_config(
|
8 |
+
page_title="Glossary",
|
9 |
+
page_icon=":shark:",
|
10 |
+
layout="wide",
|
11 |
+
initial_sidebar_state="collapsed",
|
12 |
+
)
|
13 |
+
|
14 |
+
load_local_css("styles.css")
|
15 |
+
set_header()
|
16 |
+
|
17 |
+
|
18 |
+
st.header("Glossary")
|
19 |
+
|
20 |
+
# Glossary
|
21 |
+
st.markdown(
|
22 |
+
"""
|
23 |
+
|
24 |
+
### 1. Glossary - Home
|
25 |
+
**Blend Employee ID:** Users should have access to the UI through their Blend Employee ID. User data should be available and validated through this ID.
|
26 |
+
|
27 |
+
**API Data:** Incorporate data retrieved from APIs to be used in the tool.
|
28 |
+
|
29 |
+
### 2. Glossary – Data Import
|
30 |
+
**Granularity:** Defines the level of detail to which the data should be aggregated.
|
31 |
+
|
32 |
+
**Panel:** Represents columns corresponding to markets, DMAs (Designated Market Areas).
|
33 |
+
|
34 |
+
**Response Metrics:** Target variables or metrics such as App Installs, Revenue, Form Submission/Conversion. These are the variables used for model building and spends optimization.
|
35 |
+
|
36 |
+
**Spends:** Variables representing spend data for all media channels.
|
37 |
+
|
38 |
+
**Exogenous:** External variables like bank holidays, GDP, online trends, interest rates.
|
39 |
+
|
40 |
+
**Internal:** Internal variables such as discounts or promotions.
|
41 |
+
|
42 |
+
### 3. Glossary – Data Assessment
|
43 |
+
**Trendline:** Represents the linear average movement over the entire time period.
|
44 |
+
|
45 |
+
**Media Variables:** Variables related to media activities such as Impressions, Clicks, Spends, Views. Examples: Bing Search Clicks, Instagram Impressions, YouTube Impressions.
|
46 |
+
|
47 |
+
**Non-Media Variables:** Variables that exclude media-related data. Examples: Discount, Holiday, Google Trend.
|
48 |
+
|
49 |
+
**CPM (Cost per 1000 Impressions):** Calculated as (YouTube Spends / YouTube Impressions) * 1000.
|
50 |
+
|
51 |
+
**CPC (Cost per 1000 Clicks):** Calculated as (YouTube Spends / YouTube Clicks) * 1000.
|
52 |
+
|
53 |
+
**Correlation:** Pearson correlation measures the linear relationship between two continuous variables, indicating the strength and direction of their association.
|
54 |
+
|
55 |
+
### 4. Glossary – Transformations
|
56 |
+
|
57 |
+
**Transformations:** Transformations involve adjusting input variables to capture nonlinear relationships, like diminishing returns and advertising carryover. They are always applied at the panel level if panel exist; if no panel exists, the transformations are applied directly to the aggregated level.
|
58 |
+
|
59 |
+
**Lag:** Shifts the data backward by a specified number of periods. Formula: Lagged Series = Xt−lag
|
60 |
+
|
61 |
+
**Lead:** Shifts the data forward by a specified number of periods. Formula: Lead Series = Xt+lead
|
62 |
+
|
63 |
+
**Moving Average:** Smooths the data by averaging values over a specified window size. Formula: Moving Average = 1/𝑛 ∑1_0^(𝑛−1)𝑋_(𝑡−1)
|
64 |
+
|
65 |
+
**Saturation:** Applies a saturation effect to the data based on a specified saturation percentage. Formula: 𝑌_𝑡= (1/(1 +[(𝑠𝑎𝑡𝑢𝑟𝑎𝑡𝑖𝑜𝑛 𝑝𝑜𝑖𝑛𝑡)/𝑋_𝑡 ]^𝑠𝑡𝑒𝑒𝑝𝑛𝑒𝑠𝑠 )) × 𝑋_𝑡
|
66 |
+
|
67 |
+
**Power:** Raises the data to a specified power. Formula: Yt = Xtpower
|
68 |
+
|
69 |
+
**Adstock:** Applies a decay effect to the data, simulating diminishing returns over time. Formula: Yt = Xt + decay rate x Yt-1
|
70 |
+
|
71 |
+
### 5. Glossary - AI Model Build
|
72 |
+
**Train Set:** The train set is a subset of the data on which the AI Model is trained. It is standard practice to select 70-75% of the data as train set.
|
73 |
+
|
74 |
+
**Test Set:** The test set is the subset of data on which the AI Model’s performance is tested. There will be no common records between the train and test sets.
|
75 |
+
|
76 |
+
**Residual:** Residual is defined as the difference between true value and predicted value from the AI Model. 𝑅𝑒𝑠𝑖𝑑𝑢𝑎𝑙=𝑇𝑟𝑢𝑒 𝑉𝑎𝑙𝑢𝑒 −𝑃𝑟𝑒𝑑𝑖𝑐𝑡𝑒𝑑 𝑉𝑎𝑙𝑢𝑒
|
77 |
+
|
78 |
+
**Actual VS Predicted Plot:** An actual vs. predicted plot visualizes the relationship between the actual values and the values predicted by the AI model, helping to assess model performance and identify patterns or discrepancies.
|
79 |
+
|
80 |
+
**MAPE:** MAPE or Mean Absolute Percentage Error indicates the percentage of error in the AI Model’s predictions. We use a variant of MAPE called Weighted Average Percentage Error. 𝑀𝐴𝑃𝐸= ∑ |𝐴𝑐𝑡𝑢𝑎𝑙 𝑉𝑎𝑙𝑢𝑒 −𝑃𝑟𝑒𝑑𝑖𝑐𝑡𝑒𝑑 𝑉𝑎𝑙𝑢𝑒| ÷∑ [|𝐴𝑐𝑡𝑢𝑎𝑙 𝑉𝑎𝑙𝑢𝑒|]
|
81 |
+
|
82 |
+
**R-Squared:** R-Squared is a number that tells you how well the independent variable(s) in a AI model explain the variation in the dependent variable. Indicates the goodness of fit, with values closer to 1 suggesting a better fit.
|
83 |
+
|
84 |
+
**Adjusted R-Squared:** Adjusted R-squared modifies the R-squared value to account for the number of predictors in the model, providing a more accurate measure of goodness of fit by penalizing the addition of non-significant predictors.
|
85 |
+
|
86 |
+
**Multicollinearity:** Multicollinearity refers to a situation where two or more independent variables (media channels or marketing inputs) are highly correlated with each other. This makes it difficult to isolate the individual effect of each channel on the dependent variable (e.g., sales). It can lead to unreliable coefficient estimates, making it hard to determine the true impact of each media type.
|
87 |
+
|
88 |
+
### 6. Glossary – AI Model Tuning
|
89 |
+
**Event:** An event refers to a specific occurrence or campaign, such as a promotion, holiday, or product launch, that can impact the performance of response metric
|
90 |
+
|
91 |
+
**Trend:** Trend is a straight line which helps capture the underlying direction or pattern in the data over time, helping to identify and quantify the increasing or decreasing trend of the dependent variable.
|
92 |
+
|
93 |
+
**Day of Week:** The "day of week" feature represents the specific day within a week (e.g., Monday, Tuesday) and is used to capture patterns or seasonal effects that occur on specific days.
|
94 |
+
|
95 |
+
**Sine & Cosine Waves:** Sine and cosine waves are mathematical functions used to capture seasonality in the dependent variable
|
96 |
+
|
97 |
+
**Contribution:** Contribution of a channel is the percentage of its contribution to the response metric. Contribution is an output of the AI Model, calculated using the media data and model’s coefficients
|
98 |
+
|
99 |
+
### 7. Glossary – Response Curves
|
100 |
+
|
101 |
+
**Response Curve:** A response curve in media mix modeling represents the relationship between media inputs (e.g., impressions, clicks) on the X-axis and the resulting business outcomes (e.g., revenue, app installs) on the Y-axis, illustrating the impact of media variables on the desired response metric.
|
102 |
+
|
103 |
+
**R-squared (R-squared):** R-squared (R-squared) is a statistical measure that represents the proportion of the variance in the dependent variable that is predictable from the independent variables. It indicates how well the regression model fits the data, with a value between 0 and 1, where 1 indicates a perfect fit.
|
104 |
+
**Formula for R-squared (R-squared):**
|
105 |
+
R-squared = 1 − SSres / SStot
|
106 |
+
Where:
|
107 |
+
- **SSres** is the sum of squares of residuals (errors).
|
108 |
+
- **SStot** is the total sum of squares (the variance of the dependent variable).
|
109 |
+
|
110 |
+
**Actual R-squared:** Actual R-squared is used to evaluate how well a s-curve fits the data. Modified R-squared is used to evaluate how well a modified (manually tuned) s-curve fits the data. The difference between modified s-curve and actual s-curve shows how much the fit improves or worsens.
|
111 |
+
|
112 |
+
### 8. Glossary – Scenario Planner
|
113 |
+
|
114 |
+
**CPA (Cost Per Acquisition):** The cost associated with acquiring a customer or conversion. CPA = Total Spend / Number of Acquisitions
|
115 |
+
|
116 |
+
**ROI (Return on Investment):** A measure of the profitability of an investment relative to its cost. ROI = (Revenue - Total Spend) / Total Spend
|
117 |
+
|
118 |
+
**mROI (Marginal Return on Investment):** Measures the additional return generated by an additional unit of investment. It represents the slope of the ROI curve and shows the incremental effectiveness of the investment. mROI = Δspend / ΔRevenue
|
119 |
+
|
120 |
+
**Bounds (Upper, Lower):** Constraints applied during optimization to ensure that media channel spends remain within specified limits. These bounds help in defining the feasible range for investments and prevent overspending or underspending.
|
121 |
+
|
122 |
+
**Panel & Timeframe:**
|
123 |
+
- **Panel:** The ability to optimize media strategies at a granular level based on specific categories such as geographies, product types, customer segments, or other defined groups. This allows for tailored strategies that address the unique characteristics and needs of each panel.
|
124 |
+
- **Timeframe:** The ability to optimize media strategies within specific time periods, such as by month, quarter, or year. This enables precise adjustments, ensuring that the approach aligns with specific business cycles.
|
125 |
+
|
126 |
+
**Actual Vs Optimized:**
|
127 |
+
- **Actual:** This category includes the unoptimized values, such as the actual spends and actual response metrics, reflecting the current or historical performance without any optimization applied.
|
128 |
+
- **Optimized:** This category encompasses the optimized values, such as optimized spends and optimized response metrics, representing the improved or adjusted figures after the optimization process.
|
129 |
+
|
130 |
+
**Regions:**
|
131 |
+
- **Yellow Region:** This represents the under-invested area where additional investments can lead to entering a region with a higher ROI than the average. Investing more in this region can yield better returns.
|
132 |
+
- **Green Region:** This is the optimal area where the ROI of the channel is above the average ROI, and the Marginal Return on Investment (mROI) is greater than 1, indicating that each additional dollar spent is generating substantial returns.
|
133 |
+
- **Red Region:** This represents the over-invested or saturated area where the mROI is less than 1, meaning that additional spending yields diminishing returns and is less effective.
|
134 |
+
|
135 |
+
**Important Formulas:���**
|
136 |
+
**1. Incremental CPA (%)** = (Total Optimized CPA−Total Actual CPA) / (Total Actual CPA)
|
137 |
+
**2. Incremental Spend (%)** = (Total Optimized Spend−Total Actual Spend) / (Total Actual Spend)
|
138 |
+
**Note:** Calculated based on total actual spends (base spends is always zero)
|
139 |
+
**3. Incremental Response Metric (%) [Media]** = (Total Optimized Response Metric−Total Actual Response Metric) / (Total Actual Response Metric−Total Base Response Metric)
|
140 |
+
**Note:** Calculated based on media portion of total actual response metric only, excluding the fixed base contribution
|
141 |
+
**4. Incremental Response Metric (%) [Total]** = (Total Optimized Response Metric−Total Actual Response Metric) / (Total Actual Response Metric)
|
142 |
+
**Note:** Calculated based on total actual response metric only, including the fixed base contribution
|
143 |
+
|
144 |
+
### 10. Glossary – Model Optimized Recommendation
|
145 |
+
|
146 |
+
**% Split:** The percentage distribution of total media spend across different channels, optimized to achieve the best results.
|
147 |
+
|
148 |
+
**Return forecast by Media Channel:** Effectiveness and Efficiency for selected response metric.
|
149 |
+
"""
|
150 |
+
)
|
pages/14_User_Management.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(layout="wide")
|
4 |
+
import pandas as pd
|
5 |
+
from utilities import (
|
6 |
+
load_local_css,
|
7 |
+
set_header,
|
8 |
+
query_excecuter_postgres,
|
9 |
+
store_hashed_password,
|
10 |
+
verify_password,
|
11 |
+
is_pswrd_flag_set,
|
12 |
+
)
|
13 |
+
import psycopg2
|
14 |
+
|
15 |
+
#
|
16 |
+
import bcrypt
|
17 |
+
import re
|
18 |
+
import time
|
19 |
+
import random
|
20 |
+
import string
|
21 |
+
from log_application import log_message
|
22 |
+
|
23 |
+
# setting page config
|
24 |
+
load_local_css("styles.css")
|
25 |
+
set_header()
|
26 |
+
|
27 |
+
# schema=db_cred['schema']
|
28 |
+
|
29 |
+
db_cred = None
|
30 |
+
|
31 |
+
|
32 |
+
def fetch_password_hash_key():
|
33 |
+
query = f"""
|
34 |
+
SELECT emp_id
|
35 |
+
FROM mmo_users
|
36 |
+
WHERE emp_nam ='admin';
|
37 |
+
"""
|
38 |
+
hashkey = query_excecuter_postgres(query, db_cred, insert=False)
|
39 |
+
|
40 |
+
return hashkey
|
41 |
+
|
42 |
+
|
43 |
+
def fetch_users_with_access():
|
44 |
+
# Query to get allowed employee IDs for the project
|
45 |
+
unique_users_query = f"""
|
46 |
+
SELECT DISTINCT emp_id, emp_nam, emp_typ
|
47 |
+
FROM mmo_users;
|
48 |
+
"""
|
49 |
+
try:
|
50 |
+
unique_users_result = query_excecuter_postgres(
|
51 |
+
unique_users_query, db_cred, insert=False
|
52 |
+
)
|
53 |
+
|
54 |
+
unique_users_result_df = pd.DataFrame(unique_users_result)
|
55 |
+
unique_users_result_df.columns = ["User ID", "User Name", "User Type"]
|
56 |
+
|
57 |
+
st.session_state["df_users"] = unique_users_result_df
|
58 |
+
except:
|
59 |
+
st.session_state["df_users"] = pd.DataFrame()
|
60 |
+
|
61 |
+
|
62 |
+
def add_user(emp_id, emp_name, emp_type, admin_id, plain_text_password):
|
63 |
+
"""
|
64 |
+
Adds a new user to the mmo_users table if the user does not already exist.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
emp_id (str): Employee ID of the user.
|
68 |
+
emp_name (str): Name of the user.
|
69 |
+
emp_type (str): Type of the user.
|
70 |
+
db_cred (dict): Database credentials with keys 'dbname', 'user', 'password', 'host', 'port'.
|
71 |
+
schema (str): The schema name.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
str: A success or error message.
|
75 |
+
"""
|
76 |
+
# Query to check if the user already exists
|
77 |
+
check_user_query = f"""
|
78 |
+
SELECT emp_id FROM mmo_users
|
79 |
+
WHERE emp_id = ?;
|
80 |
+
"""
|
81 |
+
params_check = (emp_id,)
|
82 |
+
|
83 |
+
try:
|
84 |
+
# Check if the user already exists
|
85 |
+
existing_user = query_excecuter_postgres(
|
86 |
+
check_user_query, db_cred, params=params_check, insert=False
|
87 |
+
)
|
88 |
+
# st.write(existing_user)
|
89 |
+
|
90 |
+
if existing_user:
|
91 |
+
# If user exists, return a warning message
|
92 |
+
return False
|
93 |
+
|
94 |
+
# Query to add a new user to the mmo_users table
|
95 |
+
else:
|
96 |
+
hashed_password = bcrypt.hashpw(
|
97 |
+
plain_text_password.encode("utf-8"), bcrypt.gensalt()
|
98 |
+
)
|
99 |
+
|
100 |
+
# Convert the byte string to a regular string for storage
|
101 |
+
hashed_password_str = hashed_password.decode("utf-8")
|
102 |
+
add_user_query = f"""
|
103 |
+
INSERT INTO mmo_users
|
104 |
+
(emp_id, emp_nam, emp_typ, crte_dt_tm, crte_by_uid,pswrd_key)
|
105 |
+
VALUES (?, ?, ?,datetime('now'), ?,?);
|
106 |
+
"""
|
107 |
+
params_add = (emp_id, emp_name, emp_type, admin_id, hashed_password_str)
|
108 |
+
|
109 |
+
# Execute the insert query
|
110 |
+
query_excecuter_postgres(
|
111 |
+
add_user_query, db_cred, params=params_add, insert=True
|
112 |
+
)
|
113 |
+
|
114 |
+
return True
|
115 |
+
|
116 |
+
except Exception as e:
|
117 |
+
return False
|
118 |
+
|
119 |
+
|
120 |
+
# def delete_users_by_names(emp_id):
|
121 |
+
|
122 |
+
# # Sanitize and format the list of names for the SQL query
|
123 |
+
# formatted_ids =tuple(emp_id)
|
124 |
+
|
125 |
+
# # Query to delete users based on the employee names
|
126 |
+
# delete_user_query = f"""
|
127 |
+
# DELETE FROM mmo_users
|
128 |
+
# WHERE emp_id IN ?;
|
129 |
+
# """
|
130 |
+
# try:
|
131 |
+
# # Execute the delete query
|
132 |
+
# query_excecuter_postgres(delete_user_query, db_cred,params=(formatted_ids,), insert=True)
|
133 |
+
|
134 |
+
# return f"{len(emp_id)} users deleted successfully."
|
135 |
+
|
136 |
+
# except Exception as e:
|
137 |
+
# st.write(e)
|
138 |
+
# print(e)
|
139 |
+
# return f"An error occurred: {e}"
|
140 |
+
|
141 |
+
|
142 |
+
def delete_users_by_names(emp_ids):
|
143 |
+
"""
|
144 |
+
Deletes users from the mmo_users table and their associated projects from the mmo_project_meta_data
|
145 |
+
and mmo_projects tables.
|
146 |
+
|
147 |
+
Parameters:
|
148 |
+
emp_ids (list): A list of employee IDs to delete.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
str: A success or error message.
|
152 |
+
"""
|
153 |
+
# Convert the list of employee IDs into a tuple for the SQL query
|
154 |
+
formatted_ids = tuple(emp_ids)
|
155 |
+
|
156 |
+
# Queries to delete projects and users
|
157 |
+
delete_projects_meta_query = """
|
158 |
+
DELETE FROM mmo_project_meta_data
|
159 |
+
WHERE prj_id IN (
|
160 |
+
SELECT prj_id FROM mmo_projects
|
161 |
+
WHERE prj_ownr_id IN ({})
|
162 |
+
);
|
163 |
+
""".format(
|
164 |
+
",".join("?" * len(formatted_ids))
|
165 |
+
)
|
166 |
+
|
167 |
+
delete_projects_query = """
|
168 |
+
DELETE FROM mmo_projects
|
169 |
+
WHERE prj_ownr_id IN ({});
|
170 |
+
""".format(
|
171 |
+
",".join("?" * len(formatted_ids))
|
172 |
+
)
|
173 |
+
|
174 |
+
delete_user_query = """
|
175 |
+
DELETE FROM mmo_users
|
176 |
+
WHERE emp_id IN ({});
|
177 |
+
""".format(
|
178 |
+
",".join("?" * len(formatted_ids))
|
179 |
+
)
|
180 |
+
|
181 |
+
try:
|
182 |
+
# Execute the delete queries using query_excecuter_postgres
|
183 |
+
query_excecuter_postgres(
|
184 |
+
delete_projects_meta_query, params=formatted_ids, insert=True
|
185 |
+
)
|
186 |
+
query_excecuter_postgres(
|
187 |
+
delete_projects_query, params=formatted_ids, insert=True
|
188 |
+
)
|
189 |
+
query_excecuter_postgres(delete_user_query, params=formatted_ids, insert=True)
|
190 |
+
|
191 |
+
return (
|
192 |
+
f"{len(emp_ids)} users and their associated projects deleted successfully."
|
193 |
+
)
|
194 |
+
|
195 |
+
except Exception as e:
|
196 |
+
return f"An error occurred: {e}"
|
197 |
+
|
198 |
+
|
199 |
+
def reset_user_name():
|
200 |
+
st.session_state.user_id = ""
|
201 |
+
st.session_state.user_name = ""
|
202 |
+
|
203 |
+
|
204 |
+
def contains_sql_keywords_check(user_input):
|
205 |
+
|
206 |
+
sql_keywords = [
|
207 |
+
"SELECT",
|
208 |
+
"INSERT",
|
209 |
+
"UPDATE",
|
210 |
+
"DELETE",
|
211 |
+
"DROP",
|
212 |
+
"ALTER",
|
213 |
+
"CREATE",
|
214 |
+
"GRANT",
|
215 |
+
"REVOKE",
|
216 |
+
"UNION",
|
217 |
+
"JOIN",
|
218 |
+
"WHERE",
|
219 |
+
"HAVING",
|
220 |
+
"EXEC",
|
221 |
+
"TRUNCATE",
|
222 |
+
"REPLACE",
|
223 |
+
"MERGE",
|
224 |
+
"DECLARE",
|
225 |
+
"SHOW",
|
226 |
+
"FROM",
|
227 |
+
]
|
228 |
+
|
229 |
+
pattern = "|".join(re.escape(keyword) for keyword in sql_keywords)
|
230 |
+
|
231 |
+
return re.search(pattern, user_input, re.IGNORECASE)
|
232 |
+
|
233 |
+
|
234 |
+
def update_password_and_flag(user_id, plain_text_password, flag=False):
|
235 |
+
"""
|
236 |
+
Hashes a plain text password using bcrypt, converts it to a UTF-8 string, and stores it as text.
|
237 |
+
|
238 |
+
Parameters:
|
239 |
+
plain_text_password (str): The plain text password to be hashed.
|
240 |
+
db_cred (dict): The database credentials including dbname, user, password, host, and port.
|
241 |
+
"""
|
242 |
+
# Hash the plain text password
|
243 |
+
hashed_password = bcrypt.hashpw(
|
244 |
+
plain_text_password.encode("utf-8"), bcrypt.gensalt()
|
245 |
+
)
|
246 |
+
|
247 |
+
# Convert the byte string to a regular string for storage
|
248 |
+
hashed_password_str = hashed_password.decode("utf-8")
|
249 |
+
|
250 |
+
# SQL query to update the pswrd_key for the specified user_id
|
251 |
+
if flag:
|
252 |
+
query = f"""
|
253 |
+
UPDATE mmo_users
|
254 |
+
SET pswrd_key = ?, pswrd_flag = 0
|
255 |
+
WHERE emp_id = ?;
|
256 |
+
"""
|
257 |
+
else:
|
258 |
+
query = f"""
|
259 |
+
UPDATE mmo_users
|
260 |
+
SET pswrd_key = ?
|
261 |
+
WHERE emp_id = ?;
|
262 |
+
"""
|
263 |
+
# Execute the query using the existing query_excecuter_postgres function
|
264 |
+
query_excecuter_postgres(
|
265 |
+
query=query, db_cred=db_cred, params=(hashed_password_str, user_id), insert=True
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
def reset_pass_key():
|
270 |
+
st.session_state["pass_key"] = ""
|
271 |
+
|
272 |
+
|
273 |
+
st.title("User Management")
|
274 |
+
|
275 |
+
# hashed_key=fetch_password_hash_key()[0][0]
|
276 |
+
|
277 |
+
if "df_users" not in st.session_state:
|
278 |
+
fetch_users_with_access()
|
279 |
+
# st.write(hashed_key.encode())
|
280 |
+
if "unique_ids_admin" not in st.session_state:
|
281 |
+
unique_users_query = f"""
|
282 |
+
SELECT DISTINCT emp_id, emp_nam, emp_typ from mmo_users
|
283 |
+
Where emp_typ= 'admin';
|
284 |
+
"""
|
285 |
+
unique_users_result = query_excecuter_postgres(
|
286 |
+
unique_users_query, db_cred, insert=False
|
287 |
+
) # retrieves all the users who has access to MMO TOOL
|
288 |
+
st.session_state["unique_ids_admin"] = {
|
289 |
+
emp_id: (emp_nam, emp_type) for emp_id, emp_nam, emp_type in unique_users_result
|
290 |
+
}
|
291 |
+
|
292 |
+
|
293 |
+
if len(st.session_state["unique_ids_admin"]) == 0:
|
294 |
+
|
295 |
+
st.error("No admin found in the database!")
|
296 |
+
st.markdown(
|
297 |
+
"""
|
298 |
+
- You can add an admin to the database.
|
299 |
+
- **Ensure** that you store the passkey securely, as losing it will prevent anyone from accessing the tool.
|
300 |
+
- Once done, reset the password on the "**Home**" page.
|
301 |
+
"""
|
302 |
+
)
|
303 |
+
|
304 |
+
emp_id = st.text_input("employee id", key="emp1111").lower()
|
305 |
+
password = st.text_input("Enter password to access the page", type="password")
|
306 |
+
|
307 |
+
|
308 |
+
if emp_id not in st.session_state["unique_ids_admin"] and len(emp_id) > 4:
|
309 |
+
|
310 |
+
if not is_pswrd_flag_set(emp_id):
|
311 |
+
st.warning("Reset password in Home page to continue")
|
312 |
+
st.stop()
|
313 |
+
|
314 |
+
st.warning(
|
315 |
+
'Incorrect username or password. If you are using the default password, please reset it on the "Home" page.'
|
316 |
+
)
|
317 |
+
|
318 |
+
st.stop()
|
319 |
+
|
320 |
+
|
321 |
+
# st.write(st.session_state["unique_ids_admin"])
|
322 |
+
|
323 |
+
|
324 |
+
# Check password or if no admin is present
|
325 |
+
|
326 |
+
if verify_password(emp_id, password) or len(st.session_state["unique_ids_admin"]) == 0:
|
327 |
+
|
328 |
+
st.success("Access Granted", icon="✅")
|
329 |
+
|
330 |
+
st.header("All Users")
|
331 |
+
|
332 |
+
st.dataframe(
|
333 |
+
st.session_state["df_users"], use_container_width=True, hide_index=True
|
334 |
+
)
|
335 |
+
|
336 |
+
user_manage = st.radio(
|
337 |
+
"Select a Option",
|
338 |
+
["Add New User", "Delete Users", "Manage User Passwords"],
|
339 |
+
horizontal=True,
|
340 |
+
)
|
341 |
+
|
342 |
+
if user_manage == "Add New User":
|
343 |
+
with st.expander("", expanded=True):
|
344 |
+
add_col = st.columns(3)
|
345 |
+
|
346 |
+
with add_col[0]:
|
347 |
+
user_id = st.text_input("Enter User ID", key="user_id").lower()
|
348 |
+
|
349 |
+
with add_col[1]:
|
350 |
+
user_name = st.text_input("Enter User Name", key="user_name").lower()
|
351 |
+
|
352 |
+
with add_col[2]:
|
353 |
+
|
354 |
+
if len(st.session_state["unique_ids_admin"]) == 0:
|
355 |
+
user_types_options = ["Admin"]
|
356 |
+
else:
|
357 |
+
user_types_options = ["Data Scientist", "Media Planner", "Admin"]
|
358 |
+
|
359 |
+
user_type = (
|
360 |
+
st.selectbox("Select User Type", user_types_options)
|
361 |
+
.replace(" ", "_")
|
362 |
+
.lower()
|
363 |
+
)
|
364 |
+
warning_box = st.empty()
|
365 |
+
|
366 |
+
if "passkey" not in st.session_state:
|
367 |
+
st.session_state["passkey"] = ""
|
368 |
+
if len(user_id) < 3 or len(user_name) < 3:
|
369 |
+
st.session_state["passkey"] = ""
|
370 |
+
|
371 |
+
pass_key_col = st.columns(3)
|
372 |
+
with pass_key_col[0]:
|
373 |
+
st.markdown(
|
374 |
+
'<div style="display: flex; position:relative; justify-content: flex-end; align-items: center; height: 100%;">'
|
375 |
+
'<p style="font-weight: bold; margin: 0; position:absolute; top:4px;">Default Password</p>'
|
376 |
+
"</div>",
|
377 |
+
unsafe_allow_html=True,
|
378 |
+
)
|
379 |
+
with pass_key_col[1]:
|
380 |
+
passkey_box = st.empty()
|
381 |
+
with passkey_box:
|
382 |
+
st.code(f"{st.session_state['passkey']}", language="text")
|
383 |
+
|
384 |
+
st.button(
|
385 |
+
"Reset Values", on_click=reset_user_name, use_container_width=True
|
386 |
+
)
|
387 |
+
|
388 |
+
if st.button("Add User", use_container_width=True):
|
389 |
+
|
390 |
+
if user_id in st.session_state["df_users"]["User ID"]:
|
391 |
+
with warning_box:
|
392 |
+
st.warning("User id already exists")
|
393 |
+
st.stop()
|
394 |
+
|
395 |
+
if (
|
396 |
+
len(user_id) == 0
|
397 |
+
or len(user_name) == 0
|
398 |
+
or not user_id.startswith("e")
|
399 |
+
or len(user_id) < 6
|
400 |
+
):
|
401 |
+
with warning_box:
|
402 |
+
st.warning("Enter a Valid User ID and User Name!")
|
403 |
+
st.stop()
|
404 |
+
|
405 |
+
if contains_sql_keywords_check(user_id):
|
406 |
+
with warning_box:
|
407 |
+
st.warning(
|
408 |
+
"Input contains SQL keywords. Please avoid using SQL commands."
|
409 |
+
)
|
410 |
+
st.stop()
|
411 |
+
else:
|
412 |
+
pass
|
413 |
+
|
414 |
+
if not (2 <= len(user_name) <= 50):
|
415 |
+
# Store the warning message details in session state
|
416 |
+
|
417 |
+
with warning_box:
|
418 |
+
st.warning(
|
419 |
+
"Please provide a valid user name (2-50 characters, only A-Z, a-z, 0-9, and _)."
|
420 |
+
)
|
421 |
+
st.stop()
|
422 |
+
|
423 |
+
if contains_sql_keywords_check(user_name):
|
424 |
+
with warning_box:
|
425 |
+
st.warning(
|
426 |
+
"Input contains SQL keywords. Please avoid using SQL commands."
|
427 |
+
)
|
428 |
+
st.stop()
|
429 |
+
else:
|
430 |
+
pass
|
431 |
+
|
432 |
+
characters = string.ascii_letters + string.digits # Letters and digits
|
433 |
+
plain_text_password = "".join(
|
434 |
+
random.choice(characters) for _ in range(10)
|
435 |
+
)
|
436 |
+
st.session_state["passkey"] = plain_text_password
|
437 |
+
|
438 |
+
if add_user(user_id, user_name, user_type, emp_id, plain_text_password):
|
439 |
+
with st.spinner("Adding New User"):
|
440 |
+
|
441 |
+
update_password_and_flag(user_id, plain_text_password)
|
442 |
+
fetch_users_with_access()
|
443 |
+
st.success("User added successfully")
|
444 |
+
st.rerun()
|
445 |
+
else:
|
446 |
+
st.warning(
|
447 |
+
f"User with emp_id {user_id} already exists in the database."
|
448 |
+
)
|
449 |
+
st.stop()
|
450 |
+
|
451 |
+
if user_manage == "Delete Users":
|
452 |
+
with st.expander("", expanded=True):
|
453 |
+
user_names_to_delete = st.multiselect(
|
454 |
+
"Select User IDS to Delete", st.session_state["df_users"]["User ID"]
|
455 |
+
)
|
456 |
+
|
457 |
+
if st.button("Delete", use_container_width=True):
|
458 |
+
delete_users_by_names(user_names_to_delete)
|
459 |
+
st.success("Users Deleted")
|
460 |
+
fetch_users_with_access()
|
461 |
+
st.rerun()
|
462 |
+
|
463 |
+
if user_manage == "Manage User Passwords":
|
464 |
+
|
465 |
+
with st.expander("**Manage User Passwords**", expanded=True):
|
466 |
+
st.markdown(
|
467 |
+
"""
|
468 |
+
1.Click on "**Reset Password**" to generate a new default password for the selected user.
|
469 |
+
2. Share the pass key with the corresponding user
|
470 |
+
""",
|
471 |
+
unsafe_allow_html=True,
|
472 |
+
)
|
473 |
+
|
474 |
+
user_id = st.selectbox(
|
475 |
+
"Select A User ID",
|
476 |
+
st.session_state["df_users"]["User ID"],
|
477 |
+
on_change=reset_pass_key,
|
478 |
+
)
|
479 |
+
|
480 |
+
if "pass_key" not in st.session_state:
|
481 |
+
st.session_state["pass_key"] = ""
|
482 |
+
|
483 |
+
default_passkey = st.code(
|
484 |
+
f"Default Password: {st.session_state['pass_key']}", language="text"
|
485 |
+
)
|
486 |
+
|
487 |
+
if st.button("Reset Password", use_container_width=True, key="reset_pass_button_key"):
|
488 |
+
with st.spinner("Reseting Password"):
|
489 |
+
characters = (
|
490 |
+
string.ascii_letters + string.digits
|
491 |
+
) # Letters and digits
|
492 |
+
plain_text_password = "".join(
|
493 |
+
random.choice(characters) for _ in range(10)
|
494 |
+
)
|
495 |
+
st.session_state["pass_key"] = plain_text_password
|
496 |
+
|
497 |
+
update_password_and_flag(user_id, plain_text_password, flag=True)
|
498 |
+
time.sleep(1)
|
499 |
+
del st.session_state["reset_pass_button_key"]
|
500 |
+
st.rerun()
|
501 |
+
|
502 |
+
elif len(password):
|
503 |
+
st.warning("Wrong user name or password!")
|
pages/1_Data_Import.py
ADDED
@@ -0,0 +1,1213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Data Import",
|
6 |
+
page_icon="⚖️",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="collapsed",
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
import re
|
13 |
+
import sys
|
14 |
+
import pickle
|
15 |
+
import numbers
|
16 |
+
import traceback
|
17 |
+
import pandas as pd
|
18 |
+
from scenario import numerize
|
19 |
+
from post_gres_cred import db_cred
|
20 |
+
from collections import OrderedDict
|
21 |
+
from log_application import log_message
|
22 |
+
from utilities import set_header, load_local_css, update_db, project_selection
|
23 |
+
from constants import (
|
24 |
+
upload_rows_limit,
|
25 |
+
upload_column_limit,
|
26 |
+
word_length_limit_lower,
|
27 |
+
word_length_limit_upper,
|
28 |
+
minimum_percent_overlap,
|
29 |
+
minimum_row_req,
|
30 |
+
percent_drop_col_threshold,
|
31 |
+
)
|
32 |
+
|
33 |
+
schema = db_cred["schema"]
|
34 |
+
load_local_css("styles.css")
|
35 |
+
set_header()
|
36 |
+
|
37 |
+
|
38 |
+
# Initialize project name session state
|
39 |
+
if "project_name" not in st.session_state:
|
40 |
+
st.session_state["project_name"] = None
|
41 |
+
|
42 |
+
# Fetch project dictionary
|
43 |
+
if "project_dct" not in st.session_state:
|
44 |
+
project_selection()
|
45 |
+
st.stop()
|
46 |
+
|
47 |
+
# Display Username and Project Name
|
48 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
49 |
+
|
50 |
+
cols1 = st.columns([2, 1])
|
51 |
+
|
52 |
+
with cols1[0]:
|
53 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
54 |
+
with cols1[1]:
|
55 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
56 |
+
|
57 |
+
|
58 |
+
# Initialize session state keys
|
59 |
+
if "granularity_selection_key" not in st.session_state:
|
60 |
+
st.session_state["granularity_selection_key"] = st.session_state["project_dct"][
|
61 |
+
"data_import"
|
62 |
+
]["granularity_selection"]
|
63 |
+
|
64 |
+
|
65 |
+
# Function to format name
|
66 |
+
def name_format_func(name):
|
67 |
+
return str(name).strip().title()
|
68 |
+
|
69 |
+
|
70 |
+
# Function to get columns with specified prefix and remove prefix
|
71 |
+
def get_columns_with_prefix(df, prefix):
|
72 |
+
return [
|
73 |
+
col.replace(prefix, "")
|
74 |
+
for col in df.columns
|
75 |
+
if col.startswith(prefix) and str(col) != str(prefix)
|
76 |
+
]
|
77 |
+
|
78 |
+
|
79 |
+
# Function to fetch columns info
|
80 |
+
@st.cache_data(show_spinner=False)
|
81 |
+
def fetch_columns(gold_layer_df, data_upload_df):
|
82 |
+
# Get lists of columns starting with 'spends_' and 'response_metric_' from gold_layer_df
|
83 |
+
spends_columns_gold_layer = get_columns_with_prefix(gold_layer_df, "spends_")
|
84 |
+
response_metric_columns_gold_layer = get_columns_with_prefix(
|
85 |
+
gold_layer_df, "response_metric_"
|
86 |
+
)
|
87 |
+
|
88 |
+
# Get lists of columns starting with 'spends_' and 'response_metric_' from data_upload_df
|
89 |
+
spends_columns_upload = get_columns_with_prefix(data_upload_df, "spends_")
|
90 |
+
response_metric_columns_upload = get_columns_with_prefix(
|
91 |
+
data_upload_df, "response_metric_"
|
92 |
+
)
|
93 |
+
|
94 |
+
# Combine lists from both DataFrames
|
95 |
+
spends_columns = spends_columns_gold_layer + spends_columns_upload
|
96 |
+
# Remove 'total' from the spends_columns list if it exists
|
97 |
+
spends_columns = list(
|
98 |
+
set([col for col in spends_columns if not col.endswith("_total")])
|
99 |
+
)
|
100 |
+
|
101 |
+
response_metric_columns = (
|
102 |
+
response_metric_columns_gold_layer + response_metric_columns_upload
|
103 |
+
)
|
104 |
+
# Filter columns ending with '_total' and remove the '_total' suffix
|
105 |
+
response_metric_columns = list(
|
106 |
+
set(
|
107 |
+
[
|
108 |
+
col[:-6]
|
109 |
+
for col in response_metric_columns
|
110 |
+
if col.endswith("_total") and len(col[:-6]) != 0
|
111 |
+
]
|
112 |
+
)
|
113 |
+
)
|
114 |
+
|
115 |
+
# Get list of all columns from both DataFrames
|
116 |
+
gold_layer_columns = list(gold_layer_df.columns)
|
117 |
+
data_upload_columns = list(data_upload_df.columns)
|
118 |
+
|
119 |
+
# Combine all columns and get unique columns
|
120 |
+
all_columns = list(set(gold_layer_columns + data_upload_columns))
|
121 |
+
|
122 |
+
return (
|
123 |
+
spends_columns,
|
124 |
+
response_metric_columns,
|
125 |
+
all_columns,
|
126 |
+
gold_layer_columns,
|
127 |
+
data_upload_columns,
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
# Function to format values for display
|
132 |
+
@st.cache_data(show_spinner=False)
|
133 |
+
def format_values_for_display(values_list):
|
134 |
+
# Format value
|
135 |
+
formatted_list = [value.lower().strip() for value in values_list]
|
136 |
+
# Join values with commas and 'and' before the last value
|
137 |
+
if len(formatted_list) > 1:
|
138 |
+
return ", ".join(formatted_list[:-1]) + ", and " + formatted_list[-1]
|
139 |
+
elif formatted_list:
|
140 |
+
return formatted_list[0]
|
141 |
+
return "No values available"
|
142 |
+
|
143 |
+
|
144 |
+
# Function to validate input DataFrame
|
145 |
+
@st.cache_data(show_spinner=False)
|
146 |
+
def valid_input_df(
|
147 |
+
df,
|
148 |
+
spends_columns,
|
149 |
+
response_metric_columns,
|
150 |
+
total_columns,
|
151 |
+
gold_layer_columns,
|
152 |
+
data_upload_columns,
|
153 |
+
):
|
154 |
+
# Check if DataFrame is empty
|
155 |
+
if df.empty or len(df) < 1:
|
156 |
+
return (True, None)
|
157 |
+
|
158 |
+
# Check for invalid column names
|
159 |
+
invalid_columns = [
|
160 |
+
col
|
161 |
+
for col in df.columns
|
162 |
+
if not re.match(r"^[A-Za-z0-9_]+$", col)
|
163 |
+
or not (word_length_limit_lower <= len(col) <= word_length_limit_upper)
|
164 |
+
]
|
165 |
+
if invalid_columns:
|
166 |
+
return (
|
167 |
+
False,
|
168 |
+
f"Invalid column names: {format_values_for_display(invalid_columns)}. Use only letters, numbers, and underscores. Column name length should be {word_length_limit_lower} to {word_length_limit_upper} characters long.",
|
169 |
+
)
|
170 |
+
|
171 |
+
# Ensure 'panel' column values are strings and conform to specified pattern and length
|
172 |
+
if "panel" in df.columns:
|
173 |
+
df["panel"] = df["panel"].astype(str).str.strip()
|
174 |
+
invalid_panel_values = [
|
175 |
+
val
|
176 |
+
for val in df["panel"].unique()
|
177 |
+
if not re.match(r"^[A-Za-z0-9_]+$", val)
|
178 |
+
or not (word_length_limit_lower <= len(val) <= word_length_limit_upper)
|
179 |
+
]
|
180 |
+
if invalid_panel_values:
|
181 |
+
return (
|
182 |
+
False,
|
183 |
+
f"Invalid panel values: {format_values_for_display(invalid_panel_values)}. Use only letters, numbers, and underscores. Panel name length should be {word_length_limit_lower} to {word_length_limit_upper} characters long.",
|
184 |
+
)
|
185 |
+
|
186 |
+
# Check for missing required columns
|
187 |
+
required_columns = ["date", "panel"]
|
188 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
189 |
+
if missing_columns:
|
190 |
+
return (
|
191 |
+
False,
|
192 |
+
f"Missing compulsory columns: {format_values_for_display(missing_columns)}.",
|
193 |
+
)
|
194 |
+
|
195 |
+
# Check if all other columns are numeric
|
196 |
+
non_numeric_columns = [
|
197 |
+
col
|
198 |
+
for col in df.columns
|
199 |
+
if col not in required_columns and not pd.api.types.is_numeric_dtype(df[col])
|
200 |
+
]
|
201 |
+
if non_numeric_columns:
|
202 |
+
return (
|
203 |
+
False,
|
204 |
+
f"Non-numeric columns: {format_values_for_display(non_numeric_columns)}. All columns except {format_values_for_display(required_columns)} should be numeric.",
|
205 |
+
)
|
206 |
+
|
207 |
+
# Ensure all columns in data_upload_columns are unique
|
208 |
+
duplicate_columns_in_upload = [
|
209 |
+
col for col in data_upload_columns if data_upload_columns.count(col) > 1
|
210 |
+
]
|
211 |
+
if duplicate_columns_in_upload:
|
212 |
+
return (
|
213 |
+
False,
|
214 |
+
f"Duplicate columns found in the uploaded data: {format_values_for_display(set(duplicate_columns_in_upload))}.",
|
215 |
+
)
|
216 |
+
|
217 |
+
# Convert 'date' column to datetime format
|
218 |
+
try:
|
219 |
+
df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d")
|
220 |
+
except:
|
221 |
+
return False, "The 'date' column is not in the correct format 'YYYY-MM-DD'."
|
222 |
+
|
223 |
+
# Check date frequency
|
224 |
+
unique_panels = df["panel"].unique()
|
225 |
+
for panel in unique_panels:
|
226 |
+
date_diff = df[df["panel"] == panel]["date"].diff().dropna()
|
227 |
+
if not (
|
228 |
+
(date_diff == pd.Timedelta(days=1)).all()
|
229 |
+
or (date_diff == pd.Timedelta(weeks=1)).all()
|
230 |
+
):
|
231 |
+
return False, "The 'date' column does not have a daily or weekly frequency."
|
232 |
+
|
233 |
+
# Check for null values in 'date' or 'panel' columns
|
234 |
+
if df[required_columns].isnull().any().any():
|
235 |
+
return (
|
236 |
+
False,
|
237 |
+
f"The {format_values_for_display(required_columns)} should not contain null values.",
|
238 |
+
)
|
239 |
+
|
240 |
+
# Check for panels with less than 1% date overlap
|
241 |
+
if not gold_layer_df.empty:
|
242 |
+
panels_with_low_overlap = []
|
243 |
+
unique_panels = list(
|
244 |
+
set(df["panel"].unique()).union(set(gold_layer_df["panel"].unique()))
|
245 |
+
)
|
246 |
+
for panel in unique_panels:
|
247 |
+
gold_layer_dates = set(
|
248 |
+
gold_layer_df[gold_layer_df["panel"] == panel]["date"]
|
249 |
+
)
|
250 |
+
data_upload_dates = set(df[df["panel"] == panel]["date"])
|
251 |
+
if gold_layer_dates and data_upload_dates:
|
252 |
+
overlap = len(gold_layer_dates & data_upload_dates) / len(
|
253 |
+
gold_layer_dates | data_upload_dates
|
254 |
+
)
|
255 |
+
else:
|
256 |
+
overlap = 0
|
257 |
+
if overlap < (minimum_percent_overlap / 100):
|
258 |
+
panels_with_low_overlap.append(panel)
|
259 |
+
|
260 |
+
if panels_with_low_overlap:
|
261 |
+
return (
|
262 |
+
False,
|
263 |
+
f"Date columns in the gold layer and uploaded data do not have at least {minimum_percent_overlap}% overlap for panels: {format_values_for_display(panels_with_low_overlap)}.",
|
264 |
+
)
|
265 |
+
|
266 |
+
# Check if spends_columns is less than two
|
267 |
+
if len(spends_columns) < 2:
|
268 |
+
return False, "Please add at least two spends columns."
|
269 |
+
|
270 |
+
# Check if response_metric_columns is empty
|
271 |
+
if len(response_metric_columns) < 1:
|
272 |
+
return False, "Please add response metric columns."
|
273 |
+
|
274 |
+
# Check if all numeric columns are positive except those starting with 'exogenous_' or 'internal_'
|
275 |
+
valid_prefixes = ["exogenous_", "internal_"]
|
276 |
+
negative_values_columns = [
|
277 |
+
col
|
278 |
+
for col in df.select_dtypes(include=[float, int]).columns
|
279 |
+
if not any(col.startswith(prefix) for prefix in valid_prefixes)
|
280 |
+
and (df[col] < 0).any()
|
281 |
+
]
|
282 |
+
if negative_values_columns:
|
283 |
+
return (
|
284 |
+
False,
|
285 |
+
f"Negative values detected in columns: {format_values_for_display(negative_values_columns)}. Ensure all media and response metric columns are positive.",
|
286 |
+
)
|
287 |
+
|
288 |
+
# Check for unassociated columns
|
289 |
+
detected_channels = spends_columns + ["total"]
|
290 |
+
unassociated_columns = []
|
291 |
+
for col in df.columns:
|
292 |
+
if (col.startswith("_") or col.endswith("_")) or not (
|
293 |
+
col.startswith("exogenous_") # Column starts with "exogenous_"
|
294 |
+
or col.startswith("internal_") # Column starts with "internal_"
|
295 |
+
or any(
|
296 |
+
col == f"spends_{channel}" for channel in spends_columns
|
297 |
+
) # Column is not in the format "spends_<channel>"
|
298 |
+
or any(
|
299 |
+
col == f"response_metric_{metric}_{channel}"
|
300 |
+
for metric in response_metric_columns
|
301 |
+
for channel in detected_channels
|
302 |
+
) # Column is not in the format "response_metric_<metric>_<channel>"
|
303 |
+
or any(
|
304 |
+
col.startswith("media_")
|
305 |
+
and col.endswith(f"_{channel}")
|
306 |
+
and len(col) > len(f"media__{channel}")
|
307 |
+
for channel in spends_columns
|
308 |
+
) # Column is not in the format "media_<media_variable_name>_<channel>"
|
309 |
+
or col in ["date", "panel"]
|
310 |
+
):
|
311 |
+
unassociated_columns.append(col)
|
312 |
+
|
313 |
+
if unassociated_columns:
|
314 |
+
return (
|
315 |
+
False,
|
316 |
+
f"Columns with incorrect format detected: {format_values_for_display(unassociated_columns)}.",
|
317 |
+
)
|
318 |
+
|
319 |
+
return True, "The data is valid and meets all requirements."
|
320 |
+
|
321 |
+
|
322 |
+
# Function to load the uploaded Excel file into a DataFrame
|
323 |
+
@st.cache_data(show_spinner=False)
|
324 |
+
def load_and_transform_data(uploaded_file):
|
325 |
+
# Load the uploaded file into a DataFrame if a file is uploaded
|
326 |
+
if uploaded_file is not None:
|
327 |
+
df = pd.read_excel(uploaded_file)
|
328 |
+
else:
|
329 |
+
df = pd.DataFrame()
|
330 |
+
return df
|
331 |
+
|
332 |
+
# Check if DataFrame exceeds row and column limits
|
333 |
+
if len(df) > upload_rows_limit or len(df.columns) > upload_column_limit:
|
334 |
+
st.warning(
|
335 |
+
f"Data exceeds the row limit of {numerize(upload_rows_limit)} or column limit of {numerize(upload_column_limit)}. Please upload a smaller file.",
|
336 |
+
icon="⚠️",
|
337 |
+
)
|
338 |
+
|
339 |
+
# Log message
|
340 |
+
log_message(
|
341 |
+
"warning",
|
342 |
+
f"Data exceeds the row limit of {numerize(upload_rows_limit)} or column limit of {numerize(upload_column_limit)}. Please upload a smaller file.",
|
343 |
+
"Data Import",
|
344 |
+
)
|
345 |
+
|
346 |
+
return pd.DataFrame()
|
347 |
+
|
348 |
+
# If the DataFrame contains only 'panel' and 'date' columns, return empty DataFrame
|
349 |
+
if set(df.columns) == {"date", "panel"}:
|
350 |
+
return pd.DataFrame()
|
351 |
+
|
352 |
+
# Transform column names: lower, strip start and end, replace spaces with _
|
353 |
+
df.columns = [str(col).strip().lower().replace(" ", "_") for col in df.columns]
|
354 |
+
|
355 |
+
# If 'panel' column exists, clean its values
|
356 |
+
try:
|
357 |
+
if "panel" in df.columns:
|
358 |
+
df["panel"] = (
|
359 |
+
df["panel"].astype(str).str.lower().str.strip().str.replace(" ", "_")
|
360 |
+
)
|
361 |
+
except:
|
362 |
+
return df
|
363 |
+
|
364 |
+
try:
|
365 |
+
df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d")
|
366 |
+
except:
|
367 |
+
# The 'date' column is not in the correct format 'YYYY-MM-DD'
|
368 |
+
return df
|
369 |
+
|
370 |
+
# Check date frequency and convert to daily if needed
|
371 |
+
date_diff = df["date"].diff().dropna()
|
372 |
+
if (date_diff == pd.Timedelta(days=1)).all():
|
373 |
+
# Data is already at daily level
|
374 |
+
return df
|
375 |
+
elif (date_diff == pd.Timedelta(weeks=1)).all():
|
376 |
+
# Data is at weekly level, convert to daily
|
377 |
+
weekly_data = df.copy()
|
378 |
+
daily_data = []
|
379 |
+
|
380 |
+
for index, row in weekly_data.iterrows():
|
381 |
+
week_start = row["date"] - pd.to_timedelta(row["date"].weekday(), unit="D")
|
382 |
+
for i in range(7):
|
383 |
+
daily_date = week_start + pd.DateOffset(days=i)
|
384 |
+
new_row = row.copy()
|
385 |
+
new_row["date"] = daily_date
|
386 |
+
for col in df.columns:
|
387 |
+
if isinstance(new_row[col], numbers.Number):
|
388 |
+
new_row[col] = new_row[col] / 7
|
389 |
+
daily_data.append(new_row)
|
390 |
+
|
391 |
+
daily_data_df = pd.DataFrame(daily_data)
|
392 |
+
return daily_data_df
|
393 |
+
else:
|
394 |
+
# The 'date' column does not have a daily or weekly frequency
|
395 |
+
return df
|
396 |
+
|
397 |
+
|
398 |
+
# Function to merge DataFrames if present
|
399 |
+
@st.cache_data(show_spinner=False)
|
400 |
+
def merge_dataframes(gold_layer_df, data_upload_df):
|
401 |
+
if gold_layer_df.empty and data_upload_df.empty:
|
402 |
+
return pd.DataFrame()
|
403 |
+
|
404 |
+
if not gold_layer_df.empty and not data_upload_df.empty:
|
405 |
+
# Merge gold_layer_df and data_upload_df on 'panel', and 'date'
|
406 |
+
merged_df = pd.merge(
|
407 |
+
gold_layer_df,
|
408 |
+
data_upload_df,
|
409 |
+
on=["panel", "date"],
|
410 |
+
how="outer",
|
411 |
+
suffixes=("_gold", "_upload"),
|
412 |
+
)
|
413 |
+
|
414 |
+
# Handle duplicate columns
|
415 |
+
for col in merged_df.columns:
|
416 |
+
if col.endswith("_gold"):
|
417 |
+
base_col = col[:-5] # Remove '_gold' suffix
|
418 |
+
upload_col = base_col + "_upload" # Column name in data_upload_df
|
419 |
+
if upload_col in merged_df.columns:
|
420 |
+
# Prefer values from data_upload_df
|
421 |
+
merged_df[base_col] = merged_df[upload_col].combine_first(
|
422 |
+
merged_df[col]
|
423 |
+
)
|
424 |
+
merged_df.drop(columns=[col, upload_col], inplace=True)
|
425 |
+
else:
|
426 |
+
# Rename column to remove the suffix
|
427 |
+
merged_df.rename(columns={col: base_col}, inplace=True)
|
428 |
+
|
429 |
+
elif data_upload_df.empty:
|
430 |
+
merged_df = gold_layer_df.copy()
|
431 |
+
|
432 |
+
elif gold_layer_df.empty:
|
433 |
+
merged_df = data_upload_df.copy()
|
434 |
+
|
435 |
+
return merged_df
|
436 |
+
|
437 |
+
|
438 |
+
# Function to check if all required columns are present in the Uploaded DataFrame
|
439 |
+
@st.cache_data(show_spinner=False)
|
440 |
+
def check_required_columns(df, detected_channels, detected_response_metric):
|
441 |
+
required_columns = []
|
442 |
+
|
443 |
+
# Add all channels with 'spends_' + detected channel name
|
444 |
+
for channel in detected_channels:
|
445 |
+
required_columns.append(f"spends_{channel}")
|
446 |
+
|
447 |
+
# Add all channels with 'response_metric_' + detected channel name
|
448 |
+
for response_metric in detected_response_metric:
|
449 |
+
for channel in detected_channels + ["total"]:
|
450 |
+
required_columns.append(f"response_metric_{response_metric}_{channel}")
|
451 |
+
|
452 |
+
# Check for missing columns
|
453 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
454 |
+
|
455 |
+
# Channel groupings
|
456 |
+
no_media_data = []
|
457 |
+
channel_columns_dict = {}
|
458 |
+
for channel in detected_channels:
|
459 |
+
channel_columns = [
|
460 |
+
col
|
461 |
+
for col in merged_df.columns
|
462 |
+
if channel in col
|
463 |
+
and not (
|
464 |
+
col.startswith("response_metric_")
|
465 |
+
or col.startswith("exogenous_")
|
466 |
+
or col.startswith("internal_")
|
467 |
+
)
|
468 |
+
and col.endswith(channel)
|
469 |
+
]
|
470 |
+
channel_columns_dict[channel] = channel_columns
|
471 |
+
|
472 |
+
if len(channel_columns) <= 1:
|
473 |
+
no_media_data.append(channel)
|
474 |
+
|
475 |
+
return missing_columns, no_media_data, channel_columns_dict
|
476 |
+
|
477 |
+
|
478 |
+
# Function to prepare tool DataFrame
|
479 |
+
def prepare_tool_df(merged_df, granularity_selection):
|
480 |
+
# Drop all response metric columns that do not end with '_total'
|
481 |
+
cols_to_drop = [
|
482 |
+
col
|
483 |
+
for col in merged_df.columns
|
484 |
+
if col.startswith("response_metric_") and not col.endswith("_total")
|
485 |
+
]
|
486 |
+
|
487 |
+
# Create a DataFrame to be used for the tool
|
488 |
+
tool_df = merged_df.drop(columns=cols_to_drop)
|
489 |
+
|
490 |
+
# Convert to weekly granularity by aggregating all data for given panel and week
|
491 |
+
if granularity_selection.lower() == "weekly":
|
492 |
+
tool_df.set_index("date", inplace=True)
|
493 |
+
tool_df = (
|
494 |
+
tool_df.groupby(
|
495 |
+
[pd.Grouper(freq="W-MON", closed="left", label="left"), "panel"]
|
496 |
+
)
|
497 |
+
.sum()
|
498 |
+
.reset_index()
|
499 |
+
)
|
500 |
+
|
501 |
+
return tool_df
|
502 |
+
|
503 |
+
|
504 |
+
# Function to generate imputation DataFrame
|
505 |
+
def generate_imputation_df(tool_df):
|
506 |
+
# Initialize lists to store the column details
|
507 |
+
column_names = []
|
508 |
+
categories = []
|
509 |
+
missing_values_info = []
|
510 |
+
zero_values_info = []
|
511 |
+
imputation_methods = []
|
512 |
+
|
513 |
+
# Define the function to calculate the percentage of missing values
|
514 |
+
def calculate_missing_percentage(series):
|
515 |
+
return series.isnull().sum(), (series.isnull().mean() * 100)
|
516 |
+
|
517 |
+
# Define the function to calculate the percentage of zero values
|
518 |
+
def calculate_zero_percentage(series):
|
519 |
+
return (series == 0).sum(), ((series == 0).mean() * 100)
|
520 |
+
|
521 |
+
# Iterate over each column to categorize and calculate missing and zero values
|
522 |
+
for col in tool_df.columns:
|
523 |
+
# Determine category based on column name prefix
|
524 |
+
if col == "date" or col == "panel":
|
525 |
+
continue
|
526 |
+
elif col.startswith("response_metric_"):
|
527 |
+
categories.append("Response Metrics")
|
528 |
+
elif col.startswith("spends_"):
|
529 |
+
categories.append("Spends")
|
530 |
+
elif col.startswith("exogenous_"):
|
531 |
+
categories.append("Exogenous")
|
532 |
+
elif col.startswith("internal_"):
|
533 |
+
categories.append("Internal")
|
534 |
+
else:
|
535 |
+
categories.append("Media")
|
536 |
+
|
537 |
+
# Calculate missing values and percentage
|
538 |
+
missing_count, missing_percentage = calculate_missing_percentage(tool_df[col])
|
539 |
+
missing_values_info.append(f"{missing_count} ({missing_percentage:.1f}%)")
|
540 |
+
|
541 |
+
# Calculate zero values and percentage
|
542 |
+
zero_count, zero_percentage = calculate_zero_percentage(tool_df[col])
|
543 |
+
zero_values_info.append(f"{zero_count} ({zero_percentage:.1f}%)")
|
544 |
+
|
545 |
+
# Determine default imputation method based on conditions
|
546 |
+
if col.startswith("spends_"):
|
547 |
+
imputation_methods.append("Fill with 0")
|
548 |
+
elif col.startswith("response_metric_"):
|
549 |
+
imputation_methods.append("Fill with Mean")
|
550 |
+
elif zero_percentage + missing_percentage > percent_drop_col_threshold:
|
551 |
+
imputation_methods.append("Drop Column")
|
552 |
+
else:
|
553 |
+
imputation_methods.append("Fill with Mean")
|
554 |
+
|
555 |
+
column_names.append(col)
|
556 |
+
|
557 |
+
# Create the DataFrame
|
558 |
+
imputation_df = pd.DataFrame(
|
559 |
+
{
|
560 |
+
"Column Name": column_names,
|
561 |
+
"Category": categories,
|
562 |
+
"Missing Values": missing_values_info,
|
563 |
+
"Zero Values": zero_values_info,
|
564 |
+
"Imputation Method": imputation_methods,
|
565 |
+
}
|
566 |
+
)
|
567 |
+
|
568 |
+
# Define the category order for sorting
|
569 |
+
category_order = {
|
570 |
+
"Response Metrics": 1,
|
571 |
+
"Spends": 2,
|
572 |
+
"Media": 3,
|
573 |
+
"Exogenous": 4,
|
574 |
+
"Internal": 5,
|
575 |
+
}
|
576 |
+
|
577 |
+
# Add a temporary column for sorting based on the category order
|
578 |
+
imputation_df["Category Order"] = imputation_df["Category"].map(category_order)
|
579 |
+
|
580 |
+
# Sort the DataFrame based on the category order and then drop the temporary column
|
581 |
+
imputation_df = imputation_df.sort_values(
|
582 |
+
by=["Category Order", "Column Name"]
|
583 |
+
).drop(columns=["Category Order"])
|
584 |
+
|
585 |
+
return imputation_df
|
586 |
+
|
587 |
+
|
588 |
+
# Function to perform imputation as per user requests
|
589 |
+
def perform_imputation(imputation_df, tool_df):
|
590 |
+
# Detect channels associated with spends
|
591 |
+
detected_channels = [
|
592 |
+
col.replace("spends_", "")
|
593 |
+
for col in tool_df.columns
|
594 |
+
if col.startswith("spends_")
|
595 |
+
]
|
596 |
+
|
597 |
+
# Create a dictionary with keys as channels and values as associated columns
|
598 |
+
group_dict = {
|
599 |
+
channel: [
|
600 |
+
col
|
601 |
+
for col in tool_df.columns
|
602 |
+
if channel in col
|
603 |
+
and not (
|
604 |
+
col.startswith("response_metric_")
|
605 |
+
or col.startswith("exogenous_")
|
606 |
+
or col.startswith("internal_")
|
607 |
+
)
|
608 |
+
]
|
609 |
+
for channel in detected_channels
|
610 |
+
}
|
611 |
+
|
612 |
+
# Create a reverse dictionary with keys as columns and values as channels
|
613 |
+
column_to_channel_dict = {
|
614 |
+
col: channel for channel, cols in group_dict.items() for col in cols
|
615 |
+
}
|
616 |
+
|
617 |
+
# Perform imputation
|
618 |
+
already_dropped = []
|
619 |
+
for index, row in imputation_df.iterrows():
|
620 |
+
col_name = row["Column Name"]
|
621 |
+
impute_method = row["Imputation Method"]
|
622 |
+
|
623 |
+
# Skip already dropped columns
|
624 |
+
if col_name in already_dropped:
|
625 |
+
continue
|
626 |
+
|
627 |
+
# Skip imputation if dropping response metric column and add warning
|
628 |
+
if impute_method == "Drop Column" and col_name.startswith("response_metric_"):
|
629 |
+
return None, {}, f"Cannot drop response metric column: {col_name}"
|
630 |
+
|
631 |
+
# Drop column if requested
|
632 |
+
if impute_method == "Drop Column":
|
633 |
+
# If spends column is dropped, drop all related columns
|
634 |
+
if col_name.startswith("spends_"):
|
635 |
+
tool_df.drop(
|
636 |
+
columns=group_dict[col_name.replace("spends_", "")],
|
637 |
+
inplace=True,
|
638 |
+
)
|
639 |
+
already_dropped += group_dict[col_name.replace("spends_", "")]
|
640 |
+
del group_dict[col_name.replace("spends_", "")]
|
641 |
+
else:
|
642 |
+
tool_df.drop(columns=[col_name], inplace=True)
|
643 |
+
if not (
|
644 |
+
col_name.startswith("exogenous_")
|
645 |
+
or col_name.startswith("internal_")
|
646 |
+
):
|
647 |
+
group_name = column_to_channel_dict[col_name]
|
648 |
+
group_dict[group_name].remove(col_name)
|
649 |
+
|
650 |
+
# Check for channels with one or fewer associated columns and add warning if needed
|
651 |
+
if len(group_dict[group_name]) <= 1:
|
652 |
+
return (
|
653 |
+
None,
|
654 |
+
{},
|
655 |
+
f"No media variable associated with category {col_name.replace('spends_', '')}.",
|
656 |
+
)
|
657 |
+
continue
|
658 |
+
|
659 |
+
# Check for each panel
|
660 |
+
for panel in tool_df["panel"].unique():
|
661 |
+
panel_df = tool_df[tool_df["panel"] == panel]
|
662 |
+
|
663 |
+
# Check if the column is entirely null or empty for the current panel
|
664 |
+
if panel_df[col_name].isnull().all():
|
665 |
+
if impute_method in ["Fill with Mean", "Fill with Median"]:
|
666 |
+
return (
|
667 |
+
None,
|
668 |
+
{},
|
669 |
+
f"Cannot impute for empty column(s) with mean or median. Select 'Fill with 0'. Details: Panel: {panel}, Column: {col_name}",
|
670 |
+
)
|
671 |
+
|
672 |
+
# Fill missing values as requested
|
673 |
+
if impute_method == "Fill with Mean":
|
674 |
+
tool_df[col_name] = tool_df.groupby("panel")[col_name].transform(
|
675 |
+
lambda x: x.fillna(x.mean())
|
676 |
+
)
|
677 |
+
elif impute_method == "Fill with Median":
|
678 |
+
tool_df[col_name] = tool_df.groupby("panel")[col_name].transform(
|
679 |
+
lambda x: x.fillna(x.median())
|
680 |
+
)
|
681 |
+
elif impute_method == "Fill with 0":
|
682 |
+
tool_df[col_name].fillna(0, inplace=True)
|
683 |
+
|
684 |
+
# Check if final DataFrame has at least one response metric and two spends categories
|
685 |
+
response_metrics = [
|
686 |
+
col for col in tool_df.columns if col.startswith("response_metric_")
|
687 |
+
]
|
688 |
+
spends_categories = [col for col in tool_df.columns if col.startswith("spends_")]
|
689 |
+
|
690 |
+
if len(response_metrics) < 1:
|
691 |
+
return (None, {}, "The final DataFrame must have at least one response metric.")
|
692 |
+
if len(spends_categories) < 2:
|
693 |
+
return (
|
694 |
+
None,
|
695 |
+
{},
|
696 |
+
"The final DataFrame must have at least two spends categories.",
|
697 |
+
)
|
698 |
+
|
699 |
+
return tool_df, group_dict, "Imputed Successfully!"
|
700 |
+
|
701 |
+
|
702 |
+
# Function to display groups with custom styling
|
703 |
+
def display_groups(input_dict):
|
704 |
+
# Define custom CSS for pastel light blue rounded rectangle
|
705 |
+
custom_css = """
|
706 |
+
<style>
|
707 |
+
.group-box {
|
708 |
+
background-color: #ffdaab;
|
709 |
+
border-radius: 10px;
|
710 |
+
padding: 10px;
|
711 |
+
margin: 5px 0;
|
712 |
+
}
|
713 |
+
</style>
|
714 |
+
"""
|
715 |
+
st.markdown(custom_css, unsafe_allow_html=True)
|
716 |
+
|
717 |
+
for group_name, values in input_dict.items():
|
718 |
+
group_html = f"<div class='group-box'><strong>{group_name}:</strong> {format_values_for_display(values)}</div>"
|
719 |
+
st.markdown(group_html, unsafe_allow_html=True)
|
720 |
+
|
721 |
+
|
722 |
+
# Function to categorize columns and create an ordered dictionary
|
723 |
+
def create_ordered_category_dict(df):
|
724 |
+
category_dict = {
|
725 |
+
"Response Metrics": [],
|
726 |
+
"Spends": [],
|
727 |
+
"Media": [],
|
728 |
+
"Exogenous": [],
|
729 |
+
"Internal": [],
|
730 |
+
}
|
731 |
+
|
732 |
+
# Define the category order for sorting
|
733 |
+
category_order = {
|
734 |
+
"Response Metrics": 1,
|
735 |
+
"Spends": 2,
|
736 |
+
"Media": 3,
|
737 |
+
"Exogenous": 4,
|
738 |
+
"Internal": 5,
|
739 |
+
}
|
740 |
+
|
741 |
+
for column in df.columns:
|
742 |
+
if column == "date" or column == "panel":
|
743 |
+
continue # Skip 'date' and 'panel' columns
|
744 |
+
|
745 |
+
if column.startswith("response_metric_"):
|
746 |
+
category_dict["Response Metrics"].append(column)
|
747 |
+
elif column.startswith("spends_"):
|
748 |
+
category_dict["Spends"].append(column)
|
749 |
+
elif column.startswith("exogenous_"):
|
750 |
+
category_dict["Exogenous"].append(column)
|
751 |
+
elif column.startswith("internal_"):
|
752 |
+
category_dict["Internal"].append(column)
|
753 |
+
else:
|
754 |
+
category_dict["Media"].append(column)
|
755 |
+
|
756 |
+
# Sort the dictionary based on the defined category order
|
757 |
+
sorted_category_dict = OrderedDict(
|
758 |
+
sorted(category_dict.items(), key=lambda item: category_order[item[0]])
|
759 |
+
)
|
760 |
+
|
761 |
+
return sorted_category_dict
|
762 |
+
|
763 |
+
|
764 |
+
try:
|
765 |
+
# Page Title
|
766 |
+
st.title("Data Import")
|
767 |
+
|
768 |
+
# Create file uploader
|
769 |
+
uploaded_file = st.file_uploader(
|
770 |
+
"Upload Data", type=["xlsx"], accept_multiple_files=False
|
771 |
+
)
|
772 |
+
|
773 |
+
# Expander with markdown for upload rules
|
774 |
+
with st.expander("Upload Rules and Guidelines"):
|
775 |
+
st.markdown(
|
776 |
+
"""
|
777 |
+
### Upload Guidelines
|
778 |
+
|
779 |
+
Please ensure your data adheres to the following rules:
|
780 |
+
|
781 |
+
1. **File Format**:
|
782 |
+
- Upload all data in a single Excel file.
|
783 |
+
|
784 |
+
2. **Compulsory Columns**:
|
785 |
+
- **Date**: Must be in the format `YYYY-MM-DD` only.
|
786 |
+
- **Panel**: If no panel data exists, use `aggregated` as a single panel.
|
787 |
+
|
788 |
+
3. **Column Naming Conventions**:
|
789 |
+
- All columns should start with the associated category prefix.
|
790 |
+
|
791 |
+
**Examples**:
|
792 |
+
|
793 |
+
- **Response Metric Column**:
|
794 |
+
- Format: `response_metric_<response_metric_name>_<channel_name>`
|
795 |
+
- Example: `response_metric_revenue_facebook`
|
796 |
+
|
797 |
+
- **Total Response Metric**:
|
798 |
+
- Format: `response_metric_<response_metric_name>_total`
|
799 |
+
- Example: `response_metric_revenue_total`
|
800 |
+
|
801 |
+
- **Spend Column**:
|
802 |
+
- Format: `spends_<channel_name>`
|
803 |
+
- Example: `spends_facebook`
|
804 |
+
|
805 |
+
- **Media Column**:
|
806 |
+
- Format: `media_<media_variable_name>_<channel_name>`
|
807 |
+
- Example: `media_clicks_facebook`
|
808 |
+
|
809 |
+
- **Exogenous Column**:
|
810 |
+
- Format: `exogenous_<variable_name>`
|
811 |
+
- Example: `exogenous_unemployment_rate`
|
812 |
+
|
813 |
+
- **Internal Column**:
|
814 |
+
- Format: `internal_<variable_name>`
|
815 |
+
- Example: `internal_discount`
|
816 |
+
|
817 |
+
**Notes**:
|
818 |
+
|
819 |
+
- The `total` response metric should represent the total for a particular date and panel, including all channels and organic contributions.
|
820 |
+
- The `date` column for weekly data should be the Monday of that week, representing the data from that Monday to the following Sunday. Example: If the week starts on Monday, August 5th, 2024, and ends on Sunday, August 11th, 2024, the date column for that week should display 2024-08-05.
|
821 |
+
"""
|
822 |
+
)
|
823 |
+
|
824 |
+
# Upload warning placeholder
|
825 |
+
upload_warning_placeholder = st.container()
|
826 |
+
|
827 |
+
# Load the uploaded file into a DataFrame if a file is uploaded
|
828 |
+
data_upload_df = load_and_transform_data(uploaded_file)
|
829 |
+
|
830 |
+
# Columns for user input
|
831 |
+
granularity_col, validate_process_col = st.columns(2)
|
832 |
+
|
833 |
+
# Dropdown for data granularity
|
834 |
+
granularity_selection = granularity_col.selectbox(
|
835 |
+
"Select data granularity",
|
836 |
+
options=["daily", "weekly"],
|
837 |
+
format_func=name_format_func,
|
838 |
+
key="granularity_selection_key",
|
839 |
+
)
|
840 |
+
|
841 |
+
# Gold Layer DataFrame
|
842 |
+
gold_layer_df = st.session_state["project_dct"]["data_import"]["gold_layer_df"]
|
843 |
+
if not gold_layer_df.empty:
|
844 |
+
st.subheader("Gold Layer DataFrame")
|
845 |
+
with st.expander("Gold Layer DataFrame"):
|
846 |
+
st.dataframe(
|
847 |
+
gold_layer_df,
|
848 |
+
hide_index=True,
|
849 |
+
column_config={
|
850 |
+
"date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
|
851 |
+
},
|
852 |
+
)
|
853 |
+
else:
|
854 |
+
st.info(
|
855 |
+
"No gold layer data is selected for this project. Please upload data manually.",
|
856 |
+
icon="📊",
|
857 |
+
)
|
858 |
+
|
859 |
+
# Check input data
|
860 |
+
with validate_process_col:
|
861 |
+
st.write("##") # Padding
|
862 |
+
|
863 |
+
if validate_process_col.button("Validate and Process", use_container_width=True):
|
864 |
+
with st.spinner("Processing ..."):
|
865 |
+
# Check if both DataFrames are empty
|
866 |
+
valid_input = True
|
867 |
+
if gold_layer_df.empty and data_upload_df.empty:
|
868 |
+
# If both gold_layer_df and data_upload_df are empty, display a warning and stop the script
|
869 |
+
st.warning(
|
870 |
+
"Both the Gold Layer data and the uploaded data are empty. Please provide at least one data source.",
|
871 |
+
icon="⚠️",
|
872 |
+
)
|
873 |
+
|
874 |
+
# Log message
|
875 |
+
log_message(
|
876 |
+
"warning",
|
877 |
+
"Both the Gold Layer data and the uploaded data are empty. Please provide at least one data source.",
|
878 |
+
"Data Import",
|
879 |
+
)
|
880 |
+
valid_input = False
|
881 |
+
|
882 |
+
# If the uploaded DataFrame is empty and the Gold Layer is not, swap them to ensure all validation conditions are checked
|
883 |
+
elif not gold_layer_df.empty and data_upload_df.empty:
|
884 |
+
data_upload_df, gold_layer_df = (
|
885 |
+
gold_layer_df.copy(),
|
886 |
+
data_upload_df.copy(),
|
887 |
+
)
|
888 |
+
valid_input = True
|
889 |
+
|
890 |
+
if valid_input:
|
891 |
+
# Fetch all necessary columns list
|
892 |
+
(
|
893 |
+
spends_columns,
|
894 |
+
response_metric_columns,
|
895 |
+
total_columns,
|
896 |
+
gold_layer_columns,
|
897 |
+
data_upload_columns,
|
898 |
+
) = fetch_columns(gold_layer_df, data_upload_df)
|
899 |
+
|
900 |
+
with upload_warning_placeholder:
|
901 |
+
valid_input, message = valid_input_df(
|
902 |
+
data_upload_df,
|
903 |
+
spends_columns,
|
904 |
+
response_metric_columns,
|
905 |
+
total_columns,
|
906 |
+
gold_layer_columns,
|
907 |
+
data_upload_columns,
|
908 |
+
)
|
909 |
+
if not valid_input:
|
910 |
+
st.warning(message, icon="⚠️")
|
911 |
+
|
912 |
+
# Log message
|
913 |
+
log_message("warning", message, "Data Import")
|
914 |
+
|
915 |
+
# Merge gold_layer_df and data_upload_df on 'panel' and 'date'
|
916 |
+
if valid_input:
|
917 |
+
merged_df = merge_dataframes(gold_layer_df, data_upload_df)
|
918 |
+
|
919 |
+
missing_columns, no_media_data, channel_columns_dict = (
|
920 |
+
check_required_columns(
|
921 |
+
merged_df, spends_columns, response_metric_columns
|
922 |
+
)
|
923 |
+
)
|
924 |
+
|
925 |
+
with upload_warning_placeholder:
|
926 |
+
# Warning for categories with no media data
|
927 |
+
if no_media_data:
|
928 |
+
st.warning(
|
929 |
+
f"Categories without media data: {format_values_for_display(no_media_data)}. Please upload at least one media column to proceed.",
|
930 |
+
icon="⚠️",
|
931 |
+
)
|
932 |
+
valid_input = False
|
933 |
+
|
934 |
+
# Log message
|
935 |
+
log_message(
|
936 |
+
"warning",
|
937 |
+
f"Categories without media data: {format_values_for_display(no_media_data)}. Please upload at least one media column to proceed.",
|
938 |
+
"Data Import",
|
939 |
+
)
|
940 |
+
|
941 |
+
# Warning for insufficient rows
|
942 |
+
elif any(
|
943 |
+
granularity_selection == "daily"
|
944 |
+
and len(merged_df[merged_df["panel"] == panel])
|
945 |
+
< minimum_row_req
|
946 |
+
for panel in merged_df["panel"].unique()
|
947 |
+
):
|
948 |
+
st.warning(
|
949 |
+
f"Insufficient data. Please provide at least {minimum_row_req} days of data for all panel.",
|
950 |
+
icon="⚠️",
|
951 |
+
)
|
952 |
+
valid_input = False
|
953 |
+
|
954 |
+
# Log message
|
955 |
+
log_message(
|
956 |
+
"warning",
|
957 |
+
f"Insufficient data. Please provide at least {minimum_row_req} days of data for all panel.",
|
958 |
+
"Data Import",
|
959 |
+
)
|
960 |
+
|
961 |
+
elif any(
|
962 |
+
granularity_selection == "weekly"
|
963 |
+
and len(merged_df[merged_df["panel"] == panel])
|
964 |
+
< minimum_row_req * 7
|
965 |
+
for panel in merged_df["panel"].unique()
|
966 |
+
):
|
967 |
+
st.warning(
|
968 |
+
f"Insufficient data. Please provide at least {minimum_row_req} weeks of data for all panel.",
|
969 |
+
icon="⚠️",
|
970 |
+
)
|
971 |
+
valid_input = False
|
972 |
+
|
973 |
+
# Log message
|
974 |
+
log_message(
|
975 |
+
"warning",
|
976 |
+
f"Insufficient data. Please provide at least {minimum_row_req} weeks of data for all panel.",
|
977 |
+
"Data Import",
|
978 |
+
)
|
979 |
+
|
980 |
+
# Info for missing columns
|
981 |
+
elif missing_columns:
|
982 |
+
st.info(
|
983 |
+
f"Missing columns: {format_values_for_display(missing_columns)}. Please upload all required columns.",
|
984 |
+
icon="💡",
|
985 |
+
)
|
986 |
+
|
987 |
+
if valid_input:
|
988 |
+
# Create a copy of the merged DataFrame for dashboard purposes
|
989 |
+
dashboard_df = merged_df
|
990 |
+
|
991 |
+
# Create a DataFrame for tool purposes
|
992 |
+
tool_df = prepare_tool_df(merged_df, granularity_selection)
|
993 |
+
|
994 |
+
# Create Imputation DataFrame
|
995 |
+
imputation_df = generate_imputation_df(tool_df)
|
996 |
+
|
997 |
+
# Save data to project dictionary
|
998 |
+
st.session_state["project_dct"]["data_import"][
|
999 |
+
"granularity_selection"
|
1000 |
+
] = st.session_state["granularity_selection_key"]
|
1001 |
+
st.session_state["project_dct"]["data_import"][
|
1002 |
+
"dashboard_df"
|
1003 |
+
] = dashboard_df
|
1004 |
+
st.session_state["project_dct"]["data_import"]["tool_df"] = tool_df
|
1005 |
+
st.session_state["project_dct"]["data_import"]["unique_panels"] = (
|
1006 |
+
tool_df["panel"].unique()
|
1007 |
+
)
|
1008 |
+
st.session_state["project_dct"]["data_import"][
|
1009 |
+
"imputation_df"
|
1010 |
+
] = imputation_df
|
1011 |
+
|
1012 |
+
# Success message
|
1013 |
+
with upload_warning_placeholder:
|
1014 |
+
st.success("Processed Successfully!", icon="🗂️")
|
1015 |
+
st.toast("Processed Successfully!", icon="🗂️")
|
1016 |
+
|
1017 |
+
# Log message
|
1018 |
+
log_message("info", "Processed Successfully!", "Data Import")
|
1019 |
+
|
1020 |
+
# Load saved data from project dictionary
|
1021 |
+
if st.session_state["project_dct"]["data_import"]["tool_df"] is None:
|
1022 |
+
st.stop()
|
1023 |
+
else:
|
1024 |
+
tool_df = st.session_state["project_dct"]["data_import"]["tool_df"]
|
1025 |
+
imputation_df = st.session_state["project_dct"]["data_import"]["imputation_df"]
|
1026 |
+
unique_panels = st.session_state["project_dct"]["data_import"]["unique_panels"]
|
1027 |
+
|
1028 |
+
# Unique Panel
|
1029 |
+
st.subheader("Unique Panel")
|
1030 |
+
|
1031 |
+
# Get unique panels count
|
1032 |
+
total_count = len(unique_panels)
|
1033 |
+
|
1034 |
+
# Define custom CSS for pastel light blue rounded rectangle
|
1035 |
+
custom_css = """
|
1036 |
+
<style>
|
1037 |
+
.panel-box {
|
1038 |
+
background-color: #ffdaab;
|
1039 |
+
border-radius: 10px;
|
1040 |
+
padding: 10px;
|
1041 |
+
margin: 0 0;
|
1042 |
+
}
|
1043 |
+
</style>
|
1044 |
+
"""
|
1045 |
+
|
1046 |
+
# Display unique panels with total count
|
1047 |
+
st.markdown(custom_css, unsafe_allow_html=True)
|
1048 |
+
panel_html = f"<div class='panel-box'><strong>Unique Panels:</strong> {format_values_for_display(unique_panels)}<br><strong>Total Count:</strong> {total_count}</div>"
|
1049 |
+
st.markdown(panel_html, unsafe_allow_html=True)
|
1050 |
+
st.write("##") # Padding
|
1051 |
+
|
1052 |
+
# Impute Missing Values
|
1053 |
+
st.subheader("Impute Missing Values")
|
1054 |
+
edited_imputation_df = st.data_editor(
|
1055 |
+
imputation_df,
|
1056 |
+
column_config={
|
1057 |
+
"Imputation Method": st.column_config.SelectboxColumn(
|
1058 |
+
options=[
|
1059 |
+
"Drop Column",
|
1060 |
+
"Fill with Mean",
|
1061 |
+
"Fill with Median",
|
1062 |
+
"Fill with 0",
|
1063 |
+
],
|
1064 |
+
required=True,
|
1065 |
+
default="Fill with 0",
|
1066 |
+
),
|
1067 |
+
},
|
1068 |
+
column_order=[
|
1069 |
+
"Column Name",
|
1070 |
+
"Category",
|
1071 |
+
"Missing Values",
|
1072 |
+
"Zero Values",
|
1073 |
+
"Imputation Method",
|
1074 |
+
],
|
1075 |
+
disabled=["Column Name", "Category", "Missing Values", "Zero Values"],
|
1076 |
+
hide_index=True,
|
1077 |
+
use_container_width=True,
|
1078 |
+
key="imputation_df_key",
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
# Expander with markdown for imputation rules
|
1082 |
+
with st.expander("Impute Missing Values Guidelines"):
|
1083 |
+
st.markdown(
|
1084 |
+
f"""
|
1085 |
+
### Imputation Guidelines
|
1086 |
+
|
1087 |
+
Please adhere to the following rules when handling missing values:
|
1088 |
+
|
1089 |
+
1. **Default Imputation Strategies**:
|
1090 |
+
- **Response Metrics**: Imputed using the **mean** value of the column.
|
1091 |
+
- **Spends**: Imputed with **zero** values.
|
1092 |
+
- **Media, Exogenous, Internal**: Imputation strategy is **dynamic** based on the data.
|
1093 |
+
|
1094 |
+
2. **Drop Threshold**:
|
1095 |
+
- If the combined percentage of **zeros** and **null values** in any column exceeds `{percent_drop_col_threshold}%`, the column will be **categorized to drop** by default which user can change manually.
|
1096 |
+
- **Example**: If `spends_facebook` has more than `{percent_drop_col_threshold}%` of zeros and nulls combined, it will be marked for dropping.
|
1097 |
+
|
1098 |
+
3. **Category Generation and Association**:
|
1099 |
+
- Categories are automatically generated from the **Spends** columns.
|
1100 |
+
- **Example**: The column `spends_facebook` will generate the **facebook** category. This means columns like `spends_facebook`, `media_impression_facebook` and `media_clicks_facebook` will also be associated with this category.
|
1101 |
+
|
1102 |
+
4. **Column Association and Imputation**:
|
1103 |
+
- Each category must have at least **one Media column** associated with it for imputation to proceed.
|
1104 |
+
- **Example**: If the **facebook** category does not have any media columns like `media_impression_facebook`, imputation will not be allowed for that category.
|
1105 |
+
- Solution: Either **drop the entire category** if it is empty, or **impute the columns** associated with the category instead of dropping them.
|
1106 |
+
|
1107 |
+
5. **Response Metrics and Category Count**:
|
1108 |
+
- Dropping **Response Metric** columns is **not allowed** under any circumstances.
|
1109 |
+
- At least **two categories** must exist after imputation, or the Imputation will not proceed.
|
1110 |
+
- **Example**: If only **facebook** remains after selection, imputation will be halted.
|
1111 |
+
|
1112 |
+
**Notes**:
|
1113 |
+
|
1114 |
+
- The decision to drop a spends column will result in all associated columns being dropped.
|
1115 |
+
- **Example**: Dropping `spends_facebook` will also drop all related columns like `media_impression_facebook` and `media_clicks_facebook`.
|
1116 |
+
"""
|
1117 |
+
)
|
1118 |
+
|
1119 |
+
# Imputation Warning Placeholder
|
1120 |
+
imputation_warning_placeholder = st.container()
|
1121 |
+
|
1122 |
+
# Save the DataFrame and dictionary from the current session
|
1123 |
+
if st.button("Impute and Save", use_container_width=True):
|
1124 |
+
with st.spinner("Imputing ..."):
|
1125 |
+
with imputation_warning_placeholder:
|
1126 |
+
# Perform Imputation
|
1127 |
+
imputed_tool_df, group_dict, message = perform_imputation(
|
1128 |
+
edited_imputation_df.copy(), tool_df.copy()
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
if imputed_tool_df is None:
|
1132 |
+
st.warning(message, icon="⚠️")
|
1133 |
+
|
1134 |
+
# Log message
|
1135 |
+
log_message("warning", message, "Data Import")
|
1136 |
+
|
1137 |
+
else:
|
1138 |
+
st.session_state["project_dct"]["data_import"][
|
1139 |
+
"imputed_tool_df"
|
1140 |
+
] = imputed_tool_df
|
1141 |
+
st.session_state["project_dct"]["data_import"][
|
1142 |
+
"imputation_df"
|
1143 |
+
] = edited_imputation_df
|
1144 |
+
st.session_state["project_dct"]["data_import"][
|
1145 |
+
"group_dict"
|
1146 |
+
] = group_dict
|
1147 |
+
st.session_state["project_dct"]["data_import"]["category_dict"] = (
|
1148 |
+
create_ordered_category_dict(imputed_tool_df)
|
1149 |
+
)
|
1150 |
+
|
1151 |
+
if imputed_tool_df is not None:
|
1152 |
+
# Update DB
|
1153 |
+
update_db(
|
1154 |
+
prj_id=st.session_state["project_number"],
|
1155 |
+
page_nam="Data Import",
|
1156 |
+
file_nam="project_dct",
|
1157 |
+
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
|
1158 |
+
schema=schema,
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
# Success message
|
1162 |
+
st.success("Saved Successfully!", icon="💾")
|
1163 |
+
st.toast("Saved Successfully!", icon="💾")
|
1164 |
+
|
1165 |
+
# Log message
|
1166 |
+
log_message("info", "Saved Successfully!", "Data Import")
|
1167 |
+
|
1168 |
+
# Load saved data from project dictionary
|
1169 |
+
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
|
1170 |
+
st.stop()
|
1171 |
+
else:
|
1172 |
+
imputed_tool_df = st.session_state["project_dct"]["data_import"][
|
1173 |
+
"imputed_tool_df"
|
1174 |
+
]
|
1175 |
+
group_dict = st.session_state["project_dct"]["data_import"]["group_dict"]
|
1176 |
+
category_dict = st.session_state["project_dct"]["data_import"]["category_dict"]
|
1177 |
+
|
1178 |
+
# Channel Groupings
|
1179 |
+
st.subheader("Channel Groupings")
|
1180 |
+
display_groups(group_dict)
|
1181 |
+
st.write("##") # Padding
|
1182 |
+
|
1183 |
+
# Variable Categorization
|
1184 |
+
st.subheader("Variable Categorization")
|
1185 |
+
display_groups(category_dict)
|
1186 |
+
st.write("##") # Padding
|
1187 |
+
|
1188 |
+
# Final DataFrame
|
1189 |
+
st.subheader("Final DataFrame")
|
1190 |
+
st.dataframe(
|
1191 |
+
imputed_tool_df,
|
1192 |
+
hide_index=True,
|
1193 |
+
column_config={
|
1194 |
+
"date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
|
1195 |
+
},
|
1196 |
+
)
|
1197 |
+
st.write("##") # Padding
|
1198 |
+
|
1199 |
+
except Exception as e:
|
1200 |
+
# Capture the error details
|
1201 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
1202 |
+
error_message = "".join(
|
1203 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
1204 |
+
)
|
1205 |
+
|
1206 |
+
# Log message
|
1207 |
+
log_message("error", f"An error occurred: {error_message}.", "Data Import")
|
1208 |
+
|
1209 |
+
# Display a warning message
|
1210 |
+
st.warning(
|
1211 |
+
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
|
1212 |
+
icon="⚠️",
|
1213 |
+
)
|
pages/2_Data_Assessment.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from data_analysis import *
|
4 |
+
import numpy as np
|
5 |
+
import pickle
|
6 |
+
import streamlit as st
|
7 |
+
from utilities import set_header, load_local_css, update_db, project_selection
|
8 |
+
from post_gres_cred import db_cred
|
9 |
+
from utilities import update_db
|
10 |
+
import re
|
11 |
+
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="Data Assessment",
|
14 |
+
page_icon=":shark:",
|
15 |
+
layout="wide",
|
16 |
+
initial_sidebar_state="collapsed",
|
17 |
+
)
|
18 |
+
|
19 |
+
schema = db_cred["schema"]
|
20 |
+
load_local_css("styles.css")
|
21 |
+
set_header()
|
22 |
+
|
23 |
+
if "username" not in st.session_state:
|
24 |
+
st.session_state["username"] = None
|
25 |
+
|
26 |
+
if "project_name" not in st.session_state:
|
27 |
+
st.session_state["project_name"] = None
|
28 |
+
|
29 |
+
if "project_dct" not in st.session_state:
|
30 |
+
project_selection()
|
31 |
+
st.stop()
|
32 |
+
|
33 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
34 |
+
|
35 |
+
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
|
36 |
+
|
37 |
+
st.error(f"Please import data from the Data Import Page")
|
38 |
+
st.stop()
|
39 |
+
|
40 |
+
st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][
|
41 |
+
"imputed_tool_df"
|
42 |
+
]
|
43 |
+
|
44 |
+
st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][
|
45 |
+
"category_dict"
|
46 |
+
]
|
47 |
+
|
48 |
+
# st.write(st.session_state['category_dict'])
|
49 |
+
cols1 = st.columns([2, 1])
|
50 |
+
|
51 |
+
with cols1[0]:
|
52 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
53 |
+
with cols1[1]:
|
54 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
55 |
+
|
56 |
+
st.title("Data Assessment")
|
57 |
+
|
58 |
+
target_variables = [
|
59 |
+
st.session_state["category_dict"][key]
|
60 |
+
for key in st.session_state["category_dict"].keys()
|
61 |
+
if key == "Response Metrics"
|
62 |
+
]
|
63 |
+
|
64 |
+
def format_display(inp):
|
65 |
+
return (
|
66 |
+
inp.title()
|
67 |
+
.replace("_", " ")
|
68 |
+
.replace("Media", "")
|
69 |
+
.replace("Cnt", "")
|
70 |
+
.strip()
|
71 |
+
)
|
72 |
+
|
73 |
+
target_variables = list(*target_variables)
|
74 |
+
target_column = st.selectbox(
|
75 |
+
"Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
|
76 |
+
target_variables,
|
77 |
+
index=st.session_state["project_dct"]["data_validation"]["target_column"],
|
78 |
+
format_func=format_display,
|
79 |
+
)
|
80 |
+
|
81 |
+
st.session_state["project_dct"]["data_validation"]["target_column"] = (
|
82 |
+
target_variables.index(target_column)
|
83 |
+
)
|
84 |
+
|
85 |
+
st.session_state["target_column"] = target_column
|
86 |
+
|
87 |
+
|
88 |
+
if "panel" not in st.session_state["cleaned_data"].columns:
|
89 |
+
st.write('True')
|
90 |
+
st.session_state["cleaned_data"]["panel"] = ["Aggregated"] * len(
|
91 |
+
st.session_state["cleaned_data"]
|
92 |
+
)
|
93 |
+
|
94 |
+
disable = True
|
95 |
+
|
96 |
+
else:
|
97 |
+
panels = st.session_state["cleaned_data"]["panel"]
|
98 |
+
|
99 |
+
disable = False
|
100 |
+
|
101 |
+
selected_panels = st.multiselect(
|
102 |
+
"Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
|
103 |
+
st.session_state["cleaned_data"]["panel"].unique(),
|
104 |
+
default=st.session_state["project_dct"]["data_validation"]["selected_panels"],
|
105 |
+
disabled=disable,
|
106 |
+
)
|
107 |
+
|
108 |
+
st.session_state["project_dct"]["data_validation"][
|
109 |
+
"selected_panels"
|
110 |
+
] = selected_panels
|
111 |
+
|
112 |
+
aggregation_dict = {
|
113 |
+
item: "sum" if key == "Media" else "mean"
|
114 |
+
for key, value in st.session_state["category_dict"].items()
|
115 |
+
for item in value
|
116 |
+
if item not in ["date", "panel"]
|
117 |
+
}
|
118 |
+
|
119 |
+
aggregation_dict = {
|
120 |
+
key: value
|
121 |
+
for key, value in aggregation_dict.items()
|
122 |
+
if key in st.session_state["cleaned_data"].columns
|
123 |
+
}
|
124 |
+
|
125 |
+
with st.expander("**Target Variable Analysis**"):
|
126 |
+
|
127 |
+
if len(selected_panels) > 0:
|
128 |
+
st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][
|
129 |
+
st.session_state["cleaned_data"]["panel"].isin(selected_panels)
|
130 |
+
]
|
131 |
+
|
132 |
+
st.session_state["Cleaned_data_panel"] = (
|
133 |
+
st.session_state["Cleaned_data_panel"]
|
134 |
+
.groupby(by="date")
|
135 |
+
.agg(aggregation_dict)
|
136 |
+
)
|
137 |
+
st.session_state["Cleaned_data_panel"] = st.session_state[
|
138 |
+
"Cleaned_data_panel"
|
139 |
+
].reset_index()
|
140 |
+
else:
|
141 |
+
# st.write(st.session_state['cleaned_data'])
|
142 |
+
st.session_state["Cleaned_data_panel"] = (
|
143 |
+
st.session_state["cleaned_data"]
|
144 |
+
.groupby(by="date")
|
145 |
+
.agg(aggregation_dict)
|
146 |
+
)
|
147 |
+
st.session_state["Cleaned_data_panel"] = st.session_state[
|
148 |
+
"Cleaned_data_panel"
|
149 |
+
].reset_index()
|
150 |
+
|
151 |
+
fig = line_plot_target(
|
152 |
+
st.session_state["Cleaned_data_panel"],
|
153 |
+
target=target_column,
|
154 |
+
title=f"{target_column} Over Time",
|
155 |
+
)
|
156 |
+
st.plotly_chart(fig, use_container_width=True)
|
157 |
+
|
158 |
+
media_channel = list(
|
159 |
+
*[
|
160 |
+
st.session_state["category_dict"][key]
|
161 |
+
for key in st.session_state["category_dict"].keys()
|
162 |
+
if key == "Media"
|
163 |
+
]
|
164 |
+
)
|
165 |
+
|
166 |
+
spends_features = list(
|
167 |
+
*[
|
168 |
+
st.session_state["category_dict"][key]
|
169 |
+
for key in st.session_state["category_dict"].keys()
|
170 |
+
if key == "Spends"
|
171 |
+
]
|
172 |
+
)
|
173 |
+
# st.write(media_channel)
|
174 |
+
|
175 |
+
exo_var = list(
|
176 |
+
*[
|
177 |
+
st.session_state["category_dict"][key]
|
178 |
+
for key in st.session_state["category_dict"].keys()
|
179 |
+
if key == "Exogenous"
|
180 |
+
]
|
181 |
+
)
|
182 |
+
internal_var = list(
|
183 |
+
*[
|
184 |
+
st.session_state["category_dict"][key]
|
185 |
+
for key in st.session_state["category_dict"].keys()
|
186 |
+
if key == "Internal"
|
187 |
+
]
|
188 |
+
)
|
189 |
+
|
190 |
+
Non_media_variables = exo_var + internal_var
|
191 |
+
|
192 |
+
st.markdown("### Annual Data Summary")
|
193 |
+
|
194 |
+
summary_df = summary(
|
195 |
+
st.session_state["Cleaned_data_panel"],
|
196 |
+
media_channel + [target_column] + spends_features,
|
197 |
+
spends=None,
|
198 |
+
Target=True,
|
199 |
+
)
|
200 |
+
|
201 |
+
st.dataframe(
|
202 |
+
summary_df.sort_index(axis=1),
|
203 |
+
use_container_width=True,
|
204 |
+
)
|
205 |
+
|
206 |
+
if st.checkbox("View Raw Data"):
|
207 |
+
st.cache_resource(show_spinner=False)
|
208 |
+
|
209 |
+
def raw_df_gen():
|
210 |
+
# Convert 'date' to datetime but do not convert to string yet for sorting
|
211 |
+
dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"])
|
212 |
+
|
213 |
+
# Concatenate the dates with other numeric columns formatted
|
214 |
+
raw_df = pd.concat(
|
215 |
+
[
|
216 |
+
dates,
|
217 |
+
st.session_state["Cleaned_data_panel"]
|
218 |
+
.select_dtypes(np.number)
|
219 |
+
.applymap(format_numbers),
|
220 |
+
],
|
221 |
+
axis=1,
|
222 |
+
)
|
223 |
+
|
224 |
+
# Now sort raw_df by the 'date' column, which is still in datetime format
|
225 |
+
sorted_raw_df = raw_df.sort_values(by="date", ascending=True)
|
226 |
+
|
227 |
+
# After sorting, convert 'date' to string format for display
|
228 |
+
sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y")
|
229 |
+
|
230 |
+
return sorted_raw_df
|
231 |
+
|
232 |
+
# Display the sorted DataFrame in Streamlit
|
233 |
+
st.dataframe(raw_df_gen())
|
234 |
+
|
235 |
+
col1 = st.columns(1)
|
236 |
+
|
237 |
+
if "selected_feature" not in st.session_state:
|
238 |
+
st.session_state["selected_feature"] = None
|
239 |
+
|
240 |
+
# st.warning('Work in Progress')
|
241 |
+
with st.expander("Media Variables Analysis"):
|
242 |
+
# Get the selected feature
|
243 |
+
|
244 |
+
st.session_state["selected_feature"] = st.selectbox(
|
245 |
+
"Select Media", media_channel + spends_features, format_func=format_display
|
246 |
+
)
|
247 |
+
|
248 |
+
# st.write(st.session_state["selected_feature"].split('cnt_')[1] )
|
249 |
+
# st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
|
250 |
+
|
251 |
+
# )
|
252 |
+
|
253 |
+
# Filter spends features based on the selected feature
|
254 |
+
spends_col = st.columns(2)
|
255 |
+
spends_feature = [
|
256 |
+
col
|
257 |
+
for col in spends_features
|
258 |
+
if re.split(r"cost_|spends_", col.lower())[1]
|
259 |
+
in st.session_state["selected_feature"]
|
260 |
+
]
|
261 |
+
|
262 |
+
with spends_col[0]:
|
263 |
+
if len(spends_feature) == 0:
|
264 |
+
st.warning(
|
265 |
+
"The selected metric does not include a 'spends' variable in the data. Please verify that the columns are correctly named or select the appropriate columns in the provided selection box."
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
st.write(
|
269 |
+
f'Selected "{spends_feature[0]}" as the corresponding spends variable automatically. If this is incorrect, please click the checkbox to change the variable.'
|
270 |
+
)
|
271 |
+
|
272 |
+
with spends_col[1]:
|
273 |
+
if len(spends_feature) == 0 or st.checkbox(
|
274 |
+
'Select "Spends" variable for CPM and CPC calculation'
|
275 |
+
):
|
276 |
+
spends_feature = [st.selectbox("Spends Variable", spends_features)]
|
277 |
+
|
278 |
+
if "validation" not in st.session_state:
|
279 |
+
|
280 |
+
st.session_state["validation"] = st.session_state["project_dct"][
|
281 |
+
"data_validation"
|
282 |
+
]["validated_variables"]
|
283 |
+
|
284 |
+
val_variables = [col for col in media_channel if col != "date"]
|
285 |
+
|
286 |
+
if not set(
|
287 |
+
st.session_state["project_dct"]["data_validation"]["validated_variables"]
|
288 |
+
).issubset(set(val_variables)):
|
289 |
+
|
290 |
+
st.session_state["validation"] = []
|
291 |
+
|
292 |
+
else:
|
293 |
+
fig_row1 = line_plot(
|
294 |
+
st.session_state["Cleaned_data_panel"],
|
295 |
+
x_col="date",
|
296 |
+
y1_cols=[st.session_state["selected_feature"]],
|
297 |
+
y2_cols=[target_column],
|
298 |
+
title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
|
299 |
+
)
|
300 |
+
st.plotly_chart(fig_row1, use_container_width=True)
|
301 |
+
st.markdown("### Summary")
|
302 |
+
st.dataframe(
|
303 |
+
summary(
|
304 |
+
st.session_state["Cleaned_data_panel"],
|
305 |
+
[st.session_state["selected_feature"]],
|
306 |
+
spends=spends_feature[0],
|
307 |
+
),
|
308 |
+
use_container_width=True,
|
309 |
+
)
|
310 |
+
|
311 |
+
cols2 = st.columns(2)
|
312 |
+
|
313 |
+
if len(
|
314 |
+
set(st.session_state["validation"]).intersection(val_variables)
|
315 |
+
) == len(val_variables):
|
316 |
+
disable = True
|
317 |
+
help = "All media variables are validated"
|
318 |
+
else:
|
319 |
+
disable = False
|
320 |
+
help = ""
|
321 |
+
|
322 |
+
with cols2[0]:
|
323 |
+
if st.button("Validate", disabled=disable, help=help):
|
324 |
+
st.session_state["validation"].append(
|
325 |
+
st.session_state["selected_feature"]
|
326 |
+
)
|
327 |
+
with cols2[1]:
|
328 |
+
|
329 |
+
if st.checkbox("Validate All", disabled=disable, help=help):
|
330 |
+
st.session_state["validation"].extend(val_variables)
|
331 |
+
st.success("All media variables are validated ✅")
|
332 |
+
|
333 |
+
if len(
|
334 |
+
set(st.session_state["validation"]).intersection(val_variables)
|
335 |
+
) != len(val_variables):
|
336 |
+
validation_data = pd.DataFrame(
|
337 |
+
{
|
338 |
+
"Validate": [
|
339 |
+
(True if col in st.session_state["validation"] else False)
|
340 |
+
for col in val_variables
|
341 |
+
],
|
342 |
+
"Variables": val_variables,
|
343 |
+
}
|
344 |
+
)
|
345 |
+
|
346 |
+
sorted_validation_df = validation_data.sort_values(
|
347 |
+
by="Variables", ascending=True, na_position="first"
|
348 |
+
)
|
349 |
+
cols3 = st.columns([1, 30])
|
350 |
+
with cols3[1]:
|
351 |
+
validation_df = st.data_editor(
|
352 |
+
sorted_validation_df,
|
353 |
+
# column_config={
|
354 |
+
# 'Validate':st.column_config.CheckboxColumn(wi)
|
355 |
+
# },
|
356 |
+
column_config={
|
357 |
+
"Validate": st.column_config.CheckboxColumn(
|
358 |
+
default=False,
|
359 |
+
width=100,
|
360 |
+
),
|
361 |
+
"Variables": st.column_config.TextColumn(width=1000),
|
362 |
+
},
|
363 |
+
hide_index=True,
|
364 |
+
)
|
365 |
+
|
366 |
+
selected_rows = validation_df[validation_df["Validate"] == True][
|
367 |
+
"Variables"
|
368 |
+
]
|
369 |
+
|
370 |
+
# st.write(selected_rows)
|
371 |
+
|
372 |
+
st.session_state["validation"].extend(selected_rows)
|
373 |
+
|
374 |
+
st.session_state["project_dct"]["data_validation"][
|
375 |
+
"validated_variables"
|
376 |
+
] = st.session_state["validation"]
|
377 |
+
|
378 |
+
not_validated_variables = [
|
379 |
+
col
|
380 |
+
for col in val_variables
|
381 |
+
if col not in st.session_state["validation"]
|
382 |
+
]
|
383 |
+
|
384 |
+
if not_validated_variables:
|
385 |
+
not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
|
386 |
+
st.warning(not_validated_message)
|
387 |
+
|
388 |
+
with st.expander("Non-Media Variables Analysis"):
|
389 |
+
if len(Non_media_variables) == 0:
|
390 |
+
st.warning("Non-Media variables not present")
|
391 |
+
|
392 |
+
else:
|
393 |
+
selected_columns_row4 = st.selectbox(
|
394 |
+
"Select Channel",
|
395 |
+
Non_media_variables,
|
396 |
+
format_func=format_display,
|
397 |
+
index=st.session_state["project_dct"]["data_validation"][
|
398 |
+
"Non_media_variables"
|
399 |
+
],
|
400 |
+
)
|
401 |
+
|
402 |
+
st.session_state["project_dct"]["data_validation"][
|
403 |
+
"Non_media_variables"
|
404 |
+
] = Non_media_variables.index(selected_columns_row4)
|
405 |
+
|
406 |
+
# # Create the dual-axis line plot
|
407 |
+
fig_row4 = line_plot(
|
408 |
+
st.session_state["Cleaned_data_panel"],
|
409 |
+
x_col="date",
|
410 |
+
y1_cols=[selected_columns_row4],
|
411 |
+
y2_cols=[target_column],
|
412 |
+
title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
|
413 |
+
)
|
414 |
+
st.plotly_chart(fig_row4, use_container_width=True)
|
415 |
+
selected_non_media = selected_columns_row4
|
416 |
+
sum_df = st.session_state["Cleaned_data_panel"][
|
417 |
+
["date", selected_non_media, target_column]
|
418 |
+
]
|
419 |
+
sum_df["Year"] = pd.to_datetime(
|
420 |
+
st.session_state["Cleaned_data_panel"]["date"]
|
421 |
+
).dt.year
|
422 |
+
# st.dataframe(df)
|
423 |
+
# st.dataframe(sum_df.head(2))
|
424 |
+
|
425 |
+
sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
|
426 |
+
sum_df.loc["Grand Total"] = sum_df.sum()
|
427 |
+
sum_df = sum_df.applymap(format_numbers)
|
428 |
+
sum_df.fillna("-", inplace=True)
|
429 |
+
sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
|
430 |
+
st.markdown("### Summary")
|
431 |
+
st.dataframe(sum_df, use_container_width=True)
|
432 |
+
|
433 |
+
with st.expander("Correlation Analysis"):
|
434 |
+
options = list(
|
435 |
+
st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
|
436 |
+
)
|
437 |
+
|
438 |
+
|
439 |
+
if "correlation" not in st.session_state["project_dct"]["data_import"]:
|
440 |
+
st.session_state["project_dct"]["data_import"]["correlation"]=[]
|
441 |
+
|
442 |
+
selected_options = st.multiselect(
|
443 |
+
"Select Variables for Correlation Plot",
|
444 |
+
[var for var in options if var != target_column],
|
445 |
+
default=st.session_state["project_dct"]["data_import"]["correlation"],
|
446 |
+
)
|
447 |
+
|
448 |
+
st.session_state["project_dct"]["data_import"]["correlation"] = selected_options
|
449 |
+
|
450 |
+
st.pyplot(
|
451 |
+
correlation_plot(
|
452 |
+
st.session_state["Cleaned_data_panel"],
|
453 |
+
selected_options,
|
454 |
+
target_column,
|
455 |
+
)
|
456 |
+
)
|
457 |
+
|
458 |
+
if st.button("Save Changes", use_container_width=True):
|
459 |
+
# Update DB
|
460 |
+
update_db(
|
461 |
+
prj_id=st.session_state["project_number"],
|
462 |
+
page_nam="Data Validation and Insights",
|
463 |
+
file_nam="project_dct",
|
464 |
+
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
|
465 |
+
schema=schema,
|
466 |
+
)
|
467 |
+
st.success("Changes saved")
|
pages/3_AI_Model_Transformations.py
ADDED
@@ -0,0 +1,1326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="AI Model Transformations",
|
6 |
+
page_icon="⚖️",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="collapsed",
|
9 |
+
)
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import pickle
|
13 |
+
import traceback
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
import plotly.graph_objects as go
|
17 |
+
from post_gres_cred import db_cred
|
18 |
+
from log_application import log_message
|
19 |
+
from utilities import (
|
20 |
+
set_header,
|
21 |
+
load_local_css,
|
22 |
+
update_db,
|
23 |
+
project_selection,
|
24 |
+
delete_entries,
|
25 |
+
retrieve_pkl_object,
|
26 |
+
)
|
27 |
+
from constants import (
|
28 |
+
predefined_defaults,
|
29 |
+
lead_min_value,
|
30 |
+
lead_max_value,
|
31 |
+
lead_step,
|
32 |
+
lag_min_value,
|
33 |
+
lag_max_value,
|
34 |
+
lag_step,
|
35 |
+
moving_average_min_value,
|
36 |
+
moving_average_max_value,
|
37 |
+
moving_average_step,
|
38 |
+
saturation_min_value,
|
39 |
+
saturation_max_value,
|
40 |
+
saturation_step,
|
41 |
+
power_min_value,
|
42 |
+
power_max_value,
|
43 |
+
power_step,
|
44 |
+
adstock_min_value,
|
45 |
+
adstock_max_value,
|
46 |
+
adstock_step,
|
47 |
+
display_max_col,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
schema = db_cred["schema"]
|
52 |
+
load_local_css("styles.css")
|
53 |
+
set_header()
|
54 |
+
|
55 |
+
|
56 |
+
# Initialize project name session state
|
57 |
+
if "project_name" not in st.session_state:
|
58 |
+
st.session_state["project_name"] = None
|
59 |
+
|
60 |
+
# Fetch project dictionary
|
61 |
+
if "project_dct" not in st.session_state:
|
62 |
+
project_selection()
|
63 |
+
st.stop()
|
64 |
+
|
65 |
+
# Display Username and Project Name
|
66 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
67 |
+
|
68 |
+
cols1 = st.columns([2, 1])
|
69 |
+
|
70 |
+
with cols1[0]:
|
71 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
72 |
+
with cols1[1]:
|
73 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
74 |
+
|
75 |
+
# Load saved data from project dictionary
|
76 |
+
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
|
77 |
+
st.warning(
|
78 |
+
"The data import is incomplete. Please go back to the Data Import page and complete the save.",
|
79 |
+
icon="🔙",
|
80 |
+
)
|
81 |
+
|
82 |
+
# Log message
|
83 |
+
log_message(
|
84 |
+
"warning",
|
85 |
+
"The data import is incomplete. Please go back to the Data Import page and complete the save.",
|
86 |
+
"Transformations",
|
87 |
+
)
|
88 |
+
|
89 |
+
st.stop()
|
90 |
+
else:
|
91 |
+
final_df_loaded = st.session_state["project_dct"]["data_import"][
|
92 |
+
"imputed_tool_df"
|
93 |
+
].copy()
|
94 |
+
bin_dict_loaded = st.session_state["project_dct"]["data_import"][
|
95 |
+
"category_dict"
|
96 |
+
].copy()
|
97 |
+
unique_panels = st.session_state["project_dct"]["data_import"][
|
98 |
+
"unique_panels"
|
99 |
+
].copy()
|
100 |
+
|
101 |
+
# Initialize project dictionary data
|
102 |
+
if st.session_state["project_dct"]["transformations"]["final_df"] is None:
|
103 |
+
st.session_state["project_dct"]["transformations"][
|
104 |
+
"final_df"
|
105 |
+
] = final_df_loaded # Default as original dataframe
|
106 |
+
|
107 |
+
# Extract original columns for specified categories
|
108 |
+
original_columns = {
|
109 |
+
category: bin_dict_loaded[category]
|
110 |
+
for category in ["Media", "Internal", "Exogenous"]
|
111 |
+
if category in bin_dict_loaded
|
112 |
+
}
|
113 |
+
|
114 |
+
# Retrive Panel columns
|
115 |
+
panel = ["panel"] if len(unique_panels) > 1 else []
|
116 |
+
|
117 |
+
|
118 |
+
# Function to clear model metadata
|
119 |
+
def clear_pages():
|
120 |
+
# Reset Pages
|
121 |
+
st.session_state["project_dct"]["model_build"] = {
|
122 |
+
"sel_target_col": None,
|
123 |
+
"all_iters_check": False,
|
124 |
+
"iterations": 0,
|
125 |
+
"build_button": False,
|
126 |
+
"show_results_check": False,
|
127 |
+
"session_state_saved": {},
|
128 |
+
}
|
129 |
+
st.session_state["project_dct"]["model_tuning"] = {
|
130 |
+
"sel_target_col": None,
|
131 |
+
"sel_model": {},
|
132 |
+
"flag_expander": False,
|
133 |
+
"start_date_default": None,
|
134 |
+
"end_date_default": None,
|
135 |
+
"repeat_default": "No",
|
136 |
+
"flags": {},
|
137 |
+
"select_all_flags_check": {},
|
138 |
+
"selected_flags": {},
|
139 |
+
"trend_check": False,
|
140 |
+
"week_num_check": False,
|
141 |
+
"sine_cosine_check": False,
|
142 |
+
"session_state_saved": {},
|
143 |
+
}
|
144 |
+
st.session_state["project_dct"]["saved_model_results"] = {
|
145 |
+
"selected_options": None,
|
146 |
+
"model_grid_sel": [1],
|
147 |
+
}
|
148 |
+
if "model_results_df" in st.session_state:
|
149 |
+
del st.session_state["model_results_df"]
|
150 |
+
if "model_results_data" in st.session_state:
|
151 |
+
del st.session_state["model_results_data"]
|
152 |
+
if "coefficients_df" in st.session_state:
|
153 |
+
del st.session_state["coefficients_df"]
|
154 |
+
|
155 |
+
|
156 |
+
# Function to update transformation change
|
157 |
+
def transformation_change(category, transformation, key):
|
158 |
+
st.session_state["project_dct"]["transformations"][category][transformation] = (
|
159 |
+
st.session_state[key]
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
# Function to update specific transformation change
|
164 |
+
def transformation_specific_change(channel_name, transformation, key):
|
165 |
+
st.session_state["project_dct"]["transformations"]["Specific"][transformation][
|
166 |
+
channel_name
|
167 |
+
] = st.session_state[key]
|
168 |
+
|
169 |
+
|
170 |
+
# Function to update transformations to apply change
|
171 |
+
def transformations_to_apply_change(category, key):
|
172 |
+
st.session_state["project_dct"]["transformations"][category][key] = (
|
173 |
+
st.session_state[key]
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
# Function to update channel select specific change
|
178 |
+
def channel_select_specific_change():
|
179 |
+
st.session_state["project_dct"]["transformations"]["Specific"][
|
180 |
+
"channel_select_specific"
|
181 |
+
] = st.session_state["channel_select_specific"]
|
182 |
+
|
183 |
+
|
184 |
+
# Function to update specific transformation change
|
185 |
+
def specific_transformation_change(specific_transformation_key):
|
186 |
+
st.session_state["project_dct"]["transformations"]["Specific"][
|
187 |
+
specific_transformation_key
|
188 |
+
] = st.session_state[specific_transformation_key]
|
189 |
+
|
190 |
+
|
191 |
+
# Function to build transformation widgets
|
192 |
+
def transformation_widgets(category, transform_params, date_granularity):
|
193 |
+
# Transformation Options
|
194 |
+
transformation_options = {
|
195 |
+
"Media": [
|
196 |
+
"Lag",
|
197 |
+
"Moving Average",
|
198 |
+
"Saturation",
|
199 |
+
"Power",
|
200 |
+
"Adstock",
|
201 |
+
],
|
202 |
+
"Internal": ["Lead", "Lag", "Moving Average"],
|
203 |
+
"Exogenous": ["Lead", "Lag", "Moving Average"],
|
204 |
+
}
|
205 |
+
|
206 |
+
# Define a helper function to create widgets for each transformation
|
207 |
+
def create_transformation_widgets(column, transformations):
|
208 |
+
with column:
|
209 |
+
for transformation in transformations:
|
210 |
+
transformation_key = f"{transformation}_{category}"
|
211 |
+
|
212 |
+
slider_value = st.session_state["project_dct"]["transformations"][
|
213 |
+
category
|
214 |
+
].get(transformation, predefined_defaults[transformation])
|
215 |
+
|
216 |
+
# Conditionally create widgets for selected transformations
|
217 |
+
if transformation == "Lead":
|
218 |
+
st.markdown(f"**{transformation} ({date_granularity})**")
|
219 |
+
|
220 |
+
lead = st.slider(
|
221 |
+
label="Lead periods",
|
222 |
+
min_value=lead_min_value,
|
223 |
+
max_value=lead_max_value,
|
224 |
+
step=lead_step,
|
225 |
+
value=slider_value,
|
226 |
+
key=transformation_key,
|
227 |
+
label_visibility="collapsed",
|
228 |
+
on_change=transformation_change,
|
229 |
+
args=(
|
230 |
+
category,
|
231 |
+
transformation,
|
232 |
+
transformation_key,
|
233 |
+
),
|
234 |
+
)
|
235 |
+
|
236 |
+
start = lead[0]
|
237 |
+
end = lead[1]
|
238 |
+
step = lead_step
|
239 |
+
transform_params[category][transformation] = np.arange(
|
240 |
+
start, end + step, step
|
241 |
+
)
|
242 |
+
|
243 |
+
if transformation == "Lag":
|
244 |
+
st.markdown(f"**{transformation} ({date_granularity})**")
|
245 |
+
|
246 |
+
lag = st.slider(
|
247 |
+
label="Lag periods",
|
248 |
+
min_value=lag_min_value,
|
249 |
+
max_value=lag_max_value,
|
250 |
+
step=lag_step,
|
251 |
+
value=slider_value,
|
252 |
+
key=transformation_key,
|
253 |
+
label_visibility="collapsed",
|
254 |
+
on_change=transformation_change,
|
255 |
+
args=(
|
256 |
+
category,
|
257 |
+
transformation,
|
258 |
+
transformation_key,
|
259 |
+
),
|
260 |
+
)
|
261 |
+
|
262 |
+
start = lag[0]
|
263 |
+
end = lag[1]
|
264 |
+
step = lag_step
|
265 |
+
transform_params[category][transformation] = np.arange(
|
266 |
+
start, end + step, step
|
267 |
+
)
|
268 |
+
|
269 |
+
if transformation == "Moving Average":
|
270 |
+
st.markdown(f"**{transformation} ({date_granularity})**")
|
271 |
+
|
272 |
+
window = st.slider(
|
273 |
+
label="Window size for Moving Average",
|
274 |
+
min_value=moving_average_min_value,
|
275 |
+
max_value=moving_average_max_value,
|
276 |
+
step=moving_average_step,
|
277 |
+
value=slider_value,
|
278 |
+
key=transformation_key,
|
279 |
+
label_visibility="collapsed",
|
280 |
+
on_change=transformation_change,
|
281 |
+
args=(
|
282 |
+
category,
|
283 |
+
transformation,
|
284 |
+
transformation_key,
|
285 |
+
),
|
286 |
+
)
|
287 |
+
|
288 |
+
start = window[0]
|
289 |
+
end = window[1]
|
290 |
+
step = moving_average_step
|
291 |
+
transform_params[category][transformation] = np.arange(
|
292 |
+
start, end + step, step
|
293 |
+
)
|
294 |
+
|
295 |
+
if transformation == "Saturation":
|
296 |
+
st.markdown(f"**{transformation} (%)**")
|
297 |
+
|
298 |
+
saturation_point = st.slider(
|
299 |
+
label="Saturation Percentage",
|
300 |
+
min_value=saturation_min_value,
|
301 |
+
max_value=saturation_max_value,
|
302 |
+
step=saturation_step,
|
303 |
+
value=slider_value,
|
304 |
+
key=transformation_key,
|
305 |
+
label_visibility="collapsed",
|
306 |
+
on_change=transformation_change,
|
307 |
+
args=(
|
308 |
+
category,
|
309 |
+
transformation,
|
310 |
+
transformation_key,
|
311 |
+
),
|
312 |
+
)
|
313 |
+
|
314 |
+
start = saturation_point[0]
|
315 |
+
end = saturation_point[1]
|
316 |
+
step = saturation_step
|
317 |
+
transform_params[category][transformation] = np.arange(
|
318 |
+
start, end + step, step
|
319 |
+
)
|
320 |
+
|
321 |
+
if transformation == "Power":
|
322 |
+
st.markdown(f"**{transformation}**")
|
323 |
+
|
324 |
+
power = st.slider(
|
325 |
+
label="Power",
|
326 |
+
min_value=power_min_value,
|
327 |
+
max_value=power_max_value,
|
328 |
+
step=power_step,
|
329 |
+
value=slider_value,
|
330 |
+
key=transformation_key,
|
331 |
+
label_visibility="collapsed",
|
332 |
+
on_change=transformation_change,
|
333 |
+
args=(
|
334 |
+
category,
|
335 |
+
transformation,
|
336 |
+
transformation_key,
|
337 |
+
),
|
338 |
+
)
|
339 |
+
|
340 |
+
start = power[0]
|
341 |
+
end = power[1]
|
342 |
+
step = power_step
|
343 |
+
transform_params[category][transformation] = np.arange(
|
344 |
+
start, end + step, step
|
345 |
+
)
|
346 |
+
|
347 |
+
if transformation == "Adstock":
|
348 |
+
st.markdown(f"**{transformation}**")
|
349 |
+
|
350 |
+
rate = st.slider(
|
351 |
+
label="Decay Factor",
|
352 |
+
min_value=adstock_min_value,
|
353 |
+
max_value=adstock_max_value,
|
354 |
+
step=adstock_step,
|
355 |
+
value=slider_value,
|
356 |
+
key=transformation_key,
|
357 |
+
label_visibility="collapsed",
|
358 |
+
on_change=transformation_change,
|
359 |
+
args=(
|
360 |
+
category,
|
361 |
+
transformation,
|
362 |
+
transformation_key,
|
363 |
+
),
|
364 |
+
)
|
365 |
+
|
366 |
+
start = rate[0]
|
367 |
+
end = rate[1]
|
368 |
+
step = adstock_step
|
369 |
+
adstock_range = [
|
370 |
+
round(a, 3) for a in np.arange(start, end + step, step)
|
371 |
+
]
|
372 |
+
transform_params[category][transformation] = np.array(adstock_range)
|
373 |
+
|
374 |
+
with st.expander(f"All {category} Transformations", expanded=True):
|
375 |
+
|
376 |
+
transformation_key = f"transformation_{category}"
|
377 |
+
|
378 |
+
# Select which transformations to apply
|
379 |
+
sel_transformations = st.session_state["project_dct"]["transformations"][
|
380 |
+
category
|
381 |
+
].get(transformation_key, [])
|
382 |
+
|
383 |
+
# Reset default selected channels list if options are changed
|
384 |
+
for channel in sel_transformations:
|
385 |
+
if channel not in transformation_options[category]:
|
386 |
+
(
|
387 |
+
st.session_state["project_dct"]["transformations"][category][
|
388 |
+
transformation_key
|
389 |
+
],
|
390 |
+
sel_transformations,
|
391 |
+
) = ([], [])
|
392 |
+
|
393 |
+
transformations_to_apply = st.multiselect(
|
394 |
+
label="Select transformations to apply",
|
395 |
+
options=transformation_options[category],
|
396 |
+
default=sel_transformations,
|
397 |
+
key=transformation_key,
|
398 |
+
on_change=transformations_to_apply_change,
|
399 |
+
args=(
|
400 |
+
category,
|
401 |
+
transformation_key,
|
402 |
+
),
|
403 |
+
)
|
404 |
+
|
405 |
+
# Determine the number of transformations to put in each column
|
406 |
+
transformations_per_column = (
|
407 |
+
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
|
408 |
+
)
|
409 |
+
|
410 |
+
# Create two columns
|
411 |
+
col1, col2 = st.columns(2)
|
412 |
+
|
413 |
+
# Assign transformations to each column
|
414 |
+
transformations_col1 = transformations_to_apply[:transformations_per_column]
|
415 |
+
transformations_col2 = transformations_to_apply[transformations_per_column:]
|
416 |
+
|
417 |
+
# Create widgets in each column
|
418 |
+
create_transformation_widgets(col1, transformations_col1)
|
419 |
+
create_transformation_widgets(col2, transformations_col2)
|
420 |
+
|
421 |
+
|
422 |
+
# Define a helper function to create widgets for each specific transformation
|
423 |
+
def create_specific_transformation_widgets(
|
424 |
+
column,
|
425 |
+
transformations,
|
426 |
+
channel_name,
|
427 |
+
date_granularity,
|
428 |
+
specific_transform_params,
|
429 |
+
):
|
430 |
+
with column:
|
431 |
+
for transformation in transformations:
|
432 |
+
transformation_key = f"{transformation}_{channel_name}_specific"
|
433 |
+
|
434 |
+
if (
|
435 |
+
transformation
|
436 |
+
not in st.session_state["project_dct"]["transformations"]["Specific"]
|
437 |
+
):
|
438 |
+
st.session_state["project_dct"]["transformations"]["Specific"][
|
439 |
+
transformation
|
440 |
+
] = {}
|
441 |
+
|
442 |
+
slider_value = st.session_state["project_dct"]["transformations"][
|
443 |
+
"Specific"
|
444 |
+
][transformation].get(channel_name, predefined_defaults[transformation])
|
445 |
+
|
446 |
+
# Conditionally create widgets for selected transformations
|
447 |
+
if transformation == "Lead":
|
448 |
+
st.markdown(f"**Lead ({date_granularity})**")
|
449 |
+
|
450 |
+
lead = st.slider(
|
451 |
+
label="Lead periods",
|
452 |
+
min_value=lead_min_value,
|
453 |
+
max_value=lead_max_value,
|
454 |
+
step=lead_step,
|
455 |
+
value=slider_value,
|
456 |
+
key=transformation_key,
|
457 |
+
label_visibility="collapsed",
|
458 |
+
on_change=transformation_specific_change,
|
459 |
+
args=(
|
460 |
+
channel_name,
|
461 |
+
transformation,
|
462 |
+
transformation_key,
|
463 |
+
),
|
464 |
+
)
|
465 |
+
|
466 |
+
start = lead[0]
|
467 |
+
end = lead[1]
|
468 |
+
step = lead_step
|
469 |
+
specific_transform_params[channel_name]["Lead"] = np.arange(
|
470 |
+
start, end + step, step
|
471 |
+
)
|
472 |
+
|
473 |
+
if transformation == "Lag":
|
474 |
+
st.markdown(f"**Lag ({date_granularity})**")
|
475 |
+
|
476 |
+
lag = st.slider(
|
477 |
+
label="Lag periods",
|
478 |
+
min_value=lag_min_value,
|
479 |
+
max_value=lag_max_value,
|
480 |
+
step=lag_step,
|
481 |
+
value=slider_value,
|
482 |
+
key=transformation_key,
|
483 |
+
label_visibility="collapsed",
|
484 |
+
on_change=transformation_specific_change,
|
485 |
+
args=(
|
486 |
+
channel_name,
|
487 |
+
transformation,
|
488 |
+
transformation_key,
|
489 |
+
),
|
490 |
+
)
|
491 |
+
|
492 |
+
start = lag[0]
|
493 |
+
end = lag[1]
|
494 |
+
step = lag_step
|
495 |
+
specific_transform_params[channel_name]["Lag"] = np.arange(
|
496 |
+
start, end + step, step
|
497 |
+
)
|
498 |
+
|
499 |
+
if transformation == "Moving Average":
|
500 |
+
st.markdown(f"**Moving Average ({date_granularity})**")
|
501 |
+
|
502 |
+
window = st.slider(
|
503 |
+
label="Window size for Moving Average",
|
504 |
+
min_value=moving_average_min_value,
|
505 |
+
max_value=moving_average_max_value,
|
506 |
+
step=moving_average_step,
|
507 |
+
value=slider_value,
|
508 |
+
key=transformation_key,
|
509 |
+
label_visibility="collapsed",
|
510 |
+
on_change=transformation_specific_change,
|
511 |
+
args=(
|
512 |
+
channel_name,
|
513 |
+
transformation,
|
514 |
+
transformation_key,
|
515 |
+
),
|
516 |
+
)
|
517 |
+
|
518 |
+
start = window[0]
|
519 |
+
end = window[1]
|
520 |
+
step = moving_average_step
|
521 |
+
specific_transform_params[channel_name]["Moving Average"] = np.arange(
|
522 |
+
start, end + step, step
|
523 |
+
)
|
524 |
+
|
525 |
+
if transformation == "Saturation":
|
526 |
+
st.markdown("**Saturation (%)**")
|
527 |
+
|
528 |
+
saturation_point = st.slider(
|
529 |
+
label="Saturation Percentage",
|
530 |
+
min_value=saturation_min_value,
|
531 |
+
max_value=saturation_max_value,
|
532 |
+
step=saturation_step,
|
533 |
+
value=slider_value,
|
534 |
+
key=transformation_key,
|
535 |
+
label_visibility="collapsed",
|
536 |
+
on_change=transformation_specific_change,
|
537 |
+
args=(
|
538 |
+
channel_name,
|
539 |
+
transformation,
|
540 |
+
transformation_key,
|
541 |
+
),
|
542 |
+
)
|
543 |
+
|
544 |
+
start = saturation_point[0]
|
545 |
+
end = saturation_point[1]
|
546 |
+
step = saturation_step
|
547 |
+
specific_transform_params[channel_name]["Saturation"] = np.arange(
|
548 |
+
start, end + step, step
|
549 |
+
)
|
550 |
+
|
551 |
+
if transformation == "Power":
|
552 |
+
st.markdown("**Power**")
|
553 |
+
|
554 |
+
power = st.slider(
|
555 |
+
label="Power",
|
556 |
+
min_value=power_min_value,
|
557 |
+
max_value=power_max_value,
|
558 |
+
step=power_step,
|
559 |
+
value=slider_value,
|
560 |
+
key=transformation_key,
|
561 |
+
label_visibility="collapsed",
|
562 |
+
on_change=transformation_specific_change,
|
563 |
+
args=(
|
564 |
+
channel_name,
|
565 |
+
transformation,
|
566 |
+
transformation_key,
|
567 |
+
),
|
568 |
+
)
|
569 |
+
|
570 |
+
start = power[0]
|
571 |
+
end = power[1]
|
572 |
+
step = power_step
|
573 |
+
specific_transform_params[channel_name]["Power"] = np.arange(
|
574 |
+
start, end + step, step
|
575 |
+
)
|
576 |
+
|
577 |
+
if transformation == "Adstock":
|
578 |
+
st.markdown("**Adstock**")
|
579 |
+
rate = st.slider(
|
580 |
+
label="Decay Factor",
|
581 |
+
min_value=adstock_min_value,
|
582 |
+
max_value=adstock_max_value,
|
583 |
+
step=adstock_step,
|
584 |
+
value=slider_value,
|
585 |
+
key=transformation_key,
|
586 |
+
label_visibility="collapsed",
|
587 |
+
on_change=transformation_specific_change,
|
588 |
+
args=(
|
589 |
+
channel_name,
|
590 |
+
transformation,
|
591 |
+
transformation_key,
|
592 |
+
),
|
593 |
+
)
|
594 |
+
|
595 |
+
start = rate[0]
|
596 |
+
end = rate[1]
|
597 |
+
step = adstock_step
|
598 |
+
adstock_range = [
|
599 |
+
round(a, 3) for a in np.arange(start, end + step, step)
|
600 |
+
]
|
601 |
+
specific_transform_params[channel_name]["Adstock"] = np.array(
|
602 |
+
adstock_range
|
603 |
+
)
|
604 |
+
|
605 |
+
|
606 |
+
# Function to apply Lag transformation
|
607 |
+
def apply_lag(df, lag):
|
608 |
+
return df.shift(lag)
|
609 |
+
|
610 |
+
|
611 |
+
# Function to apply Lead transformation
|
612 |
+
def apply_lead(df, lead):
|
613 |
+
return df.shift(-lead)
|
614 |
+
|
615 |
+
|
616 |
+
# Function to apply Moving Average transformation
|
617 |
+
def apply_moving_average(df, window_size):
|
618 |
+
return df.rolling(window=window_size).mean()
|
619 |
+
|
620 |
+
|
621 |
+
# Function to apply Saturation transformation
|
622 |
+
def apply_saturation(df, saturation_percent_100):
|
623 |
+
# Convert percentage to fraction
|
624 |
+
saturation_percent = min(max(saturation_percent_100, 0.01), 99.99) / 100.0
|
625 |
+
|
626 |
+
# Get the maximum and minimum values
|
627 |
+
column_max = df.max()
|
628 |
+
column_min = df.min()
|
629 |
+
|
630 |
+
# If the data is constant, scale it directly
|
631 |
+
if column_min == column_max:
|
632 |
+
return df.apply(lambda x: x * saturation_percent)
|
633 |
+
|
634 |
+
# Compute the saturation point from the data range
|
635 |
+
saturation_point = (column_min + saturation_percent * column_max) / 2
|
636 |
+
|
637 |
+
# Calculate steepness for the saturation curve
|
638 |
+
numerator = np.log((1 / saturation_percent) - 1)
|
639 |
+
denominator = np.log(saturation_point / column_max)
|
640 |
+
steepness = numerator / denominator
|
641 |
+
|
642 |
+
# Apply the saturation transformation
|
643 |
+
transformed_series = df.apply(
|
644 |
+
lambda x: (1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)) * x
|
645 |
+
)
|
646 |
+
|
647 |
+
return transformed_series
|
648 |
+
|
649 |
+
|
650 |
+
# Function to apply Power transformation
|
651 |
+
def apply_power(df, power):
|
652 |
+
return df**power
|
653 |
+
|
654 |
+
|
655 |
+
# Function to apply Adstock transformation
|
656 |
+
def apply_adstock(df, factor):
|
657 |
+
x = 0
|
658 |
+
# Use the walrus operator to update x iteratively with the Adstock formula
|
659 |
+
adstock_var = [x := x * factor + v for v in df]
|
660 |
+
ans = pd.Series(adstock_var, index=df.index)
|
661 |
+
return ans
|
662 |
+
|
663 |
+
|
664 |
+
# Function to generate transformed columns names
|
665 |
+
@st.cache_resource(show_spinner=False)
|
666 |
+
def generate_transformed_columns(
|
667 |
+
original_columns, transform_params, specific_transform_params
|
668 |
+
):
|
669 |
+
transformed_columns, summary = {}, {}
|
670 |
+
|
671 |
+
for category, columns in original_columns.items():
|
672 |
+
for column in columns:
|
673 |
+
transformed_columns[column] = []
|
674 |
+
summary_details = (
|
675 |
+
[]
|
676 |
+
) # List to hold transformation details for the current column
|
677 |
+
|
678 |
+
if (
|
679 |
+
column in specific_transform_params.keys()
|
680 |
+
and len(specific_transform_params[column]) > 0
|
681 |
+
):
|
682 |
+
for transformation, values in specific_transform_params[column].items():
|
683 |
+
# Generate transformed column names for each value
|
684 |
+
for value in values:
|
685 |
+
transformed_name = f"{column}@{transformation}_{value}"
|
686 |
+
transformed_columns[column].append(transformed_name)
|
687 |
+
|
688 |
+
# Format the values list as a string with commas and "and" before the last item
|
689 |
+
if len(values) > 1:
|
690 |
+
formatted_values = (
|
691 |
+
", ".join(map(str, values[:-1])) + " and " + str(values[-1])
|
692 |
+
)
|
693 |
+
else:
|
694 |
+
formatted_values = str(values[0])
|
695 |
+
|
696 |
+
# Add transformation details
|
697 |
+
summary_details.append(f"{transformation} ({formatted_values})")
|
698 |
+
|
699 |
+
else:
|
700 |
+
if category in transform_params:
|
701 |
+
for transformation, values in transform_params[category].items():
|
702 |
+
# Generate transformed column names for each value
|
703 |
+
if column not in specific_transform_params.keys():
|
704 |
+
for value in values:
|
705 |
+
transformed_name = f"{column}@{transformation}_{value}"
|
706 |
+
transformed_columns[column].append(transformed_name)
|
707 |
+
|
708 |
+
# Format the values list as a string with commas and "and" before the last item
|
709 |
+
if len(values) > 1:
|
710 |
+
formatted_values = (
|
711 |
+
", ".join(map(str, values[:-1]))
|
712 |
+
+ " and "
|
713 |
+
+ str(values[-1])
|
714 |
+
)
|
715 |
+
else:
|
716 |
+
formatted_values = str(values[0])
|
717 |
+
|
718 |
+
# Add transformation details
|
719 |
+
summary_details.append(
|
720 |
+
f"{transformation} ({formatted_values})"
|
721 |
+
)
|
722 |
+
|
723 |
+
else:
|
724 |
+
summary_details = ["No transformation selected"]
|
725 |
+
|
726 |
+
# Only add to summary if there are transformation details for the column
|
727 |
+
if summary_details:
|
728 |
+
formatted_summary = "⮕ ".join(summary_details)
|
729 |
+
# Use <strong> tags to make the column name bold
|
730 |
+
summary[column] = f"<strong>{column}</strong>: {formatted_summary}"
|
731 |
+
|
732 |
+
# Generate a comprehensive summary string for all columns
|
733 |
+
summary_items = [
|
734 |
+
f"{idx + 1}. {details}" for idx, details in enumerate(summary.values())
|
735 |
+
]
|
736 |
+
|
737 |
+
summary_string = "\n".join(summary_items)
|
738 |
+
|
739 |
+
return transformed_columns, summary_string
|
740 |
+
|
741 |
+
|
742 |
+
# Function to transform Dataframe slice
|
743 |
+
def transform_slice(
|
744 |
+
transform_params,
|
745 |
+
transformation_functions,
|
746 |
+
panel,
|
747 |
+
df,
|
748 |
+
df_slice,
|
749 |
+
category,
|
750 |
+
category_df,
|
751 |
+
):
|
752 |
+
# Iterate through each transformation and its parameters for the current category
|
753 |
+
for transformation, parameters in transform_params[category].items():
|
754 |
+
transformation_function = transformation_functions[transformation]
|
755 |
+
|
756 |
+
# Check if there is panel data to group by
|
757 |
+
if len(panel) > 0:
|
758 |
+
# Apply the transformation to each group
|
759 |
+
category_df = pd.concat(
|
760 |
+
[
|
761 |
+
df_slice.groupby(panel)
|
762 |
+
.transform(transformation_function, p)
|
763 |
+
.add_suffix(f"@{transformation}_{p}")
|
764 |
+
for p in parameters
|
765 |
+
],
|
766 |
+
axis=1,
|
767 |
+
)
|
768 |
+
|
769 |
+
# Replace all NaN or null values in category_df with 0
|
770 |
+
category_df.fillna(0, inplace=True)
|
771 |
+
|
772 |
+
# Update df_slice
|
773 |
+
df_slice = pd.concat(
|
774 |
+
[df[panel], category_df],
|
775 |
+
axis=1,
|
776 |
+
)
|
777 |
+
|
778 |
+
else:
|
779 |
+
for p in parameters:
|
780 |
+
# Apply the transformation function to each column
|
781 |
+
temp_df = df_slice.apply(
|
782 |
+
lambda x: transformation_function(x, p), axis=0
|
783 |
+
).rename(
|
784 |
+
lambda x: f"{x}@{transformation}_{p}",
|
785 |
+
axis="columns",
|
786 |
+
)
|
787 |
+
# Concatenate the transformed DataFrame slice to the category DataFrame
|
788 |
+
category_df = pd.concat([category_df, temp_df], axis=1)
|
789 |
+
|
790 |
+
# Replace all NaN or null values in category_df with 0
|
791 |
+
category_df.fillna(0, inplace=True)
|
792 |
+
|
793 |
+
# Update df_slice
|
794 |
+
df_slice = pd.concat(
|
795 |
+
[df[panel], category_df],
|
796 |
+
axis=1,
|
797 |
+
)
|
798 |
+
|
799 |
+
return category_df, df, df_slice
|
800 |
+
|
801 |
+
|
802 |
+
# Function to apply transformations to DataFrame slices based on specified categories and parameters
|
803 |
+
@st.cache_resource(show_spinner=False)
|
804 |
+
def apply_category_transformations(
|
805 |
+
df_main, bin_dict, transform_params, panel, specific_transform_params
|
806 |
+
):
|
807 |
+
# Dictionary for function mapping
|
808 |
+
transformation_functions = {
|
809 |
+
"Lead": apply_lead,
|
810 |
+
"Lag": apply_lag,
|
811 |
+
"Moving Average": apply_moving_average,
|
812 |
+
"Saturation": apply_saturation,
|
813 |
+
"Power": apply_power,
|
814 |
+
"Adstock": apply_adstock,
|
815 |
+
}
|
816 |
+
|
817 |
+
# List to collect all transformed DataFrames
|
818 |
+
transformed_dfs = []
|
819 |
+
|
820 |
+
# Iterate through each category specified in transform_params
|
821 |
+
for category in ["Media", "Exogenous", "Internal"]:
|
822 |
+
if (
|
823 |
+
category not in transform_params
|
824 |
+
or category not in bin_dict
|
825 |
+
or not transform_params[category]
|
826 |
+
):
|
827 |
+
continue # Skip categories without transformations
|
828 |
+
|
829 |
+
# Initialize category_df as an empty DataFrame
|
830 |
+
category_df = pd.DataFrame()
|
831 |
+
|
832 |
+
# Slice the DataFrame based on the columns specified in bin_dict for the current category
|
833 |
+
df_slice = df_main[bin_dict[category] + panel].copy()
|
834 |
+
|
835 |
+
# Drop the column from df_slice to skip specific transformations
|
836 |
+
df_slice = df_slice.drop(
|
837 |
+
columns=list(specific_transform_params.keys()), errors="ignore"
|
838 |
+
).copy()
|
839 |
+
|
840 |
+
category_df, df, df_slice_updated = transform_slice(
|
841 |
+
transform_params.copy(),
|
842 |
+
transformation_functions.copy(),
|
843 |
+
panel,
|
844 |
+
df_main.copy(),
|
845 |
+
df_slice.copy(),
|
846 |
+
category,
|
847 |
+
category_df.copy(),
|
848 |
+
)
|
849 |
+
|
850 |
+
# Append the transformed category DataFrame to the list if it's not empty
|
851 |
+
if not category_df.empty:
|
852 |
+
transformed_dfs.append(category_df)
|
853 |
+
|
854 |
+
# Apply channel specific transforms
|
855 |
+
for channel_specific in specific_transform_params:
|
856 |
+
# Initialize category_df as an empty DataFrame
|
857 |
+
category_df = pd.DataFrame()
|
858 |
+
|
859 |
+
df_slice_specific = df_main[[channel_specific] + panel].copy()
|
860 |
+
transform_params_specific = {
|
861 |
+
"Media": specific_transform_params[channel_specific]
|
862 |
+
}
|
863 |
+
|
864 |
+
category_df, df, df_slice_specific_updated = transform_slice(
|
865 |
+
transform_params_specific.copy(),
|
866 |
+
transformation_functions.copy(),
|
867 |
+
panel,
|
868 |
+
df_main.copy(),
|
869 |
+
df_slice_specific.copy(),
|
870 |
+
"Media",
|
871 |
+
category_df.copy(),
|
872 |
+
)
|
873 |
+
|
874 |
+
# Append the transformed category DataFrame to the list if it's not empty
|
875 |
+
if not category_df.empty:
|
876 |
+
transformed_dfs.append(category_df)
|
877 |
+
|
878 |
+
# If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
|
879 |
+
if len(transformed_dfs) > 0:
|
880 |
+
final_df = pd.concat([df_main] + transformed_dfs, axis=1)
|
881 |
+
else:
|
882 |
+
# If no transformations were applied, use the original DataFrame
|
883 |
+
final_df = df_main
|
884 |
+
|
885 |
+
# Find columns with '@' in their names
|
886 |
+
columns_with_at = [col for col in final_df.columns if "@" in col]
|
887 |
+
|
888 |
+
# Create a set of columns to drop
|
889 |
+
columns_to_drop = set()
|
890 |
+
|
891 |
+
# Iterate through columns with '@' to find shorter names to drop
|
892 |
+
for col in columns_with_at:
|
893 |
+
base_name = col.split("@")[0]
|
894 |
+
for other_col in columns_with_at:
|
895 |
+
if other_col.startswith(base_name) and len(other_col.split("@")) > len(
|
896 |
+
col.split("@")
|
897 |
+
):
|
898 |
+
columns_to_drop.add(col)
|
899 |
+
break
|
900 |
+
|
901 |
+
# Drop the identified columns from the DataFrame
|
902 |
+
final_df.drop(columns=list(columns_to_drop), inplace=True)
|
903 |
+
|
904 |
+
return final_df
|
905 |
+
|
906 |
+
|
907 |
+
# Function to infers the granularity of the date column in a DataFrame
|
908 |
+
@st.cache_resource(show_spinner=False)
|
909 |
+
def infer_date_granularity(df):
|
910 |
+
# Find the most common difference
|
911 |
+
common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
|
912 |
+
|
913 |
+
# Map the most common difference to a granularity
|
914 |
+
if common_freq == 1:
|
915 |
+
return "daily"
|
916 |
+
elif common_freq == 7:
|
917 |
+
return "weekly"
|
918 |
+
elif 28 <= common_freq <= 31:
|
919 |
+
return "monthly"
|
920 |
+
else:
|
921 |
+
return "irregular"
|
922 |
+
|
923 |
+
|
924 |
+
# Function to clean display DataFrame
|
925 |
+
@st.cache_data(show_spinner=False)
|
926 |
+
def clean_display_df(df, display_max_col=500):
|
927 |
+
# Sort by 'panel' and 'date'
|
928 |
+
sort_columns = ["panel", "date"]
|
929 |
+
sorted_df = df.sort_values(by=sort_columns, ascending=True, na_position="first")
|
930 |
+
|
931 |
+
# Drop duplicate columns
|
932 |
+
sorted_df = sorted_df.loc[:, ~sorted_df.columns.duplicated()]
|
933 |
+
|
934 |
+
# Check if the DataFrame has more than display_max_col columns
|
935 |
+
exceeds_max_col = sorted_df.shape[1] > display_max_col
|
936 |
+
|
937 |
+
if exceeds_max_col:
|
938 |
+
# Create a new DataFrame with 'date' and 'panel' at the start
|
939 |
+
display_df = sorted_df[["date", "panel"]]
|
940 |
+
|
941 |
+
# Add the next display_max_col - 2 columns (as 'date' and 'panel' already occupy 2 columns)
|
942 |
+
additional_columns = sorted_df.columns.difference(["date", "panel"]).tolist()[
|
943 |
+
: display_max_col - 2
|
944 |
+
]
|
945 |
+
display_df = pd.concat([display_df, sorted_df[additional_columns]], axis=1)
|
946 |
+
else:
|
947 |
+
# Ensure 'date' and 'panel' are the first two columns in the final display DataFrame
|
948 |
+
column_order = ["date", "panel"] + sorted_df.columns.difference(
|
949 |
+
["date", "panel"]
|
950 |
+
).tolist()
|
951 |
+
display_df = sorted_df[column_order]
|
952 |
+
|
953 |
+
# Return the display DataFrame and whether it exceeds 500 columns
|
954 |
+
return display_df, exceeds_max_col
|
955 |
+
|
956 |
+
|
957 |
+
#########################################################################################################################################################
|
958 |
+
# User input for transformations
|
959 |
+
#########################################################################################################################################################
|
960 |
+
|
961 |
+
try:
|
962 |
+
# Page Title
|
963 |
+
st.title("AI Model Transformations")
|
964 |
+
|
965 |
+
# Infer date granularity
|
966 |
+
date_granularity = infer_date_granularity(final_df_loaded)
|
967 |
+
|
968 |
+
# Initialize the main dictionary to store the transformation parameters for each category
|
969 |
+
transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
|
970 |
+
|
971 |
+
st.markdown("### Select Transformations to Apply")
|
972 |
+
|
973 |
+
with st.expander("Specific Media Transformations"):
|
974 |
+
# Select which transformations to apply
|
975 |
+
sel_channel_specific = st.session_state["project_dct"]["transformations"][
|
976 |
+
"Specific"
|
977 |
+
].get("channel_select_specific", [])
|
978 |
+
|
979 |
+
# Reset default selected channels list if options are changed
|
980 |
+
for channel in sel_channel_specific:
|
981 |
+
if channel not in bin_dict_loaded["Media"]:
|
982 |
+
(
|
983 |
+
st.session_state["project_dct"]["transformations"]["Specific"][
|
984 |
+
"channel_select_specific"
|
985 |
+
],
|
986 |
+
sel_channel_specific,
|
987 |
+
) = ([], [])
|
988 |
+
|
989 |
+
select_specific_channels = st.multiselect(
|
990 |
+
label="Select channel variable",
|
991 |
+
default=sel_channel_specific,
|
992 |
+
options=bin_dict_loaded["Media"],
|
993 |
+
key="channel_select_specific",
|
994 |
+
on_change=channel_select_specific_change,
|
995 |
+
max_selections=30,
|
996 |
+
)
|
997 |
+
|
998 |
+
specific_transform_params = {}
|
999 |
+
for select_specific_channel in select_specific_channels:
|
1000 |
+
specific_transform_params[select_specific_channel] = {}
|
1001 |
+
|
1002 |
+
st.divider()
|
1003 |
+
channel_name = str(select_specific_channel).replace("_", " ").title()
|
1004 |
+
st.markdown(f"###### {channel_name}")
|
1005 |
+
|
1006 |
+
specific_transformation_key = (
|
1007 |
+
f"specific_transformation_{select_specific_channel}_Media"
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
transformations_options = [
|
1011 |
+
"Lag",
|
1012 |
+
"Moving Average",
|
1013 |
+
"Saturation",
|
1014 |
+
"Power",
|
1015 |
+
"Adstock",
|
1016 |
+
]
|
1017 |
+
|
1018 |
+
# Select which transformations to apply
|
1019 |
+
sel_transformations = st.session_state["project_dct"]["transformations"][
|
1020 |
+
"Specific"
|
1021 |
+
].get(specific_transformation_key, [])
|
1022 |
+
|
1023 |
+
# Reset default selected channels list if options are changed
|
1024 |
+
for channel in sel_transformations:
|
1025 |
+
if channel not in transformations_options:
|
1026 |
+
(
|
1027 |
+
st.session_state["project_dct"]["transformations"]["Specific"][
|
1028 |
+
specific_transformation_key
|
1029 |
+
],
|
1030 |
+
sel_channel_specific,
|
1031 |
+
) = ([], [])
|
1032 |
+
|
1033 |
+
transformations_to_apply = st.multiselect(
|
1034 |
+
label="Select transformations to apply",
|
1035 |
+
options=transformations_options,
|
1036 |
+
default=sel_transformations,
|
1037 |
+
key=specific_transformation_key,
|
1038 |
+
on_change=specific_transformation_change,
|
1039 |
+
args=(specific_transformation_key,),
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
# Determine the number of transformations to put in each column
|
1043 |
+
transformations_per_column = (
|
1044 |
+
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
# Create two columns
|
1048 |
+
col1, col2 = st.columns(2)
|
1049 |
+
|
1050 |
+
# Assign transformations to each column
|
1051 |
+
transformations_col1 = transformations_to_apply[:transformations_per_column]
|
1052 |
+
transformations_col2 = transformations_to_apply[transformations_per_column:]
|
1053 |
+
|
1054 |
+
# Create widgets in each column
|
1055 |
+
create_specific_transformation_widgets(
|
1056 |
+
col1,
|
1057 |
+
transformations_col1,
|
1058 |
+
select_specific_channel,
|
1059 |
+
date_granularity,
|
1060 |
+
specific_transform_params,
|
1061 |
+
)
|
1062 |
+
create_specific_transformation_widgets(
|
1063 |
+
col2,
|
1064 |
+
transformations_col2,
|
1065 |
+
select_specific_channel,
|
1066 |
+
date_granularity,
|
1067 |
+
specific_transform_params,
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
# Create Widgets
|
1071 |
+
for category in ["Media", "Internal", "Exogenous"]:
|
1072 |
+
# Skip Internal
|
1073 |
+
if category == "Internal":
|
1074 |
+
continue
|
1075 |
+
|
1076 |
+
# Skip category if no column available
|
1077 |
+
elif (
|
1078 |
+
category not in bin_dict_loaded.keys()
|
1079 |
+
or len(bin_dict_loaded[category]) == 0
|
1080 |
+
):
|
1081 |
+
st.info(
|
1082 |
+
f"{str(category).title()} category has no column associated with it. Skipping transformation step for this category.",
|
1083 |
+
icon="💬",
|
1084 |
+
)
|
1085 |
+
continue
|
1086 |
+
|
1087 |
+
transformation_widgets(category, transform_params, date_granularity)
|
1088 |
+
|
1089 |
+
#########################################################################################################################################################
|
1090 |
+
# Apply transformations
|
1091 |
+
#########################################################################################################################################################
|
1092 |
+
|
1093 |
+
# Reset transformation selection to default
|
1094 |
+
button_col = st.columns(2)
|
1095 |
+
with button_col[1]:
|
1096 |
+
if st.button("Reset to Default", use_container_width=True):
|
1097 |
+
st.session_state["project_dct"]["transformations"]["Media"] = {}
|
1098 |
+
st.session_state["project_dct"]["transformations"]["Exogenous"] = {}
|
1099 |
+
st.session_state["project_dct"]["transformations"]["Internal"] = {}
|
1100 |
+
st.session_state["project_dct"]["transformations"]["Specific"] = {}
|
1101 |
+
|
1102 |
+
# Log message
|
1103 |
+
log_message(
|
1104 |
+
"info",
|
1105 |
+
"All persistent selections have been reset to their default settings and cleared.",
|
1106 |
+
"Transformations",
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
st.rerun()
|
1110 |
+
|
1111 |
+
# Apply category-based transformations to the DataFrame
|
1112 |
+
with button_col[0]:
|
1113 |
+
if st.button("Accept and Proceed", use_container_width=True):
|
1114 |
+
with st.spinner("Applying transformations ..."):
|
1115 |
+
final_df = apply_category_transformations(
|
1116 |
+
final_df_loaded.copy(),
|
1117 |
+
bin_dict_loaded.copy(),
|
1118 |
+
transform_params.copy(),
|
1119 |
+
panel.copy(),
|
1120 |
+
specific_transform_params.copy(),
|
1121 |
+
)
|
1122 |
+
|
1123 |
+
# Generate a dictionary mapping original column names to lists of transformed column names
|
1124 |
+
transformed_columns_dict, summary_string = generate_transformed_columns(
|
1125 |
+
original_columns, transform_params, specific_transform_params
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
# Store into transformed dataframe and summary session state
|
1129 |
+
st.session_state["project_dct"]["transformations"][
|
1130 |
+
"final_df"
|
1131 |
+
] = final_df
|
1132 |
+
st.session_state["project_dct"]["transformations"][
|
1133 |
+
"summary_string"
|
1134 |
+
] = summary_string
|
1135 |
+
|
1136 |
+
# Display success message
|
1137 |
+
st.success("Transformation of the DataFrame is successful!", icon="✅")
|
1138 |
+
|
1139 |
+
# Log message
|
1140 |
+
log_message(
|
1141 |
+
"info",
|
1142 |
+
"Transformation of the DataFrame is successful!",
|
1143 |
+
"Transformations",
|
1144 |
+
)
|
1145 |
+
|
1146 |
+
#########################################################################################################################################################
|
1147 |
+
# Display the transformed DataFrame and summary
|
1148 |
+
#########################################################################################################################################################
|
1149 |
+
|
1150 |
+
# Display the transformed DataFrame in the Streamlit app
|
1151 |
+
st.markdown("### Transformed DataFrame")
|
1152 |
+
with st.spinner("Please wait while the transformed DataFrame is loading ..."):
|
1153 |
+
final_df = st.session_state["project_dct"]["transformations"]["final_df"].copy()
|
1154 |
+
|
1155 |
+
# Clean display DataFrame
|
1156 |
+
display_df, exceeds_max_col = clean_display_df(final_df, display_max_col)
|
1157 |
+
|
1158 |
+
# Check the number of columns and show only the first display_max_col if there are more
|
1159 |
+
if exceeds_max_col:
|
1160 |
+
# Display a info if the DataFrame has more than display_max_col columns
|
1161 |
+
st.info(
|
1162 |
+
f"The transformed DataFrame has more than {display_max_col} columns. Displaying only the first {display_max_col} columns.",
|
1163 |
+
icon="💬",
|
1164 |
+
)
|
1165 |
+
|
1166 |
+
# Display Final DataFrame
|
1167 |
+
st.dataframe(
|
1168 |
+
display_df,
|
1169 |
+
hide_index=True,
|
1170 |
+
column_config={
|
1171 |
+
"date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
|
1172 |
+
},
|
1173 |
+
)
|
1174 |
+
|
1175 |
+
# Total rows and columns
|
1176 |
+
total_rows, total_columns = st.session_state["project_dct"]["transformations"][
|
1177 |
+
"final_df"
|
1178 |
+
].shape
|
1179 |
+
st.markdown(
|
1180 |
+
f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
|
1181 |
+
unsafe_allow_html=True,
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
# Display the summary of transformations as markdown
|
1185 |
+
if (
|
1186 |
+
"summary_string" in st.session_state["project_dct"]["transformations"]
|
1187 |
+
and st.session_state["project_dct"]["transformations"]["summary_string"]
|
1188 |
+
):
|
1189 |
+
with st.expander("Summary of Transformations"):
|
1190 |
+
st.markdown("### Summary of Transformations")
|
1191 |
+
st.markdown(
|
1192 |
+
st.session_state["project_dct"]["transformations"][
|
1193 |
+
"summary_string"
|
1194 |
+
],
|
1195 |
+
unsafe_allow_html=True,
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
#########################################################################################################################################################
|
1199 |
+
# Correlation Plot
|
1200 |
+
#########################################################################################################################################################
|
1201 |
+
|
1202 |
+
# Filter out the 'date' column
|
1203 |
+
variables = [
|
1204 |
+
col for col in final_df.columns if col.lower() not in ["date", "panel"]
|
1205 |
+
]
|
1206 |
+
|
1207 |
+
with st.expander("Transformed Variable Correlation Plot"):
|
1208 |
+
selected_vars = st.multiselect(
|
1209 |
+
label="Choose variables for correlation plot:",
|
1210 |
+
options=variables,
|
1211 |
+
max_selections=30,
|
1212 |
+
default=st.session_state["project_dct"]["transformations"][
|
1213 |
+
"correlation_plot_selection"
|
1214 |
+
],
|
1215 |
+
key="correlation_plot_key",
|
1216 |
+
)
|
1217 |
+
|
1218 |
+
# Calculate correlation
|
1219 |
+
if selected_vars:
|
1220 |
+
corr_df = final_df[selected_vars].corr()
|
1221 |
+
|
1222 |
+
# Prepare text annotations with 2 decimal places
|
1223 |
+
annotations = []
|
1224 |
+
for i in range(len(corr_df)):
|
1225 |
+
for j in range(len(corr_df.columns)):
|
1226 |
+
annotations.append(
|
1227 |
+
go.layout.Annotation(
|
1228 |
+
text=f"{corr_df.iloc[i, j]:.2f}",
|
1229 |
+
x=corr_df.columns[j],
|
1230 |
+
y=corr_df.index[i],
|
1231 |
+
showarrow=False,
|
1232 |
+
font=dict(color="black"),
|
1233 |
+
)
|
1234 |
+
)
|
1235 |
+
|
1236 |
+
# Plotly correlation plot using go
|
1237 |
+
heatmap = go.Heatmap(
|
1238 |
+
z=corr_df.values,
|
1239 |
+
x=corr_df.columns,
|
1240 |
+
y=corr_df.index,
|
1241 |
+
colorscale="RdBu",
|
1242 |
+
zmin=-1,
|
1243 |
+
zmax=1,
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
layout = go.Layout(
|
1247 |
+
title="Transformed Variable Correlation Plot",
|
1248 |
+
xaxis=dict(title="Variables"),
|
1249 |
+
yaxis=dict(title="Variables"),
|
1250 |
+
width=1000,
|
1251 |
+
height=1000,
|
1252 |
+
annotations=annotations,
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
fig = go.Figure(data=[heatmap], layout=layout)
|
1256 |
+
|
1257 |
+
st.plotly_chart(fig)
|
1258 |
+
else:
|
1259 |
+
st.write("Please select at least one variable to plot.")
|
1260 |
+
|
1261 |
+
#########################################################################################################################################################
|
1262 |
+
# Accept and Save
|
1263 |
+
#########################################################################################################################################################
|
1264 |
+
|
1265 |
+
# Check for saved model
|
1266 |
+
if (
|
1267 |
+
retrieve_pkl_object(
|
1268 |
+
st.session_state["project_number"], "Model_Build", "best_models", schema
|
1269 |
+
)
|
1270 |
+
is not None
|
1271 |
+
): # db
|
1272 |
+
st.warning(
|
1273 |
+
"Saving transformations will overwrite existing ones and delete all saved models. To keep previous models, please start a new project.",
|
1274 |
+
icon="⚠️",
|
1275 |
+
)
|
1276 |
+
|
1277 |
+
if st.button("Accept and Save", use_container_width=True):
|
1278 |
+
|
1279 |
+
with st.spinner("Saving Changes"):
|
1280 |
+
# Update correlation plot selection
|
1281 |
+
st.session_state["project_dct"]["transformations"][
|
1282 |
+
"correlation_plot_selection"
|
1283 |
+
] = st.session_state["correlation_plot_key"]
|
1284 |
+
|
1285 |
+
# Clear model metadata
|
1286 |
+
clear_pages()
|
1287 |
+
|
1288 |
+
# Update DB
|
1289 |
+
update_db(
|
1290 |
+
prj_id=st.session_state["project_number"],
|
1291 |
+
page_nam="Transformations",
|
1292 |
+
file_nam="project_dct",
|
1293 |
+
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
|
1294 |
+
schema=schema,
|
1295 |
+
)
|
1296 |
+
|
1297 |
+
# Clear data from DB
|
1298 |
+
delete_entries(
|
1299 |
+
st.session_state["project_number"],
|
1300 |
+
["Model_Build", "Model_Tuning"],
|
1301 |
+
db_cred,
|
1302 |
+
schema,
|
1303 |
+
)
|
1304 |
+
|
1305 |
+
# Success message
|
1306 |
+
st.success("Saved Successfully!", icon="💾")
|
1307 |
+
st.toast("Saved Successfully!", icon="💾")
|
1308 |
+
|
1309 |
+
# Log message
|
1310 |
+
log_message("info", "Saved Successfully!", "Transformations")
|
1311 |
+
|
1312 |
+
except Exception as e:
|
1313 |
+
# Capture the error details
|
1314 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
1315 |
+
error_message = "".join(
|
1316 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
1317 |
+
)
|
1318 |
+
|
1319 |
+
# Log message
|
1320 |
+
log_message("error", f"An error occurred: {error_message}.", "Transformations")
|
1321 |
+
|
1322 |
+
# Display a warning message
|
1323 |
+
st.warning(
|
1324 |
+
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
|
1325 |
+
icon="⚠️",
|
1326 |
+
)
|
pages/4_AI_Model_Build.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pages/5_AI Model_Tuning.py
ADDED
@@ -0,0 +1,1215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MMO Build Sprint 3
|
3 |
+
date :
|
4 |
+
changes : capability to tune MixedLM as well as simple LR in the same page
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import streamlit as st
|
9 |
+
import pandas as pd
|
10 |
+
from data_analysis import format_numbers
|
11 |
+
import pickle
|
12 |
+
from utilities import set_header, load_local_css
|
13 |
+
import statsmodels.api as sm
|
14 |
+
import re
|
15 |
+
from sklearn.preprocessing import MaxAbsScaler
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
18 |
+
import statsmodels.formula.api as smf
|
19 |
+
from data_prep import *
|
20 |
+
import sqlite3
|
21 |
+
from utilities import (
|
22 |
+
set_header,
|
23 |
+
load_local_css,
|
24 |
+
update_db,
|
25 |
+
project_selection,
|
26 |
+
retrieve_pkl_object,
|
27 |
+
)
|
28 |
+
import numpy as np
|
29 |
+
from post_gres_cred import db_cred
|
30 |
+
import re
|
31 |
+
from constants import (
|
32 |
+
NUM_FLAG_COLS_TO_DISPLAY,
|
33 |
+
HALF_YEAR_THRESHOLD,
|
34 |
+
FULL_YEAR_THRESHOLD,
|
35 |
+
TREND_MIN,
|
36 |
+
ANNUAL_FREQUENCY,
|
37 |
+
QTR_FREQUENCY_FACTOR,
|
38 |
+
HALF_YEARLY_FREQUENCY_FACTOR,
|
39 |
+
)
|
40 |
+
from log_application import log_message
|
41 |
+
import sys, traceback
|
42 |
+
|
43 |
+
schema = db_cred["schema"]
|
44 |
+
|
45 |
+
st.set_option("deprecation.showPyplotGlobalUse", False)
|
46 |
+
|
47 |
+
st.set_page_config(
|
48 |
+
page_title="AI Model Tuning",
|
49 |
+
page_icon=":shark:",
|
50 |
+
layout="wide",
|
51 |
+
initial_sidebar_state="collapsed",
|
52 |
+
)
|
53 |
+
load_local_css("styles.css")
|
54 |
+
set_header()
|
55 |
+
|
56 |
+
|
57 |
+
# Define functions
|
58 |
+
# Get random effect from MixedLM Model
|
59 |
+
def get_random_effects(media_data, panel_col, _mdf):
|
60 |
+
# create an empty dataframe
|
61 |
+
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
62 |
+
|
63 |
+
# Iterate over all panel values and add to dataframe
|
64 |
+
for i, market in enumerate(media_data[panel_col].unique()):
|
65 |
+
intercept = _mdf.random_effects[market].values[0]
|
66 |
+
random_eff_df.loc[i, "random_effect"] = intercept
|
67 |
+
random_eff_df.loc[i, panel_col] = market
|
68 |
+
|
69 |
+
return random_eff_df
|
70 |
+
|
71 |
+
|
72 |
+
# Predict on df using MixedLM model
|
73 |
+
def mdf_predict(X_df, mdf, random_eff_df):
|
74 |
+
# Create a copy of input df and predict using MixedLM model i.e fixed effect
|
75 |
+
X = X_df.copy()
|
76 |
+
X["fixed_effect"] = mdf.predict(X)
|
77 |
+
|
78 |
+
# Merge random effects
|
79 |
+
X = pd.merge(X, random_eff_df, on=panel_col, how="left")
|
80 |
+
|
81 |
+
# Get final predictions by adding random effect to fixed effect
|
82 |
+
X["pred"] = X["fixed_effect"] + X["random_effect"]
|
83 |
+
|
84 |
+
# Drop intermediate columns
|
85 |
+
X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
|
86 |
+
|
87 |
+
return X["pred"]
|
88 |
+
|
89 |
+
|
90 |
+
def format_display(inp):
|
91 |
+
# Format display titles
|
92 |
+
return inp.title().replace("_", " ").strip()
|
93 |
+
|
94 |
+
|
95 |
+
if "username" not in st.session_state:
|
96 |
+
st.session_state["username"] = None
|
97 |
+
if "project_name" not in st.session_state:
|
98 |
+
st.session_state["project_name"] = None
|
99 |
+
if "project_dct" not in st.session_state:
|
100 |
+
project_selection()
|
101 |
+
st.stop()
|
102 |
+
if "Flags" not in st.session_state:
|
103 |
+
st.session_state["Flags"] = {}
|
104 |
+
|
105 |
+
try:
|
106 |
+
# Check Authentications
|
107 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
108 |
+
|
109 |
+
if (
|
110 |
+
retrieve_pkl_object(
|
111 |
+
st.session_state["project_number"], "Model_Build", "best_models", schema
|
112 |
+
)
|
113 |
+
is None
|
114 |
+
): # db
|
115 |
+
|
116 |
+
st.error("Please save a model before tuning")
|
117 |
+
log_message(
|
118 |
+
"warning",
|
119 |
+
"No models saved",
|
120 |
+
"Model Tuning",
|
121 |
+
)
|
122 |
+
st.stop()
|
123 |
+
|
124 |
+
# Read previous progress (persistence)
|
125 |
+
if (
|
126 |
+
"session_state_saved"
|
127 |
+
in st.session_state["project_dct"]["model_build"].keys()
|
128 |
+
):
|
129 |
+
for key in [
|
130 |
+
"Model",
|
131 |
+
"date",
|
132 |
+
"saved_model_names",
|
133 |
+
"media_data",
|
134 |
+
"X_test_spends",
|
135 |
+
"spends_data",
|
136 |
+
]:
|
137 |
+
if key not in st.session_state:
|
138 |
+
st.session_state[key] = st.session_state["project_dct"][
|
139 |
+
"model_build"
|
140 |
+
]["session_state_saved"][key]
|
141 |
+
st.session_state["bin_dict"] = st.session_state["project_dct"][
|
142 |
+
"model_build"
|
143 |
+
]["session_state_saved"]["bin_dict"]
|
144 |
+
if (
|
145 |
+
"used_response_metrics" not in st.session_state
|
146 |
+
or st.session_state["used_response_metrics"] == []
|
147 |
+
):
|
148 |
+
st.session_state["used_response_metrics"] = st.session_state[
|
149 |
+
"project_dct"
|
150 |
+
]["model_build"]["session_state_saved"]["used_response_metrics"]
|
151 |
+
else:
|
152 |
+
st.error("Please load a session with a built model")
|
153 |
+
log_message(
|
154 |
+
"error",
|
155 |
+
"Session state saved not found in Project Dictionary",
|
156 |
+
"Model Tuning",
|
157 |
+
)
|
158 |
+
st.stop()
|
159 |
+
|
160 |
+
for key in ["select_all_flags_check", "selected_flags", "sel_model"]:
|
161 |
+
if key not in st.session_state["project_dct"]["model_tuning"].keys():
|
162 |
+
st.session_state["project_dct"]["model_tuning"][key] = {}
|
163 |
+
|
164 |
+
# is_panel = st.session_state['is_panel']
|
165 |
+
# panel_col = 'markets' # set the panel column
|
166 |
+
date_col = "date"
|
167 |
+
|
168 |
+
# set the panel column
|
169 |
+
panel_col = "panel"
|
170 |
+
is_panel = (
|
171 |
+
True if st.session_state["media_data"][panel_col].nunique() > 1 else False
|
172 |
+
)
|
173 |
+
|
174 |
+
if "Model_Tuned" not in st.session_state:
|
175 |
+
st.session_state["Model_Tuned"] = {}
|
176 |
+
cols1 = st.columns([2, 1])
|
177 |
+
with cols1[0]:
|
178 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
179 |
+
with cols1[1]:
|
180 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
181 |
+
|
182 |
+
st.title("AI Model Tuning")
|
183 |
+
|
184 |
+
# flag indicating there is not tuned model till now
|
185 |
+
if "is_tuned_model" not in st.session_state:
|
186 |
+
st.session_state["is_tuned_model"] = {}
|
187 |
+
|
188 |
+
# # Read all saved models
|
189 |
+
model_dict = retrieve_pkl_object(
|
190 |
+
st.session_state["project_number"], "Model_Build", "best_models", schema
|
191 |
+
)
|
192 |
+
saved_models = model_dict.keys()
|
193 |
+
|
194 |
+
# Get list of response metrics
|
195 |
+
st.session_state["used_response_metrics"] = list(
|
196 |
+
set([model.split("__")[1] for model in saved_models])
|
197 |
+
)
|
198 |
+
|
199 |
+
# Select previously selected response_metric (persistence)
|
200 |
+
default_target_idx = (
|
201 |
+
st.session_state["project_dct"]["model_tuning"].get("sel_target_col", None)
|
202 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
203 |
+
"sel_target_col", None
|
204 |
+
)
|
205 |
+
is not None
|
206 |
+
else st.session_state["used_response_metrics"][0]
|
207 |
+
)
|
208 |
+
|
209 |
+
# Dropdown to select response metric
|
210 |
+
sel_target_col = st.selectbox(
|
211 |
+
"Select the response metric",
|
212 |
+
st.session_state["used_response_metrics"],
|
213 |
+
index=st.session_state["used_response_metrics"].index(default_target_idx),
|
214 |
+
format_func=format_display,
|
215 |
+
)
|
216 |
+
# Format selected response metrics (target col)
|
217 |
+
target_col = (
|
218 |
+
sel_target_col.lower()
|
219 |
+
.replace(" ", "_")
|
220 |
+
.replace("-", "")
|
221 |
+
.replace(":", "")
|
222 |
+
.replace("__", "_")
|
223 |
+
)
|
224 |
+
st.session_state["project_dct"]["model_tuning"][
|
225 |
+
"sel_target_col"
|
226 |
+
] = sel_target_col
|
227 |
+
|
228 |
+
# Look through all saved models, only show saved models of the selected resp metric (target_col)
|
229 |
+
# Get a list of models saved for selected response metric
|
230 |
+
required_saved_models = [
|
231 |
+
m.split("__")[0] for m in saved_models if m.split("__")[1] == target_col
|
232 |
+
]
|
233 |
+
|
234 |
+
# Get previously seelcted model if available (persistence)
|
235 |
+
default_model_idx = st.session_state["project_dct"]["model_tuning"][
|
236 |
+
"sel_model"
|
237 |
+
].get(sel_target_col, required_saved_models[0])
|
238 |
+
sel_model = st.selectbox(
|
239 |
+
"Select the model to tune",
|
240 |
+
required_saved_models,
|
241 |
+
index=required_saved_models.index(default_model_idx),
|
242 |
+
)
|
243 |
+
|
244 |
+
st.session_state["project_dct"]["model_tuning"]["sel_model"][
|
245 |
+
sel_target_col
|
246 |
+
] = default_model_idx
|
247 |
+
|
248 |
+
sel_model_dict = model_dict[
|
249 |
+
sel_model + "__" + target_col
|
250 |
+
] # get the model obj of the selected model
|
251 |
+
|
252 |
+
X_train = sel_model_dict["X_train"]
|
253 |
+
X_test = sel_model_dict["X_test"]
|
254 |
+
y_train = sel_model_dict["y_train"]
|
255 |
+
y_test = sel_model_dict["y_test"]
|
256 |
+
df = st.session_state["media_data"]
|
257 |
+
|
258 |
+
st.markdown("### Event Flags")
|
259 |
+
st.markdown("Helps in quantifying the impact of specific occurrences of events")
|
260 |
+
|
261 |
+
try:
|
262 |
+
# Dropdown to add event flags
|
263 |
+
with st.expander("Apply Event Flags"):
|
264 |
+
|
265 |
+
model = sel_model_dict["Model_object"]
|
266 |
+
date = st.session_state["date"]
|
267 |
+
date = pd.to_datetime(date)
|
268 |
+
X_train = sel_model_dict["X_train"]
|
269 |
+
|
270 |
+
features_set = sel_model_dict["feature_set"]
|
271 |
+
|
272 |
+
col = st.columns(3)
|
273 |
+
|
274 |
+
# Get date range
|
275 |
+
min_date = min(date).date()
|
276 |
+
max_date = max(date).date()
|
277 |
+
|
278 |
+
# Get previously selected start and end date of flag (persistence)
|
279 |
+
start_date_default = (
|
280 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
281 |
+
"start_date_default"
|
282 |
+
)
|
283 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
284 |
+
"start_date_default"
|
285 |
+
)
|
286 |
+
is not None
|
287 |
+
else min_date
|
288 |
+
)
|
289 |
+
start_date_default = (
|
290 |
+
start_date_default if start_date_default > min_date else min_date
|
291 |
+
)
|
292 |
+
start_date_default = (
|
293 |
+
start_date_default if start_date_default < max_date else min_date
|
294 |
+
)
|
295 |
+
|
296 |
+
end_date_default = (
|
297 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
298 |
+
"end_date_default"
|
299 |
+
)
|
300 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
301 |
+
"end_date_default"
|
302 |
+
)
|
303 |
+
is not None
|
304 |
+
else max_date
|
305 |
+
)
|
306 |
+
end_date_default = (
|
307 |
+
end_date_default if end_date_default > min_date else max_date
|
308 |
+
)
|
309 |
+
end_date_default = (
|
310 |
+
end_date_default if end_date_default < max_date else max_date
|
311 |
+
)
|
312 |
+
|
313 |
+
# Flag start and end date input boxes
|
314 |
+
with col[0]:
|
315 |
+
start_date = st.date_input(
|
316 |
+
"Select Start Date",
|
317 |
+
start_date_default,
|
318 |
+
min_value=min_date,
|
319 |
+
max_value=max_date,
|
320 |
+
)
|
321 |
+
|
322 |
+
if (start_date < min_date) or (start_date > max_date):
|
323 |
+
st.error(
|
324 |
+
"Please select dates in the range of the dates in the data"
|
325 |
+
)
|
326 |
+
st.stop()
|
327 |
+
with col[1]:
|
328 |
+
# Check if end date default > selected start date
|
329 |
+
end_date_default = (
|
330 |
+
end_date_default
|
331 |
+
if pd.Timestamp(end_date_default) >= pd.Timestamp(start_date)
|
332 |
+
else start_date
|
333 |
+
)
|
334 |
+
end_date = st.date_input(
|
335 |
+
"Select End Date",
|
336 |
+
end_date_default,
|
337 |
+
min_value=max(
|
338 |
+
pd.to_datetime(min_date), pd.to_datetime(start_date)
|
339 |
+
),
|
340 |
+
max_value=pd.to_datetime(max_date),
|
341 |
+
)
|
342 |
+
|
343 |
+
if (
|
344 |
+
(start_date < min_date)
|
345 |
+
or (end_date < min_date)
|
346 |
+
or (start_date > max_date)
|
347 |
+
or (end_date > max_date)
|
348 |
+
):
|
349 |
+
st.error(
|
350 |
+
"Please select dates in the range of the dates in the data"
|
351 |
+
)
|
352 |
+
st.stop()
|
353 |
+
if end_date < start_date:
|
354 |
+
st.error("Please select end date after start date")
|
355 |
+
st.stop()
|
356 |
+
with col[2]:
|
357 |
+
# Get default value of repeat check box (persistence)
|
358 |
+
repeat_default = (
|
359 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
360 |
+
"repeat_default"
|
361 |
+
)
|
362 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
363 |
+
"repeat_default"
|
364 |
+
)
|
365 |
+
is not None
|
366 |
+
else "No"
|
367 |
+
)
|
368 |
+
repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1
|
369 |
+
repeat = st.selectbox(
|
370 |
+
"Repeat Annually", ["Yes", "No"], index=repeat_default_idx
|
371 |
+
)
|
372 |
+
|
373 |
+
# Update selected values to session dictionary (persistence)
|
374 |
+
st.session_state["project_dct"]["model_tuning"][
|
375 |
+
"start_date_default"
|
376 |
+
] = start_date
|
377 |
+
st.session_state["project_dct"]["model_tuning"][
|
378 |
+
"end_date_default"
|
379 |
+
] = end_date
|
380 |
+
st.session_state["project_dct"]["model_tuning"][
|
381 |
+
"repeat_default"
|
382 |
+
] = repeat
|
383 |
+
|
384 |
+
if repeat == "Yes":
|
385 |
+
repeat = True
|
386 |
+
else:
|
387 |
+
repeat = False
|
388 |
+
|
389 |
+
if "flags" in st.session_state["project_dct"]["model_tuning"].keys():
|
390 |
+
st.session_state["Flags"] = st.session_state["project_dct"][
|
391 |
+
"model_tuning"
|
392 |
+
]["flags"]
|
393 |
+
|
394 |
+
if is_panel:
|
395 |
+
# Create flag on Train
|
396 |
+
met, line_values, fig_flag = plot_actual_vs_predicted(
|
397 |
+
X_train[date_col],
|
398 |
+
y_train,
|
399 |
+
model.fittedvalues,
|
400 |
+
model,
|
401 |
+
target_column=sel_target_col,
|
402 |
+
flag=(start_date, end_date),
|
403 |
+
repeat_all_years=repeat,
|
404 |
+
is_panel=True,
|
405 |
+
)
|
406 |
+
st.plotly_chart(fig_flag, use_container_width=True)
|
407 |
+
|
408 |
+
# create flag on test
|
409 |
+
met, test_line_values, fig_flag = plot_actual_vs_predicted(
|
410 |
+
X_test[date_col],
|
411 |
+
y_test,
|
412 |
+
sel_model_dict["pred_test"],
|
413 |
+
model,
|
414 |
+
target_column=sel_target_col,
|
415 |
+
flag=(start_date, end_date),
|
416 |
+
repeat_all_years=repeat,
|
417 |
+
is_panel=True,
|
418 |
+
)
|
419 |
+
|
420 |
+
else:
|
421 |
+
pred_train = model.predict(X_train[features_set])
|
422 |
+
# Create flag on Train
|
423 |
+
met, line_values, fig_flag = plot_actual_vs_predicted(
|
424 |
+
X_train[date_col],
|
425 |
+
y_train,
|
426 |
+
pred_train,
|
427 |
+
model,
|
428 |
+
flag=(start_date, end_date),
|
429 |
+
repeat_all_years=repeat,
|
430 |
+
is_panel=False,
|
431 |
+
)
|
432 |
+
st.plotly_chart(fig_flag, use_container_width=True)
|
433 |
+
|
434 |
+
# create flag on test
|
435 |
+
pred_test = model.predict(X_test[features_set])
|
436 |
+
met, test_line_values, fig_flag = plot_actual_vs_predicted(
|
437 |
+
X_test[date_col],
|
438 |
+
y_test,
|
439 |
+
pred_test,
|
440 |
+
model,
|
441 |
+
flag=(start_date, end_date),
|
442 |
+
repeat_all_years=repeat,
|
443 |
+
is_panel=False,
|
444 |
+
)
|
445 |
+
|
446 |
+
flag_name = "f1_flag"
|
447 |
+
flag_name = st.text_input("Enter Flag Name")
|
448 |
+
|
449 |
+
# add selected target col to flag name
|
450 |
+
# Save the flag name, flag train values, flag test values to session state
|
451 |
+
if st.button("Save flag"):
|
452 |
+
st.session_state["Flags"][flag_name + "_flag__" + target_col] = {}
|
453 |
+
st.session_state["Flags"][flag_name + "_flag__" + target_col][
|
454 |
+
"train"
|
455 |
+
] = line_values
|
456 |
+
st.session_state["Flags"][flag_name + "_flag__" + target_col][
|
457 |
+
"test"
|
458 |
+
] = test_line_values
|
459 |
+
st.success(f'{flag_name + "_flag__" + target_col} stored')
|
460 |
+
|
461 |
+
st.session_state["project_dct"]["model_tuning"]["flags"] = (
|
462 |
+
st.session_state["Flags"]
|
463 |
+
)
|
464 |
+
|
465 |
+
# Only show flags created for the particular target col
|
466 |
+
target_model_flags = [
|
467 |
+
f.split("__")[0]
|
468 |
+
for f in st.session_state["Flags"].keys()
|
469 |
+
if f.split("__")[1] == target_col
|
470 |
+
]
|
471 |
+
options = list(target_model_flags)
|
472 |
+
num_rows = -(-len(options) // NUM_FLAG_COLS_TO_DISPLAY)
|
473 |
+
|
474 |
+
tick = False
|
475 |
+
# Select all flags checkbox
|
476 |
+
if st.checkbox(
|
477 |
+
"Select all",
|
478 |
+
value=st.session_state["project_dct"]["model_tuning"][
|
479 |
+
"select_all_flags_check"
|
480 |
+
].get(sel_target_col, False),
|
481 |
+
):
|
482 |
+
tick = True
|
483 |
+
st.session_state["project_dct"]["model_tuning"][
|
484 |
+
"select_all_flags_check"
|
485 |
+
][sel_target_col] = True
|
486 |
+
else:
|
487 |
+
st.session_state["project_dct"]["model_tuning"][
|
488 |
+
"select_all_flags_check"
|
489 |
+
][sel_target_col] = False
|
490 |
+
|
491 |
+
# Get previous flag selection (persistence)
|
492 |
+
selection_defualts = st.session_state["project_dct"]["model_tuning"][
|
493 |
+
"selected_flags"
|
494 |
+
].get(sel_target_col, [])
|
495 |
+
selected_options = selection_defualts
|
496 |
+
|
497 |
+
# create a checkbox for each available flag for selected response metric
|
498 |
+
for row in range(num_rows):
|
499 |
+
cols = st.columns(NUM_FLAG_COLS_TO_DISPLAY)
|
500 |
+
for col in cols:
|
501 |
+
if options:
|
502 |
+
option = options.pop(0)
|
503 |
+
option_default = True if option in selection_defualts else False
|
504 |
+
selected = col.checkbox(option, value=(tick or option_default))
|
505 |
+
if selected:
|
506 |
+
selected_options.append(option)
|
507 |
+
else:
|
508 |
+
if option in selected_options:
|
509 |
+
selected_options.remove(option)
|
510 |
+
selected_options = list(set(selected_options))
|
511 |
+
|
512 |
+
# Check if flag values match Data
|
513 |
+
# This is necessary because different models can have different train/test dates
|
514 |
+
remove_flags = []
|
515 |
+
for opt in selected_options:
|
516 |
+
train_match = len(
|
517 |
+
st.session_state["Flags"][opt + "__" + target_col]["train"]
|
518 |
+
) == len(X_train[date_col])
|
519 |
+
test_match = len(
|
520 |
+
st.session_state["Flags"][opt + "__" + target_col]["test"]
|
521 |
+
) == len(X_test[date_col])
|
522 |
+
if not train_match:
|
523 |
+
st.warning(f"Flag {opt} can not be used due to train date mismatch")
|
524 |
+
# selected_options.remove(opt)
|
525 |
+
remove_flags.append(opt)
|
526 |
+
if not test_match:
|
527 |
+
st.warning(f"Flag {opt} can not be used due to test date mismatch")
|
528 |
+
# selected_options.remove(opt)
|
529 |
+
remove_flags.append(opt)
|
530 |
+
|
531 |
+
if (
|
532 |
+
len(remove_flags) > 0
|
533 |
+
and len(list(set(selected_options).intersection(set(remove_flags)))) > 0
|
534 |
+
):
|
535 |
+
selected_options = list(set(selected_options) - set(remove_flags))
|
536 |
+
|
537 |
+
st.session_state["project_dct"]["model_tuning"]["selected_flags"][
|
538 |
+
sel_target_col
|
539 |
+
] = selected_options
|
540 |
+
except:
|
541 |
+
# Capture the error details
|
542 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
543 |
+
error_message = "".join(
|
544 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
545 |
+
)
|
546 |
+
log_message(
|
547 |
+
"error", f"Error while creating flags: {error_message}", "Model Tuning"
|
548 |
+
)
|
549 |
+
st.warning("An error occured, please try again", icon="⚠️")
|
550 |
+
|
551 |
+
try:
|
552 |
+
st.markdown("### Trend and Seasonality Calibration")
|
553 |
+
parameters = st.columns(3)
|
554 |
+
|
555 |
+
# Trend checkbox
|
556 |
+
with parameters[0]:
|
557 |
+
Trend = st.checkbox(
|
558 |
+
"**Trend**",
|
559 |
+
value=st.session_state["project_dct"]["model_tuning"].get(
|
560 |
+
"trend_check", False
|
561 |
+
),
|
562 |
+
)
|
563 |
+
st.markdown(
|
564 |
+
"Helps account for long-term trends or seasonality that could influence advertising effectiveness"
|
565 |
+
)
|
566 |
+
|
567 |
+
# Day of Week (week number) checkbox
|
568 |
+
with parameters[1]:
|
569 |
+
day_of_week = st.checkbox(
|
570 |
+
"**Day of Week**",
|
571 |
+
value=st.session_state["project_dct"]["model_tuning"].get(
|
572 |
+
"week_num_check", False
|
573 |
+
),
|
574 |
+
)
|
575 |
+
st.markdown(
|
576 |
+
"Assists in detecting and incorporating weekly patterns or seasonality"
|
577 |
+
)
|
578 |
+
|
579 |
+
# Sine and cosine Waves checkbox
|
580 |
+
with parameters[2]:
|
581 |
+
sine_cosine = st.checkbox(
|
582 |
+
"**Sine and Cosine Waves**",
|
583 |
+
value=st.session_state["project_dct"]["model_tuning"].get(
|
584 |
+
"sine_cosine_check", False
|
585 |
+
),
|
586 |
+
)
|
587 |
+
st.markdown(
|
588 |
+
"Helps in capturing long term cyclical patterns or seasonality in the data"
|
589 |
+
)
|
590 |
+
|
591 |
+
if sine_cosine:
|
592 |
+
# Drop down to select Frequency of waves
|
593 |
+
xtrain_time_period_months = (
|
594 |
+
X_train[date_col].max() - X_train[date_col].min()
|
595 |
+
).days / 30
|
596 |
+
|
597 |
+
# If we have 6 months of data, only quarter frequency is possible
|
598 |
+
if xtrain_time_period_months <= HALF_YEAR_THRESHOLD:
|
599 |
+
available_frequencies = ["Quarter"]
|
600 |
+
|
601 |
+
# If we have less than 12 months of data, we have quarter and semi-annual frequencies
|
602 |
+
elif xtrain_time_period_months < FULL_YEAR_THRESHOLD:
|
603 |
+
available_frequencies = ["Quarter", "Semi-Annual"]
|
604 |
+
|
605 |
+
# If we have 12 months of data or more, we have quarter, semi-annual and annual frequencies
|
606 |
+
elif xtrain_time_period_months >= FULL_YEAR_THRESHOLD:
|
607 |
+
available_frequencies = ["Quarter", "Semi-Annual", "Annual"]
|
608 |
+
|
609 |
+
wave_freq = st.selectbox("Select Frequency", available_frequencies)
|
610 |
+
except:
|
611 |
+
# Capture the error details
|
612 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
613 |
+
error_message = "".join(
|
614 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
615 |
+
)
|
616 |
+
log_message(
|
617 |
+
"error",
|
618 |
+
f"Error while selecting tuning parameters: {error_message}",
|
619 |
+
"Model Tuning",
|
620 |
+
)
|
621 |
+
st.warning("An error occured, please try again", icon="⚠️")
|
622 |
+
|
623 |
+
try:
|
624 |
+
# Build tuned model
|
625 |
+
if st.button(
|
626 |
+
"Build model with Selected Parameters and Flags",
|
627 |
+
key="build_tuned_model",
|
628 |
+
use_container_width=True,
|
629 |
+
):
|
630 |
+
new_features = features_set
|
631 |
+
st.header("2.1 Results Summary")
|
632 |
+
ss = MaxAbsScaler()
|
633 |
+
if is_panel == True:
|
634 |
+
X_train_tuned = X_train[features_set]
|
635 |
+
X_train_tuned[target_col] = X_train[target_col]
|
636 |
+
X_train_tuned[date_col] = X_train[date_col]
|
637 |
+
X_train_tuned[panel_col] = X_train[panel_col]
|
638 |
+
|
639 |
+
X_test_tuned = X_test[features_set]
|
640 |
+
X_test_tuned[target_col] = X_test[target_col]
|
641 |
+
X_test_tuned[date_col] = X_test[date_col]
|
642 |
+
X_test_tuned[panel_col] = X_test[panel_col]
|
643 |
+
|
644 |
+
else:
|
645 |
+
X_train_tuned = X_train[features_set]
|
646 |
+
|
647 |
+
X_test_tuned = X_test[features_set]
|
648 |
+
|
649 |
+
for flag in selected_options:
|
650 |
+
# Get the flag values of train and test and add to the data
|
651 |
+
X_train_tuned[flag] = st.session_state["Flags"][
|
652 |
+
flag + "__" + target_col
|
653 |
+
]["train"]
|
654 |
+
X_test_tuned[flag] = st.session_state["Flags"][
|
655 |
+
flag + "__" + target_col
|
656 |
+
]["test"]
|
657 |
+
|
658 |
+
if Trend:
|
659 |
+
st.session_state["project_dct"]["model_tuning"][
|
660 |
+
"trend_check"
|
661 |
+
] = True
|
662 |
+
|
663 |
+
# group by panel, calculate trend of each panel spearately. Add trend to new feature set
|
664 |
+
if is_panel:
|
665 |
+
newdata = pd.DataFrame()
|
666 |
+
panel_wise_end_point_train = {}
|
667 |
+
for panel, groupdf in X_train_tuned.groupby(panel_col):
|
668 |
+
groupdf.sort_values(date_col, inplace=True)
|
669 |
+
groupdf["Trend"] = np.arange(
|
670 |
+
TREND_MIN, len(groupdf) + TREND_MIN, 1
|
671 |
+
) # Trend is a straight line with starting point as TREND_MIN
|
672 |
+
newdata = pd.concat([newdata, groupdf])
|
673 |
+
panel_wise_end_point_train[panel] = len(groupdf) + TREND_MIN
|
674 |
+
X_train_tuned = newdata.copy()
|
675 |
+
|
676 |
+
test_newdata = pd.DataFrame()
|
677 |
+
for panel, test_groupdf in X_test_tuned.groupby(panel_col):
|
678 |
+
test_groupdf.sort_values(date_col, inplace=True)
|
679 |
+
start = panel_wise_end_point_train[panel]
|
680 |
+
end = start + len(test_groupdf)
|
681 |
+
test_groupdf["Trend"] = np.arange(start, end, 1)
|
682 |
+
test_newdata = pd.concat([test_newdata, test_groupdf])
|
683 |
+
X_test_tuned = test_newdata.copy()
|
684 |
+
|
685 |
+
new_features = new_features + ["Trend"]
|
686 |
+
|
687 |
+
else:
|
688 |
+
X_train_tuned["Trend"] = np.arange(
|
689 |
+
TREND_MIN, len(X_train_tuned) + TREND_MIN, 1
|
690 |
+
) # Trend is a straight line with starting point as TREND_MIN
|
691 |
+
X_test_tuned["Trend"] = np.arange(
|
692 |
+
len(X_train_tuned) + TREND_MIN,
|
693 |
+
len(X_train_tuned) + len(X_test_tuned) + TREND_MIN,
|
694 |
+
1,
|
695 |
+
)
|
696 |
+
new_features = new_features + ["Trend"]
|
697 |
+
|
698 |
+
else:
|
699 |
+
st.session_state["project_dct"]["model_tuning"][
|
700 |
+
"trend_check"
|
701 |
+
] = False # persistence
|
702 |
+
|
703 |
+
# Add day of week (Week_num) to test & train
|
704 |
+
if day_of_week:
|
705 |
+
st.session_state["project_dct"]["model_tuning"][
|
706 |
+
"week_num_check"
|
707 |
+
] = True
|
708 |
+
|
709 |
+
if is_panel:
|
710 |
+
X_train_tuned[date_col] = pd.to_datetime(
|
711 |
+
X_train_tuned[date_col]
|
712 |
+
)
|
713 |
+
X_train_tuned["day_of_week"] = X_train_tuned[
|
714 |
+
date_col
|
715 |
+
].dt.day_of_week # Day of week
|
716 |
+
# if all the dates in the data have the same day of week number this feature cant be used
|
717 |
+
if X_train_tuned["day_of_week"].nunique() == 1:
|
718 |
+
st.error(
|
719 |
+
"All dates in the data are of the same week day. Hence Week number can't be used."
|
720 |
+
)
|
721 |
+
else:
|
722 |
+
X_test_tuned[date_col] = pd.to_datetime(
|
723 |
+
X_test_tuned[date_col]
|
724 |
+
)
|
725 |
+
X_test_tuned["day_of_week"] = X_test_tuned[
|
726 |
+
date_col
|
727 |
+
].dt.day_of_week # Day of week
|
728 |
+
new_features = new_features + ["day_of_week"]
|
729 |
+
|
730 |
+
else:
|
731 |
+
date = pd.to_datetime(date.values)
|
732 |
+
X_train_tuned["day_of_week"] = pd.to_datetime(
|
733 |
+
X_train[date_col]
|
734 |
+
).dt.day_of_week # Day of week
|
735 |
+
X_test_tuned["day_of_week"] = pd.to_datetime(
|
736 |
+
X_test[date_col]
|
737 |
+
).dt.day_of_week # Day of week
|
738 |
+
|
739 |
+
# if all the dates in the data have the same day of week number this feature cant be used
|
740 |
+
if X_train_tuned["day_of_week"].nunique() == 1:
|
741 |
+
st.error(
|
742 |
+
"All dates in the data are of the same week day. Hence Week number can't be used."
|
743 |
+
)
|
744 |
+
else:
|
745 |
+
new_features = new_features + ["day_of_week"]
|
746 |
+
else:
|
747 |
+
st.session_state["project_dct"]["model_tuning"][
|
748 |
+
"week_num_check"
|
749 |
+
] = False
|
750 |
+
|
751 |
+
# create sine and cosine wave and add to data
|
752 |
+
if sine_cosine:
|
753 |
+
st.session_state["project_dct"]["model_tuning"][
|
754 |
+
"sine_cosine_check"
|
755 |
+
] = True
|
756 |
+
frequency = ANNUAL_FREQUENCY # Annual Frequency
|
757 |
+
if wave_freq == "Quarter":
|
758 |
+
frequency = frequency * QTR_FREQUENCY_FACTOR
|
759 |
+
elif wave_freq == "Semi-Annual":
|
760 |
+
frequency = frequency * HALF_YEARLY_FREQUENCY_FACTOR
|
761 |
+
# create panel wise sine cosine waves in xtrain tuned. add to new feature set
|
762 |
+
if is_panel:
|
763 |
+
new_features = new_features + ["sine_wave", "cosine_wave"]
|
764 |
+
newdata = pd.DataFrame()
|
765 |
+
newdata_test = pd.DataFrame()
|
766 |
+
groups = X_train_tuned.groupby(panel_col)
|
767 |
+
|
768 |
+
train_panel_wise_end_point = {}
|
769 |
+
for panel, groupdf in groups:
|
770 |
+
num_samples = len(groupdf)
|
771 |
+
train_panel_wise_end_point[panel] = num_samples
|
772 |
+
days_since_start = np.arange(num_samples)
|
773 |
+
sine_wave = np.sin(frequency * days_since_start)
|
774 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
775 |
+
sine_cosine_df = pd.DataFrame(
|
776 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
777 |
+
)
|
778 |
+
assert len(sine_cosine_df) == len(groupdf)
|
779 |
+
|
780 |
+
groupdf["sine_wave"] = sine_wave
|
781 |
+
groupdf["cosine_wave"] = cosine_wave
|
782 |
+
newdata = pd.concat([newdata, groupdf])
|
783 |
+
|
784 |
+
X_train_tuned = newdata.copy()
|
785 |
+
|
786 |
+
test_groups = X_test_tuned.groupby(panel_col)
|
787 |
+
for panel, test_groupdf in test_groups:
|
788 |
+
num_samples = len(test_groupdf)
|
789 |
+
start = train_panel_wise_end_point[panel]
|
790 |
+
days_since_start = np.arange(start, start + num_samples, 1)
|
791 |
+
# print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1)))
|
792 |
+
sine_wave = np.sin(frequency * days_since_start)
|
793 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
794 |
+
sine_cosine_df = pd.DataFrame(
|
795 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
796 |
+
)
|
797 |
+
assert len(sine_cosine_df) == len(test_groupdf)
|
798 |
+
# groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
|
799 |
+
test_groupdf["sine_wave"] = sine_wave
|
800 |
+
test_groupdf["cosine_wave"] = cosine_wave
|
801 |
+
newdata_test = pd.concat([newdata_test, test_groupdf])
|
802 |
+
|
803 |
+
X_test_tuned = newdata_test.copy()
|
804 |
+
|
805 |
+
else:
|
806 |
+
new_features = new_features + ["sine_wave", "cosine_wave"]
|
807 |
+
|
808 |
+
num_samples = len(X_train_tuned)
|
809 |
+
|
810 |
+
days_since_start = np.arange(num_samples)
|
811 |
+
sine_wave = np.sin(frequency * days_since_start)
|
812 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
813 |
+
sine_cosine_df = pd.DataFrame(
|
814 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
815 |
+
)
|
816 |
+
# Concatenate the sine and cosine waves with the scaled X DataFrame
|
817 |
+
X_train_tuned = pd.concat(
|
818 |
+
[X_train_tuned, sine_cosine_df], axis=1
|
819 |
+
)
|
820 |
+
|
821 |
+
test_num_samples = len(X_test_tuned)
|
822 |
+
start = num_samples
|
823 |
+
days_since_start = np.arange(start, start + test_num_samples, 1)
|
824 |
+
sine_wave = np.sin(frequency * days_since_start)
|
825 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
826 |
+
sine_cosine_df = pd.DataFrame(
|
827 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
828 |
+
)
|
829 |
+
# Concatenate the sine and cosine waves with the scaled X DataFrame
|
830 |
+
X_test_tuned = pd.concat([X_test_tuned, sine_cosine_df], axis=1)
|
831 |
+
|
832 |
+
else:
|
833 |
+
st.session_state["project_dct"]["model_tuning"][
|
834 |
+
"sine_cosine_check"
|
835 |
+
] = False
|
836 |
+
|
837 |
+
# Build model
|
838 |
+
|
839 |
+
# Get list of parameters added and scale
|
840 |
+
# previous features are scaled already during model build
|
841 |
+
added_params = list(set(new_features) - set(features_set))
|
842 |
+
if len(added_params) > 0:
|
843 |
+
concat_df = pd.concat([X_train_tuned, X_test_tuned]).reset_index(
|
844 |
+
drop=True
|
845 |
+
)
|
846 |
+
|
847 |
+
if is_panel:
|
848 |
+
train_max_date = X_train_tuned[date_col].max()
|
849 |
+
|
850 |
+
# concat_df = concat_df.reset_index(drop=True)
|
851 |
+
# concat_df=concat_df[added_params]
|
852 |
+
train_idx = X_train_tuned.index[-1]
|
853 |
+
|
854 |
+
concat_df[added_params] = ss.fit_transform(concat_df[added_params])
|
855 |
+
# added_params_df = pd.DataFrame(added_params_df)
|
856 |
+
# added_params_df.columns = added_params
|
857 |
+
|
858 |
+
if is_panel:
|
859 |
+
X_train_tuned[added_params] = concat_df[
|
860 |
+
concat_df[date_col] <= train_max_date
|
861 |
+
][added_params].reset_index(drop=True)
|
862 |
+
X_test_tuned[added_params] = concat_df[
|
863 |
+
concat_df[date_col] > train_max_date
|
864 |
+
][added_params].reset_index(drop=True)
|
865 |
+
else:
|
866 |
+
added_params_df = concat_df[added_params]
|
867 |
+
X_train_tuned[added_params] = added_params_df[: train_idx + 1]
|
868 |
+
X_test_tuned[added_params] = added_params_df.loc[
|
869 |
+
train_idx + 1 :
|
870 |
+
].reset_index(drop=True)
|
871 |
+
|
872 |
+
# Add flags (flags are 0, 1 only so need to scale)
|
873 |
+
if selected_options:
|
874 |
+
new_features = new_features + selected_options
|
875 |
+
|
876 |
+
# Build Mixed LM model for panel level data
|
877 |
+
if is_panel:
|
878 |
+
X_train_tuned.sort_values([date_col, panel_col]).reset_index(
|
879 |
+
drop=True, inplace=True
|
880 |
+
)
|
881 |
+
|
882 |
+
new_features = list(set(new_features))
|
883 |
+
inp_vars_str = " + ".join(new_features)
|
884 |
+
|
885 |
+
md_str = target_col + " ~ " + inp_vars_str
|
886 |
+
md_tuned = smf.mixedlm(
|
887 |
+
md_str,
|
888 |
+
data=X_train_tuned[[target_col] + new_features],
|
889 |
+
groups=X_train_tuned[panel_col],
|
890 |
+
)
|
891 |
+
model_tuned = md_tuned.fit()
|
892 |
+
|
893 |
+
# plot actual vs predicted for original model and tuned model
|
894 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
895 |
+
plot_actual_vs_predicted(
|
896 |
+
X_train[date_col],
|
897 |
+
y_train,
|
898 |
+
model.fittedvalues,
|
899 |
+
model,
|
900 |
+
target_column=sel_target_col,
|
901 |
+
is_panel=True,
|
902 |
+
)
|
903 |
+
)
|
904 |
+
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
|
905 |
+
plot_actual_vs_predicted(
|
906 |
+
X_train_tuned[date_col],
|
907 |
+
X_train_tuned[target_col],
|
908 |
+
model_tuned.fittedvalues,
|
909 |
+
model_tuned,
|
910 |
+
target_column=sel_target_col,
|
911 |
+
is_panel=True,
|
912 |
+
)
|
913 |
+
)
|
914 |
+
|
915 |
+
# Build OLS model for panel level data
|
916 |
+
else:
|
917 |
+
new_features = list(set(new_features))
|
918 |
+
model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit()
|
919 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
920 |
+
plot_actual_vs_predicted(
|
921 |
+
X_train[date_col],
|
922 |
+
y_train,
|
923 |
+
model.predict(X_train[features_set]),
|
924 |
+
model,
|
925 |
+
target_column=sel_target_col,
|
926 |
+
)
|
927 |
+
)
|
928 |
+
|
929 |
+
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
|
930 |
+
plot_actual_vs_predicted(
|
931 |
+
X_train[date_col],
|
932 |
+
y_train,
|
933 |
+
model_tuned.predict(X_train_tuned[new_features]),
|
934 |
+
model_tuned,
|
935 |
+
target_column=sel_target_col,
|
936 |
+
)
|
937 |
+
)
|
938 |
+
|
939 |
+
# # ----------------------------------- TESTING -----------------------------------
|
940 |
+
#
|
941 |
+
# Plot Sine & cosine wave to test
|
942 |
+
# sine_cosine_plot = plot_actual_vs_predicted(
|
943 |
+
# X_train[date_col],
|
944 |
+
# y_train,
|
945 |
+
# X_train_tuned['sine_wave'],
|
946 |
+
# model_tuned,
|
947 |
+
# target_column=sel_target_col,
|
948 |
+
# is_panel=True,
|
949 |
+
# )
|
950 |
+
# st.plotly_chart(sine_cosine_plot, use_container_width=True)
|
951 |
+
|
952 |
+
# # Plot Trend line to test
|
953 |
+
# trend_plot = plot_tuned_params(
|
954 |
+
# X_train[date_col],
|
955 |
+
# y_train,
|
956 |
+
# X_train_tuned['Trend'],
|
957 |
+
# model_tuned,
|
958 |
+
# target_column=sel_target_col,
|
959 |
+
# is_panel=True,
|
960 |
+
# )
|
961 |
+
# st.plotly_chart(trend_plot, use_container_width=True)
|
962 |
+
#
|
963 |
+
# # Plot week number to test
|
964 |
+
# week_num_plot = plot_tuned_params(
|
965 |
+
# X_train[date_col],
|
966 |
+
# y_train,
|
967 |
+
# X_train_tuned['day_of_week'],
|
968 |
+
# model_tuned,
|
969 |
+
# target_column=sel_target_col,
|
970 |
+
# is_panel=True,
|
971 |
+
# )
|
972 |
+
# st.plotly_chart(week_num_plot, use_container_width=True)
|
973 |
+
|
974 |
+
# Get model metrics from metric table & display them
|
975 |
+
mape = np.round(metrics_table.iloc[0, 1], 2)
|
976 |
+
r2 = np.round(metrics_table.iloc[1, 1], 2)
|
977 |
+
adjr2 = np.round(metrics_table.iloc[2, 1], 2)
|
978 |
+
|
979 |
+
mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2)
|
980 |
+
r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2)
|
981 |
+
adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2)
|
982 |
+
|
983 |
+
parameters_ = st.columns(3)
|
984 |
+
with parameters_[0]:
|
985 |
+
st.metric("R-squared", r2_tuned, np.round(r2_tuned - r2, 2))
|
986 |
+
with parameters_[1]:
|
987 |
+
st.metric(
|
988 |
+
"Adj. R-squared", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2)
|
989 |
+
)
|
990 |
+
with parameters_[2]:
|
991 |
+
st.metric(
|
992 |
+
"MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse"
|
993 |
+
)
|
994 |
+
|
995 |
+
st.write(model_tuned.summary())
|
996 |
+
|
997 |
+
X_train_tuned[date_col] = X_train[date_col]
|
998 |
+
X_train_tuned[target_col] = y_train
|
999 |
+
X_test_tuned[date_col] = X_test[date_col]
|
1000 |
+
X_test_tuned[target_col] = y_test
|
1001 |
+
|
1002 |
+
st.header("2.2 Actual vs. Predicted Plot (Train)")
|
1003 |
+
if is_panel:
|
1004 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
1005 |
+
plot_actual_vs_predicted(
|
1006 |
+
X_train_tuned[date_col],
|
1007 |
+
X_train_tuned[target_col],
|
1008 |
+
model_tuned.fittedvalues,
|
1009 |
+
model_tuned,
|
1010 |
+
target_column=sel_target_col,
|
1011 |
+
is_panel=True,
|
1012 |
+
)
|
1013 |
+
)
|
1014 |
+
else:
|
1015 |
+
|
1016 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
1017 |
+
plot_actual_vs_predicted(
|
1018 |
+
X_train_tuned[date_col],
|
1019 |
+
X_train_tuned[target_col],
|
1020 |
+
model_tuned.predict(X_train_tuned[new_features]),
|
1021 |
+
model_tuned,
|
1022 |
+
target_column=sel_target_col,
|
1023 |
+
is_panel=False,
|
1024 |
+
)
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
|
1028 |
+
|
1029 |
+
st.markdown("## 2.3 Residual Analysis (Train)")
|
1030 |
+
if is_panel:
|
1031 |
+
columns = st.columns(2)
|
1032 |
+
with columns[0]:
|
1033 |
+
fig = plot_residual_predicted(
|
1034 |
+
y_train, model_tuned.fittedvalues, X_train_tuned
|
1035 |
+
)
|
1036 |
+
st.plotly_chart(fig)
|
1037 |
+
|
1038 |
+
with columns[1]:
|
1039 |
+
st.empty()
|
1040 |
+
fig = qqplot(y_train, model_tuned.fittedvalues)
|
1041 |
+
st.plotly_chart(fig)
|
1042 |
+
|
1043 |
+
with columns[0]:
|
1044 |
+
fig = residual_distribution(y_train, model_tuned.fittedvalues)
|
1045 |
+
st.pyplot(fig)
|
1046 |
+
else:
|
1047 |
+
columns = st.columns(2)
|
1048 |
+
with columns[0]:
|
1049 |
+
fig = plot_residual_predicted(
|
1050 |
+
y_train,
|
1051 |
+
model_tuned.predict(X_train_tuned[new_features]),
|
1052 |
+
X_train,
|
1053 |
+
)
|
1054 |
+
st.plotly_chart(fig)
|
1055 |
+
|
1056 |
+
with columns[1]:
|
1057 |
+
st.empty()
|
1058 |
+
fig = qqplot(
|
1059 |
+
y_train, model_tuned.predict(X_train_tuned[new_features])
|
1060 |
+
)
|
1061 |
+
st.plotly_chart(fig)
|
1062 |
+
|
1063 |
+
with columns[0]:
|
1064 |
+
fig = residual_distribution(
|
1065 |
+
y_train, model_tuned.predict(X_train_tuned[new_features])
|
1066 |
+
)
|
1067 |
+
st.pyplot(fig)
|
1068 |
+
|
1069 |
+
# st.session_state['is_tuned_model'][target_col] = True
|
1070 |
+
# Save tuned model in a dict
|
1071 |
+
st.session_state["Model_Tuned"][sel_model + "__" + target_col] = {
|
1072 |
+
"Model_object": model_tuned,
|
1073 |
+
"feature_set": new_features,
|
1074 |
+
"X_train_tuned": X_train_tuned,
|
1075 |
+
"X_test_tuned": X_test_tuned,
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
with st.expander("Results Summary Test data"):
|
1079 |
+
if is_panel:
|
1080 |
+
random_eff_df = get_random_effects(
|
1081 |
+
st.session_state.media_data.copy(), panel_col, model_tuned
|
1082 |
+
)
|
1083 |
+
test_pred = mdf_predict(
|
1084 |
+
X_test_tuned, model_tuned, random_eff_df
|
1085 |
+
)
|
1086 |
+
else:
|
1087 |
+
test_pred = model_tuned.predict(X_test_tuned[new_features])
|
1088 |
+
st.header("2.2 Actual vs. Predicted Plot (Test)")
|
1089 |
+
|
1090 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
1091 |
+
plot_actual_vs_predicted(
|
1092 |
+
X_test_tuned[date_col],
|
1093 |
+
y_test,
|
1094 |
+
test_pred,
|
1095 |
+
model,
|
1096 |
+
target_column=sel_target_col,
|
1097 |
+
is_panel=is_panel,
|
1098 |
+
)
|
1099 |
+
)
|
1100 |
+
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
|
1101 |
+
st.markdown("## 2.3 Residual Analysis (Test)")
|
1102 |
+
|
1103 |
+
columns = st.columns(2)
|
1104 |
+
with columns[0]:
|
1105 |
+
fig = plot_residual_predicted(y_test, test_pred, X_test_tuned)
|
1106 |
+
st.plotly_chart(fig)
|
1107 |
+
|
1108 |
+
with columns[1]:
|
1109 |
+
st.empty()
|
1110 |
+
fig = qqplot(y_test, test_pred)
|
1111 |
+
st.plotly_chart(fig)
|
1112 |
+
|
1113 |
+
with columns[0]:
|
1114 |
+
fig = residual_distribution(y_test, test_pred)
|
1115 |
+
st.pyplot(fig)
|
1116 |
+
except:
|
1117 |
+
# Capture the error details
|
1118 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
1119 |
+
error_message = "".join(
|
1120 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
1121 |
+
)
|
1122 |
+
log_message(
|
1123 |
+
"error",
|
1124 |
+
f"Error while building tuned model: {error_message}",
|
1125 |
+
"Model Tuning",
|
1126 |
+
)
|
1127 |
+
st.warning("An error occured, please try again", icon="⚠️")
|
1128 |
+
|
1129 |
+
if (
|
1130 |
+
st.session_state["Model_Tuned"] is not None
|
1131 |
+
and len(list(st.session_state["Model_Tuned"].keys())) > 0
|
1132 |
+
):
|
1133 |
+
if st.button("Use This model for Media Planning", use_container_width=True):
|
1134 |
+
|
1135 |
+
# remove previous tuned models saved for this target col
|
1136 |
+
_remove = [
|
1137 |
+
m
|
1138 |
+
for m in st.session_state["Model_Tuned"].keys()
|
1139 |
+
if m.split("__")[1] == target_col and m.split("__")[0] != sel_model
|
1140 |
+
]
|
1141 |
+
if len(_remove) > 0:
|
1142 |
+
for m in _remove:
|
1143 |
+
del st.session_state["Model_Tuned"][m]
|
1144 |
+
|
1145 |
+
# Flag depicting tuned model for selected response metric
|
1146 |
+
st.session_state["is_tuned_model"][target_col] = True
|
1147 |
+
|
1148 |
+
tuned_model_pkl = pickle.dumps(st.session_state["Model_Tuned"])
|
1149 |
+
|
1150 |
+
update_db(
|
1151 |
+
st.session_state["project_number"],
|
1152 |
+
"Model_Tuning",
|
1153 |
+
"tuned_model",
|
1154 |
+
tuned_model_pkl,
|
1155 |
+
schema,
|
1156 |
+
# resp_mtrc=None,
|
1157 |
+
) # db
|
1158 |
+
|
1159 |
+
log_message(
|
1160 |
+
"info",
|
1161 |
+
f"Tuned model {' '.join(_remove)} removed due to overwrite",
|
1162 |
+
"Model Tuning",
|
1163 |
+
)
|
1164 |
+
|
1165 |
+
# Save session state variables (persistence)
|
1166 |
+
st.session_state["project_dct"]["model_tuning"][
|
1167 |
+
"session_state_saved"
|
1168 |
+
] = {}
|
1169 |
+
for key in [
|
1170 |
+
"bin_dict",
|
1171 |
+
"used_response_metrics",
|
1172 |
+
"is_tuned_model",
|
1173 |
+
"media_data",
|
1174 |
+
"X_test_spends",
|
1175 |
+
"spends_data",
|
1176 |
+
]:
|
1177 |
+
st.session_state["project_dct"]["model_tuning"][
|
1178 |
+
"session_state_saved"
|
1179 |
+
][key] = st.session_state[key]
|
1180 |
+
|
1181 |
+
project_dct_pkl = pickle.dumps(st.session_state["project_dct"])
|
1182 |
+
|
1183 |
+
update_db(
|
1184 |
+
st.session_state["project_number"],
|
1185 |
+
"Model_Tuning",
|
1186 |
+
"project_dct",
|
1187 |
+
project_dct_pkl,
|
1188 |
+
schema,
|
1189 |
+
# resp_mtrc=None,
|
1190 |
+
) # db
|
1191 |
+
|
1192 |
+
log_message(
|
1193 |
+
"info",
|
1194 |
+
f'Tuned Model {sel_model + "__" + target_col} Saved',
|
1195 |
+
"Model Tuning",
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
# Clear page metadata
|
1199 |
+
st.session_state["project_dct"]["scenario_planner"][
|
1200 |
+
"modified_metadata_file"
|
1201 |
+
] = None
|
1202 |
+
st.session_state["project_dct"]["response_curves"][
|
1203 |
+
"modified_metadata_file"
|
1204 |
+
] = None
|
1205 |
+
|
1206 |
+
st.success(sel_model + " for " + target_col + " Tuned saved!")
|
1207 |
+
except:
|
1208 |
+
# Capture the error details
|
1209 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
1210 |
+
error_message = "".join(
|
1211 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
1212 |
+
)
|
1213 |
+
log_message("error", f"An error has occured : {error_message}", "Model Tuning")
|
1214 |
+
st.warning("An error occured, please try again", icon="⚠️")
|
1215 |
+
# st.write(error_message)
|
pages/6_AI_Model_Validation.py
ADDED
@@ -0,0 +1,960 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.express as px
|
2 |
+
import numpy as np
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import streamlit as st
|
5 |
+
import pandas as pd
|
6 |
+
import statsmodels.api as sm
|
7 |
+
|
8 |
+
# from sklearn.metrics import mean_absolute_percentage_error
|
9 |
+
import sys
|
10 |
+
import os
|
11 |
+
from utilities import set_header, load_local_css
|
12 |
+
import seaborn as sns
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import tempfile
|
15 |
+
from sklearn.preprocessing import MinMaxScaler
|
16 |
+
|
17 |
+
# from st_aggrid import AgGrid
|
18 |
+
# from st_aggrid import GridOptionsBuilder, GridUpdateMode
|
19 |
+
# from st_aggrid import GridOptionsBuilder
|
20 |
+
import sys
|
21 |
+
import re
|
22 |
+
import pickle
|
23 |
+
from sklearn.metrics import r2_score
|
24 |
+
from data_prep import plot_actual_vs_predicted
|
25 |
+
import sqlite3
|
26 |
+
from utilities import (
|
27 |
+
set_header,
|
28 |
+
load_local_css,
|
29 |
+
update_db,
|
30 |
+
project_selection,
|
31 |
+
retrieve_pkl_object,
|
32 |
+
)
|
33 |
+
from post_gres_cred import db_cred
|
34 |
+
from log_application import log_message
|
35 |
+
import sys, traceback
|
36 |
+
|
37 |
+
schema = db_cred["schema"]
|
38 |
+
|
39 |
+
sys.setrecursionlimit(10**6)
|
40 |
+
|
41 |
+
original_stdout = sys.stdout
|
42 |
+
sys.stdout = open("temp_stdout.txt", "w")
|
43 |
+
sys.stdout.close()
|
44 |
+
sys.stdout = original_stdout
|
45 |
+
|
46 |
+
st.set_page_config(layout="wide")
|
47 |
+
load_local_css("styles.css")
|
48 |
+
set_header()
|
49 |
+
|
50 |
+
|
51 |
+
## DEFINE ALL FUCNTIONS
|
52 |
+
def plot_residual_predicted(actual, predicted, df_):
|
53 |
+
df_["Residuals"] = actual - pd.Series(predicted)
|
54 |
+
df_["StdResidual"] = (df_["Residuals"] - df_["Residuals"].mean()) / df_[
|
55 |
+
"Residuals"
|
56 |
+
].std()
|
57 |
+
|
58 |
+
# Create a Plotly scatter plot
|
59 |
+
fig = px.scatter(
|
60 |
+
df_,
|
61 |
+
x=predicted,
|
62 |
+
y="StdResidual",
|
63 |
+
opacity=0.5,
|
64 |
+
color_discrete_sequence=["#11B6BD"],
|
65 |
+
)
|
66 |
+
|
67 |
+
# Add horizontal lines
|
68 |
+
fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
|
69 |
+
fig.add_hline(y=2, line_color="red")
|
70 |
+
fig.add_hline(y=-2, line_color="red")
|
71 |
+
|
72 |
+
fig.update_xaxes(title="Predicted")
|
73 |
+
fig.update_yaxes(title="Standardized Residuals (Actual - Predicted)")
|
74 |
+
|
75 |
+
# Set the same width and height for both figures
|
76 |
+
fig.update_layout(
|
77 |
+
title="Residuals over Predicted Values",
|
78 |
+
autosize=False,
|
79 |
+
width=600,
|
80 |
+
height=400,
|
81 |
+
)
|
82 |
+
|
83 |
+
return fig
|
84 |
+
|
85 |
+
|
86 |
+
def residual_distribution(actual, predicted):
|
87 |
+
Residuals = actual - pd.Series(predicted)
|
88 |
+
|
89 |
+
# Create a Seaborn distribution plot
|
90 |
+
sns.set(style="whitegrid")
|
91 |
+
plt.figure(figsize=(6, 4))
|
92 |
+
sns.histplot(Residuals, kde=True, color="#11B6BD")
|
93 |
+
|
94 |
+
plt.title(" Distribution of Residuals")
|
95 |
+
plt.xlabel("Residuals")
|
96 |
+
plt.ylabel("Probability Density")
|
97 |
+
|
98 |
+
return plt
|
99 |
+
|
100 |
+
|
101 |
+
def qqplot(actual, predicted):
|
102 |
+
Residuals = actual - pd.Series(predicted)
|
103 |
+
Residuals = pd.Series(Residuals)
|
104 |
+
Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
|
105 |
+
|
106 |
+
# Create a QQ plot using Plotly with custom colors
|
107 |
+
fig = go.Figure()
|
108 |
+
fig.add_trace(
|
109 |
+
go.Scatter(
|
110 |
+
x=sm.ProbPlot(Resud_std).theoretical_quantiles,
|
111 |
+
y=sm.ProbPlot(Resud_std).sample_quantiles,
|
112 |
+
mode="markers",
|
113 |
+
marker=dict(size=5, color="#11B6BD"),
|
114 |
+
name="QQ Plot",
|
115 |
+
)
|
116 |
+
)
|
117 |
+
|
118 |
+
# Add the 45-degree reference line
|
119 |
+
diagonal_line = go.Scatter(
|
120 |
+
x=[
|
121 |
+
-2,
|
122 |
+
2,
|
123 |
+
], # Adjust the x values as needed to fit the range of your data
|
124 |
+
y=[-2, 2], # Adjust the y values accordingly
|
125 |
+
mode="lines",
|
126 |
+
line=dict(color="red"), # Customize the line color and style
|
127 |
+
name=" ",
|
128 |
+
)
|
129 |
+
fig.add_trace(diagonal_line)
|
130 |
+
|
131 |
+
# Customize the layout
|
132 |
+
fig.update_layout(
|
133 |
+
title="QQ Plot of Residuals",
|
134 |
+
title_x=0.5,
|
135 |
+
autosize=False,
|
136 |
+
width=600,
|
137 |
+
height=400,
|
138 |
+
xaxis_title="Theoretical Quantiles",
|
139 |
+
yaxis_title="Sample Quantiles",
|
140 |
+
)
|
141 |
+
|
142 |
+
return fig
|
143 |
+
|
144 |
+
|
145 |
+
def get_random_effects(media_data, panel_col, mdf):
|
146 |
+
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
147 |
+
for i, market in enumerate(media_data[panel_col].unique()):
|
148 |
+
print(i, end="\r")
|
149 |
+
intercept = mdf.random_effects[market].values[0]
|
150 |
+
random_eff_df.loc[i, "random_effect"] = intercept
|
151 |
+
random_eff_df.loc[i, panel_col] = market
|
152 |
+
|
153 |
+
return random_eff_df
|
154 |
+
|
155 |
+
|
156 |
+
def mdf_predict(X_df, mdf, random_eff_df):
|
157 |
+
X = X_df.copy()
|
158 |
+
X = pd.merge(
|
159 |
+
X,
|
160 |
+
random_eff_df[[panel_col, "random_effect"]],
|
161 |
+
on=panel_col,
|
162 |
+
how="left",
|
163 |
+
)
|
164 |
+
X["pred_fixed_effect"] = mdf.predict(X)
|
165 |
+
|
166 |
+
X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
|
167 |
+
X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
|
168 |
+
return X
|
169 |
+
|
170 |
+
|
171 |
+
def metrics_df_panel(model_dict, is_panel):
|
172 |
+
def wmape(actual, forecast):
|
173 |
+
# Weighted MAPE (WMAPE) eliminates the following shortcomings of MAPE & SMAPE
|
174 |
+
## 1. MAPE becomes insanely high when actual is close to 0
|
175 |
+
## 2. MAPE is more favourable to underforecast than overforecast
|
176 |
+
return np.sum(np.abs(actual - forecast)) / np.sum(np.abs(actual))
|
177 |
+
|
178 |
+
metrics_df = pd.DataFrame(
|
179 |
+
columns=[
|
180 |
+
"Model",
|
181 |
+
"R2",
|
182 |
+
"ADJR2",
|
183 |
+
"Train Mape",
|
184 |
+
"Test Mape",
|
185 |
+
"Summary",
|
186 |
+
"Model_object",
|
187 |
+
]
|
188 |
+
)
|
189 |
+
i = 0
|
190 |
+
for key in model_dict.keys():
|
191 |
+
target = key.split("__")[1]
|
192 |
+
metrics_df.at[i, "Model"] = target
|
193 |
+
y = model_dict[key]["X_train_tuned"][target]
|
194 |
+
|
195 |
+
feature_set = model_dict[key]["feature_set"]
|
196 |
+
|
197 |
+
if is_panel:
|
198 |
+
random_df = get_random_effects(
|
199 |
+
media_data, panel_col, model_dict[key]["Model_object"]
|
200 |
+
)
|
201 |
+
pred = mdf_predict(
|
202 |
+
model_dict[key]["X_train_tuned"],
|
203 |
+
model_dict[key]["Model_object"],
|
204 |
+
random_df,
|
205 |
+
)["pred"]
|
206 |
+
else:
|
207 |
+
pred = model_dict[key]["Model_object"].predict(
|
208 |
+
model_dict[key]["X_train_tuned"][feature_set]
|
209 |
+
)
|
210 |
+
|
211 |
+
ytest = model_dict[key]["X_test_tuned"][target]
|
212 |
+
if is_panel:
|
213 |
+
|
214 |
+
predtest = mdf_predict(
|
215 |
+
model_dict[key]["X_test_tuned"],
|
216 |
+
model_dict[key]["Model_object"],
|
217 |
+
random_df,
|
218 |
+
)["pred"]
|
219 |
+
|
220 |
+
else:
|
221 |
+
predtest = model_dict[key]["Model_object"].predict(
|
222 |
+
model_dict[key]["X_test_tuned"][feature_set]
|
223 |
+
)
|
224 |
+
|
225 |
+
metrics_df.at[i, "R2"] = r2_score(y, pred)
|
226 |
+
metrics_df.at[i, "ADJR2"] = 1 - (1 - metrics_df.loc[i, "R2"]) * (len(y) - 1) / (
|
227 |
+
len(y) - len(model_dict[key]["feature_set"]) - 1
|
228 |
+
)
|
229 |
+
# metrics_df.at[i, "Train Mape"] = mean_absolute_percentage_error(y, pred)
|
230 |
+
# metrics_df.at[i, "Test Mape"] = mean_absolute_percentage_error(
|
231 |
+
# ytest, predtest
|
232 |
+
# )
|
233 |
+
metrics_df.at[i, "Train Mape"] = wmape(y, pred)
|
234 |
+
metrics_df.at[i, "Test Mape"] = wmape(ytest, predtest)
|
235 |
+
metrics_df.at[i, "Summary"] = model_dict[key]["Model_object"].summary()
|
236 |
+
metrics_df.at[i, "Model_object"] = model_dict[key]["Model_object"]
|
237 |
+
i += 1
|
238 |
+
metrics_df = np.round(metrics_df, 2)
|
239 |
+
|
240 |
+
metrics_df.rename(
|
241 |
+
columns={"R2": "R-squared", "ADJR2": "Adj. R-squared"}, inplace=True
|
242 |
+
)
|
243 |
+
return metrics_df
|
244 |
+
|
245 |
+
|
246 |
+
def map_channel(transformed_var, channel_dict):
|
247 |
+
for key, value_list in channel_dict.items():
|
248 |
+
if any(raw_var in transformed_var for raw_var in value_list):
|
249 |
+
return key
|
250 |
+
return transformed_var # Return the original value if no match is found
|
251 |
+
|
252 |
+
|
253 |
+
def contributions_nonpanel(model_dict):
|
254 |
+
# with open(os.path.join(st.session_state["project_path"], "channel_groups.pkl"), "rb") as f:
|
255 |
+
# channels = pickle.load(f)
|
256 |
+
|
257 |
+
channels = st.session_state["project_dct"]["data_import"]["group_dict"] # db
|
258 |
+
media_data = st.session_state["media_data"]
|
259 |
+
contribution_df = pd.DataFrame(columns=["Channel"])
|
260 |
+
|
261 |
+
for key in model_dict.keys():
|
262 |
+
|
263 |
+
best_feature_set = model_dict[key]["feature_set"]
|
264 |
+
model = model_dict[key]["Model_object"]
|
265 |
+
target = key.split("__")[1]
|
266 |
+
X_train = model_dict[key]["X_train_tuned"]
|
267 |
+
contri_df = pd.DataFrame()
|
268 |
+
y = []
|
269 |
+
y_pred = []
|
270 |
+
|
271 |
+
coef_df = pd.DataFrame(model.params)
|
272 |
+
coef_df.reset_index(inplace=True)
|
273 |
+
coef_df.columns = ["feature", "coef"]
|
274 |
+
x_train_contribution = X_train.copy()
|
275 |
+
x_train_contribution["pred"] = model.predict(X_train[best_feature_set])
|
276 |
+
|
277 |
+
for i in range(len(coef_df)):
|
278 |
+
|
279 |
+
coef = coef_df.loc[i, "coef"]
|
280 |
+
col = coef_df.loc[i, "feature"]
|
281 |
+
if col != "const":
|
282 |
+
x_train_contribution[str(col) + "_contr"] = (
|
283 |
+
coef * x_train_contribution[col]
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
x_train_contribution["const"] = coef
|
287 |
+
|
288 |
+
tuning_cols = [
|
289 |
+
c
|
290 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
291 |
+
if c
|
292 |
+
in [
|
293 |
+
"day_of_week_contr",
|
294 |
+
"Trend_contr",
|
295 |
+
"sine_wave_contr",
|
296 |
+
"cosine_wave_contr",
|
297 |
+
]
|
298 |
+
]
|
299 |
+
flag_cols = [
|
300 |
+
c
|
301 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
302 |
+
if "_flag" in c
|
303 |
+
]
|
304 |
+
|
305 |
+
# add exogenous contribution to base
|
306 |
+
all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
|
307 |
+
all_exog_vars = [
|
308 |
+
var.lower()
|
309 |
+
.replace(".", "_")
|
310 |
+
.replace("@", "_")
|
311 |
+
.replace(" ", "_")
|
312 |
+
.replace("-", "")
|
313 |
+
.replace(":", "")
|
314 |
+
.replace("__", "_")
|
315 |
+
for var in all_exog_vars
|
316 |
+
]
|
317 |
+
exog_cols = []
|
318 |
+
if len(all_exog_vars) > 0:
|
319 |
+
for col in x_train_contribution.filter(regex="contr").columns:
|
320 |
+
if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
|
321 |
+
exog_cols.append(col)
|
322 |
+
|
323 |
+
base_cols = ["const"] + flag_cols + tuning_cols + exog_cols
|
324 |
+
|
325 |
+
x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(axis=1)
|
326 |
+
x_train_contribution.drop(columns=base_cols, inplace=True)
|
327 |
+
|
328 |
+
contri_df = pd.DataFrame(x_train_contribution.filter(regex="contr").sum(axis=0))
|
329 |
+
|
330 |
+
contri_df.reset_index(inplace=True)
|
331 |
+
contri_df.columns = ["Channel", target]
|
332 |
+
contri_df["Channel"] = contri_df["Channel"].apply(
|
333 |
+
lambda x: map_channel(x, channels)
|
334 |
+
)
|
335 |
+
contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
|
336 |
+
contri_df["Channel"].replace("base_contr", "base", inplace=True)
|
337 |
+
contribution_df = pd.merge(
|
338 |
+
contribution_df, contri_df, on="Channel", how="outer"
|
339 |
+
)
|
340 |
+
|
341 |
+
return contribution_df
|
342 |
+
|
343 |
+
|
344 |
+
def contributions_panel(model_dict):
|
345 |
+
channels = st.session_state["project_dct"]["data_import"]["group_dict"] # db
|
346 |
+
media_data = st.session_state["media_data"]
|
347 |
+
contribution_df = pd.DataFrame(columns=["Channel"])
|
348 |
+
for key in model_dict.keys():
|
349 |
+
best_feature_set = model_dict[key]["feature_set"]
|
350 |
+
model = model_dict[key]["Model_object"]
|
351 |
+
target = key.split("__")[1]
|
352 |
+
X_train = model_dict[key]["X_train_tuned"]
|
353 |
+
contri_df = pd.DataFrame()
|
354 |
+
|
355 |
+
y = []
|
356 |
+
y_pred = []
|
357 |
+
|
358 |
+
random_eff_df = get_random_effects(media_data, panel_col, model)
|
359 |
+
random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
|
360 |
+
random_eff_df["panel_effect"] = (
|
361 |
+
random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
|
362 |
+
)
|
363 |
+
|
364 |
+
coef_df = pd.DataFrame(model.fe_params)
|
365 |
+
coef_df.reset_index(inplace=True)
|
366 |
+
coef_df.columns = ["feature", "coef"]
|
367 |
+
|
368 |
+
x_train_contribution = X_train.copy()
|
369 |
+
x_train_contribution = mdf_predict(x_train_contribution, model, random_eff_df)
|
370 |
+
|
371 |
+
x_train_contribution = pd.merge(
|
372 |
+
x_train_contribution,
|
373 |
+
random_eff_df[[panel_col, "panel_effect"]],
|
374 |
+
on=panel_col,
|
375 |
+
how="left",
|
376 |
+
)
|
377 |
+
for i in range(len(coef_df)):
|
378 |
+
coef = coef_df.loc[i, "coef"]
|
379 |
+
col = coef_df.loc[i, "feature"]
|
380 |
+
if col.lower() != "intercept":
|
381 |
+
x_train_contribution[str(col) + "_contr"] = (
|
382 |
+
coef * x_train_contribution[col]
|
383 |
+
)
|
384 |
+
|
385 |
+
# x_train_contribution['sum_contributions'] = x_train_contribution.filter(regex="contr").sum(axis=1)
|
386 |
+
# x_train_contribution['sum_contributions'] = x_train_contribution['sum_contributions'] + x_train_contribution[
|
387 |
+
# 'panel_effect']
|
388 |
+
|
389 |
+
# base_cols = ["panel_effect"] + [
|
390 |
+
# c
|
391 |
+
# for c in x_train_contribution.filter(regex="contr").columns
|
392 |
+
# if c
|
393 |
+
# in [
|
394 |
+
# "day_of_week_contr",
|
395 |
+
# "Trend_contr",
|
396 |
+
# "sine_wave_contr",
|
397 |
+
# "cosine_wave_contr",
|
398 |
+
# ]
|
399 |
+
# ]
|
400 |
+
tuning_cols = [
|
401 |
+
c
|
402 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
403 |
+
if c
|
404 |
+
in [
|
405 |
+
"day_of_week_contr",
|
406 |
+
"Trend_contr",
|
407 |
+
"sine_wave_contr",
|
408 |
+
"cosine_wave_contr",
|
409 |
+
]
|
410 |
+
]
|
411 |
+
flag_cols = [
|
412 |
+
c
|
413 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
414 |
+
if "_flag" in c
|
415 |
+
]
|
416 |
+
|
417 |
+
# add exogenous contribution to base
|
418 |
+
all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
|
419 |
+
all_exog_vars = [
|
420 |
+
var.lower()
|
421 |
+
.replace(".", "_")
|
422 |
+
.replace("@", "_")
|
423 |
+
.replace(" ", "_")
|
424 |
+
.replace("-", "")
|
425 |
+
.replace(":", "")
|
426 |
+
.replace("__", "_")
|
427 |
+
for var in all_exog_vars
|
428 |
+
]
|
429 |
+
exog_cols = []
|
430 |
+
if len(all_exog_vars) > 0:
|
431 |
+
for col in x_train_contribution.filter(regex="contr").columns:
|
432 |
+
if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
|
433 |
+
exog_cols.append(col)
|
434 |
+
|
435 |
+
base_cols = ["panel_effect"] + flag_cols + tuning_cols + exog_cols
|
436 |
+
|
437 |
+
x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(axis=1)
|
438 |
+
x_train_contribution.drop(columns=base_cols, inplace=True)
|
439 |
+
|
440 |
+
contri_df = pd.DataFrame(x_train_contribution.filter(regex="contr").sum(axis=0))
|
441 |
+
contri_df.reset_index(inplace=True)
|
442 |
+
contri_df.columns = ["Channel", target]
|
443 |
+
|
444 |
+
contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
|
445 |
+
contri_df["Channel"] = contri_df["Channel"].apply(
|
446 |
+
lambda x: map_channel(x, channels)
|
447 |
+
)
|
448 |
+
|
449 |
+
contri_df["Channel"].replace("base_contr", "base", inplace=True)
|
450 |
+
contribution_df = pd.merge(
|
451 |
+
contribution_df, contri_df, on="Channel", how="outer"
|
452 |
+
)
|
453 |
+
# st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
|
454 |
+
return contribution_df
|
455 |
+
|
456 |
+
|
457 |
+
def create_grouped_bar_plot(contribution_df, contribution_selections):
|
458 |
+
# Extract the 'Channel' names
|
459 |
+
channel_names = contribution_df["Channel"].tolist()
|
460 |
+
|
461 |
+
# Dictionary to store all contributions except 'const' and 'base'
|
462 |
+
all_contributions = {
|
463 |
+
name: [] for name in channel_names if name not in ["const", "base"]
|
464 |
+
}
|
465 |
+
|
466 |
+
# Dictionary to store base sales for each selection
|
467 |
+
base_sales_dict = {}
|
468 |
+
|
469 |
+
# Accumulate contributions for each channel from each selection
|
470 |
+
for selection in contribution_selections:
|
471 |
+
contributions = contribution_df[selection].values.astype(float)
|
472 |
+
base_sales = 0 # Initialize base sales for the current selection
|
473 |
+
|
474 |
+
for channel_name, contribution in zip(channel_names, contributions):
|
475 |
+
if channel_name in all_contributions:
|
476 |
+
all_contributions[channel_name].append(contribution)
|
477 |
+
elif channel_name == "base":
|
478 |
+
base_sales = (
|
479 |
+
contribution # Capture base sales for the current selection
|
480 |
+
)
|
481 |
+
|
482 |
+
# Store base sales for each selection
|
483 |
+
base_sales_dict[selection] = base_sales
|
484 |
+
|
485 |
+
# Calculate the average of contributions and sort by this average
|
486 |
+
sorted_channels = sorted(all_contributions.items(), key=lambda x: -np.mean(x[1]))
|
487 |
+
sorted_channel_names = [name for name, _ in sorted_channels]
|
488 |
+
sorted_channel_names = [
|
489 |
+
"Base Sales"
|
490 |
+
] + sorted_channel_names # Adding 'Base Sales' at the start
|
491 |
+
|
492 |
+
trace_data = []
|
493 |
+
max_value = 0 # Initialize max_value to find the highest bar for y-axis adjustment
|
494 |
+
|
495 |
+
# Create traces for the grouped bar chart
|
496 |
+
for i, selection in enumerate(contribution_selections):
|
497 |
+
display_name = sorted_channel_names
|
498 |
+
display_contribution = [base_sales_dict[selection]] + [
|
499 |
+
all_contributions[name][i] for name in sorted_channel_names[1:]
|
500 |
+
] # Start with base sales for the current selection
|
501 |
+
|
502 |
+
# Generating text labels for each bar
|
503 |
+
text_values = [
|
504 |
+
f"{val}%" for val in np.round(display_contribution, 0).astype(int)
|
505 |
+
]
|
506 |
+
|
507 |
+
# Find the max value for y-axis calculation
|
508 |
+
max_contribution = max(display_contribution)
|
509 |
+
if max_contribution > max_value:
|
510 |
+
max_value = max_contribution
|
511 |
+
|
512 |
+
# Create a bar trace for each selection
|
513 |
+
trace = go.Bar(
|
514 |
+
x=display_name,
|
515 |
+
y=display_contribution,
|
516 |
+
name=selection,
|
517 |
+
text=text_values,
|
518 |
+
textposition="outside",
|
519 |
+
)
|
520 |
+
trace_data.append(trace)
|
521 |
+
|
522 |
+
# Define layout for the bar chart
|
523 |
+
layout = go.Layout(
|
524 |
+
title="Metrics Contribution by Channel (Train)",
|
525 |
+
xaxis=dict(title="Channel Name"),
|
526 |
+
yaxis=dict(
|
527 |
+
title="Metrics Contribution", range=[0, max_value * 1.2]
|
528 |
+
), # Set y-axis 20% higher than the max bar
|
529 |
+
barmode="group",
|
530 |
+
plot_bgcolor="white",
|
531 |
+
)
|
532 |
+
|
533 |
+
# Create the figure with trace data and layout
|
534 |
+
fig = go.Figure(data=trace_data, layout=layout)
|
535 |
+
|
536 |
+
return fig
|
537 |
+
|
538 |
+
|
539 |
+
def preprocess_and_plot(contribution_df, contribution_selections):
|
540 |
+
# Extract the 'Channel' names
|
541 |
+
channel_names = contribution_df["Channel"].tolist()
|
542 |
+
|
543 |
+
# Dictionary to store all contributions except 'const' and 'base'
|
544 |
+
all_contributions = {
|
545 |
+
name: [] for name in channel_names if name not in ["const", "base"]
|
546 |
+
}
|
547 |
+
|
548 |
+
# Dictionary to store base sales for each selection
|
549 |
+
base_sales_dict = {}
|
550 |
+
|
551 |
+
# Accumulate contributions for each channel from each selection
|
552 |
+
for selection in contribution_selections:
|
553 |
+
contributions = contribution_df[selection].values.astype(float)
|
554 |
+
base_sales = 0 # Initialize base sales for the current selection
|
555 |
+
|
556 |
+
for channel_name, contribution in zip(channel_names, contributions):
|
557 |
+
if channel_name in all_contributions:
|
558 |
+
all_contributions[channel_name].append(contribution)
|
559 |
+
elif channel_name == "base":
|
560 |
+
base_sales = (
|
561 |
+
contribution # Capture base sales for the current selection
|
562 |
+
)
|
563 |
+
|
564 |
+
# Store base sales for each selection
|
565 |
+
base_sales_dict[selection] = base_sales
|
566 |
+
|
567 |
+
# Calculate the average of contributions and sort by this average
|
568 |
+
sorted_channels = sorted(all_contributions.items(), key=lambda x: -np.mean(x[1]))
|
569 |
+
sorted_channel_names = [name for name, _ in sorted_channels]
|
570 |
+
sorted_channel_names = [
|
571 |
+
"Base Sales"
|
572 |
+
] + sorted_channel_names # Adding 'Base Sales' at the start
|
573 |
+
|
574 |
+
# Initialize a Plotly figure
|
575 |
+
fig = go.Figure()
|
576 |
+
|
577 |
+
for i, selection in enumerate(contribution_selections):
|
578 |
+
display_name = ["Base Sales"] + sorted_channel_names[
|
579 |
+
1:
|
580 |
+
] # Channel names for the plot
|
581 |
+
display_contribution = [
|
582 |
+
base_sales_dict[selection]
|
583 |
+
] # Start with base sales for the current selection
|
584 |
+
|
585 |
+
# Append average contributions for other channels
|
586 |
+
for name in sorted_channel_names[1:]:
|
587 |
+
display_contribution.append(all_contributions[name][i])
|
588 |
+
|
589 |
+
# Generating text labels for each bar
|
590 |
+
text_values = [
|
591 |
+
f"{val}%" for val in np.round(display_contribution, 0).astype(int)
|
592 |
+
]
|
593 |
+
|
594 |
+
# Add a waterfall trace for each selection
|
595 |
+
fig.add_trace(
|
596 |
+
go.Waterfall(
|
597 |
+
orientation="v",
|
598 |
+
measure=["relative"] * len(display_contribution),
|
599 |
+
x=display_name,
|
600 |
+
text=text_values,
|
601 |
+
textposition="outside",
|
602 |
+
y=display_contribution,
|
603 |
+
increasing={"marker": {"color": "green"}},
|
604 |
+
decreasing={"marker": {"color": "red"}},
|
605 |
+
totals={"marker": {"color": "blue"}},
|
606 |
+
name=selection,
|
607 |
+
)
|
608 |
+
)
|
609 |
+
|
610 |
+
# Update layout of the figure
|
611 |
+
fig.update_layout(
|
612 |
+
title="Metrics Contribution by Channel (Train)",
|
613 |
+
xaxis={"title": "Channel Name"},
|
614 |
+
yaxis=dict(title="Metrics Contribution", range=[0, 100 * 1.2]),
|
615 |
+
)
|
616 |
+
|
617 |
+
return fig
|
618 |
+
|
619 |
+
|
620 |
+
def selection_change():
|
621 |
+
edited_rows: dict = st.session_state.project_selection["edited_rows"]
|
622 |
+
st.session_state["selected_row_index_gd_table"] = next(iter(edited_rows))
|
623 |
+
st.session_state["gd_table"] = st.session_state["gd_table"].assign(selected=False)
|
624 |
+
|
625 |
+
update_dict = {idx: values for idx, values in edited_rows.items()}
|
626 |
+
|
627 |
+
st.session_state["gd_table"].update(
|
628 |
+
pd.DataFrame.from_dict(update_dict, orient="index")
|
629 |
+
)
|
630 |
+
|
631 |
+
|
632 |
+
if "username" not in st.session_state:
|
633 |
+
st.session_state["username"] = None
|
634 |
+
|
635 |
+
if "project_name" not in st.session_state:
|
636 |
+
st.session_state["project_name"] = None
|
637 |
+
|
638 |
+
if "project_dct" not in st.session_state:
|
639 |
+
project_selection()
|
640 |
+
st.stop()
|
641 |
+
|
642 |
+
try:
|
643 |
+
st.session_state["bin_dict"] = st.session_state["project_dct"]["data_import"][
|
644 |
+
"category_dict"
|
645 |
+
] # db
|
646 |
+
|
647 |
+
except Exception as e:
|
648 |
+
st.warning("Save atleast one tuned model to proceed")
|
649 |
+
log_message("warning", "No tuned models available", "AI Model Results")
|
650 |
+
st.stop()
|
651 |
+
|
652 |
+
|
653 |
+
if "gd_table" not in st.session_state:
|
654 |
+
st.session_state["gd_table"] = pd.DataFrame()
|
655 |
+
|
656 |
+
try:
|
657 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
658 |
+
|
659 |
+
if (
|
660 |
+
retrieve_pkl_object(
|
661 |
+
st.session_state["project_number"],
|
662 |
+
"Model_Tuning",
|
663 |
+
"tuned_model",
|
664 |
+
schema,
|
665 |
+
)
|
666 |
+
is None
|
667 |
+
):
|
668 |
+
|
669 |
+
st.error("Please save a tuned model")
|
670 |
+
st.stop()
|
671 |
+
|
672 |
+
if (
|
673 |
+
"session_state_saved"
|
674 |
+
in st.session_state["project_dct"]["model_tuning"].keys()
|
675 |
+
and st.session_state["project_dct"]["model_tuning"]["session_state_saved"]
|
676 |
+
!= []
|
677 |
+
):
|
678 |
+
for key in ["used_response_metrics", "media_data", "bin_dict"]:
|
679 |
+
if key not in st.session_state:
|
680 |
+
st.session_state[key] = st.session_state["project_dct"][
|
681 |
+
"model_tuning"
|
682 |
+
]["session_state_saved"][key]
|
683 |
+
# st.session_state["bin_dict"] = st.session_state["project_dct"][
|
684 |
+
# "model_build"
|
685 |
+
# ]["session_state_saved"]["bin_dict"]
|
686 |
+
|
687 |
+
media_data = st.session_state["media_data"]
|
688 |
+
|
689 |
+
# st.write(media_data.columns)
|
690 |
+
|
691 |
+
# set the panel column
|
692 |
+
panel_col = "panel"
|
693 |
+
is_panel = (
|
694 |
+
True if st.session_state["media_data"][panel_col].nunique() > 1 else False
|
695 |
+
)
|
696 |
+
# st.write(is_panel)
|
697 |
+
|
698 |
+
date_col = "date"
|
699 |
+
|
700 |
+
transformed_data = st.session_state["project_dct"]["transformations"][
|
701 |
+
"final_df"
|
702 |
+
] # db
|
703 |
+
tuned_model_dict = retrieve_pkl_object(
|
704 |
+
st.session_state["project_number"], "Model_Tuning", "tuned_model", schema
|
705 |
+
) # db
|
706 |
+
|
707 |
+
feature_set_dct = {
|
708 |
+
key.split("__")[1]: key_dict["feature_set"]
|
709 |
+
for key, key_dict in tuned_model_dict.items()
|
710 |
+
}
|
711 |
+
|
712 |
+
# """ the above part should be modified so that we are fetching features set from the saved model"""
|
713 |
+
|
714 |
+
if "contribution_df" not in st.session_state:
|
715 |
+
st.session_state["contribution_df"] = None
|
716 |
+
|
717 |
+
metrics_table = metrics_df_panel(tuned_model_dict, is_panel)
|
718 |
+
|
719 |
+
cols1 = st.columns([2, 1])
|
720 |
+
with cols1[0]:
|
721 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
722 |
+
with cols1[1]:
|
723 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
724 |
+
|
725 |
+
st.title("AI Model Validation")
|
726 |
+
|
727 |
+
st.header("Contribution Overview")
|
728 |
+
|
729 |
+
# Get list of response metrics
|
730 |
+
st.session_state["used_response_metrics"] = list(
|
731 |
+
set([model.split("__")[1] for model in tuned_model_dict.keys()])
|
732 |
+
)
|
733 |
+
options = st.session_state["used_response_metrics"]
|
734 |
+
|
735 |
+
if len(options) == 0:
|
736 |
+
st.error("Please save and tune a model")
|
737 |
+
st.stop()
|
738 |
+
options = [
|
739 |
+
opt.lower()
|
740 |
+
.replace(" ", "_")
|
741 |
+
.replace("-", "")
|
742 |
+
.replace(":", "")
|
743 |
+
.replace("__", "_")
|
744 |
+
for opt in options
|
745 |
+
]
|
746 |
+
|
747 |
+
default_options = (
|
748 |
+
st.session_state["project_dct"]["saved_model_results"].get(
|
749 |
+
"selected_options"
|
750 |
+
)
|
751 |
+
if st.session_state["project_dct"]["saved_model_results"].get(
|
752 |
+
"selected_options"
|
753 |
+
)
|
754 |
+
is not None
|
755 |
+
else [options[-1]]
|
756 |
+
)
|
757 |
+
for i in default_options:
|
758 |
+
if i not in options:
|
759 |
+
# st.write(i)
|
760 |
+
default_options.remove(i)
|
761 |
+
|
762 |
+
def remove_response_metric(name):
|
763 |
+
# Convert the name to a lowercase string and remove any leading or trailing spaces
|
764 |
+
name_str = str(name).lower().strip()
|
765 |
+
|
766 |
+
# Check if the name starts with "response metric" or "response_metric"
|
767 |
+
if name_str.startswith("response metric"):
|
768 |
+
return name[len("response metric") :].replace("_", " ").strip().title()
|
769 |
+
elif name_str.startswith("response_metric"):
|
770 |
+
return name[len("response_metric") :].replace("_", " ").strip().title()
|
771 |
+
else:
|
772 |
+
return name
|
773 |
+
|
774 |
+
contribution_selections = st.multiselect(
|
775 |
+
"Select the Response Metrics to compare contributions",
|
776 |
+
options,
|
777 |
+
default=default_options,
|
778 |
+
format_func=remove_response_metric,
|
779 |
+
)
|
780 |
+
trace_data = []
|
781 |
+
|
782 |
+
if is_panel:
|
783 |
+
st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
|
784 |
+
|
785 |
+
else:
|
786 |
+
st.session_state["contribution_df"] = contributions_nonpanel(
|
787 |
+
tuned_model_dict
|
788 |
+
)
|
789 |
+
|
790 |
+
# st.write(st.session_state["contribution_df"].columns)
|
791 |
+
# for selection in contribution_selections:
|
792 |
+
|
793 |
+
# trace = go.Bar(
|
794 |
+
# x=st.session_state["contribution_df"]["Channel"],
|
795 |
+
# y=st.session_state["contribution_df"][selection],
|
796 |
+
# name=selection,
|
797 |
+
# text=np.round(st.session_state["contribution_df"][selection], 0)
|
798 |
+
# .astype(int)
|
799 |
+
# .astype(str)
|
800 |
+
# + "%",
|
801 |
+
# textposition="outside",
|
802 |
+
# )
|
803 |
+
# trace_data.append(trace)
|
804 |
+
|
805 |
+
# layout = go.Layout(
|
806 |
+
# title="Metrics Contribution by Channel",
|
807 |
+
# xaxis=dict(title="Channel Name"),
|
808 |
+
# yaxis=dict(title="Metrics Contribution"),
|
809 |
+
# barmode="group",
|
810 |
+
# )
|
811 |
+
# fig = go.Figure(data=trace_data, layout=layout)
|
812 |
+
# st.plotly_chart(fig, use_container_width=True)
|
813 |
+
|
814 |
+
# Display the chart in Streamlit
|
815 |
+
st.plotly_chart(
|
816 |
+
create_grouped_bar_plot(
|
817 |
+
st.session_state["contribution_df"], contribution_selections
|
818 |
+
),
|
819 |
+
use_container_width=True,
|
820 |
+
)
|
821 |
+
|
822 |
+
############################################ Waterfall Chart ############################################
|
823 |
+
|
824 |
+
import plotly.graph_objects as go
|
825 |
+
|
826 |
+
st.plotly_chart(
|
827 |
+
preprocess_and_plot(
|
828 |
+
st.session_state["contribution_df"], contribution_selections
|
829 |
+
),
|
830 |
+
use_container_width=True,
|
831 |
+
)
|
832 |
+
|
833 |
+
############################################ Waterfall Chart ############################################
|
834 |
+
st.header("Analysis of Models Result")
|
835 |
+
gd_table = metrics_table.iloc[:, :-2]
|
836 |
+
target_column = gd_table.at[0, "Model"] # sprint8
|
837 |
+
st.session_state["gd_table"] = gd_table
|
838 |
+
|
839 |
+
with st.container():
|
840 |
+
table = st.data_editor(
|
841 |
+
st.session_state["gd_table"],
|
842 |
+
hide_index=True,
|
843 |
+
# on_change=selection_change,
|
844 |
+
key="project_selection",
|
845 |
+
use_container_width=True,
|
846 |
+
)
|
847 |
+
|
848 |
+
target_column = st.selectbox(
|
849 |
+
"Select a Model to analyse its results",
|
850 |
+
options=st.session_state.used_response_metrics,
|
851 |
+
placeholder=options[0],
|
852 |
+
)
|
853 |
+
feature_set = feature_set_dct[target_column]
|
854 |
+
|
855 |
+
model = metrics_table[metrics_table["Model"] == target_column][
|
856 |
+
"Model_object"
|
857 |
+
].iloc[0]
|
858 |
+
target = metrics_table[metrics_table["Model"] == target_column]["Model"].iloc[0]
|
859 |
+
st.header("Model Summary")
|
860 |
+
st.write(model.summary())
|
861 |
+
|
862 |
+
sel_dict = tuned_model_dict[
|
863 |
+
[k for k in tuned_model_dict.keys() if k.split("__")[1] == target][0]
|
864 |
+
]
|
865 |
+
|
866 |
+
feature_set = sel_dict["feature_set"]
|
867 |
+
X_train = sel_dict["X_train_tuned"]
|
868 |
+
y_train = X_train[target]
|
869 |
+
|
870 |
+
if is_panel:
|
871 |
+
random_effects = get_random_effects(media_data, panel_col, model)
|
872 |
+
pred = mdf_predict(X_train, model, random_effects)["pred"]
|
873 |
+
else:
|
874 |
+
pred = model.predict(X_train[feature_set])
|
875 |
+
|
876 |
+
X_test = sel_dict["X_test_tuned"]
|
877 |
+
y_test = X_test[target]
|
878 |
+
if is_panel:
|
879 |
+
predtest = mdf_predict(X_test, model, random_effects)["pred"]
|
880 |
+
else:
|
881 |
+
predtest = model.predict(X_test[feature_set])
|
882 |
+
|
883 |
+
metrics_table_train, _, fig_train = plot_actual_vs_predicted(
|
884 |
+
X_train[date_col],
|
885 |
+
y_train,
|
886 |
+
pred,
|
887 |
+
model,
|
888 |
+
target_column=target,
|
889 |
+
flag=None,
|
890 |
+
repeat_all_years=False,
|
891 |
+
is_panel=is_panel,
|
892 |
+
)
|
893 |
+
|
894 |
+
metrics_table_test, _, fig_test = plot_actual_vs_predicted(
|
895 |
+
X_test[date_col],
|
896 |
+
y_test,
|
897 |
+
predtest,
|
898 |
+
model,
|
899 |
+
target_column=target,
|
900 |
+
flag=None,
|
901 |
+
repeat_all_years=False,
|
902 |
+
is_panel=is_panel,
|
903 |
+
)
|
904 |
+
|
905 |
+
metrics_table_train = metrics_table_train.set_index("Metric").transpose()
|
906 |
+
metrics_table_train.index = ["Train"]
|
907 |
+
metrics_table_test = metrics_table_test.set_index("Metric").transpose()
|
908 |
+
metrics_table_test.index = ["Test"]
|
909 |
+
metrics_table = np.round(
|
910 |
+
pd.concat([metrics_table_train, metrics_table_test]), 2
|
911 |
+
)
|
912 |
+
|
913 |
+
st.markdown("Result Overview")
|
914 |
+
st.dataframe(np.round(metrics_table, 2), use_container_width=True)
|
915 |
+
|
916 |
+
st.header("Model Accuracy")
|
917 |
+
st.subheader("Actual vs Predicted Plot (Train)")
|
918 |
+
|
919 |
+
st.plotly_chart(fig_train, use_container_width=True)
|
920 |
+
st.subheader("Actual vs Predicted Plot (Test)")
|
921 |
+
st.plotly_chart(fig_test, use_container_width=True)
|
922 |
+
|
923 |
+
st.markdown("## Residual Analysis (Train)")
|
924 |
+
columns = st.columns(2)
|
925 |
+
|
926 |
+
Xtrain1 = X_train.copy()
|
927 |
+
with columns[0]:
|
928 |
+
fig = plot_residual_predicted(y_train, pred, Xtrain1)
|
929 |
+
st.plotly_chart(fig)
|
930 |
+
|
931 |
+
with columns[1]:
|
932 |
+
st.empty()
|
933 |
+
fig = qqplot(y_train, pred)
|
934 |
+
st.plotly_chart(fig)
|
935 |
+
|
936 |
+
with columns[0]:
|
937 |
+
fig = residual_distribution(y_train, pred)
|
938 |
+
st.pyplot(fig)
|
939 |
+
|
940 |
+
if st.button("Save this session", use_container_width=True):
|
941 |
+
project_dct_pkl = pickle.dumps(st.session_state["project_dct"])
|
942 |
+
|
943 |
+
update_db(
|
944 |
+
st.session_state["project_number"],
|
945 |
+
"AI_Model_Results",
|
946 |
+
"project_dct",
|
947 |
+
project_dct_pkl,
|
948 |
+
schema,
|
949 |
+
# resp_mtrc=None,
|
950 |
+
) # db
|
951 |
+
|
952 |
+
log_message("info", "Session saved!", "AI Model Results")
|
953 |
+
st.success("Session Saved!")
|
954 |
+
except:
|
955 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
956 |
+
error_message = "".join(
|
957 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
958 |
+
)
|
959 |
+
log_message("error", f"Error: {error_message}", "AI Model Results")
|
960 |
+
st.warning("An error occured, please try again", icon="⚠️")
|
pages/7_AI_Model_Media_Performance.py
ADDED
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.preprocessing import MinMaxScaler
|
4 |
+
import pickle
|
5 |
+
import os
|
6 |
+
from utilities_with_panel import load_local_css, set_header
|
7 |
+
import yaml
|
8 |
+
from yaml import SafeLoader
|
9 |
+
import sqlite3
|
10 |
+
from datetime import timedelta
|
11 |
+
from utilities import (
|
12 |
+
set_header,
|
13 |
+
load_local_css,
|
14 |
+
update_db,
|
15 |
+
project_selection,
|
16 |
+
retrieve_pkl_object,
|
17 |
+
)
|
18 |
+
from utilities_with_panel import (
|
19 |
+
overview_test_data_prep_panel,
|
20 |
+
overview_test_data_prep_nonpanel,
|
21 |
+
initialize_data_cmp,
|
22 |
+
create_channel_summary,
|
23 |
+
create_contribution_pie,
|
24 |
+
create_contribuion_stacked_plot,
|
25 |
+
create_channel_spends_sales_plot,
|
26 |
+
format_numbers,
|
27 |
+
channel_name_formating,
|
28 |
+
)
|
29 |
+
from log_application import log_message
|
30 |
+
import sys, traceback
|
31 |
+
from post_gres_cred import db_cred
|
32 |
+
|
33 |
+
st.set_page_config(layout="wide")
|
34 |
+
load_local_css("styles.css")
|
35 |
+
set_header()
|
36 |
+
|
37 |
+
|
38 |
+
schema = db_cred["schema"]
|
39 |
+
|
40 |
+
if "username" not in st.session_state:
|
41 |
+
st.session_state["username"] = None
|
42 |
+
|
43 |
+
if "project_name" not in st.session_state:
|
44 |
+
st.session_state["project_name"] = None
|
45 |
+
|
46 |
+
if "project_dct" not in st.session_state:
|
47 |
+
project_selection()
|
48 |
+
st.stop()
|
49 |
+
|
50 |
+
tuned_model = retrieve_pkl_object(
|
51 |
+
st.session_state["project_number"], "Model_Tuning", "tuned_model", schema
|
52 |
+
)
|
53 |
+
|
54 |
+
if tuned_model is None:
|
55 |
+
st.error("Please save a tuned model")
|
56 |
+
st.stop()
|
57 |
+
|
58 |
+
if (
|
59 |
+
"session_state_saved" in st.session_state["project_dct"]["model_tuning"].keys()
|
60 |
+
and st.session_state["project_dct"]["model_tuning"]["session_state_saved"] != []
|
61 |
+
):
|
62 |
+
for key in ["used_response_metrics", "media_data", "bin_dict"]:
|
63 |
+
if key not in st.session_state:
|
64 |
+
st.session_state[key] = st.session_state["project_dct"]["model_tuning"][
|
65 |
+
"session_state_saved"
|
66 |
+
][key]
|
67 |
+
|
68 |
+
|
69 |
+
## DEFINE ALL FUNCTIONS
|
70 |
+
def get_random_effects(media_data, panel_col, mdf):
|
71 |
+
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
72 |
+
for i, market in enumerate(media_data[panel_col].unique()):
|
73 |
+
print(i, end="\r")
|
74 |
+
intercept = mdf.random_effects[market].values[0]
|
75 |
+
random_eff_df.loc[i, "random_effect"] = intercept
|
76 |
+
random_eff_df.loc[i, panel_col] = market
|
77 |
+
|
78 |
+
return random_eff_df
|
79 |
+
|
80 |
+
|
81 |
+
def process_train_and_test(train, test, features, panel_col, target_col):
|
82 |
+
X1 = train[features]
|
83 |
+
|
84 |
+
ss = MinMaxScaler()
|
85 |
+
X1 = pd.DataFrame(ss.fit_transform(X1), columns=X1.columns)
|
86 |
+
|
87 |
+
X1[panel_col] = train[panel_col]
|
88 |
+
X1[target_col] = train[target_col]
|
89 |
+
|
90 |
+
if test is not None:
|
91 |
+
X2 = test[features]
|
92 |
+
X2 = pd.DataFrame(ss.transform(X2), columns=X2.columns)
|
93 |
+
X2[panel_col] = test[panel_col]
|
94 |
+
X2[target_col] = test[target_col]
|
95 |
+
return X1, X2
|
96 |
+
return X1
|
97 |
+
|
98 |
+
|
99 |
+
def mdf_predict(X_df, mdf, random_eff_df):
|
100 |
+
X = X_df.copy()
|
101 |
+
X = pd.merge(
|
102 |
+
X,
|
103 |
+
random_eff_df[[panel_col, "random_effect"]],
|
104 |
+
on=panel_col,
|
105 |
+
how="left",
|
106 |
+
)
|
107 |
+
X["pred_fixed_effect"] = mdf.predict(X)
|
108 |
+
|
109 |
+
X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
|
110 |
+
X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
|
111 |
+
|
112 |
+
return X
|
113 |
+
|
114 |
+
|
115 |
+
try:
|
116 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
117 |
+
|
118 |
+
# conn = sqlite3.connect(
|
119 |
+
# r"DB/User.db", check_same_thread=False
|
120 |
+
# ) # connection with sql db
|
121 |
+
# c = conn.cursor()
|
122 |
+
|
123 |
+
tuned_model = retrieve_pkl_object(
|
124 |
+
st.session_state["project_number"], "Model_Tuning", "tuned_model", schema
|
125 |
+
)
|
126 |
+
|
127 |
+
if tuned_model is None:
|
128 |
+
st.error("Please save a tuned model")
|
129 |
+
st.stop()
|
130 |
+
|
131 |
+
if (
|
132 |
+
"session_state_saved"
|
133 |
+
in st.session_state["project_dct"]["model_tuning"].keys()
|
134 |
+
and st.session_state["project_dct"]["model_tuning"]["session_state_saved"]
|
135 |
+
!= []
|
136 |
+
):
|
137 |
+
for key in [
|
138 |
+
"used_response_metrics",
|
139 |
+
"is_tuned_model",
|
140 |
+
"media_data",
|
141 |
+
"X_test_spends",
|
142 |
+
"spends_data",
|
143 |
+
]:
|
144 |
+
st.session_state[key] = st.session_state["project_dct"]["model_tuning"][
|
145 |
+
"session_state_saved"
|
146 |
+
][key]
|
147 |
+
elif (
|
148 |
+
"session_state_saved"
|
149 |
+
in st.session_state["project_dct"]["model_build"].keys()
|
150 |
+
and st.session_state["project_dct"]["model_build"]["session_state_saved"]
|
151 |
+
!= []
|
152 |
+
):
|
153 |
+
for key in [
|
154 |
+
"used_response_metrics",
|
155 |
+
"date",
|
156 |
+
"saved_model_names",
|
157 |
+
"media_data",
|
158 |
+
"X_test_spends",
|
159 |
+
]:
|
160 |
+
st.session_state[key] = st.session_state["project_dct"]["model_build"][
|
161 |
+
"session_state_saved"
|
162 |
+
][key]
|
163 |
+
else:
|
164 |
+
st.error("Please tune a model first")
|
165 |
+
st.session_state["bin_dict"] = st.session_state["project_dct"]["model_build"][
|
166 |
+
"session_state_saved"
|
167 |
+
]["bin_dict"]
|
168 |
+
st.session_state["media_data"].columns = [
|
169 |
+
c.lower() for c in st.session_state["media_data"].columns
|
170 |
+
]
|
171 |
+
|
172 |
+
# with open(
|
173 |
+
# os.path.join(st.session_state["project_path"], "data_import.pkl"),
|
174 |
+
# "rb",
|
175 |
+
# ) as f:
|
176 |
+
# data = pickle.load(f)
|
177 |
+
|
178 |
+
# # Accessing the loaded objects
|
179 |
+
|
180 |
+
# st.session_state["orig_media_data"] = data["final_df"]
|
181 |
+
|
182 |
+
st.session_state["orig_media_data"] = st.session_state["project_dct"][
|
183 |
+
"data_import"
|
184 |
+
][
|
185 |
+
"imputed_tool_df"
|
186 |
+
].copy() # db
|
187 |
+
st.session_state["channels"] = st.session_state["project_dct"]["data_import"][
|
188 |
+
"group_dict"
|
189 |
+
].copy()
|
190 |
+
# target='Revenue'
|
191 |
+
|
192 |
+
# set the panel column
|
193 |
+
panel_col = "panel"
|
194 |
+
is_panel = (
|
195 |
+
True if st.session_state["media_data"][panel_col].nunique() > 1 else False
|
196 |
+
)
|
197 |
+
|
198 |
+
date_col = "date"
|
199 |
+
|
200 |
+
cols1 = st.columns([2, 1])
|
201 |
+
with cols1[0]:
|
202 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
203 |
+
with cols1[1]:
|
204 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
205 |
+
|
206 |
+
st.title("AI Model Media Performance")
|
207 |
+
|
208 |
+
def remove_response_metric(name):
|
209 |
+
# Convert the name to a lowercase string and remove any leading or trailing spaces
|
210 |
+
name_str = str(name).lower().strip()
|
211 |
+
|
212 |
+
# Check if the name starts with "response metric" or "response_metric"
|
213 |
+
if name_str.startswith("response metric"):
|
214 |
+
return name[len("response metric") :].replace("_", " ").strip().title()
|
215 |
+
elif name_str.startswith("response_metric"):
|
216 |
+
return name[len("response_metric") :].replace("_", " ").strip().title()
|
217 |
+
else:
|
218 |
+
return name
|
219 |
+
|
220 |
+
sel_target_col = st.selectbox(
|
221 |
+
"Select the response metric",
|
222 |
+
st.session_state["used_response_metrics"],
|
223 |
+
format_func=remove_response_metric,
|
224 |
+
)
|
225 |
+
sel_target_col_frmttd = sel_target_col.replace("_", " ").replace("-", " ")
|
226 |
+
sel_target_col_frmttd = sel_target_col_frmttd.title()
|
227 |
+
target_col = (
|
228 |
+
sel_target_col.lower()
|
229 |
+
.replace(" ", "_")
|
230 |
+
.replace("-", "")
|
231 |
+
.replace(":", "")
|
232 |
+
.replace("__", "_")
|
233 |
+
)
|
234 |
+
target = sel_target_col
|
235 |
+
|
236 |
+
# Contribution
|
237 |
+
if is_panel:
|
238 |
+
# read tuned mixedLM model
|
239 |
+
if st.session_state["is_tuned_model"][target_col] == True:
|
240 |
+
|
241 |
+
model_dict = retrieve_pkl_object(
|
242 |
+
st.session_state["project_number"],
|
243 |
+
"Model_Tuning",
|
244 |
+
"tuned_model",
|
245 |
+
schema,
|
246 |
+
) # db
|
247 |
+
|
248 |
+
saved_models = list(model_dict.keys())
|
249 |
+
required_saved_models = [
|
250 |
+
m.split("__")[0]
|
251 |
+
for m in saved_models
|
252 |
+
if m.split("__")[1] == target_col
|
253 |
+
]
|
254 |
+
|
255 |
+
sel_model = required_saved_models[
|
256 |
+
0
|
257 |
+
] # only 1 tuned model available per resp metric
|
258 |
+
|
259 |
+
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
260 |
+
|
261 |
+
model = sel_model_dict["Model_object"]
|
262 |
+
X_train = sel_model_dict["X_train_tuned"]
|
263 |
+
X_test = sel_model_dict["X_test_tuned"]
|
264 |
+
best_feature_set = sel_model_dict["feature_set"]
|
265 |
+
|
266 |
+
# Calculate contributions
|
267 |
+
|
268 |
+
st.session_state["orig_media_data"].columns = [
|
269 |
+
col.lower()
|
270 |
+
.replace(".", "_")
|
271 |
+
.replace("@", "_")
|
272 |
+
.replace(" ", "_")
|
273 |
+
.replace("-", "")
|
274 |
+
.replace(":", "")
|
275 |
+
.replace("__", "_")
|
276 |
+
for col in st.session_state["orig_media_data"].columns
|
277 |
+
]
|
278 |
+
|
279 |
+
media_data = st.session_state["media_data"]
|
280 |
+
|
281 |
+
contri_df = pd.DataFrame()
|
282 |
+
|
283 |
+
y = []
|
284 |
+
y_pred = []
|
285 |
+
|
286 |
+
random_eff_df = get_random_effects(media_data, panel_col, model)
|
287 |
+
random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
|
288 |
+
random_eff_df["panel_effect"] = (
|
289 |
+
random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
|
290 |
+
)
|
291 |
+
|
292 |
+
coef_df = pd.DataFrame(model.fe_params)
|
293 |
+
coef_df.reset_index(inplace=True)
|
294 |
+
coef_df.columns = ["feature", "coef"]
|
295 |
+
|
296 |
+
x_train_contribution = X_train.copy()
|
297 |
+
x_test_contribution = X_test.copy()
|
298 |
+
|
299 |
+
# preprocessing not needed since X_train is already preprocessed
|
300 |
+
# X1, X2 = process_train_and_test(x_train_contribution, x_test_contribution, best_feature_set, panel_col, target_col)
|
301 |
+
# x_train_contribution[best_feature_set] = X1[best_feature_set]
|
302 |
+
# x_test_contribution[best_feature_set] = X2[best_feature_set]
|
303 |
+
|
304 |
+
x_train_contribution = mdf_predict(
|
305 |
+
x_train_contribution, model, random_eff_df
|
306 |
+
)
|
307 |
+
x_test_contribution = mdf_predict(x_test_contribution, model, random_eff_df)
|
308 |
+
|
309 |
+
x_train_contribution = pd.merge(
|
310 |
+
x_train_contribution,
|
311 |
+
random_eff_df[[panel_col, "panel_effect"]],
|
312 |
+
on=panel_col,
|
313 |
+
how="left",
|
314 |
+
)
|
315 |
+
x_test_contribution = pd.merge(
|
316 |
+
x_test_contribution,
|
317 |
+
random_eff_df[[panel_col, "panel_effect"]],
|
318 |
+
on=panel_col,
|
319 |
+
how="left",
|
320 |
+
)
|
321 |
+
|
322 |
+
for i in range(len(coef_df))[1:]:
|
323 |
+
coef = coef_df.loc[i, "coef"]
|
324 |
+
col = coef_df.loc[i, "feature"]
|
325 |
+
if col.lower() != "intercept":
|
326 |
+
x_train_contribution[str(col) + "_contr"] = (
|
327 |
+
coef * x_train_contribution[col]
|
328 |
+
)
|
329 |
+
x_test_contribution[str(col) + "_contr"] = (
|
330 |
+
coef * x_train_contribution[col]
|
331 |
+
)
|
332 |
+
|
333 |
+
tuning_cols = [
|
334 |
+
c
|
335 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
336 |
+
if c
|
337 |
+
in [
|
338 |
+
"day_of_week_contr",
|
339 |
+
"Trend_contr",
|
340 |
+
"sine_wave_contr",
|
341 |
+
"cosine_wave_contr",
|
342 |
+
]
|
343 |
+
]
|
344 |
+
flag_cols = [
|
345 |
+
c
|
346 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
347 |
+
if "_flag" in c
|
348 |
+
]
|
349 |
+
|
350 |
+
# add exogenous contribution to base
|
351 |
+
all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
|
352 |
+
all_exog_vars = [
|
353 |
+
var.lower()
|
354 |
+
.replace(".", "_")
|
355 |
+
.replace("@", "_")
|
356 |
+
.replace(" ", "_")
|
357 |
+
.replace("-", "")
|
358 |
+
.replace(":", "")
|
359 |
+
.replace("__", "_")
|
360 |
+
for var in all_exog_vars
|
361 |
+
]
|
362 |
+
exog_cols = []
|
363 |
+
if len(all_exog_vars) > 0:
|
364 |
+
for col in x_train_contribution.filter(regex="contr").columns:
|
365 |
+
if (
|
366 |
+
len([exog_var for exog_var in all_exog_vars if exog_var in col])
|
367 |
+
> 0
|
368 |
+
):
|
369 |
+
exog_cols.append(col)
|
370 |
+
|
371 |
+
base_cols = ["panel_effect"] + flag_cols + tuning_cols + exog_cols
|
372 |
+
|
373 |
+
x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(
|
374 |
+
axis=1
|
375 |
+
)
|
376 |
+
x_train_contribution.drop(columns=base_cols, inplace=True)
|
377 |
+
x_test_contribution["base_contr"] = x_test_contribution[base_cols].sum(
|
378 |
+
axis=1
|
379 |
+
)
|
380 |
+
x_test_contribution.drop(columns=base_cols, inplace=True)
|
381 |
+
|
382 |
+
overall_contributions = pd.concat(
|
383 |
+
[x_train_contribution, x_test_contribution]
|
384 |
+
).reset_index(drop=True)
|
385 |
+
|
386 |
+
overview_test_data_prep_panel(
|
387 |
+
overall_contributions,
|
388 |
+
st.session_state["orig_media_data"],
|
389 |
+
st.session_state["spends_data"],
|
390 |
+
date_col,
|
391 |
+
panel_col,
|
392 |
+
target_col,
|
393 |
+
)
|
394 |
+
|
395 |
+
else: # NON PANEL
|
396 |
+
if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
|
397 |
+
# with open(
|
398 |
+
# os.path.join(st.session_state["project_path"], "tuned_model.pkl"),
|
399 |
+
# "rb",
|
400 |
+
# ) as file:
|
401 |
+
# model_dict = pickle.load(file)
|
402 |
+
|
403 |
+
model_dict = retrieve_pkl_object(
|
404 |
+
st.session_state["project_number"],
|
405 |
+
"Model_Tuning",
|
406 |
+
"tuned_model",
|
407 |
+
schema,
|
408 |
+
) # db
|
409 |
+
|
410 |
+
saved_models = list(model_dict.keys())
|
411 |
+
required_saved_models = [
|
412 |
+
m.split("__")[0]
|
413 |
+
for m in saved_models
|
414 |
+
if m.split("__")[1] == target_col
|
415 |
+
]
|
416 |
+
|
417 |
+
sel_model = required_saved_models[
|
418 |
+
0
|
419 |
+
] # only 1 tuned model available per resp metric
|
420 |
+
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
421 |
+
|
422 |
+
model = sel_model_dict["Model_object"]
|
423 |
+
X_train = sel_model_dict["X_train_tuned"]
|
424 |
+
X_test = sel_model_dict["X_test_tuned"]
|
425 |
+
best_feature_set = sel_model_dict["feature_set"]
|
426 |
+
|
427 |
+
x_train_contribution = X_train.copy()
|
428 |
+
x_test_contribution = X_test.copy()
|
429 |
+
|
430 |
+
x_train_contribution["pred"] = model.predict(
|
431 |
+
x_train_contribution[best_feature_set]
|
432 |
+
)
|
433 |
+
x_test_contribution["pred"] = model.predict(
|
434 |
+
x_test_contribution[best_feature_set]
|
435 |
+
)
|
436 |
+
|
437 |
+
coef_df = pd.DataFrame(model.params)
|
438 |
+
coef_df.reset_index(inplace=True)
|
439 |
+
coef_df.columns = ["feature", "coef"]
|
440 |
+
|
441 |
+
# st.write(coef_df)
|
442 |
+
for i in range(len(coef_df)):
|
443 |
+
coef = coef_df.loc[i, "coef"]
|
444 |
+
col = coef_df.loc[i, "feature"]
|
445 |
+
if col != "const":
|
446 |
+
x_train_contribution[str(col) + "_contr"] = (
|
447 |
+
coef * x_train_contribution[col]
|
448 |
+
)
|
449 |
+
x_test_contribution[str(col) + "_contr"] = (
|
450 |
+
coef * x_test_contribution[col]
|
451 |
+
)
|
452 |
+
else:
|
453 |
+
x_train_contribution["const"] = coef
|
454 |
+
x_test_contribution["const"] = coef
|
455 |
+
|
456 |
+
tuning_cols = [
|
457 |
+
c
|
458 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
459 |
+
if c
|
460 |
+
in [
|
461 |
+
"day_of_week_contr",
|
462 |
+
"Trend_contr",
|
463 |
+
"sine_wave_contr",
|
464 |
+
"cosine_wave_contr",
|
465 |
+
]
|
466 |
+
]
|
467 |
+
flag_cols = [
|
468 |
+
c
|
469 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
470 |
+
if "_flag" in c
|
471 |
+
]
|
472 |
+
|
473 |
+
# add exogenous contribution to base
|
474 |
+
all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
|
475 |
+
all_exog_vars = [
|
476 |
+
var.lower()
|
477 |
+
.replace(".", "_")
|
478 |
+
.replace("@", "_")
|
479 |
+
.replace(" ", "_")
|
480 |
+
.replace("-", "")
|
481 |
+
.replace(":", "")
|
482 |
+
.replace("__", "_")
|
483 |
+
for var in all_exog_vars
|
484 |
+
]
|
485 |
+
exog_cols = []
|
486 |
+
if len(all_exog_vars) > 0:
|
487 |
+
for col in x_train_contribution.filter(regex="contr").columns:
|
488 |
+
if (
|
489 |
+
len([exog_var for exog_var in all_exog_vars if exog_var in col])
|
490 |
+
> 0
|
491 |
+
):
|
492 |
+
exog_cols.append(col)
|
493 |
+
|
494 |
+
base_cols = ["const"] + flag_cols + tuning_cols + exog_cols
|
495 |
+
# st.write(base_cols)
|
496 |
+
x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(
|
497 |
+
axis=1
|
498 |
+
)
|
499 |
+
x_train_contribution.drop(columns=base_cols, inplace=True)
|
500 |
+
|
501 |
+
x_test_contribution["base_contr"] = x_test_contribution[base_cols].sum(
|
502 |
+
axis=1
|
503 |
+
)
|
504 |
+
x_test_contribution.drop(columns=base_cols, inplace=True)
|
505 |
+
# x_test_contribution.to_csv("Test/test_contr.csv", index=False)
|
506 |
+
|
507 |
+
overall_contributions = pd.concat(
|
508 |
+
[x_train_contribution, x_test_contribution]
|
509 |
+
).reset_index(drop=True)
|
510 |
+
# overall_contributions.to_csv("Test/overall_contributions.csv", index=False)
|
511 |
+
|
512 |
+
overview_test_data_prep_nonpanel(
|
513 |
+
overall_contributions,
|
514 |
+
st.session_state["orig_media_data"].copy(),
|
515 |
+
st.session_state["spends_data"].copy(),
|
516 |
+
date_col,
|
517 |
+
target_col,
|
518 |
+
)
|
519 |
+
# for k, v in st.session_sta
|
520 |
+
# te.items():
|
521 |
+
|
522 |
+
# if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
|
523 |
+
# st.session_state[k] = v
|
524 |
+
|
525 |
+
# authenticator = st.session_state.get('authenticator')
|
526 |
+
|
527 |
+
# if authenticator is None:
|
528 |
+
# authenticator = load_authenticator()
|
529 |
+
|
530 |
+
# name, authentication_status, username = authenticator.login('Login', 'main')
|
531 |
+
# auth_status = st.session_state['authentication_status']
|
532 |
+
|
533 |
+
# if auth_status:
|
534 |
+
# authenticator.logout('Logout', 'main')
|
535 |
+
|
536 |
+
# is_state_initiaized = st.session_state.get('initialized',False)
|
537 |
+
# if not is_state_initiaized:
|
538 |
+
|
539 |
+
min_date = X_train[date_col].min().date()
|
540 |
+
max_date = X_test[date_col].max().date()
|
541 |
+
if "media_performance" not in st.session_state["project_dct"]:
|
542 |
+
st.session_state["project_dct"]["media_performance"] = {
|
543 |
+
"start_date": None,
|
544 |
+
"end_date": None,
|
545 |
+
}
|
546 |
+
|
547 |
+
start_default = st.session_state["project_dct"]["media_performance"].get(
|
548 |
+
"start_date", None
|
549 |
+
)
|
550 |
+
start_default = start_default if start_default is not None else min_date
|
551 |
+
start_default = start_default if start_default > min_date else min_date
|
552 |
+
start_default = start_default if start_default < max_date else min_date
|
553 |
+
|
554 |
+
end_default = st.session_state["project_dct"]["media_performance"].get(
|
555 |
+
"end_date", None
|
556 |
+
)
|
557 |
+
end_default = end_default if end_default is not None else max_date
|
558 |
+
end_default = end_default if end_default > min_date else max_date
|
559 |
+
end_default = end_default if end_default < max_date else max_date
|
560 |
+
|
561 |
+
st.write("Select a timeline for analysis")
|
562 |
+
date_columns = st.columns(2)
|
563 |
+
|
564 |
+
with date_columns[0]:
|
565 |
+
start_date = st.date_input(
|
566 |
+
"Select Start Date",
|
567 |
+
start_default,
|
568 |
+
min_value=min_date,
|
569 |
+
max_value=max_date,
|
570 |
+
)
|
571 |
+
if (start_date < min_date) or (start_date > max_date):
|
572 |
+
st.error("Please select dates in the range of the dates in the data")
|
573 |
+
st.stop()
|
574 |
+
end_default = (
|
575 |
+
end_default if end_default > start_date + timedelta(days=1) else max_date
|
576 |
+
)
|
577 |
+
with date_columns[1]:
|
578 |
+
end_default = (
|
579 |
+
end_default
|
580 |
+
if pd.Timestamp(end_default) >= pd.Timestamp(start_date)
|
581 |
+
else start_date
|
582 |
+
)
|
583 |
+
|
584 |
+
end_date = st.date_input(
|
585 |
+
"Select End Date",
|
586 |
+
end_default,
|
587 |
+
min_value=start_date + timedelta(days=1),
|
588 |
+
max_value=max_date,
|
589 |
+
)
|
590 |
+
if (
|
591 |
+
(start_date < min_date)
|
592 |
+
or (end_date < min_date)
|
593 |
+
or (start_date > max_date)
|
594 |
+
or (end_date > max_date)
|
595 |
+
):
|
596 |
+
st.error("Please select dates in the range of the dates in the data")
|
597 |
+
st.stop()
|
598 |
+
if end_date < start_date + timedelta(days=1):
|
599 |
+
st.error("Please select end date after start date")
|
600 |
+
st.stop()
|
601 |
+
|
602 |
+
st.session_state["project_dct"]["media_performance"]["start_date"] = start_date
|
603 |
+
st.session_state["project_dct"]["media_performance"]["end_date"] = end_date
|
604 |
+
|
605 |
+
st.header("Overview of Previous Media Spend")
|
606 |
+
|
607 |
+
initialize_data_cmp(target_col, is_panel, panel_col, start_date, end_date)
|
608 |
+
scenario = st.session_state["scenario"]
|
609 |
+
raw_df = st.session_state["raw_df"]
|
610 |
+
|
611 |
+
columns = st.columns(2)
|
612 |
+
|
613 |
+
with columns[0]:
|
614 |
+
st.metric(
|
615 |
+
label="Media Spend",
|
616 |
+
value=format_numbers(float(scenario.actual_total_spends)),
|
617 |
+
)
|
618 |
+
###print(f"##################### {scenario.actual_total_sales} ##################")
|
619 |
+
with columns[1]:
|
620 |
+
st.metric(
|
621 |
+
label=sel_target_col_frmttd,
|
622 |
+
value=format_numbers(
|
623 |
+
float(scenario.actual_total_sales), include_indicator=False
|
624 |
+
),
|
625 |
+
)
|
626 |
+
|
627 |
+
actual_summary_df = create_channel_summary(scenario, sel_target_col_frmttd)
|
628 |
+
actual_summary_df["Channel"] = actual_summary_df["Channel"].apply(
|
629 |
+
channel_name_formating
|
630 |
+
)
|
631 |
+
|
632 |
+
columns = st.columns((3, 1))
|
633 |
+
with columns[0]:
|
634 |
+
with st.expander("Channel wise overview"):
|
635 |
+
st.markdown(
|
636 |
+
actual_summary_df.style.set_table_styles(
|
637 |
+
[
|
638 |
+
{
|
639 |
+
"selector": "th",
|
640 |
+
"props": [("background-color", "#f6dcc7")],
|
641 |
+
},
|
642 |
+
{
|
643 |
+
"selector": "tr:nth-child(even)",
|
644 |
+
"props": [("background-color", "#f6dcc7")],
|
645 |
+
},
|
646 |
+
]
|
647 |
+
).to_html(),
|
648 |
+
unsafe_allow_html=True,
|
649 |
+
)
|
650 |
+
|
651 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
652 |
+
##############################
|
653 |
+
|
654 |
+
st.plotly_chart(
|
655 |
+
create_contribution_pie(scenario, sel_target_col_frmttd),
|
656 |
+
use_container_width=True,
|
657 |
+
)
|
658 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
659 |
+
|
660 |
+
################################3
|
661 |
+
st.plotly_chart(
|
662 |
+
create_contribuion_stacked_plot(scenario, sel_target_col_frmttd),
|
663 |
+
use_container_width=True,
|
664 |
+
)
|
665 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
666 |
+
#######################################
|
667 |
+
|
668 |
+
selected_channel_name = st.selectbox(
|
669 |
+
"Channel",
|
670 |
+
st.session_state["channels_list"] + ["non media"],
|
671 |
+
format_func=channel_name_formating,
|
672 |
+
)
|
673 |
+
selected_channel = scenario.channels.get(selected_channel_name, None)
|
674 |
+
|
675 |
+
st.plotly_chart(
|
676 |
+
create_channel_spends_sales_plot(selected_channel, sel_target_col_frmttd),
|
677 |
+
use_container_width=True,
|
678 |
+
)
|
679 |
+
|
680 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
681 |
+
|
682 |
+
if st.button("Save this session", use_container_width=True):
|
683 |
+
|
684 |
+
project_dct_pkl = pickle.dumps(st.session_state["project_dct"])
|
685 |
+
|
686 |
+
update_db(
|
687 |
+
st.session_state["project_number"],
|
688 |
+
"Current_Media_Performance",
|
689 |
+
"project_dct",
|
690 |
+
project_dct_pkl,
|
691 |
+
schema,
|
692 |
+
# resp_mtrc=None,
|
693 |
+
) # db
|
694 |
+
|
695 |
+
st.success("Session Saved!")
|
696 |
+
|
697 |
+
# Remove "response_metric_" from the start and "_total" from the end
|
698 |
+
if str(target_col).startswith("response_metric_"):
|
699 |
+
target_col = target_col.replace("response_metric_", "", 1)
|
700 |
+
|
701 |
+
# Remove the last 6 characters (length of "_total")
|
702 |
+
if str(target_col).endswith("_total"):
|
703 |
+
target_col = target_col[:-6]
|
704 |
+
|
705 |
+
if (
|
706 |
+
st.session_state["project_dct"]["current_media_performance"][
|
707 |
+
"model_outputs"
|
708 |
+
][target_col]
|
709 |
+
is not None
|
710 |
+
):
|
711 |
+
if (
|
712 |
+
len(
|
713 |
+
st.session_state["project_dct"]["current_media_performance"][
|
714 |
+
"model_outputs"
|
715 |
+
][target_col]["contribution_data"]
|
716 |
+
)
|
717 |
+
> 0
|
718 |
+
):
|
719 |
+
st.download_button(
|
720 |
+
label="Download Contribution File",
|
721 |
+
data=st.session_state["project_dct"]["current_media_performance"][
|
722 |
+
"model_outputs"
|
723 |
+
][target_col]["contribution_data"].to_csv(),
|
724 |
+
file_name="contributions.csv",
|
725 |
+
key="dwnld_contr",
|
726 |
+
)
|
727 |
+
except:
|
728 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
729 |
+
error_message = "".join(
|
730 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
731 |
+
)
|
732 |
+
log_message("error", f"Error: {error_message}", "Current Media Performance")
|
733 |
+
st.warning("An error occured, please try again", icon="⚠️")
|
pages/8_Response_Curves.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="Response Curves",
|
5 |
+
page_icon="⚖️",
|
6 |
+
layout="wide",
|
7 |
+
initial_sidebar_state="collapsed",
|
8 |
+
)
|
9 |
+
|
10 |
+
# Disable +/- for number input
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
<style>
|
14 |
+
button.step-up {display: none;}
|
15 |
+
button.step-down {display: none;}
|
16 |
+
div[data-baseweb] {border-radius: 4px;}
|
17 |
+
</style>""",
|
18 |
+
unsafe_allow_html=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
import sys
|
22 |
+
import json
|
23 |
+
import pickle
|
24 |
+
import traceback
|
25 |
+
import numpy as np
|
26 |
+
import pandas as pd
|
27 |
+
import plotly.express as px
|
28 |
+
import plotly.graph_objects as go
|
29 |
+
from post_gres_cred import db_cred
|
30 |
+
from sklearn.metrics import r2_score
|
31 |
+
from log_application import log_message
|
32 |
+
from utilities import project_selection, update_db, set_header, load_local_css
|
33 |
+
from utilities import (
|
34 |
+
get_panels_names,
|
35 |
+
get_metrics_names,
|
36 |
+
name_formating,
|
37 |
+
generate_rcs_data,
|
38 |
+
load_rcs_metadata_files,
|
39 |
+
)
|
40 |
+
|
41 |
+
schema = db_cred["schema"]
|
42 |
+
load_local_css("styles.css")
|
43 |
+
set_header()
|
44 |
+
|
45 |
+
|
46 |
+
# Initialize project name session state
|
47 |
+
if "project_name" not in st.session_state:
|
48 |
+
st.session_state["project_name"] = None
|
49 |
+
|
50 |
+
# Fetch project dictionary
|
51 |
+
if "project_dct" not in st.session_state:
|
52 |
+
project_selection()
|
53 |
+
st.stop()
|
54 |
+
|
55 |
+
# Display Username and Project Name
|
56 |
+
if "username" in st.session_state and st.session_state["username"] is not None:
|
57 |
+
|
58 |
+
cols1 = st.columns([2, 1])
|
59 |
+
|
60 |
+
with cols1[0]:
|
61 |
+
st.markdown(f"**Welcome {st.session_state['username']}**")
|
62 |
+
with cols1[1]:
|
63 |
+
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
|
64 |
+
|
65 |
+
|
66 |
+
# Function to build s curve
|
67 |
+
def s_curve(x, K, b, a, x0):
|
68 |
+
return K / (1 + b * np.exp(-a * (x - x0)))
|
69 |
+
|
70 |
+
|
71 |
+
# Function to update the RCS parameters in the modified RCS metadata data
|
72 |
+
def modify_rcs_parameters(metrics_selected, panel_selected, channel_selected):
|
73 |
+
# Define unique keys for each parameter based on the selection
|
74 |
+
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
75 |
+
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
76 |
+
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
77 |
+
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
78 |
+
|
79 |
+
# Retrieve the updated parameters from session state
|
80 |
+
K_updated, b_updated, a_updated, x0_updated = (
|
81 |
+
st.session_state[K_key],
|
82 |
+
st.session_state[b_key],
|
83 |
+
st.session_state[a_key],
|
84 |
+
st.session_state[x0_key],
|
85 |
+
)
|
86 |
+
|
87 |
+
# Load the existing modified RCS data
|
88 |
+
rcs_data_modified = st.session_state["project_dct"]["response_curves"][
|
89 |
+
"modified_metadata_file"
|
90 |
+
]
|
91 |
+
|
92 |
+
# Update the RCS parameters for the selected metric and panel
|
93 |
+
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
|
94 |
+
"K": K_updated,
|
95 |
+
"b": b_updated,
|
96 |
+
"a": a_updated,
|
97 |
+
"x0": x0_updated,
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
# Function to reset the parameters to their default values
|
102 |
+
def reset_parameters(
|
103 |
+
metrics_selected, panel_selected, channel_selected, original_channel_data
|
104 |
+
):
|
105 |
+
# Define unique keys for each parameter based on the selection
|
106 |
+
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
107 |
+
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
108 |
+
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
109 |
+
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
110 |
+
|
111 |
+
# Reset session state values to original data
|
112 |
+
del st.session_state[K_key]
|
113 |
+
del st.session_state[b_key]
|
114 |
+
del st.session_state[a_key]
|
115 |
+
del st.session_state[x0_key]
|
116 |
+
|
117 |
+
# Reset the modified metadata file with original parameters
|
118 |
+
rcs_data_modified = st.session_state["project_dct"]["response_curves"][
|
119 |
+
"modified_metadata_file"
|
120 |
+
]
|
121 |
+
|
122 |
+
# Update the parameters in the modified data to the original values
|
123 |
+
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
|
124 |
+
"K": original_channel_data["K"],
|
125 |
+
"b": original_channel_data["b"],
|
126 |
+
"a": original_channel_data["a"],
|
127 |
+
"x0": original_channel_data["x0"],
|
128 |
+
}
|
129 |
+
|
130 |
+
# Update the modified metadata
|
131 |
+
st.session_state["project_dct"]["response_curves"][
|
132 |
+
"modified_metadata_file"
|
133 |
+
] = rcs_data_modified
|
134 |
+
|
135 |
+
|
136 |
+
# Function to generate updated RCS parameter DataFrame
|
137 |
+
@st.cache_data(show_spinner=False)
|
138 |
+
def updated_parm_gen(original_data, modified_data, metrics_selected, panel_selected):
|
139 |
+
# Retrieve the data for the selected metric and panel
|
140 |
+
original_data_selection = original_data[metrics_selected][panel_selected]
|
141 |
+
modified_data_selection = modified_data[metrics_selected][panel_selected]
|
142 |
+
|
143 |
+
# Initialize an empty list to hold the data for the DataFrame
|
144 |
+
data = []
|
145 |
+
|
146 |
+
# Iterate through each channel in the selected metric and panel
|
147 |
+
for channel in original_data_selection:
|
148 |
+
# Extract original parameters
|
149 |
+
K_o, b_o, a_o, x0_o = (
|
150 |
+
original_data_selection[channel]["K"],
|
151 |
+
original_data_selection[channel]["b"],
|
152 |
+
original_data_selection[channel]["a"],
|
153 |
+
original_data_selection[channel]["x0"],
|
154 |
+
)
|
155 |
+
# Extract modified parameters
|
156 |
+
K_m, b_m, a_m, x0_m = (
|
157 |
+
modified_data_selection[channel]["K"],
|
158 |
+
modified_data_selection[channel]["b"],
|
159 |
+
modified_data_selection[channel]["a"],
|
160 |
+
modified_data_selection[channel]["x0"],
|
161 |
+
)
|
162 |
+
|
163 |
+
# Check if any parameters differ
|
164 |
+
if (K_o != K_m) or (b_o != b_m) or (a_o != a_m) or (x0_o != x0_m):
|
165 |
+
# Append the data to the list only if there is a difference
|
166 |
+
data.append(
|
167 |
+
{
|
168 |
+
"Metric": name_formating(metrics_selected),
|
169 |
+
"Panel": name_formating(panel_selected),
|
170 |
+
"Channel": name_formating(channel),
|
171 |
+
"K (Original)": K_o,
|
172 |
+
"b (Original)": b_o,
|
173 |
+
"a (Original)": a_o,
|
174 |
+
"x0 (Original)": x0_o,
|
175 |
+
"K (Modified)": K_m,
|
176 |
+
"b (Modified)": b_m,
|
177 |
+
"a (Modified)": a_m,
|
178 |
+
"x0 (Modified)": x0_m,
|
179 |
+
}
|
180 |
+
)
|
181 |
+
|
182 |
+
# Create a DataFrame from the collected data
|
183 |
+
df = pd.DataFrame(data)
|
184 |
+
|
185 |
+
return df
|
186 |
+
|
187 |
+
|
188 |
+
# Function to create JSON file for RCS data
|
189 |
+
@st.cache_data(show_spinner=False)
|
190 |
+
def create_json_file():
|
191 |
+
return json.dumps(
|
192 |
+
st.session_state["project_dct"]["response_curves"]["modified_metadata_file"],
|
193 |
+
indent=4,
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
try:
|
198 |
+
# Page Title
|
199 |
+
st.title("Response Curves")
|
200 |
+
|
201 |
+
# Retrieve the list of all metric names from the specified directory
|
202 |
+
metrics_list = get_metrics_names()
|
203 |
+
|
204 |
+
# Check if there are any metrics available in the metrics list
|
205 |
+
if not metrics_list:
|
206 |
+
# Display a warning message to the user if no metrics are found
|
207 |
+
st.warning(
|
208 |
+
"Please tune at least one model to generate response curves data.",
|
209 |
+
icon="⚠️",
|
210 |
+
)
|
211 |
+
|
212 |
+
# Log message
|
213 |
+
log_message(
|
214 |
+
"warning",
|
215 |
+
"Please tune at least one model to generate response curves data.",
|
216 |
+
"Response Curves",
|
217 |
+
)
|
218 |
+
|
219 |
+
# Stop further execution as there is no data to process
|
220 |
+
st.stop()
|
221 |
+
|
222 |
+
# Widget columns
|
223 |
+
metric_col, channel_col, panel_col, save_progress_col = st.columns(4)
|
224 |
+
|
225 |
+
# Metrics Selection
|
226 |
+
metrics_selected = metric_col.selectbox(
|
227 |
+
"Response Metrics",
|
228 |
+
sorted(metrics_list),
|
229 |
+
format_func=name_formating,
|
230 |
+
key="response_metrics_selectbox",
|
231 |
+
index=0,
|
232 |
+
)
|
233 |
+
|
234 |
+
# Retrieve the list of all panel names for specified Metrics
|
235 |
+
panel_list = get_panels_names(metrics_selected)
|
236 |
+
|
237 |
+
# Panel Selection
|
238 |
+
panel_selected = panel_col.selectbox(
|
239 |
+
"Panel",
|
240 |
+
sorted(panel_list),
|
241 |
+
format_func=name_formating,
|
242 |
+
key="panel_selected_selectbox",
|
243 |
+
index=0,
|
244 |
+
)
|
245 |
+
|
246 |
+
# Save Progress
|
247 |
+
with save_progress_col:
|
248 |
+
st.write("####") # Padding
|
249 |
+
save_progress_placeholder = st.empty()
|
250 |
+
|
251 |
+
# Placeholder to display message and spinner
|
252 |
+
message_spinner_placeholder = st.container()
|
253 |
+
|
254 |
+
# Save page progress
|
255 |
+
with message_spinner_placeholder, st.spinner("Saving Progress ..."):
|
256 |
+
if save_progress_placeholder.button("Save Progress", use_container_width=True):
|
257 |
+
# Update DB
|
258 |
+
update_db(
|
259 |
+
prj_id=st.session_state["project_number"],
|
260 |
+
page_nam="Response Curves",
|
261 |
+
file_nam="project_dct",
|
262 |
+
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
|
263 |
+
schema=schema,
|
264 |
+
)
|
265 |
+
|
266 |
+
# Store the message details in session state
|
267 |
+
message_spinner_placeholder.success(
|
268 |
+
"Progress saved successfully!", icon="💾"
|
269 |
+
)
|
270 |
+
st.toast("Progress saved successfully!", icon="💾")
|
271 |
+
|
272 |
+
# Log message
|
273 |
+
log_message("info", "Progress saved successfully!", "Response Curves")
|
274 |
+
|
275 |
+
# Check if the RCS metadata file does not exist
|
276 |
+
if (
|
277 |
+
st.session_state["project_dct"]["response_curves"]["original_metadata_file"]
|
278 |
+
is None
|
279 |
+
or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
|
280 |
+
is None
|
281 |
+
):
|
282 |
+
# RCS metadata file does not exist. Generating new RCS data
|
283 |
+
generate_rcs_data()
|
284 |
+
|
285 |
+
# Log message
|
286 |
+
log_message(
|
287 |
+
"info",
|
288 |
+
"RCS metadata file does not exist. Generating new RCS data.",
|
289 |
+
"Response Curves",
|
290 |
+
)
|
291 |
+
|
292 |
+
# Load metadata files if they exist
|
293 |
+
original_data, modified_data = load_rcs_metadata_files()
|
294 |
+
|
295 |
+
# Retrieve the list of all channels names for specified Metrics and Panel
|
296 |
+
chanel_list_final = list(original_data[metrics_selected][panel_selected].keys())
|
297 |
+
|
298 |
+
# Channel Selection
|
299 |
+
channel_selected = channel_col.selectbox(
|
300 |
+
"Channel",
|
301 |
+
sorted(chanel_list_final),
|
302 |
+
format_func=name_formating,
|
303 |
+
key="selected_channel_name_selectbox",
|
304 |
+
)
|
305 |
+
|
306 |
+
# Extract original channel data for the selected metric, panel, and channel
|
307 |
+
original_channel_data = original_data[metrics_selected][panel_selected][
|
308 |
+
channel_selected
|
309 |
+
]
|
310 |
+
|
311 |
+
# Extract modified channel data for the same metric, panel, and channel
|
312 |
+
modified_channel_data = modified_data[metrics_selected][panel_selected][
|
313 |
+
channel_selected
|
314 |
+
]
|
315 |
+
|
316 |
+
# X and Y values for plotting
|
317 |
+
x = original_channel_data["x"]
|
318 |
+
y = original_channel_data["y"]
|
319 |
+
|
320 |
+
# Scaling factor for X values and range for S-curve plotting
|
321 |
+
power = original_channel_data["power"]
|
322 |
+
x_plot = original_channel_data["x_plot"]
|
323 |
+
|
324 |
+
# Original S-curve parameters
|
325 |
+
K_orig = original_channel_data["K"]
|
326 |
+
b_orig = original_channel_data["b"]
|
327 |
+
a_orig = original_channel_data["a"]
|
328 |
+
x0_orig = original_channel_data["x0"]
|
329 |
+
|
330 |
+
# Modified S-curve parameters (user-adjusted)
|
331 |
+
K_mod = modified_channel_data["K"]
|
332 |
+
b_mod = modified_channel_data["b"]
|
333 |
+
a_mod = modified_channel_data["a"]
|
334 |
+
x0_mod = modified_channel_data["x0"]
|
335 |
+
|
336 |
+
# Create a scatter plot for the original data points
|
337 |
+
fig = px.scatter(
|
338 |
+
x=x,
|
339 |
+
y=y,
|
340 |
+
title="Original and Modified S-Curve Plot",
|
341 |
+
labels={"x": "Spends", "y": name_formating(metrics_selected)},
|
342 |
+
)
|
343 |
+
|
344 |
+
# Add the modified S-curve trace
|
345 |
+
fig.add_trace(
|
346 |
+
go.Scatter(
|
347 |
+
x=x_plot,
|
348 |
+
y=s_curve(
|
349 |
+
np.array(x_plot) / 10**power,
|
350 |
+
K_mod,
|
351 |
+
b_mod,
|
352 |
+
a_mod,
|
353 |
+
x0_mod,
|
354 |
+
),
|
355 |
+
line=dict(color="red"),
|
356 |
+
name="Modified",
|
357 |
+
),
|
358 |
+
)
|
359 |
+
|
360 |
+
# Add the original S-curve trace
|
361 |
+
fig.add_trace(
|
362 |
+
go.Scatter(
|
363 |
+
x=x_plot,
|
364 |
+
y=s_curve(
|
365 |
+
np.array(x_plot) / 10**power,
|
366 |
+
K_orig,
|
367 |
+
b_orig,
|
368 |
+
a_orig,
|
369 |
+
x0_orig,
|
370 |
+
),
|
371 |
+
line=dict(color="rgba(0, 255, 0, 0.6)"), # Semi-transparent green
|
372 |
+
name="Original",
|
373 |
+
),
|
374 |
+
)
|
375 |
+
|
376 |
+
# Customize the layout of the plot
|
377 |
+
fig.update_layout(
|
378 |
+
title="Comparison of Original and Modified Response-Curves",
|
379 |
+
xaxis_title="Input (Clicks, Impressions, etc..)",
|
380 |
+
yaxis_title=name_formating(metrics_selected),
|
381 |
+
legend_title="Curve Type",
|
382 |
+
)
|
383 |
+
|
384 |
+
# Display s-curve
|
385 |
+
st.plotly_chart(fig, use_container_width=True)
|
386 |
+
|
387 |
+
# Calculate R-squared for the original curve
|
388 |
+
y_orig_pred = s_curve(np.array(x) / 10**power, K_orig, b_orig, a_orig, x0_orig)
|
389 |
+
r2_orig = r2_score(y, y_orig_pred)
|
390 |
+
|
391 |
+
# Calculate R-squared for the modified curve
|
392 |
+
y_mod_pred = s_curve(np.array(x) / 10**power, K_mod, b_mod, a_mod, x0_mod)
|
393 |
+
r2_mod = r2_score(y, y_mod_pred)
|
394 |
+
|
395 |
+
# Calculate the difference in R-squared
|
396 |
+
r2_diff = r2_mod - r2_orig
|
397 |
+
|
398 |
+
# Display R-squared metrics
|
399 |
+
st.write("## R-squared Comparison")
|
400 |
+
r2_col = st.columns(3)
|
401 |
+
|
402 |
+
r2_col[0].metric("R-squared (Original)", f"{r2_orig:.2f}")
|
403 |
+
r2_col[1].metric("R-squared (Modified)", f"{r2_mod:.2f}")
|
404 |
+
r2_col[2].metric("Difference in R-squared", f"{r2_diff:.2f}")
|
405 |
+
|
406 |
+
# Define unique keys for each parameter based on the selection
|
407 |
+
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
408 |
+
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
409 |
+
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
410 |
+
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
|
411 |
+
|
412 |
+
# Initialize session state keys if they do not exist
|
413 |
+
if K_key not in st.session_state:
|
414 |
+
st.session_state[K_key] = K_mod
|
415 |
+
if b_key not in st.session_state:
|
416 |
+
st.session_state[b_key] = b_mod
|
417 |
+
if a_key not in st.session_state:
|
418 |
+
st.session_state[a_key] = a_mod
|
419 |
+
if x0_key not in st.session_state:
|
420 |
+
st.session_state[x0_key] = x0_mod
|
421 |
+
|
422 |
+
# RCS parameters input
|
423 |
+
rsc_ip_col = st.columns(4)
|
424 |
+
with rsc_ip_col[0]:
|
425 |
+
K_updated = st.number_input(
|
426 |
+
"K",
|
427 |
+
step=0.001,
|
428 |
+
min_value=0.000000,
|
429 |
+
format="%.6f",
|
430 |
+
on_change=modify_rcs_parameters,
|
431 |
+
args=(metrics_selected, panel_selected, channel_selected),
|
432 |
+
key=K_key,
|
433 |
+
)
|
434 |
+
with rsc_ip_col[1]:
|
435 |
+
b_updated = st.number_input(
|
436 |
+
"b",
|
437 |
+
step=0.001,
|
438 |
+
min_value=0.000000,
|
439 |
+
format="%.6f",
|
440 |
+
on_change=modify_rcs_parameters,
|
441 |
+
args=(metrics_selected, panel_selected, channel_selected),
|
442 |
+
key=b_key,
|
443 |
+
)
|
444 |
+
with rsc_ip_col[2]:
|
445 |
+
a_updated = st.number_input(
|
446 |
+
"a",
|
447 |
+
step=0.001,
|
448 |
+
min_value=0.000000,
|
449 |
+
format="%.6f",
|
450 |
+
on_change=modify_rcs_parameters,
|
451 |
+
args=(metrics_selected, panel_selected, channel_selected),
|
452 |
+
key=a_key,
|
453 |
+
)
|
454 |
+
with rsc_ip_col[3]:
|
455 |
+
x0_updated = st.number_input(
|
456 |
+
"x0",
|
457 |
+
step=0.001,
|
458 |
+
min_value=0.000000,
|
459 |
+
format="%.6f",
|
460 |
+
on_change=modify_rcs_parameters,
|
461 |
+
args=(metrics_selected, panel_selected, channel_selected),
|
462 |
+
key=x0_key,
|
463 |
+
)
|
464 |
+
|
465 |
+
# Create columns for Reset and Download buttons
|
466 |
+
reset_download_col = st.columns(2)
|
467 |
+
with reset_download_col[0]:
|
468 |
+
if st.button(
|
469 |
+
"Reset",
|
470 |
+
use_container_width=True,
|
471 |
+
):
|
472 |
+
reset_parameters(
|
473 |
+
metrics_selected,
|
474 |
+
panel_selected,
|
475 |
+
channel_selected,
|
476 |
+
original_channel_data,
|
477 |
+
)
|
478 |
+
|
479 |
+
# Log message
|
480 |
+
log_message(
|
481 |
+
"info",
|
482 |
+
f"METRIC: {name_formating(metrics_selected)} ; PANEL: {name_formating(panel_selected)}, CHANNEL: {name_formating(channel_selected)} has been reset to its original value.",
|
483 |
+
"Response Curves",
|
484 |
+
)
|
485 |
+
|
486 |
+
st.rerun()
|
487 |
+
|
488 |
+
with reset_download_col[1]:
|
489 |
+
# Provide a download button for the modified RCS data
|
490 |
+
try:
|
491 |
+
# Create JSON file for RCS data
|
492 |
+
json_data = create_json_file()
|
493 |
+
st.download_button(
|
494 |
+
label="Download",
|
495 |
+
data=json_data,
|
496 |
+
file_name=f"{name_formating(metrics_selected)}_{name_formating(panel_selected)}_rcs_data.json",
|
497 |
+
mime="application/json",
|
498 |
+
use_container_width=True,
|
499 |
+
)
|
500 |
+
except:
|
501 |
+
# Download failed
|
502 |
+
pass
|
503 |
+
|
504 |
+
# Generate the DataFrame showing only non-matching parameters
|
505 |
+
updated_parm_df = updated_parm_gen(
|
506 |
+
original_data, modified_data, metrics_selected, panel_selected
|
507 |
+
)
|
508 |
+
|
509 |
+
# Display the DataFrame or show an informational message if no updates
|
510 |
+
if not updated_parm_df.empty:
|
511 |
+
st.write("## Parameter Comparison for Selected Metric and Panel")
|
512 |
+
st.dataframe(updated_parm_df, hide_index=True)
|
513 |
+
else:
|
514 |
+
st.info("No parameters are updated for the selected Metric and Panel")
|
515 |
+
|
516 |
+
except Exception as e:
|
517 |
+
# Capture the error details
|
518 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
519 |
+
error_message = "".join(
|
520 |
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
521 |
+
)
|
522 |
+
|
523 |
+
# Log message
|
524 |
+
log_message("error", f"An error occurred: {error_message}.", "Response Curves")
|
525 |
+
|
526 |
+
# Display a warning message
|
527 |
+
st.warning(
|
528 |
+
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
|
529 |
+
icon="⚠️",
|
530 |
+
)
|
pages/9_Scenario_Planner.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
post_gres_cred.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
db_cred = {'schema': "mmo_service_owner"}
|
2 |
+
schema = ""
|
ppt/template.txt
ADDED
File without changes
|
ppt_utils.py
ADDED
@@ -0,0 +1,1419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import pptx
|
4 |
+
from pptx import Presentation
|
5 |
+
from pptx.chart.data import CategoryChartData, ChartData
|
6 |
+
from pptx.enum.chart import XL_CHART_TYPE, XL_LEGEND_POSITION, XL_LABEL_POSITION
|
7 |
+
from pptx.enum.chart import XL_TICK_LABEL_POSITION
|
8 |
+
from pptx.util import Inches, Pt
|
9 |
+
import os
|
10 |
+
import pickle
|
11 |
+
from pathlib import Path
|
12 |
+
from sklearn.metrics import (
|
13 |
+
mean_absolute_error,
|
14 |
+
r2_score,
|
15 |
+
mean_absolute_percentage_error,
|
16 |
+
)
|
17 |
+
import streamlit as st
|
18 |
+
from collections import OrderedDict
|
19 |
+
from utilities import get_metrics_names, initialize_data, retrieve_pkl_object_without_warning
|
20 |
+
from io import BytesIO
|
21 |
+
from pptx.dml.color import RGBColor
|
22 |
+
from post_gres_cred import db_cred
|
23 |
+
schema=db_cred['schema']
|
24 |
+
|
25 |
+
from constants import (
|
26 |
+
TITLE_FONT_SIZE,
|
27 |
+
AXIS_LABEL_FONT_SIZE,
|
28 |
+
CHART_TITLE_FONT_SIZE,
|
29 |
+
AXIS_TITLE_FONT_SIZE,
|
30 |
+
DATA_LABEL_FONT_SIZE,
|
31 |
+
LEGEND_FONT_SIZE,
|
32 |
+
PIE_LEGEND_FONT_SIZE
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def format_response_metric(target):
|
37 |
+
if target.startswith('response_metric_'):
|
38 |
+
target = target.replace('response_metric_', '')
|
39 |
+
target = target.replace("_", " ").title()
|
40 |
+
return target
|
41 |
+
|
42 |
+
|
43 |
+
def smape(actual, forecast):
|
44 |
+
# Symmetric Mape (SMAPE) eliminates shortcomings of MAPE :
|
45 |
+
## 1. MAPE becomes insanely high when actual is close to 0
|
46 |
+
## 2. MAPE is more favourable to underforecast than overforecast
|
47 |
+
return (1 / len(actual)) * np.sum(1 * np.abs(forecast - actual) / (np.abs(actual) + np.abs(forecast)))
|
48 |
+
|
49 |
+
|
50 |
+
def safe_num_to_per(num):
|
51 |
+
try:
|
52 |
+
return "{:.0%}".format(num)
|
53 |
+
except:
|
54 |
+
return num
|
55 |
+
|
56 |
+
|
57 |
+
# Function to convert numbers to abbreviated format
|
58 |
+
def convert_number_to_abbreviation(number):
|
59 |
+
try:
|
60 |
+
number = float(number)
|
61 |
+
if number >= 1000000:
|
62 |
+
return f'{number / 1000000:.1f} M'
|
63 |
+
elif number >= 1000:
|
64 |
+
return f'{number / 1000:.1f} K'
|
65 |
+
else:
|
66 |
+
return str(number)
|
67 |
+
except:
|
68 |
+
return number
|
69 |
+
|
70 |
+
|
71 |
+
def round_off(x, round_off_decimal=0):
|
72 |
+
# round off
|
73 |
+
try:
|
74 |
+
x = float(x)
|
75 |
+
if x < 1 and x > 0:
|
76 |
+
round_off_decimal = int(np.floor(np.abs(np.log10(x)))) + max(round_off_decimal, 1)
|
77 |
+
x = np.round(x, round_off_decimal)
|
78 |
+
elif x < 0 and x > -1:
|
79 |
+
round_off_decimal = int(np.floor(np.abs(np.log10(np.abs(x))))) + max(round_off_decimal, 1)
|
80 |
+
x = -np.round(x, round_off_decimal)
|
81 |
+
else:
|
82 |
+
x = np.round(x, round_off_decimal)
|
83 |
+
return x
|
84 |
+
except:
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
def fill_table_placeholder(table_placeholder, slide, df, column_width=None, table_height=None):
|
89 |
+
cols = len(df.columns)
|
90 |
+
rows = len(df)
|
91 |
+
|
92 |
+
if table_height is None:
|
93 |
+
table_height = table_placeholder.height
|
94 |
+
|
95 |
+
x, y, cx, cy = table_placeholder.left, table_placeholder.top, table_placeholder.width, table_height
|
96 |
+
table = slide.shapes.add_table(rows + 1, cols, x, y, cx, cy).table
|
97 |
+
|
98 |
+
# Populate the table with data from the DataFrame
|
99 |
+
for row_idx, row in enumerate(df.values):
|
100 |
+
for col_idx, value in enumerate(row):
|
101 |
+
cell = table.cell(row_idx + 1, col_idx)
|
102 |
+
cell.text = str(value)
|
103 |
+
for col_idx, value in enumerate(df.columns):
|
104 |
+
cell = table.cell(0, col_idx)
|
105 |
+
cell.text = str(value)
|
106 |
+
|
107 |
+
if column_width is not None:
|
108 |
+
for col_idx, column_width in column_width.items():
|
109 |
+
table.columns[col_idx].width = Inches(column_width)
|
110 |
+
|
111 |
+
table_placeholder._element.getparent().remove(table_placeholder._element)
|
112 |
+
|
113 |
+
|
114 |
+
def bar_chart(chart_placeholder, slide, chart_data, titles={}, min_y=None, max_y=None, type='V', legend=True,
|
115 |
+
label_type=None, xaxis_pos=None):
|
116 |
+
x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height
|
117 |
+
if type == 'V':
|
118 |
+
graphic_frame = slide.shapes.add_chart(
|
119 |
+
XL_CHART_TYPE.COLUMN_CLUSTERED, x, y, cx, cy, chart_data
|
120 |
+
)
|
121 |
+
if type == 'H':
|
122 |
+
graphic_frame = slide.shapes.add_chart(
|
123 |
+
XL_CHART_TYPE.BAR_CLUSTERED, x, y, cx, cy, chart_data
|
124 |
+
)
|
125 |
+
chart = graphic_frame.chart
|
126 |
+
|
127 |
+
category_axis = chart.category_axis
|
128 |
+
value_axis = chart.value_axis
|
129 |
+
|
130 |
+
# Add chart title
|
131 |
+
if 'chart_title' in titles.keys():
|
132 |
+
chart.has_title = True
|
133 |
+
chart.chart_title.text_frame.text = titles['chart_title']
|
134 |
+
chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
|
135 |
+
chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)
|
136 |
+
|
137 |
+
# Add axis titles
|
138 |
+
if 'x_axis' in titles.keys():
|
139 |
+
category_axis.has_title = True
|
140 |
+
category_axis.axis_title.text_frame.text = titles['x_axis']
|
141 |
+
category_title = category_axis.axis_title.text_frame.paragraphs[0].runs[0]
|
142 |
+
category_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
|
143 |
+
|
144 |
+
if 'y_axis' in titles.keys():
|
145 |
+
value_axis.has_title = True
|
146 |
+
value_axis.axis_title.text_frame.text = titles['y_axis']
|
147 |
+
value_title = value_axis.axis_title.text_frame.paragraphs[0].runs[0]
|
148 |
+
value_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
|
149 |
+
|
150 |
+
if xaxis_pos == 'low':
|
151 |
+
category_axis.tick_label_position = XL_TICK_LABEL_POSITION.LOW
|
152 |
+
|
153 |
+
# Customize the chart
|
154 |
+
if legend:
|
155 |
+
chart.has_legend = True
|
156 |
+
chart.legend.position = XL_LEGEND_POSITION.BOTTOM
|
157 |
+
chart.legend.font.size = Pt(LEGEND_FONT_SIZE)
|
158 |
+
chart.legend.include_in_layout = False
|
159 |
+
|
160 |
+
# Adjust font size for axis labels
|
161 |
+
category_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
|
162 |
+
value_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
|
163 |
+
|
164 |
+
if min_y is not None:
|
165 |
+
value_axis.minimum_scale = min_y # Adjust this value as needed
|
166 |
+
|
167 |
+
if max_y is not None:
|
168 |
+
value_axis.maximum_scale = max_y # Adjust this value as needed
|
169 |
+
|
170 |
+
plot = chart.plots[0]
|
171 |
+
plot.has_data_labels = True
|
172 |
+
data_labels = plot.data_labels
|
173 |
+
|
174 |
+
if label_type == 'per':
|
175 |
+
data_labels.number_format = '0"%"'
|
176 |
+
elif label_type == '$':
|
177 |
+
data_labels.number_format = '$[>=1000000]#,##0.0,,"M";$[>=1000]#,##0.0,"K";$#,##0'
|
178 |
+
elif label_type == '$1':
|
179 |
+
data_labels.number_format = '$[>=1000000]#,##0,,"M";$[>=1000]#,##0,"K";$#,##0'
|
180 |
+
elif label_type == 'M':
|
181 |
+
data_labels.number_format = '#0.0,,"M"'
|
182 |
+
elif label_type == 'M1':
|
183 |
+
data_labels.number_format = '#0.00,,"M"'
|
184 |
+
elif label_type == 'K':
|
185 |
+
data_labels.number_format = '#0.0,"K"'
|
186 |
+
|
187 |
+
data_labels.font.size = Pt(DATA_LABEL_FONT_SIZE)
|
188 |
+
|
189 |
+
chart_placeholder._element.getparent().remove(chart_placeholder._element)
|
190 |
+
|
191 |
+
|
192 |
+
def line_chart(chart_placeholder, slide, chart_data, titles={}, min_y=None, max_y=None):
|
193 |
+
# Add the chart to the slide
|
194 |
+
x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height
|
195 |
+
|
196 |
+
chart = slide.shapes.add_chart(
|
197 |
+
XL_CHART_TYPE.LINE, x, y, cx, cy, chart_data
|
198 |
+
).chart
|
199 |
+
|
200 |
+
chart.has_legend = True
|
201 |
+
chart.legend.position = XL_LEGEND_POSITION.BOTTOM
|
202 |
+
chart.legend.font.size = Pt(LEGEND_FONT_SIZE)
|
203 |
+
|
204 |
+
category_axis = chart.category_axis
|
205 |
+
value_axis = chart.value_axis
|
206 |
+
|
207 |
+
if min_y is not None:
|
208 |
+
value_axis.minimum_scale = min_y
|
209 |
+
|
210 |
+
if max_y is not None:
|
211 |
+
value_axis.maximum_scale = max_y
|
212 |
+
|
213 |
+
if min_y is not None and max_y is not None:
|
214 |
+
value_axis.major_unit = int((max_y - min_y) / 2)
|
215 |
+
|
216 |
+
if 'chart_title' in titles.keys():
|
217 |
+
chart.has_title = True
|
218 |
+
chart.chart_title.text_frame.text = titles['chart_title']
|
219 |
+
chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
|
220 |
+
chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)
|
221 |
+
|
222 |
+
if 'x_axis' in titles.keys():
|
223 |
+
category_axis.has_title = True
|
224 |
+
category_axis.axis_title.text_frame.text = titles['x_axis']
|
225 |
+
category_title = category_axis.axis_title.text_frame.paragraphs[0].runs[0]
|
226 |
+
category_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
|
227 |
+
|
228 |
+
if 'y_axis' in titles.keys():
|
229 |
+
value_axis.has_title = True
|
230 |
+
value_axis.axis_title.text_frame.text = titles['y_axis']
|
231 |
+
value_title = value_axis.axis_title.text_frame.paragraphs[0].runs[0]
|
232 |
+
value_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)
|
233 |
+
|
234 |
+
# Adjust font size for axis labels
|
235 |
+
category_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
|
236 |
+
value_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
|
237 |
+
|
238 |
+
plot = chart.plots[0]
|
239 |
+
series = plot.series[1]
|
240 |
+
line = series.format.line
|
241 |
+
line.color.rgb = RGBColor(141, 47, 0)
|
242 |
+
|
243 |
+
chart_placeholder._element.getparent().remove(chart_placeholder._element)
|
244 |
+
|
245 |
+
|
246 |
+
def pie_chart(chart_placeholder, slide, chart_data, title):
|
247 |
+
# Add the chart to the slide
|
248 |
+
x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height
|
249 |
+
|
250 |
+
chart = slide.shapes.add_chart(
|
251 |
+
XL_CHART_TYPE.PIE, x, y, cx, cy, chart_data
|
252 |
+
).chart
|
253 |
+
|
254 |
+
chart.has_legend = True
|
255 |
+
chart.legend.position = XL_LEGEND_POSITION.RIGHT
|
256 |
+
chart.legend.include_in_layout = False
|
257 |
+
chart.legend.font.size = Pt(PIE_LEGEND_FONT_SIZE)
|
258 |
+
|
259 |
+
chart.plots[0].has_data_labels = True
|
260 |
+
data_labels = chart.plots[0].data_labels
|
261 |
+
data_labels.number_format = '0%'
|
262 |
+
data_labels.position = XL_LABEL_POSITION.OUTSIDE_END
|
263 |
+
data_labels.font.size = Pt(DATA_LABEL_FONT_SIZE)
|
264 |
+
|
265 |
+
chart.has_title = True
|
266 |
+
chart.chart_title.text_frame.text = title
|
267 |
+
chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
|
268 |
+
chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)
|
269 |
+
|
270 |
+
chart_placeholder._element.getparent().remove(chart_placeholder._element)
|
271 |
+
|
272 |
+
|
273 |
+
def title_and_table(slide, title, df, column_width=None, custom_table_height=False):
|
274 |
+
placeholders = slide.placeholders
|
275 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
276 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
277 |
+
title_ph.text = title
|
278 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
279 |
+
|
280 |
+
table_placeholder = slide.placeholders[ph_idx[1]]
|
281 |
+
|
282 |
+
table_height = None
|
283 |
+
if custom_table_height:
|
284 |
+
if len(df) < 4:
|
285 |
+
table_height = int(np.ceil(table_placeholder.height / 2))
|
286 |
+
|
287 |
+
fill_table_placeholder(table_placeholder, slide, df, column_width, table_height)
|
288 |
+
|
289 |
+
# try:
|
290 |
+
# font_size = 18 # default for 3*3
|
291 |
+
# if cols < 3:
|
292 |
+
# row_diff = 3 - rows
|
293 |
+
# font_size = font_size + ((row_diff)*2) # 1 row less -> 2 pt font size increase & vice versa
|
294 |
+
# else:
|
295 |
+
# row_diff = 2 - rows
|
296 |
+
# font_size = font_size + ((row_diff)*2)
|
297 |
+
# for row in table.rows:
|
298 |
+
# for cell in row.cells:
|
299 |
+
# cell.text_frame.paragraphs[0].runs[0].font.size = Pt(font_size)
|
300 |
+
# except Exception as e :
|
301 |
+
# print("**"*30)
|
302 |
+
# print(e)
|
303 |
+
# else:
|
304 |
+
# except Exception as e:
|
305 |
+
# print('table', e)
|
306 |
+
return slide
|
307 |
+
|
308 |
+
|
309 |
+
def data_import(data, bin_dict):
|
310 |
+
import_df = pd.DataFrame(columns=['Category', 'Value'])
|
311 |
+
|
312 |
+
import_df.at[0, 'Category'] = 'Date Range'
|
313 |
+
|
314 |
+
date_start = data['date'].min().date()
|
315 |
+
date_end = data['date'].max().date()
|
316 |
+
import_df.at[0, 'Value'] = str(date_start) + ' - ' + str(date_end)
|
317 |
+
|
318 |
+
import_df.at[1, 'Category'] = 'Response Metrics'
|
319 |
+
import_df.at[1, 'Value'] = ', '.join(bin_dict['Response Metrics'])
|
320 |
+
|
321 |
+
import_df.at[2, 'Category'] = 'Media Variables'
|
322 |
+
import_df.at[2, 'Value'] = ', '.join(bin_dict['Media'])
|
323 |
+
|
324 |
+
import_df.at[3, 'Category'] = 'Spend Variables'
|
325 |
+
import_df.at[3, 'Value'] = ', '.join(bin_dict['Spends'])
|
326 |
+
|
327 |
+
if bin_dict['Exogenous'] != []:
|
328 |
+
import_df.at[4, 'Category'] = 'Exogenous Variables'
|
329 |
+
import_df.at[4, 'Value'] = ', '.join(bin_dict['Exogenous'])
|
330 |
+
|
331 |
+
return import_df
|
332 |
+
|
333 |
+
|
334 |
+
def channel_groups_df(channel_groups_dct={}, bin_dict={}):
|
335 |
+
df = pd.DataFrame(columns=['Channel', 'Media Variables', 'Spend Variables'])
|
336 |
+
i = 0
|
337 |
+
for channel, vars in channel_groups_dct.items():
|
338 |
+
media_vars = ", ".join(list(set(vars).intersection(set(bin_dict["Media"]))))
|
339 |
+
spend_vars = ", ".join(list(set(vars).intersection(set(bin_dict["Spends"]))))
|
340 |
+
df.at[i, "Channel"] = channel
|
341 |
+
df.at[i, 'Media Variables'] = media_vars
|
342 |
+
df.at[i, 'Spend Variables'] = spend_vars
|
343 |
+
i += 1
|
344 |
+
|
345 |
+
return df
|
346 |
+
|
347 |
+
|
348 |
+
def transformations(transform_dict):
|
349 |
+
transform_df = pd.DataFrame(columns=['Category', 'Transformation', 'Value'])
|
350 |
+
i = 0
|
351 |
+
|
352 |
+
for category in ['Media', 'Exogenous']:
|
353 |
+
transformations = f'transformation_{category}'
|
354 |
+
category_dict = transform_dict[category]
|
355 |
+
if transformations in category_dict.keys():
|
356 |
+
for transformation in category_dict[transformations]:
|
357 |
+
transform_df.at[i, 'Category'] = category
|
358 |
+
transform_df.at[i, 'Transformation'] = transformation
|
359 |
+
transform_df.at[i, 'Value'] = str(category_dict[transformation][0]) + ' - ' + str(
|
360 |
+
category_dict[transformation][1])
|
361 |
+
i += 1
|
362 |
+
return transform_df
|
363 |
+
|
364 |
+
|
365 |
+
def model_metrics(model_dict, is_panel):
|
366 |
+
metrics_df = pd.DataFrame(
|
367 |
+
columns=[
|
368 |
+
"Response Metric",
|
369 |
+
"Model",
|
370 |
+
"R2",
|
371 |
+
"ADJR2",
|
372 |
+
"Train MAPE",
|
373 |
+
"Test MAPE"
|
374 |
+
]
|
375 |
+
)
|
376 |
+
i = 0
|
377 |
+
for key in model_dict.keys():
|
378 |
+
target = key.split("__")[1]
|
379 |
+
metrics_df.at[i, "Response Metric"] = format_response_metric(target)
|
380 |
+
metrics_df.at[i, "Model"] = key.split("__")[0]
|
381 |
+
|
382 |
+
y = model_dict[key]["X_train_tuned"][target]
|
383 |
+
|
384 |
+
feature_set = model_dict[key]["feature_set"]
|
385 |
+
|
386 |
+
if is_panel:
|
387 |
+
random_df = get_random_effects(
|
388 |
+
media_data, panel_col, model_dict[key]["Model_object"]
|
389 |
+
)
|
390 |
+
pred = mdf_predict(
|
391 |
+
model_dict[key]["X_train_tuned"],
|
392 |
+
model_dict[key]["Model_object"],
|
393 |
+
random_df,
|
394 |
+
)["pred"]
|
395 |
+
else:
|
396 |
+
pred = model_dict[key]["Model_object"].predict(model_dict[key]["X_train_tuned"][feature_set])
|
397 |
+
|
398 |
+
ytest = model_dict[key]["X_test_tuned"][target]
|
399 |
+
if is_panel:
|
400 |
+
|
401 |
+
predtest = mdf_predict(
|
402 |
+
model_dict[key]["X_test_tuned"],
|
403 |
+
model_dict[key]["Model_object"],
|
404 |
+
random_df,
|
405 |
+
)["pred"]
|
406 |
+
|
407 |
+
else:
|
408 |
+
predtest = model_dict[key]["Model_object"].predict(model_dict[key]["X_test_tuned"][feature_set])
|
409 |
+
|
410 |
+
metrics_df.at[i, "R2"] = np.round(r2_score(y, pred), 2)
|
411 |
+
adjr2 = 1 - (1 - metrics_df.loc[i, "R2"]) * (
|
412 |
+
len(y) - 1
|
413 |
+
) / (len(y) - len(model_dict[key]["feature_set"]) - 1)
|
414 |
+
metrics_df.at[i, "ADJR2"] = np.round(adjr2, 2)
|
415 |
+
# y = np.where(np.abs(y) < 0.00001, 0.00001, y)
|
416 |
+
metrics_df.at[i, "Train MAPE"] = np.round(smape(y, pred), 2)
|
417 |
+
metrics_df.at[i, "Test MAPE"] = np.round(smape(ytest, predtest), 2)
|
418 |
+
i += 1
|
419 |
+
metrics_df = np.round(metrics_df, 2)
|
420 |
+
|
421 |
+
return metrics_df
|
422 |
+
|
423 |
+
|
424 |
+
def model_result(slide, model_key, model_dict, model_metrics_df, date_col):
|
425 |
+
placeholders = slide.placeholders
|
426 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
427 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
428 |
+
title_ph.text = model_key.split('__')[0]
|
429 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
430 |
+
target = model_key.split('__')[1]
|
431 |
+
|
432 |
+
metrics_table_placeholder = slide.placeholders[ph_idx[1]]
|
433 |
+
metrics_df = model_metrics_df[model_metrics_df['Model'] == model_key.split('__')[0]].reset_index(drop=True)
|
434 |
+
|
435 |
+
# Accuracy = 1-mape
|
436 |
+
metrics_df['Accuracy'] = 100 * (1 - metrics_df['Train MAPE'])
|
437 |
+
metrics_df['Accuracy'] = metrics_df['Accuracy'].apply(lambda x: f'{np.round(x, 0)}%')
|
438 |
+
|
439 |
+
## Removing metrics as requested by Ioannis
|
440 |
+
|
441 |
+
metrics_df = metrics_df.drop(columns=['R2', 'ADJR2', 'Train MAPE', 'Test MAPE'])
|
442 |
+
fill_table_placeholder(metrics_table_placeholder, slide, metrics_df)
|
443 |
+
|
444 |
+
# coeff_table_placeholder = slide.placeholders[ph_idx[2]]
|
445 |
+
# coeff_df = pd.DataFrame(model_dict['Model_object'].params)
|
446 |
+
# coeff_df.reset_index(inplace=True)
|
447 |
+
# coeff_df.columns = ['Feature', 'Coefficent']
|
448 |
+
# fill_table_placeholder(coeff_table_placeholder, slide, coeff_df)
|
449 |
+
|
450 |
+
chart_placeholder = slide.placeholders[ph_idx[2]]
|
451 |
+
full_df = pd.concat([model_dict['X_train_tuned'], model_dict['X_test_tuned']])
|
452 |
+
full_df['Predicted'] = model_dict['Model_object'].predict(full_df[model_dict['feature_set']])
|
453 |
+
pred_df = full_df[[date_col, target, 'Predicted']]
|
454 |
+
pred_df.rename(columns={target: 'Actual'}, inplace=True)
|
455 |
+
|
456 |
+
# Create chart data
|
457 |
+
chart_data = CategoryChartData()
|
458 |
+
chart_data.categories = pred_df[date_col]
|
459 |
+
chart_data.add_series('Actual', pred_df['Actual'])
|
460 |
+
chart_data.add_series('Predicted', pred_df['Predicted'])
|
461 |
+
|
462 |
+
# Set range for y axis
|
463 |
+
min_y = np.floor(min(pred_df['Actual'].min(), pred_df['Predicted'].min()))
|
464 |
+
max_y = np.ceil(max(pred_df['Actual'].max(), pred_df['Predicted'].max()))
|
465 |
+
|
466 |
+
# Create the chart
|
467 |
+
line_chart(chart_placeholder=chart_placeholder,
|
468 |
+
slide=slide,
|
469 |
+
chart_data=chart_data,
|
470 |
+
titles={'chart_title': 'Actual VS Predicted',
|
471 |
+
'x_axis': 'Date',
|
472 |
+
'y_axis': target.title().replace('_', ' ')
|
473 |
+
},
|
474 |
+
min_y=min_y,
|
475 |
+
max_y=max_y
|
476 |
+
)
|
477 |
+
|
478 |
+
return slide
|
479 |
+
|
480 |
+
|
481 |
+
def metrics_contributions(slide, contributions_excels_dict, panel_col):
|
482 |
+
# Create data for metrics contributions
|
483 |
+
all_contribution_df = pd.DataFrame(columns=['Channel'])
|
484 |
+
target_sum_dict = {}
|
485 |
+
sort_support_dct = {}
|
486 |
+
for target in contributions_excels_dict.keys():
|
487 |
+
contribution_df = contributions_excels_dict[target]['CONTRIBUTION MMM'].copy()
|
488 |
+
if 'Date' in contribution_df.columns:
|
489 |
+
contribution_df.drop(columns=['Date'], inplace=True)
|
490 |
+
if panel_col in contribution_df.columns:
|
491 |
+
contribution_df.drop(columns=[panel_col], inplace=True)
|
492 |
+
|
493 |
+
contribution_df = pd.DataFrame(np.sum(contribution_df, axis=0)).reset_index()
|
494 |
+
contribution_df.columns = ['Channel', target]
|
495 |
+
target_sum = contribution_df[target].sum()
|
496 |
+
target_sum_dict[target] = target_sum
|
497 |
+
contribution_df[target] = 100 * contribution_df[target] / target_sum
|
498 |
+
|
499 |
+
all_contribution_df = pd.merge(all_contribution_df, contribution_df, on='Channel', how='outer')
|
500 |
+
|
501 |
+
sorted_target_sum_dict = sorted(target_sum_dict.items(), key=lambda kv: kv[1], reverse=True)
|
502 |
+
sorted_target_sum_keys = [kv[0] for kv in sorted_target_sum_dict]
|
503 |
+
if len([metric for metric in sorted_target_sum_keys if metric.lower() == 'revenue']) == 1:
|
504 |
+
rev_metric = [metric for metric in sorted_target_sum_keys if metric.lower() == 'revenue'][0]
|
505 |
+
sorted_target_sum_keys.remove(rev_metric)
|
506 |
+
sorted_target_sum_keys.append(rev_metric)
|
507 |
+
all_contribution_df = all_contribution_df[['Channel'] + sorted_target_sum_keys]
|
508 |
+
|
509 |
+
# for col in all_contribution_df.columns:
|
510 |
+
# all_contribution_df[col]=all_contribution_df[col].apply(lambda x: round_off(x,1))
|
511 |
+
|
512 |
+
# Sort Data by Average contribution of the channels keeping base first <Removed>
|
513 |
+
# all_contribution_df['avg'] = np.mean(all_contribution_df[list(contributions_excels_dict.keys())],axis=1)
|
514 |
+
# all_contribution_df['rank'] = all_contribution_df['avg'].rank(ascending=False)
|
515 |
+
|
516 |
+
# Sort data by contribution of bottom funnel metric
|
517 |
+
bottom_funnel_metric = sorted_target_sum_keys[-1]
|
518 |
+
all_contribution_df['rank'] = all_contribution_df[bottom_funnel_metric].rank(ascending=False)
|
519 |
+
all_contribution_df.loc[all_contribution_df[all_contribution_df['Channel'] == 'base'].index, 'rank'] = 0
|
520 |
+
all_contribution_df = all_contribution_df.sort_values(by='rank')
|
521 |
+
all_contribution_df.drop(columns=['rank'], inplace=True)
|
522 |
+
|
523 |
+
# Add title
|
524 |
+
placeholders = slide.placeholders
|
525 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
526 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
527 |
+
title_ph.text = "Response Metrics Contributions"
|
528 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
529 |
+
|
530 |
+
for target in contributions_excels_dict.keys():
|
531 |
+
all_contribution_df[target] = all_contribution_df[target].astype(float)
|
532 |
+
|
533 |
+
|
534 |
+
# Create chart data
|
535 |
+
chart_data = CategoryChartData()
|
536 |
+
chart_data.categories = all_contribution_df['Channel']
|
537 |
+
for target in sorted_target_sum_keys:
|
538 |
+
chart_data.add_series(format_response_metric(target), all_contribution_df[target])
|
539 |
+
chart_placeholder = slide.placeholders[ph_idx[1]]
|
540 |
+
|
541 |
+
if isinstance(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])), float):
|
542 |
+
|
543 |
+
# Add the chart to the slide
|
544 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
545 |
+
slide=slide,
|
546 |
+
chart_data=chart_data,
|
547 |
+
titles={'chart_title': 'Response Metrics Contributions',
|
548 |
+
# 'x_axis':'Channels',
|
549 |
+
'y_axis': 'Contributions'},
|
550 |
+
min_y=np.floor(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime']))),
|
551 |
+
max_y=np.ceil(np.max(all_contribution_df.select_dtypes(exclude=['object', 'datetime']))),
|
552 |
+
type='V',
|
553 |
+
label_type='per'
|
554 |
+
)
|
555 |
+
else:
|
556 |
+
|
557 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
558 |
+
slide=slide,
|
559 |
+
chart_data=chart_data,
|
560 |
+
titles={'chart_title': 'Response Metrics Contributions',
|
561 |
+
# 'x_axis':'Channels',
|
562 |
+
'y_axis': 'Contributions'},
|
563 |
+
min_y=np.floor(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
|
564 |
+
max_y=np.ceil(np.max(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
|
565 |
+
type='V',
|
566 |
+
label_type='per'
|
567 |
+
)
|
568 |
+
|
569 |
+
return slide
|
570 |
+
|
571 |
+
|
572 |
+
def model_media_performance(slide, target, contributions_excels_dict, date_col='Date', is_panel=False,
|
573 |
+
panel_col='panel'):
|
574 |
+
# Add title
|
575 |
+
placeholders = slide.placeholders
|
576 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
577 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
578 |
+
title_ph.text = "Media Performance - " + target.title().replace("_", " ")
|
579 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
580 |
+
|
581 |
+
# CONTRIBUTION CHART
|
582 |
+
# Create contribution data
|
583 |
+
contribution_df = contributions_excels_dict[target]['CONTRIBUTION MMM']
|
584 |
+
if panel_col in contribution_df.columns:
|
585 |
+
contribution_df.drop(columns=[panel_col], inplace=True)
|
586 |
+
# contribution_df.drop(columns=[date_col], inplace=True)
|
587 |
+
contribution_df = pd.DataFrame(np.sum(contribution_df, axis=0)).reset_index()
|
588 |
+
contribution_df.columns = ['Channel', format_response_metric(target)]
|
589 |
+
contribution_df['Channel'] = contribution_df['Channel'].apply(lambda x: x.title())
|
590 |
+
target_sum = contribution_df[format_response_metric(target)].sum()
|
591 |
+
contribution_df[format_response_metric(target)] = contribution_df[format_response_metric(target)] / target_sum
|
592 |
+
contribution_df.sort_values(by=['Channel'], ascending=False, inplace=True)
|
593 |
+
|
594 |
+
# for col in contribution_df.columns:
|
595 |
+
# contribution_df[col] = contribution_df[col].apply(lambda x : round_off(x))
|
596 |
+
|
597 |
+
# Create Chart Data
|
598 |
+
chart_data = ChartData()
|
599 |
+
chart_data.categories = contribution_df['Channel']
|
600 |
+
chart_data.add_series('Contribution', contribution_df[format_response_metric(target)])
|
601 |
+
|
602 |
+
chart_placeholder = slide.placeholders[ph_idx[2]]
|
603 |
+
pie_chart(chart_placeholder=chart_placeholder,
|
604 |
+
slide=slide,
|
605 |
+
chart_data=chart_data,
|
606 |
+
title='Contribution')
|
607 |
+
|
608 |
+
# SPENDS CHART
|
609 |
+
|
610 |
+
initialize_data(panel='aggregated', metrics=target)
|
611 |
+
scenario = st.session_state["scenario"]
|
612 |
+
spends_values = {
|
613 |
+
channel_name: round(
|
614 |
+
scenario.channels[channel_name].actual_total_spends
|
615 |
+
* scenario.channels[channel_name].conversion_rate,
|
616 |
+
1,
|
617 |
+
)
|
618 |
+
for channel_name in st.session_state["channels_list"]
|
619 |
+
}
|
620 |
+
spends_df = pd.DataFrame(columns=['Channel', 'Media Spend'])
|
621 |
+
spends_df['Channel'] = list(spends_values.keys())
|
622 |
+
spends_df['Media Spend'] = list(spends_values.values())
|
623 |
+
spends_sum = spends_df['Media Spend'].sum()
|
624 |
+
spends_df['Media Spend'] = spends_df['Media Spend'] / spends_sum
|
625 |
+
spends_df['Channel'] = spends_df['Channel'].apply(lambda x: x.title())
|
626 |
+
spends_df.sort_values(by='Channel', ascending=False, inplace=True)
|
627 |
+
# for col in spends_df.columns:
|
628 |
+
# spends_df[col] = spends_df[col].apply(lambda x : round_off(x))
|
629 |
+
|
630 |
+
# Create Chart Data
|
631 |
+
spends_chart_data = ChartData()
|
632 |
+
spends_chart_data = ChartData()
|
633 |
+
spends_chart_data.categories = spends_df['Channel']
|
634 |
+
spends_chart_data.add_series('Media Spend', spends_df['Media Spend'])
|
635 |
+
|
636 |
+
spends_chart_placeholder = slide.placeholders[ph_idx[1]]
|
637 |
+
pie_chart(chart_placeholder=spends_chart_placeholder,
|
638 |
+
slide=slide,
|
639 |
+
chart_data=spends_chart_data,
|
640 |
+
title='Media Spend')
|
641 |
+
# spends_values.append(0)
|
642 |
+
return contribution_df, spends_df
|
643 |
+
|
644 |
+
|
645 |
+
# def get_saved_scenarios_dict(project_path):
|
646 |
+
# # Path to the saved scenarios file
|
647 |
+
# saved_scenarios_dict_path = os.path.join(
|
648 |
+
# project_path, "saved_scenarios.pkl"
|
649 |
+
# )
|
650 |
+
#
|
651 |
+
# # Load existing scenarios if the file exists
|
652 |
+
# if os.path.exists(saved_scenarios_dict_path):
|
653 |
+
# with open(saved_scenarios_dict_path, "rb") as f:
|
654 |
+
# saved_scenarios_dict = pickle.load(f)
|
655 |
+
# else:
|
656 |
+
# saved_scenarios_dict = OrderedDict()
|
657 |
+
#
|
658 |
+
# return saved_scenarios_dict
|
659 |
+
|
660 |
+
def optimization_summary(slide, scenario, scenario_name):
|
661 |
+
placeholders = slide.placeholders
|
662 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
663 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
664 |
+
title_ph.text = 'Optimization Summary' # + ' (Scenario: ' + scenario_name + ')'
|
665 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
666 |
+
|
667 |
+
multiplier = 1 / float(scenario['multiplier'])
|
668 |
+
# st.write(scenario['multiplier'], multiplier)
|
669 |
+
## Multiplier is an indicator of selected time fram
|
670 |
+
## Doesn't effect CPA
|
671 |
+
|
672 |
+
opt_on = scenario['optimization']
|
673 |
+
if opt_on.lower() == 'spends':
|
674 |
+
opt_on = 'Media Spend'
|
675 |
+
|
676 |
+
details_ph = slide.placeholders[ph_idx[3]]
|
677 |
+
details_ph.text = 'Scenario Name: ' + scenario_name + \
|
678 |
+
'\nResponse Metric: ' + str(scenario['metrics_selected']).replace("_", " ").title() + \
|
679 |
+
'\nOptimized on: ' + str(opt_on).replace("_", " ").title()
|
680 |
+
|
681 |
+
scenario_df = pd.DataFrame(columns=['Category', 'Actual', 'Simulated', 'Change'])
|
682 |
+
scenario_df.at[0, 'Category'] = 'Media Spend'
|
683 |
+
|
684 |
+
scenario_df.at[0, 'Actual'] = scenario['actual_total_spends'] * multiplier
|
685 |
+
scenario_df.at[0, 'Simulated'] = scenario['modified_total_spends'] * multiplier
|
686 |
+
scenario_df.at[0, 'Change'] = (scenario['modified_total_spends'] - scenario['actual_total_spends']) * multiplier
|
687 |
+
|
688 |
+
scenario_df.at[1, 'Category'] = scenario['metrics_selected'].replace("_", " ").title()
|
689 |
+
scenario_df.at[1, 'Actual'] = scenario['actual_total_sales'] * multiplier
|
690 |
+
scenario_df.at[1, 'Simulated'] = (scenario['modified_total_sales']) * multiplier
|
691 |
+
scenario_df.at[1, 'Change'] = (scenario['modified_total_sales'] - scenario['actual_total_sales']) * multiplier
|
692 |
+
|
693 |
+
scenario_df.at[2, 'Category'] = 'CPA'
|
694 |
+
actual_cpa = scenario['actual_total_spends'] / scenario['actual_total_sales']
|
695 |
+
modified_cpa = scenario['modified_total_spends'] / scenario['modified_total_sales']
|
696 |
+
scenario_df.at[2, 'Actual'] = actual_cpa
|
697 |
+
scenario_df.at[2, 'Simulated'] = modified_cpa
|
698 |
+
scenario_df.at[2, 'Change'] = modified_cpa - actual_cpa
|
699 |
+
|
700 |
+
scenario_df.at[3, 'Category'] = 'ROI'
|
701 |
+
act_roi = scenario['actual_total_sales'] / scenario['actual_total_spends']
|
702 |
+
opt_roi = scenario['modified_total_sales'] / scenario['modified_total_spends']
|
703 |
+
scenario_df.at[3, 'Actual'] = act_roi
|
704 |
+
scenario_df.at[3, 'Simulated'] = opt_roi
|
705 |
+
scenario_df.at[3, 'Change'] = opt_roi - act_roi
|
706 |
+
|
707 |
+
for col in scenario_df.columns:
|
708 |
+
scenario_df[col] = scenario_df[col].apply(lambda x: round_off(x, 1))
|
709 |
+
scenario_df[col] = scenario_df[col].apply(lambda x: convert_number_to_abbreviation(x))
|
710 |
+
|
711 |
+
table_placeholder = slide.placeholders[ph_idx[1]]
|
712 |
+
fill_table_placeholder(table_placeholder, slide, scenario_df)
|
713 |
+
|
714 |
+
channel_spends_df = pd.DataFrame(columns=['Channel', 'Actual Spends', 'Optimized Spends'])
|
715 |
+
for i, channel in enumerate(scenario['channels'].values()):
|
716 |
+
channel_spends_df.at[i, 'Channel'] = channel['name']
|
717 |
+
channel_conversion_rate = channel[
|
718 |
+
"conversion_rate"
|
719 |
+
]
|
720 |
+
channel_spends_df.at[i, 'Actual Spends'] = (
|
721 |
+
channel["actual_total_spends"]
|
722 |
+
* channel_conversion_rate
|
723 |
+
) * multiplier
|
724 |
+
channel_spends_df.at[i, 'Optimized Spends'] = (
|
725 |
+
channel["modified_total_spends"]
|
726 |
+
* channel_conversion_rate
|
727 |
+
) * multiplier
|
728 |
+
channel_spends_df['Actual Spends'] = channel_spends_df['Actual Spends'].astype('float')
|
729 |
+
channel_spends_df['Optimized Spends'] = channel_spends_df['Optimized Spends'].astype('float')
|
730 |
+
|
731 |
+
for col in channel_spends_df.columns:
|
732 |
+
channel_spends_df[col] = channel_spends_df[col].apply(lambda x: round_off(x, 0))
|
733 |
+
|
734 |
+
# Sort data on Actual Spends
|
735 |
+
channel_spends_df.sort_values(by='Actual Spends', inplace=True, ascending=False)
|
736 |
+
|
737 |
+
# Create chart data
|
738 |
+
chart_data = CategoryChartData()
|
739 |
+
chart_data.categories = channel_spends_df['Channel']
|
740 |
+
for col in ['Actual Spends', 'Optimized Spends']:
|
741 |
+
chart_data.add_series(col, channel_spends_df[col])
|
742 |
+
|
743 |
+
chart_placeholder = slide.placeholders[ph_idx[2]]
|
744 |
+
|
745 |
+
# Add the chart to the slide
|
746 |
+
if isinstance(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime'])),float):
|
747 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
748 |
+
slide=slide,
|
749 |
+
chart_data=chart_data,
|
750 |
+
titles={'chart_title': 'Channel Wise Spends',
|
751 |
+
# 'x_axis':'Channels',
|
752 |
+
'y_axis': 'Spends'},
|
753 |
+
# min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
|
754 |
+
min_y=0,
|
755 |
+
max_y=np.ceil(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
|
756 |
+
label_type='$'
|
757 |
+
)
|
758 |
+
else:
|
759 |
+
# Add the chart to the slide
|
760 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
761 |
+
slide=slide,
|
762 |
+
chart_data=chart_data,
|
763 |
+
titles={'chart_title': 'Channel Wise Spends',
|
764 |
+
# 'x_axis':'Channels',
|
765 |
+
'y_axis': 'Spends'},
|
766 |
+
# min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
|
767 |
+
min_y=0,
|
768 |
+
max_y=np.ceil(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
|
769 |
+
label_type='$'
|
770 |
+
)
|
771 |
+
|
772 |
+
|
773 |
+
def channel_wise_spends(slide, scenario):
|
774 |
+
placeholders = slide.placeholders
|
775 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
776 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
777 |
+
title_ph.text = 'Channel Spends and Impact'
|
778 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
779 |
+
# print(scenario.keys())
|
780 |
+
|
781 |
+
multiplier = 1 / float(scenario['multiplier'])
|
782 |
+
channel_spends_df = pd.DataFrame(columns=['Channel', 'Actual Spends', 'Optimized Spends'])
|
783 |
+
for i, channel in enumerate(scenario['channels'].values()):
|
784 |
+
channel_spends_df.at[i, 'Channel'] = channel['name']
|
785 |
+
channel_conversion_rate = channel["conversion_rate"]
|
786 |
+
channel_spends_df.at[i, 'Actual Spends'] = (channel[
|
787 |
+
"actual_total_spends"] * channel_conversion_rate) * multiplier
|
788 |
+
channel_spends_df.at[i, 'Optimized Spends'] = (channel[
|
789 |
+
"modified_total_spends"] * channel_conversion_rate) * multiplier
|
790 |
+
channel_spends_df['Actual Spends'] = channel_spends_df['Actual Spends'].astype('float')
|
791 |
+
channel_spends_df['Optimized Spends'] = channel_spends_df['Optimized Spends'].astype('float')
|
792 |
+
|
793 |
+
actual_sum = channel_spends_df['Actual Spends'].sum()
|
794 |
+
opt_sum = channel_spends_df['Optimized Spends'].sum()
|
795 |
+
|
796 |
+
for col in channel_spends_df.columns:
|
797 |
+
channel_spends_df[col] = channel_spends_df[col].apply(lambda x: round_off(x, 0))
|
798 |
+
|
799 |
+
channel_spends_df['Actual Spends %'] = 100 * (channel_spends_df['Actual Spends'] / actual_sum)
|
800 |
+
channel_spends_df['Optimized Spends %'] = 100 * (channel_spends_df['Optimized Spends'] / opt_sum)
|
801 |
+
channel_spends_df['Actual Spends %'] = np.round(channel_spends_df['Actual Spends %'])
|
802 |
+
channel_spends_df['Optimized Spends %'] = np.round(channel_spends_df['Optimized Spends %'])
|
803 |
+
|
804 |
+
# Sort Data based on Actual Spends %
|
805 |
+
channel_spends_df.sort_values(by='Actual Spends %', inplace=True)
|
806 |
+
|
807 |
+
# Create chart data
|
808 |
+
chart_data = CategoryChartData()
|
809 |
+
chart_data.categories = channel_spends_df['Channel']
|
810 |
+
for col in ['Actual Spends %', 'Optimized Spends %']:
|
811 |
+
# for col in ['Actual Spends %']:
|
812 |
+
chart_data.add_series(col, channel_spends_df[col])
|
813 |
+
chart_placeholder = slide.placeholders[ph_idx[1]]
|
814 |
+
|
815 |
+
# Add the chart to the slide
|
816 |
+
if isinstance(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']]), float):
|
817 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
818 |
+
slide=slide,
|
819 |
+
chart_data=chart_data,
|
820 |
+
titles={'chart_title': 'Spend Split %',
|
821 |
+
# 'x_axis':'Channels',
|
822 |
+
'y_axis': 'Spend %'},
|
823 |
+
min_y=0,
|
824 |
+
max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']])),
|
825 |
+
type='H',
|
826 |
+
legend=True,
|
827 |
+
label_type='per',
|
828 |
+
xaxis_pos='low'
|
829 |
+
)
|
830 |
+
else:
|
831 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
832 |
+
slide=slide,
|
833 |
+
chart_data=chart_data,
|
834 |
+
titles={'chart_title': 'Spend Split %',
|
835 |
+
# 'x_axis':'Channels',
|
836 |
+
'y_axis': 'Spend %'},
|
837 |
+
min_y=0,
|
838 |
+
max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']]).values[0]),
|
839 |
+
type='H',
|
840 |
+
legend=True,
|
841 |
+
label_type='per',
|
842 |
+
xaxis_pos='low'
|
843 |
+
)
|
844 |
+
#
|
845 |
+
# # Create chart data
|
846 |
+
# chart_data_1 = CategoryChartData()
|
847 |
+
# chart_data_1.categories = channel_spends_df['Channel']
|
848 |
+
# # for col in ['Actual Spends %', 'Optimized Spends %']:
|
849 |
+
# for col in ['Optimized Spends %']:
|
850 |
+
# chart_data_1.add_series(col, channel_spends_df[col])
|
851 |
+
# chart_placeholder_1 = slide.placeholders[ph_idx[3]]
|
852 |
+
#
|
853 |
+
# # Add the chart to the slide
|
854 |
+
# bar_chart(chart_placeholder=chart_placeholder_1,
|
855 |
+
# slide=slide,
|
856 |
+
# chart_data=chart_data_1,
|
857 |
+
# titles={'chart_title': 'Optimized Spends Split %',
|
858 |
+
# # 'x_axis':'Channels',
|
859 |
+
# 'y_axis': 'Spends %'},
|
860 |
+
# min_y=0,
|
861 |
+
# max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']])),
|
862 |
+
# type='H',
|
863 |
+
# legend=False,
|
864 |
+
# label_type='per'
|
865 |
+
# )
|
866 |
+
|
867 |
+
channel_spends_df['Delta %'] = 100 * (channel_spends_df['Optimized Spends'] - channel_spends_df['Actual Spends']) / \
|
868 |
+
channel_spends_df['Actual Spends']
|
869 |
+
channel_spends_df['Delta %'] = channel_spends_df['Delta %'].apply(lambda x: round_off(x, 0))
|
870 |
+
|
871 |
+
# Create chart data
|
872 |
+
delta_chart_data = CategoryChartData()
|
873 |
+
delta_chart_data.categories = channel_spends_df['Channel']
|
874 |
+
col = 'Delta %'
|
875 |
+
delta_chart_data.add_series(col, channel_spends_df[col])
|
876 |
+
delta_chart_placeholder = slide.placeholders[ph_idx[3]]
|
877 |
+
|
878 |
+
# Add the chart to the slide
|
879 |
+
if isinstance(np.min(channel_spends_df['Delta %']), float):
|
880 |
+
bar_chart(chart_placeholder=delta_chart_placeholder,
|
881 |
+
slide=slide,
|
882 |
+
chart_data=delta_chart_data,
|
883 |
+
titles={'chart_title': 'Spend Delta %',
|
884 |
+
'y_axis': 'Spend Delta %'},
|
885 |
+
min_y=np.floor(np.min(channel_spends_df['Delta %'])),
|
886 |
+
max_y=np.ceil(np.max(channel_spends_df['Delta %'])),
|
887 |
+
type='H',
|
888 |
+
legend=False,
|
889 |
+
label_type='per',
|
890 |
+
xaxis_pos='low'
|
891 |
+
|
892 |
+
)
|
893 |
+
else:
|
894 |
+
bar_chart(chart_placeholder=delta_chart_placeholder,
|
895 |
+
slide=slide,
|
896 |
+
chart_data=delta_chart_data,
|
897 |
+
titles={'chart_title': 'Spend Delta %',
|
898 |
+
'y_axis': 'Spend Delta %'},
|
899 |
+
min_y=np.floor(np.min(channel_spends_df['Delta %']).values[0]),
|
900 |
+
max_y=np.ceil(np.max(channel_spends_df['Delta %']).values[0]),
|
901 |
+
type='H',
|
902 |
+
legend=False,
|
903 |
+
label_type='per',
|
904 |
+
xaxis_pos='low'
|
905 |
+
|
906 |
+
)
|
907 |
+
|
908 |
+
# Incremental Impact
|
909 |
+
channel_inc_df = pd.DataFrame(columns=['Channel', 'Increment'])
|
910 |
+
for i, channel in enumerate(scenario['channels'].values()):
|
911 |
+
channel_inc_df.at[i, 'Channel'] = channel['name']
|
912 |
+
act_impact = channel['actual_total_sales']
|
913 |
+
opt_impact = channel['modified_total_sales']
|
914 |
+
impact = opt_impact - act_impact
|
915 |
+
impact = round_off(impact, 0)
|
916 |
+
impact = impact if abs(impact) > 0.0001 else 0
|
917 |
+
channel_inc_df.at[i, 'Increment'] = impact
|
918 |
+
|
919 |
+
channel_inc_df_1 = pd.merge(channel_spends_df, channel_inc_df, how='left', on='Channel')
|
920 |
+
|
921 |
+
# Create chart data
|
922 |
+
delta_chart_data = CategoryChartData()
|
923 |
+
delta_chart_data.categories = channel_inc_df_1['Channel']
|
924 |
+
col = 'Increment'
|
925 |
+
delta_chart_data.add_series(col, channel_inc_df_1[col])
|
926 |
+
delta_chart_placeholder = slide.placeholders[ph_idx[2]]
|
927 |
+
|
928 |
+
label_req = True
|
929 |
+
if min(np.abs(channel_inc_df_1[col])) > 100000: # 0.1M
|
930 |
+
label_type = 'M'
|
931 |
+
elif min(np.abs(channel_inc_df_1[col])) > 10000 and max(np.abs(channel_inc_df_1[col])) > 1000000:
|
932 |
+
label_type = 'M1'
|
933 |
+
elif min(np.abs(channel_inc_df_1[col])) > 100 and max(np.abs(channel_inc_df_1[col])) > 1000:
|
934 |
+
label_type = 'K'
|
935 |
+
else:
|
936 |
+
label_req = False
|
937 |
+
# Add the chart to the slide
|
938 |
+
if label_req:
|
939 |
+
bar_chart(chart_placeholder=delta_chart_placeholder,
|
940 |
+
slide=slide,
|
941 |
+
chart_data=delta_chart_data,
|
942 |
+
titles={'chart_title': 'Incremental Impact',
|
943 |
+
'y_axis': format_response_metric(scenario['metrics_selected'])},
|
944 |
+
# min_y=np.floor(np.min(channel_inc_df_1['Delta %'])),
|
945 |
+
# max_y=np.ceil(np.max(channel_inc_df_1['Delta %'])),
|
946 |
+
type='H',
|
947 |
+
label_type=label_type,
|
948 |
+
legend=False,
|
949 |
+
xaxis_pos='low'
|
950 |
+
)
|
951 |
+
else:
|
952 |
+
bar_chart(chart_placeholder=delta_chart_placeholder,
|
953 |
+
slide=slide,
|
954 |
+
chart_data=delta_chart_data,
|
955 |
+
titles={'chart_title': 'Increment',
|
956 |
+
'y_axis': scenario['metrics_selected']},
|
957 |
+
# min_y=np.floor(np.min(channel_inc_df_1['Delta %'])),
|
958 |
+
# max_y=np.ceil(np.max(channel_inc_df_1['Delta %'])),
|
959 |
+
type='H',
|
960 |
+
legend=False,
|
961 |
+
xaxis_pos='low'
|
962 |
+
)
|
963 |
+
|
964 |
+
|
965 |
+
def channel_wise_roi(slide, scenario):
|
966 |
+
channel_roi_mroi = scenario['channel_roi_mroi']
|
967 |
+
|
968 |
+
# Add title
|
969 |
+
placeholders = slide.placeholders
|
970 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
971 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
972 |
+
title_ph.text = 'Channel ROIs'
|
973 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
974 |
+
|
975 |
+
channel_roi_df = pd.DataFrame(columns=['Channel', 'Actual ROI', 'Optimized ROI'])
|
976 |
+
for i, channel in enumerate(channel_roi_mroi.keys()):
|
977 |
+
channel_roi_df.at[i, 'Channel'] = channel
|
978 |
+
channel_roi_df.at[i, 'Actual ROI'] = channel_roi_mroi[channel]['actual_roi']
|
979 |
+
channel_roi_df.at[i, 'Optimized ROI'] = channel_roi_mroi[channel]['optimized_roi']
|
980 |
+
channel_roi_df['Actual ROI'] = channel_roi_df['Actual ROI'].astype('float')
|
981 |
+
channel_roi_df['Optimized ROI'] = channel_roi_df['Optimized ROI'].astype('float')
|
982 |
+
|
983 |
+
for col in channel_roi_df.columns:
|
984 |
+
channel_roi_df[col] = channel_roi_df[col].apply(lambda x: round_off(x, 2))
|
985 |
+
|
986 |
+
# Create chart data
|
987 |
+
chart_data = CategoryChartData()
|
988 |
+
chart_data.categories = channel_roi_df['Channel']
|
989 |
+
for col in ['Actual ROI', 'Optimized ROI']:
|
990 |
+
chart_data.add_series(col, channel_roi_df[col])
|
991 |
+
|
992 |
+
chart_placeholder = slide.placeholders[ph_idx[1]]
|
993 |
+
|
994 |
+
# Add the chart to the slide
|
995 |
+
if isinstance(channel_roi_df.select_dtypes(exclude=['object', 'datetime']), float):
|
996 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
997 |
+
slide=slide,
|
998 |
+
chart_data=chart_data,
|
999 |
+
titles={'chart_title': 'Channel Wise ROI',
|
1000 |
+
# 'x_axis':'Channels',
|
1001 |
+
'y_axis': 'ROI'},
|
1002 |
+
# min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
|
1003 |
+
min_y=0,
|
1004 |
+
max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime']))
|
1005 |
+
)
|
1006 |
+
else:
|
1007 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
1008 |
+
slide=slide,
|
1009 |
+
chart_data=chart_data,
|
1010 |
+
titles={'chart_title': 'Channel Wise ROI',
|
1011 |
+
# 'x_axis':'Channels',
|
1012 |
+
'y_axis': 'ROI'},
|
1013 |
+
# min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
|
1014 |
+
min_y=0,
|
1015 |
+
max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])).values[0]
|
1016 |
+
)
|
1017 |
+
# act_roi = scenario['actual_total_sales']/scenario['actual_total_spends']
|
1018 |
+
# opt_roi = scenario['modified_total_sales']/scenario['modified_total_spends']
|
1019 |
+
#
|
1020 |
+
# act_roi_ph = slide.placeholders[ph_idx[2]]
|
1021 |
+
# act_roi_ph.text = 'Actual ROI: ' + str(round_off(act_roi,2))
|
1022 |
+
# opt_roi_ph = slide.placeholders[ph_idx[3]]
|
1023 |
+
# opt_roi_ph.text = 'Optimized ROI: ' + str(round_off(opt_roi, 2))
|
1024 |
+
|
1025 |
+
## Removing mroi chart as per Ioannis' feedback
|
1026 |
+
# channel_mroi_df = pd.DataFrame(columns=['Channel', 'Actual mROI', 'Optimized mROI'])
|
1027 |
+
# for i, channel in enumerate(channel_roi_mroi.keys()):
|
1028 |
+
# channel_mroi_df.at[i, 'Channel'] = channel
|
1029 |
+
# channel_mroi_df.at[i, 'Actual mROI'] = channel_roi_mroi[channel]['actual_mroi']
|
1030 |
+
# channel_mroi_df.at[i, 'Optimized mROI'] = channel_roi_mroi[channel]['optimized_mroi']
|
1031 |
+
# channel_mroi_df['Actual mROI']=channel_mroi_df['Actual mROI'].astype('float')
|
1032 |
+
# channel_mroi_df['Optimized mROI']=channel_mroi_df['Optimized mROI'].astype('float')
|
1033 |
+
#
|
1034 |
+
# for col in channel_mroi_df.columns:
|
1035 |
+
# channel_mroi_df[col]=channel_mroi_df[col].apply(lambda x: round_off(x))
|
1036 |
+
#
|
1037 |
+
# # Create chart data
|
1038 |
+
# mroi_chart_data = CategoryChartData()
|
1039 |
+
# mroi_chart_data.categories = channel_mroi_df['Channel']
|
1040 |
+
# for col in ['Actual mROI', 'Optimized mROI']:
|
1041 |
+
# mroi_chart_data.add_series(col, channel_mroi_df[col])
|
1042 |
+
#
|
1043 |
+
# mroi_chart_placeholder=slide.placeholders[ph_idx[2]]
|
1044 |
+
#
|
1045 |
+
# # Add the chart to the slide
|
1046 |
+
# bar_chart(chart_placeholder=mroi_chart_placeholder,
|
1047 |
+
# slide=slide,
|
1048 |
+
# chart_data=mroi_chart_data,
|
1049 |
+
# titles={'chart_title':'Channel Wise mROI',
|
1050 |
+
# # 'x_axis':'Channels',
|
1051 |
+
# 'y_axis':'mROI'},
|
1052 |
+
# # min_y=np.floor(np.min(channel_mroi_df.select_dtypes(exclude=['object', 'datetime']))),
|
1053 |
+
# min_y=0,
|
1054 |
+
# max_y=np.ceil(np.max(channel_mroi_df.select_dtypes(exclude=['object', 'datetime'])))
|
1055 |
+
# )
|
1056 |
+
|
1057 |
+
|
1058 |
+
def effictiveness_efficiency(slide, final_data, bin_dct, scenario):
|
1059 |
+
# Add title
|
1060 |
+
placeholders = slide.placeholders
|
1061 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
1062 |
+
title_ph = slide.placeholders[ph_idx[0]]
|
1063 |
+
title_ph.text = 'Effectiveness and Efficiency'
|
1064 |
+
title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
|
1065 |
+
|
1066 |
+
response_metrics = bin_dct['Response Metrics']
|
1067 |
+
|
1068 |
+
kpi_df = final_data[response_metrics].sum(axis=0).reset_index()
|
1069 |
+
kpi_df.columns = ['Response Metric', 'Effectiveness']
|
1070 |
+
kpi_df['Efficiency'] = kpi_df['Effectiveness'] / scenario['modified_total_spends']
|
1071 |
+
kpi_df['Efficiency'] = kpi_df['Efficiency'].apply(lambda x: round_off(x, 1))
|
1072 |
+
kpi_df.sort_values(by='Effectiveness', inplace=True)
|
1073 |
+
kpi_df['Response Metric'] = kpi_df['Response Metric'].apply(lambda x: format_response_metric(x))
|
1074 |
+
|
1075 |
+
# Create chart data for effectiveness
|
1076 |
+
chart_data = CategoryChartData()
|
1077 |
+
chart_data.categories = kpi_df['Response Metric']
|
1078 |
+
chart_data.add_series('Effectiveness', kpi_df['Effectiveness'])
|
1079 |
+
|
1080 |
+
chart_placeholder = slide.placeholders[ph_idx[1]]
|
1081 |
+
|
1082 |
+
# Add the chart to the slide
|
1083 |
+
bar_chart(chart_placeholder=chart_placeholder,
|
1084 |
+
slide=slide,
|
1085 |
+
chart_data=chart_data,
|
1086 |
+
titles={'chart_title': 'Effectiveness',
|
1087 |
+
# 'x_axis':'Channels',
|
1088 |
+
# 'y_axis': 'ROI'
|
1089 |
+
},
|
1090 |
+
# min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
|
1091 |
+
min_y=0,
|
1092 |
+
# max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])),
|
1093 |
+
type='H',
|
1094 |
+
label_type='M'
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
# Create chart data for efficiency
|
1098 |
+
chart_data_1 = CategoryChartData()
|
1099 |
+
chart_data_1.categories = kpi_df['Response Metric']
|
1100 |
+
chart_data_1.add_series('Efficiency', kpi_df['Efficiency'])
|
1101 |
+
|
1102 |
+
chart_placeholder_1 = slide.placeholders[ph_idx[2]]
|
1103 |
+
|
1104 |
+
# Add the chart to the slide
|
1105 |
+
bar_chart(chart_placeholder=chart_placeholder_1,
|
1106 |
+
slide=slide,
|
1107 |
+
chart_data=chart_data_1,
|
1108 |
+
titles={'chart_title': 'Efficiency',
|
1109 |
+
# 'x_axis':'Channels',
|
1110 |
+
# 'y_axis': 'ROI'
|
1111 |
+
},
|
1112 |
+
# min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
|
1113 |
+
min_y=0,
|
1114 |
+
# max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])),
|
1115 |
+
type='H'
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
definition_ph_1 = slide.placeholders[ph_idx[3]]
|
1119 |
+
definition_ph_1.text = 'Effectiveness is measured as the total sum of the Response Metric'
|
1120 |
+
definition_ph_2 = slide.placeholders[ph_idx[4]]
|
1121 |
+
definition_ph_2.text = 'Efficiency is measured as the ratio of sum of the Response Metric and sum of Media Spend'
|
1122 |
+
|
1123 |
+
|
1124 |
+
def load_pickle(path):
|
1125 |
+
with open(path, "rb") as f:
|
1126 |
+
file_data = pickle.load(f)
|
1127 |
+
return file_data
|
1128 |
+
|
1129 |
+
|
1130 |
+
def read_all_files():
|
1131 |
+
files=[]
|
1132 |
+
|
1133 |
+
# Read data and bin dictionary
|
1134 |
+
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is not None:
|
1135 |
+
final_df_loaded = st.session_state["project_dct"]["data_import"]["imputed_tool_df"].copy()
|
1136 |
+
bin_dict_loaded = st.session_state["project_dct"]["data_import"]["category_dict"].copy()
|
1137 |
+
|
1138 |
+
files.append(final_df_loaded)
|
1139 |
+
files.append(bin_dict_loaded)
|
1140 |
+
|
1141 |
+
if "group_dict" in st.session_state["project_dct"]["data_import"].keys():
|
1142 |
+
channels = st.session_state["project_dct"]["data_import"]["group_dict"]
|
1143 |
+
files.append(channels)
|
1144 |
+
|
1145 |
+
|
1146 |
+
if st.session_state["project_dct"]["transformations"]["final_df"] is not None:
|
1147 |
+
transform_dict = st.session_state["project_dct"]["transformations"]
|
1148 |
+
files.append(transform_dict)
|
1149 |
+
if retrieve_pkl_object_without_warning(st.session_state['project_number'], "Model_Tuning", "tuned_model", schema) is not None:
|
1150 |
+
tuned_model_dict = retrieve_pkl_object_without_warning(st.session_state['project_number'], "Model_Tuning",
|
1151 |
+
"tuned_model", schema) # db
|
1152 |
+
|
1153 |
+
files.append(tuned_model_dict)
|
1154 |
+
else:
|
1155 |
+
files.append(None)
|
1156 |
+
else:
|
1157 |
+
files.append(None)
|
1158 |
+
|
1159 |
+
if len(list(st.session_state["project_dct"]["current_media_performance"]["model_outputs"].keys()))>0: # check if there are model outputs for at least one metric
|
1160 |
+
metrics_list = list(st.session_state["project_dct"]["current_media_performance"]["model_outputs"].keys())
|
1161 |
+
contributions_excels_dict = {}
|
1162 |
+
for metrics in metrics_list:
|
1163 |
+
# raw_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["raw_data"]
|
1164 |
+
# spend_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["spends_data"]
|
1165 |
+
contribution_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["contribution_data"]
|
1166 |
+
contributions_excels_dict[metrics] = {'CONTRIBUTION MMM':contribution_df}
|
1167 |
+
files.append(contributions_excels_dict)
|
1168 |
+
|
1169 |
+
# Get Saved Scenarios
|
1170 |
+
if len(list(st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"].keys()))>0:
|
1171 |
+
files.append(st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"])
|
1172 |
+
|
1173 |
+
# saved_scenarios_loaded = get_saved_scenarios_dict(project_path)
|
1174 |
+
|
1175 |
+
|
1176 |
+
return files
|
1177 |
+
|
1178 |
+
|
1179 |
+
|
1180 |
+
'''
|
1181 |
+
|
1182 |
+
Template Layout
|
1183 |
+
|
1184 |
+
0 : Title
|
1185 |
+
1 : Data Details Section {no changes required}
|
1186 |
+
2 : Data Import
|
1187 |
+
3 : Data Import - Channel Groups
|
1188 |
+
4 : Model Results {Duplicate for each model}
|
1189 |
+
5 : Metrics Contribution
|
1190 |
+
6 : Media performance {Duplicate for each model}
|
1191 |
+
7 : Media performance Tabular View {Duplicate for each model}
|
1192 |
+
8 : Optimization Section {no changes}
|
1193 |
+
9 : Optimization Summary {Duplicate for each section}
|
1194 |
+
10 : Channel Spends {Duplicate for each model}
|
1195 |
+
11 : Channel Wise ROI {Duplicate for each model}
|
1196 |
+
12 : Efficiency & Efficacy
|
1197 |
+
13 : Appendix
|
1198 |
+
14 : Transformations
|
1199 |
+
15 : Model Summary
|
1200 |
+
16 : Thank You Slide
|
1201 |
+
|
1202 |
+
'''
|
1203 |
+
|
1204 |
+
|
1205 |
+
def create_ppt(project_name, username, panel_col):
|
1206 |
+
# Read saved files
|
1207 |
+
files = read_all_files()
|
1208 |
+
transform_dict, tuned_model_dict, contributions_excels_dict, saved_scenarios_loaded = None, None, None, None
|
1209 |
+
|
1210 |
+
if len(files)>0:
|
1211 |
+
# saved_data = files[0]
|
1212 |
+
data = files[0]
|
1213 |
+
bin_dict = files[1]
|
1214 |
+
|
1215 |
+
channel_groups_dct = files[2]
|
1216 |
+
try:
|
1217 |
+
transform_dict = files[3]
|
1218 |
+
tuned_model_dict = files[4]
|
1219 |
+
contributions_excels_dict = files[5]
|
1220 |
+
saved_scenarios_loaded = files[6]
|
1221 |
+
except Exception as e:
|
1222 |
+
print(e)
|
1223 |
+
|
1224 |
+
else:
|
1225 |
+
return False
|
1226 |
+
|
1227 |
+
is_panel = True if data[panel_col].nunique()>1 else False
|
1228 |
+
|
1229 |
+
template_path = 'ppt/template.pptx'
|
1230 |
+
# ppt_path = os.path.join('ProjectSummary.pptx')
|
1231 |
+
|
1232 |
+
prs = Presentation(template_path)
|
1233 |
+
num_slides = len(prs.slides)
|
1234 |
+
slides = prs.slides
|
1235 |
+
|
1236 |
+
# Title Slide
|
1237 |
+
title_slide_layout = slides[0].slide_layout
|
1238 |
+
title_slide = prs.slides.add_slide(title_slide_layout)
|
1239 |
+
|
1240 |
+
# Add title & project name
|
1241 |
+
placeholders = title_slide.placeholders
|
1242 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
1243 |
+
title_ph = title_slide.placeholders[ph_idx[0]]
|
1244 |
+
title_ph.text = 'Media Mix Optimization Summary'
|
1245 |
+
txt_ph = title_slide.placeholders[ph_idx[1]]
|
1246 |
+
txt_ph.text = 'Project Name: ' + project_name + '\nCreated By: ' + username
|
1247 |
+
|
1248 |
+
# Model Details Section
|
1249 |
+
model_section_slide_layout = slides[1].slide_layout
|
1250 |
+
model_section_slide = prs.slides.add_slide(model_section_slide_layout)
|
1251 |
+
|
1252 |
+
## Add title
|
1253 |
+
placeholders = model_section_slide.placeholders
|
1254 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
1255 |
+
title_ph = model_section_slide.placeholders[ph_idx[0]]
|
1256 |
+
title_ph.text = 'Model Details'
|
1257 |
+
section_ph = model_section_slide.placeholders[ph_idx[1]]
|
1258 |
+
section_ph.text = 'Section 1'
|
1259 |
+
|
1260 |
+
# Data Import
|
1261 |
+
data_import_slide_layout = slides[2].slide_layout
|
1262 |
+
data_import_slide = prs.slides.add_slide(data_import_slide_layout)
|
1263 |
+
data_import_slide = title_and_table(slide=data_import_slide,
|
1264 |
+
title='Data Import',
|
1265 |
+
df=data_import(data, bin_dict),
|
1266 |
+
column_width={0: 2, 1: 7}
|
1267 |
+
)
|
1268 |
+
|
1269 |
+
# Channel Groups
|
1270 |
+
channel_group_slide_layout = slides[3].slide_layout
|
1271 |
+
channel_group_slide = prs.slides.add_slide(channel_group_slide_layout)
|
1272 |
+
channel_group_slide = title_and_table(slide=channel_group_slide,
|
1273 |
+
title='Channels - Media and Spend',
|
1274 |
+
df=channel_groups_df(channel_groups_dct, bin_dict),
|
1275 |
+
column_width={0: 2, 1: 5, 2: 2}
|
1276 |
+
)
|
1277 |
+
|
1278 |
+
if tuned_model_dict is not None:
|
1279 |
+
model_metrics_df = model_metrics(tuned_model_dict, False)
|
1280 |
+
|
1281 |
+
# Model Results
|
1282 |
+
for model_key, model_dict in tuned_model_dict.items():
|
1283 |
+
model_result_slide_layout = slides[4].slide_layout
|
1284 |
+
model_result_slide = prs.slides.add_slide(model_result_slide_layout)
|
1285 |
+
model_result_slide = model_result(slide=model_result_slide,
|
1286 |
+
model_key=model_key,
|
1287 |
+
model_dict=model_dict,
|
1288 |
+
model_metrics_df=model_metrics_df,
|
1289 |
+
date_col='date')
|
1290 |
+
|
1291 |
+
if contributions_excels_dict is not None:
|
1292 |
+
|
1293 |
+
# Metrics Contributions
|
1294 |
+
metrics_contributions_slide_layout = slides[5].slide_layout
|
1295 |
+
metrics_contributions_slide = prs.slides.add_slide(metrics_contributions_slide_layout)
|
1296 |
+
metrics_contributions_slide = metrics_contributions(slide=metrics_contributions_slide,
|
1297 |
+
contributions_excels_dict=contributions_excels_dict,
|
1298 |
+
panel_col=panel_col
|
1299 |
+
)
|
1300 |
+
|
1301 |
+
# Media Performance
|
1302 |
+
for target in contributions_excels_dict.keys():
|
1303 |
+
|
1304 |
+
# Chart
|
1305 |
+
model_media_perf_slide_layout = slides[6].slide_layout
|
1306 |
+
model_media_perf_slide = prs.slides.add_slide(model_media_perf_slide_layout)
|
1307 |
+
contribution_df, spends_df = model_media_performance(slide=model_media_perf_slide,
|
1308 |
+
target=target,
|
1309 |
+
contributions_excels_dict=contributions_excels_dict
|
1310 |
+
)
|
1311 |
+
|
1312 |
+
# Tabular View
|
1313 |
+
contri_spends_df = pd.merge(spends_df, contribution_df, on='Channel', how='outer')
|
1314 |
+
contri_spends_df.fillna(0, inplace=True)
|
1315 |
+
|
1316 |
+
for col in [c for c in contri_spends_df.columns if c != 'Channel']:
|
1317 |
+
contri_spends_df[col] = contri_spends_df[col].apply(lambda x: safe_num_to_per(x))
|
1318 |
+
|
1319 |
+
media_performance_table_slide_layout = slides[7].slide_layout
|
1320 |
+
media_performance_table_slide = prs.slides.add_slide(media_performance_table_slide_layout)
|
1321 |
+
media_performance_table_slide = title_and_table(slide=media_performance_table_slide,
|
1322 |
+
title='Media and Spends Channels Tabular View',
|
1323 |
+
df=contri_spends_df,
|
1324 |
+
# column_width={0:2, 1:5, 2:2}
|
1325 |
+
)
|
1326 |
+
|
1327 |
+
if saved_scenarios_loaded is not None:
|
1328 |
+
# Optimization Details
|
1329 |
+
opt_section_slide_layout = slides[8].slide_layout
|
1330 |
+
opt_section_slide = prs.slides.add_slide(opt_section_slide_layout)
|
1331 |
+
|
1332 |
+
## Add title
|
1333 |
+
placeholders = opt_section_slide.placeholders
|
1334 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
1335 |
+
title_ph = opt_section_slide.placeholders[ph_idx[0]]
|
1336 |
+
title_ph.text = 'Optimizations Details'
|
1337 |
+
section_ph = opt_section_slide.placeholders[ph_idx[1]]
|
1338 |
+
section_ph.text = 'Section 2'
|
1339 |
+
|
1340 |
+
# Optimization
|
1341 |
+
for scenario_name, scenario in saved_scenarios_loaded.items():
|
1342 |
+
opt_summary_slide_layout = slides[9].slide_layout
|
1343 |
+
opt_summary_slide = prs.slides.add_slide(opt_summary_slide_layout)
|
1344 |
+
optimization_summary(opt_summary_slide, scenario, scenario_name)
|
1345 |
+
|
1346 |
+
channel_spends_slide_layout = slides[10].slide_layout
|
1347 |
+
channel_spends_slide = prs.slides.add_slide(channel_spends_slide_layout)
|
1348 |
+
channel_wise_spends(channel_spends_slide, scenario)
|
1349 |
+
|
1350 |
+
channel_roi_slide_layout = slides[11].slide_layout
|
1351 |
+
channel_roi_slide = prs.slides.add_slide(channel_roi_slide_layout)
|
1352 |
+
channel_wise_roi(channel_roi_slide, scenario)
|
1353 |
+
|
1354 |
+
effictiveness_efficiency_slide_layout = slides[12].slide_layout
|
1355 |
+
effictiveness_efficiency_slide = prs.slides.add_slide(effictiveness_efficiency_slide_layout)
|
1356 |
+
effictiveness_efficiency(effictiveness_efficiency_slide,
|
1357 |
+
data,
|
1358 |
+
bin_dict,
|
1359 |
+
scenario)
|
1360 |
+
|
1361 |
+
# Appendix Section
|
1362 |
+
appendix_section_slide_layout = slides[13].slide_layout
|
1363 |
+
appendix_section_slide = prs.slides.add_slide(appendix_section_slide_layout)
|
1364 |
+
|
1365 |
+
if tuned_model_dict is not None:
|
1366 |
+
|
1367 |
+
## Add title
|
1368 |
+
placeholders = appendix_section_slide.placeholders
|
1369 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
1370 |
+
title_ph = appendix_section_slide.placeholders[ph_idx[0]]
|
1371 |
+
title_ph.text = 'Appendix'
|
1372 |
+
section_ph = appendix_section_slide.placeholders[ph_idx[1]]
|
1373 |
+
section_ph.text = 'Section 3'
|
1374 |
+
|
1375 |
+
# Add transformations
|
1376 |
+
# if transform_dict is not None:
|
1377 |
+
# # Transformations
|
1378 |
+
# transformation_slide_layout = slides[14].slide_layout
|
1379 |
+
# transformation_slide = prs.slides.add_slide(transformation_slide_layout)
|
1380 |
+
# transformation_slide = title_and_table(slide=transformation_slide,
|
1381 |
+
# title='Transformations',
|
1382 |
+
# df=transformations(transform_dict),
|
1383 |
+
# custom_table_height=True
|
1384 |
+
# )
|
1385 |
+
|
1386 |
+
# Add model summary
|
1387 |
+
# Model Summary
|
1388 |
+
model_metrics_df = model_metrics(tuned_model_dict, False)
|
1389 |
+
model_summary_slide_layout = slides[15].slide_layout
|
1390 |
+
model_summary_slide = prs.slides.add_slide(model_summary_slide_layout)
|
1391 |
+
model_summary_slide = title_and_table(slide=model_summary_slide,
|
1392 |
+
title='Model Summary',
|
1393 |
+
df=model_metrics_df,
|
1394 |
+
custom_table_height=True
|
1395 |
+
)
|
1396 |
+
|
1397 |
+
# Last Slide
|
1398 |
+
last_slide_layout = slides[num_slides - 1].slide_layout
|
1399 |
+
last_slide = prs.slides.add_slide(last_slide_layout)
|
1400 |
+
|
1401 |
+
# Add title
|
1402 |
+
placeholders = last_slide.placeholders
|
1403 |
+
ph_idx = [ph.placeholder_format.idx for ph in placeholders]
|
1404 |
+
title_ph = last_slide.placeholders[ph_idx[0]]
|
1405 |
+
title_ph.text = 'Thank You'
|
1406 |
+
|
1407 |
+
# Remove template slides
|
1408 |
+
xml_slides = prs.slides._sldIdLst
|
1409 |
+
slides = list(xml_slides)
|
1410 |
+
for index in range(num_slides):
|
1411 |
+
xml_slides.remove(slides[index])
|
1412 |
+
|
1413 |
+
# prs.save(ppt_path)
|
1414 |
+
|
1415 |
+
# save the output into binary form
|
1416 |
+
binary_output = BytesIO()
|
1417 |
+
prs.save(binary_output)
|
1418 |
+
|
1419 |
+
return binary_output
|
requirements.txt
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair == 4.2.0
|
2 |
+
attrs == 23.1.0
|
3 |
+
bcrypt == 4.0.1
|
4 |
+
blinker == 1.6.2
|
5 |
+
cachetools == 5.3.1
|
6 |
+
certifi == 2023.7.22
|
7 |
+
charset-normalizer == 3.2.0
|
8 |
+
click == 8.1.7
|
9 |
+
colorama == 0.4.6
|
10 |
+
contourpy == 1.1.1
|
11 |
+
cycler == 0.11.0
|
12 |
+
dacite == 1.8.1
|
13 |
+
entrypoints == 0.4
|
14 |
+
et-xmlfile == 1.1.0
|
15 |
+
extra-streamlit-components == 0.1.56
|
16 |
+
fonttools == 4.42.1
|
17 |
+
gitdb == 4.0.10
|
18 |
+
GitPython == 3.1.35
|
19 |
+
htmlmin == 0.1.12
|
20 |
+
idna == 3.4
|
21 |
+
ImageHash == 4.3.1
|
22 |
+
importlib-metadata == 6.8.0
|
23 |
+
importlib-resources == 6.1.0
|
24 |
+
Jinja2 == 3.1.2
|
25 |
+
joblib == 1.3.2
|
26 |
+
jsonschema == 4.19.0
|
27 |
+
jsonschema-specifications== 2023.7.1
|
28 |
+
kaleido == 0.2.1
|
29 |
+
kiwisolver == 1.4.5
|
30 |
+
markdown-it-py == 3.0.0
|
31 |
+
MarkupSafe == 2.1.3
|
32 |
+
matplotlib == 3.7.0
|
33 |
+
mdurl == 0.1.2
|
34 |
+
networkx == 3.1
|
35 |
+
numerize == 0.12
|
36 |
+
numpy == 1.23.5
|
37 |
+
openpyxl>=3.1.0
|
38 |
+
packaging == 23.1
|
39 |
+
pandas == 1.5.2
|
40 |
+
pandas-profiling == 3.6.6
|
41 |
+
patsy == 0.5.3
|
42 |
+
phik == 0.12.3
|
43 |
+
Pillow == 10.0.0
|
44 |
+
pip == 23.2.1
|
45 |
+
plotly == 5.11.0
|
46 |
+
protobuf == 3.20.3
|
47 |
+
pyarrow == 13.0.0
|
48 |
+
pydantic == 1.10.13
|
49 |
+
pydeck == 0.8.1b0
|
50 |
+
Pygments == 2.16.1
|
51 |
+
PyJWT == 2.8.0
|
52 |
+
Pympler == 1.0.1
|
53 |
+
pyparsing == 3.1.1
|
54 |
+
python-dateutil == 2.8.2
|
55 |
+
python-decouple == 3.8
|
56 |
+
pytz == 2023.3.post1
|
57 |
+
PyWavelets == 1.4.1
|
58 |
+
PyYAML == 6.0.1
|
59 |
+
referencing == 0.30.2
|
60 |
+
requests == 2.31.0
|
61 |
+
rich == 13.5.2
|
62 |
+
rpds-py == 0.10.2
|
63 |
+
scikit-learn == 1.1.3
|
64 |
+
scipy == 1.9.3
|
65 |
+
seaborn == 0.12.2
|
66 |
+
semver == 3.0.1
|
67 |
+
setuptools == 68.1.2
|
68 |
+
six == 1.16.0
|
69 |
+
smmap == 5.0.0
|
70 |
+
statsmodels == 0.14.0
|
71 |
+
streamlit == 1.28.0
|
72 |
+
streamlit-aggrid == 0.3.4.post3
|
73 |
+
streamlit-authenticator == 0.2.1
|
74 |
+
streamlit-pandas-profiling== 0.1.3
|
75 |
+
sweetviz == 2.2.1
|
76 |
+
tangled-up-in-unicode == 0.2.0
|
77 |
+
tenacity == 8.2.3
|
78 |
+
threadpoolctl == 3.2.0
|
79 |
+
toml == 0.10.2
|
80 |
+
toolz == 0.12.0
|
81 |
+
tornado == 6.3.3
|
82 |
+
tqdm == 4.66.1
|
83 |
+
typeguard == 2.13.3
|
84 |
+
typing_extensions == 4.7.1
|
85 |
+
tzdata == 2023.3
|
86 |
+
tzlocal == 5.0.1
|
87 |
+
urllib3 == 2.0.4
|
88 |
+
validators == 0.22.0
|
89 |
+
visions == 0.7.5
|
90 |
+
watchdog == 3.0.0
|
91 |
+
wheel == 0.41.2
|
92 |
+
wordcloud == 1.9.2
|
93 |
+
ydata-profiling == 4.5.1
|
94 |
+
zipp == 3.16.2
|
95 |
+
psycopg2 == 2.9.9
|
96 |
+
python-pptx == 0.6.21
|
scenario.py
ADDED
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.optimize import minimize, LinearConstraint, NonlinearConstraint
|
3 |
+
from collections import OrderedDict
|
4 |
+
import pandas as pd
|
5 |
+
from decimal import Decimal
|
6 |
+
|
7 |
+
|
8 |
+
def round_num(n, decimal=2):
|
9 |
+
n = Decimal(n)
|
10 |
+
return n.to_integral() if n == n.to_integral() else round(n.normalize(), decimal)
|
11 |
+
|
12 |
+
|
13 |
+
def numerize(n, decimal=2):
|
14 |
+
# 60 sufixes
|
15 |
+
sufixes = [
|
16 |
+
"",
|
17 |
+
"K",
|
18 |
+
"M",
|
19 |
+
"B",
|
20 |
+
"T",
|
21 |
+
"Qa",
|
22 |
+
"Qu",
|
23 |
+
"S",
|
24 |
+
"Oc",
|
25 |
+
"No",
|
26 |
+
"D",
|
27 |
+
"Ud",
|
28 |
+
"Dd",
|
29 |
+
"Td",
|
30 |
+
"Qt",
|
31 |
+
"Qi",
|
32 |
+
"Se",
|
33 |
+
"Od",
|
34 |
+
"Nd",
|
35 |
+
"V",
|
36 |
+
"Uv",
|
37 |
+
"Dv",
|
38 |
+
"Tv",
|
39 |
+
"Qv",
|
40 |
+
"Qx",
|
41 |
+
"Sx",
|
42 |
+
"Ox",
|
43 |
+
"Nx",
|
44 |
+
"Tn",
|
45 |
+
"Qa",
|
46 |
+
"Qu",
|
47 |
+
"S",
|
48 |
+
"Oc",
|
49 |
+
"No",
|
50 |
+
"D",
|
51 |
+
"Ud",
|
52 |
+
"Dd",
|
53 |
+
"Td",
|
54 |
+
"Qt",
|
55 |
+
"Qi",
|
56 |
+
"Se",
|
57 |
+
"Od",
|
58 |
+
"Nd",
|
59 |
+
"V",
|
60 |
+
"Uv",
|
61 |
+
"Dv",
|
62 |
+
"Tv",
|
63 |
+
"Qv",
|
64 |
+
"Qx",
|
65 |
+
"Sx",
|
66 |
+
"Ox",
|
67 |
+
"Nx",
|
68 |
+
"Tn",
|
69 |
+
"x",
|
70 |
+
"xx",
|
71 |
+
"xxx",
|
72 |
+
"X",
|
73 |
+
"XX",
|
74 |
+
"XXX",
|
75 |
+
"END",
|
76 |
+
]
|
77 |
+
|
78 |
+
sci_expr = [
|
79 |
+
1e0,
|
80 |
+
1e3,
|
81 |
+
1e6,
|
82 |
+
1e9,
|
83 |
+
1e12,
|
84 |
+
1e15,
|
85 |
+
1e18,
|
86 |
+
1e21,
|
87 |
+
1e24,
|
88 |
+
1e27,
|
89 |
+
1e30,
|
90 |
+
1e33,
|
91 |
+
1e36,
|
92 |
+
1e39,
|
93 |
+
1e42,
|
94 |
+
1e45,
|
95 |
+
1e48,
|
96 |
+
1e51,
|
97 |
+
1e54,
|
98 |
+
1e57,
|
99 |
+
1e60,
|
100 |
+
1e63,
|
101 |
+
1e66,
|
102 |
+
1e69,
|
103 |
+
1e72,
|
104 |
+
1e75,
|
105 |
+
1e78,
|
106 |
+
1e81,
|
107 |
+
1e84,
|
108 |
+
1e87,
|
109 |
+
1e90,
|
110 |
+
1e93,
|
111 |
+
1e96,
|
112 |
+
1e99,
|
113 |
+
1e102,
|
114 |
+
1e105,
|
115 |
+
1e108,
|
116 |
+
1e111,
|
117 |
+
1e114,
|
118 |
+
1e117,
|
119 |
+
1e120,
|
120 |
+
1e123,
|
121 |
+
1e126,
|
122 |
+
1e129,
|
123 |
+
1e132,
|
124 |
+
1e135,
|
125 |
+
1e138,
|
126 |
+
1e141,
|
127 |
+
1e144,
|
128 |
+
1e147,
|
129 |
+
1e150,
|
130 |
+
1e153,
|
131 |
+
1e156,
|
132 |
+
1e159,
|
133 |
+
1e162,
|
134 |
+
1e165,
|
135 |
+
1e168,
|
136 |
+
1e171,
|
137 |
+
1e174,
|
138 |
+
1e177,
|
139 |
+
]
|
140 |
+
minus_buff = n
|
141 |
+
n = abs(n)
|
142 |
+
|
143 |
+
if n < 1:
|
144 |
+
return f"{round(n/1000, decimal)}K"
|
145 |
+
|
146 |
+
for x in range(len(sci_expr)):
|
147 |
+
try:
|
148 |
+
if n >= sci_expr[x] and n < sci_expr[x + 1]:
|
149 |
+
sufix = sufixes[x]
|
150 |
+
if n >= 1e3:
|
151 |
+
num = str(round_num(n / sci_expr[x], decimal))
|
152 |
+
else:
|
153 |
+
num = str(round_num(n, decimal))
|
154 |
+
return num + sufix if minus_buff > 0 else "-" + num + sufix
|
155 |
+
except IndexError:
|
156 |
+
pass
|
157 |
+
|
158 |
+
|
159 |
+
def class_to_dict(class_instance):
|
160 |
+
attr_dict = {}
|
161 |
+
if isinstance(class_instance, Channel):
|
162 |
+
attr_dict["type"] = "Channel"
|
163 |
+
attr_dict["name"] = class_instance.name
|
164 |
+
attr_dict["dates"] = class_instance.dates
|
165 |
+
attr_dict["spends"] = class_instance.actual_spends
|
166 |
+
attr_dict["conversion_rate"] = class_instance.conversion_rate
|
167 |
+
attr_dict["modified_spends"] = class_instance.modified_spends
|
168 |
+
attr_dict["modified_sales"] = class_instance.modified_sales
|
169 |
+
attr_dict["response_curve_type"] = class_instance.response_curve_type
|
170 |
+
attr_dict["response_curve_params"] = class_instance.response_curve_params
|
171 |
+
attr_dict["penalty"] = class_instance.penalty
|
172 |
+
attr_dict["bounds"] = class_instance.bounds
|
173 |
+
attr_dict["actual_total_spends"] = class_instance.actual_total_spends
|
174 |
+
attr_dict["actual_total_sales"] = class_instance.actual_total_sales
|
175 |
+
attr_dict["modified_total_spends"] = class_instance.modified_total_spends
|
176 |
+
attr_dict["modified_total_sales"] = class_instance.modified_total_sales
|
177 |
+
attr_dict["actual_mroi"] = class_instance.get_marginal_roi("actual")
|
178 |
+
attr_dict["modified_mroi"] = class_instance.get_marginal_roi("modified")
|
179 |
+
attr_dict["freeze"] = class_instance.freeze
|
180 |
+
attr_dict["correction"] = class_instance.correction
|
181 |
+
|
182 |
+
elif isinstance(class_instance, Scenario):
|
183 |
+
attr_dict["type"] = "Scenario"
|
184 |
+
attr_dict["name"] = class_instance.name
|
185 |
+
attr_dict["bounds"] = class_instance.bounds
|
186 |
+
channels = []
|
187 |
+
for channel in class_instance.channels.values():
|
188 |
+
channels.append(class_to_dict(channel))
|
189 |
+
attr_dict["channels"] = channels
|
190 |
+
attr_dict["constant"] = class_instance.constant
|
191 |
+
attr_dict["correction"] = class_instance.correction
|
192 |
+
attr_dict["actual_total_spends"] = class_instance.actual_total_spends
|
193 |
+
attr_dict["actual_total_sales"] = class_instance.actual_total_sales
|
194 |
+
attr_dict["modified_total_spends"] = class_instance.modified_total_spends
|
195 |
+
attr_dict["modified_total_sales"] = class_instance.modified_total_sales
|
196 |
+
|
197 |
+
return attr_dict
|
198 |
+
|
199 |
+
|
200 |
+
# def class_convert_to_dict(class_instance):
|
201 |
+
# attr_dict = {}
|
202 |
+
# if isinstance(class_instance, Channel):
|
203 |
+
# attr_dict["type"] = "Channel"
|
204 |
+
# attr_dict["name"] = class_instance.name
|
205 |
+
# attr_dict["dates"] = class_instance.dates
|
206 |
+
# attr_dict["spends"] = class_instance.actual_spends
|
207 |
+
# attr_dict["conversion_rate"] = class_instance.conversion_rate
|
208 |
+
# attr_dict["modified_spends"] = class_instance.modified_spends
|
209 |
+
# attr_dict["modified_sales"] = class_instance.modified_sales
|
210 |
+
# attr_dict["response_curve_type"] = class_instance.response_curve_type
|
211 |
+
# attr_dict["response_curve_params"] = class_instance.response_curve_params
|
212 |
+
# # attr_dict["penalty"] = class_instance.penalty
|
213 |
+
# attr_dict["bounds"] = class_instance.bounds
|
214 |
+
# attr_dict["actual_total_spends"] = class_instance.actual_total_spends
|
215 |
+
# attr_dict["actual_total_sales"] = class_instance.actual_total_sales
|
216 |
+
# attr_dict["modified_total_spends"] = class_instance.modified_total_spends
|
217 |
+
# attr_dict["modified_total_sales"] = class_instance.modified_total_sales
|
218 |
+
# # attr_dict["actual_mroi"] = class_instance.get_marginal_roi("actual")
|
219 |
+
# # attr_dict["modified_mroi"] = class_instance.get_marginal_roi("modified")
|
220 |
+
|
221 |
+
# attr_dict["freeze"] = class_instance.freeze
|
222 |
+
# attr_dict["correction"] = class_instance.correction
|
223 |
+
|
224 |
+
# elif isinstance(class_instance, Scenario):
|
225 |
+
# attr_dict["type"] = "Scenario"
|
226 |
+
# attr_dict["name"] = class_instance.name
|
227 |
+
# channels = {}
|
228 |
+
# for channel in class_instance.channels.values():
|
229 |
+
# channels[channel.name] = class_to_dict(channel)
|
230 |
+
# attr_dict["channels"] = channels
|
231 |
+
# attr_dict["constant"] = class_instance.constant
|
232 |
+
# attr_dict["correction"] = class_instance.correction
|
233 |
+
# attr_dict["actual_total_spends"] = class_instance.actual_total_spends
|
234 |
+
# attr_dict["actual_total_sales"] = class_instance.actual_total_sales
|
235 |
+
# attr_dict["modified_total_spends"] = class_instance.modified_total_spends
|
236 |
+
# attr_dict["modified_total_sales"] = class_instance.modified_total_sales
|
237 |
+
|
238 |
+
# attr_dict["bound_type"] = class_instance.bound_type
|
239 |
+
|
240 |
+
# attr_dict["bounds"] = class_instance.bounds
|
241 |
+
|
242 |
+
# return attr_dict.copy()
|
243 |
+
|
244 |
+
# Function to convert class instance to dictionary
|
245 |
+
def class_convert_to_dict(class_instance):
|
246 |
+
attr_dict = {}
|
247 |
+
|
248 |
+
if isinstance(class_instance, Channel):
|
249 |
+
# Convert Channel instance to dictionary
|
250 |
+
attr_dict["type"] = "Channel"
|
251 |
+
attr_dict["name"] = class_instance.name
|
252 |
+
attr_dict["dates"] = class_instance.dates
|
253 |
+
attr_dict["spends"] = list(class_instance.actual_spends)
|
254 |
+
attr_dict["conversion_rate"] = class_instance.conversion_rate
|
255 |
+
attr_dict["modified_spends"] = list(class_instance.modified_spends)
|
256 |
+
attr_dict["modified_sales"] = list(class_instance.modified_sales)
|
257 |
+
attr_dict["response_curve_type"] = class_instance.response_curve_type
|
258 |
+
attr_dict["response_curve_params"] = class_instance.response_curve_params.copy()
|
259 |
+
attr_dict["bounds"] = class_instance.bounds.copy()
|
260 |
+
attr_dict["actual_total_spends"] = class_instance.actual_total_spends
|
261 |
+
attr_dict["actual_total_sales"] = class_instance.actual_total_sales
|
262 |
+
attr_dict["modified_total_spends"] = class_instance.modified_total_spends
|
263 |
+
attr_dict["modified_total_sales"] = class_instance.modified_total_sales
|
264 |
+
attr_dict["freeze"] = class_instance.freeze
|
265 |
+
attr_dict["correction"] = class_instance.correction.copy()
|
266 |
+
|
267 |
+
elif isinstance(class_instance, Scenario):
|
268 |
+
# Convert Scenario instance to dictionary
|
269 |
+
attr_dict["type"] = "Scenario"
|
270 |
+
attr_dict["name"] = class_instance.name
|
271 |
+
|
272 |
+
channels = {}
|
273 |
+
for channel in class_instance.channels.values():
|
274 |
+
channels[channel.name] = class_convert_to_dict(channel)
|
275 |
+
attr_dict["channels"] = channels
|
276 |
+
|
277 |
+
attr_dict["constant"] = list(class_instance.constant)
|
278 |
+
attr_dict["correction"] = list(class_instance.correction)
|
279 |
+
attr_dict["actual_total_spends"] = class_instance.actual_total_spends
|
280 |
+
attr_dict["actual_total_sales"] = class_instance.actual_total_sales
|
281 |
+
attr_dict["modified_total_spends"] = class_instance.modified_total_spends
|
282 |
+
attr_dict["modified_total_sales"] = class_instance.modified_total_sales
|
283 |
+
attr_dict["bound_type"] = class_instance.bound_type
|
284 |
+
attr_dict["bounds"] = class_instance.bounds.copy()
|
285 |
+
|
286 |
+
return attr_dict
|
287 |
+
|
288 |
+
|
289 |
+
def class_from_dict(attr_dict):
|
290 |
+
if attr_dict["type"] == "Channel":
|
291 |
+
return Channel.from_dict(attr_dict)
|
292 |
+
elif attr_dict["type"] == "Scenario":
|
293 |
+
return Scenario.from_dict(attr_dict)
|
294 |
+
|
295 |
+
|
296 |
+
class Channel:
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
name,
|
300 |
+
dates,
|
301 |
+
spends,
|
302 |
+
response_curve_type,
|
303 |
+
response_curve_params,
|
304 |
+
bounds,
|
305 |
+
correction,
|
306 |
+
conversion_rate=1,
|
307 |
+
modified_spends=None,
|
308 |
+
penalty=True,
|
309 |
+
freeze=False,
|
310 |
+
):
|
311 |
+
self.name = name
|
312 |
+
self.dates = dates
|
313 |
+
self.conversion_rate = conversion_rate
|
314 |
+
self.actual_spends = spends.copy()
|
315 |
+
self.correction = correction
|
316 |
+
|
317 |
+
if modified_spends is None:
|
318 |
+
self.modified_spends = self.actual_spends.copy()
|
319 |
+
else:
|
320 |
+
self.modified_spends = modified_spends
|
321 |
+
|
322 |
+
self.response_curve_type = response_curve_type
|
323 |
+
self.response_curve_params = response_curve_params
|
324 |
+
self.bounds = bounds
|
325 |
+
self.penalty = penalty
|
326 |
+
self.freeze = freeze
|
327 |
+
|
328 |
+
self.upper_limit = self.actual_spends.max() + self.actual_spends.std()
|
329 |
+
self.power = np.ceil(np.log(self.actual_spends.max()) / np.log(10)) - 3
|
330 |
+
# self.actual_sales = None
|
331 |
+
# self.actual_sales = self.response_curve(self.actual_spends)
|
332 |
+
self.actual_total_spends = self.actual_spends.sum()
|
333 |
+
self.actual_total_sales = self.actual_sales.sum()
|
334 |
+
self.modified_sales = self.calculate_sales()
|
335 |
+
self.modified_total_spends = self.modified_spends.sum()
|
336 |
+
self.modified_total_sales = self.modified_sales.sum()
|
337 |
+
self.delta_spends = self.modified_total_spends - self.actual_total_spends
|
338 |
+
self.delta_sales = self.modified_total_sales - self.actual_total_sales
|
339 |
+
|
340 |
+
@property
|
341 |
+
def actual_sales(self):
|
342 |
+
return self.response_curve(self.actual_spends) + self.correction
|
343 |
+
|
344 |
+
def update_penalty(self, penalty):
|
345 |
+
self.penalty = penalty
|
346 |
+
|
347 |
+
def _modify_spends(self, spends_array, total_spends):
|
348 |
+
return spends_array * total_spends / spends_array.sum()
|
349 |
+
|
350 |
+
def modify_spends(self, total_spends):
|
351 |
+
self.modified_spends = (
|
352 |
+
self.modified_spends * total_spends / self.modified_spends.sum()
|
353 |
+
)
|
354 |
+
|
355 |
+
def calculate_sales(self):
|
356 |
+
return self.response_curve(self.modified_spends) + self.correction
|
357 |
+
|
358 |
+
def response_curve(self, x):
|
359 |
+
if self.penalty:
|
360 |
+
x = np.where(
|
361 |
+
x < self.upper_limit,
|
362 |
+
x,
|
363 |
+
self.upper_limit + (x - self.upper_limit) * self.upper_limit / x,
|
364 |
+
)
|
365 |
+
if self.response_curve_type == "s-curve":
|
366 |
+
if self.power >= 0:
|
367 |
+
x = x / 10**self.power
|
368 |
+
x = x.astype("float64")
|
369 |
+
K = self.response_curve_params["K"]
|
370 |
+
b = self.response_curve_params["b"]
|
371 |
+
a = self.response_curve_params["a"]
|
372 |
+
x0 = self.response_curve_params["x0"]
|
373 |
+
sales = K / (1 + b * np.exp(-a * (x - x0)))
|
374 |
+
if self.response_curve_type == "linear":
|
375 |
+
beta = self.response_curve_params["beta"]
|
376 |
+
sales = beta * x
|
377 |
+
|
378 |
+
return sales
|
379 |
+
|
380 |
+
def get_marginal_roi(self, flag):
|
381 |
+
K = self.response_curve_params["K"]
|
382 |
+
a = self.response_curve_params["a"]
|
383 |
+
# x = self.modified_total_spends
|
384 |
+
# if self.power >= 0 :
|
385 |
+
# x = x / 10**self.power
|
386 |
+
# x = x.astype('float64')
|
387 |
+
# return K*b*a*np.exp(-a*(x-x0)) / (1 + b * np.exp(-a*(x - x0)))**2
|
388 |
+
if flag == "actual":
|
389 |
+
y = self.response_curve(self.actual_spends)
|
390 |
+
# spends_array = self.actual_spends
|
391 |
+
# total_spends = self.actual_total_spends
|
392 |
+
# total_sales = self.actual_total_sales
|
393 |
+
|
394 |
+
else:
|
395 |
+
y = self.response_curve(self.modified_spends)
|
396 |
+
# spends_array = self.modified_spends
|
397 |
+
# total_spends = self.modified_total_spends
|
398 |
+
# total_sales = self.modified_total_sales
|
399 |
+
|
400 |
+
# spends_inc_1 = self._modify_spends(spends_array, total_spends+1)
|
401 |
+
mroi = a * (y) * (1 - y / K)
|
402 |
+
return mroi.sum() / len(self.modified_spends)
|
403 |
+
# spends_inc_1 = self.spends_array + 1
|
404 |
+
# new_total_sales = self.response_curve(spends_inc_1).sum()
|
405 |
+
# return (new_total_sales - total_sales) / len(self.modified_spends)
|
406 |
+
|
407 |
+
def update(self, total_spends):
|
408 |
+
self.modify_spends(total_spends)
|
409 |
+
self.modified_sales = self.calculate_sales()
|
410 |
+
self.modified_total_spends = self.modified_spends.sum()
|
411 |
+
self.modified_total_sales = self.modified_sales.sum()
|
412 |
+
self.delta_spends = self.modified_total_spends - self.actual_total_spends
|
413 |
+
self.delta_sales = self.modified_total_sales - self.actual_total_sales
|
414 |
+
|
415 |
+
def intialize(self):
|
416 |
+
self.new_spends = self.old_spends
|
417 |
+
|
418 |
+
def __str__(self):
|
419 |
+
return f"{self.name},{self.actual_total_sales}, {self.modified_total_spends}"
|
420 |
+
|
421 |
+
@classmethod
|
422 |
+
def from_dict(cls, attr_dict):
|
423 |
+
return Channel(
|
424 |
+
name=attr_dict["name"],
|
425 |
+
dates=attr_dict["dates"],
|
426 |
+
spends=attr_dict["spends"],
|
427 |
+
bounds=attr_dict["bounds"],
|
428 |
+
modified_spends=attr_dict["modified_spends"],
|
429 |
+
response_curve_type=attr_dict["response_curve_type"],
|
430 |
+
response_curve_params=attr_dict["response_curve_params"],
|
431 |
+
penalty=attr_dict["penalty"],
|
432 |
+
correction=attr_dict["correction"],
|
433 |
+
)
|
434 |
+
|
435 |
+
def update_response_curves(self, response_curve_params):
|
436 |
+
self.response_curve_params = response_curve_params
|
437 |
+
|
438 |
+
|
439 |
+
class Scenario:
|
440 |
+
def __init__(
|
441 |
+
self, name, channels, constant, correction, bound_type=False, bounds=[-10, 10]
|
442 |
+
):
|
443 |
+
self.name = name
|
444 |
+
self.channels = channels
|
445 |
+
self.constant = constant
|
446 |
+
self.correction = correction
|
447 |
+
self.bound_type = bound_type
|
448 |
+
self.bounds = bounds
|
449 |
+
|
450 |
+
self.actual_total_spends = self.calculate_modified_total_spends()
|
451 |
+
self.actual_total_sales = self.calculate_actual_total_sales()
|
452 |
+
self.modified_total_sales = self.calculate_modified_total_sales()
|
453 |
+
self.modified_total_spends = self.calculate_modified_total_spends()
|
454 |
+
self.delta_spends = self.modified_total_spends - self.actual_total_spends
|
455 |
+
self.delta_sales = self.modified_total_sales - self.actual_total_sales
|
456 |
+
|
457 |
+
def update_penalty(self, value):
|
458 |
+
for channel in self.channels.values():
|
459 |
+
channel.update_penalty(value)
|
460 |
+
|
461 |
+
def calculate_modified_total_spends(self):
|
462 |
+
total_actual_spends = 0.0
|
463 |
+
for channel in self.channels.values():
|
464 |
+
total_actual_spends += channel.actual_total_spends * channel.conversion_rate
|
465 |
+
return total_actual_spends
|
466 |
+
|
467 |
+
def calculate_modified_total_spends(self):
|
468 |
+
total_modified_spends = 0.0
|
469 |
+
for channel in self.channels.values():
|
470 |
+
# import streamlit as st
|
471 |
+
# st.write(channel.modified_total_spends )
|
472 |
+
total_modified_spends += (
|
473 |
+
channel.modified_total_spends * channel.conversion_rate
|
474 |
+
)
|
475 |
+
return total_modified_spends
|
476 |
+
|
477 |
+
def calculate_actual_total_sales(self):
|
478 |
+
total_actual_sales = self.constant.sum() + self.correction.sum()
|
479 |
+
for channel in self.channels.values():
|
480 |
+
total_actual_sales += channel.actual_total_sales
|
481 |
+
return total_actual_sales
|
482 |
+
|
483 |
+
def calculate_modified_total_sales(self):
|
484 |
+
total_modified_sales = self.constant.sum() + self.correction.sum()
|
485 |
+
for channel in self.channels.values():
|
486 |
+
total_modified_sales += channel.modified_total_sales
|
487 |
+
return total_modified_sales
|
488 |
+
|
489 |
+
def update(self, channel_name, modified_spends):
|
490 |
+
self.channels[channel_name].update(modified_spends)
|
491 |
+
self.modified_total_sales = self.calculate_modified_total_sales()
|
492 |
+
self.modified_total_spends = self.calculate_modified_total_spends()
|
493 |
+
self.delta_spends = self.modified_total_spends - self.actual_total_spends
|
494 |
+
self.delta_sales = self.modified_total_sales - self.actual_total_sales
|
495 |
+
|
496 |
+
# def optimize_spends(self, sales_percent, channels_list, algo="COBYLA"):
|
497 |
+
# desired_sales = self.actual_total_sales * (1 + sales_percent / 100.0)
|
498 |
+
|
499 |
+
# def constraint(x):
|
500 |
+
# for ch, spends in zip(channels_list, x):
|
501 |
+
# self.update(ch, spends)
|
502 |
+
# return self.modified_total_sales - desired_sales
|
503 |
+
|
504 |
+
# bounds = []
|
505 |
+
# for ch in channels_list:
|
506 |
+
# bounds.append(
|
507 |
+
# (1 + np.array([-50.0, 100.0]) / 100.0)
|
508 |
+
# * self.channels[ch].actual_total_spends
|
509 |
+
# )
|
510 |
+
|
511 |
+
# initial_point = []
|
512 |
+
# for bound in bounds:
|
513 |
+
# initial_point.append(bound[0])
|
514 |
+
|
515 |
+
# power = np.ceil(np.log(sum(initial_point)) / np.log(10))
|
516 |
+
|
517 |
+
# constraints = [NonlinearConstraint(constraint, -1.0, 1.0)]
|
518 |
+
|
519 |
+
# res = minimize(
|
520 |
+
# lambda x: sum(x) / 10 ** (power),
|
521 |
+
# bounds=bounds,
|
522 |
+
# x0=initial_point,
|
523 |
+
# constraints=constraints,
|
524 |
+
# method=algo,
|
525 |
+
# options={"maxiter": int(2e7), "catol": 1},
|
526 |
+
# )
|
527 |
+
|
528 |
+
# for channel_name, modified_spends in zip(channels_list, res.x):
|
529 |
+
# self.update(channel_name, modified_spends)
|
530 |
+
|
531 |
+
# return zip(channels_list, res.x)
|
532 |
+
|
533 |
+
def optimize_spends(self, sales_percent, channels_list, algo="trust-constr"):
|
534 |
+
desired_sales = self.actual_total_sales * (1 + sales_percent / 100.0)
|
535 |
+
|
536 |
+
def constraint(x):
|
537 |
+
for ch, spends in zip(channels_list, x):
|
538 |
+
self.update(ch, spends)
|
539 |
+
return self.modified_total_sales - desired_sales
|
540 |
+
|
541 |
+
bounds = []
|
542 |
+
for ch in channels_list:
|
543 |
+
bounds.append(
|
544 |
+
(1 + np.array([-50.0, 100.0]) / 100.0)
|
545 |
+
* self.channels[ch].actual_total_spends
|
546 |
+
)
|
547 |
+
|
548 |
+
initial_point = []
|
549 |
+
for bound in bounds:
|
550 |
+
initial_point.append(bound[0])
|
551 |
+
|
552 |
+
power = np.ceil(np.log(sum(initial_point)) / np.log(10))
|
553 |
+
|
554 |
+
constraints = [NonlinearConstraint(constraint, -1.0, 1.0)]
|
555 |
+
|
556 |
+
res = minimize(
|
557 |
+
lambda x: sum(x) / 10 ** (power),
|
558 |
+
bounds=bounds,
|
559 |
+
x0=initial_point,
|
560 |
+
constraints=constraints,
|
561 |
+
method=algo,
|
562 |
+
options={"maxiter": int(2e7), "xtol": 100},
|
563 |
+
)
|
564 |
+
|
565 |
+
for channel_name, modified_spends in zip(channels_list, res.x):
|
566 |
+
self.update(channel_name, modified_spends)
|
567 |
+
|
568 |
+
return zip(channels_list, res.x)
|
569 |
+
|
570 |
+
def optimize(self, spends_percent, channels_list):
|
571 |
+
# channels_list = self.channels.keys()
|
572 |
+
num_channels = len(channels_list)
|
573 |
+
spends_constant = []
|
574 |
+
spends_constraint = 0.0
|
575 |
+
for channel_name in channels_list:
|
576 |
+
# spends_constraint += self.channels[channel_name].modified_total_spends
|
577 |
+
spends_constant.append(self.channels[channel_name].conversion_rate)
|
578 |
+
spends_constraint += (
|
579 |
+
self.channels[channel_name].actual_total_spends
|
580 |
+
* self.channels[channel_name].conversion_rate
|
581 |
+
)
|
582 |
+
spends_constraint = spends_constraint * (1 + spends_percent / 100)
|
583 |
+
# constraint= LinearConstraint(np.ones((num_channels,)), lb = spends_constraint, ub = spends_constraint)
|
584 |
+
constraint = LinearConstraint(
|
585 |
+
np.array(spends_constant),
|
586 |
+
lb=spends_constraint,
|
587 |
+
ub=spends_constraint,
|
588 |
+
)
|
589 |
+
bounds = []
|
590 |
+
old_spends = []
|
591 |
+
for channel_name in channels_list:
|
592 |
+
_channel_class = self.channels[channel_name]
|
593 |
+
channel_bounds = _channel_class.bounds
|
594 |
+
channel_actual_total_spends = _channel_class.actual_total_spends * (
|
595 |
+
(1 + spends_percent / 100)
|
596 |
+
)
|
597 |
+
old_spends.append(channel_actual_total_spends)
|
598 |
+
bounds.append((1 + channel_bounds / 100) * channel_actual_total_spends)
|
599 |
+
|
600 |
+
def objective_function(x):
|
601 |
+
for channel_name, modified_spends in zip(channels_list, x):
|
602 |
+
self.update(channel_name, modified_spends)
|
603 |
+
return -1 * self.modified_total_sales
|
604 |
+
|
605 |
+
power = np.ceil(np.log(self.modified_total_sales) / np.log(10))
|
606 |
+
|
607 |
+
res = minimize(
|
608 |
+
lambda x: objective_function(x) / 10 ** (power - 1),
|
609 |
+
method="trust-constr",
|
610 |
+
x0=old_spends,
|
611 |
+
constraints=constraint,
|
612 |
+
bounds=bounds,
|
613 |
+
options={"maxiter": int(1e7), "xtol": 0.1},
|
614 |
+
)
|
615 |
+
# res = dual_annealing(
|
616 |
+
# objective_function,
|
617 |
+
# x0=old_spends,
|
618 |
+
# mi
|
619 |
+
# constraints=constraint,
|
620 |
+
# bounds=bounds,
|
621 |
+
# tol=1e-16
|
622 |
+
|
623 |
+
for channel_name, modified_spends in zip(channels_list, res.x):
|
624 |
+
self.update(channel_name, modified_spends)
|
625 |
+
|
626 |
+
return zip(channels_list, res.x)
|
627 |
+
|
628 |
+
def save(self):
|
629 |
+
details = {}
|
630 |
+
actual_list = []
|
631 |
+
modified_list = []
|
632 |
+
data = {}
|
633 |
+
channel_data = []
|
634 |
+
|
635 |
+
summary_rows = []
|
636 |
+
actual_list.append(
|
637 |
+
{
|
638 |
+
"name": "Total",
|
639 |
+
"Spends": self.actual_total_spends,
|
640 |
+
"Sales": self.actual_total_sales,
|
641 |
+
}
|
642 |
+
)
|
643 |
+
modified_list.append(
|
644 |
+
{
|
645 |
+
"name": "Total",
|
646 |
+
"Spends": self.modified_total_spends,
|
647 |
+
"Sales": self.modified_total_sales,
|
648 |
+
}
|
649 |
+
)
|
650 |
+
for channel in self.channels.values():
|
651 |
+
name_mod = channel.name.replace("_", " ")
|
652 |
+
if name_mod.lower().endswith(" imp"):
|
653 |
+
name_mod = name_mod.replace("Imp", " Impressions")
|
654 |
+
summary_rows.append(
|
655 |
+
[
|
656 |
+
name_mod,
|
657 |
+
channel.actual_total_spends,
|
658 |
+
channel.modified_total_spends,
|
659 |
+
channel.actual_total_sales,
|
660 |
+
channel.modified_total_sales,
|
661 |
+
round(
|
662 |
+
channel.actual_total_sales / channel.actual_total_spends,
|
663 |
+
2,
|
664 |
+
),
|
665 |
+
round(
|
666 |
+
channel.modified_total_sales / channel.modified_total_spends,
|
667 |
+
2,
|
668 |
+
),
|
669 |
+
channel.get_marginal_roi("actual"),
|
670 |
+
channel.get_marginal_roi("modified"),
|
671 |
+
]
|
672 |
+
)
|
673 |
+
data[channel.name] = channel.modified_spends
|
674 |
+
data["Date"] = channel.dates
|
675 |
+
data["Sales"] = (
|
676 |
+
data.get("Sales", np.zeros((len(channel.dates),)))
|
677 |
+
+ channel.modified_sales
|
678 |
+
)
|
679 |
+
actual_list.append(
|
680 |
+
{
|
681 |
+
"name": channel.name,
|
682 |
+
"Spends": channel.actual_total_spends,
|
683 |
+
"Sales": channel.actual_total_sales,
|
684 |
+
"ROI": round(
|
685 |
+
channel.actual_total_sales / channel.actual_total_spends,
|
686 |
+
2,
|
687 |
+
),
|
688 |
+
}
|
689 |
+
)
|
690 |
+
modified_list.append(
|
691 |
+
{
|
692 |
+
"name": channel.name,
|
693 |
+
"Spends": channel.modified_total_spends,
|
694 |
+
"Sales": channel.modified_total_sales,
|
695 |
+
"ROI": round(
|
696 |
+
channel.modified_total_sales / channel.modified_total_spends,
|
697 |
+
2,
|
698 |
+
),
|
699 |
+
"Marginal ROI": channel.get_marginal_roi("modified"),
|
700 |
+
}
|
701 |
+
)
|
702 |
+
|
703 |
+
channel_data.append(
|
704 |
+
{
|
705 |
+
"channel": channel.name,
|
706 |
+
"spends_act": channel.actual_total_spends,
|
707 |
+
"spends_mod": channel.modified_total_spends,
|
708 |
+
"sales_act": channel.actual_total_sales,
|
709 |
+
"sales_mod": channel.modified_total_sales,
|
710 |
+
}
|
711 |
+
)
|
712 |
+
summary_rows.append(
|
713 |
+
[
|
714 |
+
"Total",
|
715 |
+
self.actual_total_spends,
|
716 |
+
self.modified_total_spends,
|
717 |
+
self.actual_total_sales,
|
718 |
+
self.modified_total_sales,
|
719 |
+
round(self.actual_total_sales / self.actual_total_spends, 2),
|
720 |
+
round(self.modified_total_sales / self.modified_total_spends, 2),
|
721 |
+
0.0,
|
722 |
+
0.0,
|
723 |
+
]
|
724 |
+
)
|
725 |
+
details["Actual"] = actual_list
|
726 |
+
details["Modified"] = modified_list
|
727 |
+
columns_index = pd.MultiIndex.from_product(
|
728 |
+
[[""], ["Channel"]], names=["first", "second"]
|
729 |
+
)
|
730 |
+
columns_index = columns_index.append(
|
731 |
+
pd.MultiIndex.from_product(
|
732 |
+
[["Spends", "NRPU", "ROI", "MROI"], ["Actual", "Simulated"]],
|
733 |
+
names=["first", "second"],
|
734 |
+
)
|
735 |
+
)
|
736 |
+
details["Summary"] = pd.DataFrame(summary_rows, columns=columns_index)
|
737 |
+
data_df = pd.DataFrame(data)
|
738 |
+
channel_list = list(self.channels.keys())
|
739 |
+
data_df = data_df[["Date", *channel_list, "Sales"]]
|
740 |
+
|
741 |
+
details["download"] = {
|
742 |
+
"data_df": data_df,
|
743 |
+
"channels_df": pd.DataFrame(channel_data),
|
744 |
+
"total_spends_act": self.actual_total_spends,
|
745 |
+
"total_sales_act": self.actual_total_sales,
|
746 |
+
"total_spends_mod": self.modified_total_spends,
|
747 |
+
"total_sales_mod": self.modified_total_sales,
|
748 |
+
}
|
749 |
+
|
750 |
+
return details
|
751 |
+
|
752 |
+
@classmethod
|
753 |
+
def from_dict(cls, attr_dict):
|
754 |
+
channels_list = attr_dict["channels"]
|
755 |
+
channels = {
|
756 |
+
channel["name"]: class_from_dict(channel) for channel in channels_list
|
757 |
+
}
|
758 |
+
return Scenario(
|
759 |
+
name=attr_dict["name"],
|
760 |
+
channels=channels,
|
761 |
+
constant=attr_dict["constant"],
|
762 |
+
correction=attr_dict["correction"],
|
763 |
+
)
|
single_manifest.yml
ADDED
File without changes
|
styles.css
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
html {
|
2 |
+
margin: 0;
|
3 |
+
}
|
4 |
+
|
5 |
+
|
6 |
+
#MainMenu {
|
7 |
+
|
8 |
+
visibility: collapse;
|
9 |
+
}
|
10 |
+
|
11 |
+
footer {
|
12 |
+
visibility: collapse;
|
13 |
+
}
|
14 |
+
|
15 |
+
div.block-container{
|
16 |
+
padding: 2rem 3rem;
|
17 |
+
}
|
18 |
+
|
19 |
+
div[data-testid="stExpander"]{
|
20 |
+
border: 1px solid '#739FAE';
|
21 |
+
}
|
22 |
+
|
23 |
+
a[data-testid="stPageLink-NavLink"]{
|
24 |
+
border: 1px solid rgba(0, 110, 192, 0.2);
|
25 |
+
margin: 0;
|
26 |
+
padding: 0.25rem 0.5rem;
|
27 |
+
border-radius: 0.5rem;
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
hr{
|
33 |
+
margin: 0;
|
34 |
+
padding: 0;
|
35 |
+
}
|
36 |
+
|
37 |
+
hr.spends-heading-seperator {
|
38 |
+
background-color : #11B6BD;
|
39 |
+
height: 2px;
|
40 |
+
}
|
41 |
+
|
42 |
+
.spends-header{
|
43 |
+
font-size: 1rem;
|
44 |
+
font-weight: bold;
|
45 |
+
margin: 0;
|
46 |
+
|
47 |
+
}
|
48 |
+
|
49 |
+
td {
|
50 |
+
max-width: 100px;
|
51 |
+
/* white-space:nowrap; */
|
52 |
+
}
|
53 |
+
|
54 |
+
.main-header {
|
55 |
+
display: flex;
|
56 |
+
flex-direction: row;
|
57 |
+
justify-content: space-between;
|
58 |
+
align-items: center;
|
59 |
+
|
60 |
+
}
|
61 |
+
.blend-logo {
|
62 |
+
max-height: 64px;
|
63 |
+
/* max-width: 300px; */
|
64 |
+
object-fit: cover;
|
65 |
+
}
|
66 |
+
|
67 |
+
table {
|
68 |
+
width: 90%;
|
69 |
+
}
|
70 |
+
|
71 |
+
.lime-logo {
|
72 |
+
margin: 0;
|
73 |
+
padding: 0;
|
74 |
+
display: flex;
|
75 |
+
align-items: center ;
|
76 |
+
max-height: 64px;
|
77 |
+
}
|
78 |
+
|
79 |
+
.lime-text {
|
80 |
+
color: #00EDED;
|
81 |
+
font-size: 30px;
|
82 |
+
margin: 0;
|
83 |
+
padding: 0;
|
84 |
+
line-height: 0%;
|
85 |
+
}
|
86 |
+
|
87 |
+
.lime-img {
|
88 |
+
max-height: 30px;
|
89 |
+
/* max-height: 64px; */
|
90 |
+
/* max-width: 300px; */
|
91 |
+
object-fit: cover;
|
92 |
+
}
|
93 |
+
|
94 |
+
img {
|
95 |
+
margin: 0;
|
96 |
+
padding: 0;
|
97 |
+
}
|
temp_stdout.txt
ADDED
File without changes
|
utilities.py
ADDED
@@ -0,0 +1,2155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
from scenario import Channel, Scenario
|
5 |
+
import numpy as np
|
6 |
+
from plotly.subplots import make_subplots
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
from scenario import class_to_dict
|
9 |
+
from collections import OrderedDict
|
10 |
+
import io
|
11 |
+
import plotly
|
12 |
+
from pathlib import Path
|
13 |
+
import pickle
|
14 |
+
import yaml
|
15 |
+
from yaml import SafeLoader
|
16 |
+
from streamlit.components.v1 import html
|
17 |
+
import smtplib
|
18 |
+
from scipy.optimize import curve_fit
|
19 |
+
from sklearn.metrics import r2_score
|
20 |
+
from scenario import class_from_dict, class_convert_to_dict
|
21 |
+
import os
|
22 |
+
import base64
|
23 |
+
import sqlite3
|
24 |
+
import datetime
|
25 |
+
from scenario import numerize
|
26 |
+
import psycopg2
|
27 |
+
|
28 |
+
#
|
29 |
+
import re
|
30 |
+
import bcrypt
|
31 |
+
import os
|
32 |
+
import json
|
33 |
+
import glob
|
34 |
+
import pickle
|
35 |
+
import streamlit as st
|
36 |
+
import streamlit as st
|
37 |
+
import pandas as pd
|
38 |
+
import json
|
39 |
+
from scenario import Channel, Scenario
|
40 |
+
import numpy as np
|
41 |
+
from plotly.subplots import make_subplots
|
42 |
+
import plotly.graph_objects as go
|
43 |
+
from scenario import class_to_dict
|
44 |
+
from collections import OrderedDict
|
45 |
+
import io
|
46 |
+
import plotly
|
47 |
+
from pathlib import Path
|
48 |
+
import pickle
|
49 |
+
import yaml
|
50 |
+
from yaml import SafeLoader
|
51 |
+
from streamlit.components.v1 import html
|
52 |
+
import smtplib
|
53 |
+
from scipy.optimize import curve_fit
|
54 |
+
from sklearn.metrics import r2_score
|
55 |
+
from scenario import class_from_dict, class_convert_to_dict
|
56 |
+
import os
|
57 |
+
import base64
|
58 |
+
import sqlite3
|
59 |
+
import datetime
|
60 |
+
from scenario import numerize
|
61 |
+
import sqlite3
|
62 |
+
|
63 |
+
# # schema = db_cred["schema"]
|
64 |
+
|
65 |
+
color_palette = [
|
66 |
+
"#F3F3F0",
|
67 |
+
"#5E7D7E",
|
68 |
+
"#2FA1FF",
|
69 |
+
"#00EDED",
|
70 |
+
"#00EAE4",
|
71 |
+
"#304550",
|
72 |
+
"#EDEBEB",
|
73 |
+
"#7FBEFD",
|
74 |
+
"#003059",
|
75 |
+
"#A2F3F3",
|
76 |
+
"#E1D6E2",
|
77 |
+
"#B6B6B6",
|
78 |
+
]
|
79 |
+
|
80 |
+
|
81 |
+
CURRENCY_INDICATOR = "$"
|
82 |
+
db_cred = None
|
83 |
+
# database_file = r"DB/User.db"
|
84 |
+
|
85 |
+
# conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db
|
86 |
+
# c = conn.cursor()
|
87 |
+
|
88 |
+
|
89 |
+
# def query_excecuter_postgres(
|
90 |
+
# query,
|
91 |
+
# db_cred,
|
92 |
+
# params=None,
|
93 |
+
# insert=True,
|
94 |
+
# insert_retrieve=False,
|
95 |
+
# ):
|
96 |
+
# """
|
97 |
+
# Executes a SQL query on a PostgreSQL database, handling both insert and select operations.
|
98 |
+
|
99 |
+
# Parameters:
|
100 |
+
# query (str): The SQL query to be executed.
|
101 |
+
# params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
|
102 |
+
# insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
|
103 |
+
# insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
|
104 |
+
|
105 |
+
# """
|
106 |
+
# # Database connection parameters
|
107 |
+
# dbname = db_cred["dbname"]
|
108 |
+
# user = db_cred["user"]
|
109 |
+
# password = db_cred["password"]
|
110 |
+
# host = db_cred["host"]
|
111 |
+
# port = db_cred["port"]
|
112 |
+
|
113 |
+
# try:
|
114 |
+
# # Establish connection to the PostgreSQL database
|
115 |
+
# conn = psycopg2.connect(
|
116 |
+
# dbname=dbname, user=user, password=password, host=host, port=port
|
117 |
+
# )
|
118 |
+
# except psycopg2.Error as e:
|
119 |
+
# st.warning(f"Unable to connect to the database: {e}")
|
120 |
+
# st.stop()
|
121 |
+
|
122 |
+
# # Create a cursor object to interact with the database
|
123 |
+
# c = conn.cursor()
|
124 |
+
|
125 |
+
# try:
|
126 |
+
# # Execute the query with or without parameters
|
127 |
+
# if params:
|
128 |
+
# c.execute(query, params)
|
129 |
+
# else:
|
130 |
+
# c.execute(query)
|
131 |
+
|
132 |
+
# if not insert:
|
133 |
+
# # If not an insert operation, fetch and return the results
|
134 |
+
# results = c.fetchall()
|
135 |
+
# return results
|
136 |
+
# elif insert_retrieve:
|
137 |
+
# # If insert and retrieve operation, fetch and return the results
|
138 |
+
# conn.commit()
|
139 |
+
# return c.fetchall()
|
140 |
+
# else:
|
141 |
+
# conn.commit()
|
142 |
+
|
143 |
+
# except Exception as e:
|
144 |
+
# st.write(f"Error executing query: {e}")
|
145 |
+
# finally:
|
146 |
+
# conn.close()
|
147 |
+
|
148 |
+
|
149 |
+
db_path = os.path.join("imp_db.db")
|
150 |
+
|
151 |
+
|
152 |
+
def query_excecuter_postgres(
|
153 |
+
query, db_path=None, params=None, insert=True, insert_retrieve=False, db_cred=None
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
Executes a SQL query on a SQLite database, handling both insert and select operations.
|
157 |
+
|
158 |
+
Parameters:
|
159 |
+
query (str): The SQL query to be executed.
|
160 |
+
db_path (str): Path to the SQLite database file.
|
161 |
+
params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
|
162 |
+
insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
|
163 |
+
insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
|
164 |
+
|
165 |
+
"""
|
166 |
+
try:
|
167 |
+
# Construct a cross-platform path to the database
|
168 |
+
db_dir = os.path.join("db")
|
169 |
+
os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists
|
170 |
+
db_path = os.path.join(db_dir, "imp_db.db")
|
171 |
+
|
172 |
+
# Establish connection to the SQLite database
|
173 |
+
conn = sqlite3.connect(db_path)
|
174 |
+
except sqlite3.Error as e:
|
175 |
+
st.warning(f"Unable to connect to the SQLite database: {e}")
|
176 |
+
st.stop()
|
177 |
+
|
178 |
+
# Create a cursor object to interact with the database
|
179 |
+
c = conn.cursor()
|
180 |
+
|
181 |
+
# Prepare the query with proper placeholders
|
182 |
+
if params:
|
183 |
+
# Handle the `IN (?)` clause dynamically
|
184 |
+
query = query.replace("IN (?)", f"IN ({','.join(['?' for _ in params])})")
|
185 |
+
c.execute(query, params)
|
186 |
+
else:
|
187 |
+
c.execute(query)
|
188 |
+
|
189 |
+
try:
|
190 |
+
if not insert:
|
191 |
+
# If not an insert operation, fetch and return the results
|
192 |
+
results = c.fetchall()
|
193 |
+
return results
|
194 |
+
elif insert_retrieve:
|
195 |
+
# If insert and retrieve operation, commit and return the last inserted row ID
|
196 |
+
conn.commit()
|
197 |
+
return c.lastrowid
|
198 |
+
else:
|
199 |
+
# For standard insert operations, commit the transaction
|
200 |
+
conn.commit()
|
201 |
+
|
202 |
+
except Exception as e:
|
203 |
+
st.write(f"Error executing query: {e}")
|
204 |
+
finally:
|
205 |
+
conn.close()
|
206 |
+
|
207 |
+
|
208 |
+
def update_summary_df():
|
209 |
+
"""
|
210 |
+
Updates the 'project_summary_df' in the session state with the latest project
|
211 |
+
summary information based on the most recent updates.
|
212 |
+
|
213 |
+
This function executes a SQL query to retrieve project metadata from a database
|
214 |
+
and stores the result in the session state.
|
215 |
+
|
216 |
+
Uses:
|
217 |
+
- query_excecuter_postgres(query, params=params, insert=False): A function that
|
218 |
+
executes the provided SQL query on a PostgreSQL database.
|
219 |
+
|
220 |
+
Modifies:
|
221 |
+
- st.session_state['project_summary_df']: Updates the dataframe with columns:
|
222 |
+
'Project Number', 'Project Name', 'Last Modified Page', 'Last Modified Time'.
|
223 |
+
"""
|
224 |
+
|
225 |
+
query = f"""
|
226 |
+
WITH LatestUpdates AS (
|
227 |
+
SELECT
|
228 |
+
prj_id,
|
229 |
+
page_nam,
|
230 |
+
updt_dt_tm,
|
231 |
+
ROW_NUMBER() OVER (PARTITION BY prj_id ORDER BY updt_dt_tm DESC) AS rn
|
232 |
+
FROM
|
233 |
+
mmo_project_meta_data
|
234 |
+
)
|
235 |
+
SELECT
|
236 |
+
p.prj_id,
|
237 |
+
p.prj_nam AS prj_nam,
|
238 |
+
lu.page_nam,
|
239 |
+
lu.updt_dt_tm
|
240 |
+
FROM
|
241 |
+
LatestUpdates lu
|
242 |
+
RIGHT JOIN
|
243 |
+
mmo_projects p ON lu.prj_id = p.prj_id
|
244 |
+
WHERE
|
245 |
+
p.prj_ownr_id = ? AND lu.rn = 1
|
246 |
+
"""
|
247 |
+
|
248 |
+
params = (st.session_state["emp_id"],) # Parameters for the SQL query
|
249 |
+
|
250 |
+
# Execute the query and retrieve project summary data
|
251 |
+
project_summary = query_excecuter_postgres(
|
252 |
+
query, db_cred, params=params, insert=False
|
253 |
+
)
|
254 |
+
|
255 |
+
# Update the session state with the project summary dataframe
|
256 |
+
st.session_state["project_summary_df"] = pd.DataFrame(
|
257 |
+
project_summary,
|
258 |
+
columns=[
|
259 |
+
"Project Number",
|
260 |
+
"Project Name",
|
261 |
+
"Last Modified Page",
|
262 |
+
"Last Modified Time",
|
263 |
+
],
|
264 |
+
)
|
265 |
+
|
266 |
+
st.session_state["project_summary_df"] = st.session_state[
|
267 |
+
"project_summary_df"
|
268 |
+
].sort_values(by=["Last Modified Time"], ascending=False)
|
269 |
+
|
270 |
+
return st.session_state["project_summary_df"]
|
271 |
+
|
272 |
+
|
273 |
+
from constants import default_dct
|
274 |
+
|
275 |
+
|
276 |
+
def ensure_project_dct_structure(session_state, default_dct):
|
277 |
+
for key, value in default_dct.items():
|
278 |
+
if key not in session_state:
|
279 |
+
session_state[key] = value
|
280 |
+
elif isinstance(value, dict):
|
281 |
+
ensure_project_dct_structure(session_state[key], value)
|
282 |
+
|
283 |
+
|
284 |
+
def project_selection():
|
285 |
+
|
286 |
+
emp_id = st.text_input("employee id", key="emp1111").lower()
|
287 |
+
password = st.text_input("Password", max_chars=15, type="password")
|
288 |
+
|
289 |
+
if st.button("Login"):
|
290 |
+
|
291 |
+
if "unique_ids" not in st.session_state:
|
292 |
+
unique_users_query = f"""
|
293 |
+
SELECT DISTINCT emp_id, emp_nam, emp_typ from mmo_users;
|
294 |
+
"""
|
295 |
+
unique_users_result = query_excecuter_postgres(
|
296 |
+
unique_users_query, db_cred, insert=False
|
297 |
+
) # retrieves all the users who has access to MMO TOOL
|
298 |
+
st.session_state["unique_ids"] = {
|
299 |
+
emp_id: (emp_nam, emp_type)
|
300 |
+
for emp_id, emp_nam, emp_type in unique_users_result
|
301 |
+
}
|
302 |
+
|
303 |
+
if emp_id not in st.session_state["unique_ids"].keys() or len(password) == 0:
|
304 |
+
st.warning("invalid id or password!")
|
305 |
+
st.stop()
|
306 |
+
|
307 |
+
if not is_pswrd_flag_set(emp_id):
|
308 |
+
st.warning("Reset password in home page to continue")
|
309 |
+
st.stop()
|
310 |
+
|
311 |
+
elif not verify_password(emp_id, password):
|
312 |
+
st.warning("Invalid user name or password")
|
313 |
+
st.stop()
|
314 |
+
|
315 |
+
else:
|
316 |
+
st.session_state["emp_id"] = emp_id
|
317 |
+
st.session_state["username"] = st.session_state["unique_ids"][
|
318 |
+
st.session_state["emp_id"]
|
319 |
+
][0]
|
320 |
+
|
321 |
+
with st.spinner("Loading Saved Projects"):
|
322 |
+
st.session_state["project_summary_df"] = update_summary_df()
|
323 |
+
|
324 |
+
# st.write(st.session_state["project_name"][0])
|
325 |
+
if len(st.session_state["project_summary_df"]) == 0:
|
326 |
+
st.warning("No projects found please create a project in Home page")
|
327 |
+
st.stop()
|
328 |
+
|
329 |
+
else:
|
330 |
+
|
331 |
+
try:
|
332 |
+
st.session_state["project_name"] = (
|
333 |
+
st.session_state["project_summary_df"]
|
334 |
+
.loc[
|
335 |
+
st.session_state["project_summary_df"]["Project Number"]
|
336 |
+
== st.session_state["project_summary_df"].iloc[0, 0],
|
337 |
+
"Project Name",
|
338 |
+
]
|
339 |
+
.values[0]
|
340 |
+
) # fetching project name from project number stored in summary df
|
341 |
+
|
342 |
+
poroject_dct_query = f"""
|
343 |
+
|
344 |
+
SELECT pkl_obj FROM mmo_project_meta_data WHERE prj_id = ? AND file_nam=?;
|
345 |
+
|
346 |
+
"""
|
347 |
+
# Execute the query and retrieve the result
|
348 |
+
|
349 |
+
project_number = int(st.session_state["project_summary_df"].iloc[0, 0])
|
350 |
+
|
351 |
+
st.session_state["project_number"] = project_number
|
352 |
+
|
353 |
+
project_dct_retrieved = query_excecuter_postgres(
|
354 |
+
poroject_dct_query,
|
355 |
+
db_cred,
|
356 |
+
params=(project_number, "project_dct"),
|
357 |
+
insert=False,
|
358 |
+
)
|
359 |
+
# retrieves project dict (meta data) stored in db
|
360 |
+
|
361 |
+
st.session_state["project_dct"] = pickle.loads(
|
362 |
+
project_dct_retrieved[0][0]
|
363 |
+
) # converting bytes data to original objet using pickle
|
364 |
+
ensure_project_dct_structure(
|
365 |
+
st.session_state["project_dct"], default_dct
|
366 |
+
)
|
367 |
+
|
368 |
+
st.success("Project Loded")
|
369 |
+
st.rerun()
|
370 |
+
|
371 |
+
except Exception as e:
|
372 |
+
|
373 |
+
st.write(
|
374 |
+
"Failed to load project meta data from db please create new project!"
|
375 |
+
)
|
376 |
+
st.stop()
|
377 |
+
|
378 |
+
|
379 |
+
def update_db(prj_id, page_nam, file_nam, pkl_obj, resp_mtrc="", schema=""):
|
380 |
+
|
381 |
+
# Check if an entry already exists
|
382 |
+
|
383 |
+
check_query = f"""
|
384 |
+
SELECT 1 FROM mmo_project_meta_data
|
385 |
+
WHERE prj_id = ? AND file_nam =?;
|
386 |
+
"""
|
387 |
+
|
388 |
+
check_params = (prj_id, file_nam)
|
389 |
+
result = query_excecuter_postgres(
|
390 |
+
check_query, db_cred, params=check_params, insert=False
|
391 |
+
)
|
392 |
+
|
393 |
+
# If entry exists, perform an update
|
394 |
+
if result is not None and result:
|
395 |
+
|
396 |
+
update_query = f"""
|
397 |
+
UPDATE mmo_project_meta_data
|
398 |
+
SET file_nam = ?, pkl_obj = ?, page_nam=? ,updt_dt_tm = datetime('now')
|
399 |
+
|
400 |
+
WHERE prj_id = ? AND file_nam = ?;
|
401 |
+
"""
|
402 |
+
|
403 |
+
update_params = (file_nam, pkl_obj, page_nam, prj_id, file_nam)
|
404 |
+
|
405 |
+
query_excecuter_postgres(
|
406 |
+
update_query, db_cred, params=update_params, insert=True
|
407 |
+
)
|
408 |
+
|
409 |
+
# If entry does not exist, perform an insert
|
410 |
+
else:
|
411 |
+
|
412 |
+
insert_query = f"""
|
413 |
+
INSERT INTO mmo_project_meta_data
|
414 |
+
(prj_id, page_nam, file_nam, pkl_obj,crte_by_uid, crte_dt_tm, updt_dt_tm)
|
415 |
+
VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'));
|
416 |
+
"""
|
417 |
+
|
418 |
+
insert_params = (
|
419 |
+
prj_id,
|
420 |
+
page_nam,
|
421 |
+
file_nam,
|
422 |
+
pkl_obj,
|
423 |
+
st.session_state["emp_id"],
|
424 |
+
)
|
425 |
+
|
426 |
+
query_excecuter_postgres(
|
427 |
+
insert_query, db_cred, params=insert_params, insert=True
|
428 |
+
)
|
429 |
+
|
430 |
+
# st.success(f"Inserted project meta data for project {prj_id}, page {page_nam}")
|
431 |
+
|
432 |
+
|
433 |
+
def retrieve_pkl_object(prj_id, page_nam, file_nam, schema=""):
|
434 |
+
|
435 |
+
query = f"""
|
436 |
+
SELECT pkl_obj FROM mmo_project_meta_data
|
437 |
+
WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
|
438 |
+
"""
|
439 |
+
|
440 |
+
params = (prj_id, page_nam, file_nam)
|
441 |
+
result = query_excecuter_postgres(
|
442 |
+
query, db_cred=db_cred, params=params, insert=False
|
443 |
+
)
|
444 |
+
|
445 |
+
if result and result[0] and result[0][0]:
|
446 |
+
pkl_obj = result[0][0]
|
447 |
+
# Deserialize the pickle object
|
448 |
+
return pickle.loads(pkl_obj)
|
449 |
+
else:
|
450 |
+
return None
|
451 |
+
|
452 |
+
|
453 |
+
def validate_text(input_text):
|
454 |
+
|
455 |
+
# Check the length of the text
|
456 |
+
if len(input_text) < 2:
|
457 |
+
return False, "Input should be at least 2 characters long."
|
458 |
+
if len(input_text) > 30:
|
459 |
+
return False, "Input should not exceed 30 characters."
|
460 |
+
|
461 |
+
# Check if the text contains only allowed characters
|
462 |
+
if not re.match(r"^[A-Za-z0-9_]+$", input_text):
|
463 |
+
return (
|
464 |
+
False,
|
465 |
+
"Input contains invalid characters. Only letters, numbers and underscores are allowed.",
|
466 |
+
)
|
467 |
+
|
468 |
+
return True, "Input is valid."
|
469 |
+
|
470 |
+
|
471 |
+
def delete_entries(prj_id, page_names, db_cred=None, schema=None):
|
472 |
+
"""
|
473 |
+
Deletes all entries from the project_meta_data table based on prj_id and a list of page names.
|
474 |
+
|
475 |
+
Parameters:
|
476 |
+
prj_id (int): The project ID.
|
477 |
+
page_names (list): A list of page names.
|
478 |
+
db_cred (dict): Database credentials with keys 'dbname', 'user', 'password', 'host', 'port'.
|
479 |
+
schema (str): The schema name.
|
480 |
+
"""
|
481 |
+
# Create placeholders for each page name in the list
|
482 |
+
placeholders = ", ".join(["?"] * len(page_names))
|
483 |
+
query = f"""
|
484 |
+
DELETE FROM mmo_project_meta_data
|
485 |
+
WHERE prj_id = ? AND page_nam IN ({placeholders});
|
486 |
+
"""
|
487 |
+
|
488 |
+
# Combine prj_id and page_names into one list of parameters
|
489 |
+
params = (prj_id, *page_names)
|
490 |
+
|
491 |
+
query_excecuter_postgres(query, db_cred, params=params, insert=True)
|
492 |
+
|
493 |
+
|
494 |
+
# st.success(f"Deleted entries for project {prj_id}, page {page_name}")
|
495 |
+
def store_hashed_password(
|
496 |
+
user_id,
|
497 |
+
plain_text_password,
|
498 |
+
):
|
499 |
+
"""
|
500 |
+
Hashes a plain text password using bcrypt, converts it to a UTF-8 string, and stores it as text.
|
501 |
+
|
502 |
+
Parameters:
|
503 |
+
plain_text_password (str): The plain text password to be hashed.
|
504 |
+
db_cred (dict): The database credentials including dbname, user, password, host, and port.
|
505 |
+
"""
|
506 |
+
# Hash the plain text password
|
507 |
+
hashed_password = bcrypt.hashpw(
|
508 |
+
plain_text_password.encode("utf-8"), bcrypt.gensalt()
|
509 |
+
)
|
510 |
+
|
511 |
+
# Convert the byte string to a regular string for storage
|
512 |
+
hashed_password_str = hashed_password.decode("utf-8")
|
513 |
+
|
514 |
+
# SQL query to update the pswrd_key for the specified user_id
|
515 |
+
query = f"""
|
516 |
+
UPDATE mmo_users
|
517 |
+
SET pswrd_key = ?
|
518 |
+
WHERE emp_id = ?;
|
519 |
+
"""
|
520 |
+
|
521 |
+
# Execute the query using the existing query_excecuter_postgres function
|
522 |
+
query_excecuter_postgres(
|
523 |
+
query=query, db_cred=db_cred, params=(hashed_password_str, user_id), insert=True
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
def verify_password(user_id, plain_text_password):
|
528 |
+
"""
|
529 |
+
Verifies the plain text password against the stored hashed password for the specified user_id.
|
530 |
+
|
531 |
+
Parameters:
|
532 |
+
user_id (int): The ID of the user whose password is being verified.
|
533 |
+
plain_text_password (str): The plain text password to verify.
|
534 |
+
db_cred (dict): The database credentials including dbname, user, password, host, and port.
|
535 |
+
"""
|
536 |
+
# SQL query to retrieve the hashed password for the user_id
|
537 |
+
query = f"""
|
538 |
+
SELECT pswrd_key FROM mmo_users WHERE emp_id = ?;
|
539 |
+
"""
|
540 |
+
|
541 |
+
# Execute the query using the existing query_excecuter_postgres function
|
542 |
+
result = query_excecuter_postgres(
|
543 |
+
query=query, db_cred=db_cred, params=(user_id,), insert=False
|
544 |
+
)
|
545 |
+
|
546 |
+
if result:
|
547 |
+
|
548 |
+
stored_hashed_password_str = result[0][0]
|
549 |
+
# Convert the stored string back to bytes
|
550 |
+
stored_hashed_password = stored_hashed_password_str.encode("utf-8")
|
551 |
+
|
552 |
+
if bcrypt.checkpw(plain_text_password.encode("utf-8"), stored_hashed_password):
|
553 |
+
|
554 |
+
return True
|
555 |
+
else:
|
556 |
+
|
557 |
+
return False
|
558 |
+
else:
|
559 |
+
|
560 |
+
return False
|
561 |
+
|
562 |
+
|
563 |
+
def update_password_in_db(user_id, plain_text_password):
|
564 |
+
"""
|
565 |
+
Hashes the plain text password and updates the `pswrd_key`
|
566 |
+
column for the given `emp_id` in the `mmo_users` table.
|
567 |
+
|
568 |
+
Parameters:
|
569 |
+
emp_id (var): The ID of the user whose password needs to be updated.
|
570 |
+
plain_text_password (str): The plain text password to be hashed and stored.
|
571 |
+
db_cred (dict): Database credentials required to connect to the database.
|
572 |
+
"""
|
573 |
+
# Hash the plain text password using bcrypt
|
574 |
+
hashed_password = bcrypt.hashpw(
|
575 |
+
plain_text_password.encode("utf-8"), bcrypt.gensalt()
|
576 |
+
)
|
577 |
+
|
578 |
+
# Convert the hashed password from bytes to a string for storage
|
579 |
+
hashed_password_str = hashed_password.decode("utf-8")
|
580 |
+
|
581 |
+
# SQL query to update the password in the database
|
582 |
+
query = f"""
|
583 |
+
UPDATE mmo_users
|
584 |
+
SET pswrd_key = ?
|
585 |
+
WHERE emp_id = ?
|
586 |
+
"""
|
587 |
+
|
588 |
+
# Parameters for the query
|
589 |
+
params = (hashed_password_str, user_id)
|
590 |
+
|
591 |
+
# Execute the query using the query_excecuter_postgres function
|
592 |
+
query_excecuter_postgres(query, db_cred, params=params, insert=True)
|
593 |
+
|
594 |
+
|
595 |
+
def is_pswrd_flag_set(user_id):
|
596 |
+
query = f"""
|
597 |
+
SELECT pswrd_flag
|
598 |
+
FROM mmo_users
|
599 |
+
WHERE emp_id = ?;
|
600 |
+
"""
|
601 |
+
|
602 |
+
# Execute the query
|
603 |
+
result = query_excecuter_postgres(query, db_cred, params=(user_id,), insert=False)
|
604 |
+
|
605 |
+
# Return True if the flag is 1, otherwise return False
|
606 |
+
if result and result[0][0] == 1:
|
607 |
+
return True
|
608 |
+
else:
|
609 |
+
return False
|
610 |
+
|
611 |
+
|
612 |
+
def set_pswrd_flag(user_id):
|
613 |
+
query = f"""
|
614 |
+
UPDATE mmo_users
|
615 |
+
SET pswrd_flag = 1
|
616 |
+
WHERE emp_id = ?;
|
617 |
+
"""
|
618 |
+
|
619 |
+
# Execute the update query
|
620 |
+
query_excecuter_postgres(query, db_cred, params=(user_id,), insert=True)
|
621 |
+
|
622 |
+
|
623 |
+
def retrieve_pkl_object_without_warning(prj_id, page_nam, file_nam, schema):
|
624 |
+
|
625 |
+
query = f"""
|
626 |
+
SELECT pkl_obj FROM mmo_project_meta_data
|
627 |
+
WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
|
628 |
+
"""
|
629 |
+
|
630 |
+
params = (prj_id, page_nam, file_nam)
|
631 |
+
result = query_excecuter_postgres(
|
632 |
+
query, db_cred=db_cred, params=params, insert=False
|
633 |
+
)
|
634 |
+
|
635 |
+
if result and result[0] and result[0][0]:
|
636 |
+
pkl_obj = result[0][0]
|
637 |
+
# Deserialize the pickle object
|
638 |
+
return pickle.loads(pkl_obj)
|
639 |
+
else:
|
640 |
+
# st.warning(
|
641 |
+
# "Pickle object not found for the given project ID, page name, and file name."
|
642 |
+
# )
|
643 |
+
return None
|
644 |
+
|
645 |
+
|
646 |
+
color_palette = [
|
647 |
+
"#F3F3F0",
|
648 |
+
"#5E7D7E",
|
649 |
+
"#2FA1FF",
|
650 |
+
"#00EDED",
|
651 |
+
"#00EAE4",
|
652 |
+
"#304550",
|
653 |
+
"#EDEBEB",
|
654 |
+
"#7FBEFD",
|
655 |
+
"#003059",
|
656 |
+
"#A2F3F3",
|
657 |
+
"#E1D6E2",
|
658 |
+
"#B6B6B6",
|
659 |
+
]
|
660 |
+
|
661 |
+
|
662 |
+
CURRENCY_INDICATOR = "$"
|
663 |
+
|
664 |
+
|
665 |
+
# database_file = r"DB/User.db"
|
666 |
+
|
667 |
+
# conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db
|
668 |
+
# c = conn.cursor()
|
669 |
+
|
670 |
+
|
671 |
+
# def load_authenticator():
|
672 |
+
# with open("config.yaml") as file:
|
673 |
+
# config = yaml.load(file, Loader=SafeLoader)
|
674 |
+
# st.session_state["config"] = config
|
675 |
+
# authenticator = stauth.Authenticate(
|
676 |
+
# credentials=config["credentials"],
|
677 |
+
# cookie_name=config["cookie"]["name"],
|
678 |
+
# key=config["cookie"]["key"],
|
679 |
+
# cookie_expiry_days=config["cookie"]["expiry_days"],
|
680 |
+
# preauthorized=config["preauthorized"],
|
681 |
+
# )
|
682 |
+
# st.session_state["authenticator"] = authenticator
|
683 |
+
# return authenticator
|
684 |
+
|
685 |
+
|
686 |
+
# Authentication
|
687 |
+
# def authenticator():
|
688 |
+
# for k, v in st.session_state.items():
|
689 |
+
# if k not in ["logout", "login", "config"] and not k.startswith(
|
690 |
+
# "FormSubmitter"
|
691 |
+
# ):
|
692 |
+
# st.session_state[k] = v
|
693 |
+
# with open("config.yaml") as file:
|
694 |
+
# config = yaml.load(file, Loader=SafeLoader)
|
695 |
+
# st.session_state["config"] = config
|
696 |
+
# authenticator = stauth.Authenticate(
|
697 |
+
# config["credentials"],
|
698 |
+
# config["cookie"]["name"],
|
699 |
+
# config["cookie"]["key"],
|
700 |
+
# config["cookie"]["expiry_days"],
|
701 |
+
# config["preauthorized"],
|
702 |
+
# )
|
703 |
+
# st.session_state["authenticator"] = authenticator
|
704 |
+
# name, authentication_status, username = authenticator.login(
|
705 |
+
# "Login", "main"
|
706 |
+
# )
|
707 |
+
# auth_status = st.session_state.get("authentication_status")
|
708 |
+
|
709 |
+
# if auth_status == True:
|
710 |
+
# authenticator.logout("Logout", "main")
|
711 |
+
# is_state_initiaized = st.session_state.get("initialized", False)
|
712 |
+
|
713 |
+
# if not is_state_initiaized:
|
714 |
+
|
715 |
+
# if "session_name" not in st.session_state:
|
716 |
+
# st.session_state["session_name"] = None
|
717 |
+
|
718 |
+
# return name
|
719 |
+
|
720 |
+
|
721 |
+
# def authentication():
|
722 |
+
# with open("config.yaml") as file:
|
723 |
+
# config = yaml.load(file, Loader=SafeLoader)
|
724 |
+
|
725 |
+
# authenticator = stauth.Authenticate(
|
726 |
+
# config["credentials"],
|
727 |
+
# config["cookie"]["name"],
|
728 |
+
# config["cookie"]["key"],
|
729 |
+
# config["cookie"]["expiry_days"],
|
730 |
+
# config["preauthorized"],
|
731 |
+
# )
|
732 |
+
|
733 |
+
# name, authentication_status, username = authenticator.login(
|
734 |
+
# "Login", "main"
|
735 |
+
# )
|
736 |
+
# return authenticator, name, authentication_status, username
|
737 |
+
|
738 |
+
|
739 |
+
def nav_page(page_name, timeout_secs=3):
|
740 |
+
nav_script = """
|
741 |
+
<script type="text/javascript">
|
742 |
+
function attempt_nav_page(page_name, start_time, timeout_secs) {
|
743 |
+
var links = window.parent.document.getElementsByTagName("a");
|
744 |
+
for (var i = 0; i < links.length; i++) {
|
745 |
+
if (links[i].href.toLowerCase().endsWith("/" + page_name.toLowerCase())) {
|
746 |
+
links[i].click();
|
747 |
+
return;
|
748 |
+
}
|
749 |
+
}
|
750 |
+
var elasped = new Date() - start_time;
|
751 |
+
if (elasped < timeout_secs * 1000) {
|
752 |
+
setTimeout(attempt_nav_page, 100, page_name, start_time, timeout_secs);
|
753 |
+
} else {
|
754 |
+
alert("Unable to navigate to page '" + page_name + "' after " + timeout_secs + " second(s).");
|
755 |
+
}
|
756 |
+
}
|
757 |
+
window.addEventListener("load", function() {
|
758 |
+
attempt_nav_page("%s", new Date(), %d);
|
759 |
+
});
|
760 |
+
</script>
|
761 |
+
""" % (
|
762 |
+
page_name,
|
763 |
+
timeout_secs,
|
764 |
+
)
|
765 |
+
html(nav_script)
|
766 |
+
|
767 |
+
|
768 |
+
# def load_local_css(file_name):
|
769 |
+
# with open(file_name) as f:
|
770 |
+
# st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
771 |
+
|
772 |
+
|
773 |
+
# def set_header():
|
774 |
+
# return st.markdown(f"""<div class='main-header'>
|
775 |
+
# <h1>MMM LiME</h1>
|
776 |
+
# <img src="https://assets-global.website-files.com/64c8fffb0e95cbc525815b79/64df84637f83a891c1473c51_Vector%20(Stroke).svg ">
|
777 |
+
# </div>""", unsafe_allow_html=True)
|
778 |
+
|
779 |
+
path = os.path.dirname(__file__)
|
780 |
+
|
781 |
+
file_ = open(f"{path}/logo.png", "rb")
|
782 |
+
|
783 |
+
contents = file_.read()
|
784 |
+
|
785 |
+
data_url = base64.b64encode(contents).decode("utf-8")
|
786 |
+
|
787 |
+
file_.close()
|
788 |
+
|
789 |
+
|
790 |
+
DATA_PATH = "./data"
|
791 |
+
|
792 |
+
IMAGES_PATH = "./data/images_224_224"
|
793 |
+
|
794 |
+
|
795 |
+
def load_local_css(file_name):
|
796 |
+
|
797 |
+
with open(file_name) as f:
|
798 |
+
|
799 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
800 |
+
|
801 |
+
|
802 |
+
# def set_header():
|
803 |
+
|
804 |
+
# return st.markdown(f"""<div class='main-header'>
|
805 |
+
|
806 |
+
# <h1>H & M Recommendations</h1>
|
807 |
+
|
808 |
+
# <img src="data:image;base64,{data_url}", alt="Logo">
|
809 |
+
|
810 |
+
# </div>""", unsafe_allow_html=True)
|
811 |
+
path1 = os.path.dirname(__file__)
|
812 |
+
|
813 |
+
# file_1 = open(f"{path}/willbank.png", "rb")
|
814 |
+
|
815 |
+
# contents1 = file_1.read()
|
816 |
+
|
817 |
+
# data_url1 = base64.b64encode(contents1).decode("utf-8")
|
818 |
+
|
819 |
+
# file_1.close()
|
820 |
+
|
821 |
+
|
822 |
+
DATA_PATH1 = "./data"
|
823 |
+
|
824 |
+
IMAGES_PATH1 = "./data/images_224_224"
|
825 |
+
|
826 |
+
|
827 |
+
def set_header():
|
828 |
+
return st.markdown(
|
829 |
+
f"""<div class='main-header'>
|
830 |
+
<!-- <h1></h1> -->
|
831 |
+
<div >
|
832 |
+
<img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo">
|
833 |
+
</div>""",
|
834 |
+
unsafe_allow_html=True,
|
835 |
+
)
|
836 |
+
|
837 |
+
|
838 |
+
# def set_header():
|
839 |
+
# logo_path = "./path/to/your/local/LIME_logo.png" # Replace with the actual file path
|
840 |
+
# text = "LiME"
|
841 |
+
# return st.markdown(f"""<div class='main-header'>
|
842 |
+
# <img src="data:image/png;base64,{data_url}" alt="Logo" style="float: left; margin-right: 10px; width: 100px; height: auto;">
|
843 |
+
# <h1>{text}</h1>
|
844 |
+
# </div>""", unsafe_allow_html=True)
|
845 |
+
|
846 |
+
|
847 |
+
def s_curve(x, K, b, a, x0):
|
848 |
+
return K / (1 + b * np.exp(-a * (x - x0)))
|
849 |
+
|
850 |
+
|
851 |
+
def panel_level(input_df, date_column="Date"):
|
852 |
+
# Ensure 'Date' is set as the index
|
853 |
+
if date_column not in input_df.index.names:
|
854 |
+
input_df = input_df.set_index(date_column)
|
855 |
+
|
856 |
+
# Select numeric columns only (excluding 'Date' since it's now the index)
|
857 |
+
numeric_columns_df = input_df.select_dtypes(include="number")
|
858 |
+
|
859 |
+
# Group by 'Date' (which is the index) and sum the numeric columns
|
860 |
+
aggregated_df = numeric_columns_df.groupby(input_df.index).sum()
|
861 |
+
|
862 |
+
# Reset the index to bring the 'Date' column
|
863 |
+
aggregated_df = aggregated_df.reset_index()
|
864 |
+
|
865 |
+
return aggregated_df
|
866 |
+
|
867 |
+
|
868 |
+
def fetch_actual_data(
|
869 |
+
panel=None,
|
870 |
+
target_file="Overview_data_test.xlsx",
|
871 |
+
updated_rcs=None,
|
872 |
+
metrics=None,
|
873 |
+
):
|
874 |
+
excel = pd.read_excel(Path(target_file), sheet_name=None)
|
875 |
+
|
876 |
+
# Extract dataframes for raw data, spend input, and contribution MMM
|
877 |
+
raw_df = excel["RAW DATA MMM"]
|
878 |
+
spend_df = excel["SPEND INPUT"]
|
879 |
+
contri_df = excel["CONTRIBUTION MMM"]
|
880 |
+
|
881 |
+
# Check if the panel is not None
|
882 |
+
if panel is not None and panel != "Aggregated":
|
883 |
+
raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
|
884 |
+
spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
|
885 |
+
contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
|
886 |
+
elif panel == "Aggregated":
|
887 |
+
raw_df = panel_level(raw_df, date_column="Date")
|
888 |
+
spend_df = panel_level(spend_df, date_column="Week")
|
889 |
+
contri_df = panel_level(contri_df, date_column="Date")
|
890 |
+
|
891 |
+
# Revenue_df = excel['Revenue']
|
892 |
+
|
893 |
+
## remove sesonalities, indices etc ...
|
894 |
+
unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
|
895 |
+
## remove sesonalities, indices etc ...
|
896 |
+
|
897 |
+
exclude_columns = [
|
898 |
+
"Date",
|
899 |
+
"Region",
|
900 |
+
"Controls_Grammarly_Index_SeasonalAVG",
|
901 |
+
"Controls_Quillbot_Index",
|
902 |
+
"Daily_Positive_Outliers",
|
903 |
+
"External_RemoteClass_Index",
|
904 |
+
"Intervals ON 20190520-20190805 | 20200518-20200803 | 20210517-20210802",
|
905 |
+
"Intervals ON 20190826-20191209 | 20200824-20201207 | 20210823-20211206",
|
906 |
+
"Intervals ON 20201005-20201019",
|
907 |
+
"Promotion_PercentOff",
|
908 |
+
"Promotion_TimeBased",
|
909 |
+
"Seasonality_Indicator_Chirstmas",
|
910 |
+
"Seasonality_Indicator_NewYears_Days",
|
911 |
+
"Seasonality_Indicator_Thanksgiving",
|
912 |
+
"Trend 20200302 / 20200803",
|
913 |
+
] + unnamed_cols
|
914 |
+
|
915 |
+
raw_df["Date"] = pd.to_datetime(raw_df["Date"])
|
916 |
+
contri_df["Date"] = pd.to_datetime(contri_df["Date"])
|
917 |
+
input_df = raw_df.sort_values(by="Date")
|
918 |
+
output_df = contri_df.sort_values(by="Date")
|
919 |
+
spend_df["Week"] = pd.to_datetime(
|
920 |
+
spend_df["Week"], format="%Y-%m-%d", errors="coerce"
|
921 |
+
)
|
922 |
+
spend_df.sort_values(by="Week", inplace=True)
|
923 |
+
|
924 |
+
# spend_df['Week'] = pd.to_datetime(spend_df['Week'], errors='coerce')
|
925 |
+
# spend_df = spend_df.sort_values(by='Week')
|
926 |
+
|
927 |
+
channel_list = [col for col in input_df.columns if col not in exclude_columns]
|
928 |
+
channel_list = list(set(channel_list) - set(["fb_level_achieved_tier_1", "ga_app"]))
|
929 |
+
|
930 |
+
infeasible_channels = [
|
931 |
+
c
|
932 |
+
for c in contri_df.select_dtypes(include=["float", "int"]).columns
|
933 |
+
if contri_df[c].sum() <= 0
|
934 |
+
]
|
935 |
+
# st.write(channel_list)
|
936 |
+
channel_list = list(set(channel_list) - set(infeasible_channels))
|
937 |
+
|
938 |
+
upper_limits = {}
|
939 |
+
output_cols = []
|
940 |
+
actual_output_dic = {}
|
941 |
+
actual_input_dic = {}
|
942 |
+
|
943 |
+
for inp_col in channel_list:
|
944 |
+
# st.write(inp_col)
|
945 |
+
spends = input_df[inp_col].values
|
946 |
+
x = spends.copy()
|
947 |
+
# upper limit for penalty
|
948 |
+
upper_limits[inp_col] = 2 * x.max()
|
949 |
+
|
950 |
+
# contribution
|
951 |
+
# out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
|
952 |
+
out_col = inp_col
|
953 |
+
y = output_df[out_col].values.copy()
|
954 |
+
actual_output_dic[inp_col] = y.copy()
|
955 |
+
actual_input_dic[inp_col] = x.copy()
|
956 |
+
##output cols aggregation
|
957 |
+
output_cols.append(out_col)
|
958 |
+
|
959 |
+
return pd.DataFrame(actual_input_dic), pd.DataFrame(actual_output_dic)
|
960 |
+
|
961 |
+
|
962 |
+
# Function to initialize model results data
|
963 |
+
def initialize_data(panel=None, metrics=None):
|
964 |
+
# Extract dataframes for raw data, spend input, and contribution data
|
965 |
+
raw_df = st.session_state["project_dct"]["current_media_performance"][
|
966 |
+
"model_outputs"
|
967 |
+
][metrics]["raw_data"].copy()
|
968 |
+
spend_df = st.session_state["project_dct"]["current_media_performance"][
|
969 |
+
"model_outputs"
|
970 |
+
][metrics]["spends_data"].copy()
|
971 |
+
contribution_df = st.session_state["project_dct"]["current_media_performance"][
|
972 |
+
"model_outputs"
|
973 |
+
][metrics]["contribution_data"].copy()
|
974 |
+
|
975 |
+
# Check if 'Panel' or 'panel' is in the columns
|
976 |
+
panel_column = None
|
977 |
+
if "Panel" in raw_df.columns:
|
978 |
+
panel_column = "Panel"
|
979 |
+
elif "panel" in raw_df.columns:
|
980 |
+
panel_column = "panel"
|
981 |
+
|
982 |
+
# Filter data by panel if provided
|
983 |
+
if panel and panel.lower() != "aggregated":
|
984 |
+
raw_df = raw_df[raw_df[panel_column] == panel].drop(columns=[panel_column])
|
985 |
+
spend_df = spend_df[spend_df[panel_column] == panel].drop(
|
986 |
+
columns=[panel_column]
|
987 |
+
)
|
988 |
+
contribution_df = contribution_df[contribution_df[panel_column] == panel].drop(
|
989 |
+
columns=[panel_column]
|
990 |
+
)
|
991 |
+
else:
|
992 |
+
raw_df = panel_level(raw_df, date_column="Date")
|
993 |
+
spend_df = panel_level(spend_df, date_column="Date")
|
994 |
+
contribution_df = panel_level(contribution_df, date_column="Date")
|
995 |
+
|
996 |
+
# Remove unnecessary columns
|
997 |
+
unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
|
998 |
+
exclude_columns = ["Date"] + unnamed_cols
|
999 |
+
|
1000 |
+
# Convert Date columns to datetime
|
1001 |
+
for df in [raw_df, spend_df, contribution_df]:
|
1002 |
+
df["Date"] = pd.to_datetime(df["Date"], format="%Y-%m-%d", errors="coerce")
|
1003 |
+
|
1004 |
+
# Sort data by Date
|
1005 |
+
input_df = raw_df.sort_values(by="Date")
|
1006 |
+
contribution_df = contribution_df.sort_values(by="Date")
|
1007 |
+
spend_df.sort_values(by="Date", inplace=True)
|
1008 |
+
|
1009 |
+
# Extract channels excluding unwanted columns
|
1010 |
+
channel_list = [col for col in input_df.columns if col not in exclude_columns]
|
1011 |
+
|
1012 |
+
# Filter out channels with non-positive contributions
|
1013 |
+
negative_contributions = [
|
1014 |
+
col
|
1015 |
+
for col in contribution_df.select_dtypes(include=["float", "int"]).columns
|
1016 |
+
if contribution_df[col].sum() <= 0
|
1017 |
+
]
|
1018 |
+
channel_list = list(set(channel_list) - set(negative_contributions))
|
1019 |
+
|
1020 |
+
# Initialize dictionaries for metrics and response curves
|
1021 |
+
response_curves, mapes, rmses, upper_limits = {}, {}, {}, {}
|
1022 |
+
r2_scores, powers, conversion_rates, actual_output, actual_input = (
|
1023 |
+
{},
|
1024 |
+
{},
|
1025 |
+
{},
|
1026 |
+
{},
|
1027 |
+
{},
|
1028 |
+
)
|
1029 |
+
channels = {}
|
1030 |
+
sales = None
|
1031 |
+
dates = input_df["Date"].values
|
1032 |
+
|
1033 |
+
# Fit s-curve for each channel
|
1034 |
+
for channel in channel_list:
|
1035 |
+
spends = input_df[channel].values
|
1036 |
+
x = spends.copy()
|
1037 |
+
upper_limits[channel] = 2 * x.max()
|
1038 |
+
|
1039 |
+
# Get corresponding output column
|
1040 |
+
output_col = [
|
1041 |
+
_col for _col in contribution_df.columns if _col.startswith(channel)
|
1042 |
+
][0]
|
1043 |
+
y = contribution_df[output_col].values.copy()
|
1044 |
+
actual_output[channel] = y.copy()
|
1045 |
+
actual_input[channel] = x.copy()
|
1046 |
+
|
1047 |
+
# Scale input data
|
1048 |
+
power = np.ceil(np.log(x.max()) / np.log(10)) - 3
|
1049 |
+
if power >= 0:
|
1050 |
+
x = x / 10**power
|
1051 |
+
x, y = x.astype("float64"), y.astype("float64")
|
1052 |
+
|
1053 |
+
# Set bounds for curve fitting
|
1054 |
+
if y.max() <= 0.01:
|
1055 |
+
bounds = (
|
1056 |
+
(0, 0, 0, 0),
|
1057 |
+
(3 * 0.01, 1000, 1, x.max() if x.max() > 0 else 0.01),
|
1058 |
+
)
|
1059 |
+
else:
|
1060 |
+
bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
|
1061 |
+
|
1062 |
+
# Set y to 0 where x is 0
|
1063 |
+
y[x == 0] = 0
|
1064 |
+
|
1065 |
+
# Fit s-curve and calculate metrics
|
1066 |
+
# params, _ = curve_fit(
|
1067 |
+
# s_curve,
|
1068 |
+
# x
|
1069 |
+
# y,
|
1070 |
+
# p0=(2 * y.max(), 0.01, 1e-5, x.max()),
|
1071 |
+
# bounds=bounds,
|
1072 |
+
# maxfev=int(1e6),
|
1073 |
+
# )
|
1074 |
+
params, _ = curve_fit(
|
1075 |
+
s_curve,
|
1076 |
+
list(x) + [0] * len(x),
|
1077 |
+
list(y) + [0] * len(y),
|
1078 |
+
p0=(2 * y.max(), 0.01, 1e-5, x.max()),
|
1079 |
+
bounds=bounds,
|
1080 |
+
maxfev=int(1e6),
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
|
1084 |
+
rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
|
1085 |
+
r2_score_ = r2_score(y, s_curve(x, *params))
|
1086 |
+
|
1087 |
+
# Store metrics and parameters
|
1088 |
+
response_curves[channel] = {
|
1089 |
+
"K": params[0],
|
1090 |
+
"b": params[1],
|
1091 |
+
"a": params[2],
|
1092 |
+
"x0": params[3],
|
1093 |
+
}
|
1094 |
+
mapes[channel] = mape
|
1095 |
+
rmses[channel] = rmse
|
1096 |
+
r2_scores[channel] = r2_score_
|
1097 |
+
powers[channel] = power
|
1098 |
+
|
1099 |
+
conversion_rate = spend_df[channel].sum() / max(input_df[channel].sum(), 1e-9)
|
1100 |
+
conversion_rates[channel] = conversion_rate
|
1101 |
+
correction = y - s_curve(x, *params)
|
1102 |
+
|
1103 |
+
# Initialize Channel object
|
1104 |
+
channel_obj = Channel(
|
1105 |
+
name=channel,
|
1106 |
+
dates=dates,
|
1107 |
+
spends=spends,
|
1108 |
+
conversion_rate=conversion_rate,
|
1109 |
+
response_curve_type="s-curve",
|
1110 |
+
response_curve_params={
|
1111 |
+
"K": params[0],
|
1112 |
+
"b": params[1],
|
1113 |
+
"a": params[2],
|
1114 |
+
"x0": params[3],
|
1115 |
+
},
|
1116 |
+
bounds=np.array([-10, 10]),
|
1117 |
+
correction=correction,
|
1118 |
+
)
|
1119 |
+
channels[channel] = channel_obj
|
1120 |
+
if sales is None:
|
1121 |
+
sales = channel_obj.actual_sales
|
1122 |
+
else:
|
1123 |
+
sales += channel_obj.actual_sales
|
1124 |
+
|
1125 |
+
# Calculate other contributions
|
1126 |
+
other_contributions = (
|
1127 |
+
contribution_df.drop(columns=[*response_curves.keys()])
|
1128 |
+
.sum(axis=1, numeric_only=True)
|
1129 |
+
.values
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
# Initialize Scenario object
|
1133 |
+
scenario = Scenario(
|
1134 |
+
name="default",
|
1135 |
+
channels=channels,
|
1136 |
+
constant=other_contributions,
|
1137 |
+
correction=np.array([]),
|
1138 |
+
)
|
1139 |
+
|
1140 |
+
# Set session state variables
|
1141 |
+
st.session_state.update(
|
1142 |
+
{
|
1143 |
+
"initialized": True,
|
1144 |
+
"actual_df": input_df,
|
1145 |
+
"raw_df": raw_df,
|
1146 |
+
"contri_df": contribution_df,
|
1147 |
+
"default_scenario_dict": class_to_dict(scenario),
|
1148 |
+
"scenario": scenario,
|
1149 |
+
"channels_list": channel_list,
|
1150 |
+
"optimization_channels": {
|
1151 |
+
channel_name: False for channel_name in channel_list
|
1152 |
+
},
|
1153 |
+
"rcs": response_curves.copy(),
|
1154 |
+
"powers": powers,
|
1155 |
+
"actual_contribution_df": pd.DataFrame(actual_output),
|
1156 |
+
"actual_input_df": pd.DataFrame(actual_input),
|
1157 |
+
"xlsx_buffer": io.BytesIO(),
|
1158 |
+
"saved_scenarios": (
|
1159 |
+
pickle.load(open("../saved_scenarios.pkl", "rb"))
|
1160 |
+
if Path("../saved_scenarios.pkl").exists()
|
1161 |
+
else OrderedDict()
|
1162 |
+
),
|
1163 |
+
"disable_download_button": True,
|
1164 |
+
}
|
1165 |
+
)
|
1166 |
+
|
1167 |
+
for channel in channels.values():
|
1168 |
+
st.session_state[channel.name] = numerize(
|
1169 |
+
channel.actual_total_spends * channel.conversion_rate, 1
|
1170 |
+
)
|
1171 |
+
|
1172 |
+
# Prepare response curve data for output
|
1173 |
+
response_curve_data = {}
|
1174 |
+
for channel, params in st.session_state["rcs"].items():
|
1175 |
+
x = st.session_state["actual_input_df"][channel].values.astype(float)
|
1176 |
+
y = st.session_state["actual_contribution_df"][channel].values.astype(float)
|
1177 |
+
power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
|
1178 |
+
x_plot = list(np.linspace(0, 5 * max(x), 100))
|
1179 |
+
|
1180 |
+
response_curve_data[channel] = {
|
1181 |
+
"K": float(params["K"]),
|
1182 |
+
"b": float(params["b"]),
|
1183 |
+
"a": float(params["a"]),
|
1184 |
+
"x0": float(params["x0"]),
|
1185 |
+
"power": power,
|
1186 |
+
"x": list(x),
|
1187 |
+
"y": list(y),
|
1188 |
+
"x_plot": x_plot,
|
1189 |
+
}
|
1190 |
+
|
1191 |
+
return response_curve_data, scenario
|
1192 |
+
|
1193 |
+
|
1194 |
+
# def initialize_data(panel=None, metrics=None):
|
1195 |
+
# # Extract dataframes for raw data, spend input, and contribution data
|
1196 |
+
# raw_df = st.session_state["project_dct"]["current_media_performance"][
|
1197 |
+
# "model_outputs"
|
1198 |
+
# ][metrics]["raw_data"]
|
1199 |
+
# spend_df = st.session_state["project_dct"]["current_media_performance"][
|
1200 |
+
# "model_outputs"
|
1201 |
+
# ][metrics]["spends_data"]
|
1202 |
+
# contri_df = st.session_state["project_dct"]["current_media_performance"][
|
1203 |
+
# "model_outputs"
|
1204 |
+
# ][metrics]["contribution_data"]
|
1205 |
+
|
1206 |
+
# # Check if the panel is not None
|
1207 |
+
# if panel is not None and panel.lower() != "aggregated":
|
1208 |
+
# raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
|
1209 |
+
# spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
|
1210 |
+
# contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
|
1211 |
+
# elif panel.lower() == "aggregated":
|
1212 |
+
# raw_df = panel_level(raw_df, date_column="Date")
|
1213 |
+
# spend_df = panel_level(spend_df, date_column="Date")
|
1214 |
+
# contri_df = panel_level(contri_df, date_column="Date")
|
1215 |
+
|
1216 |
+
# ## remove sesonalities, indices etc ...
|
1217 |
+
# unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
|
1218 |
+
|
1219 |
+
# ## remove sesonalities, indices etc ...
|
1220 |
+
# exclude_columns = ["Date"] + unnamed_cols
|
1221 |
+
|
1222 |
+
# raw_df["Date"] = pd.to_datetime(raw_df["Date"], format="%Y-%m-%d", errors="coerce")
|
1223 |
+
# contri_df["Date"] = pd.to_datetime(
|
1224 |
+
# contri_df["Date"], format="%Y-%m-%d", errors="coerce"
|
1225 |
+
# )
|
1226 |
+
# spend_df["Date"] = pd.to_datetime(
|
1227 |
+
# spend_df["Date"], format="%Y-%m-%d", errors="coerce"
|
1228 |
+
# )
|
1229 |
+
|
1230 |
+
# input_df = raw_df.sort_values(by="Date")
|
1231 |
+
# output_df = contri_df.sort_values(by="Date")
|
1232 |
+
# spend_df.sort_values(by="Date", inplace=True)
|
1233 |
+
|
1234 |
+
# channel_list = [col for col in input_df.columns if col not in exclude_columns]
|
1235 |
+
|
1236 |
+
# negative_contribution = [
|
1237 |
+
# c
|
1238 |
+
# for c in contri_df.select_dtypes(include=["float", "int"]).columns
|
1239 |
+
# if contri_df[c].sum() <= 0
|
1240 |
+
# ]
|
1241 |
+
# channel_list = list(set(channel_list) - set(negative_contribution))
|
1242 |
+
|
1243 |
+
# response_curves = {}
|
1244 |
+
# mapes = {}
|
1245 |
+
# rmses = {}
|
1246 |
+
# upper_limits = {}
|
1247 |
+
# powers = {}
|
1248 |
+
# r2 = {}
|
1249 |
+
# conv_rates = {}
|
1250 |
+
# output_cols = []
|
1251 |
+
# channels = {}
|
1252 |
+
# sales = None
|
1253 |
+
# dates = input_df.Date.values
|
1254 |
+
# actual_output_dic = {}
|
1255 |
+
# actual_input_dic = {}
|
1256 |
+
|
1257 |
+
# for inp_col in channel_list:
|
1258 |
+
# spends = input_df[inp_col].values
|
1259 |
+
# x = spends.copy()
|
1260 |
+
# # upper limit for penalty
|
1261 |
+
# upper_limits[inp_col] = 2 * x.max()
|
1262 |
+
|
1263 |
+
# # contribution
|
1264 |
+
# out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
|
1265 |
+
# y = output_df[out_col].values.copy()
|
1266 |
+
# actual_output_dic[inp_col] = y.copy()
|
1267 |
+
# actual_input_dic[inp_col] = x.copy()
|
1268 |
+
# ##output cols aggregation
|
1269 |
+
# output_cols.append(out_col)
|
1270 |
+
|
1271 |
+
# ## scale the input
|
1272 |
+
# power = np.ceil(np.log(x.max()) / np.log(10)) - 3
|
1273 |
+
# if power >= 0:
|
1274 |
+
# x = x / 10**power
|
1275 |
+
|
1276 |
+
# x = x.astype("float64")
|
1277 |
+
# y = y.astype("float64")
|
1278 |
+
|
1279 |
+
# if y.max() <= 0.01:
|
1280 |
+
# if x.max() <= 0.0:
|
1281 |
+
# bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, 0.01))
|
1282 |
+
|
1283 |
+
# else:
|
1284 |
+
# bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, x.max()))
|
1285 |
+
# else:
|
1286 |
+
# bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
|
1287 |
+
|
1288 |
+
# params, _ = curve_fit(
|
1289 |
+
# s_curve,
|
1290 |
+
# x,
|
1291 |
+
# y,
|
1292 |
+
# p0=(2 * y.max(), 0.01, 1e-5, x.max()),
|
1293 |
+
# bounds=bounds,
|
1294 |
+
# maxfev=int(1e5),
|
1295 |
+
# )
|
1296 |
+
# mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
|
1297 |
+
# rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
|
1298 |
+
# r2_ = r2_score(y, s_curve(x, *params))
|
1299 |
+
|
1300 |
+
# response_curves[inp_col] = {
|
1301 |
+
# "K": params[0],
|
1302 |
+
# "b": params[1],
|
1303 |
+
# "a": params[2],
|
1304 |
+
# "x0": params[3],
|
1305 |
+
# }
|
1306 |
+
|
1307 |
+
# mapes[inp_col] = mape
|
1308 |
+
# rmses[inp_col] = rmse
|
1309 |
+
# r2[inp_col] = r2_
|
1310 |
+
# powers[inp_col] = power
|
1311 |
+
|
1312 |
+
# conv = spend_df[inp_col].sum() / max(input_df[inp_col].sum(), 1e-9)
|
1313 |
+
# conv_rates[inp_col] = conv
|
1314 |
+
|
1315 |
+
# correction = y - s_curve(x, *params)
|
1316 |
+
|
1317 |
+
# channel = Channel(
|
1318 |
+
# name=inp_col,
|
1319 |
+
# dates=dates,
|
1320 |
+
# spends=spends,
|
1321 |
+
# conversion_rate=conv_rates[inp_col],
|
1322 |
+
# response_curve_type="s-curve",
|
1323 |
+
# response_curve_params={
|
1324 |
+
# "K": params[0],
|
1325 |
+
# "b": params[1],
|
1326 |
+
# "a": params[2],
|
1327 |
+
# "x0": params[3],
|
1328 |
+
# },
|
1329 |
+
# bounds=np.array([-10, 10]),
|
1330 |
+
# correction=correction,
|
1331 |
+
# )
|
1332 |
+
|
1333 |
+
# channels[inp_col] = channel
|
1334 |
+
# if sales is None:
|
1335 |
+
# sales = channel.actual_sales
|
1336 |
+
# else:
|
1337 |
+
# sales += channel.actual_sales
|
1338 |
+
|
1339 |
+
# other_contributions = (
|
1340 |
+
# output_df.drop([*output_cols], axis=1).sum(axis=1, numeric_only=True).values
|
1341 |
+
# )
|
1342 |
+
|
1343 |
+
# scenario = Scenario(
|
1344 |
+
# name="default",
|
1345 |
+
# channels=channels,
|
1346 |
+
# constant=other_contributions,
|
1347 |
+
# correction=np.array([]),
|
1348 |
+
# )
|
1349 |
+
|
1350 |
+
# ## setting session variables
|
1351 |
+
# st.session_state["initialized"] = True
|
1352 |
+
# st.session_state["actual_df"] = input_df
|
1353 |
+
# st.session_state["raw_df"] = raw_df
|
1354 |
+
# st.session_state["contri_df"] = output_df
|
1355 |
+
# default_scenario_dict = class_to_dict(scenario)
|
1356 |
+
# st.session_state["default_scenario_dict"] = default_scenario_dict
|
1357 |
+
# st.session_state["scenario"] = scenario
|
1358 |
+
# st.session_state["channels_list"] = channel_list
|
1359 |
+
# st.session_state["optimization_channels"] = {
|
1360 |
+
# channel_name: False for channel_name in channel_list
|
1361 |
+
# }
|
1362 |
+
# st.session_state["rcs"] = response_curves.copy()
|
1363 |
+
|
1364 |
+
# st.session_state["powers"] = powers
|
1365 |
+
# st.session_state["actual_contribution_df"] = pd.DataFrame(actual_output_dic)
|
1366 |
+
# st.session_state["actual_input_df"] = pd.DataFrame(actual_input_dic)
|
1367 |
+
|
1368 |
+
# for channel in channels.values():
|
1369 |
+
# st.session_state[channel.name] = numerize(
|
1370 |
+
# channel.actual_total_spends * channel.conversion_rate, 1
|
1371 |
+
# )
|
1372 |
+
|
1373 |
+
# st.session_state["xlsx_buffer"] = io.BytesIO()
|
1374 |
+
|
1375 |
+
# if Path("../saved_scenarios.pkl").exists():
|
1376 |
+
# with open("../saved_scenarios.pkl", "rb") as f:
|
1377 |
+
# st.session_state["saved_scenarios"] = pickle.load(f)
|
1378 |
+
# else:
|
1379 |
+
# st.session_state["saved_scenarios"] = OrderedDict()
|
1380 |
+
|
1381 |
+
# # st.session_state["total_spends_change"] = 0
|
1382 |
+
# st.session_state["optimization_channels"] = {
|
1383 |
+
# channel_name: False for channel_name in channel_list
|
1384 |
+
# }
|
1385 |
+
# st.session_state["disable_download_button"] = True
|
1386 |
+
|
1387 |
+
# rcs_data = {}
|
1388 |
+
# for channel in st.session_state["rcs"]:
|
1389 |
+
# # Convert to native Python lists and types
|
1390 |
+
# x = list(st.session_state["actual_input_df"][channel].values.astype(float))
|
1391 |
+
# y = list(
|
1392 |
+
# st.session_state["actual_contribution_df"][channel].values.astype(float)
|
1393 |
+
# )
|
1394 |
+
# power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
|
1395 |
+
# x_plot = list(np.linspace(0, 5 * max(x), 100))
|
1396 |
+
|
1397 |
+
# rcs_data[channel] = {
|
1398 |
+
# "K": float(st.session_state["rcs"][channel]["K"]),
|
1399 |
+
# "b": float(st.session_state["rcs"][channel]["b"]),
|
1400 |
+
# "a": float(st.session_state["rcs"][channel]["a"]),
|
1401 |
+
# "x0": float(st.session_state["rcs"][channel]["x0"]),
|
1402 |
+
# "power": power,
|
1403 |
+
# "x": x,
|
1404 |
+
# "y": y,
|
1405 |
+
# "x_plot": x_plot,
|
1406 |
+
# }
|
1407 |
+
|
1408 |
+
# return rcs_data, scenario
|
1409 |
+
|
1410 |
+
|
1411 |
+
# def initialize_data():
|
1412 |
+
# # fetch data from excel
|
1413 |
+
# output = pd.read_excel('data.xlsx',sheet_name=None)
|
1414 |
+
# raw_df = output['RAW DATA MMM']
|
1415 |
+
# contribution_df = output['CONTRIBUTION MMM']
|
1416 |
+
# Revenue_df = output['Revenue']
|
1417 |
+
|
1418 |
+
# ## channels to be shows
|
1419 |
+
# channel_list = []
|
1420 |
+
# for col in raw_df.columns:
|
1421 |
+
# if 'click' in col.lower() or 'spend' in col.lower() or 'imp' in col.lower():
|
1422 |
+
# channel_list.append(col)
|
1423 |
+
# else:
|
1424 |
+
# pass
|
1425 |
+
|
1426 |
+
# ## NOTE : Considered only Desktop spends for all calculations
|
1427 |
+
# acutal_df = raw_df[raw_df.Region == 'Desktop'].copy()
|
1428 |
+
# ## NOTE : Considered one year of data
|
1429 |
+
# acutal_df = acutal_df[acutal_df.Date>'2020-12-31']
|
1430 |
+
# actual_df = acutal_df.drop('Region',axis=1).sort_values(by='Date')[[*channel_list,'Date']]
|
1431 |
+
|
1432 |
+
# ##load response curves
|
1433 |
+
# with open('./grammarly_response_curves.json','r') as f:
|
1434 |
+
# response_curves = json.load(f)
|
1435 |
+
|
1436 |
+
# ## create channel dict for scenario creation
|
1437 |
+
# dates = actual_df.Date.values
|
1438 |
+
# channels = {}
|
1439 |
+
# rcs = {}
|
1440 |
+
# constant = 0.
|
1441 |
+
# for i,info_dict in enumerate(response_curves):
|
1442 |
+
# name = info_dict.get('name')
|
1443 |
+
# response_curve_type = info_dict.get('response_curve')
|
1444 |
+
# response_curve_params = info_dict.get('params')
|
1445 |
+
# rcs[name] = response_curve_params
|
1446 |
+
# if name != 'constant':
|
1447 |
+
# spends = actual_df[name].values
|
1448 |
+
# channel = Channel(name=name,dates=dates,
|
1449 |
+
# spends=spends,
|
1450 |
+
# response_curve_type=response_curve_type,
|
1451 |
+
# response_curve_params=response_curve_params,
|
1452 |
+
# bounds=np.array([-30,30]))
|
1453 |
+
|
1454 |
+
# channels[name] = channel
|
1455 |
+
# else:
|
1456 |
+
# constant = info_dict.get('value',0.) * len(dates)
|
1457 |
+
|
1458 |
+
# ## create scenario
|
1459 |
+
# scenario = Scenario(name='default', channels=channels, constant=constant)
|
1460 |
+
# default_scenario_dict = class_to_dict(scenario)
|
1461 |
+
|
1462 |
+
|
1463 |
+
# ## setting session variables
|
1464 |
+
# st.session_state['initialized'] = True
|
1465 |
+
# st.session_state['actual_df'] = actual_df
|
1466 |
+
# st.session_state['raw_df'] = raw_df
|
1467 |
+
# st.session_state['default_scenario_dict'] = default_scenario_dict
|
1468 |
+
# st.session_state['scenario'] = scenario
|
1469 |
+
# st.session_state['channels_list'] = channel_list
|
1470 |
+
# st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
|
1471 |
+
# st.session_state['rcs'] = rcs
|
1472 |
+
# for channel in channels.values():
|
1473 |
+
# if channel.name not in st.session_state:
|
1474 |
+
# st.session_state[channel.name] = float(channel.actual_total_spends)
|
1475 |
+
|
1476 |
+
# if 'xlsx_buffer' not in st.session_state:
|
1477 |
+
# st.session_state['xlsx_buffer'] = io.BytesIO()
|
1478 |
+
|
1479 |
+
# ## for saving scenarios
|
1480 |
+
# if 'saved_scenarios' not in st.session_state:
|
1481 |
+
# if Path('../saved_scenarios.pkl').exists():
|
1482 |
+
# with open('../saved_scenarios.pkl','rb') as f:
|
1483 |
+
# st.session_state['saved_scenarios'] = pickle.load(f)
|
1484 |
+
|
1485 |
+
# else:
|
1486 |
+
# st.session_state['saved_scenarios'] = OrderedDict()
|
1487 |
+
|
1488 |
+
# if 'total_spends_change' not in st.session_state:
|
1489 |
+
# st.session_state['total_spends_change'] = 0
|
1490 |
+
|
1491 |
+
# if 'optimization_channels' not in st.session_state:
|
1492 |
+
# st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
|
1493 |
+
|
1494 |
+
# if 'disable_download_button' not in st.session_state:
|
1495 |
+
# st.session_state['disable_download_button'] = True
|
1496 |
+
|
1497 |
+
|
1498 |
+
def create_channel_summary(scenario):
|
1499 |
+
|
1500 |
+
# Provided data
|
1501 |
+
data = {
|
1502 |
+
"Channel": [
|
1503 |
+
"Paid Search",
|
1504 |
+
"Ga will cid baixo risco",
|
1505 |
+
"Digital tactic others",
|
1506 |
+
"Fb la tier 1",
|
1507 |
+
"Fb la tier 2",
|
1508 |
+
"Paid social others",
|
1509 |
+
"Programmatic",
|
1510 |
+
"Kwai",
|
1511 |
+
"Indicacao",
|
1512 |
+
"Infleux",
|
1513 |
+
"Influencer",
|
1514 |
+
],
|
1515 |
+
"Spends": [
|
1516 |
+
"$ 11.3K",
|
1517 |
+
"$ 155.2K",
|
1518 |
+
"$ 50.7K",
|
1519 |
+
"$ 125.4K",
|
1520 |
+
"$ 125.2K",
|
1521 |
+
"$ 105K",
|
1522 |
+
"$ 3.3M",
|
1523 |
+
"$ 47.5K",
|
1524 |
+
"$ 55.9K",
|
1525 |
+
"$ 632.3K",
|
1526 |
+
"$ 48.3K",
|
1527 |
+
],
|
1528 |
+
"Revenue": [
|
1529 |
+
"558.0K",
|
1530 |
+
"3.5M",
|
1531 |
+
"5.2M",
|
1532 |
+
"3.1M",
|
1533 |
+
"3.1M",
|
1534 |
+
"2.1M",
|
1535 |
+
"20.8M",
|
1536 |
+
"1.6M",
|
1537 |
+
"728.4K",
|
1538 |
+
"22.9M",
|
1539 |
+
"4.8M",
|
1540 |
+
],
|
1541 |
+
}
|
1542 |
+
|
1543 |
+
# Create DataFrame
|
1544 |
+
df = pd.DataFrame(data)
|
1545 |
+
|
1546 |
+
# Convert currency strings to numeric values
|
1547 |
+
df["Spends"] = (
|
1548 |
+
df["Spends"]
|
1549 |
+
.replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
|
1550 |
+
.map(pd.eval)
|
1551 |
+
.astype(int)
|
1552 |
+
)
|
1553 |
+
df["Revenue"] = (
|
1554 |
+
df["Revenue"]
|
1555 |
+
.replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
|
1556 |
+
.map(pd.eval)
|
1557 |
+
.astype(int)
|
1558 |
+
)
|
1559 |
+
|
1560 |
+
# Calculate ROI
|
1561 |
+
df["ROI"] = (df["Revenue"] - df["Spends"]) / df["Spends"]
|
1562 |
+
|
1563 |
+
# Format columns
|
1564 |
+
format_currency = lambda x: f"${x:,.1f}"
|
1565 |
+
format_roi = lambda x: f"{x:.1f}"
|
1566 |
+
|
1567 |
+
df["Spends"] = [
|
1568 |
+
"$ 11.3K",
|
1569 |
+
"$ 155.2K",
|
1570 |
+
"$ 50.7K",
|
1571 |
+
"$ 125.4K",
|
1572 |
+
"$ 125.2K",
|
1573 |
+
"$ 105K",
|
1574 |
+
"$ 3.3M",
|
1575 |
+
"$ 47.5K",
|
1576 |
+
"$ 55.9K",
|
1577 |
+
"$ 632.3K",
|
1578 |
+
"$ 48.3K",
|
1579 |
+
]
|
1580 |
+
df["Revenue"] = [
|
1581 |
+
"$ 536.3K",
|
1582 |
+
"$ 3.4M",
|
1583 |
+
"$ 5M",
|
1584 |
+
"$ 3M",
|
1585 |
+
"$ 3M",
|
1586 |
+
"$ 2M",
|
1587 |
+
"$ 20M",
|
1588 |
+
"$ 1.5M",
|
1589 |
+
"$ 7.1M",
|
1590 |
+
"$ 22M",
|
1591 |
+
"$ 4.6M",
|
1592 |
+
]
|
1593 |
+
df["ROI"] = df["ROI"].apply(format_roi)
|
1594 |
+
|
1595 |
+
return df
|
1596 |
+
|
1597 |
+
|
1598 |
+
# @st.cache(allow_output_mutation=True)
|
1599 |
+
# def create_contribution_pie(scenario):
|
1600 |
+
# #c1f7dc
|
1601 |
+
# colors_map = {col:color for col,color in zip(st.session_state['channels_list'],plotly.colors.n_colors(plotly.colors.hex_to_rgb('#BE6468'), plotly.colors.hex_to_rgb('#E7B8B7'),23))}
|
1602 |
+
# total_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "pie"}, {"type": "pie"}]])
|
1603 |
+
# total_contribution_fig.add_trace(
|
1604 |
+
# go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
|
1605 |
+
# values= [round(scenario.channels[channel_name].actual_total_spends * scenario.channels[channel_name].conversion_rate,1) for channel_name in st.session_state['channels_list']] + [0],
|
1606 |
+
# marker=dict(colors = [plotly.colors.label_rgb(colors_map[channel_name]) for channel_name in st.session_state['channels_list']] + ['#F0F0F0']),
|
1607 |
+
# hole=0.3),
|
1608 |
+
# row=1, col=1)
|
1609 |
+
|
1610 |
+
# total_contribution_fig.add_trace(
|
1611 |
+
# go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
|
1612 |
+
# values= [scenario.channels[channel_name].actual_total_sales for channel_name in st.session_state['channels_list']] + [scenario.correction.sum() + scenario.constant.sum()],
|
1613 |
+
# hole=0.3),
|
1614 |
+
# row=1, col=2)
|
1615 |
+
|
1616 |
+
# total_contribution_fig.update_traces(textposition='inside',texttemplate='%{percent:.1%}')
|
1617 |
+
# total_contribution_fig.update_layout(uniformtext_minsize=12,title='Channel contribution', uniformtext_mode='hide')
|
1618 |
+
# return total_contribution_fig
|
1619 |
+
|
1620 |
+
# @st.cache(allow_output_mutation=True)
|
1621 |
+
|
1622 |
+
# def create_contribuion_stacked_plot(scenario):
|
1623 |
+
# weekly_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "bar"}, {"type": "bar"}]])
|
1624 |
+
# raw_df = st.session_state['raw_df']
|
1625 |
+
# df = raw_df.sort_values(by='Date')
|
1626 |
+
# x = df.Date
|
1627 |
+
# weekly_spends_data = []
|
1628 |
+
# weekly_sales_data = []
|
1629 |
+
# for channel_name in st.session_state['channels_list']:
|
1630 |
+
# weekly_spends_data.append((go.Bar(x=x,
|
1631 |
+
# y=scenario.channels[channel_name].actual_spends * scenario.channels[channel_name].conversion_rate,
|
1632 |
+
# name=channel_name_formating(channel_name),
|
1633 |
+
# hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
|
1634 |
+
# legendgroup=channel_name)))
|
1635 |
+
# weekly_sales_data.append((go.Bar(x=x,
|
1636 |
+
# y=scenario.channels[channel_name].actual_sales,
|
1637 |
+
# name=channel_name_formating(channel_name),
|
1638 |
+
# hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1639 |
+
# legendgroup=channel_name, showlegend=False)))
|
1640 |
+
# for _d in weekly_spends_data:
|
1641 |
+
# weekly_contribution_fig.add_trace(_d, row=1, col=1)
|
1642 |
+
# for _d in weekly_sales_data:
|
1643 |
+
# weekly_contribution_fig.add_trace(_d, row=1, col=2)
|
1644 |
+
# weekly_contribution_fig.add_trace(go.Bar(x=x,
|
1645 |
+
# y=scenario.constant + scenario.correction,
|
1646 |
+
# name='Non Media',
|
1647 |
+
# hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), row=1, col=2)
|
1648 |
+
# weekly_contribution_fig.update_layout(barmode='stack', title='Channel contribuion by week', xaxis_title='Date')
|
1649 |
+
# weekly_contribution_fig.update_xaxes(showgrid=False)
|
1650 |
+
# weekly_contribution_fig.update_yaxes(showgrid=False)
|
1651 |
+
# return weekly_contribution_fig
|
1652 |
+
|
1653 |
+
# @st.cache(allow_output_mutation=True)
|
1654 |
+
# def create_channel_spends_sales_plot(channel):
|
1655 |
+
# if channel is not None:
|
1656 |
+
# x = channel.dates
|
1657 |
+
# _spends = channel.actual_spends * channel.conversion_rate
|
1658 |
+
# _sales = channel.actual_sales
|
1659 |
+
# channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1660 |
+
# channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
|
1661 |
+
# channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#005b96'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
|
1662 |
+
# channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
|
1663 |
+
# channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1664 |
+
# channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1665 |
+
# else:
|
1666 |
+
# raw_df = st.session_state['raw_df']
|
1667 |
+
# df = raw_df.sort_values(by='Date')
|
1668 |
+
# x = df.Date
|
1669 |
+
# scenario = class_from_dict(st.session_state['default_scenario_dict'])
|
1670 |
+
# _sales = scenario.constant + scenario.correction
|
1671 |
+
# channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1672 |
+
# channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
|
1673 |
+
# # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#15C39A'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
|
1674 |
+
# channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
|
1675 |
+
# channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1676 |
+
# channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1677 |
+
# return channel_sales_spends_fig
|
1678 |
+
|
1679 |
+
|
1680 |
+
# Define a shared color palette
|
1681 |
+
|
1682 |
+
|
1683 |
+
def create_contribution_pie():
|
1684 |
+
color_palette = [
|
1685 |
+
"#F3F3F0",
|
1686 |
+
"#5E7D7E",
|
1687 |
+
"#2FA1FF",
|
1688 |
+
"#00EDED",
|
1689 |
+
"#00EAE4",
|
1690 |
+
"#304550",
|
1691 |
+
"#EDEBEB",
|
1692 |
+
"#7FBEFD",
|
1693 |
+
"#003059",
|
1694 |
+
"#A2F3F3",
|
1695 |
+
"#E1D6E2",
|
1696 |
+
"#B6B6B6",
|
1697 |
+
]
|
1698 |
+
total_contribution_fig = make_subplots(
|
1699 |
+
rows=1,
|
1700 |
+
cols=2,
|
1701 |
+
subplot_titles=["Spends", "Revenue"],
|
1702 |
+
specs=[[{"type": "pie"}, {"type": "pie"}]],
|
1703 |
+
)
|
1704 |
+
|
1705 |
+
channels_list = [
|
1706 |
+
"Paid Search",
|
1707 |
+
"Ga will cid baixo risco",
|
1708 |
+
"Digital tactic others",
|
1709 |
+
"Fb la tier 1",
|
1710 |
+
"Fb la tier 2",
|
1711 |
+
"Paid social others",
|
1712 |
+
"Programmatic",
|
1713 |
+
"Kwai",
|
1714 |
+
"Indicacao",
|
1715 |
+
"Infleux",
|
1716 |
+
"Influencer",
|
1717 |
+
"Non Media",
|
1718 |
+
]
|
1719 |
+
|
1720 |
+
# Assign colors from the limited palette to channels
|
1721 |
+
colors_map = {
|
1722 |
+
col: color_palette[i % len(color_palette)]
|
1723 |
+
for i, col in enumerate(channels_list)
|
1724 |
+
}
|
1725 |
+
colors_map["Non Media"] = color_palette[
|
1726 |
+
5
|
1727 |
+
] # Assign fixed green color for 'Non Media'
|
1728 |
+
|
1729 |
+
# Hardcoded values for Spends and Revenue
|
1730 |
+
spends_values = [0.5, 3.36, 1.1, 2.7, 2.7, 2.27, 70.6, 1, 1, 13.7, 1, 0]
|
1731 |
+
revenue_values = [1, 4, 5, 3, 3, 2, 50.8, 1.5, 0.7, 13, 0, 16]
|
1732 |
+
|
1733 |
+
# Add trace for Spends pie chart
|
1734 |
+
total_contribution_fig.add_trace(
|
1735 |
+
go.Pie(
|
1736 |
+
labels=[channel_name for channel_name in channels_list],
|
1737 |
+
values=spends_values,
|
1738 |
+
marker=dict(
|
1739 |
+
colors=[colors_map[channel_name] for channel_name in channels_list]
|
1740 |
+
),
|
1741 |
+
hole=0.3,
|
1742 |
+
),
|
1743 |
+
row=1,
|
1744 |
+
col=1,
|
1745 |
+
)
|
1746 |
+
|
1747 |
+
# Add trace for Revenue pie chart
|
1748 |
+
total_contribution_fig.add_trace(
|
1749 |
+
go.Pie(
|
1750 |
+
labels=[channel_name for channel_name in channels_list],
|
1751 |
+
values=revenue_values,
|
1752 |
+
marker=dict(
|
1753 |
+
colors=[colors_map[channel_name] for channel_name in channels_list]
|
1754 |
+
),
|
1755 |
+
hole=0.3,
|
1756 |
+
),
|
1757 |
+
row=1,
|
1758 |
+
col=2,
|
1759 |
+
)
|
1760 |
+
|
1761 |
+
total_contribution_fig.update_traces(
|
1762 |
+
textposition="inside", texttemplate="%{percent:.1%}"
|
1763 |
+
)
|
1764 |
+
total_contribution_fig.update_layout(
|
1765 |
+
uniformtext_minsize=12,
|
1766 |
+
title="Channel contribution",
|
1767 |
+
uniformtext_mode="hide",
|
1768 |
+
)
|
1769 |
+
return total_contribution_fig
|
1770 |
+
|
1771 |
+
|
1772 |
+
def create_contribuion_stacked_plot(scenario):
|
1773 |
+
weekly_contribution_fig = make_subplots(
|
1774 |
+
rows=1,
|
1775 |
+
cols=2,
|
1776 |
+
subplot_titles=["Spends", "Revenue"],
|
1777 |
+
specs=[[{"type": "bar"}, {"type": "bar"}]],
|
1778 |
+
)
|
1779 |
+
raw_df = st.session_state["raw_df"]
|
1780 |
+
df = raw_df.sort_values(by="Date")
|
1781 |
+
x = df.Date
|
1782 |
+
weekly_spends_data = []
|
1783 |
+
weekly_sales_data = []
|
1784 |
+
|
1785 |
+
for i, channel_name in enumerate(st.session_state["channels_list"]):
|
1786 |
+
color = color_palette[i % len(color_palette)]
|
1787 |
+
|
1788 |
+
weekly_spends_data.append(
|
1789 |
+
go.Bar(
|
1790 |
+
x=x,
|
1791 |
+
y=scenario.channels[channel_name].actual_spends
|
1792 |
+
* scenario.channels[channel_name].conversion_rate,
|
1793 |
+
name=channel_name_formating(channel_name),
|
1794 |
+
hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
|
1795 |
+
legendgroup=channel_name,
|
1796 |
+
marker_color=color,
|
1797 |
+
)
|
1798 |
+
)
|
1799 |
+
|
1800 |
+
weekly_sales_data.append(
|
1801 |
+
go.Bar(
|
1802 |
+
x=x,
|
1803 |
+
y=scenario.channels[channel_name].actual_sales,
|
1804 |
+
name=channel_name_formating(channel_name),
|
1805 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1806 |
+
legendgroup=channel_name,
|
1807 |
+
showlegend=False,
|
1808 |
+
marker_color=color,
|
1809 |
+
)
|
1810 |
+
)
|
1811 |
+
|
1812 |
+
for _d in weekly_spends_data:
|
1813 |
+
weekly_contribution_fig.add_trace(_d, row=1, col=1)
|
1814 |
+
for _d in weekly_sales_data:
|
1815 |
+
weekly_contribution_fig.add_trace(_d, row=1, col=2)
|
1816 |
+
|
1817 |
+
weekly_contribution_fig.add_trace(
|
1818 |
+
go.Bar(
|
1819 |
+
x=x,
|
1820 |
+
y=scenario.constant + scenario.correction,
|
1821 |
+
name="Non Media",
|
1822 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1823 |
+
marker_color=color_palette[-1],
|
1824 |
+
),
|
1825 |
+
row=1,
|
1826 |
+
col=2,
|
1827 |
+
)
|
1828 |
+
|
1829 |
+
weekly_contribution_fig.update_layout(
|
1830 |
+
barmode="stack",
|
1831 |
+
title="Channel contribution by week",
|
1832 |
+
xaxis_title="Date",
|
1833 |
+
)
|
1834 |
+
weekly_contribution_fig.update_xaxes(showgrid=False)
|
1835 |
+
weekly_contribution_fig.update_yaxes(showgrid=False)
|
1836 |
+
return weekly_contribution_fig
|
1837 |
+
|
1838 |
+
|
1839 |
+
def create_channel_spends_sales_plot(channel):
|
1840 |
+
if channel is not None:
|
1841 |
+
x = channel.dates
|
1842 |
+
_spends = channel.actual_spends * channel.conversion_rate
|
1843 |
+
_sales = channel.actual_sales
|
1844 |
+
channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1845 |
+
channel_sales_spends_fig.add_trace(
|
1846 |
+
go.Bar(
|
1847 |
+
x=x,
|
1848 |
+
y=_sales,
|
1849 |
+
marker_color=color_palette[
|
1850 |
+
3
|
1851 |
+
], # You can choose a color from the palette
|
1852 |
+
name="Revenue",
|
1853 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1854 |
+
),
|
1855 |
+
secondary_y=False,
|
1856 |
+
)
|
1857 |
+
|
1858 |
+
channel_sales_spends_fig.add_trace(
|
1859 |
+
go.Scatter(
|
1860 |
+
x=x,
|
1861 |
+
y=_spends,
|
1862 |
+
line=dict(
|
1863 |
+
color=color_palette[2]
|
1864 |
+
), # You can choose another color from the palette
|
1865 |
+
name="Spends",
|
1866 |
+
hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
|
1867 |
+
),
|
1868 |
+
secondary_y=True,
|
1869 |
+
)
|
1870 |
+
|
1871 |
+
channel_sales_spends_fig.update_layout(
|
1872 |
+
xaxis_title="Date",
|
1873 |
+
yaxis_title="Revenue",
|
1874 |
+
yaxis2_title="Spends ($)",
|
1875 |
+
title="Channel spends and Revenue week-wise",
|
1876 |
+
)
|
1877 |
+
channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1878 |
+
channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1879 |
+
else:
|
1880 |
+
raw_df = st.session_state["raw_df"]
|
1881 |
+
df = raw_df.sort_values(by="Date")
|
1882 |
+
x = df.Date
|
1883 |
+
scenario = class_from_dict(st.session_state["default_scenario_dict"])
|
1884 |
+
_sales = scenario.constant + scenario.correction
|
1885 |
+
channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1886 |
+
channel_sales_spends_fig.add_trace(
|
1887 |
+
go.Bar(
|
1888 |
+
x=x,
|
1889 |
+
y=_sales,
|
1890 |
+
marker_color=color_palette[
|
1891 |
+
0
|
1892 |
+
], # You can choose a color from the palette
|
1893 |
+
name="Revenue",
|
1894 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1895 |
+
),
|
1896 |
+
secondary_y=False,
|
1897 |
+
)
|
1898 |
+
|
1899 |
+
channel_sales_spends_fig.update_layout(
|
1900 |
+
xaxis_title="Date",
|
1901 |
+
yaxis_title="Revenue",
|
1902 |
+
yaxis2_title="Spends ($)",
|
1903 |
+
title="Channel spends and Revenue week-wise",
|
1904 |
+
)
|
1905 |
+
channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1906 |
+
channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1907 |
+
|
1908 |
+
return channel_sales_spends_fig
|
1909 |
+
|
1910 |
+
|
1911 |
+
def format_numbers(value, n_decimals=1, include_indicator=True):
|
1912 |
+
if value is None:
|
1913 |
+
return None
|
1914 |
+
_value = value if value < 1 else numerize(value, n_decimals)
|
1915 |
+
if include_indicator:
|
1916 |
+
return f"{CURRENCY_INDICATOR} {_value}"
|
1917 |
+
else:
|
1918 |
+
return f"{_value}"
|
1919 |
+
|
1920 |
+
|
1921 |
+
def decimal_formater(num_string, n_decimals=1):
|
1922 |
+
parts = num_string.split(".")
|
1923 |
+
if len(parts) == 1:
|
1924 |
+
return num_string + "." + "0" * n_decimals
|
1925 |
+
else:
|
1926 |
+
to_be_padded = n_decimals - len(parts[-1])
|
1927 |
+
if to_be_padded > 0:
|
1928 |
+
return num_string + "0" * to_be_padded
|
1929 |
+
else:
|
1930 |
+
return num_string
|
1931 |
+
|
1932 |
+
|
1933 |
+
def channel_name_formating(channel_name):
|
1934 |
+
name_mod = channel_name.replace("_", " ")
|
1935 |
+
if name_mod.lower().endswith(" imp"):
|
1936 |
+
name_mod = name_mod.replace("Imp", "Spend")
|
1937 |
+
elif name_mod.lower().endswith(" clicks"):
|
1938 |
+
name_mod = name_mod.replace("Clicks", "Spend")
|
1939 |
+
return name_mod
|
1940 |
+
|
1941 |
+
|
1942 |
+
def send_email(email, message):
|
1943 |
+
s = smtplib.SMTP("smtp.gmail.com", 587)
|
1944 |
+
s.starttls()
|
1945 |
+
s.login("[email protected]", "jgydhpfusuremcol")
|
1946 |
+
s.sendmail("[email protected]", email, message)
|
1947 |
+
s.quit()
|
1948 |
+
|
1949 |
+
|
1950 |
+
# if __name__ == "__main__":
|
1951 |
+
# initialize_data()
|
1952 |
+
|
1953 |
+
|
1954 |
+
#############################################################################################################
|
1955 |
+
|
1956 |
+
import os
|
1957 |
+
import json
|
1958 |
+
import streamlit as st
|
1959 |
+
|
1960 |
+
|
1961 |
+
# Function to get panels names
|
1962 |
+
def get_panels_names(file_selected):
|
1963 |
+
raw_data_df = st.session_state["project_dct"]["current_media_performance"][
|
1964 |
+
"model_outputs"
|
1965 |
+
][file_selected]["raw_data"]
|
1966 |
+
|
1967 |
+
if "panel" in raw_data_df.columns:
|
1968 |
+
panel = list(set(raw_data_df["panel"]))
|
1969 |
+
elif "Panel" in raw_data_df.columns:
|
1970 |
+
panel = list(set(raw_data_df["Panel"]))
|
1971 |
+
else:
|
1972 |
+
panel = []
|
1973 |
+
|
1974 |
+
return panel + ["aggregated"]
|
1975 |
+
|
1976 |
+
|
1977 |
+
# Function to get metrics names
|
1978 |
+
def get_metrics_names():
|
1979 |
+
return list(
|
1980 |
+
st.session_state["project_dct"]["current_media_performance"][
|
1981 |
+
"model_outputs"
|
1982 |
+
].keys()
|
1983 |
+
)
|
1984 |
+
|
1985 |
+
|
1986 |
+
# Function to load the original and modified rcs metadata files into dictionaries
|
1987 |
+
def load_rcs_metadata_files():
|
1988 |
+
original_data = st.session_state["project_dct"]["response_curves"][
|
1989 |
+
"original_metadata_file"
|
1990 |
+
]
|
1991 |
+
modified_data = st.session_state["project_dct"]["response_curves"][
|
1992 |
+
"modified_metadata_file"
|
1993 |
+
]
|
1994 |
+
|
1995 |
+
return original_data, modified_data
|
1996 |
+
|
1997 |
+
|
1998 |
+
# Function to format name
|
1999 |
+
def name_formating(name):
|
2000 |
+
# Replace underscores with spaces
|
2001 |
+
name_mod = name.replace("_", " ")
|
2002 |
+
|
2003 |
+
# Capitalize the first letter of each word
|
2004 |
+
name_mod = name_mod.title()
|
2005 |
+
|
2006 |
+
return name_mod
|
2007 |
+
|
2008 |
+
|
2009 |
+
# Function to load the original and modified scenario metadata files into dictionaries
|
2010 |
+
def load_scenario_metadata_files():
|
2011 |
+
original_data = st.session_state["project_dct"]["scenario_planner"][
|
2012 |
+
"original_metadata_file"
|
2013 |
+
]
|
2014 |
+
modified_data = st.session_state["project_dct"]["scenario_planner"][
|
2015 |
+
"modified_metadata_file"
|
2016 |
+
]
|
2017 |
+
|
2018 |
+
return original_data, modified_data
|
2019 |
+
|
2020 |
+
|
2021 |
+
# Function to generate RCS data and store it as dictionary
|
2022 |
+
def generate_rcs_data():
|
2023 |
+
# Retrieve the list of all metric names from the specified directory
|
2024 |
+
metrics_list = get_metrics_names()
|
2025 |
+
|
2026 |
+
# Dictionary to store RCS data for all metrics and their respective panels
|
2027 |
+
all_rcs_data_original = {}
|
2028 |
+
all_rcs_data_modified = {}
|
2029 |
+
|
2030 |
+
# Iterate over each metric in the metrics list
|
2031 |
+
for metric in metrics_list:
|
2032 |
+
# Retrieve the list of panel names from the current metric's Excel file
|
2033 |
+
panel_list = get_panels_names(file_selected=metric)
|
2034 |
+
|
2035 |
+
# Check if rcs_data_modified exist
|
2036 |
+
if (
|
2037 |
+
st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
|
2038 |
+
is not None
|
2039 |
+
):
|
2040 |
+
modified_data = st.session_state["project_dct"]["response_curves"][
|
2041 |
+
"modified_metadata_file"
|
2042 |
+
]
|
2043 |
+
|
2044 |
+
# Iterate over each panel in the panel list
|
2045 |
+
for panel in panel_list:
|
2046 |
+
# Initialize the original RCS data for the current panel and metric
|
2047 |
+
rcs_dict_original, scenario = initialize_data(
|
2048 |
+
panel=panel,
|
2049 |
+
metrics=metric,
|
2050 |
+
)
|
2051 |
+
|
2052 |
+
# Ensure the dictionary has the metric as a key for original data
|
2053 |
+
if metric not in all_rcs_data_original:
|
2054 |
+
all_rcs_data_original[metric] = {}
|
2055 |
+
|
2056 |
+
# Store the original RCS data under the corresponding panel for the current metric
|
2057 |
+
all_rcs_data_original[metric][panel] = rcs_dict_original
|
2058 |
+
|
2059 |
+
# Ensure the dictionary has the metric as a key for modified data
|
2060 |
+
if metric not in all_rcs_data_modified:
|
2061 |
+
all_rcs_data_modified[metric] = {}
|
2062 |
+
|
2063 |
+
# Store the modified RCS data under the corresponding panel for the current metric
|
2064 |
+
for channel in rcs_dict_original:
|
2065 |
+
all_rcs_data_modified[metric][panel] = all_rcs_data_modified[
|
2066 |
+
metric
|
2067 |
+
].get(panel, {})
|
2068 |
+
|
2069 |
+
try:
|
2070 |
+
updated_rcs_dict = modified_data[metric][panel][channel]
|
2071 |
+
except:
|
2072 |
+
updated_rcs_dict = {
|
2073 |
+
"K": rcs_dict_original[channel]["K"],
|
2074 |
+
"b": rcs_dict_original[channel]["b"],
|
2075 |
+
"a": rcs_dict_original[channel]["a"],
|
2076 |
+
"x0": rcs_dict_original[channel]["x0"],
|
2077 |
+
}
|
2078 |
+
|
2079 |
+
all_rcs_data_modified[metric][panel][channel] = updated_rcs_dict
|
2080 |
+
|
2081 |
+
# Write the original RCS data
|
2082 |
+
st.session_state["project_dct"]["response_curves"][
|
2083 |
+
"original_metadata_file"
|
2084 |
+
] = all_rcs_data_original
|
2085 |
+
|
2086 |
+
# Write the modified RCS data
|
2087 |
+
st.session_state["project_dct"]["response_curves"][
|
2088 |
+
"modified_metadata_file"
|
2089 |
+
] = all_rcs_data_modified
|
2090 |
+
|
2091 |
+
|
2092 |
+
# Function to generate scenario data and store it as dictionary
|
2093 |
+
def generate_scenario_data():
|
2094 |
+
# Retrieve the list of all metric names from the specified directory
|
2095 |
+
metrics_list = get_metrics_names()
|
2096 |
+
|
2097 |
+
# Dictionary to store scenario data for all metrics and their respective panels
|
2098 |
+
all_scenario_data_original = {}
|
2099 |
+
all_scenario_data_modified = {}
|
2100 |
+
|
2101 |
+
# Iterate over each metric in the metrics list
|
2102 |
+
for metric in metrics_list:
|
2103 |
+
# Retrieve the list of panel names from the current metric's Excel file
|
2104 |
+
panel_list = get_panels_names(metric)
|
2105 |
+
|
2106 |
+
# Check if scenario_data_modified exist
|
2107 |
+
if (
|
2108 |
+
st.session_state["project_dct"]["scenario_planner"][
|
2109 |
+
"modified_metadata_file"
|
2110 |
+
]
|
2111 |
+
is not None
|
2112 |
+
):
|
2113 |
+
modified_data = st.session_state["project_dct"]["scenario_planner"][
|
2114 |
+
"modified_metadata_file"
|
2115 |
+
]
|
2116 |
+
|
2117 |
+
# Iterate over each panel in the panel list
|
2118 |
+
for panel in panel_list:
|
2119 |
+
# Initialize the original scenario data for the current panel and metric
|
2120 |
+
rcs_dict_original, scenario = initialize_data(
|
2121 |
+
panel=panel,
|
2122 |
+
metrics=metric,
|
2123 |
+
)
|
2124 |
+
|
2125 |
+
# Ensure the dictionary has the metric as a key for original data
|
2126 |
+
if metric not in all_scenario_data_original:
|
2127 |
+
all_scenario_data_original[metric] = {}
|
2128 |
+
|
2129 |
+
# Store the original scenario data under the corresponding panel for the current metric
|
2130 |
+
all_scenario_data_original[metric][panel] = class_convert_to_dict(scenario)
|
2131 |
+
|
2132 |
+
# Ensure the dictionary has the metric as a key for modified data
|
2133 |
+
if metric not in all_scenario_data_modified:
|
2134 |
+
all_scenario_data_modified[metric] = {}
|
2135 |
+
|
2136 |
+
# Store the modified scenario data under the corresponding panel for the current metric
|
2137 |
+
try:
|
2138 |
+
all_scenario_data_modified[metric][panel] = modified_data[metric][panel]
|
2139 |
+
except:
|
2140 |
+
all_scenario_data_modified[metric][panel] = class_convert_to_dict(
|
2141 |
+
scenario
|
2142 |
+
)
|
2143 |
+
|
2144 |
+
# Write the original scenario data
|
2145 |
+
st.session_state["project_dct"]["scenario_planner"][
|
2146 |
+
"original_metadata_file"
|
2147 |
+
] = all_scenario_data_original
|
2148 |
+
|
2149 |
+
# Write the modified scenario data
|
2150 |
+
st.session_state["project_dct"]["scenario_planner"][
|
2151 |
+
"modified_metadata_file"
|
2152 |
+
] = all_scenario_data_modified
|
2153 |
+
|
2154 |
+
|
2155 |
+
#############################################################################################################
|
utilities_with_panel.py
ADDED
@@ -0,0 +1,1520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scenario import numerize
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
from scenario import Channel, Scenario
|
6 |
+
import numpy as np
|
7 |
+
from plotly.subplots import make_subplots
|
8 |
+
import plotly.graph_objects as go
|
9 |
+
from scenario import class_to_dict
|
10 |
+
from collections import OrderedDict
|
11 |
+
import io
|
12 |
+
import plotly
|
13 |
+
from pathlib import Path
|
14 |
+
import pickle
|
15 |
+
import yaml
|
16 |
+
from yaml import SafeLoader
|
17 |
+
from streamlit.components.v1 import html
|
18 |
+
import smtplib
|
19 |
+
from scipy.optimize import curve_fit
|
20 |
+
from sklearn.metrics import r2_score
|
21 |
+
from scenario import class_from_dict
|
22 |
+
from utilities import retrieve_pkl_object
|
23 |
+
import os
|
24 |
+
import base64
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
# # schema = db_cred["schema"]
|
30 |
+
|
31 |
+
color_palette = [
|
32 |
+
"#F3F3F0",
|
33 |
+
"#5E7D7E",
|
34 |
+
"#2FA1FF",
|
35 |
+
"#00EDED",
|
36 |
+
"#00EAE4",
|
37 |
+
"#304550",
|
38 |
+
"#EDEBEB",
|
39 |
+
"#7FBEFD",
|
40 |
+
"#003059",
|
41 |
+
"#A2F3F3",
|
42 |
+
"#E1D6E2",
|
43 |
+
"#B6B6B6",
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
CURRENCY_INDICATOR = "$"
|
48 |
+
|
49 |
+
|
50 |
+
if "project_dct" not in st.session_state or "project_number" not in st.session_state:
|
51 |
+
st.error(
|
52 |
+
"No tuned model available. Please build and tune a model to generate response curves."
|
53 |
+
)
|
54 |
+
st.stop()
|
55 |
+
|
56 |
+
tuned_model = retrieve_pkl_object(
|
57 |
+
st.session_state["project_number"], "Model_Tuning", "tuned_model"
|
58 |
+
)
|
59 |
+
|
60 |
+
if tuned_model is None:
|
61 |
+
st.error(
|
62 |
+
"No tuned model available. Please build and tune a model to generate response curves."
|
63 |
+
)
|
64 |
+
st.stop()
|
65 |
+
|
66 |
+
|
67 |
+
def load_authenticator():
|
68 |
+
with open("config.yaml") as file:
|
69 |
+
config = yaml.load(file, Loader=SafeLoader)
|
70 |
+
st.session_state["config"] = config
|
71 |
+
authenticator = stauth.Authenticate(
|
72 |
+
config["credentials"],
|
73 |
+
config["cookie"]["name"],
|
74 |
+
config["cookie"]["key"],
|
75 |
+
config["cookie"]["expiry_days"],
|
76 |
+
config["preauthorized"],
|
77 |
+
)
|
78 |
+
st.session_state["authenticator"] = authenticator
|
79 |
+
return authenticator
|
80 |
+
|
81 |
+
|
82 |
+
def nav_page(page_name, timeout_secs=3):
|
83 |
+
nav_script = """
|
84 |
+
<script type="text/javascript">
|
85 |
+
function attempt_nav_page(page_name, start_time, timeout_secs) {
|
86 |
+
var links = window.parent.document.getElementsByTagName("a");
|
87 |
+
for (var i = 0; i < links.length; i++) {
|
88 |
+
if (links[i].href.toLowerCase().endsWith("/" + page_name.toLowerCase())) {
|
89 |
+
links[i].click();
|
90 |
+
return;
|
91 |
+
}
|
92 |
+
}
|
93 |
+
var elasped = new Date() - start_time;
|
94 |
+
if (elasped < timeout_secs * 1000) {
|
95 |
+
setTimeout(attempt_nav_page, 100, page_name, start_time, timeout_secs);
|
96 |
+
} else {
|
97 |
+
alert("Unable to navigate to page '" + page_name + "' after " + timeout_secs + " second(s).");
|
98 |
+
}
|
99 |
+
}
|
100 |
+
window.addEventListener("load", function() {
|
101 |
+
attempt_nav_page("?", new Date(), %d);
|
102 |
+
});
|
103 |
+
</script>
|
104 |
+
""" % (
|
105 |
+
page_name,
|
106 |
+
timeout_secs,
|
107 |
+
)
|
108 |
+
html(nav_script)
|
109 |
+
|
110 |
+
|
111 |
+
# def load_local_css(file_name):
|
112 |
+
# with open(file_name) as f:
|
113 |
+
# st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
114 |
+
|
115 |
+
|
116 |
+
# def set_header():
|
117 |
+
# return st.markdown(f"""<div class='main-header'>
|
118 |
+
# <h1>MMM LiME</h1>
|
119 |
+
# <img src="https://assets-global.website-files.com/64c8fffb0e95cbc525815b79/64df84637f83a891c1473c51_Vector%20(Stroke).svg ">
|
120 |
+
# </div>""", unsafe_allow_html=True)
|
121 |
+
|
122 |
+
path = os.path.dirname(__file__)
|
123 |
+
|
124 |
+
file_ = open(f"{path}/logo.png", "rb")
|
125 |
+
|
126 |
+
contents = file_.read()
|
127 |
+
|
128 |
+
data_url = base64.b64encode(contents).decode("utf-8")
|
129 |
+
|
130 |
+
file_.close()
|
131 |
+
|
132 |
+
|
133 |
+
DATA_PATH = "./data"
|
134 |
+
|
135 |
+
IMAGES_PATH = "./data/images_224_224"
|
136 |
+
|
137 |
+
|
138 |
+
# is_panel = True if len(panel_col) > 0 else False
|
139 |
+
|
140 |
+
# manoj
|
141 |
+
|
142 |
+
is_panel = False
|
143 |
+
date_col = "Date"
|
144 |
+
# is_panel = False # flag if set to true - do panel level response curves
|
145 |
+
|
146 |
+
|
147 |
+
def load_local_css(file_name):
|
148 |
+
|
149 |
+
with open(file_name) as f:
|
150 |
+
|
151 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
152 |
+
|
153 |
+
|
154 |
+
# def set_header():
|
155 |
+
|
156 |
+
# return st.markdown(f"""<div class='main-header'>
|
157 |
+
|
158 |
+
# <h1>H & M Recommendations</h1>
|
159 |
+
|
160 |
+
# <img src="data:image;base64,{data_url}", alt="Logo">
|
161 |
+
|
162 |
+
# </div>""", unsafe_allow_html=True)
|
163 |
+
path1 = os.path.dirname(__file__)
|
164 |
+
|
165 |
+
# file_1 = open(f"{path}/willbank.png", "rb")
|
166 |
+
|
167 |
+
# contents1 = file_1.read()
|
168 |
+
|
169 |
+
# data_url1 = base64.b64encode(contents1).decode("utf-8")
|
170 |
+
|
171 |
+
# file_1.close()
|
172 |
+
|
173 |
+
|
174 |
+
DATA_PATH1 = "./data"
|
175 |
+
|
176 |
+
IMAGES_PATH1 = "./data/images_224_224"
|
177 |
+
|
178 |
+
|
179 |
+
def set_header():
|
180 |
+
return st.markdown(
|
181 |
+
f"""<div class='main-header'>
|
182 |
+
<!-- <h1></h1> -->
|
183 |
+
<div >
|
184 |
+
<img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo">
|
185 |
+
</div>
|
186 |
+
<img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo">
|
187 |
+
</div>""",
|
188 |
+
unsafe_allow_html=True,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
# def set_header():
|
193 |
+
# logo_path = "./path/to/your/local/LIME_logo.png" # Replace with the actual file path
|
194 |
+
# text = "LiME"
|
195 |
+
# return st.markdown(f"""<div class='main-header'>
|
196 |
+
# <img src="data:image/png;base64,{data_url}" alt="Logo" style="float: left; margin-right: 10px; width: 100px; height: auto;">
|
197 |
+
# <h1>{text}</h1>
|
198 |
+
# </div>""", unsafe_allow_html=True)
|
199 |
+
|
200 |
+
|
201 |
+
def s_curve(x, K, b, a, x0):
|
202 |
+
return K / (1 + b * np.exp(-a * (x - x0)))
|
203 |
+
|
204 |
+
|
205 |
+
def overview_test_data_prep_panel(X, df, spends_X, date_col, panel_col, target_col):
|
206 |
+
"""
|
207 |
+
function to create the data which is used in initialize data fn
|
208 |
+
X : X test with contributions
|
209 |
+
df : originally uploaded data (media data) which has raw vars
|
210 |
+
spends_X : spends of dates in X test
|
211 |
+
"""
|
212 |
+
|
213 |
+
channels = st.session_state["channels"]
|
214 |
+
channel_list = channels.keys()
|
215 |
+
|
216 |
+
# map transformed variable to raw variable name & channel name
|
217 |
+
# mapping eg : paid_search_clicks_lag_2 (transformed var) --> paid_search_clicks (raw var) --> paid_search (channel)
|
218 |
+
variables = {}
|
219 |
+
channel_and_variables = {}
|
220 |
+
new_variables = {}
|
221 |
+
new_channels_and_variables = {}
|
222 |
+
|
223 |
+
for transformed_var in [
|
224 |
+
col
|
225 |
+
for col in X.drop(columns=[date_col, panel_col, target_col, "pred"]).columns
|
226 |
+
if "_contr" not in col
|
227 |
+
]:
|
228 |
+
if len([col for col in df.columns if col in transformed_var]) == 1:
|
229 |
+
raw_var = [col for col in df.columns if col in transformed_var][0]
|
230 |
+
|
231 |
+
variables[transformed_var] = raw_var
|
232 |
+
|
233 |
+
# Check if the list comprehension result is not empty before accessing the first element
|
234 |
+
channels_list = [
|
235 |
+
channel for channel, raw_vars in channels.items() if raw_var in raw_vars
|
236 |
+
]
|
237 |
+
if channels_list:
|
238 |
+
channel_and_variables[raw_var] = channels_list[0]
|
239 |
+
else:
|
240 |
+
# Handle the case where channels_list is empty
|
241 |
+
# You might want to set a default value or handle it according to your use case
|
242 |
+
channel_and_variables[raw_var] = None
|
243 |
+
else:
|
244 |
+
new_variables[transformed_var] = transformed_var
|
245 |
+
new_channels_and_variables[transformed_var] = "base"
|
246 |
+
|
247 |
+
# Raw DF
|
248 |
+
raw_X = pd.merge(
|
249 |
+
X[[date_col, panel_col]],
|
250 |
+
df[[date_col, panel_col] + list(variables.values())],
|
251 |
+
how="left",
|
252 |
+
on=[date_col, panel_col],
|
253 |
+
)
|
254 |
+
assert len(raw_X) == len(X)
|
255 |
+
|
256 |
+
raw_X_cols = []
|
257 |
+
for i in raw_X.columns:
|
258 |
+
if i in channel_and_variables.keys():
|
259 |
+
raw_X_cols.append(channel_and_variables[i])
|
260 |
+
else:
|
261 |
+
raw_X_cols.append(i)
|
262 |
+
raw_X.columns = raw_X_cols
|
263 |
+
|
264 |
+
# Contribution DF
|
265 |
+
contr_X = X[
|
266 |
+
[date_col, panel_col]
|
267 |
+
+ [col for col in X.columns if "_contr" in col and "sum_" not in col]
|
268 |
+
].copy()
|
269 |
+
# if "base_contr" in contr_X.columns:
|
270 |
+
# contr_X.rename(columns={'base_contr':'const_contr'},inplace=True)
|
271 |
+
# # new_variables = [
|
272 |
+
# col
|
273 |
+
# for col in contr_X.columns
|
274 |
+
# if "_flag" in col.lower() or "trend" in col.lower() or "sine" in col.lower()
|
275 |
+
# ]
|
276 |
+
# if len(new_variables) > 0:
|
277 |
+
# contr_X["const"] = contr_X[["panel_effect"] + new_variables].sum(axis=1)
|
278 |
+
# contr_X.drop(columns=["panel_effect"], inplace=True)
|
279 |
+
# contr_X.drop(columns=new_variables, inplace=True)
|
280 |
+
# else:
|
281 |
+
# contr_X.rename(columns={"panel_effect": "const"}, inplace=True)
|
282 |
+
|
283 |
+
new_contr_X_cols = []
|
284 |
+
for col in contr_X.columns:
|
285 |
+
col_clean = col.replace("_contr", "")
|
286 |
+
new_contr_X_cols.append(col_clean)
|
287 |
+
contr_X.columns = new_contr_X_cols
|
288 |
+
|
289 |
+
contr_X_cols = []
|
290 |
+
for i in contr_X.columns:
|
291 |
+
if i in variables.keys():
|
292 |
+
contr_X_cols.append(channel_and_variables[variables[i]])
|
293 |
+
else:
|
294 |
+
contr_X_cols.append(i)
|
295 |
+
contr_X.columns = contr_X_cols
|
296 |
+
|
297 |
+
# Spends DF
|
298 |
+
spends_X.columns = [col.replace("_cost", "") for col in spends_X.columns]
|
299 |
+
|
300 |
+
raw_X.rename(columns={"date": "Date"}, inplace=True)
|
301 |
+
contr_X.rename(columns={"date": "Date"}, inplace=True)
|
302 |
+
spends_X.rename(columns={"date": "Week"}, inplace=True)
|
303 |
+
|
304 |
+
spends_X.columns = [
|
305 |
+
col.replace("spends_", "") if col.startswith("spends_") else col
|
306 |
+
for col in spends_X.columns
|
307 |
+
]
|
308 |
+
|
309 |
+
# Rename column to 'Date'
|
310 |
+
spends_X.rename(columns={"Week": "Date"}, inplace=True)
|
311 |
+
|
312 |
+
# Remove "response_metric_" from the start and "_total" from the end
|
313 |
+
if str(target_col).startswith("response_metric_"):
|
314 |
+
target_col = target_col.replace("response_metric_", "", 1)
|
315 |
+
|
316 |
+
# Remove the last 6 characters (length of "_total")
|
317 |
+
if str(target_col).endswith("_total"):
|
318 |
+
target_col = target_col[:-6]
|
319 |
+
|
320 |
+
# Rename column to 'Date'
|
321 |
+
spends_X.rename(columns={"Week": "Date"}, inplace=True)
|
322 |
+
|
323 |
+
# Save raw, spends and contribution data
|
324 |
+
st.session_state["project_dct"]["current_media_performance"]["model_outputs"][
|
325 |
+
target_col
|
326 |
+
] = {
|
327 |
+
"raw_data": raw_X,
|
328 |
+
"contribution_data": contr_X,
|
329 |
+
"spends_data": spends_X,
|
330 |
+
}
|
331 |
+
|
332 |
+
# Clear page metadata
|
333 |
+
st.session_state["project_dct"]["scenario_planner"]["original_metadata_file"] = None
|
334 |
+
st.session_state["project_dct"]["response_curves"]["original_metadata_file"] = None
|
335 |
+
|
336 |
+
# st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = None
|
337 |
+
# st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] = None
|
338 |
+
|
339 |
+
|
340 |
+
def overview_test_data_prep_nonpanel(X, df, spends_X, date_col, target_col):
|
341 |
+
"""
|
342 |
+
function to create the data which is used in initialize data fn
|
343 |
+
"""
|
344 |
+
|
345 |
+
# with open(
|
346 |
+
# os.path.join(st.session_state["project_path"], "channel_groups.pkl"), "rb"
|
347 |
+
# ) as f:
|
348 |
+
# channels = pickle.load(f)
|
349 |
+
|
350 |
+
# channel_list = list(channels.keys())
|
351 |
+
channels = st.session_state["channels"]
|
352 |
+
channel_list = channels.keys()
|
353 |
+
|
354 |
+
# map transformed variable to raw variable name & channel name
|
355 |
+
# mapping eg : paid_search_clicks_lag_2 (transformed var) --> paid_search_clicks (raw var) --> paid_search (channel)
|
356 |
+
variables = {}
|
357 |
+
channel_and_variables = {}
|
358 |
+
new_variables = {}
|
359 |
+
new_channels_and_variables = {}
|
360 |
+
|
361 |
+
cols_to_del = list(
|
362 |
+
set([date_col, target_col, "pred"]).intersection((set(X.columns)))
|
363 |
+
)
|
364 |
+
|
365 |
+
# remove exog cols from RAW data (exog cols are part of base, raw data needs media vars only)
|
366 |
+
all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
|
367 |
+
all_exog_vars = [
|
368 |
+
var.lower()
|
369 |
+
.replace(".", "_")
|
370 |
+
.replace("@", "_")
|
371 |
+
.replace(" ", "_")
|
372 |
+
.replace("-", "")
|
373 |
+
.replace(":", "")
|
374 |
+
.replace("__", "_")
|
375 |
+
for var in all_exog_vars
|
376 |
+
]
|
377 |
+
exog_cols = []
|
378 |
+
if len(all_exog_vars) > 0:
|
379 |
+
for col in X.columns:
|
380 |
+
if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
|
381 |
+
exog_cols.append(col)
|
382 |
+
cols_to_del = cols_to_del + exog_cols
|
383 |
+
for transformed_var in [
|
384 |
+
col for col in X.drop(columns=cols_to_del).columns if "_contr" not in col
|
385 |
+
]: # also has 'const'
|
386 |
+
|
387 |
+
if (
|
388 |
+
len([col for col in df.columns if col in transformed_var]) == 1
|
389 |
+
): # col is raw var
|
390 |
+
raw_var = [col for col in df.columns if col in transformed_var][0]
|
391 |
+
variables[transformed_var] = raw_var
|
392 |
+
channel_and_variables[raw_var] = [
|
393 |
+
channel for channel, raw_vars in channels.items() if raw_var in raw_vars
|
394 |
+
][0]
|
395 |
+
else: # when no corresponding raw var then base
|
396 |
+
new_variables[transformed_var] = transformed_var
|
397 |
+
new_channels_and_variables[transformed_var] = "base"
|
398 |
+
|
399 |
+
# Raw DF
|
400 |
+
raw_X = pd.merge(
|
401 |
+
X[[date_col]],
|
402 |
+
df[[date_col] + list(variables.values())],
|
403 |
+
how="left",
|
404 |
+
on=[date_col],
|
405 |
+
)
|
406 |
+
assert len(raw_X) == len(X)
|
407 |
+
|
408 |
+
raw_X_cols = []
|
409 |
+
for i in raw_X.columns:
|
410 |
+
if i in channel_and_variables.keys():
|
411 |
+
raw_X_cols.append(channel_and_variables[i])
|
412 |
+
else:
|
413 |
+
raw_X_cols.append(i)
|
414 |
+
raw_X.columns = raw_X_cols
|
415 |
+
|
416 |
+
# Contribution DF
|
417 |
+
contr_X = X[
|
418 |
+
[date_col] + [col for col in X.columns if "_contr" in col and "sum_" not in col]
|
419 |
+
].copy()
|
420 |
+
|
421 |
+
new_variables = [
|
422 |
+
col
|
423 |
+
for col in contr_X.columns
|
424 |
+
if "_flag" in col.lower() or "trend" in col.lower() or "sine" in col.lower()
|
425 |
+
]
|
426 |
+
if (
|
427 |
+
len(new_variables) > 0
|
428 |
+
): # if new vars are available, their contributions should be added to base (called const)
|
429 |
+
contr_X["const_contr"] = contr_X[["const_contr"] + new_variables].sum(axis=1)
|
430 |
+
contr_X.drop(columns=new_variables, inplace=True)
|
431 |
+
|
432 |
+
new_contr_X_cols = []
|
433 |
+
for col in contr_X.columns:
|
434 |
+
col_clean = col.replace("_contr", "")
|
435 |
+
new_contr_X_cols.append(col_clean)
|
436 |
+
contr_X.columns = new_contr_X_cols
|
437 |
+
|
438 |
+
contr_X_cols = []
|
439 |
+
for i in contr_X.columns:
|
440 |
+
if i in variables.keys():
|
441 |
+
contr_X_cols.append(channel_and_variables[variables[i]])
|
442 |
+
else:
|
443 |
+
contr_X_cols.append(i)
|
444 |
+
contr_X.columns = contr_X_cols
|
445 |
+
|
446 |
+
# Spends DF
|
447 |
+
# spends_X.columns = [
|
448 |
+
# col.replace("_cost", "").replace("_spends", "").replace("_spend", "")
|
449 |
+
# for col in spends_X.columns
|
450 |
+
# ]
|
451 |
+
spends_X_col_map = {
|
452 |
+
col: bucket
|
453 |
+
for col in spends_X.columns
|
454 |
+
for bucket in channels.keys()
|
455 |
+
if col in channels[bucket]
|
456 |
+
}
|
457 |
+
spends_X.rename(columns=spends_X_col_map, inplace=True)
|
458 |
+
|
459 |
+
raw_X.rename(columns={"date": "Date"}, inplace=True)
|
460 |
+
contr_X.rename(columns={"date": "Date"}, inplace=True)
|
461 |
+
spends_X.rename(columns={"date": "Week"}, inplace=True)
|
462 |
+
|
463 |
+
spends_X.columns = [
|
464 |
+
col.replace("spends_", "") if col.startswith("spends_") else col
|
465 |
+
for col in spends_X.columns
|
466 |
+
]
|
467 |
+
|
468 |
+
# Rename column to 'Date'
|
469 |
+
spends_X.rename(columns={"Week": "Date"}, inplace=True)
|
470 |
+
|
471 |
+
# Remove "response_metric_" from the start and "_total" from the end
|
472 |
+
if str(target_col).startswith("response_metric_"):
|
473 |
+
target_col = target_col.replace("response_metric_", "", 1)
|
474 |
+
|
475 |
+
# Remove the last 6 characters (length of "_total")
|
476 |
+
if str(target_col).endswith("_total"):
|
477 |
+
target_col = target_col[:-6]
|
478 |
+
|
479 |
+
# Rename column to 'Date'
|
480 |
+
spends_X.rename(columns={"Week": "Date"}, inplace=True)
|
481 |
+
|
482 |
+
# Save raw, spends and contribution data
|
483 |
+
st.session_state["project_dct"]["current_media_performance"]["model_outputs"][
|
484 |
+
target_col
|
485 |
+
] = {
|
486 |
+
"raw_data": raw_X,
|
487 |
+
"contribution_data": contr_X,
|
488 |
+
"spends_data": spends_X,
|
489 |
+
}
|
490 |
+
|
491 |
+
# Clear page metadata
|
492 |
+
st.session_state["project_dct"]["scenario_planner"]["original_metadata_file"] = None
|
493 |
+
st.session_state["project_dct"]["response_curves"]["original_metadata_file"] = None
|
494 |
+
|
495 |
+
# st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = None
|
496 |
+
# st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] = None
|
497 |
+
|
498 |
+
|
499 |
+
def initialize_data_cmp(target_col, is_panel, panel_col, start_date, end_date):
|
500 |
+
start_date = pd.to_datetime(start_date)
|
501 |
+
end_date = pd.to_datetime(end_date)
|
502 |
+
|
503 |
+
# Remove "response_metric_" from the start and "_total" from the end
|
504 |
+
if str(target_col).startswith("response_metric_"):
|
505 |
+
target_col = target_col.replace("response_metric_", "", 1)
|
506 |
+
|
507 |
+
# Remove the last 6 characters (length of "_total")
|
508 |
+
if str(target_col).endswith("_total"):
|
509 |
+
target_col = target_col[:-6]
|
510 |
+
|
511 |
+
# Extract dataframes for raw data, spend input, and contribution data
|
512 |
+
raw_df = st.session_state["project_dct"]["current_media_performance"][
|
513 |
+
"model_outputs"
|
514 |
+
][target_col]["raw_data"]
|
515 |
+
spend_df = st.session_state["project_dct"]["current_media_performance"][
|
516 |
+
"model_outputs"
|
517 |
+
][target_col]["spends_data"]
|
518 |
+
contri_df = st.session_state["project_dct"]["current_media_performance"][
|
519 |
+
"model_outputs"
|
520 |
+
][target_col]["contribution_data"]
|
521 |
+
|
522 |
+
# Remove unnecessary columns
|
523 |
+
unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
|
524 |
+
exclude_columns = ["Date"] + unnamed_cols
|
525 |
+
|
526 |
+
if is_panel:
|
527 |
+
exclude_columns = exclude_columns + [panel_col]
|
528 |
+
|
529 |
+
# Aggregate all 3 dfs to date level (from date-panel level)
|
530 |
+
raw_df[date_col] = pd.to_datetime(raw_df[date_col])
|
531 |
+
raw_df = raw_df[raw_df[date_col] >= start_date]
|
532 |
+
raw_df = raw_df[raw_df[date_col] <= end_date].reset_index(drop=True)
|
533 |
+
raw_df_aggregations = {c: "sum" for c in raw_df.columns if c not in exclude_columns}
|
534 |
+
raw_df = raw_df.groupby(date_col).agg(raw_df_aggregations).reset_index()
|
535 |
+
|
536 |
+
contri_df[date_col] = pd.to_datetime(contri_df[date_col])
|
537 |
+
contri_df = contri_df[contri_df[date_col] >= start_date]
|
538 |
+
contri_df = contri_df[contri_df[date_col] <= end_date].reset_index(drop=True)
|
539 |
+
contri_df_aggregations = {
|
540 |
+
c: "sum" for c in contri_df.columns if c not in exclude_columns
|
541 |
+
}
|
542 |
+
contri_df = contri_df.groupby(date_col).agg(contri_df_aggregations).reset_index()
|
543 |
+
|
544 |
+
input_df = raw_df.sort_values(by=[date_col])
|
545 |
+
|
546 |
+
output_df = contri_df.sort_values(by=[date_col])
|
547 |
+
spend_df["Date"] = pd.to_datetime(
|
548 |
+
spend_df["Date"], format="%Y-%m-%d", errors="coerce"
|
549 |
+
)
|
550 |
+
spend_df = spend_df[spend_df["Date"] >= start_date]
|
551 |
+
spend_df = spend_df[spend_df["Date"] <= end_date].reset_index(drop=True)
|
552 |
+
spend_df_aggregations = {
|
553 |
+
c: "sum" for c in spend_df.columns if c not in exclude_columns
|
554 |
+
}
|
555 |
+
spend_df = spend_df.groupby("Date").agg(spend_df_aggregations).reset_index()
|
556 |
+
# spend_df['Week'] = pd.to_datetime(spend_df['Week'], errors='coerce')
|
557 |
+
# spend_df = spend_df.sort_values(by='Week')
|
558 |
+
|
559 |
+
channel_list = [col for col in input_df.columns if col not in exclude_columns]
|
560 |
+
|
561 |
+
response_curves = {}
|
562 |
+
mapes = {}
|
563 |
+
rmses = {}
|
564 |
+
upper_limits = {}
|
565 |
+
powers = {}
|
566 |
+
r2 = {}
|
567 |
+
conv_rates = {}
|
568 |
+
output_cols = []
|
569 |
+
channels = {}
|
570 |
+
sales = None
|
571 |
+
dates = input_df.Date.values
|
572 |
+
actual_output_dic = {}
|
573 |
+
actual_input_dic = {}
|
574 |
+
|
575 |
+
# channel_list=['programmatic']
|
576 |
+
infeasible_channels = [
|
577 |
+
c
|
578 |
+
for c in contri_df.select_dtypes(include=["float", "int"]).columns
|
579 |
+
if contri_df[c].sum() <= 0
|
580 |
+
]
|
581 |
+
|
582 |
+
channel_list = list(set(channel_list) - set(infeasible_channels))
|
583 |
+
|
584 |
+
for inp_col in channel_list:
|
585 |
+
spends = input_df[inp_col].values
|
586 |
+
|
587 |
+
x = spends.copy()
|
588 |
+
|
589 |
+
# upper limit for penalty
|
590 |
+
upper_limits[inp_col] = 2 * x.max()
|
591 |
+
|
592 |
+
out_col = inp_col
|
593 |
+
y = output_df[out_col].values.copy()
|
594 |
+
|
595 |
+
actual_output_dic[inp_col] = y.copy()
|
596 |
+
actual_input_dic[inp_col] = x.copy()
|
597 |
+
##output cols aggregation
|
598 |
+
output_cols.append(out_col)
|
599 |
+
|
600 |
+
## scale the input
|
601 |
+
power = np.ceil(np.log(x.max()) / np.log(10)) - 3
|
602 |
+
if power >= 0:
|
603 |
+
x = x / 10**power
|
604 |
+
|
605 |
+
x = x.astype("float64")
|
606 |
+
y = y.astype("float64")
|
607 |
+
|
608 |
+
if y.max() <= 0.01:
|
609 |
+
bounds = (
|
610 |
+
(0, 0, 0, 0),
|
611 |
+
(3 * 0.01, 1000, 1, x.max() if x.max() > 0 else 0.01),
|
612 |
+
)
|
613 |
+
else:
|
614 |
+
bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
|
615 |
+
|
616 |
+
params, _ = curve_fit(
|
617 |
+
s_curve,
|
618 |
+
x,
|
619 |
+
y,
|
620 |
+
p0=(2 * y.max(), 0.01, 1e-5, x.max()),
|
621 |
+
bounds=bounds,
|
622 |
+
maxfev=int(1e5),
|
623 |
+
)
|
624 |
+
mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
|
625 |
+
rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
|
626 |
+
r2_ = r2_score(y, s_curve(x, *params))
|
627 |
+
|
628 |
+
response_curves[inp_col] = {
|
629 |
+
"K": params[0],
|
630 |
+
"b": params[1],
|
631 |
+
"a": params[2],
|
632 |
+
"x0": params[3],
|
633 |
+
}
|
634 |
+
mapes[inp_col] = mape
|
635 |
+
rmses[inp_col] = rmse
|
636 |
+
r2[inp_col] = r2_
|
637 |
+
powers[inp_col] = power
|
638 |
+
|
639 |
+
conv = spend_df[inp_col].sum() / max(input_df[inp_col].sum(), 1e-9)
|
640 |
+
conv_rates[inp_col] = conv
|
641 |
+
|
642 |
+
channel = Channel(
|
643 |
+
name=inp_col,
|
644 |
+
dates=dates,
|
645 |
+
spends=spends,
|
646 |
+
# conversion_rate = np.mean(list(conv_rates[inp_col].values())),
|
647 |
+
conversion_rate=conv_rates[inp_col],
|
648 |
+
response_curve_type="s-curve",
|
649 |
+
response_curve_params={
|
650 |
+
"K": params[0],
|
651 |
+
"b": params[1],
|
652 |
+
"a": params[2],
|
653 |
+
"x0": params[3],
|
654 |
+
},
|
655 |
+
bounds=np.array([-10, 10]),
|
656 |
+
correction=y - s_curve(x, *params),
|
657 |
+
)
|
658 |
+
channels[inp_col] = channel
|
659 |
+
if sales is None:
|
660 |
+
sales = channel.actual_sales
|
661 |
+
else:
|
662 |
+
sales += channel.actual_sales
|
663 |
+
|
664 |
+
other_contributions = (
|
665 |
+
output_df.drop([*output_cols], axis=1).sum(axis=1, numeric_only=True).values
|
666 |
+
)
|
667 |
+
correction = output_df.drop(["Date"], axis=1).sum(axis=1).values - (
|
668 |
+
sales + other_contributions
|
669 |
+
)
|
670 |
+
|
671 |
+
# Testing
|
672 |
+
# scenario_test_df = pd.DataFrame(
|
673 |
+
# columns=["other_contributions", "correction", "sales"]
|
674 |
+
# )
|
675 |
+
# scenario_test_df["other_contributions"] = other_contributions
|
676 |
+
# scenario_test_df["correction"] = correction
|
677 |
+
# scenario_test_df["sales"] = sales
|
678 |
+
# scenario_test_df.to_csv("test\scenario_test_df.csv", index=False)
|
679 |
+
# output_df.to_csv("test\output_df.csv", index=False)
|
680 |
+
|
681 |
+
scenario = Scenario(
|
682 |
+
name="default",
|
683 |
+
channels=channels,
|
684 |
+
constant=other_contributions,
|
685 |
+
correction=correction,
|
686 |
+
)
|
687 |
+
## setting session variables
|
688 |
+
st.session_state["initialized"] = True
|
689 |
+
st.session_state["actual_df"] = input_df
|
690 |
+
st.session_state["raw_df"] = raw_df
|
691 |
+
st.session_state["contri_df"] = output_df
|
692 |
+
default_scenario_dict = class_to_dict(scenario)
|
693 |
+
st.session_state["default_scenario_dict"] = default_scenario_dict
|
694 |
+
st.session_state["scenario"] = scenario
|
695 |
+
st.session_state["channels_list"] = channel_list
|
696 |
+
st.session_state["optimization_channels"] = {
|
697 |
+
channel_name: False for channel_name in channel_list
|
698 |
+
}
|
699 |
+
st.session_state["rcs"] = response_curves
|
700 |
+
|
701 |
+
# orig_rcs_path = os.path.join(
|
702 |
+
# st.session_state["project_path"], f"orig_rcs_{target_col}_{panel_col}.json"
|
703 |
+
# )
|
704 |
+
# if Path(orig_rcs_path).exists():
|
705 |
+
# with open(orig_rcs_path, "r") as f:
|
706 |
+
# st.session_state["orig_rcs"] = json.load(f)
|
707 |
+
# else:
|
708 |
+
# st.session_state["orig_rcs"] = response_curves.copy()
|
709 |
+
# with open(orig_rcs_path, "w") as f:
|
710 |
+
# json.dump(st.session_state["orig_rcs"], f)
|
711 |
+
|
712 |
+
st.session_state["powers"] = powers
|
713 |
+
st.session_state["actual_contribution_df"] = pd.DataFrame(actual_output_dic)
|
714 |
+
st.session_state["actual_input_df"] = pd.DataFrame(actual_input_dic)
|
715 |
+
|
716 |
+
for channel in channels.values():
|
717 |
+
st.session_state[channel.name] = numerize(
|
718 |
+
channel.actual_total_spends * channel.conversion_rate, 1
|
719 |
+
)
|
720 |
+
|
721 |
+
# st.session_state["xlsx_buffer"] = io.BytesIO()
|
722 |
+
#
|
723 |
+
# if Path("../saved_scenarios.pkl").exists():
|
724 |
+
# with open("../saved_scenarios.pkl", "rb") as f:
|
725 |
+
# st.session_state["saved_scenarios"] = pickle.load(f)
|
726 |
+
# else:
|
727 |
+
# st.session_state["saved_scenarios"] = OrderedDict()
|
728 |
+
#
|
729 |
+
# st.session_state["total_spends_change"] = 0
|
730 |
+
# st.session_state["optimization_channels"] = {
|
731 |
+
# channel_name: False for channel_name in channel_list
|
732 |
+
# }
|
733 |
+
# st.session_state["disable_download_button"] = True
|
734 |
+
|
735 |
+
|
736 |
+
# def initialize_data():
|
737 |
+
# # fetch data from excel
|
738 |
+
# output = pd.read_excel('data.xlsx',sheet_name=None)
|
739 |
+
# raw_df = output['RAW DATA MMM']
|
740 |
+
# contribution_df = output['CONTRIBUTION MMM']
|
741 |
+
# Revenue_df = output['Revenue']
|
742 |
+
|
743 |
+
# ## channels to be shows
|
744 |
+
# channel_list = []
|
745 |
+
# for col in raw_df.columns:
|
746 |
+
# if 'click' in col.lower() or 'spend' in col.lower() or 'imp' in col.lower():
|
747 |
+
#
|
748 |
+
# channel_list.append(col)
|
749 |
+
# else:
|
750 |
+
# pass
|
751 |
+
|
752 |
+
# ## NOTE : Considered only Desktop spends for all calculations
|
753 |
+
# acutal_df = raw_df[raw_df.Region == 'Desktop'].copy()
|
754 |
+
# ## NOTE : Considered one year of data
|
755 |
+
# acutal_df = acutal_df[acutal_df.Date>'2020-12-31']
|
756 |
+
# actual_df = acutal_df.drop('Region',axis=1).sort_values(by='Date')[[*channel_list,'Date']]
|
757 |
+
|
758 |
+
# ##load response curves
|
759 |
+
# with open('./grammarly_response_curves.json','r') as f:
|
760 |
+
# response_curves = json.load(f)
|
761 |
+
|
762 |
+
# ## create channel dict for scenario creation
|
763 |
+
# dates = actual_df.Date.values
|
764 |
+
# channels = {}
|
765 |
+
# rcs = {}
|
766 |
+
# constant = 0.
|
767 |
+
# for i,info_dict in enumerate(response_curves):
|
768 |
+
# name = info_dict.get('name')
|
769 |
+
# response_curve_type = info_dict.get('response_curve')
|
770 |
+
# response_curve_params = info_dict.get('params')
|
771 |
+
# rcs[name] = response_curve_params
|
772 |
+
# if name != 'constant':
|
773 |
+
# spends = actual_df[name].values
|
774 |
+
# channel = Channel(name=name,dates=dates,
|
775 |
+
# spends=spends,
|
776 |
+
# response_curve_type=response_curve_type,
|
777 |
+
# response_curve_params=response_curve_params,
|
778 |
+
# bounds=np.array([-30,30]))
|
779 |
+
|
780 |
+
# channels[name] = channel
|
781 |
+
# else:
|
782 |
+
# constant = info_dict.get('value',0.) * len(dates)
|
783 |
+
|
784 |
+
# ## create scenario
|
785 |
+
# scenario = Scenario(name='default', channels=channels, constant=constant)
|
786 |
+
# default_scenario_dict = class_to_dict(scenario)
|
787 |
+
|
788 |
+
|
789 |
+
# ## setting session variables
|
790 |
+
# st.session_state['initialized'] = True
|
791 |
+
# st.session_state['actual_df'] = actual_df
|
792 |
+
# st.session_state['raw_df'] = raw_df
|
793 |
+
# st.session_state['default_scenario_dict'] = default_scenario_dict
|
794 |
+
# st.session_state['scenario'] = scenario
|
795 |
+
# st.session_state['channels_list'] = channel_list
|
796 |
+
# st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
|
797 |
+
# st.session_state['rcs'] = rcs
|
798 |
+
# for channel in channels.values():
|
799 |
+
# if channel.name not in st.session_state:
|
800 |
+
# st.session_state[channel.name] = float(channel.actual_total_spends)
|
801 |
+
|
802 |
+
# if 'xlsx_buffer' not in st.session_state:
|
803 |
+
# st.session_state['xlsx_buffer'] = io.BytesIO()
|
804 |
+
|
805 |
+
# ## for saving scenarios
|
806 |
+
# if 'saved_scenarios' not in st.session_state:
|
807 |
+
# if Path('../saved_scenarios.pkl').exists():
|
808 |
+
# with open('../saved_scenarios.pkl','rb') as f:
|
809 |
+
# st.session_state['saved_scenarios'] = pickle.load(f)
|
810 |
+
|
811 |
+
# else:
|
812 |
+
# st.session_state['saved_scenarios'] = OrderedDict()
|
813 |
+
|
814 |
+
# if 'total_spends_change' not in st.session_state:
|
815 |
+
# st.session_state['total_spends_change'] = 0
|
816 |
+
|
817 |
+
# if 'optimization_channels' not in st.session_state:
|
818 |
+
# st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
|
819 |
+
|
820 |
+
|
821 |
+
# if 'disable_download_button' not in st.session_state:
|
822 |
+
# st.session_state['disable_download_button'] = True
|
823 |
+
def create_channel_summary(scenario, target_column):
|
824 |
+
def round_off(x, round_off_decimal=0):
|
825 |
+
# round off
|
826 |
+
try:
|
827 |
+
x = float(x)
|
828 |
+
if x < 1 and x > 0:
|
829 |
+
round_off_decimal = int(np.floor(np.abs(np.log10(x)))) + max(
|
830 |
+
round_off_decimal, 1
|
831 |
+
)
|
832 |
+
x = np.round(x, round_off_decimal)
|
833 |
+
elif x < 0 and x > -1:
|
834 |
+
round_off_decimal = int(np.floor(np.abs(np.log10(np.abs(x))))) + max(
|
835 |
+
round_off_decimal, 1
|
836 |
+
)
|
837 |
+
x = -np.round(x, round_off_decimal)
|
838 |
+
else:
|
839 |
+
x = np.round(x, round_off_decimal)
|
840 |
+
return x
|
841 |
+
except:
|
842 |
+
return x
|
843 |
+
|
844 |
+
summary_columns = []
|
845 |
+
|
846 |
+
actual_spends_rows = []
|
847 |
+
|
848 |
+
actual_sales_rows = []
|
849 |
+
|
850 |
+
actual_roi_rows = []
|
851 |
+
|
852 |
+
for channel in scenario.channels.values():
|
853 |
+
|
854 |
+
name_mod = channel.name.replace("_", " ")
|
855 |
+
|
856 |
+
if name_mod.lower().endswith(" imp"):
|
857 |
+
name_mod = name_mod.replace("Imp", " Impressions")
|
858 |
+
|
859 |
+
summary_columns.append(name_mod)
|
860 |
+
|
861 |
+
actual_spends_rows.append(
|
862 |
+
format_numbers(float(channel.actual_total_spends * channel.conversion_rate))
|
863 |
+
)
|
864 |
+
|
865 |
+
actual_sales_rows.append(format_numbers((float(channel.actual_total_sales))))
|
866 |
+
|
867 |
+
roi = (channel.actual_total_sales) / (
|
868 |
+
channel.actual_total_spends * channel.conversion_rate
|
869 |
+
)
|
870 |
+
if roi < 0.0001:
|
871 |
+
roi = 0
|
872 |
+
|
873 |
+
actual_roi_rows.append(
|
874 |
+
decimal_formater(
|
875 |
+
str(round_off(roi, 2)),
|
876 |
+
n_decimals=2,
|
877 |
+
)
|
878 |
+
)
|
879 |
+
|
880 |
+
actual_summary_df = pd.DataFrame(
|
881 |
+
[
|
882 |
+
summary_columns,
|
883 |
+
actual_spends_rows,
|
884 |
+
actual_sales_rows,
|
885 |
+
actual_roi_rows,
|
886 |
+
]
|
887 |
+
).T
|
888 |
+
|
889 |
+
actual_summary_df.columns = ["Channel", "Spends", target_column, "ROI"]
|
890 |
+
|
891 |
+
actual_summary_df[target_column] = actual_summary_df[target_column].map(
|
892 |
+
lambda x: str(x)[1:]
|
893 |
+
)
|
894 |
+
|
895 |
+
return actual_summary_df
|
896 |
+
|
897 |
+
|
898 |
+
# def create_channel_summary(scenario):
|
899 |
+
#
|
900 |
+
# # Provided data
|
901 |
+
# data = {
|
902 |
+
# 'Channel': ['Paid Search', 'Ga will cid baixo risco', 'Digital tactic others', 'Fb la tier 1', 'Fb la tier 2', 'Paid social others', 'Programmatic', 'Kwai', 'Indicacao', 'Infleux', 'Influencer'],
|
903 |
+
# 'Spends': ['$ 11.3K', '$ 155.2K', '$ 50.7K', '$ 125.4K', '$ 125.2K', '$ 105K', '$ 3.3M', '$ 47.5K', '$ 55.9K', '$ 632.3K', '$ 48.3K'],
|
904 |
+
# 'Revenue': ['558.0K', '3.5M', '5.2M', '3.1M', '3.1M', '2.1M', '20.8M', '1.6M', '728.4K', '22.9M', '4.8M']
|
905 |
+
# }
|
906 |
+
#
|
907 |
+
# # Create DataFrame
|
908 |
+
# df = pd.DataFrame(data)
|
909 |
+
#
|
910 |
+
# # Convert currency strings to numeric values
|
911 |
+
# df['Spends'] = df['Spends'].replace({'\$': '', 'K': '*1e3', 'M': '*1e6'}, regex=True).map(pd.eval).astype(int)
|
912 |
+
# df['Revenue'] = df['Revenue'].replace({'\$': '', 'K': '*1e3', 'M': '*1e6'}, regex=True).map(pd.eval).astype(int)
|
913 |
+
#
|
914 |
+
# # Calculate ROI
|
915 |
+
# df['ROI'] = ((df['Revenue'] - df['Spends']) / df['Spends'])
|
916 |
+
#
|
917 |
+
# # Format columns
|
918 |
+
# format_currency = lambda x: f"${x:,.1f}"
|
919 |
+
# format_roi = lambda x: f"{x:.1f}"
|
920 |
+
#
|
921 |
+
# df['Spends'] = ['$ 11.3K', '$ 155.2K', '$ 50.7K', '$ 125.4K', '$ 125.2K', '$ 105K', '$ 3.3M', '$ 47.5K', '$ 55.9K', '$ 632.3K', '$ 48.3K']
|
922 |
+
# df['Revenue'] = ['$ 536.3K', '$ 3.4M', '$ 5M', '$ 3M', '$ 3M', '$ 2M', '$ 20M', '$ 1.5M', '$ 7.1M', '$ 22M', '$ 4.6M']
|
923 |
+
# df['ROI'] = df['ROI'].apply(format_roi)
|
924 |
+
#
|
925 |
+
# return df
|
926 |
+
|
927 |
+
|
928 |
+
# @st.cache_resource()
|
929 |
+
# def create_contribution_pie(_scenario):
|
930 |
+
# colors_map = {
|
931 |
+
# col: color
|
932 |
+
# for col, color in zip(
|
933 |
+
# st.session_state["channels_list"],
|
934 |
+
# plotly.colors.n_colors(
|
935 |
+
# plotly.colors.hex_to_rgb("#BE6468"),
|
936 |
+
# plotly.colors.hex_to_rgb("#E7B8B7"),
|
937 |
+
# 20,
|
938 |
+
# ),
|
939 |
+
# )
|
940 |
+
# }
|
941 |
+
# total_contribution_fig = make_subplots(
|
942 |
+
# rows=1,
|
943 |
+
# cols=2,
|
944 |
+
# subplot_titles=["Spends", "Revenue"],
|
945 |
+
# specs=[[{"type": "pie"}, {"type": "pie"}]],
|
946 |
+
# )
|
947 |
+
# total_contribution_fig.add_trace(
|
948 |
+
# go.Pie(
|
949 |
+
# labels=[
|
950 |
+
# channel_name_formating(channel_name)
|
951 |
+
# for channel_name in st.session_state["channels_list"]
|
952 |
+
# ]
|
953 |
+
# + ["Non Media"],
|
954 |
+
# values=[
|
955 |
+
# round(
|
956 |
+
# _scenario.channels[channel_name].actual_total_spends
|
957 |
+
# * _scenario.channels[channel_name].conversion_rate,
|
958 |
+
# 1,
|
959 |
+
# )
|
960 |
+
# for channel_name in st.session_state["channels_list"]
|
961 |
+
# ]
|
962 |
+
# + [0],
|
963 |
+
# marker=dict(
|
964 |
+
# colors=[
|
965 |
+
# plotly.colors.label_rgb(colors_map[channel_name])
|
966 |
+
# for channel_name in st.session_state["channels_list"]
|
967 |
+
# ]
|
968 |
+
# + ["#F0F0F0"]
|
969 |
+
# ),
|
970 |
+
# hole=0.3,
|
971 |
+
# ),
|
972 |
+
# row=1,
|
973 |
+
# col=1,
|
974 |
+
# )
|
975 |
+
|
976 |
+
# total_contribution_fig.add_trace(
|
977 |
+
# go.Pie(
|
978 |
+
# labels=[
|
979 |
+
# channel_name_formating(channel_name)
|
980 |
+
# for channel_name in st.session_state["channels_list"]
|
981 |
+
# ]
|
982 |
+
# + ["Non Media"],
|
983 |
+
# values=[
|
984 |
+
# _scenario.channels[channel_name].actual_total_sales
|
985 |
+
# for channel_name in st.session_state["channels_list"]
|
986 |
+
# ]
|
987 |
+
# + [_scenario.correction.sum() + _scenario.constant.sum()],
|
988 |
+
# hole=0.3,
|
989 |
+
# ),
|
990 |
+
# row=1,
|
991 |
+
# col=2,
|
992 |
+
# )
|
993 |
+
|
994 |
+
# total_contribution_fig.update_traces(
|
995 |
+
# textposition="inside", texttemplate="%{percent:.1%}"
|
996 |
+
# )
|
997 |
+
# total_contribution_fig.update_layout(
|
998 |
+
# uniformtext_minsize=12,
|
999 |
+
# title="Channel contribution",
|
1000 |
+
# uniformtext_mode="hide",
|
1001 |
+
# )
|
1002 |
+
# return total_contribution_fig
|
1003 |
+
|
1004 |
+
|
1005 |
+
# @st.cache_resource()
|
1006 |
+
# def create_contribution_pie(_scenario):
|
1007 |
+
# colors = plotly.colors.qualitative.Plotly # A diverse color palette
|
1008 |
+
# colors_map = {
|
1009 |
+
# col: colors[i % len(colors)]
|
1010 |
+
# for i, col in enumerate(st.session_state["channels_list"])
|
1011 |
+
# }
|
1012 |
+
|
1013 |
+
# total_contribution_fig = make_subplots(
|
1014 |
+
# rows=1,
|
1015 |
+
# cols=2,
|
1016 |
+
# subplot_titles=["Spends", "Revenue"],
|
1017 |
+
# specs=[[{"type": "pie"}, {"type": "pie"}]],
|
1018 |
+
# )
|
1019 |
+
|
1020 |
+
# total_contribution_fig.add_trace(
|
1021 |
+
# go.Pie(
|
1022 |
+
# labels=[
|
1023 |
+
# channel_name_formating(channel_name)
|
1024 |
+
# for channel_name in st.session_state["channels_list"]
|
1025 |
+
# ]
|
1026 |
+
# + ["Non Media"],
|
1027 |
+
# values=[
|
1028 |
+
# round(
|
1029 |
+
# _scenario.channels[channel_name].actual_total_spends
|
1030 |
+
# * _scenario.channels[channel_name].conversion_rate,
|
1031 |
+
# 1,
|
1032 |
+
# )
|
1033 |
+
# for channel_name in st.session_state["channels_list"]
|
1034 |
+
# ]
|
1035 |
+
# + [0],
|
1036 |
+
# marker=dict(
|
1037 |
+
# colors=[
|
1038 |
+
# colors_map[channel_name]
|
1039 |
+
# for channel_name in st.session_state["channels_list"]
|
1040 |
+
# ]
|
1041 |
+
# + ["#F0F0F0"]
|
1042 |
+
# ),
|
1043 |
+
# hole=0.3,
|
1044 |
+
# ),
|
1045 |
+
# row=1,
|
1046 |
+
# col=1,
|
1047 |
+
# )
|
1048 |
+
|
1049 |
+
# total_contribution_fig.add_trace(
|
1050 |
+
# go.Pie(
|
1051 |
+
# labels=[
|
1052 |
+
# channel_name_formating(channel_name)
|
1053 |
+
# for channel_name in st.session_state["channels_list"]
|
1054 |
+
# ]
|
1055 |
+
# + ["Non Media"],
|
1056 |
+
# values=[
|
1057 |
+
# _scenario.channels[channel_name].actual_total_sales
|
1058 |
+
# for channel_name in st.session_state["channels_list"]
|
1059 |
+
# ]
|
1060 |
+
# + [_scenario.correction.sum() + _scenario.constant.sum()],
|
1061 |
+
# marker=dict(
|
1062 |
+
# colors=[
|
1063 |
+
# colors_map[channel_name]
|
1064 |
+
# for channel_name in st.session_state["channels_list"]
|
1065 |
+
# ]
|
1066 |
+
# + ["#F0F0F0"]
|
1067 |
+
# ),
|
1068 |
+
# hole=0.3,
|
1069 |
+
# ),
|
1070 |
+
# row=1,
|
1071 |
+
# col=2,
|
1072 |
+
# )
|
1073 |
+
|
1074 |
+
# total_contribution_fig.update_traces(
|
1075 |
+
# textposition="inside", texttemplate="%{percent:.1%}"
|
1076 |
+
# )
|
1077 |
+
# total_contribution_fig.update_layout(
|
1078 |
+
# uniformtext_minsize=12,
|
1079 |
+
# title="Channel contribution",
|
1080 |
+
# uniformtext_mode="hide",
|
1081 |
+
# )
|
1082 |
+
# return total_contribution_fig
|
1083 |
+
|
1084 |
+
|
1085 |
+
# @st.cache_resource()
|
1086 |
+
def create_contribution_pie(_scenario, target_col):
|
1087 |
+
colors = plotly.colors.qualitative.Plotly # A diverse color palette
|
1088 |
+
colors_map = {
|
1089 |
+
col: colors[i % len(colors)]
|
1090 |
+
for i, col in enumerate(st.session_state["channels_list"])
|
1091 |
+
}
|
1092 |
+
|
1093 |
+
spends_values = [
|
1094 |
+
round(
|
1095 |
+
_scenario.channels[channel_name].actual_total_spends
|
1096 |
+
* _scenario.channels[channel_name].conversion_rate,
|
1097 |
+
1,
|
1098 |
+
)
|
1099 |
+
for channel_name in st.session_state["channels_list"]
|
1100 |
+
]
|
1101 |
+
spends_values.append(0) # Adding Non Media value
|
1102 |
+
|
1103 |
+
revenue_values = [
|
1104 |
+
_scenario.channels[channel_name].actual_total_sales
|
1105 |
+
for channel_name in st.session_state["channels_list"]
|
1106 |
+
]
|
1107 |
+
revenue_values.append(
|
1108 |
+
_scenario.correction.sum() + _scenario.constant.sum()
|
1109 |
+
) # Adding Non Media value
|
1110 |
+
|
1111 |
+
total_contribution_fig = make_subplots(
|
1112 |
+
rows=1,
|
1113 |
+
cols=2,
|
1114 |
+
subplot_titles=["Spend", target_col],
|
1115 |
+
specs=[[{"type": "pie"}, {"type": "pie"}]],
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
total_contribution_fig.add_trace(
|
1119 |
+
go.Pie(
|
1120 |
+
labels=[
|
1121 |
+
channel_name_formating(channel_name)
|
1122 |
+
for channel_name in st.session_state["channels_list"]
|
1123 |
+
]
|
1124 |
+
+ ["Non Media"],
|
1125 |
+
values=spends_values,
|
1126 |
+
marker=dict(
|
1127 |
+
colors=[
|
1128 |
+
colors_map[channel_name]
|
1129 |
+
for channel_name in st.session_state["channels_list"]
|
1130 |
+
]
|
1131 |
+
+ ["#F0F0F0"]
|
1132 |
+
),
|
1133 |
+
hole=0.3,
|
1134 |
+
),
|
1135 |
+
row=1,
|
1136 |
+
col=1,
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
total_contribution_fig.add_trace(
|
1140 |
+
go.Pie(
|
1141 |
+
labels=[
|
1142 |
+
channel_name_formating(channel_name)
|
1143 |
+
for channel_name in st.session_state["channels_list"]
|
1144 |
+
]
|
1145 |
+
+ ["Non Media"],
|
1146 |
+
values=revenue_values,
|
1147 |
+
marker=dict(
|
1148 |
+
colors=[
|
1149 |
+
colors_map[channel_name]
|
1150 |
+
for channel_name in st.session_state["channels_list"]
|
1151 |
+
]
|
1152 |
+
+ ["#F0F0F0"]
|
1153 |
+
),
|
1154 |
+
hole=0.3,
|
1155 |
+
),
|
1156 |
+
row=1,
|
1157 |
+
col=2,
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
total_contribution_fig.update_traces(
|
1161 |
+
textposition="inside", texttemplate="%{percent:.1%}"
|
1162 |
+
)
|
1163 |
+
total_contribution_fig.update_layout(
|
1164 |
+
uniformtext_minsize=12,
|
1165 |
+
title="Channel contribution",
|
1166 |
+
uniformtext_mode="hide",
|
1167 |
+
)
|
1168 |
+
return total_contribution_fig
|
1169 |
+
|
1170 |
+
|
1171 |
+
# def create_contribuion_stacked_plot(scenario):
|
1172 |
+
# weekly_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "bar"}, {"type": "bar"}]])
|
1173 |
+
# raw_df = st.session_state['raw_df']
|
1174 |
+
# df = raw_df.sort_values(by='Date')
|
1175 |
+
# x = df.Date
|
1176 |
+
# weekly_spends_data = []
|
1177 |
+
# weekly_sales_data = []
|
1178 |
+
# for channel_name in st.session_state['channels_list']:
|
1179 |
+
# weekly_spends_data.append((go.Bar(x=x,
|
1180 |
+
# y=scenario.channels[channel_name].actual_spends * scenario.channels[channel_name].conversion_rate,
|
1181 |
+
# name=channel_name_formating(channel_name),
|
1182 |
+
# hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
|
1183 |
+
# legendgroup=channel_name)))
|
1184 |
+
# weekly_sales_data.append((go.Bar(x=x,
|
1185 |
+
# y=scenario.channels[channel_name].actual_sales,
|
1186 |
+
# name=channel_name_formating(channel_name),
|
1187 |
+
# hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1188 |
+
# legendgroup=channel_name, showlegend=False)))
|
1189 |
+
# for _d in weekly_spends_data:
|
1190 |
+
# weekly_contribution_fig.add_trace(_d, row=1, col=1)
|
1191 |
+
# for _d in weekly_sales_data:
|
1192 |
+
# weekly_contribution_fig.add_trace(_d, row=1, col=2)
|
1193 |
+
# weekly_contribution_fig.add_trace(go.Bar(x=x,
|
1194 |
+
# y=scenario.constant + scenario.correction,
|
1195 |
+
# name='Non Media',
|
1196 |
+
# hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), row=1, col=2)
|
1197 |
+
# weekly_contribution_fig.update_layout(barmode='stack', title='Channel contribuion by week', xaxis_title='Date')
|
1198 |
+
# weekly_contribution_fig.update_xaxes(showgrid=False)
|
1199 |
+
# weekly_contribution_fig.update_yaxes(showgrid=False)
|
1200 |
+
# return weekly_contribution_fig
|
1201 |
+
|
1202 |
+
# @st.cache_resource()
|
1203 |
+
# def create_channel_spends_sales_plot(channel):
|
1204 |
+
# if channel is not None:
|
1205 |
+
# x = channel.dates
|
1206 |
+
# _spends = channel.actual_spends * channel.conversion_rate
|
1207 |
+
# _sales = channel.actual_sales
|
1208 |
+
# channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1209 |
+
# channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
|
1210 |
+
# channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#005b96'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
|
1211 |
+
# channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
|
1212 |
+
# channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1213 |
+
# channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1214 |
+
# else:
|
1215 |
+
# raw_df = st.session_state['raw_df']
|
1216 |
+
# df = raw_df.sort_values(by='Date')
|
1217 |
+
# x = df.Date
|
1218 |
+
# scenario = class_from_dict(st.session_state['default_scenario_dict'])
|
1219 |
+
# _sales = scenario.constant + scenario.correction
|
1220 |
+
# channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1221 |
+
# channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
|
1222 |
+
# # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#15C39A'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
|
1223 |
+
# channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
|
1224 |
+
# channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1225 |
+
# channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1226 |
+
# return channel_sales_spends_fig
|
1227 |
+
|
1228 |
+
|
1229 |
+
# Define a shared color palette
|
1230 |
+
|
1231 |
+
|
1232 |
+
# def create_contribution_pie():
|
1233 |
+
# color_palette = ['#F3F3F0', '#5E7D7E', '#2FA1FF', '#00EDED', '#00EAE4', '#304550', '#EDEBEB', '#7FBEFD', '#003059', '#A2F3F3', '#E1D6E2', '#B6B6B6']
|
1234 |
+
# total_contribution_fig = make_subplots(rows=1, cols=2, subplot_titles=['Spends', 'Revenue'], specs=[[{"type": "pie"}, {"type": "pie"}]])
|
1235 |
+
#
|
1236 |
+
# channels_list = ['Paid Search', 'Ga will cid baixo risco', 'Digital tactic others', 'Fb la tier 1', 'Fb la tier 2', 'Paid social others', 'Programmatic', 'Kwai', 'Indicacao', 'Infleux', 'Influencer', 'Non Media']
|
1237 |
+
#
|
1238 |
+
# # Assign colors from the limited palette to channels
|
1239 |
+
# colors_map = {col: color_palette[i % len(color_palette)] for i, col in enumerate(channels_list)}
|
1240 |
+
# colors_map['Non Media'] = color_palette[5] # Assign fixed green color for 'Non Media'
|
1241 |
+
#
|
1242 |
+
# # Hardcoded values for Spends and Revenue
|
1243 |
+
# spends_values = [0.5, 3.36, 1.1, 2.7, 2.7, 2.27, 70.6, 1, 1, 13.7, 1, 0]
|
1244 |
+
# revenue_values = [1, 4, 5, 3, 3, 2, 50.8, 1.5, 0.7, 13, 0, 16]
|
1245 |
+
#
|
1246 |
+
# # Add trace for Spends pie chart
|
1247 |
+
# total_contribution_fig.add_trace(
|
1248 |
+
# go.Pie(
|
1249 |
+
# labels=[channel_name for channel_name in channels_list],
|
1250 |
+
# values=spends_values,
|
1251 |
+
# marker=dict(colors=[colors_map[channel_name] for channel_name in channels_list]),
|
1252 |
+
# hole=0.3
|
1253 |
+
# ),
|
1254 |
+
# row=1, col=1
|
1255 |
+
# )
|
1256 |
+
#
|
1257 |
+
# # Add trace for Revenue pie chart
|
1258 |
+
# total_contribution_fig.add_trace(
|
1259 |
+
# go.Pie(
|
1260 |
+
# labels=[channel_name for channel_name in channels_list],
|
1261 |
+
# values=revenue_values,
|
1262 |
+
# marker=dict(colors=[colors_map[channel_name] for channel_name in channels_list]),
|
1263 |
+
# hole=0.3
|
1264 |
+
# ),
|
1265 |
+
# row=1, col=2
|
1266 |
+
# )
|
1267 |
+
#
|
1268 |
+
# total_contribution_fig.update_traces(textposition='inside', texttemplate='%{percent:.1%}')
|
1269 |
+
# total_contribution_fig.update_layout(uniformtext_minsize=12, title='Channel contribution', uniformtext_mode='hide')
|
1270 |
+
# return total_contribution_fig
|
1271 |
+
|
1272 |
+
# @st.cache_resource()
|
1273 |
+
# def create_contribuion_stacked_plot(_scenario):
|
1274 |
+
# weekly_contribution_fig = make_subplots(
|
1275 |
+
# rows=1,
|
1276 |
+
# cols=2,
|
1277 |
+
# subplot_titles=["Spends", "Revenue"],
|
1278 |
+
# specs=[[{"type": "bar"}, {"type": "bar"}]],
|
1279 |
+
# )
|
1280 |
+
# raw_df = st.session_state["raw_df"]
|
1281 |
+
# df = raw_df.sort_values(by="Date")
|
1282 |
+
# x = df.Date
|
1283 |
+
# weekly_spends_data = []
|
1284 |
+
# weekly_sales_data = []
|
1285 |
+
|
1286 |
+
# for i, channel_name in enumerate(st.session_state["channels_list"]):
|
1287 |
+
# color = color_palette[i % len(color_palette)]
|
1288 |
+
|
1289 |
+
# weekly_spends_data.append(
|
1290 |
+
# go.Bar(
|
1291 |
+
# x=x,
|
1292 |
+
# y=_scenario.channels[channel_name].actual_spends
|
1293 |
+
# * _scenario.channels[channel_name].conversion_rate,
|
1294 |
+
# name=channel_name_formating(channel_name),
|
1295 |
+
# hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
|
1296 |
+
# legendgroup=channel_name,
|
1297 |
+
# marker_color=color,
|
1298 |
+
# )
|
1299 |
+
# )
|
1300 |
+
|
1301 |
+
# weekly_sales_data.append(
|
1302 |
+
# go.Bar(
|
1303 |
+
# x=x,
|
1304 |
+
# y=_scenario.channels[channel_name].actual_sales,
|
1305 |
+
# name=channel_name_formating(channel_name),
|
1306 |
+
# hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1307 |
+
# legendgroup=channel_name,
|
1308 |
+
# showlegend=False,
|
1309 |
+
# marker_color=color,
|
1310 |
+
# )
|
1311 |
+
# )
|
1312 |
+
|
1313 |
+
# for _d in weekly_spends_data:
|
1314 |
+
# weekly_contribution_fig.add_trace(_d, row=1, col=1)
|
1315 |
+
# for _d in weekly_sales_data:
|
1316 |
+
# weekly_contribution_fig.add_trace(_d, row=1, col=2)
|
1317 |
+
|
1318 |
+
# weekly_contribution_fig.add_trace(
|
1319 |
+
# go.Bar(
|
1320 |
+
# x=x,
|
1321 |
+
# y=_scenario.constant + _scenario.correction,
|
1322 |
+
# name="Non Media",
|
1323 |
+
# hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1324 |
+
# marker_color=color_palette[-1],
|
1325 |
+
# ),
|
1326 |
+
# row=1,
|
1327 |
+
# col=2,
|
1328 |
+
# )
|
1329 |
+
|
1330 |
+
# weekly_contribution_fig.update_layout(
|
1331 |
+
# barmode="stack",
|
1332 |
+
# title="Channel contribution by week",
|
1333 |
+
# xaxis_title="Date",
|
1334 |
+
# )
|
1335 |
+
# weekly_contribution_fig.update_xaxes(showgrid=False)
|
1336 |
+
# weekly_contribution_fig.update_yaxes(showgrid=False)
|
1337 |
+
# return weekly_contribution_fig
|
1338 |
+
|
1339 |
+
|
1340 |
+
# @st.cache_resource()
|
1341 |
+
def create_contribuion_stacked_plot(_scenario, target_col):
|
1342 |
+
color_palette = plotly.colors.qualitative.Plotly # A diverse color palette
|
1343 |
+
|
1344 |
+
weekly_contribution_fig = make_subplots(
|
1345 |
+
rows=1,
|
1346 |
+
cols=2,
|
1347 |
+
subplot_titles=["Spend", target_col],
|
1348 |
+
specs=[[{"type": "bar"}, {"type": "bar"}]],
|
1349 |
+
)
|
1350 |
+
|
1351 |
+
raw_df = st.session_state["raw_df"]
|
1352 |
+
df = raw_df.sort_values(by="Date")
|
1353 |
+
x = df.Date
|
1354 |
+
weekly_spends_data = []
|
1355 |
+
weekly_sales_data = []
|
1356 |
+
|
1357 |
+
for i, channel_name in enumerate(st.session_state["channels_list"]):
|
1358 |
+
color = color_palette[i % len(color_palette)]
|
1359 |
+
|
1360 |
+
weekly_spends_data.append(
|
1361 |
+
go.Bar(
|
1362 |
+
x=x,
|
1363 |
+
y=_scenario.channels[channel_name].actual_spends
|
1364 |
+
* _scenario.channels[channel_name].conversion_rate,
|
1365 |
+
name=channel_name_formating(channel_name),
|
1366 |
+
hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
|
1367 |
+
legendgroup=channel_name,
|
1368 |
+
marker_color=color,
|
1369 |
+
)
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
weekly_sales_data.append(
|
1373 |
+
go.Bar(
|
1374 |
+
x=x,
|
1375 |
+
y=_scenario.channels[channel_name].actual_sales,
|
1376 |
+
name=channel_name_formating(channel_name),
|
1377 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1378 |
+
legendgroup=channel_name,
|
1379 |
+
showlegend=False,
|
1380 |
+
marker_color=color,
|
1381 |
+
)
|
1382 |
+
)
|
1383 |
+
|
1384 |
+
for _d in weekly_spends_data:
|
1385 |
+
weekly_contribution_fig.add_trace(_d, row=1, col=1)
|
1386 |
+
for _d in weekly_sales_data:
|
1387 |
+
weekly_contribution_fig.add_trace(_d, row=1, col=2)
|
1388 |
+
|
1389 |
+
weekly_contribution_fig.add_trace(
|
1390 |
+
go.Bar(
|
1391 |
+
x=x,
|
1392 |
+
y=_scenario.constant + _scenario.correction,
|
1393 |
+
name="Non Media",
|
1394 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1395 |
+
marker_color=color_palette[-1],
|
1396 |
+
),
|
1397 |
+
row=1,
|
1398 |
+
col=2,
|
1399 |
+
)
|
1400 |
+
|
1401 |
+
weekly_contribution_fig.update_layout(
|
1402 |
+
barmode="stack",
|
1403 |
+
title="Channel contribution by week",
|
1404 |
+
xaxis_title="Date",
|
1405 |
+
)
|
1406 |
+
weekly_contribution_fig.update_xaxes(showgrid=False)
|
1407 |
+
weekly_contribution_fig.update_yaxes(showgrid=False)
|
1408 |
+
return weekly_contribution_fig
|
1409 |
+
|
1410 |
+
|
1411 |
+
def create_channel_spends_sales_plot(channel, target_col):
|
1412 |
+
if channel is not None:
|
1413 |
+
x = channel.dates
|
1414 |
+
_spends = channel.actual_spends * channel.conversion_rate
|
1415 |
+
_sales = channel.actual_sales
|
1416 |
+
channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1417 |
+
channel_sales_spends_fig.add_trace(
|
1418 |
+
go.Bar(
|
1419 |
+
x=x,
|
1420 |
+
y=_sales,
|
1421 |
+
marker_color=color_palette[
|
1422 |
+
3
|
1423 |
+
], # You can choose a color from the palette
|
1424 |
+
name=target_col,
|
1425 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1426 |
+
),
|
1427 |
+
secondary_y=False,
|
1428 |
+
)
|
1429 |
+
|
1430 |
+
channel_sales_spends_fig.add_trace(
|
1431 |
+
go.Scatter(
|
1432 |
+
x=x,
|
1433 |
+
y=_spends,
|
1434 |
+
line=dict(
|
1435 |
+
color=color_palette[2]
|
1436 |
+
), # You can choose another color from the palette
|
1437 |
+
name="Spends",
|
1438 |
+
hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
|
1439 |
+
),
|
1440 |
+
secondary_y=True,
|
1441 |
+
)
|
1442 |
+
|
1443 |
+
channel_sales_spends_fig.update_layout(
|
1444 |
+
xaxis_title="Date",
|
1445 |
+
yaxis_title=target_col,
|
1446 |
+
yaxis2_title="Spend ($)",
|
1447 |
+
title="Weekly Channel Spends and " + target_col,
|
1448 |
+
)
|
1449 |
+
channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1450 |
+
channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1451 |
+
else:
|
1452 |
+
raw_df = st.session_state["raw_df"]
|
1453 |
+
df = raw_df.sort_values(by="Date")
|
1454 |
+
x = df.Date
|
1455 |
+
scenario = class_from_dict(st.session_state["default_scenario_dict"])
|
1456 |
+
_sales = scenario.constant + scenario.correction
|
1457 |
+
channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
|
1458 |
+
channel_sales_spends_fig.add_trace(
|
1459 |
+
go.Bar(
|
1460 |
+
x=x,
|
1461 |
+
y=_sales,
|
1462 |
+
marker_color=color_palette[
|
1463 |
+
0
|
1464 |
+
], # You can choose a color from the palette
|
1465 |
+
name="Revenue",
|
1466 |
+
hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
|
1467 |
+
),
|
1468 |
+
secondary_y=False,
|
1469 |
+
)
|
1470 |
+
|
1471 |
+
channel_sales_spends_fig.update_layout(
|
1472 |
+
xaxis_title="Date",
|
1473 |
+
yaxis_title="Revenue",
|
1474 |
+
yaxis2_title="Spend ($)",
|
1475 |
+
title="Channel spends and Revenue week-wise",
|
1476 |
+
)
|
1477 |
+
channel_sales_spends_fig.update_xaxes(showgrid=False)
|
1478 |
+
channel_sales_spends_fig.update_yaxes(showgrid=False)
|
1479 |
+
|
1480 |
+
return channel_sales_spends_fig
|
1481 |
+
|
1482 |
+
|
1483 |
+
def format_numbers(value, n_decimals=1, include_indicator=True):
|
1484 |
+
if include_indicator:
|
1485 |
+
return f"{CURRENCY_INDICATOR} {numerize(value,n_decimals)}"
|
1486 |
+
else:
|
1487 |
+
return f"{numerize(value,n_decimals)}"
|
1488 |
+
|
1489 |
+
|
1490 |
+
def decimal_formater(num_string, n_decimals=1):
|
1491 |
+
parts = num_string.split(".")
|
1492 |
+
if len(parts) == 1:
|
1493 |
+
return num_string + "." + "0" * n_decimals
|
1494 |
+
else:
|
1495 |
+
to_be_padded = n_decimals - len(parts[-1])
|
1496 |
+
if to_be_padded > 0:
|
1497 |
+
return num_string + "0" * to_be_padded
|
1498 |
+
else:
|
1499 |
+
return num_string
|
1500 |
+
|
1501 |
+
|
1502 |
+
def channel_name_formating(channel_name):
|
1503 |
+
name_mod = channel_name.replace("_", " ")
|
1504 |
+
if name_mod.lower().endswith(" imp"):
|
1505 |
+
name_mod = name_mod.replace("Imp", "Spend")
|
1506 |
+
elif name_mod.lower().endswith(" clicks"):
|
1507 |
+
name_mod = name_mod.replace("Clicks", "Spend")
|
1508 |
+
return name_mod
|
1509 |
+
|
1510 |
+
|
1511 |
+
def send_email(email, message):
|
1512 |
+
s = smtplib.SMTP("smtp.gmail.com", 587)
|
1513 |
+
s.starttls()
|
1514 |
+
s.login("[email protected]", "jgydhpfusuremcol")
|
1515 |
+
s.sendmail("[email protected]", email, message)
|
1516 |
+
s.quit()
|
1517 |
+
|
1518 |
+
|
1519 |
+
# if __name__ == "__main__":
|
1520 |
+
# initialize_data()
|