Spaces:
Sleeping
Sleeping
# Importing necessary libraries | |
import streamlit as st | |
st.set_page_config( | |
page_title="AI Model Transformations", | |
page_icon="โ๏ธ", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
import sys | |
import pickle | |
import traceback | |
import numpy as np | |
import pandas as pd | |
import plotly.graph_objects as go | |
from post_gres_cred import db_cred | |
from log_application import log_message | |
from utilities import ( | |
set_header, | |
load_local_css, | |
update_db, | |
project_selection, | |
delete_entries, | |
retrieve_pkl_object, | |
) | |
from constants import ( | |
predefined_defaults, | |
lead_min_value, | |
lead_max_value, | |
lead_step, | |
lag_min_value, | |
lag_max_value, | |
lag_step, | |
moving_average_min_value, | |
moving_average_max_value, | |
moving_average_step, | |
saturation_min_value, | |
saturation_max_value, | |
saturation_step, | |
power_min_value, | |
power_max_value, | |
power_step, | |
adstock_min_value, | |
adstock_max_value, | |
adstock_step, | |
display_max_col, | |
) | |
schema = db_cred["schema"] | |
load_local_css("styles.css") | |
set_header() | |
# Initialize project name session state | |
if "project_name" not in st.session_state: | |
st.session_state["project_name"] = None | |
# Fetch project dictionary | |
if "project_dct" not in st.session_state: | |
project_selection() | |
st.stop() | |
# Display Username and Project Name | |
if "username" in st.session_state and st.session_state["username"] is not None: | |
cols1 = st.columns([2, 1]) | |
with cols1[0]: | |
st.markdown(f"**Welcome {st.session_state['username']}**") | |
with cols1[1]: | |
st.markdown(f"**Current Project: {st.session_state['project_name']}**") | |
# Load saved data from project dictionary | |
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None: | |
st.warning( | |
"The data import is incomplete. Please go back to the Data Import page and complete the save.", | |
icon="๐", | |
) | |
# Log message | |
log_message( | |
"warning", | |
"The data import is incomplete. Please go back to the Data Import page and complete the save.", | |
"Transformations", | |
) | |
st.stop() | |
else: | |
final_df_loaded = st.session_state["project_dct"]["data_import"][ | |
"imputed_tool_df" | |
].copy() | |
bin_dict_loaded = st.session_state["project_dct"]["data_import"][ | |
"category_dict" | |
].copy() | |
unique_panels = st.session_state["project_dct"]["data_import"][ | |
"unique_panels" | |
].copy() | |
# Initialize project dictionary data | |
if st.session_state["project_dct"]["transformations"]["final_df"] is None: | |
st.session_state["project_dct"]["transformations"][ | |
"final_df" | |
] = final_df_loaded # Default as original dataframe | |
# Extract original columns for specified categories | |
original_columns = { | |
category: bin_dict_loaded[category] | |
for category in ["Media", "Internal", "Exogenous"] | |
if category in bin_dict_loaded | |
} | |
# Retrive Panel columns | |
panel = ["panel"] if len(unique_panels) > 1 else [] | |
# Function to clear model metadata | |
def clear_pages(): | |
# Reset Pages | |
st.session_state["project_dct"]["model_build"] = { | |
"sel_target_col": None, | |
"all_iters_check": False, | |
"iterations": 0, | |
"build_button": False, | |
"show_results_check": False, | |
"session_state_saved": {}, | |
} | |
st.session_state["project_dct"]["model_tuning"] = { | |
"sel_target_col": None, | |
"sel_model": {}, | |
"flag_expander": False, | |
"start_date_default": None, | |
"end_date_default": None, | |
"repeat_default": "No", | |
"flags": {}, | |
"select_all_flags_check": {}, | |
"selected_flags": {}, | |
"trend_check": False, | |
"week_num_check": False, | |
"sine_cosine_check": False, | |
"session_state_saved": {}, | |
} | |
st.session_state["project_dct"]["saved_model_results"] = { | |
"selected_options": None, | |
"model_grid_sel": [1], | |
} | |
if "model_results_df" in st.session_state: | |
del st.session_state["model_results_df"] | |
if "model_results_data" in st.session_state: | |
del st.session_state["model_results_data"] | |
if "coefficients_df" in st.session_state: | |
del st.session_state["coefficients_df"] | |
# Function to update transformation change | |
def transformation_change(category, transformation, key): | |
st.session_state["project_dct"]["transformations"][category][transformation] = ( | |
st.session_state[key] | |
) | |
# Function to update specific transformation change | |
def transformation_specific_change(channel_name, transformation, key): | |
st.session_state["project_dct"]["transformations"]["Specific"][transformation][ | |
channel_name | |
] = st.session_state[key] | |
# Function to update transformations to apply change | |
def transformations_to_apply_change(category, key): | |
st.session_state["project_dct"]["transformations"][category][key] = ( | |
st.session_state[key] | |
) | |
# Function to update channel select specific change | |
def channel_select_specific_change(): | |
st.session_state["project_dct"]["transformations"]["Specific"][ | |
"channel_select_specific" | |
] = st.session_state["channel_select_specific"] | |
# Function to update specific transformation change | |
def specific_transformation_change(specific_transformation_key): | |
st.session_state["project_dct"]["transformations"]["Specific"][ | |
specific_transformation_key | |
] = st.session_state[specific_transformation_key] | |
# Function to build transformation widgets | |
def transformation_widgets(category, transform_params, date_granularity): | |
# Transformation Options | |
transformation_options = { | |
"Media": [ | |
"Lag", | |
"Moving Average", | |
"Saturation", | |
"Power", | |
"Adstock", | |
], | |
"Internal": ["Lead", "Lag", "Moving Average"], | |
"Exogenous": ["Lead", "Lag", "Moving Average"], | |
} | |
# Define a helper function to create widgets for each transformation | |
def create_transformation_widgets(column, transformations): | |
with column: | |
for transformation in transformations: | |
transformation_key = f"{transformation}_{category}" | |
slider_value = st.session_state["project_dct"]["transformations"][ | |
category | |
].get(transformation, predefined_defaults[transformation]) | |
# Conditionally create widgets for selected transformations | |
if transformation == "Lead": | |
st.markdown(f"**{transformation} ({date_granularity})**") | |
lead = st.slider( | |
label="Lead periods", | |
min_value=lead_min_value, | |
max_value=lead_max_value, | |
step=lead_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_change, | |
args=( | |
category, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = lead[0] | |
end = lead[1] | |
step = lead_step | |
transform_params[category][transformation] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Lag": | |
st.markdown(f"**{transformation} ({date_granularity})**") | |
lag = st.slider( | |
label="Lag periods", | |
min_value=lag_min_value, | |
max_value=lag_max_value, | |
step=lag_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_change, | |
args=( | |
category, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = lag[0] | |
end = lag[1] | |
step = lag_step | |
transform_params[category][transformation] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Moving Average": | |
st.markdown(f"**{transformation} ({date_granularity})**") | |
window = st.slider( | |
label="Window size for Moving Average", | |
min_value=moving_average_min_value, | |
max_value=moving_average_max_value, | |
step=moving_average_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_change, | |
args=( | |
category, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = window[0] | |
end = window[1] | |
step = moving_average_step | |
transform_params[category][transformation] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Saturation": | |
st.markdown(f"**{transformation} (%)**") | |
saturation_point = st.slider( | |
label="Saturation Percentage", | |
min_value=saturation_min_value, | |
max_value=saturation_max_value, | |
step=saturation_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_change, | |
args=( | |
category, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = saturation_point[0] | |
end = saturation_point[1] | |
step = saturation_step | |
transform_params[category][transformation] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Power": | |
st.markdown(f"**{transformation}**") | |
power = st.slider( | |
label="Power", | |
min_value=power_min_value, | |
max_value=power_max_value, | |
step=power_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_change, | |
args=( | |
category, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = power[0] | |
end = power[1] | |
step = power_step | |
transform_params[category][transformation] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Adstock": | |
st.markdown(f"**{transformation}**") | |
rate = st.slider( | |
label="Decay Factor", | |
min_value=adstock_min_value, | |
max_value=adstock_max_value, | |
step=adstock_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_change, | |
args=( | |
category, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = rate[0] | |
end = rate[1] | |
step = adstock_step | |
adstock_range = [ | |
round(a, 3) for a in np.arange(start, end + step, step) | |
] | |
transform_params[category][transformation] = np.array(adstock_range) | |
with st.expander(f"All {category} Transformations", expanded=True): | |
transformation_key = f"transformation_{category}" | |
# Select which transformations to apply | |
sel_transformations = st.session_state["project_dct"]["transformations"][ | |
category | |
].get(transformation_key, []) | |
# Reset default selected channels list if options are changed | |
for channel in sel_transformations: | |
if channel not in transformation_options[category]: | |
( | |
st.session_state["project_dct"]["transformations"][category][ | |
transformation_key | |
], | |
sel_transformations, | |
) = ([], []) | |
transformations_to_apply = st.multiselect( | |
label="Select transformations to apply", | |
options=transformation_options[category], | |
default=sel_transformations, | |
key=transformation_key, | |
on_change=transformations_to_apply_change, | |
args=( | |
category, | |
transformation_key, | |
), | |
) | |
# Determine the number of transformations to put in each column | |
transformations_per_column = ( | |
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 | |
) | |
# Create two columns | |
col1, col2 = st.columns(2) | |
# Assign transformations to each column | |
transformations_col1 = transformations_to_apply[:transformations_per_column] | |
transformations_col2 = transformations_to_apply[transformations_per_column:] | |
# Create widgets in each column | |
create_transformation_widgets(col1, transformations_col1) | |
create_transformation_widgets(col2, transformations_col2) | |
# Define a helper function to create widgets for each specific transformation | |
def create_specific_transformation_widgets( | |
column, | |
transformations, | |
channel_name, | |
date_granularity, | |
specific_transform_params, | |
): | |
with column: | |
for transformation in transformations: | |
transformation_key = f"{transformation}_{channel_name}_specific" | |
if ( | |
transformation | |
not in st.session_state["project_dct"]["transformations"]["Specific"] | |
): | |
st.session_state["project_dct"]["transformations"]["Specific"][ | |
transformation | |
] = {} | |
slider_value = st.session_state["project_dct"]["transformations"][ | |
"Specific" | |
][transformation].get(channel_name, predefined_defaults[transformation]) | |
# Conditionally create widgets for selected transformations | |
if transformation == "Lead": | |
st.markdown(f"**Lead ({date_granularity})**") | |
lead = st.slider( | |
label="Lead periods", | |
min_value=lead_min_value, | |
max_value=lead_max_value, | |
step=lead_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_specific_change, | |
args=( | |
channel_name, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = lead[0] | |
end = lead[1] | |
step = lead_step | |
specific_transform_params[channel_name]["Lead"] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Lag": | |
st.markdown(f"**Lag ({date_granularity})**") | |
lag = st.slider( | |
label="Lag periods", | |
min_value=lag_min_value, | |
max_value=lag_max_value, | |
step=lag_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_specific_change, | |
args=( | |
channel_name, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = lag[0] | |
end = lag[1] | |
step = lag_step | |
specific_transform_params[channel_name]["Lag"] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Moving Average": | |
st.markdown(f"**Moving Average ({date_granularity})**") | |
window = st.slider( | |
label="Window size for Moving Average", | |
min_value=moving_average_min_value, | |
max_value=moving_average_max_value, | |
step=moving_average_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_specific_change, | |
args=( | |
channel_name, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = window[0] | |
end = window[1] | |
step = moving_average_step | |
specific_transform_params[channel_name]["Moving Average"] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Saturation": | |
st.markdown("**Saturation (%)**") | |
saturation_point = st.slider( | |
label="Saturation Percentage", | |
min_value=saturation_min_value, | |
max_value=saturation_max_value, | |
step=saturation_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_specific_change, | |
args=( | |
channel_name, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = saturation_point[0] | |
end = saturation_point[1] | |
step = saturation_step | |
specific_transform_params[channel_name]["Saturation"] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Power": | |
st.markdown("**Power**") | |
power = st.slider( | |
label="Power", | |
min_value=power_min_value, | |
max_value=power_max_value, | |
step=power_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_specific_change, | |
args=( | |
channel_name, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = power[0] | |
end = power[1] | |
step = power_step | |
specific_transform_params[channel_name]["Power"] = np.arange( | |
start, end + step, step | |
) | |
if transformation == "Adstock": | |
st.markdown("**Adstock**") | |
rate = st.slider( | |
label="Decay Factor", | |
min_value=adstock_min_value, | |
max_value=adstock_max_value, | |
step=adstock_step, | |
value=slider_value, | |
key=transformation_key, | |
label_visibility="collapsed", | |
on_change=transformation_specific_change, | |
args=( | |
channel_name, | |
transformation, | |
transformation_key, | |
), | |
) | |
start = rate[0] | |
end = rate[1] | |
step = adstock_step | |
adstock_range = [ | |
round(a, 3) for a in np.arange(start, end + step, step) | |
] | |
specific_transform_params[channel_name]["Adstock"] = np.array( | |
adstock_range | |
) | |
# Function to apply Lag transformation | |
def apply_lag(df, lag): | |
return df.shift(lag) | |
# Function to apply Lead transformation | |
def apply_lead(df, lead): | |
return df.shift(-lead) | |
# Function to apply Moving Average transformation | |
def apply_moving_average(df, window_size): | |
return df.rolling(window=window_size).mean() | |
# Function to apply Saturation transformation | |
def apply_saturation(df, saturation_percent_100): | |
# Convert percentage to fraction | |
saturation_percent = min(max(saturation_percent_100, 0.01), 99.99) / 100.0 | |
# Get the maximum and minimum values | |
column_max = df.max() | |
column_min = df.min() | |
# If the data is constant, scale it directly | |
if column_min == column_max: | |
return df.apply(lambda x: x * saturation_percent) | |
# Compute the saturation point from the data range | |
saturation_point = (column_min + saturation_percent * column_max) / 2 | |
# Calculate steepness for the saturation curve | |
numerator = np.log((1 / saturation_percent) - 1) | |
denominator = np.log(saturation_point / column_max) | |
steepness = numerator / denominator | |
# Apply the saturation transformation | |
transformed_series = df.apply( | |
lambda x: (1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)) * x | |
) | |
return transformed_series | |
# Function to apply Power transformation | |
def apply_power(df, power): | |
return df**power | |
# Function to apply Adstock transformation | |
def apply_adstock(df, factor): | |
x = 0 | |
# Use the walrus operator to update x iteratively with the Adstock formula | |
adstock_var = [x := x * factor + v for v in df] | |
ans = pd.Series(adstock_var, index=df.index) | |
return ans | |
# Function to generate transformed columns names | |
def generate_transformed_columns( | |
original_columns, transform_params, specific_transform_params | |
): | |
transformed_columns, summary = {}, {} | |
for category, columns in original_columns.items(): | |
for column in columns: | |
transformed_columns[column] = [] | |
summary_details = ( | |
[] | |
) # List to hold transformation details for the current column | |
if ( | |
column in specific_transform_params.keys() | |
and len(specific_transform_params[column]) > 0 | |
): | |
for transformation, values in specific_transform_params[column].items(): | |
# Generate transformed column names for each value | |
for value in values: | |
transformed_name = f"{column}@{transformation}_{value}" | |
transformed_columns[column].append(transformed_name) | |
# Format the values list as a string with commas and "and" before the last item | |
if len(values) > 1: | |
formatted_values = ( | |
", ".join(map(str, values[:-1])) + " and " + str(values[-1]) | |
) | |
else: | |
formatted_values = str(values[0]) | |
# Add transformation details | |
summary_details.append(f"{transformation} ({formatted_values})") | |
else: | |
if category in transform_params: | |
for transformation, values in transform_params[category].items(): | |
# Generate transformed column names for each value | |
if column not in specific_transform_params.keys(): | |
for value in values: | |
transformed_name = f"{column}@{transformation}_{value}" | |
transformed_columns[column].append(transformed_name) | |
# Format the values list as a string with commas and "and" before the last item | |
if len(values) > 1: | |
formatted_values = ( | |
", ".join(map(str, values[:-1])) | |
+ " and " | |
+ str(values[-1]) | |
) | |
else: | |
formatted_values = str(values[0]) | |
# Add transformation details | |
summary_details.append( | |
f"{transformation} ({formatted_values})" | |
) | |
else: | |
summary_details = ["No transformation selected"] | |
# Only add to summary if there are transformation details for the column | |
if summary_details: | |
formatted_summary = "โฎ ".join(summary_details) | |
# Use <strong> tags to make the column name bold | |
summary[column] = f"<strong>{column}</strong>: {formatted_summary}" | |
# Generate a comprehensive summary string for all columns | |
summary_items = [ | |
f"{idx + 1}. {details}" for idx, details in enumerate(summary.values()) | |
] | |
summary_string = "\n".join(summary_items) | |
return transformed_columns, summary_string | |
# Function to transform Dataframe slice | |
def transform_slice( | |
transform_params, | |
transformation_functions, | |
panel, | |
df, | |
df_slice, | |
category, | |
category_df, | |
): | |
# Iterate through each transformation and its parameters for the current category | |
for transformation, parameters in transform_params[category].items(): | |
transformation_function = transformation_functions[transformation] | |
# Check if there is panel data to group by | |
if len(panel) > 0: | |
# Apply the transformation to each group | |
category_df = pd.concat( | |
[ | |
df_slice.groupby(panel) | |
.transform(transformation_function, p) | |
.add_suffix(f"@{transformation}_{p}") | |
for p in parameters | |
], | |
axis=1, | |
) | |
# Replace all NaN or null values in category_df with 0 | |
category_df.fillna(0, inplace=True) | |
# Update df_slice | |
df_slice = pd.concat( | |
[df[panel], category_df], | |
axis=1, | |
) | |
else: | |
for p in parameters: | |
# Apply the transformation function to each column | |
temp_df = df_slice.apply( | |
lambda x: transformation_function(x, p), axis=0 | |
).rename( | |
lambda x: f"{x}@{transformation}_{p}", | |
axis="columns", | |
) | |
# Concatenate the transformed DataFrame slice to the category DataFrame | |
category_df = pd.concat([category_df, temp_df], axis=1) | |
# Replace all NaN or null values in category_df with 0 | |
category_df.fillna(0, inplace=True) | |
# Update df_slice | |
df_slice = pd.concat( | |
[df[panel], category_df], | |
axis=1, | |
) | |
return category_df, df, df_slice | |
# Function to apply transformations to DataFrame slices based on specified categories and parameters | |
def apply_category_transformations( | |
df_main, bin_dict, transform_params, panel, specific_transform_params | |
): | |
# Dictionary for function mapping | |
transformation_functions = { | |
"Lead": apply_lead, | |
"Lag": apply_lag, | |
"Moving Average": apply_moving_average, | |
"Saturation": apply_saturation, | |
"Power": apply_power, | |
"Adstock": apply_adstock, | |
} | |
# List to collect all transformed DataFrames | |
transformed_dfs = [] | |
# Iterate through each category specified in transform_params | |
for category in ["Media", "Exogenous", "Internal"]: | |
if ( | |
category not in transform_params | |
or category not in bin_dict | |
or not transform_params[category] | |
): | |
continue # Skip categories without transformations | |
# Initialize category_df as an empty DataFrame | |
category_df = pd.DataFrame() | |
# Slice the DataFrame based on the columns specified in bin_dict for the current category | |
df_slice = df_main[bin_dict[category] + panel].copy() | |
# Drop the column from df_slice to skip specific transformations | |
df_slice = df_slice.drop( | |
columns=list(specific_transform_params.keys()), errors="ignore" | |
).copy() | |
category_df, df, df_slice_updated = transform_slice( | |
transform_params.copy(), | |
transformation_functions.copy(), | |
panel, | |
df_main.copy(), | |
df_slice.copy(), | |
category, | |
category_df.copy(), | |
) | |
# Append the transformed category DataFrame to the list if it's not empty | |
if not category_df.empty: | |
transformed_dfs.append(category_df) | |
# Apply channel specific transforms | |
for channel_specific in specific_transform_params: | |
# Initialize category_df as an empty DataFrame | |
category_df = pd.DataFrame() | |
df_slice_specific = df_main[[channel_specific] + panel].copy() | |
transform_params_specific = { | |
"Media": specific_transform_params[channel_specific] | |
} | |
category_df, df, df_slice_specific_updated = transform_slice( | |
transform_params_specific.copy(), | |
transformation_functions.copy(), | |
panel, | |
df_main.copy(), | |
df_slice_specific.copy(), | |
"Media", | |
category_df.copy(), | |
) | |
# Append the transformed category DataFrame to the list if it's not empty | |
if not category_df.empty: | |
transformed_dfs.append(category_df) | |
# If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame | |
if len(transformed_dfs) > 0: | |
final_df = pd.concat([df_main] + transformed_dfs, axis=1) | |
else: | |
# If no transformations were applied, use the original DataFrame | |
final_df = df_main | |
# Find columns with '@' in their names | |
columns_with_at = [col for col in final_df.columns if "@" in col] | |
# Create a set of columns to drop | |
columns_to_drop = set() | |
# Iterate through columns with '@' to find shorter names to drop | |
for col in columns_with_at: | |
base_name = col.split("@")[0] | |
for other_col in columns_with_at: | |
if other_col.startswith(base_name) and len(other_col.split("@")) > len( | |
col.split("@") | |
): | |
columns_to_drop.add(col) | |
break | |
# Drop the identified columns from the DataFrame | |
final_df.drop(columns=list(columns_to_drop), inplace=True) | |
return final_df | |
# Function to infers the granularity of the date column in a DataFrame | |
def infer_date_granularity(df): | |
# Find the most common difference | |
common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0] | |
# Map the most common difference to a granularity | |
if common_freq == 1: | |
return "daily" | |
elif common_freq == 7: | |
return "weekly" | |
elif 28 <= common_freq <= 31: | |
return "monthly" | |
else: | |
return "irregular" | |
# Function to clean display DataFrame | |
def clean_display_df(df, display_max_col=500): | |
# Sort by 'panel' and 'date' | |
sort_columns = ["panel", "date"] | |
sorted_df = df.sort_values(by=sort_columns, ascending=True, na_position="first") | |
# Drop duplicate columns | |
sorted_df = sorted_df.loc[:, ~sorted_df.columns.duplicated()] | |
# Check if the DataFrame has more than display_max_col columns | |
exceeds_max_col = sorted_df.shape[1] > display_max_col | |
if exceeds_max_col: | |
# Create a new DataFrame with 'date' and 'panel' at the start | |
display_df = sorted_df[["date", "panel"]] | |
# Add the next display_max_col - 2 columns (as 'date' and 'panel' already occupy 2 columns) | |
additional_columns = sorted_df.columns.difference(["date", "panel"]).tolist()[ | |
: display_max_col - 2 | |
] | |
display_df = pd.concat([display_df, sorted_df[additional_columns]], axis=1) | |
else: | |
# Ensure 'date' and 'panel' are the first two columns in the final display DataFrame | |
column_order = ["date", "panel"] + sorted_df.columns.difference( | |
["date", "panel"] | |
).tolist() | |
display_df = sorted_df[column_order] | |
# Return the display DataFrame and whether it exceeds 500 columns | |
return display_df, exceeds_max_col | |
######################################################################################################################################################### | |
# User input for transformations | |
######################################################################################################################################################### | |
try: | |
# Page Title | |
st.title("AI Model Transformations") | |
# Infer date granularity | |
date_granularity = infer_date_granularity(final_df_loaded) | |
# Initialize the main dictionary to store the transformation parameters for each category | |
transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}} | |
st.markdown("### Select Transformations to Apply") | |
with st.expander("Specific Media Transformations"): | |
# Select which transformations to apply | |
sel_channel_specific = st.session_state["project_dct"]["transformations"][ | |
"Specific" | |
].get("channel_select_specific", []) | |
# Reset default selected channels list if options are changed | |
for channel in sel_channel_specific: | |
if channel not in bin_dict_loaded["Media"]: | |
( | |
st.session_state["project_dct"]["transformations"]["Specific"][ | |
"channel_select_specific" | |
], | |
sel_channel_specific, | |
) = ([], []) | |
select_specific_channels = st.multiselect( | |
label="Select channel variable", | |
default=sel_channel_specific, | |
options=bin_dict_loaded["Media"], | |
key="channel_select_specific", | |
on_change=channel_select_specific_change, | |
max_selections=30, | |
) | |
specific_transform_params = {} | |
for select_specific_channel in select_specific_channels: | |
specific_transform_params[select_specific_channel] = {} | |
st.divider() | |
channel_name = str(select_specific_channel).replace("_", " ").title() | |
st.markdown(f"###### {channel_name}") | |
specific_transformation_key = ( | |
f"specific_transformation_{select_specific_channel}_Media" | |
) | |
transformations_options = [ | |
"Lag", | |
"Moving Average", | |
"Saturation", | |
"Power", | |
"Adstock", | |
] | |
# Select which transformations to apply | |
sel_transformations = st.session_state["project_dct"]["transformations"][ | |
"Specific" | |
].get(specific_transformation_key, []) | |
# Reset default selected channels list if options are changed | |
for channel in sel_transformations: | |
if channel not in transformations_options: | |
( | |
st.session_state["project_dct"]["transformations"]["Specific"][ | |
specific_transformation_key | |
], | |
sel_channel_specific, | |
) = ([], []) | |
transformations_to_apply = st.multiselect( | |
label="Select transformations to apply", | |
options=transformations_options, | |
default=sel_transformations, | |
key=specific_transformation_key, | |
on_change=specific_transformation_change, | |
args=(specific_transformation_key,), | |
) | |
# Determine the number of transformations to put in each column | |
transformations_per_column = ( | |
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 | |
) | |
# Create two columns | |
col1, col2 = st.columns(2) | |
# Assign transformations to each column | |
transformations_col1 = transformations_to_apply[:transformations_per_column] | |
transformations_col2 = transformations_to_apply[transformations_per_column:] | |
# Create widgets in each column | |
create_specific_transformation_widgets( | |
col1, | |
transformations_col1, | |
select_specific_channel, | |
date_granularity, | |
specific_transform_params, | |
) | |
create_specific_transformation_widgets( | |
col2, | |
transformations_col2, | |
select_specific_channel, | |
date_granularity, | |
specific_transform_params, | |
) | |
# Create Widgets | |
for category in ["Media", "Internal", "Exogenous"]: | |
# Skip Internal | |
if category == "Internal": | |
continue | |
# Skip category if no column available | |
elif ( | |
category not in bin_dict_loaded.keys() | |
or len(bin_dict_loaded[category]) == 0 | |
): | |
st.info( | |
f"{str(category).title()} category has no column associated with it. Skipping transformation step for this category.", | |
icon="๐ฌ", | |
) | |
continue | |
transformation_widgets(category, transform_params, date_granularity) | |
######################################################################################################################################################### | |
# Apply transformations | |
######################################################################################################################################################### | |
# Reset transformation selection to default | |
button_col = st.columns(2) | |
with button_col[1]: | |
if st.button("Reset to Default", use_container_width=True): | |
st.session_state["project_dct"]["transformations"]["Media"] = {} | |
st.session_state["project_dct"]["transformations"]["Exogenous"] = {} | |
st.session_state["project_dct"]["transformations"]["Internal"] = {} | |
st.session_state["project_dct"]["transformations"]["Specific"] = {} | |
# Log message | |
log_message( | |
"info", | |
"All persistent selections have been reset to their default settings and cleared.", | |
"Transformations", | |
) | |
st.rerun() | |
# Apply category-based transformations to the DataFrame | |
with button_col[0]: | |
if st.button("Accept and Proceed", use_container_width=True): | |
with st.spinner("Applying transformations ..."): | |
final_df = apply_category_transformations( | |
final_df_loaded.copy(), | |
bin_dict_loaded.copy(), | |
transform_params.copy(), | |
panel.copy(), | |
specific_transform_params.copy(), | |
) | |
# Generate a dictionary mapping original column names to lists of transformed column names | |
transformed_columns_dict, summary_string = generate_transformed_columns( | |
original_columns, transform_params, specific_transform_params | |
) | |
# Store into transformed dataframe and summary session state | |
st.session_state["project_dct"]["transformations"][ | |
"final_df" | |
] = final_df | |
st.session_state["project_dct"]["transformations"][ | |
"summary_string" | |
] = summary_string | |
# Display success message | |
st.success("Transformation of the DataFrame is successful!", icon="โ ") | |
# Log message | |
log_message( | |
"info", | |
"Transformation of the DataFrame is successful!", | |
"Transformations", | |
) | |
######################################################################################################################################################### | |
# Display the transformed DataFrame and summary | |
######################################################################################################################################################### | |
# Display the transformed DataFrame in the Streamlit app | |
st.markdown("### Transformed DataFrame") | |
with st.spinner("Please wait while the transformed DataFrame is loading ..."): | |
final_df = st.session_state["project_dct"]["transformations"]["final_df"].copy() | |
# Clean display DataFrame | |
display_df, exceeds_max_col = clean_display_df(final_df, display_max_col) | |
# Check the number of columns and show only the first display_max_col if there are more | |
if exceeds_max_col: | |
# Display a info if the DataFrame has more than display_max_col columns | |
st.info( | |
f"The transformed DataFrame has more than {display_max_col} columns. Displaying only the first {display_max_col} columns.", | |
icon="๐ฌ", | |
) | |
# Display Final DataFrame | |
st.dataframe( | |
display_df, | |
hide_index=True, | |
column_config={ | |
"date": st.column_config.DateColumn("date", format="YYYY-MM-DD") | |
}, | |
) | |
# Total rows and columns | |
total_rows, total_columns = st.session_state["project_dct"]["transformations"][ | |
"final_df" | |
].shape | |
st.markdown( | |
f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>", | |
unsafe_allow_html=True, | |
) | |
# Display the summary of transformations as markdown | |
if ( | |
"summary_string" in st.session_state["project_dct"]["transformations"] | |
and st.session_state["project_dct"]["transformations"]["summary_string"] | |
): | |
with st.expander("Summary of Transformations"): | |
st.markdown("### Summary of Transformations") | |
st.markdown( | |
st.session_state["project_dct"]["transformations"][ | |
"summary_string" | |
], | |
unsafe_allow_html=True, | |
) | |
######################################################################################################################################################### | |
# Correlation Plot | |
######################################################################################################################################################### | |
# Filter out the 'date' column | |
variables = [ | |
col for col in final_df.columns if col.lower() not in ["date", "panel"] | |
] | |
with st.expander("Transformed Variable Correlation Plot"): | |
selected_vars = st.multiselect( | |
label="Choose variables for correlation plot:", | |
options=variables, | |
max_selections=30, | |
default=st.session_state["project_dct"]["transformations"][ | |
"correlation_plot_selection" | |
], | |
key="correlation_plot_key", | |
) | |
# Calculate correlation | |
if selected_vars: | |
corr_df = final_df[selected_vars].corr() | |
# Prepare text annotations with 2 decimal places | |
annotations = [] | |
for i in range(len(corr_df)): | |
for j in range(len(corr_df.columns)): | |
annotations.append( | |
go.layout.Annotation( | |
text=f"{corr_df.iloc[i, j]:.2f}", | |
x=corr_df.columns[j], | |
y=corr_df.index[i], | |
showarrow=False, | |
font=dict(color="black"), | |
) | |
) | |
# Plotly correlation plot using go | |
heatmap = go.Heatmap( | |
z=corr_df.values, | |
x=corr_df.columns, | |
y=corr_df.index, | |
colorscale="RdBu", | |
zmin=-1, | |
zmax=1, | |
) | |
layout = go.Layout( | |
title="Transformed Variable Correlation Plot", | |
xaxis=dict(title="Variables"), | |
yaxis=dict(title="Variables"), | |
width=1000, | |
height=1000, | |
annotations=annotations, | |
) | |
fig = go.Figure(data=[heatmap], layout=layout) | |
st.plotly_chart(fig) | |
else: | |
st.write("Please select at least one variable to plot.") | |
######################################################################################################################################################### | |
# Accept and Save | |
######################################################################################################################################################### | |
# Check for saved model | |
if ( | |
retrieve_pkl_object( | |
st.session_state["project_number"], "Model_Build", "best_models", schema | |
) | |
is not None | |
): # db | |
st.warning( | |
"Saving transformations will overwrite existing ones and delete all saved models. To keep previous models, please start a new project.", | |
icon="โ ๏ธ", | |
) | |
if st.button("Accept and Save", use_container_width=True): | |
with st.spinner("Saving Changes"): | |
# Update correlation plot selection | |
st.session_state["project_dct"]["transformations"][ | |
"correlation_plot_selection" | |
] = st.session_state["correlation_plot_key"] | |
# Clear model metadata | |
clear_pages() | |
# Update DB | |
update_db( | |
prj_id=st.session_state["project_number"], | |
page_nam="Transformations", | |
file_nam="project_dct", | |
pkl_obj=pickle.dumps(st.session_state["project_dct"]), | |
schema=schema, | |
) | |
# Clear data from DB | |
delete_entries( | |
st.session_state["project_number"], | |
["Model_Build", "Model_Tuning"], | |
db_cred, | |
schema, | |
) | |
# Success message | |
st.success("Saved Successfully!", icon="๐พ") | |
st.toast("Saved Successfully!", icon="๐พ") | |
# Log message | |
log_message("info", "Saved Successfully!", "Transformations") | |
except Exception as e: | |
# Capture the error details | |
exc_type, exc_value, exc_traceback = sys.exc_info() | |
error_message = "".join( | |
traceback.format_exception(exc_type, exc_value, exc_traceback) | |
) | |
# Log message | |
log_message("error", f"An error occurred: {error_message}.", "Transformations") | |
# Display a warning message | |
st.warning( | |
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.", | |
icon="โ ๏ธ", | |
) | |