MediaMixOptimization / pages /5_AI Model_Tuning.py
samkeet's picture
Upload 40 files
00b00eb verified
"""
MMO Build Sprint 3
date :
changes : capability to tune MixedLM as well as simple LR in the same page
"""
import os
import streamlit as st
import pandas as pd
from data_analysis import format_numbers
import pickle
from utilities import set_header, load_local_css
import statsmodels.api as sm
import re
from sklearn.preprocessing import MaxAbsScaler
import matplotlib.pyplot as plt
from statsmodels.stats.outliers_influence import variance_inflation_factor
import statsmodels.formula.api as smf
from data_prep import *
import sqlite3
from utilities import (
set_header,
load_local_css,
update_db,
project_selection,
retrieve_pkl_object,
)
import numpy as np
from post_gres_cred import db_cred
import re
from constants import (
NUM_FLAG_COLS_TO_DISPLAY,
HALF_YEAR_THRESHOLD,
FULL_YEAR_THRESHOLD,
TREND_MIN,
ANNUAL_FREQUENCY,
QTR_FREQUENCY_FACTOR,
HALF_YEARLY_FREQUENCY_FACTOR,
)
from log_application import log_message
import sys, traceback
schema = db_cred["schema"]
st.set_option("deprecation.showPyplotGlobalUse", False)
st.set_page_config(
page_title="AI Model Tuning",
page_icon=":shark:",
layout="wide",
initial_sidebar_state="collapsed",
)
load_local_css("styles.css")
set_header()
# Define functions
# Get random effect from MixedLM Model
def get_random_effects(media_data, panel_col, _mdf):
# create an empty dataframe
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
# Iterate over all panel values and add to dataframe
for i, market in enumerate(media_data[panel_col].unique()):
intercept = _mdf.random_effects[market].values[0]
random_eff_df.loc[i, "random_effect"] = intercept
random_eff_df.loc[i, panel_col] = market
return random_eff_df
# Predict on df using MixedLM model
def mdf_predict(X_df, mdf, random_eff_df):
# Create a copy of input df and predict using MixedLM model i.e fixed effect
X = X_df.copy()
X["fixed_effect"] = mdf.predict(X)
# Merge random effects
X = pd.merge(X, random_eff_df, on=panel_col, how="left")
# Get final predictions by adding random effect to fixed effect
X["pred"] = X["fixed_effect"] + X["random_effect"]
# Drop intermediate columns
X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
return X["pred"]
def format_display(inp):
# Format display titles
return inp.title().replace("_", " ").strip()
if "username" not in st.session_state:
st.session_state["username"] = None
if "project_name" not in st.session_state:
st.session_state["project_name"] = None
if "project_dct" not in st.session_state:
project_selection()
st.stop()
if "Flags" not in st.session_state:
st.session_state["Flags"] = {}
try:
# Check Authentications
if "username" in st.session_state and st.session_state["username"] is not None:
if (
retrieve_pkl_object(
st.session_state["project_number"], "Model_Build", "best_models", schema
)
is None
): # db
st.error("Please save a model before tuning")
log_message(
"warning",
"No models saved",
"Model Tuning",
)
st.stop()
# Read previous progress (persistence)
if (
"session_state_saved"
in st.session_state["project_dct"]["model_build"].keys()
):
for key in [
"Model",
"date",
"saved_model_names",
"media_data",
"X_test_spends",
"spends_data",
]:
if key not in st.session_state:
st.session_state[key] = st.session_state["project_dct"][
"model_build"
]["session_state_saved"][key]
st.session_state["bin_dict"] = st.session_state["project_dct"][
"model_build"
]["session_state_saved"]["bin_dict"]
if (
"used_response_metrics" not in st.session_state
or st.session_state["used_response_metrics"] == []
):
st.session_state["used_response_metrics"] = st.session_state[
"project_dct"
]["model_build"]["session_state_saved"]["used_response_metrics"]
else:
st.error("Please load a session with a built model")
log_message(
"error",
"Session state saved not found in Project Dictionary",
"Model Tuning",
)
st.stop()
for key in ["select_all_flags_check", "selected_flags", "sel_model"]:
if key not in st.session_state["project_dct"]["model_tuning"].keys():
st.session_state["project_dct"]["model_tuning"][key] = {}
# is_panel = st.session_state['is_panel']
# panel_col = 'markets' # set the panel column
date_col = "date"
# set the panel column
panel_col = "panel"
is_panel = (
True if st.session_state["media_data"][panel_col].nunique() > 1 else False
)
if "Model_Tuned" not in st.session_state:
st.session_state["Model_Tuned"] = {}
cols1 = st.columns([2, 1])
with cols1[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with cols1[1]:
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
st.title("AI Model Tuning")
# flag indicating there is not tuned model till now
if "is_tuned_model" not in st.session_state:
st.session_state["is_tuned_model"] = {}
# # Read all saved models
model_dict = retrieve_pkl_object(
st.session_state["project_number"], "Model_Build", "best_models", schema
)
saved_models = model_dict.keys()
# Get list of response metrics
st.session_state["used_response_metrics"] = list(
set([model.split("__")[1] for model in saved_models])
)
# Select previously selected response_metric (persistence)
default_target_idx = (
st.session_state["project_dct"]["model_tuning"].get("sel_target_col", None)
if st.session_state["project_dct"]["model_tuning"].get(
"sel_target_col", None
)
is not None
else st.session_state["used_response_metrics"][0]
)
# Dropdown to select response metric
sel_target_col = st.selectbox(
"Select the response metric",
st.session_state["used_response_metrics"],
index=st.session_state["used_response_metrics"].index(default_target_idx),
format_func=format_display,
)
# Format selected response metrics (target col)
target_col = (
sel_target_col.lower()
.replace(" ", "_")
.replace("-", "")
.replace(":", "")
.replace("__", "_")
)
st.session_state["project_dct"]["model_tuning"][
"sel_target_col"
] = sel_target_col
# Look through all saved models, only show saved models of the selected resp metric (target_col)
# Get a list of models saved for selected response metric
required_saved_models = [
m.split("__")[0] for m in saved_models if m.split("__")[1] == target_col
]
# Get previously seelcted model if available (persistence)
default_model_idx = st.session_state["project_dct"]["model_tuning"][
"sel_model"
].get(sel_target_col, required_saved_models[0])
sel_model = st.selectbox(
"Select the model to tune",
required_saved_models,
index=required_saved_models.index(default_model_idx),
)
st.session_state["project_dct"]["model_tuning"]["sel_model"][
sel_target_col
] = default_model_idx
sel_model_dict = model_dict[
sel_model + "__" + target_col
] # get the model obj of the selected model
X_train = sel_model_dict["X_train"]
X_test = sel_model_dict["X_test"]
y_train = sel_model_dict["y_train"]
y_test = sel_model_dict["y_test"]
df = st.session_state["media_data"]
st.markdown("### Event Flags")
st.markdown("Helps in quantifying the impact of specific occurrences of events")
try:
# Dropdown to add event flags
with st.expander("Apply Event Flags"):
model = sel_model_dict["Model_object"]
date = st.session_state["date"]
date = pd.to_datetime(date)
X_train = sel_model_dict["X_train"]
features_set = sel_model_dict["feature_set"]
col = st.columns(3)
# Get date range
min_date = min(date).date()
max_date = max(date).date()
# Get previously selected start and end date of flag (persistence)
start_date_default = (
st.session_state["project_dct"]["model_tuning"].get(
"start_date_default"
)
if st.session_state["project_dct"]["model_tuning"].get(
"start_date_default"
)
is not None
else min_date
)
start_date_default = (
start_date_default if start_date_default > min_date else min_date
)
start_date_default = (
start_date_default if start_date_default < max_date else min_date
)
end_date_default = (
st.session_state["project_dct"]["model_tuning"].get(
"end_date_default"
)
if st.session_state["project_dct"]["model_tuning"].get(
"end_date_default"
)
is not None
else max_date
)
end_date_default = (
end_date_default if end_date_default > min_date else max_date
)
end_date_default = (
end_date_default if end_date_default < max_date else max_date
)
# Flag start and end date input boxes
with col[0]:
start_date = st.date_input(
"Select Start Date",
start_date_default,
min_value=min_date,
max_value=max_date,
)
if (start_date < min_date) or (start_date > max_date):
st.error(
"Please select dates in the range of the dates in the data"
)
st.stop()
with col[1]:
# Check if end date default > selected start date
end_date_default = (
end_date_default
if pd.Timestamp(end_date_default) >= pd.Timestamp(start_date)
else start_date
)
end_date = st.date_input(
"Select End Date",
end_date_default,
min_value=max(
pd.to_datetime(min_date), pd.to_datetime(start_date)
),
max_value=pd.to_datetime(max_date),
)
if (
(start_date < min_date)
or (end_date < min_date)
or (start_date > max_date)
or (end_date > max_date)
):
st.error(
"Please select dates in the range of the dates in the data"
)
st.stop()
if end_date < start_date:
st.error("Please select end date after start date")
st.stop()
with col[2]:
# Get default value of repeat check box (persistence)
repeat_default = (
st.session_state["project_dct"]["model_tuning"].get(
"repeat_default"
)
if st.session_state["project_dct"]["model_tuning"].get(
"repeat_default"
)
is not None
else "No"
)
repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1
repeat = st.selectbox(
"Repeat Annually", ["Yes", "No"], index=repeat_default_idx
)
# Update selected values to session dictionary (persistence)
st.session_state["project_dct"]["model_tuning"][
"start_date_default"
] = start_date
st.session_state["project_dct"]["model_tuning"][
"end_date_default"
] = end_date
st.session_state["project_dct"]["model_tuning"][
"repeat_default"
] = repeat
if repeat == "Yes":
repeat = True
else:
repeat = False
if "flags" in st.session_state["project_dct"]["model_tuning"].keys():
st.session_state["Flags"] = st.session_state["project_dct"][
"model_tuning"
]["flags"]
if is_panel:
# Create flag on Train
met, line_values, fig_flag = plot_actual_vs_predicted(
X_train[date_col],
y_train,
model.fittedvalues,
model,
target_column=sel_target_col,
flag=(start_date, end_date),
repeat_all_years=repeat,
is_panel=True,
)
st.plotly_chart(fig_flag, use_container_width=True)
# create flag on test
met, test_line_values, fig_flag = plot_actual_vs_predicted(
X_test[date_col],
y_test,
sel_model_dict["pred_test"],
model,
target_column=sel_target_col,
flag=(start_date, end_date),
repeat_all_years=repeat,
is_panel=True,
)
else:
pred_train = model.predict(X_train[features_set])
# Create flag on Train
met, line_values, fig_flag = plot_actual_vs_predicted(
X_train[date_col],
y_train,
pred_train,
model,
flag=(start_date, end_date),
repeat_all_years=repeat,
is_panel=False,
)
st.plotly_chart(fig_flag, use_container_width=True)
# create flag on test
pred_test = model.predict(X_test[features_set])
met, test_line_values, fig_flag = plot_actual_vs_predicted(
X_test[date_col],
y_test,
pred_test,
model,
flag=(start_date, end_date),
repeat_all_years=repeat,
is_panel=False,
)
flag_name = "f1_flag"
flag_name = st.text_input("Enter Flag Name")
# add selected target col to flag name
# Save the flag name, flag train values, flag test values to session state
if st.button("Save flag"):
st.session_state["Flags"][flag_name + "_flag__" + target_col] = {}
st.session_state["Flags"][flag_name + "_flag__" + target_col][
"train"
] = line_values
st.session_state["Flags"][flag_name + "_flag__" + target_col][
"test"
] = test_line_values
st.success(f'{flag_name + "_flag__" + target_col} stored')
st.session_state["project_dct"]["model_tuning"]["flags"] = (
st.session_state["Flags"]
)
# Only show flags created for the particular target col
target_model_flags = [
f.split("__")[0]
for f in st.session_state["Flags"].keys()
if f.split("__")[1] == target_col
]
options = list(target_model_flags)
num_rows = -(-len(options) // NUM_FLAG_COLS_TO_DISPLAY)
tick = False
# Select all flags checkbox
if st.checkbox(
"Select all",
value=st.session_state["project_dct"]["model_tuning"][
"select_all_flags_check"
].get(sel_target_col, False),
):
tick = True
st.session_state["project_dct"]["model_tuning"][
"select_all_flags_check"
][sel_target_col] = True
else:
st.session_state["project_dct"]["model_tuning"][
"select_all_flags_check"
][sel_target_col] = False
# Get previous flag selection (persistence)
selection_defualts = st.session_state["project_dct"]["model_tuning"][
"selected_flags"
].get(sel_target_col, [])
selected_options = selection_defualts
# create a checkbox for each available flag for selected response metric
for row in range(num_rows):
cols = st.columns(NUM_FLAG_COLS_TO_DISPLAY)
for col in cols:
if options:
option = options.pop(0)
option_default = True if option in selection_defualts else False
selected = col.checkbox(option, value=(tick or option_default))
if selected:
selected_options.append(option)
else:
if option in selected_options:
selected_options.remove(option)
selected_options = list(set(selected_options))
# Check if flag values match Data
# This is necessary because different models can have different train/test dates
remove_flags = []
for opt in selected_options:
train_match = len(
st.session_state["Flags"][opt + "__" + target_col]["train"]
) == len(X_train[date_col])
test_match = len(
st.session_state["Flags"][opt + "__" + target_col]["test"]
) == len(X_test[date_col])
if not train_match:
st.warning(f"Flag {opt} can not be used due to train date mismatch")
# selected_options.remove(opt)
remove_flags.append(opt)
if not test_match:
st.warning(f"Flag {opt} can not be used due to test date mismatch")
# selected_options.remove(opt)
remove_flags.append(opt)
if (
len(remove_flags) > 0
and len(list(set(selected_options).intersection(set(remove_flags)))) > 0
):
selected_options = list(set(selected_options) - set(remove_flags))
st.session_state["project_dct"]["model_tuning"]["selected_flags"][
sel_target_col
] = selected_options
except:
# Capture the error details
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
log_message(
"error", f"Error while creating flags: {error_message}", "Model Tuning"
)
st.warning("An error occured, please try again", icon="⚠️")
try:
st.markdown("### Trend and Seasonality Calibration")
parameters = st.columns(3)
# Trend checkbox
with parameters[0]:
Trend = st.checkbox(
"**Trend**",
value=st.session_state["project_dct"]["model_tuning"].get(
"trend_check", False
),
)
st.markdown(
"Helps account for long-term trends or seasonality that could influence advertising effectiveness"
)
# Day of Week (week number) checkbox
with parameters[1]:
day_of_week = st.checkbox(
"**Day of Week**",
value=st.session_state["project_dct"]["model_tuning"].get(
"week_num_check", False
),
)
st.markdown(
"Assists in detecting and incorporating weekly patterns or seasonality"
)
# Sine and cosine Waves checkbox
with parameters[2]:
sine_cosine = st.checkbox(
"**Sine and Cosine Waves**",
value=st.session_state["project_dct"]["model_tuning"].get(
"sine_cosine_check", False
),
)
st.markdown(
"Helps in capturing long term cyclical patterns or seasonality in the data"
)
if sine_cosine:
# Drop down to select Frequency of waves
xtrain_time_period_months = (
X_train[date_col].max() - X_train[date_col].min()
).days / 30
# If we have 6 months of data, only quarter frequency is possible
if xtrain_time_period_months <= HALF_YEAR_THRESHOLD:
available_frequencies = ["Quarter"]
# If we have less than 12 months of data, we have quarter and semi-annual frequencies
elif xtrain_time_period_months < FULL_YEAR_THRESHOLD:
available_frequencies = ["Quarter", "Semi-Annual"]
# If we have 12 months of data or more, we have quarter, semi-annual and annual frequencies
elif xtrain_time_period_months >= FULL_YEAR_THRESHOLD:
available_frequencies = ["Quarter", "Semi-Annual", "Annual"]
wave_freq = st.selectbox("Select Frequency", available_frequencies)
except:
# Capture the error details
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
log_message(
"error",
f"Error while selecting tuning parameters: {error_message}",
"Model Tuning",
)
st.warning("An error occured, please try again", icon="⚠️")
try:
# Build tuned model
if st.button(
"Build model with Selected Parameters and Flags",
key="build_tuned_model",
use_container_width=True,
):
new_features = features_set
st.header("2.1 Results Summary")
ss = MaxAbsScaler()
if is_panel == True:
X_train_tuned = X_train[features_set]
X_train_tuned[target_col] = X_train[target_col]
X_train_tuned[date_col] = X_train[date_col]
X_train_tuned[panel_col] = X_train[panel_col]
X_test_tuned = X_test[features_set]
X_test_tuned[target_col] = X_test[target_col]
X_test_tuned[date_col] = X_test[date_col]
X_test_tuned[panel_col] = X_test[panel_col]
else:
X_train_tuned = X_train[features_set]
X_test_tuned = X_test[features_set]
for flag in selected_options:
# Get the flag values of train and test and add to the data
X_train_tuned[flag] = st.session_state["Flags"][
flag + "__" + target_col
]["train"]
X_test_tuned[flag] = st.session_state["Flags"][
flag + "__" + target_col
]["test"]
if Trend:
st.session_state["project_dct"]["model_tuning"][
"trend_check"
] = True
# group by panel, calculate trend of each panel spearately. Add trend to new feature set
if is_panel:
newdata = pd.DataFrame()
panel_wise_end_point_train = {}
for panel, groupdf in X_train_tuned.groupby(panel_col):
groupdf.sort_values(date_col, inplace=True)
groupdf["Trend"] = np.arange(
TREND_MIN, len(groupdf) + TREND_MIN, 1
) # Trend is a straight line with starting point as TREND_MIN
newdata = pd.concat([newdata, groupdf])
panel_wise_end_point_train[panel] = len(groupdf) + TREND_MIN
X_train_tuned = newdata.copy()
test_newdata = pd.DataFrame()
for panel, test_groupdf in X_test_tuned.groupby(panel_col):
test_groupdf.sort_values(date_col, inplace=True)
start = panel_wise_end_point_train[panel]
end = start + len(test_groupdf)
test_groupdf["Trend"] = np.arange(start, end, 1)
test_newdata = pd.concat([test_newdata, test_groupdf])
X_test_tuned = test_newdata.copy()
new_features = new_features + ["Trend"]
else:
X_train_tuned["Trend"] = np.arange(
TREND_MIN, len(X_train_tuned) + TREND_MIN, 1
) # Trend is a straight line with starting point as TREND_MIN
X_test_tuned["Trend"] = np.arange(
len(X_train_tuned) + TREND_MIN,
len(X_train_tuned) + len(X_test_tuned) + TREND_MIN,
1,
)
new_features = new_features + ["Trend"]
else:
st.session_state["project_dct"]["model_tuning"][
"trend_check"
] = False # persistence
# Add day of week (Week_num) to test & train
if day_of_week:
st.session_state["project_dct"]["model_tuning"][
"week_num_check"
] = True
if is_panel:
X_train_tuned[date_col] = pd.to_datetime(
X_train_tuned[date_col]
)
X_train_tuned["day_of_week"] = X_train_tuned[
date_col
].dt.day_of_week # Day of week
# if all the dates in the data have the same day of week number this feature cant be used
if X_train_tuned["day_of_week"].nunique() == 1:
st.error(
"All dates in the data are of the same week day. Hence Week number can't be used."
)
else:
X_test_tuned[date_col] = pd.to_datetime(
X_test_tuned[date_col]
)
X_test_tuned["day_of_week"] = X_test_tuned[
date_col
].dt.day_of_week # Day of week
new_features = new_features + ["day_of_week"]
else:
date = pd.to_datetime(date.values)
X_train_tuned["day_of_week"] = pd.to_datetime(
X_train[date_col]
).dt.day_of_week # Day of week
X_test_tuned["day_of_week"] = pd.to_datetime(
X_test[date_col]
).dt.day_of_week # Day of week
# if all the dates in the data have the same day of week number this feature cant be used
if X_train_tuned["day_of_week"].nunique() == 1:
st.error(
"All dates in the data are of the same week day. Hence Week number can't be used."
)
else:
new_features = new_features + ["day_of_week"]
else:
st.session_state["project_dct"]["model_tuning"][
"week_num_check"
] = False
# create sine and cosine wave and add to data
if sine_cosine:
st.session_state["project_dct"]["model_tuning"][
"sine_cosine_check"
] = True
frequency = ANNUAL_FREQUENCY # Annual Frequency
if wave_freq == "Quarter":
frequency = frequency * QTR_FREQUENCY_FACTOR
elif wave_freq == "Semi-Annual":
frequency = frequency * HALF_YEARLY_FREQUENCY_FACTOR
# create panel wise sine cosine waves in xtrain tuned. add to new feature set
if is_panel:
new_features = new_features + ["sine_wave", "cosine_wave"]
newdata = pd.DataFrame()
newdata_test = pd.DataFrame()
groups = X_train_tuned.groupby(panel_col)
train_panel_wise_end_point = {}
for panel, groupdf in groups:
num_samples = len(groupdf)
train_panel_wise_end_point[panel] = num_samples
days_since_start = np.arange(num_samples)
sine_wave = np.sin(frequency * days_since_start)
cosine_wave = np.cos(frequency * days_since_start)
sine_cosine_df = pd.DataFrame(
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
)
assert len(sine_cosine_df) == len(groupdf)
groupdf["sine_wave"] = sine_wave
groupdf["cosine_wave"] = cosine_wave
newdata = pd.concat([newdata, groupdf])
X_train_tuned = newdata.copy()
test_groups = X_test_tuned.groupby(panel_col)
for panel, test_groupdf in test_groups:
num_samples = len(test_groupdf)
start = train_panel_wise_end_point[panel]
days_since_start = np.arange(start, start + num_samples, 1)
# print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1)))
sine_wave = np.sin(frequency * days_since_start)
cosine_wave = np.cos(frequency * days_since_start)
sine_cosine_df = pd.DataFrame(
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
)
assert len(sine_cosine_df) == len(test_groupdf)
# groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
test_groupdf["sine_wave"] = sine_wave
test_groupdf["cosine_wave"] = cosine_wave
newdata_test = pd.concat([newdata_test, test_groupdf])
X_test_tuned = newdata_test.copy()
else:
new_features = new_features + ["sine_wave", "cosine_wave"]
num_samples = len(X_train_tuned)
days_since_start = np.arange(num_samples)
sine_wave = np.sin(frequency * days_since_start)
cosine_wave = np.cos(frequency * days_since_start)
sine_cosine_df = pd.DataFrame(
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
)
# Concatenate the sine and cosine waves with the scaled X DataFrame
X_train_tuned = pd.concat(
[X_train_tuned, sine_cosine_df], axis=1
)
test_num_samples = len(X_test_tuned)
start = num_samples
days_since_start = np.arange(start, start + test_num_samples, 1)
sine_wave = np.sin(frequency * days_since_start)
cosine_wave = np.cos(frequency * days_since_start)
sine_cosine_df = pd.DataFrame(
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
)
# Concatenate the sine and cosine waves with the scaled X DataFrame
X_test_tuned = pd.concat([X_test_tuned, sine_cosine_df], axis=1)
else:
st.session_state["project_dct"]["model_tuning"][
"sine_cosine_check"
] = False
# Build model
# Get list of parameters added and scale
# previous features are scaled already during model build
added_params = list(set(new_features) - set(features_set))
if len(added_params) > 0:
concat_df = pd.concat([X_train_tuned, X_test_tuned]).reset_index(
drop=True
)
if is_panel:
train_max_date = X_train_tuned[date_col].max()
# concat_df = concat_df.reset_index(drop=True)
# concat_df=concat_df[added_params]
train_idx = X_train_tuned.index[-1]
concat_df[added_params] = ss.fit_transform(concat_df[added_params])
# added_params_df = pd.DataFrame(added_params_df)
# added_params_df.columns = added_params
if is_panel:
X_train_tuned[added_params] = concat_df[
concat_df[date_col] <= train_max_date
][added_params].reset_index(drop=True)
X_test_tuned[added_params] = concat_df[
concat_df[date_col] > train_max_date
][added_params].reset_index(drop=True)
else:
added_params_df = concat_df[added_params]
X_train_tuned[added_params] = added_params_df[: train_idx + 1]
X_test_tuned[added_params] = added_params_df.loc[
train_idx + 1 :
].reset_index(drop=True)
# Add flags (flags are 0, 1 only so need to scale)
if selected_options:
new_features = new_features + selected_options
# Build Mixed LM model for panel level data
if is_panel:
X_train_tuned.sort_values([date_col, panel_col]).reset_index(
drop=True, inplace=True
)
new_features = list(set(new_features))
inp_vars_str = " + ".join(new_features)
md_str = target_col + " ~ " + inp_vars_str
md_tuned = smf.mixedlm(
md_str,
data=X_train_tuned[[target_col] + new_features],
groups=X_train_tuned[panel_col],
)
model_tuned = md_tuned.fit()
# plot actual vs predicted for original model and tuned model
metrics_table, line, actual_vs_predicted_plot = (
plot_actual_vs_predicted(
X_train[date_col],
y_train,
model.fittedvalues,
model,
target_column=sel_target_col,
is_panel=True,
)
)
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
plot_actual_vs_predicted(
X_train_tuned[date_col],
X_train_tuned[target_col],
model_tuned.fittedvalues,
model_tuned,
target_column=sel_target_col,
is_panel=True,
)
)
# Build OLS model for panel level data
else:
new_features = list(set(new_features))
model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit()
metrics_table, line, actual_vs_predicted_plot = (
plot_actual_vs_predicted(
X_train[date_col],
y_train,
model.predict(X_train[features_set]),
model,
target_column=sel_target_col,
)
)
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
plot_actual_vs_predicted(
X_train[date_col],
y_train,
model_tuned.predict(X_train_tuned[new_features]),
model_tuned,
target_column=sel_target_col,
)
)
# # ----------------------------------- TESTING -----------------------------------
#
# Plot Sine & cosine wave to test
# sine_cosine_plot = plot_actual_vs_predicted(
# X_train[date_col],
# y_train,
# X_train_tuned['sine_wave'],
# model_tuned,
# target_column=sel_target_col,
# is_panel=True,
# )
# st.plotly_chart(sine_cosine_plot, use_container_width=True)
# # Plot Trend line to test
# trend_plot = plot_tuned_params(
# X_train[date_col],
# y_train,
# X_train_tuned['Trend'],
# model_tuned,
# target_column=sel_target_col,
# is_panel=True,
# )
# st.plotly_chart(trend_plot, use_container_width=True)
#
# # Plot week number to test
# week_num_plot = plot_tuned_params(
# X_train[date_col],
# y_train,
# X_train_tuned['day_of_week'],
# model_tuned,
# target_column=sel_target_col,
# is_panel=True,
# )
# st.plotly_chart(week_num_plot, use_container_width=True)
# Get model metrics from metric table & display them
mape = np.round(metrics_table.iloc[0, 1], 2)
r2 = np.round(metrics_table.iloc[1, 1], 2)
adjr2 = np.round(metrics_table.iloc[2, 1], 2)
mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2)
r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2)
adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2)
parameters_ = st.columns(3)
with parameters_[0]:
st.metric("R-squared", r2_tuned, np.round(r2_tuned - r2, 2))
with parameters_[1]:
st.metric(
"Adj. R-squared", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2)
)
with parameters_[2]:
st.metric(
"MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse"
)
st.write(model_tuned.summary())
X_train_tuned[date_col] = X_train[date_col]
X_train_tuned[target_col] = y_train
X_test_tuned[date_col] = X_test[date_col]
X_test_tuned[target_col] = y_test
st.header("2.2 Actual vs. Predicted Plot (Train)")
if is_panel:
metrics_table, line, actual_vs_predicted_plot = (
plot_actual_vs_predicted(
X_train_tuned[date_col],
X_train_tuned[target_col],
model_tuned.fittedvalues,
model_tuned,
target_column=sel_target_col,
is_panel=True,
)
)
else:
metrics_table, line, actual_vs_predicted_plot = (
plot_actual_vs_predicted(
X_train_tuned[date_col],
X_train_tuned[target_col],
model_tuned.predict(X_train_tuned[new_features]),
model_tuned,
target_column=sel_target_col,
is_panel=False,
)
)
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
st.markdown("## 2.3 Residual Analysis (Train)")
if is_panel:
columns = st.columns(2)
with columns[0]:
fig = plot_residual_predicted(
y_train, model_tuned.fittedvalues, X_train_tuned
)
st.plotly_chart(fig)
with columns[1]:
st.empty()
fig = qqplot(y_train, model_tuned.fittedvalues)
st.plotly_chart(fig)
with columns[0]:
fig = residual_distribution(y_train, model_tuned.fittedvalues)
st.pyplot(fig)
else:
columns = st.columns(2)
with columns[0]:
fig = plot_residual_predicted(
y_train,
model_tuned.predict(X_train_tuned[new_features]),
X_train,
)
st.plotly_chart(fig)
with columns[1]:
st.empty()
fig = qqplot(
y_train, model_tuned.predict(X_train_tuned[new_features])
)
st.plotly_chart(fig)
with columns[0]:
fig = residual_distribution(
y_train, model_tuned.predict(X_train_tuned[new_features])
)
st.pyplot(fig)
# st.session_state['is_tuned_model'][target_col] = True
# Save tuned model in a dict
st.session_state["Model_Tuned"][sel_model + "__" + target_col] = {
"Model_object": model_tuned,
"feature_set": new_features,
"X_train_tuned": X_train_tuned,
"X_test_tuned": X_test_tuned,
}
with st.expander("Results Summary Test data"):
if is_panel:
random_eff_df = get_random_effects(
st.session_state.media_data.copy(), panel_col, model_tuned
)
test_pred = mdf_predict(
X_test_tuned, model_tuned, random_eff_df
)
else:
test_pred = model_tuned.predict(X_test_tuned[new_features])
st.header("2.2 Actual vs. Predicted Plot (Test)")
metrics_table, line, actual_vs_predicted_plot = (
plot_actual_vs_predicted(
X_test_tuned[date_col],
y_test,
test_pred,
model,
target_column=sel_target_col,
is_panel=is_panel,
)
)
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
st.markdown("## 2.3 Residual Analysis (Test)")
columns = st.columns(2)
with columns[0]:
fig = plot_residual_predicted(y_test, test_pred, X_test_tuned)
st.plotly_chart(fig)
with columns[1]:
st.empty()
fig = qqplot(y_test, test_pred)
st.plotly_chart(fig)
with columns[0]:
fig = residual_distribution(y_test, test_pred)
st.pyplot(fig)
except:
# Capture the error details
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
log_message(
"error",
f"Error while building tuned model: {error_message}",
"Model Tuning",
)
st.warning("An error occured, please try again", icon="⚠️")
if (
st.session_state["Model_Tuned"] is not None
and len(list(st.session_state["Model_Tuned"].keys())) > 0
):
if st.button("Use This model for Media Planning", use_container_width=True):
# remove previous tuned models saved for this target col
_remove = [
m
for m in st.session_state["Model_Tuned"].keys()
if m.split("__")[1] == target_col and m.split("__")[0] != sel_model
]
if len(_remove) > 0:
for m in _remove:
del st.session_state["Model_Tuned"][m]
# Flag depicting tuned model for selected response metric
st.session_state["is_tuned_model"][target_col] = True
tuned_model_pkl = pickle.dumps(st.session_state["Model_Tuned"])
update_db(
st.session_state["project_number"],
"Model_Tuning",
"tuned_model",
tuned_model_pkl,
schema,
# resp_mtrc=None,
) # db
log_message(
"info",
f"Tuned model {' '.join(_remove)} removed due to overwrite",
"Model Tuning",
)
# Save session state variables (persistence)
st.session_state["project_dct"]["model_tuning"][
"session_state_saved"
] = {}
for key in [
"bin_dict",
"used_response_metrics",
"is_tuned_model",
"media_data",
"X_test_spends",
"spends_data",
]:
st.session_state["project_dct"]["model_tuning"][
"session_state_saved"
][key] = st.session_state[key]
project_dct_pkl = pickle.dumps(st.session_state["project_dct"])
update_db(
st.session_state["project_number"],
"Model_Tuning",
"project_dct",
project_dct_pkl,
schema,
# resp_mtrc=None,
) # db
log_message(
"info",
f'Tuned Model {sel_model + "__" + target_col} Saved',
"Model Tuning",
)
# Clear page metadata
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
] = None
st.session_state["project_dct"]["response_curves"][
"modified_metadata_file"
] = None
st.success(sel_model + " for " + target_col + " Tuned saved!")
except:
# Capture the error details
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
log_message("error", f"An error has occured : {error_message}", "Model Tuning")
st.warning("An error occured, please try again", icon="⚠️")
# st.write(error_message)