MediaMixOptimization / pages /3_AI_Model_Transformations.py
samkeet's picture
Upload 40 files
00b00eb verified
# Importing necessary libraries
import streamlit as st
st.set_page_config(
page_title="AI Model Transformations",
page_icon="โš–๏ธ",
layout="wide",
initial_sidebar_state="collapsed",
)
import sys
import pickle
import traceback
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from post_gres_cred import db_cred
from log_application import log_message
from utilities import (
set_header,
load_local_css,
update_db,
project_selection,
delete_entries,
retrieve_pkl_object,
)
from constants import (
predefined_defaults,
lead_min_value,
lead_max_value,
lead_step,
lag_min_value,
lag_max_value,
lag_step,
moving_average_min_value,
moving_average_max_value,
moving_average_step,
saturation_min_value,
saturation_max_value,
saturation_step,
power_min_value,
power_max_value,
power_step,
adstock_min_value,
adstock_max_value,
adstock_step,
display_max_col,
)
schema = db_cred["schema"]
load_local_css("styles.css")
set_header()
# Initialize project name session state
if "project_name" not in st.session_state:
st.session_state["project_name"] = None
# Fetch project dictionary
if "project_dct" not in st.session_state:
project_selection()
st.stop()
# Display Username and Project Name
if "username" in st.session_state and st.session_state["username"] is not None:
cols1 = st.columns([2, 1])
with cols1[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with cols1[1]:
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
# Load saved data from project dictionary
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
st.warning(
"The data import is incomplete. Please go back to the Data Import page and complete the save.",
icon="๐Ÿ”™",
)
# Log message
log_message(
"warning",
"The data import is incomplete. Please go back to the Data Import page and complete the save.",
"Transformations",
)
st.stop()
else:
final_df_loaded = st.session_state["project_dct"]["data_import"][
"imputed_tool_df"
].copy()
bin_dict_loaded = st.session_state["project_dct"]["data_import"][
"category_dict"
].copy()
unique_panels = st.session_state["project_dct"]["data_import"][
"unique_panels"
].copy()
# Initialize project dictionary data
if st.session_state["project_dct"]["transformations"]["final_df"] is None:
st.session_state["project_dct"]["transformations"][
"final_df"
] = final_df_loaded # Default as original dataframe
# Extract original columns for specified categories
original_columns = {
category: bin_dict_loaded[category]
for category in ["Media", "Internal", "Exogenous"]
if category in bin_dict_loaded
}
# Retrive Panel columns
panel = ["panel"] if len(unique_panels) > 1 else []
# Function to clear model metadata
def clear_pages():
# Reset Pages
st.session_state["project_dct"]["model_build"] = {
"sel_target_col": None,
"all_iters_check": False,
"iterations": 0,
"build_button": False,
"show_results_check": False,
"session_state_saved": {},
}
st.session_state["project_dct"]["model_tuning"] = {
"sel_target_col": None,
"sel_model": {},
"flag_expander": False,
"start_date_default": None,
"end_date_default": None,
"repeat_default": "No",
"flags": {},
"select_all_flags_check": {},
"selected_flags": {},
"trend_check": False,
"week_num_check": False,
"sine_cosine_check": False,
"session_state_saved": {},
}
st.session_state["project_dct"]["saved_model_results"] = {
"selected_options": None,
"model_grid_sel": [1],
}
if "model_results_df" in st.session_state:
del st.session_state["model_results_df"]
if "model_results_data" in st.session_state:
del st.session_state["model_results_data"]
if "coefficients_df" in st.session_state:
del st.session_state["coefficients_df"]
# Function to update transformation change
def transformation_change(category, transformation, key):
st.session_state["project_dct"]["transformations"][category][transformation] = (
st.session_state[key]
)
# Function to update specific transformation change
def transformation_specific_change(channel_name, transformation, key):
st.session_state["project_dct"]["transformations"]["Specific"][transformation][
channel_name
] = st.session_state[key]
# Function to update transformations to apply change
def transformations_to_apply_change(category, key):
st.session_state["project_dct"]["transformations"][category][key] = (
st.session_state[key]
)
# Function to update channel select specific change
def channel_select_specific_change():
st.session_state["project_dct"]["transformations"]["Specific"][
"channel_select_specific"
] = st.session_state["channel_select_specific"]
# Function to update specific transformation change
def specific_transformation_change(specific_transformation_key):
st.session_state["project_dct"]["transformations"]["Specific"][
specific_transformation_key
] = st.session_state[specific_transformation_key]
# Function to build transformation widgets
def transformation_widgets(category, transform_params, date_granularity):
# Transformation Options
transformation_options = {
"Media": [
"Lag",
"Moving Average",
"Saturation",
"Power",
"Adstock",
],
"Internal": ["Lead", "Lag", "Moving Average"],
"Exogenous": ["Lead", "Lag", "Moving Average"],
}
# Define a helper function to create widgets for each transformation
def create_transformation_widgets(column, transformations):
with column:
for transformation in transformations:
transformation_key = f"{transformation}_{category}"
slider_value = st.session_state["project_dct"]["transformations"][
category
].get(transformation, predefined_defaults[transformation])
# Conditionally create widgets for selected transformations
if transformation == "Lead":
st.markdown(f"**{transformation} ({date_granularity})**")
lead = st.slider(
label="Lead periods",
min_value=lead_min_value,
max_value=lead_max_value,
step=lead_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_change,
args=(
category,
transformation,
transformation_key,
),
)
start = lead[0]
end = lead[1]
step = lead_step
transform_params[category][transformation] = np.arange(
start, end + step, step
)
if transformation == "Lag":
st.markdown(f"**{transformation} ({date_granularity})**")
lag = st.slider(
label="Lag periods",
min_value=lag_min_value,
max_value=lag_max_value,
step=lag_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_change,
args=(
category,
transformation,
transformation_key,
),
)
start = lag[0]
end = lag[1]
step = lag_step
transform_params[category][transformation] = np.arange(
start, end + step, step
)
if transformation == "Moving Average":
st.markdown(f"**{transformation} ({date_granularity})**")
window = st.slider(
label="Window size for Moving Average",
min_value=moving_average_min_value,
max_value=moving_average_max_value,
step=moving_average_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_change,
args=(
category,
transformation,
transformation_key,
),
)
start = window[0]
end = window[1]
step = moving_average_step
transform_params[category][transformation] = np.arange(
start, end + step, step
)
if transformation == "Saturation":
st.markdown(f"**{transformation} (%)**")
saturation_point = st.slider(
label="Saturation Percentage",
min_value=saturation_min_value,
max_value=saturation_max_value,
step=saturation_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_change,
args=(
category,
transformation,
transformation_key,
),
)
start = saturation_point[0]
end = saturation_point[1]
step = saturation_step
transform_params[category][transformation] = np.arange(
start, end + step, step
)
if transformation == "Power":
st.markdown(f"**{transformation}**")
power = st.slider(
label="Power",
min_value=power_min_value,
max_value=power_max_value,
step=power_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_change,
args=(
category,
transformation,
transformation_key,
),
)
start = power[0]
end = power[1]
step = power_step
transform_params[category][transformation] = np.arange(
start, end + step, step
)
if transformation == "Adstock":
st.markdown(f"**{transformation}**")
rate = st.slider(
label="Decay Factor",
min_value=adstock_min_value,
max_value=adstock_max_value,
step=adstock_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_change,
args=(
category,
transformation,
transformation_key,
),
)
start = rate[0]
end = rate[1]
step = adstock_step
adstock_range = [
round(a, 3) for a in np.arange(start, end + step, step)
]
transform_params[category][transformation] = np.array(adstock_range)
with st.expander(f"All {category} Transformations", expanded=True):
transformation_key = f"transformation_{category}"
# Select which transformations to apply
sel_transformations = st.session_state["project_dct"]["transformations"][
category
].get(transformation_key, [])
# Reset default selected channels list if options are changed
for channel in sel_transformations:
if channel not in transformation_options[category]:
(
st.session_state["project_dct"]["transformations"][category][
transformation_key
],
sel_transformations,
) = ([], [])
transformations_to_apply = st.multiselect(
label="Select transformations to apply",
options=transformation_options[category],
default=sel_transformations,
key=transformation_key,
on_change=transformations_to_apply_change,
args=(
category,
transformation_key,
),
)
# Determine the number of transformations to put in each column
transformations_per_column = (
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
)
# Create two columns
col1, col2 = st.columns(2)
# Assign transformations to each column
transformations_col1 = transformations_to_apply[:transformations_per_column]
transformations_col2 = transformations_to_apply[transformations_per_column:]
# Create widgets in each column
create_transformation_widgets(col1, transformations_col1)
create_transformation_widgets(col2, transformations_col2)
# Define a helper function to create widgets for each specific transformation
def create_specific_transformation_widgets(
column,
transformations,
channel_name,
date_granularity,
specific_transform_params,
):
with column:
for transformation in transformations:
transformation_key = f"{transformation}_{channel_name}_specific"
if (
transformation
not in st.session_state["project_dct"]["transformations"]["Specific"]
):
st.session_state["project_dct"]["transformations"]["Specific"][
transformation
] = {}
slider_value = st.session_state["project_dct"]["transformations"][
"Specific"
][transformation].get(channel_name, predefined_defaults[transformation])
# Conditionally create widgets for selected transformations
if transformation == "Lead":
st.markdown(f"**Lead ({date_granularity})**")
lead = st.slider(
label="Lead periods",
min_value=lead_min_value,
max_value=lead_max_value,
step=lead_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_specific_change,
args=(
channel_name,
transformation,
transformation_key,
),
)
start = lead[0]
end = lead[1]
step = lead_step
specific_transform_params[channel_name]["Lead"] = np.arange(
start, end + step, step
)
if transformation == "Lag":
st.markdown(f"**Lag ({date_granularity})**")
lag = st.slider(
label="Lag periods",
min_value=lag_min_value,
max_value=lag_max_value,
step=lag_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_specific_change,
args=(
channel_name,
transformation,
transformation_key,
),
)
start = lag[0]
end = lag[1]
step = lag_step
specific_transform_params[channel_name]["Lag"] = np.arange(
start, end + step, step
)
if transformation == "Moving Average":
st.markdown(f"**Moving Average ({date_granularity})**")
window = st.slider(
label="Window size for Moving Average",
min_value=moving_average_min_value,
max_value=moving_average_max_value,
step=moving_average_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_specific_change,
args=(
channel_name,
transformation,
transformation_key,
),
)
start = window[0]
end = window[1]
step = moving_average_step
specific_transform_params[channel_name]["Moving Average"] = np.arange(
start, end + step, step
)
if transformation == "Saturation":
st.markdown("**Saturation (%)**")
saturation_point = st.slider(
label="Saturation Percentage",
min_value=saturation_min_value,
max_value=saturation_max_value,
step=saturation_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_specific_change,
args=(
channel_name,
transformation,
transformation_key,
),
)
start = saturation_point[0]
end = saturation_point[1]
step = saturation_step
specific_transform_params[channel_name]["Saturation"] = np.arange(
start, end + step, step
)
if transformation == "Power":
st.markdown("**Power**")
power = st.slider(
label="Power",
min_value=power_min_value,
max_value=power_max_value,
step=power_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_specific_change,
args=(
channel_name,
transformation,
transformation_key,
),
)
start = power[0]
end = power[1]
step = power_step
specific_transform_params[channel_name]["Power"] = np.arange(
start, end + step, step
)
if transformation == "Adstock":
st.markdown("**Adstock**")
rate = st.slider(
label="Decay Factor",
min_value=adstock_min_value,
max_value=adstock_max_value,
step=adstock_step,
value=slider_value,
key=transformation_key,
label_visibility="collapsed",
on_change=transformation_specific_change,
args=(
channel_name,
transformation,
transformation_key,
),
)
start = rate[0]
end = rate[1]
step = adstock_step
adstock_range = [
round(a, 3) for a in np.arange(start, end + step, step)
]
specific_transform_params[channel_name]["Adstock"] = np.array(
adstock_range
)
# Function to apply Lag transformation
def apply_lag(df, lag):
return df.shift(lag)
# Function to apply Lead transformation
def apply_lead(df, lead):
return df.shift(-lead)
# Function to apply Moving Average transformation
def apply_moving_average(df, window_size):
return df.rolling(window=window_size).mean()
# Function to apply Saturation transformation
def apply_saturation(df, saturation_percent_100):
# Convert percentage to fraction
saturation_percent = min(max(saturation_percent_100, 0.01), 99.99) / 100.0
# Get the maximum and minimum values
column_max = df.max()
column_min = df.min()
# If the data is constant, scale it directly
if column_min == column_max:
return df.apply(lambda x: x * saturation_percent)
# Compute the saturation point from the data range
saturation_point = (column_min + saturation_percent * column_max) / 2
# Calculate steepness for the saturation curve
numerator = np.log((1 / saturation_percent) - 1)
denominator = np.log(saturation_point / column_max)
steepness = numerator / denominator
# Apply the saturation transformation
transformed_series = df.apply(
lambda x: (1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)) * x
)
return transformed_series
# Function to apply Power transformation
def apply_power(df, power):
return df**power
# Function to apply Adstock transformation
def apply_adstock(df, factor):
x = 0
# Use the walrus operator to update x iteratively with the Adstock formula
adstock_var = [x := x * factor + v for v in df]
ans = pd.Series(adstock_var, index=df.index)
return ans
# Function to generate transformed columns names
@st.cache_resource(show_spinner=False)
def generate_transformed_columns(
original_columns, transform_params, specific_transform_params
):
transformed_columns, summary = {}, {}
for category, columns in original_columns.items():
for column in columns:
transformed_columns[column] = []
summary_details = (
[]
) # List to hold transformation details for the current column
if (
column in specific_transform_params.keys()
and len(specific_transform_params[column]) > 0
):
for transformation, values in specific_transform_params[column].items():
# Generate transformed column names for each value
for value in values:
transformed_name = f"{column}@{transformation}_{value}"
transformed_columns[column].append(transformed_name)
# Format the values list as a string with commas and "and" before the last item
if len(values) > 1:
formatted_values = (
", ".join(map(str, values[:-1])) + " and " + str(values[-1])
)
else:
formatted_values = str(values[0])
# Add transformation details
summary_details.append(f"{transformation} ({formatted_values})")
else:
if category in transform_params:
for transformation, values in transform_params[category].items():
# Generate transformed column names for each value
if column not in specific_transform_params.keys():
for value in values:
transformed_name = f"{column}@{transformation}_{value}"
transformed_columns[column].append(transformed_name)
# Format the values list as a string with commas and "and" before the last item
if len(values) > 1:
formatted_values = (
", ".join(map(str, values[:-1]))
+ " and "
+ str(values[-1])
)
else:
formatted_values = str(values[0])
# Add transformation details
summary_details.append(
f"{transformation} ({formatted_values})"
)
else:
summary_details = ["No transformation selected"]
# Only add to summary if there are transformation details for the column
if summary_details:
formatted_summary = "โฎ• ".join(summary_details)
# Use <strong> tags to make the column name bold
summary[column] = f"<strong>{column}</strong>: {formatted_summary}"
# Generate a comprehensive summary string for all columns
summary_items = [
f"{idx + 1}. {details}" for idx, details in enumerate(summary.values())
]
summary_string = "\n".join(summary_items)
return transformed_columns, summary_string
# Function to transform Dataframe slice
def transform_slice(
transform_params,
transformation_functions,
panel,
df,
df_slice,
category,
category_df,
):
# Iterate through each transformation and its parameters for the current category
for transformation, parameters in transform_params[category].items():
transformation_function = transformation_functions[transformation]
# Check if there is panel data to group by
if len(panel) > 0:
# Apply the transformation to each group
category_df = pd.concat(
[
df_slice.groupby(panel)
.transform(transformation_function, p)
.add_suffix(f"@{transformation}_{p}")
for p in parameters
],
axis=1,
)
# Replace all NaN or null values in category_df with 0
category_df.fillna(0, inplace=True)
# Update df_slice
df_slice = pd.concat(
[df[panel], category_df],
axis=1,
)
else:
for p in parameters:
# Apply the transformation function to each column
temp_df = df_slice.apply(
lambda x: transformation_function(x, p), axis=0
).rename(
lambda x: f"{x}@{transformation}_{p}",
axis="columns",
)
# Concatenate the transformed DataFrame slice to the category DataFrame
category_df = pd.concat([category_df, temp_df], axis=1)
# Replace all NaN or null values in category_df with 0
category_df.fillna(0, inplace=True)
# Update df_slice
df_slice = pd.concat(
[df[panel], category_df],
axis=1,
)
return category_df, df, df_slice
# Function to apply transformations to DataFrame slices based on specified categories and parameters
@st.cache_resource(show_spinner=False)
def apply_category_transformations(
df_main, bin_dict, transform_params, panel, specific_transform_params
):
# Dictionary for function mapping
transformation_functions = {
"Lead": apply_lead,
"Lag": apply_lag,
"Moving Average": apply_moving_average,
"Saturation": apply_saturation,
"Power": apply_power,
"Adstock": apply_adstock,
}
# List to collect all transformed DataFrames
transformed_dfs = []
# Iterate through each category specified in transform_params
for category in ["Media", "Exogenous", "Internal"]:
if (
category not in transform_params
or category not in bin_dict
or not transform_params[category]
):
continue # Skip categories without transformations
# Initialize category_df as an empty DataFrame
category_df = pd.DataFrame()
# Slice the DataFrame based on the columns specified in bin_dict for the current category
df_slice = df_main[bin_dict[category] + panel].copy()
# Drop the column from df_slice to skip specific transformations
df_slice = df_slice.drop(
columns=list(specific_transform_params.keys()), errors="ignore"
).copy()
category_df, df, df_slice_updated = transform_slice(
transform_params.copy(),
transformation_functions.copy(),
panel,
df_main.copy(),
df_slice.copy(),
category,
category_df.copy(),
)
# Append the transformed category DataFrame to the list if it's not empty
if not category_df.empty:
transformed_dfs.append(category_df)
# Apply channel specific transforms
for channel_specific in specific_transform_params:
# Initialize category_df as an empty DataFrame
category_df = pd.DataFrame()
df_slice_specific = df_main[[channel_specific] + panel].copy()
transform_params_specific = {
"Media": specific_transform_params[channel_specific]
}
category_df, df, df_slice_specific_updated = transform_slice(
transform_params_specific.copy(),
transformation_functions.copy(),
panel,
df_main.copy(),
df_slice_specific.copy(),
"Media",
category_df.copy(),
)
# Append the transformed category DataFrame to the list if it's not empty
if not category_df.empty:
transformed_dfs.append(category_df)
# If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
if len(transformed_dfs) > 0:
final_df = pd.concat([df_main] + transformed_dfs, axis=1)
else:
# If no transformations were applied, use the original DataFrame
final_df = df_main
# Find columns with '@' in their names
columns_with_at = [col for col in final_df.columns if "@" in col]
# Create a set of columns to drop
columns_to_drop = set()
# Iterate through columns with '@' to find shorter names to drop
for col in columns_with_at:
base_name = col.split("@")[0]
for other_col in columns_with_at:
if other_col.startswith(base_name) and len(other_col.split("@")) > len(
col.split("@")
):
columns_to_drop.add(col)
break
# Drop the identified columns from the DataFrame
final_df.drop(columns=list(columns_to_drop), inplace=True)
return final_df
# Function to infers the granularity of the date column in a DataFrame
@st.cache_resource(show_spinner=False)
def infer_date_granularity(df):
# Find the most common difference
common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
# Map the most common difference to a granularity
if common_freq == 1:
return "daily"
elif common_freq == 7:
return "weekly"
elif 28 <= common_freq <= 31:
return "monthly"
else:
return "irregular"
# Function to clean display DataFrame
@st.cache_data(show_spinner=False)
def clean_display_df(df, display_max_col=500):
# Sort by 'panel' and 'date'
sort_columns = ["panel", "date"]
sorted_df = df.sort_values(by=sort_columns, ascending=True, na_position="first")
# Drop duplicate columns
sorted_df = sorted_df.loc[:, ~sorted_df.columns.duplicated()]
# Check if the DataFrame has more than display_max_col columns
exceeds_max_col = sorted_df.shape[1] > display_max_col
if exceeds_max_col:
# Create a new DataFrame with 'date' and 'panel' at the start
display_df = sorted_df[["date", "panel"]]
# Add the next display_max_col - 2 columns (as 'date' and 'panel' already occupy 2 columns)
additional_columns = sorted_df.columns.difference(["date", "panel"]).tolist()[
: display_max_col - 2
]
display_df = pd.concat([display_df, sorted_df[additional_columns]], axis=1)
else:
# Ensure 'date' and 'panel' are the first two columns in the final display DataFrame
column_order = ["date", "panel"] + sorted_df.columns.difference(
["date", "panel"]
).tolist()
display_df = sorted_df[column_order]
# Return the display DataFrame and whether it exceeds 500 columns
return display_df, exceeds_max_col
#########################################################################################################################################################
# User input for transformations
#########################################################################################################################################################
try:
# Page Title
st.title("AI Model Transformations")
# Infer date granularity
date_granularity = infer_date_granularity(final_df_loaded)
# Initialize the main dictionary to store the transformation parameters for each category
transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
st.markdown("### Select Transformations to Apply")
with st.expander("Specific Media Transformations"):
# Select which transformations to apply
sel_channel_specific = st.session_state["project_dct"]["transformations"][
"Specific"
].get("channel_select_specific", [])
# Reset default selected channels list if options are changed
for channel in sel_channel_specific:
if channel not in bin_dict_loaded["Media"]:
(
st.session_state["project_dct"]["transformations"]["Specific"][
"channel_select_specific"
],
sel_channel_specific,
) = ([], [])
select_specific_channels = st.multiselect(
label="Select channel variable",
default=sel_channel_specific,
options=bin_dict_loaded["Media"],
key="channel_select_specific",
on_change=channel_select_specific_change,
max_selections=30,
)
specific_transform_params = {}
for select_specific_channel in select_specific_channels:
specific_transform_params[select_specific_channel] = {}
st.divider()
channel_name = str(select_specific_channel).replace("_", " ").title()
st.markdown(f"###### {channel_name}")
specific_transformation_key = (
f"specific_transformation_{select_specific_channel}_Media"
)
transformations_options = [
"Lag",
"Moving Average",
"Saturation",
"Power",
"Adstock",
]
# Select which transformations to apply
sel_transformations = st.session_state["project_dct"]["transformations"][
"Specific"
].get(specific_transformation_key, [])
# Reset default selected channels list if options are changed
for channel in sel_transformations:
if channel not in transformations_options:
(
st.session_state["project_dct"]["transformations"]["Specific"][
specific_transformation_key
],
sel_channel_specific,
) = ([], [])
transformations_to_apply = st.multiselect(
label="Select transformations to apply",
options=transformations_options,
default=sel_transformations,
key=specific_transformation_key,
on_change=specific_transformation_change,
args=(specific_transformation_key,),
)
# Determine the number of transformations to put in each column
transformations_per_column = (
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
)
# Create two columns
col1, col2 = st.columns(2)
# Assign transformations to each column
transformations_col1 = transformations_to_apply[:transformations_per_column]
transformations_col2 = transformations_to_apply[transformations_per_column:]
# Create widgets in each column
create_specific_transformation_widgets(
col1,
transformations_col1,
select_specific_channel,
date_granularity,
specific_transform_params,
)
create_specific_transformation_widgets(
col2,
transformations_col2,
select_specific_channel,
date_granularity,
specific_transform_params,
)
# Create Widgets
for category in ["Media", "Internal", "Exogenous"]:
# Skip Internal
if category == "Internal":
continue
# Skip category if no column available
elif (
category not in bin_dict_loaded.keys()
or len(bin_dict_loaded[category]) == 0
):
st.info(
f"{str(category).title()} category has no column associated with it. Skipping transformation step for this category.",
icon="๐Ÿ’ฌ",
)
continue
transformation_widgets(category, transform_params, date_granularity)
#########################################################################################################################################################
# Apply transformations
#########################################################################################################################################################
# Reset transformation selection to default
button_col = st.columns(2)
with button_col[1]:
if st.button("Reset to Default", use_container_width=True):
st.session_state["project_dct"]["transformations"]["Media"] = {}
st.session_state["project_dct"]["transformations"]["Exogenous"] = {}
st.session_state["project_dct"]["transformations"]["Internal"] = {}
st.session_state["project_dct"]["transformations"]["Specific"] = {}
# Log message
log_message(
"info",
"All persistent selections have been reset to their default settings and cleared.",
"Transformations",
)
st.rerun()
# Apply category-based transformations to the DataFrame
with button_col[0]:
if st.button("Accept and Proceed", use_container_width=True):
with st.spinner("Applying transformations ..."):
final_df = apply_category_transformations(
final_df_loaded.copy(),
bin_dict_loaded.copy(),
transform_params.copy(),
panel.copy(),
specific_transform_params.copy(),
)
# Generate a dictionary mapping original column names to lists of transformed column names
transformed_columns_dict, summary_string = generate_transformed_columns(
original_columns, transform_params, specific_transform_params
)
# Store into transformed dataframe and summary session state
st.session_state["project_dct"]["transformations"][
"final_df"
] = final_df
st.session_state["project_dct"]["transformations"][
"summary_string"
] = summary_string
# Display success message
st.success("Transformation of the DataFrame is successful!", icon="โœ…")
# Log message
log_message(
"info",
"Transformation of the DataFrame is successful!",
"Transformations",
)
#########################################################################################################################################################
# Display the transformed DataFrame and summary
#########################################################################################################################################################
# Display the transformed DataFrame in the Streamlit app
st.markdown("### Transformed DataFrame")
with st.spinner("Please wait while the transformed DataFrame is loading ..."):
final_df = st.session_state["project_dct"]["transformations"]["final_df"].copy()
# Clean display DataFrame
display_df, exceeds_max_col = clean_display_df(final_df, display_max_col)
# Check the number of columns and show only the first display_max_col if there are more
if exceeds_max_col:
# Display a info if the DataFrame has more than display_max_col columns
st.info(
f"The transformed DataFrame has more than {display_max_col} columns. Displaying only the first {display_max_col} columns.",
icon="๐Ÿ’ฌ",
)
# Display Final DataFrame
st.dataframe(
display_df,
hide_index=True,
column_config={
"date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
},
)
# Total rows and columns
total_rows, total_columns = st.session_state["project_dct"]["transformations"][
"final_df"
].shape
st.markdown(
f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
unsafe_allow_html=True,
)
# Display the summary of transformations as markdown
if (
"summary_string" in st.session_state["project_dct"]["transformations"]
and st.session_state["project_dct"]["transformations"]["summary_string"]
):
with st.expander("Summary of Transformations"):
st.markdown("### Summary of Transformations")
st.markdown(
st.session_state["project_dct"]["transformations"][
"summary_string"
],
unsafe_allow_html=True,
)
#########################################################################################################################################################
# Correlation Plot
#########################################################################################################################################################
# Filter out the 'date' column
variables = [
col for col in final_df.columns if col.lower() not in ["date", "panel"]
]
with st.expander("Transformed Variable Correlation Plot"):
selected_vars = st.multiselect(
label="Choose variables for correlation plot:",
options=variables,
max_selections=30,
default=st.session_state["project_dct"]["transformations"][
"correlation_plot_selection"
],
key="correlation_plot_key",
)
# Calculate correlation
if selected_vars:
corr_df = final_df[selected_vars].corr()
# Prepare text annotations with 2 decimal places
annotations = []
for i in range(len(corr_df)):
for j in range(len(corr_df.columns)):
annotations.append(
go.layout.Annotation(
text=f"{corr_df.iloc[i, j]:.2f}",
x=corr_df.columns[j],
y=corr_df.index[i],
showarrow=False,
font=dict(color="black"),
)
)
# Plotly correlation plot using go
heatmap = go.Heatmap(
z=corr_df.values,
x=corr_df.columns,
y=corr_df.index,
colorscale="RdBu",
zmin=-1,
zmax=1,
)
layout = go.Layout(
title="Transformed Variable Correlation Plot",
xaxis=dict(title="Variables"),
yaxis=dict(title="Variables"),
width=1000,
height=1000,
annotations=annotations,
)
fig = go.Figure(data=[heatmap], layout=layout)
st.plotly_chart(fig)
else:
st.write("Please select at least one variable to plot.")
#########################################################################################################################################################
# Accept and Save
#########################################################################################################################################################
# Check for saved model
if (
retrieve_pkl_object(
st.session_state["project_number"], "Model_Build", "best_models", schema
)
is not None
): # db
st.warning(
"Saving transformations will overwrite existing ones and delete all saved models. To keep previous models, please start a new project.",
icon="โš ๏ธ",
)
if st.button("Accept and Save", use_container_width=True):
with st.spinner("Saving Changes"):
# Update correlation plot selection
st.session_state["project_dct"]["transformations"][
"correlation_plot_selection"
] = st.session_state["correlation_plot_key"]
# Clear model metadata
clear_pages()
# Update DB
update_db(
prj_id=st.session_state["project_number"],
page_nam="Transformations",
file_nam="project_dct",
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
schema=schema,
)
# Clear data from DB
delete_entries(
st.session_state["project_number"],
["Model_Build", "Model_Tuning"],
db_cred,
schema,
)
# Success message
st.success("Saved Successfully!", icon="๐Ÿ’พ")
st.toast("Saved Successfully!", icon="๐Ÿ’พ")
# Log message
log_message("info", "Saved Successfully!", "Transformations")
except Exception as e:
# Capture the error details
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
# Log message
log_message("error", f"An error occurred: {error_message}.", "Transformations")
# Display a warning message
st.warning(
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
icon="โš ๏ธ",
)