# Importing necessary libraries import streamlit as st st.set_page_config( page_title="AI Model Transformations", page_icon="⚖️", layout="wide", initial_sidebar_state="collapsed", ) import sys import pickle import traceback import numpy as np import pandas as pd import plotly.graph_objects as go from post_gres_cred import db_cred from log_application import log_message from utilities import ( set_header, load_local_css, update_db, project_selection, delete_entries, retrieve_pkl_object, ) from constants import ( predefined_defaults, lead_min_value, lead_max_value, lead_step, lag_min_value, lag_max_value, lag_step, moving_average_min_value, moving_average_max_value, moving_average_step, saturation_min_value, saturation_max_value, saturation_step, power_min_value, power_max_value, power_step, adstock_min_value, adstock_max_value, adstock_step, display_max_col, ) schema = db_cred["schema"] load_local_css("styles.css") set_header() # Initialize project name session state if "project_name" not in st.session_state: st.session_state["project_name"] = None # Fetch project dictionary if "project_dct" not in st.session_state: project_selection() st.stop() # Display Username and Project Name if "username" in st.session_state and st.session_state["username"] is not None: cols1 = st.columns([2, 1]) with cols1[0]: st.markdown(f"**Welcome {st.session_state['username']}**") with cols1[1]: st.markdown(f"**Current Project: {st.session_state['project_name']}**") # Load saved data from project dictionary if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None: st.warning( "The data import is incomplete. Please go back to the Data Import page and complete the save.", icon="🔙", ) # Log message log_message( "warning", "The data import is incomplete. Please go back to the Data Import page and complete the save.", "Transformations", ) st.stop() else: final_df_loaded = st.session_state["project_dct"]["data_import"][ "imputed_tool_df" ].copy() bin_dict_loaded = st.session_state["project_dct"]["data_import"][ "category_dict" ].copy() unique_panels = st.session_state["project_dct"]["data_import"][ "unique_panels" ].copy() # Initialize project dictionary data if st.session_state["project_dct"]["transformations"]["final_df"] is None: st.session_state["project_dct"]["transformations"][ "final_df" ] = final_df_loaded # Default as original dataframe # Extract original columns for specified categories original_columns = { category: bin_dict_loaded[category] for category in ["Media", "Internal", "Exogenous"] if category in bin_dict_loaded } # Retrive Panel columns panel = ["panel"] if len(unique_panels) > 1 else [] # Function to clear model metadata def clear_pages(): # Reset Pages st.session_state["project_dct"]["model_build"] = { "sel_target_col": None, "all_iters_check": False, "iterations": 0, "build_button": False, "show_results_check": False, "session_state_saved": {}, } st.session_state["project_dct"]["model_tuning"] = { "sel_target_col": None, "sel_model": {}, "flag_expander": False, "start_date_default": None, "end_date_default": None, "repeat_default": "No", "flags": {}, "select_all_flags_check": {}, "selected_flags": {}, "trend_check": False, "week_num_check": False, "sine_cosine_check": False, "session_state_saved": {}, } st.session_state["project_dct"]["saved_model_results"] = { "selected_options": None, "model_grid_sel": [1], } if "model_results_df" in st.session_state: del st.session_state["model_results_df"] if "model_results_data" in st.session_state: del st.session_state["model_results_data"] if "coefficients_df" in st.session_state: del st.session_state["coefficients_df"] # Function to update transformation change def transformation_change(category, transformation, key): st.session_state["project_dct"]["transformations"][category][transformation] = ( st.session_state[key] ) # Function to update specific transformation change def transformation_specific_change(channel_name, transformation, key): st.session_state["project_dct"]["transformations"]["Specific"][transformation][ channel_name ] = st.session_state[key] # Function to update transformations to apply change def transformations_to_apply_change(category, key): st.session_state["project_dct"]["transformations"][category][key] = ( st.session_state[key] ) # Function to update channel select specific change def channel_select_specific_change(): st.session_state["project_dct"]["transformations"]["Specific"][ "channel_select_specific" ] = st.session_state["channel_select_specific"] # Function to update specific transformation change def specific_transformation_change(specific_transformation_key): st.session_state["project_dct"]["transformations"]["Specific"][ specific_transformation_key ] = st.session_state[specific_transformation_key] # Function to build transformation widgets def transformation_widgets(category, transform_params, date_granularity): # Transformation Options transformation_options = { "Media": [ "Lag", "Moving Average", "Saturation", "Power", "Adstock", ], "Internal": ["Lead", "Lag", "Moving Average"], "Exogenous": ["Lead", "Lag", "Moving Average"], } # Define a helper function to create widgets for each transformation def create_transformation_widgets(column, transformations): with column: for transformation in transformations: transformation_key = f"{transformation}_{category}" slider_value = st.session_state["project_dct"]["transformations"][ category ].get(transformation, predefined_defaults[transformation]) # Conditionally create widgets for selected transformations if transformation == "Lead": st.markdown(f"**{transformation} ({date_granularity})**") lead = st.slider( label="Lead periods", min_value=lead_min_value, max_value=lead_max_value, step=lead_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_change, args=( category, transformation, transformation_key, ), ) start = lead[0] end = lead[1] step = lead_step transform_params[category][transformation] = np.arange( start, end + step, step ) if transformation == "Lag": st.markdown(f"**{transformation} ({date_granularity})**") lag = st.slider( label="Lag periods", min_value=lag_min_value, max_value=lag_max_value, step=lag_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_change, args=( category, transformation, transformation_key, ), ) start = lag[0] end = lag[1] step = lag_step transform_params[category][transformation] = np.arange( start, end + step, step ) if transformation == "Moving Average": st.markdown(f"**{transformation} ({date_granularity})**") window = st.slider( label="Window size for Moving Average", min_value=moving_average_min_value, max_value=moving_average_max_value, step=moving_average_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_change, args=( category, transformation, transformation_key, ), ) start = window[0] end = window[1] step = moving_average_step transform_params[category][transformation] = np.arange( start, end + step, step ) if transformation == "Saturation": st.markdown(f"**{transformation} (%)**") saturation_point = st.slider( label="Saturation Percentage", min_value=saturation_min_value, max_value=saturation_max_value, step=saturation_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_change, args=( category, transformation, transformation_key, ), ) start = saturation_point[0] end = saturation_point[1] step = saturation_step transform_params[category][transformation] = np.arange( start, end + step, step ) if transformation == "Power": st.markdown(f"**{transformation}**") power = st.slider( label="Power", min_value=power_min_value, max_value=power_max_value, step=power_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_change, args=( category, transformation, transformation_key, ), ) start = power[0] end = power[1] step = power_step transform_params[category][transformation] = np.arange( start, end + step, step ) if transformation == "Adstock": st.markdown(f"**{transformation}**") rate = st.slider( label="Decay Factor", min_value=adstock_min_value, max_value=adstock_max_value, step=adstock_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_change, args=( category, transformation, transformation_key, ), ) start = rate[0] end = rate[1] step = adstock_step adstock_range = [ round(a, 3) for a in np.arange(start, end + step, step) ] transform_params[category][transformation] = np.array(adstock_range) with st.expander(f"All {category} Transformations", expanded=True): transformation_key = f"transformation_{category}" # Select which transformations to apply sel_transformations = st.session_state["project_dct"]["transformations"][ category ].get(transformation_key, []) # Reset default selected channels list if options are changed for channel in sel_transformations: if channel not in transformation_options[category]: ( st.session_state["project_dct"]["transformations"][category][ transformation_key ], sel_transformations, ) = ([], []) transformations_to_apply = st.multiselect( label="Select transformations to apply", options=transformation_options[category], default=sel_transformations, key=transformation_key, on_change=transformations_to_apply_change, args=( category, transformation_key, ), ) # Determine the number of transformations to put in each column transformations_per_column = ( len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 ) # Create two columns col1, col2 = st.columns(2) # Assign transformations to each column transformations_col1 = transformations_to_apply[:transformations_per_column] transformations_col2 = transformations_to_apply[transformations_per_column:] # Create widgets in each column create_transformation_widgets(col1, transformations_col1) create_transformation_widgets(col2, transformations_col2) # Define a helper function to create widgets for each specific transformation def create_specific_transformation_widgets( column, transformations, channel_name, date_granularity, specific_transform_params, ): with column: for transformation in transformations: transformation_key = f"{transformation}_{channel_name}_specific" if ( transformation not in st.session_state["project_dct"]["transformations"]["Specific"] ): st.session_state["project_dct"]["transformations"]["Specific"][ transformation ] = {} slider_value = st.session_state["project_dct"]["transformations"][ "Specific" ][transformation].get(channel_name, predefined_defaults[transformation]) # Conditionally create widgets for selected transformations if transformation == "Lead": st.markdown(f"**Lead ({date_granularity})**") lead = st.slider( label="Lead periods", min_value=lead_min_value, max_value=lead_max_value, step=lead_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_specific_change, args=( channel_name, transformation, transformation_key, ), ) start = lead[0] end = lead[1] step = lead_step specific_transform_params[channel_name]["Lead"] = np.arange( start, end + step, step ) if transformation == "Lag": st.markdown(f"**Lag ({date_granularity})**") lag = st.slider( label="Lag periods", min_value=lag_min_value, max_value=lag_max_value, step=lag_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_specific_change, args=( channel_name, transformation, transformation_key, ), ) start = lag[0] end = lag[1] step = lag_step specific_transform_params[channel_name]["Lag"] = np.arange( start, end + step, step ) if transformation == "Moving Average": st.markdown(f"**Moving Average ({date_granularity})**") window = st.slider( label="Window size for Moving Average", min_value=moving_average_min_value, max_value=moving_average_max_value, step=moving_average_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_specific_change, args=( channel_name, transformation, transformation_key, ), ) start = window[0] end = window[1] step = moving_average_step specific_transform_params[channel_name]["Moving Average"] = np.arange( start, end + step, step ) if transformation == "Saturation": st.markdown("**Saturation (%)**") saturation_point = st.slider( label="Saturation Percentage", min_value=saturation_min_value, max_value=saturation_max_value, step=saturation_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_specific_change, args=( channel_name, transformation, transformation_key, ), ) start = saturation_point[0] end = saturation_point[1] step = saturation_step specific_transform_params[channel_name]["Saturation"] = np.arange( start, end + step, step ) if transformation == "Power": st.markdown("**Power**") power = st.slider( label="Power", min_value=power_min_value, max_value=power_max_value, step=power_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_specific_change, args=( channel_name, transformation, transformation_key, ), ) start = power[0] end = power[1] step = power_step specific_transform_params[channel_name]["Power"] = np.arange( start, end + step, step ) if transformation == "Adstock": st.markdown("**Adstock**") rate = st.slider( label="Decay Factor", min_value=adstock_min_value, max_value=adstock_max_value, step=adstock_step, value=slider_value, key=transformation_key, label_visibility="collapsed", on_change=transformation_specific_change, args=( channel_name, transformation, transformation_key, ), ) start = rate[0] end = rate[1] step = adstock_step adstock_range = [ round(a, 3) for a in np.arange(start, end + step, step) ] specific_transform_params[channel_name]["Adstock"] = np.array( adstock_range ) # Function to apply Lag transformation def apply_lag(df, lag): return df.shift(lag) # Function to apply Lead transformation def apply_lead(df, lead): return df.shift(-lead) # Function to apply Moving Average transformation def apply_moving_average(df, window_size): return df.rolling(window=window_size).mean() # Function to apply Saturation transformation def apply_saturation(df, saturation_percent_100): # Convert percentage to fraction saturation_percent = min(max(saturation_percent_100, 0.01), 99.99) / 100.0 # Get the maximum and minimum values column_max = df.max() column_min = df.min() # If the data is constant, scale it directly if column_min == column_max: return df.apply(lambda x: x * saturation_percent) # Compute the saturation point from the data range saturation_point = (column_min + saturation_percent * column_max) / 2 # Calculate steepness for the saturation curve numerator = np.log((1 / saturation_percent) - 1) denominator = np.log(saturation_point / column_max) steepness = numerator / denominator # Apply the saturation transformation transformed_series = df.apply( lambda x: (1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)) * x ) return transformed_series # Function to apply Power transformation def apply_power(df, power): return df**power # Function to apply Adstock transformation def apply_adstock(df, factor): x = 0 # Use the walrus operator to update x iteratively with the Adstock formula adstock_var = [x := x * factor + v for v in df] ans = pd.Series(adstock_var, index=df.index) return ans # Function to generate transformed columns names @st.cache_resource(show_spinner=False) def generate_transformed_columns( original_columns, transform_params, specific_transform_params ): transformed_columns, summary = {}, {} for category, columns in original_columns.items(): for column in columns: transformed_columns[column] = [] summary_details = ( [] ) # List to hold transformation details for the current column if ( column in specific_transform_params.keys() and len(specific_transform_params[column]) > 0 ): for transformation, values in specific_transform_params[column].items(): # Generate transformed column names for each value for value in values: transformed_name = f"{column}@{transformation}_{value}" transformed_columns[column].append(transformed_name) # Format the values list as a string with commas and "and" before the last item if len(values) > 1: formatted_values = ( ", ".join(map(str, values[:-1])) + " and " + str(values[-1]) ) else: formatted_values = str(values[0]) # Add transformation details summary_details.append(f"{transformation} ({formatted_values})") else: if category in transform_params: for transformation, values in transform_params[category].items(): # Generate transformed column names for each value if column not in specific_transform_params.keys(): for value in values: transformed_name = f"{column}@{transformation}_{value}" transformed_columns[column].append(transformed_name) # Format the values list as a string with commas and "and" before the last item if len(values) > 1: formatted_values = ( ", ".join(map(str, values[:-1])) + " and " + str(values[-1]) ) else: formatted_values = str(values[0]) # Add transformation details summary_details.append( f"{transformation} ({formatted_values})" ) else: summary_details = ["No transformation selected"] # Only add to summary if there are transformation details for the column if summary_details: formatted_summary = "⮕ ".join(summary_details) # Use tags to make the column name bold summary[column] = f"{column}: {formatted_summary}" # Generate a comprehensive summary string for all columns summary_items = [ f"{idx + 1}. {details}" for idx, details in enumerate(summary.values()) ] summary_string = "\n".join(summary_items) return transformed_columns, summary_string # Function to transform Dataframe slice def transform_slice( transform_params, transformation_functions, panel, df, df_slice, category, category_df, ): # Iterate through each transformation and its parameters for the current category for transformation, parameters in transform_params[category].items(): transformation_function = transformation_functions[transformation] # Check if there is panel data to group by if len(panel) > 0: # Apply the transformation to each group category_df = pd.concat( [ df_slice.groupby(panel) .transform(transformation_function, p) .add_suffix(f"@{transformation}_{p}") for p in parameters ], axis=1, ) # Replace all NaN or null values in category_df with 0 category_df.fillna(0, inplace=True) # Update df_slice df_slice = pd.concat( [df[panel], category_df], axis=1, ) else: for p in parameters: # Apply the transformation function to each column temp_df = df_slice.apply( lambda x: transformation_function(x, p), axis=0 ).rename( lambda x: f"{x}@{transformation}_{p}", axis="columns", ) # Concatenate the transformed DataFrame slice to the category DataFrame category_df = pd.concat([category_df, temp_df], axis=1) # Replace all NaN or null values in category_df with 0 category_df.fillna(0, inplace=True) # Update df_slice df_slice = pd.concat( [df[panel], category_df], axis=1, ) return category_df, df, df_slice # Function to apply transformations to DataFrame slices based on specified categories and parameters @st.cache_resource(show_spinner=False) def apply_category_transformations( df_main, bin_dict, transform_params, panel, specific_transform_params ): # Dictionary for function mapping transformation_functions = { "Lead": apply_lead, "Lag": apply_lag, "Moving Average": apply_moving_average, "Saturation": apply_saturation, "Power": apply_power, "Adstock": apply_adstock, } # List to collect all transformed DataFrames transformed_dfs = [] # Iterate through each category specified in transform_params for category in ["Media", "Exogenous", "Internal"]: if ( category not in transform_params or category not in bin_dict or not transform_params[category] ): continue # Skip categories without transformations # Initialize category_df as an empty DataFrame category_df = pd.DataFrame() # Slice the DataFrame based on the columns specified in bin_dict for the current category df_slice = df_main[bin_dict[category] + panel].copy() # Drop the column from df_slice to skip specific transformations df_slice = df_slice.drop( columns=list(specific_transform_params.keys()), errors="ignore" ).copy() category_df, df, df_slice_updated = transform_slice( transform_params.copy(), transformation_functions.copy(), panel, df_main.copy(), df_slice.copy(), category, category_df.copy(), ) # Append the transformed category DataFrame to the list if it's not empty if not category_df.empty: transformed_dfs.append(category_df) # Apply channel specific transforms for channel_specific in specific_transform_params: # Initialize category_df as an empty DataFrame category_df = pd.DataFrame() df_slice_specific = df_main[[channel_specific] + panel].copy() transform_params_specific = { "Media": specific_transform_params[channel_specific] } category_df, df, df_slice_specific_updated = transform_slice( transform_params_specific.copy(), transformation_functions.copy(), panel, df_main.copy(), df_slice_specific.copy(), "Media", category_df.copy(), ) # Append the transformed category DataFrame to the list if it's not empty if not category_df.empty: transformed_dfs.append(category_df) # If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame if len(transformed_dfs) > 0: final_df = pd.concat([df_main] + transformed_dfs, axis=1) else: # If no transformations were applied, use the original DataFrame final_df = df_main # Find columns with '@' in their names columns_with_at = [col for col in final_df.columns if "@" in col] # Create a set of columns to drop columns_to_drop = set() # Iterate through columns with '@' to find shorter names to drop for col in columns_with_at: base_name = col.split("@")[0] for other_col in columns_with_at: if other_col.startswith(base_name) and len(other_col.split("@")) > len( col.split("@") ): columns_to_drop.add(col) break # Drop the identified columns from the DataFrame final_df.drop(columns=list(columns_to_drop), inplace=True) return final_df # Function to infers the granularity of the date column in a DataFrame @st.cache_resource(show_spinner=False) def infer_date_granularity(df): # Find the most common difference common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0] # Map the most common difference to a granularity if common_freq == 1: return "daily" elif common_freq == 7: return "weekly" elif 28 <= common_freq <= 31: return "monthly" else: return "irregular" # Function to clean display DataFrame @st.cache_data(show_spinner=False) def clean_display_df(df, display_max_col=500): # Sort by 'panel' and 'date' sort_columns = ["panel", "date"] sorted_df = df.sort_values(by=sort_columns, ascending=True, na_position="first") # Drop duplicate columns sorted_df = sorted_df.loc[:, ~sorted_df.columns.duplicated()] # Check if the DataFrame has more than display_max_col columns exceeds_max_col = sorted_df.shape[1] > display_max_col if exceeds_max_col: # Create a new DataFrame with 'date' and 'panel' at the start display_df = sorted_df[["date", "panel"]] # Add the next display_max_col - 2 columns (as 'date' and 'panel' already occupy 2 columns) additional_columns = sorted_df.columns.difference(["date", "panel"]).tolist()[ : display_max_col - 2 ] display_df = pd.concat([display_df, sorted_df[additional_columns]], axis=1) else: # Ensure 'date' and 'panel' are the first two columns in the final display DataFrame column_order = ["date", "panel"] + sorted_df.columns.difference( ["date", "panel"] ).tolist() display_df = sorted_df[column_order] # Return the display DataFrame and whether it exceeds 500 columns return display_df, exceeds_max_col ######################################################################################################################################################### # User input for transformations ######################################################################################################################################################### try: # Page Title st.title("AI Model Transformations") # Infer date granularity date_granularity = infer_date_granularity(final_df_loaded) # Initialize the main dictionary to store the transformation parameters for each category transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}} st.markdown("### Select Transformations to Apply") with st.expander("Specific Media Transformations"): # Select which transformations to apply sel_channel_specific = st.session_state["project_dct"]["transformations"][ "Specific" ].get("channel_select_specific", []) # Reset default selected channels list if options are changed for channel in sel_channel_specific: if channel not in bin_dict_loaded["Media"]: ( st.session_state["project_dct"]["transformations"]["Specific"][ "channel_select_specific" ], sel_channel_specific, ) = ([], []) select_specific_channels = st.multiselect( label="Select channel variable", default=sel_channel_specific, options=bin_dict_loaded["Media"], key="channel_select_specific", on_change=channel_select_specific_change, max_selections=30, ) specific_transform_params = {} for select_specific_channel in select_specific_channels: specific_transform_params[select_specific_channel] = {} st.divider() channel_name = str(select_specific_channel).replace("_", " ").title() st.markdown(f"###### {channel_name}") specific_transformation_key = ( f"specific_transformation_{select_specific_channel}_Media" ) transformations_options = [ "Lag", "Moving Average", "Saturation", "Power", "Adstock", ] # Select which transformations to apply sel_transformations = st.session_state["project_dct"]["transformations"][ "Specific" ].get(specific_transformation_key, []) # Reset default selected channels list if options are changed for channel in sel_transformations: if channel not in transformations_options: ( st.session_state["project_dct"]["transformations"]["Specific"][ specific_transformation_key ], sel_channel_specific, ) = ([], []) transformations_to_apply = st.multiselect( label="Select transformations to apply", options=transformations_options, default=sel_transformations, key=specific_transformation_key, on_change=specific_transformation_change, args=(specific_transformation_key,), ) # Determine the number of transformations to put in each column transformations_per_column = ( len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 ) # Create two columns col1, col2 = st.columns(2) # Assign transformations to each column transformations_col1 = transformations_to_apply[:transformations_per_column] transformations_col2 = transformations_to_apply[transformations_per_column:] # Create widgets in each column create_specific_transformation_widgets( col1, transformations_col1, select_specific_channel, date_granularity, specific_transform_params, ) create_specific_transformation_widgets( col2, transformations_col2, select_specific_channel, date_granularity, specific_transform_params, ) # Create Widgets for category in ["Media", "Internal", "Exogenous"]: # Skip Internal if category == "Internal": continue # Skip category if no column available elif ( category not in bin_dict_loaded.keys() or len(bin_dict_loaded[category]) == 0 ): st.info( f"{str(category).title()} category has no column associated with it. Skipping transformation step for this category.", icon="💬", ) continue transformation_widgets(category, transform_params, date_granularity) ######################################################################################################################################################### # Apply transformations ######################################################################################################################################################### # Reset transformation selection to default button_col = st.columns(2) with button_col[1]: if st.button("Reset to Default", use_container_width=True): st.session_state["project_dct"]["transformations"]["Media"] = {} st.session_state["project_dct"]["transformations"]["Exogenous"] = {} st.session_state["project_dct"]["transformations"]["Internal"] = {} st.session_state["project_dct"]["transformations"]["Specific"] = {} # Log message log_message( "info", "All persistent selections have been reset to their default settings and cleared.", "Transformations", ) st.rerun() # Apply category-based transformations to the DataFrame with button_col[0]: if st.button("Accept and Proceed", use_container_width=True): with st.spinner("Applying transformations ..."): final_df = apply_category_transformations( final_df_loaded.copy(), bin_dict_loaded.copy(), transform_params.copy(), panel.copy(), specific_transform_params.copy(), ) # Generate a dictionary mapping original column names to lists of transformed column names transformed_columns_dict, summary_string = generate_transformed_columns( original_columns, transform_params, specific_transform_params ) # Store into transformed dataframe and summary session state st.session_state["project_dct"]["transformations"][ "final_df" ] = final_df st.session_state["project_dct"]["transformations"][ "summary_string" ] = summary_string # Display success message st.success("Transformation of the DataFrame is successful!", icon="✅") # Log message log_message( "info", "Transformation of the DataFrame is successful!", "Transformations", ) ######################################################################################################################################################### # Display the transformed DataFrame and summary ######################################################################################################################################################### # Display the transformed DataFrame in the Streamlit app st.markdown("### Transformed DataFrame") with st.spinner("Please wait while the transformed DataFrame is loading ..."): final_df = st.session_state["project_dct"]["transformations"]["final_df"].copy() # Clean display DataFrame display_df, exceeds_max_col = clean_display_df(final_df, display_max_col) # Check the number of columns and show only the first display_max_col if there are more if exceeds_max_col: # Display a info if the DataFrame has more than display_max_col columns st.info( f"The transformed DataFrame has more than {display_max_col} columns. Displaying only the first {display_max_col} columns.", icon="💬", ) # Display Final DataFrame st.dataframe( display_df, hide_index=True, column_config={ "date": st.column_config.DateColumn("date", format="YYYY-MM-DD") }, ) # Total rows and columns total_rows, total_columns = st.session_state["project_dct"]["transformations"][ "final_df" ].shape st.markdown( f"

The transformed DataFrame contains {total_rows} rows and {total_columns} columns.

", unsafe_allow_html=True, ) # Display the summary of transformations as markdown if ( "summary_string" in st.session_state["project_dct"]["transformations"] and st.session_state["project_dct"]["transformations"]["summary_string"] ): with st.expander("Summary of Transformations"): st.markdown("### Summary of Transformations") st.markdown( st.session_state["project_dct"]["transformations"][ "summary_string" ], unsafe_allow_html=True, ) ######################################################################################################################################################### # Correlation Plot ######################################################################################################################################################### # Filter out the 'date' column variables = [ col for col in final_df.columns if col.lower() not in ["date", "panel"] ] with st.expander("Transformed Variable Correlation Plot"): selected_vars = st.multiselect( label="Choose variables for correlation plot:", options=variables, max_selections=30, default=st.session_state["project_dct"]["transformations"][ "correlation_plot_selection" ], key="correlation_plot_key", ) # Calculate correlation if selected_vars: corr_df = final_df[selected_vars].corr() # Prepare text annotations with 2 decimal places annotations = [] for i in range(len(corr_df)): for j in range(len(corr_df.columns)): annotations.append( go.layout.Annotation( text=f"{corr_df.iloc[i, j]:.2f}", x=corr_df.columns[j], y=corr_df.index[i], showarrow=False, font=dict(color="black"), ) ) # Plotly correlation plot using go heatmap = go.Heatmap( z=corr_df.values, x=corr_df.columns, y=corr_df.index, colorscale="RdBu", zmin=-1, zmax=1, ) layout = go.Layout( title="Transformed Variable Correlation Plot", xaxis=dict(title="Variables"), yaxis=dict(title="Variables"), width=1000, height=1000, annotations=annotations, ) fig = go.Figure(data=[heatmap], layout=layout) st.plotly_chart(fig) else: st.write("Please select at least one variable to plot.") ######################################################################################################################################################### # Accept and Save ######################################################################################################################################################### # Check for saved model if ( retrieve_pkl_object( st.session_state["project_number"], "Model_Build", "best_models", schema ) is not None ): # db st.warning( "Saving transformations will overwrite existing ones and delete all saved models. To keep previous models, please start a new project.", icon="⚠️", ) if st.button("Accept and Save", use_container_width=True): with st.spinner("Saving Changes"): # Update correlation plot selection st.session_state["project_dct"]["transformations"][ "correlation_plot_selection" ] = st.session_state["correlation_plot_key"] # Clear model metadata clear_pages() # Update DB update_db( prj_id=st.session_state["project_number"], page_nam="Transformations", file_nam="project_dct", pkl_obj=pickle.dumps(st.session_state["project_dct"]), schema=schema, ) # Clear data from DB delete_entries( st.session_state["project_number"], ["Model_Build", "Model_Tuning"], db_cred, schema, ) # Success message st.success("Saved Successfully!", icon="💾") st.toast("Saved Successfully!", icon="💾") # Log message log_message("info", "Saved Successfully!", "Transformations") except Exception as e: # Capture the error details exc_type, exc_value, exc_traceback = sys.exc_info() error_message = "".join( traceback.format_exception(exc_type, exc_value, exc_traceback) ) # Log message log_message("error", f"An error occurred: {error_message}.", "Transformations") # Display a warning message st.warning( "Oops! Something went wrong. Please try refreshing the tool or creating a new project.", icon="⚠️", )